Richard Piazza 2017-10-04 10:27:18 -04:00
commit ebd28aca9d
20 changed files with 1029 additions and 600 deletions

View File

@ -8,6 +8,8 @@ from .common import (TLP_AMBER, TLP_GREEN, TLP_RED, TLP_WHITE, CustomMarking,
MarkingDefinition, StatementMarking, TLPMarking) MarkingDefinition, StatementMarking, TLPMarking)
from .core import Bundle, _register_type, parse from .core import Bundle, _register_type, parse
from .environment import Environment, ObjectFactory from .environment import Environment, ObjectFactory
from .markings import (add_markings, clear_markings, get_markings, is_marked,
remove_markings, set_markings)
from .observables import (URL, AlternateDataStream, ArchiveExt, Artifact, from .observables import (URL, AlternateDataStream, ArchiveExt, Artifact,
AutonomousSystem, CustomObservable, Directory, AutonomousSystem, CustomObservable, Directory,
DomainName, EmailAddress, EmailMessage, DomainName, EmailAddress, EmailMessage,

View File

@ -3,6 +3,7 @@
from collections import OrderedDict from collections import OrderedDict
from .base import _STIXBase from .base import _STIXBase
from .markings import MarkingsMixin
from .properties import (HashesProperty, IDProperty, ListProperty, Property, from .properties import (HashesProperty, IDProperty, ListProperty, Property,
ReferenceProperty, SelectorProperty, StringProperty, ReferenceProperty, SelectorProperty, StringProperty,
TimestampProperty, TypeProperty) TimestampProperty, TypeProperty)
@ -76,7 +77,7 @@ class MarkingProperty(Property):
raise ValueError("must be a Statement, TLP Marking or a registered marking.") raise ValueError("must be a Statement, TLP Marking or a registered marking.")
class MarkingDefinition(_STIXBase): class MarkingDefinition(_STIXBase, MarkingsMixin):
_type = 'marking-definition' _type = 'marking-definition'
_properties = OrderedDict() _properties = OrderedDict()
_properties.update([ _properties.update([

View File

@ -1,7 +1,7 @@
import copy import copy
from .core import parse as _parse from .core import parse as _parse
from .sources import CompositeDataSource, DataSource, DataStore from .sources import CompositeDataSource, DataStore
class ObjectFactory(object): class ObjectFactory(object):
@ -132,17 +132,15 @@ class Environment(object):
def add_filters(self, *args, **kwargs): def add_filters(self, *args, **kwargs):
try: try:
return self.source.add_filters(*args, **kwargs) return self.source.filters.update(*args, **kwargs)
except AttributeError: except AttributeError:
raise AttributeError('Environment has no data source') raise AttributeError('Environment has no data source')
add_filters.__doc__ = DataSource.add_filters.__doc__
def add_filter(self, *args, **kwargs): def add_filter(self, *args, **kwargs):
try: try:
return self.source.add_filter(*args, **kwargs) return self.source.filters.add(*args, **kwargs)
except AttributeError: except AttributeError:
raise AttributeError('Environment has no data source') raise AttributeError('Environment has no data source')
add_filter.__doc__ = DataSource.add_filter.__doc__
def add(self, *args, **kwargs): def add(self, *args, **kwargs):
try: try:

View File

@ -212,3 +212,16 @@ def is_marked(obj, marking=None, selectors=None, inherited=False, descendants=Fa
result = result or object_markings.is_marked(obj, object_marks) result = result or object_markings.is_marked(obj, object_marks)
return result return result
class MarkingsMixin():
pass
# Note that all of these methods will return a new object because of immutability
MarkingsMixin.get_markings = get_markings
MarkingsMixin.set_markings = set_markings
MarkingsMixin.remove_markings = remove_markings
MarkingsMixin.add_markings = add_markings
MarkingsMixin.clear_markings = clear_markings
MarkingsMixin.is_marked = is_marked

View File

@ -88,6 +88,7 @@ def remove_markings(obj, marking, selectors):
""" """
selectors = utils.convert_to_list(selectors) selectors = utils.convert_to_list(selectors)
marking = utils.convert_to_marking_list(marking)
utils.validate(obj, selectors) utils.validate(obj, selectors)
granular_markings = obj.get("granular_markings") granular_markings = obj.get("granular_markings")
@ -97,12 +98,9 @@ def remove_markings(obj, marking, selectors):
granular_markings = utils.expand_markings(granular_markings) granular_markings = utils.expand_markings(granular_markings)
if isinstance(marking, list): to_remove = []
to_remove = [] for m in marking:
for m in marking: to_remove.append({"marking_ref": m, "selectors": selectors})
to_remove.append({"marking_ref": m, "selectors": selectors})
else:
to_remove = [{"marking_ref": marking, "selectors": selectors}]
remove = utils.build_granular_marking(to_remove).get("granular_markings") remove = utils.build_granular_marking(to_remove).get("granular_markings")
@ -140,14 +138,12 @@ def add_markings(obj, marking, selectors):
""" """
selectors = utils.convert_to_list(selectors) selectors = utils.convert_to_list(selectors)
marking = utils.convert_to_marking_list(marking)
utils.validate(obj, selectors) utils.validate(obj, selectors)
if isinstance(marking, list): granular_marking = []
granular_marking = [] for m in marking:
for m in marking: granular_marking.append({"marking_ref": m, "selectors": sorted(selectors)})
granular_marking.append({"marking_ref": m, "selectors": sorted(selectors)})
else:
granular_marking = [{"marking_ref": marking, "selectors": sorted(selectors)}]
if obj.get("granular_markings"): if obj.get("granular_markings"):
granular_marking.extend(obj.get("granular_markings")) granular_marking.extend(obj.get("granular_markings"))
@ -244,7 +240,7 @@ def is_marked(obj, marking=None, selectors=None, inherited=False, descendants=Fa
raise TypeError("Required argument 'selectors' must be provided") raise TypeError("Required argument 'selectors' must be provided")
selectors = utils.convert_to_list(selectors) selectors = utils.convert_to_list(selectors)
marking = utils.convert_to_list(marking) marking = utils.convert_to_marking_list(marking)
utils.validate(obj, selectors) utils.validate(obj, selectors)
granular_markings = obj.get("granular_markings", []) granular_markings = obj.get("granular_markings", [])

View File

@ -31,7 +31,7 @@ def add_markings(obj, marking):
A new version of the given SDO or SRO with specified markings added. A new version of the given SDO or SRO with specified markings added.
""" """
marking = utils.convert_to_list(marking) marking = utils.convert_to_marking_list(marking)
object_markings = set(obj.get("object_marking_refs", []) + marking) object_markings = set(obj.get("object_marking_refs", []) + marking)
@ -55,7 +55,7 @@ def remove_markings(obj, marking):
A new version of the given SDO or SRO with specified markings removed. A new version of the given SDO or SRO with specified markings removed.
""" """
marking = utils.convert_to_list(marking) marking = utils.convert_to_marking_list(marking)
object_markings = obj.get("object_marking_refs", []) object_markings = obj.get("object_marking_refs", [])
@ -121,7 +121,7 @@ def is_marked(obj, marking=None):
provided marking refs match, True is returned. provided marking refs match, True is returned.
""" """
marking = utils.convert_to_list(marking) marking = utils.convert_to_marking_list(marking)
object_markings = obj.get("object_marking_refs", []) object_markings = obj.get("object_marking_refs", [])
if marking: if marking:

View File

@ -37,6 +37,12 @@ def _validate_selector(obj, selector):
return True return True
def _get_marking_id(marking):
if type(marking).__name__ is 'MarkingDefinition': # avoid circular import
return marking.id
return marking
def validate(obj, selectors): def validate(obj, selectors):
"""Given an SDO or SRO, check that each selector is valid.""" """Given an SDO or SRO, check that each selector is valid."""
if selectors: if selectors:
@ -57,6 +63,15 @@ def convert_to_list(data):
return [data] return [data]
def convert_to_marking_list(data):
"""Convert input into a list of marking identifiers."""
if data is not None:
if isinstance(data, list):
return [_get_marking_id(x) for x in data]
else:
return [_get_marking_id(data)]
def compress_markings(granular_markings): def compress_markings(granular_markings):
""" """
Compress granular markings list. If there is more than one marking Compress granular markings list. If there is more than one marking

View File

@ -6,6 +6,7 @@ import stix2
from .base import _STIXBase from .base import _STIXBase
from .common import ExternalReference, GranularMarking, KillChainPhase from .common import ExternalReference, GranularMarking, KillChainPhase
from .markings import MarkingsMixin
from .observables import ObservableProperty from .observables import ObservableProperty
from .properties import (BooleanProperty, IDProperty, IntegerProperty, from .properties import (BooleanProperty, IDProperty, IntegerProperty,
ListProperty, PatternProperty, ReferenceProperty, ListProperty, PatternProperty, ReferenceProperty,
@ -13,7 +14,11 @@ from .properties import (BooleanProperty, IDProperty, IntegerProperty,
from .utils import NOW from .utils import NOW
class AttackPattern(_STIXBase): class STIXDomainObject(_STIXBase, MarkingsMixin):
pass
class AttackPattern(STIXDomainObject):
_type = 'attack-pattern' _type = 'attack-pattern'
_properties = OrderedDict() _properties = OrderedDict()
@ -34,7 +39,7 @@ class AttackPattern(_STIXBase):
]) ])
class Campaign(_STIXBase): class Campaign(STIXDomainObject):
_type = 'campaign' _type = 'campaign'
_properties = OrderedDict() _properties = OrderedDict()
@ -58,7 +63,7 @@ class Campaign(_STIXBase):
]) ])
class CourseOfAction(_STIXBase): class CourseOfAction(STIXDomainObject):
_type = 'course-of-action' _type = 'course-of-action'
_properties = OrderedDict() _properties = OrderedDict()
@ -78,7 +83,7 @@ class CourseOfAction(_STIXBase):
]) ])
class Identity(_STIXBase): class Identity(STIXDomainObject):
_type = 'identity' _type = 'identity'
_properties = OrderedDict() _properties = OrderedDict()
@ -101,7 +106,7 @@ class Identity(_STIXBase):
]) ])
class Indicator(_STIXBase): class Indicator(STIXDomainObject):
_type = 'indicator' _type = 'indicator'
_properties = OrderedDict() _properties = OrderedDict()
@ -125,7 +130,7 @@ class Indicator(_STIXBase):
]) ])
class IntrusionSet(_STIXBase): class IntrusionSet(STIXDomainObject):
_type = 'intrusion-set' _type = 'intrusion-set'
_properties = OrderedDict() _properties = OrderedDict()
@ -152,7 +157,7 @@ class IntrusionSet(_STIXBase):
]) ])
class Malware(_STIXBase): class Malware(STIXDomainObject):
_type = 'malware' _type = 'malware'
_properties = OrderedDict() _properties = OrderedDict()
@ -173,7 +178,7 @@ class Malware(_STIXBase):
]) ])
class ObservedData(_STIXBase): class ObservedData(STIXDomainObject):
_type = 'observed-data' _type = 'observed-data'
_properties = OrderedDict() _properties = OrderedDict()
@ -195,7 +200,7 @@ class ObservedData(_STIXBase):
]) ])
class Report(_STIXBase): class Report(STIXDomainObject):
_type = 'report' _type = 'report'
_properties = OrderedDict() _properties = OrderedDict()
@ -217,7 +222,7 @@ class Report(_STIXBase):
]) ])
class ThreatActor(_STIXBase): class ThreatActor(STIXDomainObject):
_type = 'threat-actor' _type = 'threat-actor'
_properties = OrderedDict() _properties = OrderedDict()
@ -245,7 +250,7 @@ class ThreatActor(_STIXBase):
]) ])
class Tool(_STIXBase): class Tool(STIXDomainObject):
_type = 'tool' _type = 'tool'
_properties = OrderedDict() _properties = OrderedDict()
@ -267,7 +272,7 @@ class Tool(_STIXBase):
]) ])
class Vulnerability(_STIXBase): class Vulnerability(STIXDomainObject):
_type = 'vulnerability' _type = 'vulnerability'
_properties = OrderedDict() _properties = OrderedDict()
@ -316,7 +321,7 @@ def CustomObject(type='x-custom-type', properties=None):
def custom_builder(cls): def custom_builder(cls):
class _Custom(cls, _STIXBase): class _Custom(cls, STIXDomainObject):
_type = type _type = type
_properties = OrderedDict() _properties = OrderedDict()
_properties.update([ _properties.update([

View File

@ -7,21 +7,12 @@ Classes:
DataSource DataSource
CompositeDataSource CompositeDataSource
TODO:Test everything
Notes:
add_filter(), remove_filter(), deduplicate() - if these functions remain
the exact same for DataSource, DataSink, CompositeDataSource etc... -> just
make those functions an interface to inherit?
""" """
import uuid import uuid
from six import iteritems from stix2.utils import deduplicate
from stix2.sources.filters import (FILTER_OPS, FILTER_VALUE_TYPES,
STIX_COMMON_FIELDS, STIX_COMMON_FILTERS_MAP)
def make_id(): def make_id():
@ -29,13 +20,21 @@ def make_id():
class DataStore(object): class DataStore(object):
""" """An implementer will create a concrete subclass from
An implementer will create a concrete subclass from this class for the specific DataStore.
this abstract class for the specific data store.
Args:
source (DataSource): An existing DataSource to use
as this DataStore's DataSource component
sink (DataSink): An existing DataSink to use
as this DataStore's DataSink component
Attributes: Attributes:
id (str): A unique UUIDv4 to identify this DataStore. id (str): A unique UUIDv4 to identify this DataStore.
source (DataStore): An object that implements DataStore class.
source (DataSource): An object that implements DataSource class.
sink (DataSink): An object that implements DataSink class. sink (DataSink): An object that implements DataSink class.
""" """
@ -47,14 +46,13 @@ class DataStore(object):
def get(self, stix_id): def get(self, stix_id):
"""Retrieve the most recent version of a single STIX object by ID. """Retrieve the most recent version of a single STIX object by ID.
Notes: Translate get() call to the appropriate DataSource call.
Translate API get() call to the appropriate DataSource call.
Args: Args:
stix_id (str): the id of the STIX 2.0 object to retrieve. stix_id (str): the id of the STIX object to retrieve.
Returns: Returns:
stix_obj (dictionary): the single most recent version of the STIX stix_obj: the single most recent version of the STIX
object specified by the "id". object specified by the "id".
""" """
@ -63,15 +61,13 @@ class DataStore(object):
def all_versions(self, stix_id): def all_versions(self, stix_id):
"""Retrieve all versions of a single STIX object by ID. """Retrieve all versions of a single STIX object by ID.
Implement: Implement: Translate all_versions() call to the appropriate DataSource call
Translate all_versions() call to the appropriate DataSource call
Args: Args:
stix_id (str): the id of the STIX 2.0 object to retrieve. stix_id (str): the id of the STIX object to retrieve.
Returns: Returns:
stix_objs (list): a list of STIX objects (where each object is a stix_objs (list): a list of STIX objects
STIX object)
""" """
return self.source.all_versions(stix_id) return self.source.all_versions(stix_id)
@ -79,17 +75,15 @@ class DataStore(object):
def query(self, query): def query(self, query):
"""Retrieve STIX objects matching a set of filters. """Retrieve STIX objects matching a set of filters.
Notes: Implement: Specific data source API calls, processing,
Implement the specific data source API calls, processing, functionality required for retrieving query from the data source.
functionality required for retrieving query from the data source.
Args: Args:
query (list): a list of filters (which collectively are the query) query (list): a list of filters (which collectively are the query)
to conduct search on. to conduct search on.
Returns: Returns:
stix_objs (list): a list of STIX objects (where each object is a stix_objs (list): a list of STIX objects
STIX object)
""" """
return self.source.query(query=query) return self.source.query(query=query)
@ -97,21 +91,17 @@ class DataStore(object):
def add(self, stix_objs): def add(self, stix_objs):
"""Store STIX objects. """Store STIX objects.
Notes: Translates add() to the appropriate DataSink call.
Translate add() to the appropriate DataSink call().
Args: Args:
stix_objs (list): a list of STIX objects (where each object is a stix_objs (list): a list of STIX objects
STIX object)
""" """
return self.sink.add(stix_objs) return self.sink.add(stix_objs)
class DataSink(object): class DataSink(object):
""" """An implementer will create a concrete subclass from
Abstract class for defining a data sink. Intended for subclassing into this class for the specific DataSink.
different sink components.
Attributes: Attributes:
id (str): A unique UUIDv4 to identify this DataSink. id (str): A unique UUIDv4 to identify this DataSink.
@ -123,9 +113,8 @@ class DataSink(object):
def add(self, stix_objs): def add(self, stix_objs):
"""Store STIX objects. """Store STIX objects.
Notes: Implement: Specific data sink API calls, processing,
Implement the specific data sink API calls, processing, functionality required for adding data to the sink
functionality required for adding data to the sink
Args: Args:
stix_objs (list): a list of STIX objects (where each object is a stix_objs (list): a list of STIX objects (where each object is a
@ -136,13 +125,13 @@ class DataSink(object):
class DataSource(object): class DataSource(object):
""" """An implementer will create a concrete subclass from
Abstract class for defining a data source. Intended for subclassing into this class for the specific DataSource.
different source components.
Attributes: Attributes:
id (str): A unique UUIDv4 to identify this DataSource. id (str): A unique UUIDv4 to identify this DataSource.
filters (set): A collection of filters present in this DataSource.
_filters (set): A collection of filters attached to this DataSource.
""" """
def __init__(self): def __init__(self):
@ -151,179 +140,76 @@ class DataSource(object):
def get(self, stix_id, _composite_filters=None): def get(self, stix_id, _composite_filters=None):
""" """
Fill: Implement: Specific data source API calls, processing,
Implement the specific data source API calls, processing, functionality required for retrieving data from the data source
functionality required for retrieving data from the data source
Args: Args:
stix_id (str): the id of the STIX 2.0 object to retrieve. Should stix_id (str): the id of the STIX 2.0 object to retrieve. Should
return a single object, the most recent version of the object return a single object, the most recent version of the object
specified by the "id". specified by the "id".
_composite_filters (list): list of filters passed along from _composite_filters (set): set of filters passed from the parent
the Composite Data Filter. the CompositeDataSource, not user supplied
Returns: Returns:
stix_obj (dictionary): the STIX object to be returned stix_obj: the STIX object
""" """
raise NotImplementedError() raise NotImplementedError()
def all_versions(self, stix_id, _composite_filters=None): def all_versions(self, stix_id, _composite_filters=None):
""" """
Notes: Implement: Similar to get() except returns list of all object versions of
Similar to get() except returns list of all object versions of the specified "id". In addition, implement the specific data
the specified "id". In addition, implement the specific data source API calls, processing, functionality required for retrieving
source API calls, processing, functionality required for retrieving data from the data source.
data from the data source.
Args: Args:
stix_id (str): The id of the STIX 2.0 object to retrieve. Should stix_id (str): The id of the STIX 2.0 object to retrieve. Should
return a list of objects, all the versions of the object return a list of objects, all the versions of the object
specified by the "id". specified by the "id".
_composite_filters (list): list of filters passed from the _composite_filters (set): set of filters passed from the parent
Composite Data Source CompositeDataSource, not user supplied
Returns: Returns:
stix_objs (list): a list of STIX objects (where each object is a stix_objs (list): a list of STIX objects
STIX object)
""" """
raise NotImplementedError() raise NotImplementedError()
def query(self, query, _composite_filters=None): def query(self, query, _composite_filters=None):
""" """
Fill: Implement:Implement the specific data source API calls, processing,
-implement the specific data source API calls, processing, functionality required for retrieving query from the data source
functionality required for retrieving query from the data source
Args: Args:
query (list): a list of filters (which collectively are the query) query (list): a list of filters (which collectively are the query)
to conduct search on to conduct search on
_composite_filters (list): a list of filters passed from the _composite_filters (set): a set of filters passed from the parent
Composite Data Source CompositeDataSource, not user supplied
Returns: Returns:
stix_objs (list): a list of STIX objects
""" """
raise NotImplementedError() raise NotImplementedError()
def add_filters(self, filters):
"""Add multiple filters to be applied to all queries for STIX objects.
Args:
filters (list): list of filters (dict) to add to the Data Source.
"""
for filter in filters:
self.add_filter(filter)
def add_filter(self, filter):
"""Add a filter to be applied to all queries for STIX objects.
Args:
filter: filter to add to the Data Source.
"""
# check filter field is a supported STIX 2.0 common field
if filter.field not in STIX_COMMON_FIELDS:
raise ValueError("Filter 'field' is not a STIX 2.0 common property. Currently only STIX object common properties supported")
# check filter operator is supported
if filter.op not in FILTER_OPS:
raise ValueError("Filter operation (from 'op' field) not supported")
# check filter value type is supported
if type(filter.value) not in FILTER_VALUE_TYPES:
raise ValueError("Filter 'value' type is not supported. The type(value) must be python immutable type or dictionary")
self.filters.add(filter)
def apply_common_filters(self, stix_objs, query):
"""Evaluate filters against a set of STIX 2.0 objects.
Supports only STIX 2.0 common property fields
Args:
stix_objs (list): list of STIX objects to apply the query to
query (list): list of filters (combined form complete query)
Returns:
(list): list of STIX objects that successfully evaluate against
the query.
"""
filtered_stix_objs = []
# evaluate objects against filter
for stix_obj in stix_objs:
clean = True
for filter_ in query:
# skip filter as filter was identified (when added) as
# not a common filter
if filter_.field not in STIX_COMMON_FIELDS:
raise ValueError("Error, field: {0} is not supported for filtering on.".format(filter_.field))
# For properties like granular_markings and external_references
# need to break the first property from the string.
if "." in filter_.field:
field = filter_.field.split(".")[0]
else:
field = filter_.field
# check filter "field" is in STIX object - if cant be
# applied due to STIX object, STIX object is discarded
# (i.e. did not make it through the filter)
if field not in stix_obj.keys():
clean = False
break
match = STIX_COMMON_FILTERS_MAP[filter_.field.split('.')[0]](filter_, stix_obj)
if not match:
clean = False
break
elif match == -1:
raise ValueError("Error, filter operator: {0} not supported for specified field: {1}".format(filter_.op, filter_.field))
# if object unmarked after all filters, add it
if clean:
filtered_stix_objs.append(stix_obj)
return filtered_stix_objs
def deduplicate(self, stix_obj_list):
"""Deduplicate a list of STIX objects to a unique set
Reduces a set of STIX objects to unique set by looking
at 'id' and 'modified' fields - as a unique object version
is determined by the combination of those fields
Args:
stix_obj_list (list): list of STIX objects (dicts)
Returns:
A list with a unique set of the passed list of STIX objects.
"""
unique_objs = {}
for obj in stix_obj_list:
unique_objs[(obj['id'], obj['modified'])] = obj
return list(unique_objs.values())
class CompositeDataSource(DataSource): class CompositeDataSource(DataSource):
"""Controller for all the defined/configured STIX Data Sources. """Controller for all the attached DataSources.
E.g. a user can define n Data Sources - creating Data Source (objects) A user can have a single CompositeDataSource as an interface
for each. There is only one instance of this for any Python STIX 2.0 the a set of DataSources. When an API call is made to the
application. CompositeDataSource, it is delegated to each of the (real)
DataSources that are attached to it.
DataSources can be attached to CompositeDataSource for a variety
of reasons, e.g. common filters, organization, less API calls.
Attributes: Attributes:
name (str): The name that identifies this CompositeDataSource.
data_sources (dict): A dictionary of DataSource objects; to be data_sources (dict): A dictionary of DataSource objects; to be
controlled and used by the Data Source Controller object. controlled and used by the Data Source Controller object.
@ -332,49 +218,52 @@ class CompositeDataSource(DataSource):
"""Create a new STIX Data Source. """Create a new STIX Data Source.
Args: Args:
name (str): A string containing the name to attach in the
CompositeDataSource instance.
""" """
super(CompositeDataSource, self).__init__() super(CompositeDataSource, self).__init__()
self.data_sources = {} self.data_sources = []
def get(self, stix_id, _composite_filters=None): def get(self, stix_id, _composite_filters=None):
"""Retrieve STIX object by 'id' """Retrieve STIX object by STIX ID
Federated retrieve method-iterates through all STIX data sources Federated retrieve method, iterates through all DataSources
defined in the "data_sources" parameter. Each data source has a defined in the "data_sources" parameter. Each data source has a
specific API retrieve-like function and associated parameters. This specific API retrieve-like function and associated parameters. This
function does a federated retrieval and consolidation of the data function does a federated retrieval and consolidation of the data
returned from all the STIX data sources. returned from all the STIX data sources.
Notes: A composite data source will pass its attached filters to
A composite data source will pass its attached filters to each configured data source, pushing filtering to them to handle.
each configured data source, pushing filtering to them to handle.
Args: Args:
stix_id (str): the id of the STIX object to retrieve. stix_id (str): the id of the STIX object to retrieve.
_composite_filters (list): a list of filters passed from the _composite_filters (list): a list of filters passed from a
Composite Data Source CompositeDataSource (i.e. if this CompositeDataSource is attached
to another parent CompositeDataSource), not user supplied
Returns: Returns:
stix_obj (dict): the STIX object to be returned. stix_obj: the STIX object to be returned.
""" """
if not self.get_all_data_sources(): if not self.has_data_sources():
raise AttributeError('CompositeDataSource has no data sources') raise AttributeError('CompositeDataSource has no data sources')
all_data = [] all_data = []
all_filters = set()
all_filters.update(self.filters)
if _composite_filters:
all_filters.update(_composite_filters)
# for every configured Data Source, call its retrieve handler # for every configured Data Source, call its retrieve handler
for ds_id, ds in iteritems(self.data_sources): for ds in self.data_sources:
data = ds.get(stix_id=stix_id, _composite_filters=list(self.filters)) data = ds.get(stix_id=stix_id, _composite_filters=all_filters)
all_data.append(data) all_data.append(data)
# remove duplicate versions # remove duplicate versions
if len(all_data) > 0: if len(all_data) > 0:
all_data = self.deduplicate(all_data) all_data = deduplicate(all_data)
# reduce to most recent version # reduce to most recent version
stix_obj = sorted(all_data, key=lambda k: k['modified'], reverse=True)[0] stix_obj = sorted(all_data, key=lambda k: k['modified'], reverse=True)[0]
@ -382,128 +271,149 @@ class CompositeDataSource(DataSource):
return stix_obj return stix_obj
def all_versions(self, stix_id, _composite_filters=None): def all_versions(self, stix_id, _composite_filters=None):
"""Retrieve STIX objects by 'id' """Retrieve STIX objects by STIX ID
Federated all_versions retrieve method - iterates through all STIX data Federated all_versions retrieve method - iterates through all DataSources
sources defined in "data_sources" defined in "data_sources"
Notes: A composite data source will pass its attached filters to
A composite data source will pass its attached filters to each configured data source, pushing filtering to them to handle
each configured data source, pushing filtering to them to handle
Args: Args:
stix_id (str): id of the STIX objects to retrieve stix_id (str): id of the STIX objects to retrieve
_composite_filters (list): a list of filters passed from the _composite_filters (list): a list of filters passed from a
Composite Data Source CompositeDataSource (i.e. if this CompositeDataSource is attached
to a parent CompositeDataSource), not user supplied
Returns: Returns:
all_data (list): list of STIX objects that have the specified id all_data (list): list of STIX objects that have the specified id
""" """
if not self.get_all_data_sources(): if not self.has_data_sources():
raise AttributeError('CompositeDataSource has no data sources') raise AttributeError('CompositeDataSource has no data sources')
all_data = [] all_data = []
all_filters = self.filters all_filters = set()
all_filters.update(self.filters)
if _composite_filters: if _composite_filters:
all_filters = set(self.filters).update(_composite_filters) all_filters.update(_composite_filters)
# retrieve STIX objects from all configured data sources # retrieve STIX objects from all configured data sources
for ds_id, ds in iteritems(self.data_sources): for ds in self.data_sources:
data = ds.all_versions(stix_id=stix_id, _composite_filters=list(all_filters)) data = ds.all_versions(stix_id=stix_id, _composite_filters=all_filters)
all_data.extend(data) all_data.extend(data)
# remove exact duplicates (where duplicates are STIX 2.0 objects # remove exact duplicates (where duplicates are STIX 2.0 objects
# with the same 'id' and 'modified' values) # with the same 'id' and 'modified' values)
if len(all_data) > 0: if len(all_data) > 0:
all_data = self.deduplicate(all_data) all_data = deduplicate(all_data)
return all_data return all_data
def query(self, query=None, _composite_filters=None): def query(self, query=None, _composite_filters=None):
"""Federate the query to all Data Sources attached to the """Retrieve STIX objects that match query
Federate the query to all DataSources attached to the
Composite Data Source. Composite Data Source.
Args: Args:
query (list): list of filters to search on. query (list): list of filters to search on
_composite_filters (list): a list of filters passed from the _composite_filters (list): a list of filters passed from a
Composite Data Source CompositeDataSource (i.e. if this CompositeDataSource is attached
to a parent CompositeDataSource), not user supplied
Returns: Returns:
all_data (list): list of STIX objects to be returned all_data (list): list of STIX objects to be returned
""" """
if not self.get_all_data_sources(): if not self.has_data_sources():
raise AttributeError('CompositeDataSource has no data sources') raise AttributeError('CompositeDataSource has no data sources')
if not query: if not query:
# dont mess with the query (i.e. convert to a set, as thats done
# within the specific DataSources that are called)
query = [] query = []
all_data = [] all_data = []
all_filters = self.filters
all_filters = set()
all_filters.update(self.filters)
if _composite_filters: if _composite_filters:
all_filters = set(self.filters).update(_composite_filters) all_filters.update(_composite_filters)
# federate query to all attached data sources, # federate query to all attached data sources,
# pass composite filters to id # pass composite filters to id
for ds_id, ds in iteritems(self.data_sources): for ds in self.data_sources:
data = ds.query(query=query, _composite_filters=list(all_filters)) data = ds.query(query=query, _composite_filters=all_filters)
all_data.extend(data) all_data.extend(data)
# remove exact duplicates (where duplicates are STIX 2.0 # remove exact duplicates (where duplicates are STIX 2.0
# objects with the same 'id' and 'modified' values) # objects with the same 'id' and 'modified' values)
if len(all_data) > 0: if len(all_data) > 0:
all_data = self.deduplicate(all_data) all_data = deduplicate(all_data)
return all_data return all_data
def add_data_source(self, data_sources): def add_data_source(self, data_source):
"""Add/attach Data Source to the Composite Data Source instance """Attach a DataSource to CompositeDataSource instance
Args: Args:
data_sources (list): a list of Data Source objects to attach data_source (DataSource): a stix2.DataSource to attach
to the Composite Data Source to the CompositeDataSource
""" """
if not isinstance(data_sources, list): if issubclass(data_source.__class__, DataSource):
data_sources = [data_sources] if data_source.id not in [ds_.id for ds_ in self.data_sources]:
# check DataSource not already attached CompositeDataSource
self.data_sources.append(data_source)
else:
raise TypeError("DataSource (to be added) is not of type stix2.DataSource. DataSource type is '%s'" % type(data_source))
return
def add_data_sources(self, data_sources):
"""Attach list of DataSources to CompositeDataSource instance
Args:
data_sources (list): stix2.DataSources to attach to
CompositeDataSource
"""
for ds in data_sources: for ds in data_sources:
if issubclass(ds.__class__, DataSource): self.add_data_source(ds)
if ds.id in self.data_sources:
# data source already attached to Composite Data Source
continue
# add data source to Composite Data Source
# (its id will be its key identifier)
self.data_sources[ds.id] = ds
else:
# the Data Source object is not a proper subclass
# of DataSource Abstract Class
# TODO: maybe log error?
continue
return return
def remove_data_source(self, data_source_ids): def remove_data_source(self, data_source_id):
"""Remove/detach Data Source from the Composite Data Source instance """Remove DataSource from the CompositeDataSource instance
Args: Args:
data_source_ids (list): a list of Data Source identifiers. data_source_id (str): DataSource IDs.
""" """
for id in data_source_ids: def _match(ds_id, candidate_ds_id):
if id in self.data_sources: return ds_id == candidate_ds_id
del self.data_sources[id]
else: self.data_sources[:] = [ds for ds in self.data_sources if not _match(ds.id, data_source_id)]
raise ValueError("DataSource 'id' not found in CompositeDataSource collection.")
return return
def remove_data_sources(self, data_source_ids):
"""Remove DataSources from the CompositeDataSource instance
Args:
data_source_ids (list): DataSource IDs
"""
for ds_id in data_source_ids:
self.remove_data_source(ds_id)
return
def has_data_sources(self):
return len(self.data_sources)
def get_all_data_sources(self): def get_all_data_sources(self):
"""Return all attached Data Sources return self.data_sources
"""
return self.data_sources.values()

View File

@ -12,71 +12,148 @@ TODO: Test everything
import json import json
import os import os
from stix2 import Bundle from stix2.base import _STIXBase
from stix2.core import Bundle, parse
from stix2.sources import DataSink, DataSource, DataStore from stix2.sources import DataSink, DataSource, DataStore
from stix2.sources.filters import Filter from stix2.sources.filters import Filter, apply_common_filters
from stix2.utils import deduplicate
class FileSystemStore(DataStore): class FileSystemStore(DataStore):
"""FileSystemStore
Provides an interface to an file directory of STIX objects.
FileSystemStore is a wrapper around a paired FileSystemSink
and FileSystemSource.
Args:
stix_dir (str): path to directory of STIX objects
Attributes:
source (FileSystemSource): FuleSystemSource
sink (FileSystemSink): FileSystemSink
""" """
""" def __init__(self, stix_dir):
def __init__(self, stix_dir="stix_data"):
super(FileSystemStore, self).__init__() super(FileSystemStore, self).__init__()
self.source = FileSystemSource(stix_dir=stix_dir) self.source = FileSystemSource(stix_dir=stix_dir)
self.sink = FileSystemSink(stix_dir=stix_dir) self.sink = FileSystemSink(stix_dir=stix_dir)
class FileSystemSink(DataSink): class FileSystemSink(DataSink):
""" """FileSystemSink
"""
def __init__(self, stix_dir="stix_data"):
super(FileSystemSink, self).__init__()
self.stix_dir = os.path.abspath(stix_dir)
# check directory path exists Provides an interface for adding/pushing STIX objects
if not os.path.exists(self.stix_dir): to file directory of STIX objects.
print("Error: directory path for STIX data does not exist")
Can be paired with a FileSystemSource, together as the two
components of a FileSystemStore.
Args:
stix_dir (str): path to directory of STIX objects
"""
def __init__(self, stix_dir):
super(FileSystemSink, self).__init__()
self._stix_dir = os.path.abspath(stix_dir)
if not os.path.exists(self._stix_dir):
raise ValueError("directory path for STIX data does not exist")
@property @property
def stix_dir(self): def stix_dir(self):
return self.stix_dir return self._stix_dir
@stix_dir.setter def add(self, stix_data=None):
def stix_dir(self, dir): """add STIX objects to file directory
self.stix_dir = dir
def add(self, stix_objs=None): Args:
stix_data (STIX object OR dict OR str OR list): valid STIX 2.0 content
in a STIX object(or list of), dict (or list of), or a STIX 2.0
json encoded string
TODO: Bundlify STIX content or no? When dumping to disk.
""" """
Q: bundlify or no? def _check_path_and_write(stix_dir, stix_obj):
""" path = os.path.join(stix_dir, stix_obj["type"], stix_obj["id"] + ".json")
if not stix_objs:
stix_objs = [] if not os.path.exists(os.path.dirname(path)):
for stix_obj in stix_objs: os.makedirs(os.path.dirname(path))
path = os.path.join(self.stix_dir, stix_obj["type"], stix_obj["id"])
json.dump(Bundle([stix_obj]), open(path, 'w+'), indent=4) with open(path, "w") as f:
# Bundle() can take dict or STIX obj as argument
f.write(str(Bundle(stix_obj)))
if isinstance(stix_data, _STIXBase):
# adding python STIX object
_check_path_and_write(self._stix_dir, stix_data)
elif isinstance(stix_data, dict):
if stix_data["type"] == "bundle":
# adding json-formatted Bundle - extracting STIX objects
for stix_obj in stix_data["objects"]:
self.add(stix_obj)
else:
# adding json-formatted STIX
_check_path_and_write(self._stix_dir, stix_data)
elif isinstance(stix_data, str):
# adding json encoded string of STIX content
stix_data = parse(stix_data)
if stix_data["type"] == "bundle":
for stix_obj in stix_data:
self.add(stix_obj)
else:
self.add(stix_data)
elif isinstance(stix_data, list):
# if list, recurse call on individual STIX objects
for stix_obj in stix_data:
self.add(stix_obj)
else:
raise ValueError("stix_data must be a STIX object(or list of, json formatted STIX(or list of) or a json formatted STIX bundle")
class FileSystemSource(DataSource): class FileSystemSource(DataSource):
""" """FileSystemSource
"""
def __init__(self, stix_dir="stix_data"):
super(FileSystemSource, self).__init__()
self.stix_dir = os.path.abspath(stix_dir)
# check directory path exists Provides an interface for searching/retrieving
if not os.path.exists(self.stix_dir): STIX objects from a STIX object file directory.
Can be paired with a FileSystemSink, together as the two
components of a FileSystemStore.
Args:
stix_dir (str): path to directory of STIX objects
"""
def __init__(self, stix_dir):
super(FileSystemSource, self).__init__()
self._stix_dir = os.path.abspath(stix_dir)
if not os.path.exists(self._stix_dir):
print("Error: directory path for STIX data does not exist") print("Error: directory path for STIX data does not exist")
@property @property
def stix_dir(self): def stix_dir(self):
return self.stix_dir return self._stix_dir
@stix_dir.setter
def stix_dir(self, dir):
self.stix_dir = dir
def get(self, stix_id, _composite_filters=None): def get(self, stix_id, _composite_filters=None):
""" """retrieve STIX object from file directory via STIX ID
Args:
stix_id (str): The STIX ID of the STIX object to be retrieved.
composite_filters (set): set of filters passed from the parent
CompositeDataSource, not user supplied
Returns:
(STIX object): STIX object that has the supplied STIX ID.
The STIX object is loaded from its json file, parsed into
a python STIX object and then returned
""" """
query = [Filter("id", "=", stix_id)] query = [Filter("id", "=", stix_id)]
@ -84,30 +161,63 @@ class FileSystemSource(DataSource):
stix_obj = sorted(all_data, key=lambda k: k['modified'])[0] stix_obj = sorted(all_data, key=lambda k: k['modified'])[0]
return stix_obj return parse(stix_obj)
def all_versions(self, stix_id, _composite_filters=None): def all_versions(self, stix_id, _composite_filters=None):
""" """retrieve STIX object from file directory via STIX ID, all versions
Notes:
Since FileSystem sources/sinks don't handle multiple versions Note: Since FileSystem sources/sinks don't handle multiple versions
of a STIX object, this operation is unnecessary. Pass call to get(). of a STIX object, this operation is unnecessary. Pass call to get().
Args:
stix_id (str): The STIX ID of the STIX objects to be retrieved.
composite_filters (set): set of filters passed from the parent
CompositeDataSource, not user supplied
Returns:
(list): of STIX objects that has the supplied STIX ID.
The STIX objects are loaded from their json files, parsed into
a python STIX objects and then returned
""" """
return [self.get(stix_id=stix_id, _composite_filters=_composite_filters)] return [self.get(stix_id=stix_id, _composite_filters=_composite_filters)]
def query(self, query=None, _composite_filters=None): def query(self, query=None, _composite_filters=None):
""" """search and retrieve STIX objects based on the complete query
A "complete query" includes the filters from the query, the filters
attached to MemorySource, and any filters passed from a
CompositeDataSource (i.e. _composite_filters)
Args:
query (list): list of filters to search on
composite_filters (set): set of filters passed from the
CompositeDataSource, not user supplied
Returns:
(list): list of STIX objects that matches the supplied
query. The STIX objects are loaded from their json files,
parsed into a python STIX objects and then returned.
""" """
all_data = [] all_data = []
if query is None: if query is None:
query = [] query = set()
else:
if not isinstance(query, list):
# make sure dont make set from a Filter object,
# need to make a set from a list of Filter objects (even if just one Filter)
query = list(query)
query = set(query)
# combine all query filters # combine all query filters
if self.filters: if self.filters:
query.extend(self.filters.values()) query.update(self.filters)
if _composite_filters: if _composite_filters:
query.extend(_composite_filters) query.update(_composite_filters)
# extract any filters that are for "type" or "id" , as we can then do # extract any filters that are for "type" or "id" , as we can then do
# filtering before reading in the STIX objects. A STIX 'type' filter # filtering before reading in the STIX objects. A STIX 'type' filter
@ -125,12 +235,12 @@ class FileSystemSource(DataSource):
for filter in file_filters: for filter in file_filters:
if filter.field == "type": if filter.field == "type":
if filter.op == "=": if filter.op == "=":
include_paths.append(os.path.join(self.stix_dir, filter.value)) include_paths.append(os.path.join(self._stix_dir, filter.value))
elif filter.op == "!=": elif filter.op == "!=":
declude_paths.append(os.path.join(self.stix_dir, filter.value)) declude_paths.append(os.path.join(self._stix_dir, filter.value))
else: else:
# have to walk entire STIX directory # have to walk entire STIX directory
include_paths.append(self.stix_dir) include_paths.append(self._stix_dir)
# if a user specifies a "type" filter like "type = <stix-object_type>", # if a user specifies a "type" filter like "type = <stix-object_type>",
# the filter is reducing the search space to single stix object types # the filter is reducing the search space to single stix object types
@ -144,7 +254,7 @@ class FileSystemSource(DataSource):
# user has specified types that are not wanted (i.e. "!=") # user has specified types that are not wanted (i.e. "!=")
# so query will look in all STIX directories that are not # so query will look in all STIX directories that are not
# the specified type. Compile correct dir paths # the specified type. Compile correct dir paths
for dir in os.listdir(self.stix_dir): for dir in os.listdir(self._stix_dir):
if os.path.abspath(dir) not in declude_paths: if os.path.abspath(dir) not in declude_paths:
include_paths.append(os.path.abspath(dir)) include_paths.append(os.path.abspath(dir))
@ -153,36 +263,50 @@ class FileSystemSource(DataSource):
if "id" in [filter.field for filter in file_filters]: if "id" in [filter.field for filter in file_filters]:
for filter in file_filters: for filter in file_filters:
if filter.field == "id" and filter.op == "=": if filter.field == "id" and filter.op == "=":
id = filter.value id_ = filter.value
break break
else: else:
id = None id_ = None
else: else:
id = None id_ = None
# now iterate through all STIX objs # now iterate through all STIX objs
for path in include_paths: for path in include_paths:
for root, dirs, files in os.walk(path): for root, dirs, files in os.walk(path):
for file in files: for file_ in files:
if id: if id_:
if id == file.split(".")[0]: if id_ == file_.split(".")[0]:
# since ID is specified in one of filters, can evaluate against filename first without loading # since ID is specified in one of filters, can evaluate against filename first without loading
stix_obj = json.load(file)["objects"] stix_obj = json.load(open(os.path.join(root, file_)))["objects"][0]
# check against other filters, add if match # check against other filters, add if match
all_data.extend(self.apply_common_filters([stix_obj], query)) matches = [stix_obj_ for stix_obj_ in apply_common_filters([stix_obj], query)]
all_data.extend(matches)
else: else:
# have to load into memory regardless to evaluate other filters # have to load into memory regardless to evaluate other filters
stix_obj = json.load(file)["objects"] stix_obj = json.load(open(os.path.join(root, file_)))["objects"][0]
all_data.extend(self.apply_common_filters([stix_obj], query)) matches = [stix_obj_ for stix_obj_ in apply_common_filters([stix_obj], query)]
all_data.extend(matches)
all_data = self.deduplicate(all_data) all_data = deduplicate(all_data)
return all_data
# parse python STIX objects from the STIX object dicts
stix_objs = [parse(stix_obj_dict) for stix_obj_dict in all_data]
return stix_objs
def _parse_file_filters(self, query): def _parse_file_filters(self, query):
"""utility method to extract STIX common filters
that can used to possibly speed up querying STIX objects
from the file system
Extracts filters that are for the "id" and "type" field of
a STIX object. As the file directory is organized by STIX
object type with filenames that are equivalent to the STIX
object ID, these filters can be used first to reduce the
search space of a FileSystemStore(or FileSystemSink)
""" """
""" file_filters = set()
file_filters = [] for filter_ in query:
for filter in query: if filter_.field == "id" or filter_.field == "type":
if filter.field == "id" or filter.field == "type": file_filters.add(filter_)
file_filters.append(filter)
return file_filters return file_filters

View File

@ -4,10 +4,6 @@ Filters for Python STIX 2.0 DataSources, DataSinks, DataStores
Classes: Classes:
Filter Filter
TODO: The script at the bottom of the module works (to capture
all the callable filter methods), however it causes this module
to be imported by itself twice. Not sure how big of deal that is,
or if cleaner solution possible.
""" """
import collections import collections
@ -15,6 +11,8 @@ import types
# Currently, only STIX 2.0 common SDO fields (that are not complex objects) # Currently, only STIX 2.0 common SDO fields (that are not complex objects)
# are supported for filtering on # are supported for filtering on
"""Supported STIX properties"""
STIX_COMMON_FIELDS = [ STIX_COMMON_FIELDS = [
"created", "created",
"created_by_ref", "created_by_ref",
@ -30,32 +28,140 @@ STIX_COMMON_FIELDS = [
"modified", "modified",
"object_marking_refs", "object_marking_refs",
"revoked", "revoked",
"type", "type"
"granular_markings"
] ]
# Supported filter operations """Supported filter operations"""
FILTER_OPS = ['=', '!=', 'in', '>', '<', '>=', '<='] FILTER_OPS = ['=', '!=', 'in', '>', '<', '>=', '<=']
# Supported filter value types """Supported filter value types"""
FILTER_VALUE_TYPES = [bool, dict, float, int, list, str, tuple] FILTER_VALUE_TYPES = [bool, dict, float, int, list, str, tuple]
# filter lookup map - STIX 2 common fields -> filter method # filter lookup map - STIX 2 common fields -> filter method
STIX_COMMON_FILTERS_MAP = {} STIX_COMMON_FILTERS_MAP = {}
def _check_filter_components(field, op, value):
"""check filter meets minimum validity
Note: Currently can create Filters that are not valid
STIX2 object common properties, as filter.field value
is not checked, only filter.op, filter.value are checked
here. They are just ignored when
applied within the DataSource API. For example, a user
can add a TAXII Filter, that is extracted and sent to
a TAXII endpoint within TAXIICollection and not applied
locally (within this API).
"""
if op not in FILTER_OPS:
# check filter operator is supported
raise ValueError("Filter operator '%s' not supported for specified field: '%s'" % (op, field))
if type(value) not in FILTER_VALUE_TYPES:
# check filter value type is supported
raise TypeError("Filter value type '%s' is not supported. The type must be a python immutable type or dictionary" % type(value))
return True
class Filter(collections.namedtuple("Filter", ['field', 'op', 'value'])): class Filter(collections.namedtuple("Filter", ['field', 'op', 'value'])):
"""STIX 2 filters that support the querying functionality of STIX 2
DataStores and DataSources.
Initialized like a python tuple
Args:
field (str): filter field name, corresponds to STIX 2 object property
op (str): operator of the filter
value (str): filter field value
Example:
Filter("id", "=", "malware--0f862b01-99da-47cc-9bdb-db4a86a95bb1")
"""
__slots__ = () __slots__ = ()
def __new__(cls, field, op, value): def __new__(cls, field, op, value):
# If value is a list, convert it to a tuple so it is hashable. # If value is a list, convert it to a tuple so it is hashable.
if isinstance(value, list): if isinstance(value, list):
value = tuple(value) value = tuple(value)
_check_filter_components(field, op, value)
self = super(Filter, cls).__new__(cls, field, op, value) self = super(Filter, cls).__new__(cls, field, op, value)
return self return self
@property
def common(self):
"""return whether Filter is valid STIX2 Object common property
Note: The Filter operator and Filter value type are checked when
the filter is created, thus only leaving the Filter field to be
checked to make sure a valid STIX2 Object common property.
Note: Filters that are not valid STIX2 Object common property
Filters are still allowed to be created for extended usage of
Filter. (e.g. TAXII specific filters can be created, which are
then extracted and sent to TAXII endpoint.)
"""
return self.field in STIX_COMMON_FIELDS
def apply_common_filters(stix_objs, query):
"""Evaluate filters against a set of STIX 2.0 objects.
Supports only STIX 2.0 common property fields
Args:
stix_objs (list): list of STIX objects to apply the query to
query (set): set of filters (combined form complete query)
Returns:
(generator): of STIX objects that successfully evaluate against
the query.
"""
for stix_obj in stix_objs:
clean = True
for filter_ in query:
if not filter_.common:
# skip filter as it is not a STIX2 Object common property filter
continue
if "." in filter_.field:
# For properties like granular_markings and external_references
# need to extract the first property from the string.
field = filter_.field.split(".")[0]
else:
field = filter_.field
if field not in stix_obj.keys():
# check filter "field" is in STIX object - if cant be
# applied to STIX object, STIX object is discarded
# (i.e. did not make it through the filter)
clean = False
break
match = STIX_COMMON_FILTERS_MAP[filter_.field.split('.')[0]](filter_, stix_obj)
if not match:
clean = False
break
elif match == -1:
raise ValueError("Error, filter operator: {0} not supported for specified field: {1}".format(filter_.op, filter_.field))
# if object unmarked after all filters, add it
if clean:
yield stix_obj
"""Base type filters"""
# primitive type filters
def _all_filter(filter_, stix_obj_field): def _all_filter(filter_, stix_obj_field):
"""all filter operations (for filters whose value type can be applied to any operation type)""" """all filter operations (for filters whose value type can be applied to any operation type)"""
@ -78,7 +184,7 @@ def _all_filter(filter_, stix_obj_field):
def _id_filter(filter_, stix_obj_id): def _id_filter(filter_, stix_obj_id):
"""base filter types""" """base STIX id filter"""
if filter_.op == "=": if filter_.op == "=":
return stix_obj_id == filter_.value return stix_obj_id == filter_.value
elif filter_.op == "!=": elif filter_.op == "!=":
@ -88,6 +194,7 @@ def _id_filter(filter_, stix_obj_id):
def _boolean_filter(filter_, stix_obj_field): def _boolean_filter(filter_, stix_obj_field):
"""base boolean filter"""
if filter_.op == "=": if filter_.op == "=":
return stix_obj_field == filter_.value return stix_obj_field == filter_.value
elif filter_.op == "!=": elif filter_.op == "!=":
@ -97,19 +204,25 @@ def _boolean_filter(filter_, stix_obj_field):
def _string_filter(filter_, stix_obj_field): def _string_filter(filter_, stix_obj_field):
"""base string filter"""
return _all_filter(filter_, stix_obj_field) return _all_filter(filter_, stix_obj_field)
def _timestamp_filter(filter_, stix_obj_timestamp): def _timestamp_filter(filter_, stix_obj_timestamp):
"""base STIX 2 timestamp filter"""
return _all_filter(filter_, stix_obj_timestamp) return _all_filter(filter_, stix_obj_timestamp)
# STIX 2.0 Common Property filters
# The naming of these functions is important as """STIX 2.0 Common Property Filters
# they are used to index a mapping dictionary from
# STIX common field names to these filter functions. The naming of these functions is important as
# they are used to index a mapping dictionary from
# REQUIRED naming scheme: STIX common field names to these filter functions.
# "check_<STIX field name>_filter"
REQUIRED naming scheme:
"check_<STIX field name>_filter"
"""
def check_created_filter(filter_, stix_obj): def check_created_filter(filter_, stix_obj):
@ -124,13 +237,15 @@ def check_external_references_filter(filter_, stix_obj):
""" """
STIX object's can have a list of external references STIX object's can have a list of external references
external_references properties: external_references properties supported:
external_references.source_name (string) external_references.source_name (string)
external_references.description (string) external_references.description (string)
external_references.url (string) external_references.url (string)
external_references.hashes (hash, but for filtering purposes, a string)
external_references.external_id (string) external_references.external_id (string)
external_references properties not supported:
external_references.hashes
""" """
for er in stix_obj["external_references"]: for er in stix_obj["external_references"]:
# grab er property name from filter field # grab er property name from filter field

View File

@ -6,7 +6,8 @@ Classes:
MemorySink MemorySink
MemorySource MemorySource
TODO: Test everything.
TODO: Run through tests again, lot of changes.
TODO: Use deduplicate() calls only when memory corpus is dirty (been added to) TODO: Use deduplicate() calls only when memory corpus is dirty (been added to)
can save a lot of time for successive queries can save a lot of time for successive queries
@ -18,49 +19,87 @@ Notes:
""" """
import collections
import json import json
import os import os
from stix2 import Bundle from stix2.base import _STIXBase
from stix2.core import Bundle, parse
from stix2.sources import DataSink, DataSource, DataStore from stix2.sources import DataSink, DataSource, DataStore
from stix2.sources.filters import Filter from stix2.sources.filters import Filter, apply_common_filters
def _add(store, stix_data=None): def _add(store, stix_data=None):
"""Adds stix objects to MemoryStore/Source/Sink.""" """Adds STIX objects to MemoryStore/Sink.
if isinstance(stix_data, collections.Mapping):
# stix objects are in a bundle Adds STIX objects to an in-memory dictionary for fast lookup.
# make dictionary of the objects for easy lookup Recursive function, breaks down STIX Bundles and lists.
for stix_obj in stix_data["objects"]:
store.data[stix_obj["id"]] = stix_obj Args:
stix_data (list OR dict OR STIX object): STIX objects to be added
"""
if isinstance(stix_data, _STIXBase):
# adding a python STIX object
store._data[stix_data["id"]] = stix_data
elif isinstance(stix_data, dict):
if stix_data["type"] == "bundle":
# adding a json bundle - so just grab STIX objects
for stix_obj in stix_data["objects"]:
_add(store, stix_obj)
else:
# adding a json STIX object
store._data[stix_data["id"]] = stix_data
elif isinstance(stix_data, str):
# adding json encoded string of STIX content
stix_data = parse(stix_data)
if stix_data["type"] == "bundle":
# recurse on each STIX object in bundle
for stix_obj in stix_data:
_add(store, stix_obj)
else:
_add(store, stix_data)
elif isinstance(stix_data, list): elif isinstance(stix_data, list):
# stix objects are in a list # STIX objects are in a list- recurse on each object
for stix_obj in stix_data: for stix_obj in stix_data:
store.data[stix_obj["id"]] = stix_obj _add(store, stix_obj)
else: else:
raise ValueError("stix_data must be in bundle format or raw list") raise TypeError("stix_data must be as STIX object(or list of),json formatted STIX (or list of), or a json formatted STIX bundle")
class MemoryStore(DataStore): class MemoryStore(DataStore):
""" """Provides an interface to an in-memory dictionary
""" of STIX objects. MemoryStore is a wrapper around a paired
def __init__(self, stix_data=None): MemorySink and MemorySource
"""
Notes:
It doesn't make sense to create a MemoryStore by passing
in existing MemorySource and MemorySink because there could
be data concurrency issues. Just as easy to create new MemoryStore.
""" Note: It doesn't make sense to create a MemoryStore by passing
in existing MemorySource and MemorySink because there could
be data concurrency issues. As well, just as easy to create new MemoryStore.
Args:
stix_data (list OR dict OR STIX object): STIX content to be added
Attributes:
_data (dict): the in-memory dict that holds STIX objects
source (MemorySource): MemorySource
sink (MemorySink): MemorySink
"""
def __init__(self, stix_data=None):
super(MemoryStore, self).__init__() super(MemoryStore, self).__init__()
self.data = {} self._data = {}
if stix_data: if stix_data:
_add(self, stix_data) _add(self, stix_data)
self.source = MemorySource(stix_data=self.data, _store=True) self.source = MemorySource(stix_data=self._data, _store=True)
self.sink = MemorySink(stix_data=self.data, _store=True) self.sink = MemorySink(stix_data=self._data, _store=True)
def save_to_file(self, file_path): def save_to_file(self, file_path):
return self.sink.save_to_file(file_path=file_path) return self.sink.save_to_file(file_path=file_path)
@ -70,64 +109,107 @@ class MemoryStore(DataStore):
class MemorySink(DataSink): class MemorySink(DataSink):
""" """Provides an interface for adding/pushing STIX objects
""" to an in-memory dictionary.
def __init__(self, stix_data=None, _store=False):
"""
Args:
stix_data (dictionary OR list): valid STIX 2.0 content in
bundle or a list.
_store (bool): if the MemorySink is a part of a DataStore,
in which case "stix_data" is a direct reference to
shared memory with DataSource.
""" Designed to be paired with a MemorySource, together as the two
components of a MemoryStore.
Args:
stix_data (dict OR list): valid STIX 2.0 content in
bundle or a list.
_store (bool): if the MemorySink is a part of a DataStore,
in which case "stix_data" is a direct reference to
shared memory with DataSource. Not user supplied
Attributes:
_data (dict): the in-memory dict that holds STIX objects.
If apart of a MemoryStore, dict is shared between with
a MemorySource
"""
def __init__(self, stix_data=None, _store=False):
super(MemorySink, self).__init__() super(MemorySink, self).__init__()
self.data = {} self._data = {}
if _store: if _store:
self.data = stix_data self._data = stix_data
elif stix_data: elif stix_data:
self.add(stix_data) _add(self, stix_data)
def add(self, stix_data): def add(self, stix_data):
""" """add STIX objects to in-memory dictionary maintained by
the MemorySink (MemoryStore)
see "_add()" for args documentation
""" """
_add(self, stix_data) _add(self, stix_data)
def save_to_file(self, file_path): def save_to_file(self, file_path):
"""write SITX objects in in-memory dictionary to json file, as a STIX Bundle
Args:
file_path (str): file path to write STIX data to
""" """
""" file_path = os.path.abspath(file_path)
json.dump(Bundle(self.data.values()), file_path, indent=4) if not os.path.exists(os.path.dirname(file_path)):
os.makedirs(os.path.dirname(file_path))
with open(file_path, "w") as f:
f.write(str(Bundle(self._data.values())))
class MemorySource(DataSource): class MemorySource(DataSource):
"""Provides an interface for searching/retrieving
STIX objects from an in-memory dictionary.
Designed to be paired with a MemorySink, together as the two
components of a MemoryStore.
Args:
stix_data (dict OR list OR STIX object): valid STIX 2.0 content in
bundle or list.
_store (bool): if the MemorySource is a part of a DataStore,
in which case "stix_data" is a direct reference to shared
memory with DataSink. Not user supplied
Attributes:
_data (dict): the in-memory dict that holds STIX objects.
If apart of a MemoryStore, dict is shared between with
a MemorySink
"""
def __init__(self, stix_data=None, _store=False): def __init__(self, stix_data=None, _store=False):
"""
Args:
stix_data (dictionary OR list): valid STIX 2.0 content in
bundle or list.
_store (bool): if the MemorySource is a part of a DataStore,
in which case "stix_data" is a direct reference to shared
memory with DataSink.
"""
super(MemorySource, self).__init__() super(MemorySource, self).__init__()
self.data = {} self._data = {}
if _store: if _store:
self.data = stix_data self._data = stix_data
elif stix_data: elif stix_data:
_add(self, stix_data) _add(self, stix_data)
def get(self, stix_id, _composite_filters=None): def get(self, stix_id, _composite_filters=None):
"""retrieve STIX object from in-memory dict via STIX ID
Args:
stix_id (str): The STIX ID of the STIX object to be retrieved.
composite_filters (set): set of filters passed from the parent
CompositeDataSource, not user supplied
Returns:
(dict OR STIX object): STIX object that has the supplied
ID. As the MemoryStore(i.e. MemorySink) adds STIX objects to memory
as they are supplied (either as python dictionary or STIX object), it
is returned in the same form as it as added
""" """
"""
if _composite_filters is None: if _composite_filters is None:
# if get call is only based on 'id', no need to search, just retrieve from dict # if get call is only based on 'id', no need to search, just retrieve from dict
try: try:
stix_obj = self.data[stix_id] stix_obj = self._data[stix_id]
except KeyError: except KeyError:
stix_obj = None stix_obj = None
return stix_obj return stix_obj
@ -143,44 +225,75 @@ class MemorySource(DataSource):
return stix_obj return stix_obj
def all_versions(self, stix_id, _composite_filters=None): def all_versions(self, stix_id, _composite_filters=None):
""" """retrieve STIX objects from in-memory dict via STIX ID, all versions of it
Notes:
Since Memory sources/sinks don't handle multiple versions of a Note: Since Memory sources/sinks don't handle multiple versions of a
STIX object, this operation is unnecessary. Translate call to get(). STIX object, this operation is unnecessary. Translate call to get().
Args: Args:
stix_id (str): The id of the STIX 2.0 object to retrieve. Should stix_id (str): The STIX ID of the STIX 2 object to retrieve.
return a list of objects, all the versions of the object
specified by the "id". composite_filters (set): set of filters passed from the parent
CompositeDataSource, not user supplied
Returns: Returns:
(list): STIX object that matched ``stix_id``. (list): list of STIX objects that has the supplied ID. As the
MemoryStore(i.e. MemorySink) adds STIX objects to memory as they
are supplied (either as python dictionary or STIX object), it
is returned in the same form as it as added
""" """
return [self.get(stix_id=stix_id, _composite_filters=_composite_filters)] return [self.get(stix_id=stix_id, _composite_filters=_composite_filters)]
def query(self, query=None, _composite_filters=None): def query(self, query=None, _composite_filters=None):
""" """search and retrieve STIX objects based on the complete query
A "complete query" includes the filters from the query, the filters
attached to MemorySource, and any filters passed from a
CompositeDataSource (i.e. _composite_filters)
Args:
query (list): list of filters to search on
composite_filters (set): set of filters passed from the
CompositeDataSource, not user supplied
Returns:
(list): list of STIX objects that matches the supplied
query. As the MemoryStore(i.e. MemorySink) adds STIX objects to memory
as they are supplied (either as python dictionary or STIX object), it
is returned in the same form as it as added
""" """
if query is None: if query is None:
query = [] query = set()
else:
if not isinstance(query, list):
# make sure dont make set from a Filter object,
# need to make a set from a list of Filter objects (even if just one Filter)
query = list(query)
query = set(query)
# combine all query filters # combine all query filters
if self.filters: if self.filters:
query.extend(list(self.filters)) query.update(self.filters)
if _composite_filters: if _composite_filters:
query.extend(_composite_filters) query.update(_composite_filters)
# Apply STIX common property filters. # Apply STIX common property filters.
all_data = self.apply_common_filters(self.data.values(), query) all_data = [stix_obj for stix_obj in apply_common_filters(self._data.values(), query)]
return all_data return all_data
def load_from_file(self, file_path): def load_from_file(self, file_path):
""" """load STIX data from json file
File format is expected to be a single json
STIX object or json STIX bundle
Args:
file_path (str): file path to load STIX data from
""" """
file_path = os.path.abspath(file_path) file_path = os.path.abspath(file_path)
stix_data = json.load(open(file_path, "r")) stix_data = json.load(open(file_path, "r"))
_add(self, stix_data)
for stix_obj in stix_data["objects"]:
self.data[stix_obj["id"]] = stix_obj

View File

@ -10,83 +10,144 @@ TODO: Test everything
""" """
import json from stix2.base import _STIXBase
from stix2.core import Bundle, parse
from stix2.sources import DataSink, DataSource, DataStore, make_id from stix2.sources import DataSink, DataSource, DataStore
from stix2.sources.filters import Filter from stix2.sources.filters import Filter, apply_common_filters
from stix2.utils import deduplicate
TAXII_FILTERS = ['added_after', 'id', 'type', 'version'] TAXII_FILTERS = ['added_after', 'id', 'type', 'version']
class TAXIICollectionStore(DataStore): class TAXIICollectionStore(DataStore):
""" """Provides an interface to a local/remote TAXII Collection
of STIX data. TAXIICollectionStore is a wrapper
around a paired TAXIICollectionSink and TAXIICollectionSource.
Args:
collection (taxii2.Collection): TAXII Collection instance
""" """
def __init__(self, collection): def __init__(self, collection):
"""
Create a new TAXII Collection Data store
Args:
collection (taxii2.Collection): Collection instance
"""
super(TAXIICollectionStore, self).__init__() super(TAXIICollectionStore, self).__init__()
self.source = TAXIICollectionSource(collection) self.source = TAXIICollectionSource(collection)
self.sink = TAXIICollectionSink(collection) self.sink = TAXIICollectionSink(collection)
class TAXIICollectionSink(DataSink): class TAXIICollectionSink(DataSink):
""" """Provides an interface for pushing STIX objects to a local/remote
TAXII Collection endpoint.
Args:
collection (taxii2.Collection): TAXII2 Collection instance
""" """
def __init__(self, collection): def __init__(self, collection):
super(TAXIICollectionSink, self).__init__() super(TAXIICollectionSink, self).__init__()
self.collection = collection self.collection = collection
def add(self, stix_obj): def add(self, stix_data):
""" """add/push STIX content to TAXII Collection endpoint
"""
self.collection.add_objects(self.create_bundle([json.loads(str(stix_obj))]))
@staticmethod Args:
def create_bundle(objects): stix_data (STIX object OR dict OR str OR list): valid STIX 2.0 content
return dict(id="bundle--%s" % make_id(), in a STIX object (or Bundle), STIX onject dict (or Bundle dict), or a STIX 2.0
objects=objects, json encoded string, or list of any of the following
spec_version="2.0",
type="bundle") """
if isinstance(stix_data, _STIXBase):
# adding python STIX object
bundle = dict(Bundle(stix_data))
elif isinstance(stix_data, dict):
# adding python dict (of either Bundle or STIX obj)
if stix_data["type"] == "bundle":
bundle = stix_data
else:
bundle = dict(Bundle(stix_data))
elif isinstance(stix_data, list):
# adding list of something - recurse on each
for obj in stix_data:
self.add(obj)
elif isinstance(stix_data, str):
# adding json encoded string of STIX content
stix_data = parse(stix_data)
if stix_data["type"] == "bundle":
bundle = dict(stix_data)
else:
bundle = dict(Bundle(stix_data))
else:
raise TypeError("stix_data must be as STIX object(or list of),json formatted STIX (or list of), or a json formatted STIX bundle")
self.collection.add_objects(bundle)
class TAXIICollectionSource(DataSource): class TAXIICollectionSource(DataSource):
""" """Provides an interface for searching/retrieving STIX objects
from a local/remote TAXII Collection endpoint.
Args:
collection (taxii2.Collection): TAXII Collection instance
""" """
def __init__(self, collection): def __init__(self, collection):
super(TAXIICollectionSource, self).__init__() super(TAXIICollectionSource, self).__init__()
self.collection = collection self.collection = collection
def get(self, stix_id, _composite_filters=None): def get(self, stix_id, _composite_filters=None):
""" """retrieve STIX object from local/remote STIX Collection
endpoint.
Args:
stix_id (str): The STIX ID of the STIX object to be retrieved.
composite_filters (set): set of filters passed from the parent
CompositeDataSource, not user supplied
Returns:
(STIX object): STIX object that has the supplied STIX ID.
The STIX object is received from TAXII has dict, parsed into
a python STIX object and then returned
""" """
# combine all query filters # combine all query filters
query = [] query = set()
if self.filters: if self.filters:
query.extend(self.filters.values()) query.update(self.filters)
if _composite_filters: if _composite_filters:
query.extend(_composite_filters) query.update(_composite_filters)
# separate taxii query terms (can be done remotely) # separate taxii query terms (can be done remotely)
taxii_filters = self._parse_taxii_filters(query) taxii_filters = self._parse_taxii_filters(query)
stix_objs = self.collection.get_object(stix_id, taxii_filters)["objects"] stix_objs = self.collection.get_object(stix_id, taxii_filters)["objects"]
stix_obj = self.apply_common_filters(stix_objs, query) stix_obj = [stix_obj for stix_obj in apply_common_filters(stix_objs, query)]
if len(stix_obj) > 0: if len(stix_obj):
stix_obj = stix_obj[0] stix_obj = stix_obj[0]
else: else:
stix_obj = None stix_obj = None
return stix_obj return parse(stix_obj)
def all_versions(self, stix_id, _composite_filters=None): def all_versions(self, stix_id, _composite_filters=None):
""" """retrieve STIX object from local/remote TAXII Collection
endpoint, all versions of it
Args:
stix_id (str): The STIX ID of the STIX objects to be retrieved.
composite_filters (set): set of filters passed from the parent
CompositeDataSource, not user supplied
Returns:
(see query() as all_versions() is just a wrapper)
""" """
# make query in TAXII query format since 'id' is TAXII field # make query in TAXII query format since 'id' is TAXII field
query = [ query = [
@ -99,16 +160,39 @@ class TAXIICollectionSource(DataSource):
return all_data return all_data
def query(self, query=None, _composite_filters=None): def query(self, query=None, _composite_filters=None):
"""search and retreive STIX objects based on the complete query
A "complete query" includes the filters from the query, the filters
attached to MemorySource, and any filters passed from a
CompositeDataSource (i.e. _composite_filters)
Args:
query (list): list of filters to search on
composite_filters (set): set of filters passed from the
CompositeDataSource, not user supplied
Returns:
(list): list of STIX objects that matches the supplied
query. The STIX objects are received from TAXII as dicts,
parsed into python STIX objects and then returned.
""" """
"""
if query is None: if query is None:
query = [] query = set()
else:
if not isinstance(query, list):
# make sure dont make set from a Filter object,
# need to make a set from a list of Filter objects (even if just one Filter)
query = list(query)
query = set(query)
# combine all query filters # combine all query filters
if self.filters: if self.filters:
query.extend(self.filters.values()) query.update(self.filters)
if _composite_filters: if _composite_filters:
query.extend(_composite_filters) query.update(_composite_filters)
# separate taxii query terms (can be done remotely) # separate taxii query terms (can be done remotely)
taxii_filters = self._parse_taxii_filters(query) taxii_filters = self._parse_taxii_filters(query)
@ -117,12 +201,15 @@ class TAXIICollectionSource(DataSource):
all_data = self.collection.get_objects(filters=taxii_filters)["objects"] all_data = self.collection.get_objects(filters=taxii_filters)["objects"]
# deduplicate data (before filtering as reduces wasted filtering) # deduplicate data (before filtering as reduces wasted filtering)
all_data = self.deduplicate(all_data) all_data = deduplicate(all_data)
# apply local (composite and data source filters) # apply local (CompositeDataSource, TAXIICollectionSource and query filters)
all_data = self.apply_common_filters(all_data, query) all_data = [stix_obj for stix_obj in apply_common_filters(all_data, query)]
return all_data # parse python STIX objects from the STIX object dicts
stix_objs = [parse(stix_obj_dict) for stix_obj_dict in all_data]
return stix_objs
def _parse_taxii_filters(self, query): def _parse_taxii_filters(self, query):
"""Parse out TAXII filters that the TAXII server can filter on. """Parse out TAXII filters that the TAXII server can filter on.
@ -142,6 +229,7 @@ class TAXIICollectionSource(DataSource):
for 'requests.get()'. for 'requests.get()'.
""" """
params = {} params = {}
for filter_ in query: for filter_ in query:

View File

@ -4,13 +4,18 @@ from collections import OrderedDict
from .base import _STIXBase from .base import _STIXBase
from .common import ExternalReference, GranularMarking from .common import ExternalReference, GranularMarking
from .markings import MarkingsMixin
from .properties import (BooleanProperty, IDProperty, IntegerProperty, from .properties import (BooleanProperty, IDProperty, IntegerProperty,
ListProperty, ReferenceProperty, StringProperty, ListProperty, ReferenceProperty, StringProperty,
TimestampProperty, TypeProperty) TimestampProperty, TypeProperty)
from .utils import NOW from .utils import NOW
class Relationship(_STIXBase): class STIXRelationshipObject(_STIXBase, MarkingsMixin):
pass
class Relationship(STIXRelationshipObject):
_type = 'relationship' _type = 'relationship'
_properties = OrderedDict() _properties = OrderedDict()
@ -45,7 +50,7 @@ class Relationship(_STIXBase):
super(Relationship, self).__init__(**kwargs) super(Relationship, self).__init__(**kwargs)
class Sighting(_STIXBase): class Sighting(STIXRelationshipObject):
_type = 'sighting' _type = 'sighting'
_properties = OrderedDict() _properties = OrderedDict()
_properties.update([ _properties.update([

View File

@ -3,8 +3,9 @@ from taxii2client import Collection
from stix2.sources import (CompositeDataSource, DataSink, DataSource, from stix2.sources import (CompositeDataSource, DataSink, DataSource,
DataStore, make_id, taxii) DataStore, make_id, taxii)
from stix2.sources.filters import Filter from stix2.sources.filters import Filter, apply_common_filters
from stix2.sources.memory import MemorySource, MemoryStore from stix2.sources.memory import MemorySource, MemoryStore
from stix2.utils import deduplicate
COLLECTION_URL = 'https://example.com/api1/collections/91a7b528-80eb-42ed-a74d-c6fbd5a26116/' COLLECTION_URL = 'https://example.com/api1/collections/91a7b528-80eb-42ed-a74d-c6fbd5a26116/'
@ -206,39 +207,37 @@ def test_add_get_remove_filter(ds):
Filter('id', '!=', 'stix object id'), Filter('id', '!=', 'stix object id'),
Filter('labels', 'in', ["heartbleed", "malicious-activity"]), Filter('labels', 'in', ["heartbleed", "malicious-activity"]),
] ]
invalid_filters = [
Filter('description', '=', 'not supported field - just place holder'), # Invalid filters - wont pass creation
Filter('modified', '*', 'not supported operator - just place holder'), # these filters will not be allowed to be created
Filter('created', '=', object()), # check proper errors are raised when trying to create them
]
with pytest.raises(ValueError) as excinfo:
# create Filter that has an operator that is not allowed
Filter('modified', '*', 'not supported operator - just place holder')
assert str(excinfo.value) == "Filter operator '*' not supported for specified field: 'modified'"
with pytest.raises(TypeError) as excinfo:
# create Filter that has a value type that is not allowed
Filter('created', '=', object())
# On Python 2, the type of object() is `<type 'object'>` On Python 3, it's `<class 'object'>`.
assert str(excinfo.value).startswith("Filter value type")
assert str(excinfo.value).endswith("is not supported. The type must be a python immutable type or dictionary")
assert len(ds.filters) == 0 assert len(ds.filters) == 0
ds.add_filter(valid_filters[0]) ds.filters.add(valid_filters[0])
assert len(ds.filters) == 1 assert len(ds.filters) == 1
# Addin the same filter again will have no effect since `filters` uses a set # Addin the same filter again will have no effect since `filters` uses a set
ds.add_filter(valid_filters[0]) ds.filters.add(valid_filters[0])
assert len(ds.filters) == 1 assert len(ds.filters) == 1
ds.add_filter(valid_filters[1]) ds.filters.add(valid_filters[1])
assert len(ds.filters) == 2 assert len(ds.filters) == 2
ds.add_filter(valid_filters[2]) ds.filters.add(valid_filters[2])
assert len(ds.filters) == 3 assert len(ds.filters) == 3
# TODO: make better error messages
with pytest.raises(ValueError) as excinfo:
ds.add_filter(invalid_filters[0])
assert str(excinfo.value) == "Filter 'field' is not a STIX 2.0 common property. Currently only STIX object common properties supported"
with pytest.raises(ValueError) as excinfo:
ds.add_filter(invalid_filters[1])
assert str(excinfo.value) == "Filter operation (from 'op' field) not supported"
with pytest.raises(ValueError) as excinfo:
ds.add_filter(invalid_filters[2])
assert str(excinfo.value) == "Filter 'value' type is not supported. The type(value) must be python immutable type or dictionary"
assert set(valid_filters) == ds.filters assert set(valid_filters) == ds.filters
# remove # remove
@ -246,7 +245,7 @@ def test_add_get_remove_filter(ds):
assert len(ds.filters) == 2 assert len(ds.filters) == 2
ds.add_filters(valid_filters) ds.filters.update(valid_filters)
def test_apply_common_filters(ds): def test_apply_common_filters(ds):
@ -320,7 +319,6 @@ def test_apply_common_filters(ds):
Filter("created", ">", "2015-01-01T01:00:00.000Z"), Filter("created", ">", "2015-01-01T01:00:00.000Z"),
Filter("revoked", "=", True), Filter("revoked", "=", True),
Filter("revoked", "!=", True), Filter("revoked", "!=", True),
Filter("revoked", "?", False),
Filter("object_marking_refs", "=", "marking-definition--613f2e26-407d-48c7-9eca-b8e91df99dc9"), Filter("object_marking_refs", "=", "marking-definition--613f2e26-407d-48c7-9eca-b8e91df99dc9"),
Filter("granular_markings.selectors", "in", "relationship_type"), Filter("granular_markings.selectors", "in", "relationship_type"),
Filter("granular_markings.marking_ref", "=", "marking-definition--5e57c739-391a-4eb3-b6be-7d15ca92d5ed"), Filter("granular_markings.marking_ref", "=", "marking-definition--5e57c739-391a-4eb3-b6be-7d15ca92d5ed"),
@ -332,7 +330,7 @@ def test_apply_common_filters(ds):
] ]
# "Return any object whose type is not relationship" # "Return any object whose type is not relationship"
resp = ds.apply_common_filters(stix_objs, [filters[0]]) resp = [stix_obj for stix_obj in apply_common_filters(stix_objs, [filters[0]])]
ids = [r['id'] for r in resp] ids = [r['id'] for r in resp]
assert stix_objs[0]['id'] in ids assert stix_objs[0]['id'] in ids
assert stix_objs[1]['id'] in ids assert stix_objs[1]['id'] in ids
@ -340,138 +338,109 @@ def test_apply_common_filters(ds):
assert len(ids) == 3 assert len(ids) == 3
# "Return any object that matched id relationship--2f9a9aa9-108a-4333-83e2-4fb25add0463" # "Return any object that matched id relationship--2f9a9aa9-108a-4333-83e2-4fb25add0463"
resp = ds.apply_common_filters(stix_objs, [filters[1]]) resp = [stix_obj for stix_obj in apply_common_filters(stix_objs, [filters[1]])]
assert resp[0]['id'] == stix_objs[2]['id'] assert resp[0]['id'] == stix_objs[2]['id']
assert len(resp) == 1 assert len(resp) == 1
# "Return any object that contains remote-access-trojan in labels" # "Return any object that contains remote-access-trojan in labels"
resp = ds.apply_common_filters(stix_objs, [filters[2]]) resp = [stix_obj for stix_obj in apply_common_filters(stix_objs, [filters[2]])]
assert resp[0]['id'] == stix_objs[0]['id'] assert resp[0]['id'] == stix_objs[0]['id']
assert len(resp) == 1 assert len(resp) == 1
# "Return any object created after 2015-01-01T01:00:00.000Z" # "Return any object created after 2015-01-01T01:00:00.000Z"
resp = ds.apply_common_filters(stix_objs, [filters[3]]) resp = [stix_obj for stix_obj in apply_common_filters(stix_objs, [filters[3]])]
assert resp[0]['id'] == stix_objs[0]['id'] assert resp[0]['id'] == stix_objs[0]['id']
assert len(resp) == 2 assert len(resp) == 2
# "Return any revoked object" # "Return any revoked object"
resp = ds.apply_common_filters(stix_objs, [filters[4]]) resp = [stix_obj for stix_obj in apply_common_filters(stix_objs, [filters[4]])]
assert resp[0]['id'] == stix_objs[2]['id'] assert resp[0]['id'] == stix_objs[2]['id']
assert len(resp) == 1 assert len(resp) == 1
# "Return any object whose not revoked" # "Return any object whose not revoked"
# Note that if 'revoked' property is not present in object. # Note that if 'revoked' property is not present in object.
# Currently we can't use such an expression to filter for... :( # Currently we can't use such an expression to filter for... :(
resp = ds.apply_common_filters(stix_objs, [filters[5]]) resp = [stix_obj for stix_obj in apply_common_filters(stix_objs, [filters[5]])]
assert len(resp) == 0 assert len(resp) == 0
# Assert unknown operator for _boolean() raises exception.
with pytest.raises(ValueError) as excinfo:
ds.apply_common_filters(stix_objs, [filters[6]])
assert str(excinfo.value) == ("Error, filter operator: {0} not supported "
"for specified field: {1}"
.format(filters[6].op, filters[6].field))
# "Return any object that matches marking-definition--613f2e26-407d-48c7-9eca-b8e91df99dc9 in object_marking_refs" # "Return any object that matches marking-definition--613f2e26-407d-48c7-9eca-b8e91df99dc9 in object_marking_refs"
resp = ds.apply_common_filters(stix_objs, [filters[7]]) resp = [stix_obj for stix_obj in apply_common_filters(stix_objs, [filters[6]])]
assert resp[0]['id'] == stix_objs[2]['id'] assert resp[0]['id'] == stix_objs[2]['id']
assert len(resp) == 1 assert len(resp) == 1
# "Return any object that contains relationship_type in their selectors AND # "Return any object that contains relationship_type in their selectors AND
# also has marking-definition--5e57c739-391a-4eb3-b6be-7d15ca92d5ed in marking_ref" # also has marking-definition--5e57c739-391a-4eb3-b6be-7d15ca92d5ed in marking_ref"
resp = ds.apply_common_filters(stix_objs, [filters[8], filters[9]]) resp = [stix_obj for stix_obj in apply_common_filters(stix_objs, [filters[7], filters[8]])]
assert resp[0]['id'] == stix_objs[2]['id'] assert resp[0]['id'] == stix_objs[2]['id']
assert len(resp) == 1 assert len(resp) == 1
# "Return any object that contains CVE-2014-0160,CVE-2017-6608 in their external_id" # "Return any object that contains CVE-2014-0160,CVE-2017-6608 in their external_id"
resp = ds.apply_common_filters(stix_objs, [filters[10]]) resp = [stix_obj for stix_obj in apply_common_filters(stix_objs, [filters[9]])]
assert resp[0]['id'] == stix_objs[3]['id'] assert resp[0]['id'] == stix_objs[3]['id']
assert len(resp) == 1 assert len(resp) == 1
# "Return any object that matches created_by_ref identity--00000000-0000-0000-0000-b8e91df99dc9" # "Return any object that matches created_by_ref identity--00000000-0000-0000-0000-b8e91df99dc9"
resp = ds.apply_common_filters(stix_objs, [filters[11]]) resp = [stix_obj for stix_obj in apply_common_filters(stix_objs, [filters[10]])]
assert len(resp) == 1 assert len(resp) == 1
# "Return any object that matches marking-definition--613f2e26-0000-0000-0000-b8e91df99dc9 in object_marking_refs" (None) # "Return any object that matches marking-definition--613f2e26-0000-0000-0000-b8e91df99dc9 in object_marking_refs" (None)
resp = ds.apply_common_filters(stix_objs, [filters[12]]) resp = [stix_obj for stix_obj in apply_common_filters(stix_objs, [filters[11]])]
assert len(resp) == 0 assert len(resp) == 0
# "Return any object that contains description in its selectors" (None) # "Return any object that contains description in its selectors" (None)
resp = ds.apply_common_filters(stix_objs, [filters[13]]) resp = [stix_obj for stix_obj in apply_common_filters(stix_objs, [filters[12]])]
assert len(resp) == 0 assert len(resp) == 0
# "Return any object that object that matches CVE in source_name" (None, case sensitive) # "Return any object that object that matches CVE in source_name" (None, case sensitive)
resp = ds.apply_common_filters(stix_objs, [filters[14]]) resp = [stix_obj for stix_obj in apply_common_filters(stix_objs, [filters[13]])]
assert len(resp) == 0 assert len(resp) == 0
def test_filters0(ds): def test_filters0(ds):
# "Return any object modified before 2017-01-28T13:49:53.935Z" # "Return any object modified before 2017-01-28T13:49:53.935Z"
resp = ds.apply_common_filters(STIX_OBJS2, [Filter("modified", "<", "2017-01-28T13:49:53.935Z")]) resp = [stix_obj for stix_obj in apply_common_filters(STIX_OBJS2, [Filter("modified", "<", "2017-01-28T13:49:53.935Z")])]
assert resp[0]['id'] == STIX_OBJS2[1]['id'] assert resp[0]['id'] == STIX_OBJS2[1]['id']
assert len(resp) == 2 assert len(resp) == 2
def test_filters1(ds): def test_filters1(ds):
# "Return any object modified after 2017-01-28T13:49:53.935Z" # "Return any object modified after 2017-01-28T13:49:53.935Z"
resp = ds.apply_common_filters(STIX_OBJS2, [Filter("modified", ">", "2017-01-28T13:49:53.935Z")]) resp = [stix_obj for stix_obj in apply_common_filters(STIX_OBJS2, [Filter("modified", ">", "2017-01-28T13:49:53.935Z")])]
assert resp[0]['id'] == STIX_OBJS2[0]['id'] assert resp[0]['id'] == STIX_OBJS2[0]['id']
assert len(resp) == 1 assert len(resp) == 1
def test_filters2(ds): def test_filters2(ds):
# "Return any object modified after or on 2017-01-28T13:49:53.935Z" # "Return any object modified after or on 2017-01-28T13:49:53.935Z"
resp = ds.apply_common_filters(STIX_OBJS2, [Filter("modified", ">=", "2017-01-27T13:49:53.935Z")]) resp = [stix_obj for stix_obj in apply_common_filters(STIX_OBJS2, [Filter("modified", ">=", "2017-01-27T13:49:53.935Z")])]
assert resp[0]['id'] == STIX_OBJS2[0]['id'] assert resp[0]['id'] == STIX_OBJS2[0]['id']
assert len(resp) == 3 assert len(resp) == 3
def test_filters3(ds): def test_filters3(ds):
# "Return any object modified before or on 2017-01-28T13:49:53.935Z" # "Return any object modified before or on 2017-01-28T13:49:53.935Z"
resp = ds.apply_common_filters(STIX_OBJS2, [Filter("modified", "<=", "2017-01-27T13:49:53.935Z")]) resp = [stix_obj for stix_obj in apply_common_filters(STIX_OBJS2, [Filter("modified", "<=", "2017-01-27T13:49:53.935Z")])]
assert resp[0]['id'] == STIX_OBJS2[1]['id'] assert resp[0]['id'] == STIX_OBJS2[1]['id']
assert len(resp) == 2 assert len(resp) == 2
def test_filters4(ds): def test_filters4(ds):
fltr4 = Filter("modified", "?", "2017-01-27T13:49:53.935Z") # Assert invalid Filter cannot be created
# Assert unknown operator for _all() raises exception.
with pytest.raises(ValueError) as excinfo: with pytest.raises(ValueError) as excinfo:
ds.apply_common_filters(STIX_OBJS2, [fltr4]) Filter("modified", "?", "2017-01-27T13:49:53.935Z")
assert str(excinfo.value) == ("Error, filter operator: {0} not supported " assert str(excinfo.value) == ("Filter operator '?' not supported "
"for specified field: {1}").format(fltr4.op, fltr4.field) "for specified field: 'modified'")
def test_filters5(ds): def test_filters5(ds):
# "Return any object whose id is not indicator--d81f86b8-975b-bc0b-775e-810c5ad45a4f" # "Return any object whose id is not indicator--d81f86b8-975b-bc0b-775e-810c5ad45a4f"
resp = ds.apply_common_filters(STIX_OBJS2, [Filter("id", "!=", "indicator--d81f86b8-975b-bc0b-775e-810c5ad45a4f")]) resp = [stix_obj for stix_obj in apply_common_filters(STIX_OBJS2, [Filter("id", "!=", "indicator--d81f86b8-975b-bc0b-775e-810c5ad45a4f")])]
assert resp[0]['id'] == STIX_OBJS2[0]['id'] assert resp[0]['id'] == STIX_OBJS2[0]['id']
assert len(resp) == 1 assert len(resp) == 1
def test_filters6(ds):
fltr6 = Filter("id", "?", "indicator--d81f86b8-975b-bc0b-775e-810c5ad45a4f")
# Assert unknown operator for _id() raises exception.
with pytest.raises(ValueError) as excinfo:
ds.apply_common_filters(STIX_OBJS2, [fltr6])
assert str(excinfo.value) == ("Error, filter operator: {0} not supported "
"for specified field: {1}").format(fltr6.op, fltr6.field)
def test_filters7(ds):
fltr7 = Filter("notacommonproperty", "=", "bar")
# Assert unknown field raises exception.
with pytest.raises(ValueError) as excinfo:
ds.apply_common_filters(STIX_OBJS2, [fltr7])
assert str(excinfo.value) == ("Error, field: {0} is not supported for "
"filtering on.").format(fltr7.field)
def test_deduplicate(ds): def test_deduplicate(ds):
unique = ds.deduplicate(STIX_OBJS1) unique = deduplicate(STIX_OBJS1)
# Only 3 objects are unique # Only 3 objects are unique
# 2 id's vary # 2 id's vary
@ -494,17 +463,19 @@ def test_add_remove_composite_datasource():
ds2 = DataSource() ds2 = DataSource()
ds3 = DataSink() ds3 = DataSink()
cds.add_data_source([ds1, ds2, ds1, ds3]) with pytest.raises(TypeError) as excinfo:
cds.add_data_sources([ds1, ds2, ds1, ds3])
assert str(excinfo.value) == ("DataSource (to be added) is not of type "
"stix2.DataSource. DataSource type is '<class 'stix2.sources.DataSink'>'")
cds.add_data_sources([ds1, ds2, ds1])
assert len(cds.get_all_data_sources()) == 2 assert len(cds.get_all_data_sources()) == 2
cds.remove_data_source([ds1.id, ds2.id]) cds.remove_data_sources([ds1.id, ds2.id])
assert len(cds.get_all_data_sources()) == 0 assert len(cds.get_all_data_sources()) == 0
with pytest.raises(ValueError):
cds.remove_data_source([ds3.id])
def test_composite_datasource_operations(): def test_composite_datasource_operations():
BUNDLE1 = dict(id="bundle--%s" % make_id(), BUNDLE1 = dict(id="bundle--%s" % make_id(),
@ -515,7 +486,7 @@ def test_composite_datasource_operations():
ds1 = MemorySource(stix_data=BUNDLE1) ds1 = MemorySource(stix_data=BUNDLE1)
ds2 = MemorySource(stix_data=STIX_OBJS2) ds2 = MemorySource(stix_data=STIX_OBJS2)
cds.add_data_source([ds1, ds2]) cds.add_data_sources([ds1, ds2])
indicators = cds.all_versions("indicator--d81f86b9-975b-bc0b-775e-810c5ad45a4f") indicators = cds.all_versions("indicator--d81f86b9-975b-bc0b-775e-810c5ad45a4f")

View File

@ -150,13 +150,11 @@ def test_environment_no_datastore():
env.query(INDICATOR_ID) env.query(INDICATOR_ID)
assert 'Environment has no data source' in str(excinfo.value) assert 'Environment has no data source' in str(excinfo.value)
with pytest.raises(AttributeError) as excinfo:
env.add_filters(INDICATOR_ID)
assert 'Environment has no data source' in str(excinfo.value)
with pytest.raises(AttributeError) as excinfo: def test_environment_add_filters():
env.add_filter(INDICATOR_ID) env = stix2.Environment(factory=stix2.ObjectFactory())
assert 'Environment has no data source' in str(excinfo.value) env.add_filters([INDICATOR_ID])
env.add_filter(INDICATOR_ID)
def test_environment_datastore_and_no_object_factory(): def test_environment_datastore_and_no_object_factory():

View File

@ -1,7 +1,7 @@
import pytest import pytest
from stix2 import Malware, markings from stix2 import TLP_RED, Malware, markings
from .constants import MALWARE_MORE_KWARGS as MALWARE_KWARGS_CONST from .constants import MALWARE_MORE_KWARGS as MALWARE_KWARGS_CONST
from .constants import MARKING_IDS from .constants import MARKING_IDS
@ -45,6 +45,7 @@ def test_add_marking_mark_one_selector_multiple_refs():
}, },
], ],
**MALWARE_KWARGS), **MALWARE_KWARGS),
MARKING_IDS[0],
), ),
( (
MALWARE_KWARGS, MALWARE_KWARGS,
@ -56,13 +57,26 @@ def test_add_marking_mark_one_selector_multiple_refs():
}, },
], ],
**MALWARE_KWARGS), **MALWARE_KWARGS),
MARKING_IDS[0],
),
(
Malware(**MALWARE_KWARGS),
Malware(
granular_markings=[
{
"selectors": ["description", "name"],
"marking_ref": TLP_RED.id,
},
],
**MALWARE_KWARGS),
TLP_RED,
), ),
]) ])
def test_add_marking_mark_multiple_selector_one_refs(data): def test_add_marking_mark_multiple_selector_one_refs(data):
before = data[0] before = data[0]
after = data[1] after = data[1]
before = markings.add_markings(before, [MARKING_IDS[0]], ["description", "name"]) before = markings.add_markings(before, data[2], ["description", "name"])
for m in before["granular_markings"]: for m in before["granular_markings"]:
assert m in after["granular_markings"] assert m in after["granular_markings"]
@ -347,36 +361,42 @@ def test_get_markings_positional_arguments_combinations(data):
assert set(markings.get_markings(data, "x.z.foo2", False, True)) == set(["10"]) assert set(markings.get_markings(data, "x.z.foo2", False, True)) == set(["10"])
@pytest.mark.parametrize("before", [ @pytest.mark.parametrize("data", [
Malware( (
granular_markings=[ Malware(
{ granular_markings=[
"selectors": ["description"], {
"marking_ref": MARKING_IDS[0] "selectors": ["description"],
}, "marking_ref": MARKING_IDS[0]
{ },
"selectors": ["description"], {
"marking_ref": MARKING_IDS[1] "selectors": ["description"],
}, "marking_ref": MARKING_IDS[1]
], },
**MALWARE_KWARGS ],
**MALWARE_KWARGS
),
[MARKING_IDS[0], MARKING_IDS[1]],
), ),
dict( (
granular_markings=[ dict(
{ granular_markings=[
"selectors": ["description"], {
"marking_ref": MARKING_IDS[0] "selectors": ["description"],
}, "marking_ref": MARKING_IDS[0]
{ },
"selectors": ["description"], {
"marking_ref": MARKING_IDS[1] "selectors": ["description"],
}, "marking_ref": MARKING_IDS[1]
], },
**MALWARE_KWARGS ],
**MALWARE_KWARGS
),
[MARKING_IDS[0], MARKING_IDS[1]],
), ),
]) ])
def test_remove_marking_remove_one_selector_with_multiple_refs(before): def test_remove_marking_remove_one_selector_with_multiple_refs(data):
before = markings.remove_markings(before, [MARKING_IDS[0], MARKING_IDS[1]], ["description"]) before = markings.remove_markings(data[0], data[1], ["description"])
assert "granular_markings" not in before assert "granular_markings" not in before

View File

@ -241,4 +241,14 @@ def test_marking_wrong_type_construction():
assert str(excinfo.value) == "Must supply a list, containing tuples. For example, [('property1', IntegerProperty())]" assert str(excinfo.value) == "Must supply a list, containing tuples. For example, [('property1', IntegerProperty())]"
# TODO: Add other examples def test_campaign_add_markings():
campaign = stix2.Campaign(
id="campaign--8e2e2d2b-17d4-4cbf-938f-98ee46b3cd3f",
created_by_ref="identity--f431f809-377b-45e0-aa1c-6a4751cae5ff",
created="2016-04-06T20:03:00Z",
modified="2016-04-06T20:03:00Z",
name="Green Group Attacks Against Finance",
description="Campaign by Green Group against a series of targets in the financial services sector.",
)
campaign = campaign.add_markings(TLP_WHITE)
assert campaign.object_marking_refs[0] == TLP_WHITE.id

View File

@ -1,7 +1,7 @@
import pytest import pytest
from stix2 import Malware, exceptions, markings from stix2 import TLP_AMBER, Malware, exceptions, markings
from .constants import FAKE_TIME, MALWARE_ID, MARKING_IDS from .constants import FAKE_TIME, MALWARE_ID, MARKING_IDS
from .constants import MALWARE_KWARGS as MALWARE_KWARGS_CONST from .constants import MALWARE_KWARGS as MALWARE_KWARGS_CONST
@ -21,18 +21,26 @@ MALWARE_KWARGS.update({
Malware(**MALWARE_KWARGS), Malware(**MALWARE_KWARGS),
Malware(object_marking_refs=[MARKING_IDS[0]], Malware(object_marking_refs=[MARKING_IDS[0]],
**MALWARE_KWARGS), **MALWARE_KWARGS),
MARKING_IDS[0],
), ),
( (
MALWARE_KWARGS, MALWARE_KWARGS,
dict(object_marking_refs=[MARKING_IDS[0]], dict(object_marking_refs=[MARKING_IDS[0]],
**MALWARE_KWARGS), **MALWARE_KWARGS),
MARKING_IDS[0],
),
(
Malware(**MALWARE_KWARGS),
Malware(object_marking_refs=[TLP_AMBER.id],
**MALWARE_KWARGS),
TLP_AMBER,
), ),
]) ])
def test_add_markings_one_marking(data): def test_add_markings_one_marking(data):
before = data[0] before = data[0]
after = data[1] after = data[1]
before = markings.add_markings(before, MARKING_IDS[0], None) before = markings.add_markings(before, data[2], None)
for m in before["object_marking_refs"]: for m in before["object_marking_refs"]:
assert m in after["object_marking_refs"] assert m in after["object_marking_refs"]
@ -280,19 +288,28 @@ def test_remove_markings_object_level(data):
**MALWARE_KWARGS), **MALWARE_KWARGS),
Malware(object_marking_refs=[MARKING_IDS[1]], Malware(object_marking_refs=[MARKING_IDS[1]],
**MALWARE_KWARGS), **MALWARE_KWARGS),
[MARKING_IDS[0], MARKING_IDS[2]],
), ),
( (
dict(object_marking_refs=[MARKING_IDS[0], MARKING_IDS[1], MARKING_IDS[2]], dict(object_marking_refs=[MARKING_IDS[0], MARKING_IDS[1], MARKING_IDS[2]],
**MALWARE_KWARGS), **MALWARE_KWARGS),
dict(object_marking_refs=[MARKING_IDS[1]], dict(object_marking_refs=[MARKING_IDS[1]],
**MALWARE_KWARGS), **MALWARE_KWARGS),
[MARKING_IDS[0], MARKING_IDS[2]],
),
(
Malware(object_marking_refs=[MARKING_IDS[0], MARKING_IDS[1], TLP_AMBER.id],
**MALWARE_KWARGS),
Malware(object_marking_refs=[MARKING_IDS[1]],
**MALWARE_KWARGS),
[MARKING_IDS[0], TLP_AMBER],
), ),
]) ])
def test_remove_markings_multiple(data): def test_remove_markings_multiple(data):
before = data[0] before = data[0]
after = data[1] after = data[1]
before = markings.remove_markings(before, [MARKING_IDS[0], MARKING_IDS[2]], None) before = markings.remove_markings(before, data[2], None)
assert before['object_marking_refs'] == after['object_marking_refs'] assert before['object_marking_refs'] == after['object_marking_refs']

View File

@ -33,6 +33,34 @@ class STIXdatetime(dt.datetime):
return "'%s'" % format_datetime(self) return "'%s'" % format_datetime(self)
def deduplicate(stix_obj_list):
"""Deduplicate a list of STIX objects to a unique set
Reduces a set of STIX objects to unique set by looking
at 'id' and 'modified' fields - as a unique object version
is determined by the combination of those fields
Note: Be aware, as can be seen in the implementation
of deduplicate(),that if the "stix_obj_list" argument has
multiple STIX objects of the same version, the last object
version found in the list will be the one that is returned.
()
Args:
stix_obj_list (list): list of STIX objects (dicts)
Returns:
A list with a unique set of the passed list of STIX objects.
"""
unique_objs = {}
for obj in stix_obj_list:
unique_objs[(obj['id'], obj['modified'])] = obj
return list(unique_objs.values())
def get_timestamp(): def get_timestamp():
return STIXdatetime.now(tz=pytz.UTC) return STIXdatetime.now(tz=pytz.UTC)