diff --git a/stix2/__init__.py b/stix2/__init__.py index 7be0904..53c2fb1 100644 --- a/stix2/__init__.py +++ b/stix2/__init__.py @@ -8,6 +8,8 @@ from .common import (TLP_AMBER, TLP_GREEN, TLP_RED, TLP_WHITE, CustomMarking, MarkingDefinition, StatementMarking, TLPMarking) from .core import Bundle, _register_type, parse from .environment import Environment, ObjectFactory +from .markings import (add_markings, clear_markings, get_markings, is_marked, + remove_markings, set_markings) from .observables import (URL, AlternateDataStream, ArchiveExt, Artifact, AutonomousSystem, CustomObservable, Directory, DomainName, EmailAddress, EmailMessage, diff --git a/stix2/common.py b/stix2/common.py index a2e6918..d7994c6 100644 --- a/stix2/common.py +++ b/stix2/common.py @@ -3,6 +3,7 @@ from collections import OrderedDict from .base import _STIXBase +from .markings import MarkingsMixin from .properties import (HashesProperty, IDProperty, ListProperty, Property, ReferenceProperty, SelectorProperty, StringProperty, TimestampProperty, TypeProperty) @@ -76,7 +77,7 @@ class MarkingProperty(Property): raise ValueError("must be a Statement, TLP Marking or a registered marking.") -class MarkingDefinition(_STIXBase): +class MarkingDefinition(_STIXBase, MarkingsMixin): _type = 'marking-definition' _properties = OrderedDict() _properties.update([ diff --git a/stix2/environment.py b/stix2/environment.py index 8e24c9b..c4816ee 100644 --- a/stix2/environment.py +++ b/stix2/environment.py @@ -1,7 +1,7 @@ import copy from .core import parse as _parse -from .sources import CompositeDataSource, DataSource, DataStore +from .sources import CompositeDataSource, DataStore class ObjectFactory(object): @@ -132,17 +132,15 @@ class Environment(object): def add_filters(self, *args, **kwargs): try: - return self.source.add_filters(*args, **kwargs) + return self.source.filters.update(*args, **kwargs) except AttributeError: raise AttributeError('Environment has no data source') - add_filters.__doc__ = DataSource.add_filters.__doc__ def add_filter(self, *args, **kwargs): try: - return self.source.add_filter(*args, **kwargs) + return self.source.filters.add(*args, **kwargs) except AttributeError: raise AttributeError('Environment has no data source') - add_filter.__doc__ = DataSource.add_filter.__doc__ def add(self, *args, **kwargs): try: diff --git a/stix2/markings/__init__.py b/stix2/markings/__init__.py index 4f72e4c..41c761d 100644 --- a/stix2/markings/__init__.py +++ b/stix2/markings/__init__.py @@ -212,3 +212,16 @@ def is_marked(obj, marking=None, selectors=None, inherited=False, descendants=Fa result = result or object_markings.is_marked(obj, object_marks) return result + + +class MarkingsMixin(): + pass + + +# Note that all of these methods will return a new object because of immutability +MarkingsMixin.get_markings = get_markings +MarkingsMixin.set_markings = set_markings +MarkingsMixin.remove_markings = remove_markings +MarkingsMixin.add_markings = add_markings +MarkingsMixin.clear_markings = clear_markings +MarkingsMixin.is_marked = is_marked diff --git a/stix2/markings/granular_markings.py b/stix2/markings/granular_markings.py index 7e9ccc7..5afd1cc 100644 --- a/stix2/markings/granular_markings.py +++ b/stix2/markings/granular_markings.py @@ -88,6 +88,7 @@ def remove_markings(obj, marking, selectors): """ selectors = utils.convert_to_list(selectors) + marking = utils.convert_to_marking_list(marking) utils.validate(obj, selectors) granular_markings = obj.get("granular_markings") @@ -97,12 +98,9 @@ def remove_markings(obj, marking, selectors): granular_markings = utils.expand_markings(granular_markings) - if isinstance(marking, list): - to_remove = [] - for m in marking: - to_remove.append({"marking_ref": m, "selectors": selectors}) - else: - to_remove = [{"marking_ref": marking, "selectors": selectors}] + to_remove = [] + for m in marking: + to_remove.append({"marking_ref": m, "selectors": selectors}) remove = utils.build_granular_marking(to_remove).get("granular_markings") @@ -140,14 +138,12 @@ def add_markings(obj, marking, selectors): """ selectors = utils.convert_to_list(selectors) + marking = utils.convert_to_marking_list(marking) utils.validate(obj, selectors) - if isinstance(marking, list): - granular_marking = [] - for m in marking: - granular_marking.append({"marking_ref": m, "selectors": sorted(selectors)}) - else: - granular_marking = [{"marking_ref": marking, "selectors": sorted(selectors)}] + granular_marking = [] + for m in marking: + granular_marking.append({"marking_ref": m, "selectors": sorted(selectors)}) if obj.get("granular_markings"): granular_marking.extend(obj.get("granular_markings")) @@ -244,7 +240,7 @@ def is_marked(obj, marking=None, selectors=None, inherited=False, descendants=Fa raise TypeError("Required argument 'selectors' must be provided") selectors = utils.convert_to_list(selectors) - marking = utils.convert_to_list(marking) + marking = utils.convert_to_marking_list(marking) utils.validate(obj, selectors) granular_markings = obj.get("granular_markings", []) diff --git a/stix2/markings/object_markings.py b/stix2/markings/object_markings.py index c39c036..a775ddc 100644 --- a/stix2/markings/object_markings.py +++ b/stix2/markings/object_markings.py @@ -31,7 +31,7 @@ def add_markings(obj, marking): A new version of the given SDO or SRO with specified markings added. """ - marking = utils.convert_to_list(marking) + marking = utils.convert_to_marking_list(marking) object_markings = set(obj.get("object_marking_refs", []) + marking) @@ -55,7 +55,7 @@ def remove_markings(obj, marking): A new version of the given SDO or SRO with specified markings removed. """ - marking = utils.convert_to_list(marking) + marking = utils.convert_to_marking_list(marking) object_markings = obj.get("object_marking_refs", []) @@ -121,7 +121,7 @@ def is_marked(obj, marking=None): provided marking refs match, True is returned. """ - marking = utils.convert_to_list(marking) + marking = utils.convert_to_marking_list(marking) object_markings = obj.get("object_marking_refs", []) if marking: diff --git a/stix2/markings/utils.py b/stix2/markings/utils.py index d0d38bb..1154d19 100644 --- a/stix2/markings/utils.py +++ b/stix2/markings/utils.py @@ -37,6 +37,12 @@ def _validate_selector(obj, selector): return True +def _get_marking_id(marking): + if type(marking).__name__ is 'MarkingDefinition': # avoid circular import + return marking.id + return marking + + def validate(obj, selectors): """Given an SDO or SRO, check that each selector is valid.""" if selectors: @@ -57,6 +63,15 @@ def convert_to_list(data): return [data] +def convert_to_marking_list(data): + """Convert input into a list of marking identifiers.""" + if data is not None: + if isinstance(data, list): + return [_get_marking_id(x) for x in data] + else: + return [_get_marking_id(data)] + + def compress_markings(granular_markings): """ Compress granular markings list. If there is more than one marking diff --git a/stix2/sdo.py b/stix2/sdo.py index 77c781a..53f965d 100644 --- a/stix2/sdo.py +++ b/stix2/sdo.py @@ -6,6 +6,7 @@ import stix2 from .base import _STIXBase from .common import ExternalReference, GranularMarking, KillChainPhase +from .markings import MarkingsMixin from .observables import ObservableProperty from .properties import (BooleanProperty, IDProperty, IntegerProperty, ListProperty, PatternProperty, ReferenceProperty, @@ -13,7 +14,11 @@ from .properties import (BooleanProperty, IDProperty, IntegerProperty, from .utils import NOW -class AttackPattern(_STIXBase): +class STIXDomainObject(_STIXBase, MarkingsMixin): + pass + + +class AttackPattern(STIXDomainObject): _type = 'attack-pattern' _properties = OrderedDict() @@ -34,7 +39,7 @@ class AttackPattern(_STIXBase): ]) -class Campaign(_STIXBase): +class Campaign(STIXDomainObject): _type = 'campaign' _properties = OrderedDict() @@ -58,7 +63,7 @@ class Campaign(_STIXBase): ]) -class CourseOfAction(_STIXBase): +class CourseOfAction(STIXDomainObject): _type = 'course-of-action' _properties = OrderedDict() @@ -78,7 +83,7 @@ class CourseOfAction(_STIXBase): ]) -class Identity(_STIXBase): +class Identity(STIXDomainObject): _type = 'identity' _properties = OrderedDict() @@ -101,7 +106,7 @@ class Identity(_STIXBase): ]) -class Indicator(_STIXBase): +class Indicator(STIXDomainObject): _type = 'indicator' _properties = OrderedDict() @@ -125,7 +130,7 @@ class Indicator(_STIXBase): ]) -class IntrusionSet(_STIXBase): +class IntrusionSet(STIXDomainObject): _type = 'intrusion-set' _properties = OrderedDict() @@ -152,7 +157,7 @@ class IntrusionSet(_STIXBase): ]) -class Malware(_STIXBase): +class Malware(STIXDomainObject): _type = 'malware' _properties = OrderedDict() @@ -173,7 +178,7 @@ class Malware(_STIXBase): ]) -class ObservedData(_STIXBase): +class ObservedData(STIXDomainObject): _type = 'observed-data' _properties = OrderedDict() @@ -195,7 +200,7 @@ class ObservedData(_STIXBase): ]) -class Report(_STIXBase): +class Report(STIXDomainObject): _type = 'report' _properties = OrderedDict() @@ -217,7 +222,7 @@ class Report(_STIXBase): ]) -class ThreatActor(_STIXBase): +class ThreatActor(STIXDomainObject): _type = 'threat-actor' _properties = OrderedDict() @@ -245,7 +250,7 @@ class ThreatActor(_STIXBase): ]) -class Tool(_STIXBase): +class Tool(STIXDomainObject): _type = 'tool' _properties = OrderedDict() @@ -267,7 +272,7 @@ class Tool(_STIXBase): ]) -class Vulnerability(_STIXBase): +class Vulnerability(STIXDomainObject): _type = 'vulnerability' _properties = OrderedDict() @@ -316,7 +321,7 @@ def CustomObject(type='x-custom-type', properties=None): def custom_builder(cls): - class _Custom(cls, _STIXBase): + class _Custom(cls, STIXDomainObject): _type = type _properties = OrderedDict() _properties.update([ diff --git a/stix2/sources/__init__.py b/stix2/sources/__init__.py index cb6e5b5..6fcd17b 100644 --- a/stix2/sources/__init__.py +++ b/stix2/sources/__init__.py @@ -7,21 +7,12 @@ Classes: DataSource CompositeDataSource -TODO:Test everything - -Notes: - add_filter(), remove_filter(), deduplicate() - if these functions remain - the exact same for DataSource, DataSink, CompositeDataSource etc... -> just - make those functions an interface to inherit? """ import uuid -from six import iteritems - -from stix2.sources.filters import (FILTER_OPS, FILTER_VALUE_TYPES, - STIX_COMMON_FIELDS, STIX_COMMON_FILTERS_MAP) +from stix2.utils import deduplicate def make_id(): @@ -29,13 +20,21 @@ def make_id(): class DataStore(object): - """ - An implementer will create a concrete subclass from - this abstract class for the specific data store. + """An implementer will create a concrete subclass from + this class for the specific DataStore. + + Args: + source (DataSource): An existing DataSource to use + as this DataStore's DataSource component + + sink (DataSink): An existing DataSink to use + as this DataStore's DataSink component Attributes: id (str): A unique UUIDv4 to identify this DataStore. - source (DataStore): An object that implements DataStore class. + + source (DataSource): An object that implements DataSource class. + sink (DataSink): An object that implements DataSink class. """ @@ -47,14 +46,13 @@ class DataStore(object): def get(self, stix_id): """Retrieve the most recent version of a single STIX object by ID. - Notes: - Translate API get() call to the appropriate DataSource call. + Translate get() call to the appropriate DataSource call. Args: - stix_id (str): the id of the STIX 2.0 object to retrieve. + stix_id (str): the id of the STIX object to retrieve. Returns: - stix_obj (dictionary): the single most recent version of the STIX + stix_obj: the single most recent version of the STIX object specified by the "id". """ @@ -63,15 +61,13 @@ class DataStore(object): def all_versions(self, stix_id): """Retrieve all versions of a single STIX object by ID. - Implement: - Translate all_versions() call to the appropriate DataSource call + Implement: Translate all_versions() call to the appropriate DataSource call Args: - stix_id (str): the id of the STIX 2.0 object to retrieve. + stix_id (str): the id of the STIX object to retrieve. Returns: - stix_objs (list): a list of STIX objects (where each object is a - STIX object) + stix_objs (list): a list of STIX objects """ return self.source.all_versions(stix_id) @@ -79,17 +75,15 @@ class DataStore(object): def query(self, query): """Retrieve STIX objects matching a set of filters. - Notes: - Implement the specific data source API calls, processing, - functionality required for retrieving query from the data source. + Implement: Specific data source API calls, processing, + functionality required for retrieving query from the data source. Args: query (list): a list of filters (which collectively are the query) to conduct search on. Returns: - stix_objs (list): a list of STIX objects (where each object is a - STIX object) + stix_objs (list): a list of STIX objects """ return self.source.query(query=query) @@ -97,21 +91,17 @@ class DataStore(object): def add(self, stix_objs): """Store STIX objects. - Notes: - Translate add() to the appropriate DataSink call(). + Translates add() to the appropriate DataSink call. Args: - stix_objs (list): a list of STIX objects (where each object is a - STIX object) - + stix_objs (list): a list of STIX objects """ return self.sink.add(stix_objs) class DataSink(object): - """ - Abstract class for defining a data sink. Intended for subclassing into - different sink components. + """An implementer will create a concrete subclass from + this class for the specific DataSink. Attributes: id (str): A unique UUIDv4 to identify this DataSink. @@ -123,9 +113,8 @@ class DataSink(object): def add(self, stix_objs): """Store STIX objects. - Notes: - Implement the specific data sink API calls, processing, - functionality required for adding data to the sink + Implement: Specific data sink API calls, processing, + functionality required for adding data to the sink Args: stix_objs (list): a list of STIX objects (where each object is a @@ -136,13 +125,13 @@ class DataSink(object): class DataSource(object): - """ - Abstract class for defining a data source. Intended for subclassing into - different source components. + """An implementer will create a concrete subclass from + this class for the specific DataSource. Attributes: id (str): A unique UUIDv4 to identify this DataSource. - filters (set): A collection of filters present in this DataSource. + + _filters (set): A collection of filters attached to this DataSource. """ def __init__(self): @@ -151,179 +140,76 @@ class DataSource(object): def get(self, stix_id, _composite_filters=None): """ - Fill: - Implement the specific data source API calls, processing, - functionality required for retrieving data from the data source + Implement: Specific data source API calls, processing, + functionality required for retrieving data from the data source Args: stix_id (str): the id of the STIX 2.0 object to retrieve. Should return a single object, the most recent version of the object specified by the "id". - _composite_filters (list): list of filters passed along from - the Composite Data Filter. + _composite_filters (set): set of filters passed from the parent + the CompositeDataSource, not user supplied Returns: - stix_obj (dictionary): the STIX object to be returned + stix_obj: the STIX object """ raise NotImplementedError() def all_versions(self, stix_id, _composite_filters=None): """ - Notes: - Similar to get() except returns list of all object versions of - the specified "id". In addition, implement the specific data - source API calls, processing, functionality required for retrieving - data from the data source. + Implement: Similar to get() except returns list of all object versions of + the specified "id". In addition, implement the specific data + source API calls, processing, functionality required for retrieving + data from the data source. Args: stix_id (str): The id of the STIX 2.0 object to retrieve. Should return a list of objects, all the versions of the object specified by the "id". - _composite_filters (list): list of filters passed from the - Composite Data Source + _composite_filters (set): set of filters passed from the parent + CompositeDataSource, not user supplied Returns: - stix_objs (list): a list of STIX objects (where each object is a - STIX object) + stix_objs (list): a list of STIX objects """ raise NotImplementedError() def query(self, query, _composite_filters=None): """ - Fill: - -implement the specific data source API calls, processing, - functionality required for retrieving query from the data source + Implement:Implement the specific data source API calls, processing, + functionality required for retrieving query from the data source Args: query (list): a list of filters (which collectively are the query) to conduct search on - _composite_filters (list): a list of filters passed from the - Composite Data Source + _composite_filters (set): a set of filters passed from the parent + CompositeDataSource, not user supplied Returns: + stix_objs (list): a list of STIX objects """ raise NotImplementedError() - def add_filters(self, filters): - """Add multiple filters to be applied to all queries for STIX objects. - - Args: - filters (list): list of filters (dict) to add to the Data Source. - - """ - for filter in filters: - self.add_filter(filter) - - def add_filter(self, filter): - """Add a filter to be applied to all queries for STIX objects. - - Args: - filter: filter to add to the Data Source. - - """ - # check filter field is a supported STIX 2.0 common field - if filter.field not in STIX_COMMON_FIELDS: - raise ValueError("Filter 'field' is not a STIX 2.0 common property. Currently only STIX object common properties supported") - - # check filter operator is supported - if filter.op not in FILTER_OPS: - raise ValueError("Filter operation (from 'op' field) not supported") - - # check filter value type is supported - if type(filter.value) not in FILTER_VALUE_TYPES: - raise ValueError("Filter 'value' type is not supported. The type(value) must be python immutable type or dictionary") - - self.filters.add(filter) - - def apply_common_filters(self, stix_objs, query): - """Evaluate filters against a set of STIX 2.0 objects. - - Supports only STIX 2.0 common property fields - - Args: - stix_objs (list): list of STIX objects to apply the query to - query (list): list of filters (combined form complete query) - - Returns: - (list): list of STIX objects that successfully evaluate against - the query. - - """ - filtered_stix_objs = [] - - # evaluate objects against filter - for stix_obj in stix_objs: - clean = True - for filter_ in query: - # skip filter as filter was identified (when added) as - # not a common filter - if filter_.field not in STIX_COMMON_FIELDS: - raise ValueError("Error, field: {0} is not supported for filtering on.".format(filter_.field)) - - # For properties like granular_markings and external_references - # need to break the first property from the string. - if "." in filter_.field: - field = filter_.field.split(".")[0] - else: - field = filter_.field - - # check filter "field" is in STIX object - if cant be - # applied due to STIX object, STIX object is discarded - # (i.e. did not make it through the filter) - if field not in stix_obj.keys(): - clean = False - break - - match = STIX_COMMON_FILTERS_MAP[filter_.field.split('.')[0]](filter_, stix_obj) - if not match: - clean = False - break - elif match == -1: - raise ValueError("Error, filter operator: {0} not supported for specified field: {1}".format(filter_.op, filter_.field)) - - # if object unmarked after all filters, add it - if clean: - filtered_stix_objs.append(stix_obj) - - return filtered_stix_objs - - def deduplicate(self, stix_obj_list): - """Deduplicate a list of STIX objects to a unique set - - Reduces a set of STIX objects to unique set by looking - at 'id' and 'modified' fields - as a unique object version - is determined by the combination of those fields - - Args: - stix_obj_list (list): list of STIX objects (dicts) - - Returns: - A list with a unique set of the passed list of STIX objects. - - """ - unique_objs = {} - - for obj in stix_obj_list: - unique_objs[(obj['id'], obj['modified'])] = obj - - return list(unique_objs.values()) - class CompositeDataSource(DataSource): - """Controller for all the defined/configured STIX Data Sources. + """Controller for all the attached DataSources. - E.g. a user can define n Data Sources - creating Data Source (objects) - for each. There is only one instance of this for any Python STIX 2.0 - application. + A user can have a single CompositeDataSource as an interface + the a set of DataSources. When an API call is made to the + CompositeDataSource, it is delegated to each of the (real) + DataSources that are attached to it. + + DataSources can be attached to CompositeDataSource for a variety + of reasons, e.g. common filters, organization, less API calls. Attributes: - name (str): The name that identifies this CompositeDataSource. + data_sources (dict): A dictionary of DataSource objects; to be controlled and used by the Data Source Controller object. @@ -332,49 +218,52 @@ class CompositeDataSource(DataSource): """Create a new STIX Data Source. Args: - name (str): A string containing the name to attach in the - CompositeDataSource instance. """ super(CompositeDataSource, self).__init__() - self.data_sources = {} + self.data_sources = [] def get(self, stix_id, _composite_filters=None): - """Retrieve STIX object by 'id' + """Retrieve STIX object by STIX ID - Federated retrieve method-iterates through all STIX data sources + Federated retrieve method, iterates through all DataSources defined in the "data_sources" parameter. Each data source has a specific API retrieve-like function and associated parameters. This function does a federated retrieval and consolidation of the data returned from all the STIX data sources. - Notes: - A composite data source will pass its attached filters to - each configured data source, pushing filtering to them to handle. + A composite data source will pass its attached filters to + each configured data source, pushing filtering to them to handle. Args: stix_id (str): the id of the STIX object to retrieve. - _composite_filters (list): a list of filters passed from the - Composite Data Source + _composite_filters (list): a list of filters passed from a + CompositeDataSource (i.e. if this CompositeDataSource is attached + to another parent CompositeDataSource), not user supplied Returns: - stix_obj (dict): the STIX object to be returned. + stix_obj: the STIX object to be returned. """ - if not self.get_all_data_sources(): + if not self.has_data_sources(): raise AttributeError('CompositeDataSource has no data sources') all_data = [] + all_filters = set() + all_filters.update(self.filters) + + if _composite_filters: + all_filters.update(_composite_filters) # for every configured Data Source, call its retrieve handler - for ds_id, ds in iteritems(self.data_sources): - data = ds.get(stix_id=stix_id, _composite_filters=list(self.filters)) + for ds in self.data_sources: + data = ds.get(stix_id=stix_id, _composite_filters=all_filters) all_data.append(data) # remove duplicate versions if len(all_data) > 0: - all_data = self.deduplicate(all_data) + all_data = deduplicate(all_data) # reduce to most recent version stix_obj = sorted(all_data, key=lambda k: k['modified'], reverse=True)[0] @@ -382,128 +271,149 @@ class CompositeDataSource(DataSource): return stix_obj def all_versions(self, stix_id, _composite_filters=None): - """Retrieve STIX objects by 'id' + """Retrieve STIX objects by STIX ID - Federated all_versions retrieve method - iterates through all STIX data - sources defined in "data_sources" + Federated all_versions retrieve method - iterates through all DataSources + defined in "data_sources" - Notes: - A composite data source will pass its attached filters to - each configured data source, pushing filtering to them to handle + A composite data source will pass its attached filters to + each configured data source, pushing filtering to them to handle Args: stix_id (str): id of the STIX objects to retrieve - _composite_filters (list): a list of filters passed from the - Composite Data Source + _composite_filters (list): a list of filters passed from a + CompositeDataSource (i.e. if this CompositeDataSource is attached + to a parent CompositeDataSource), not user supplied Returns: all_data (list): list of STIX objects that have the specified id """ - if not self.get_all_data_sources(): + if not self.has_data_sources(): raise AttributeError('CompositeDataSource has no data sources') all_data = [] - all_filters = self.filters + all_filters = set() + + all_filters.update(self.filters) if _composite_filters: - all_filters = set(self.filters).update(_composite_filters) + all_filters.update(_composite_filters) # retrieve STIX objects from all configured data sources - for ds_id, ds in iteritems(self.data_sources): - data = ds.all_versions(stix_id=stix_id, _composite_filters=list(all_filters)) + for ds in self.data_sources: + data = ds.all_versions(stix_id=stix_id, _composite_filters=all_filters) all_data.extend(data) # remove exact duplicates (where duplicates are STIX 2.0 objects # with the same 'id' and 'modified' values) if len(all_data) > 0: - all_data = self.deduplicate(all_data) + all_data = deduplicate(all_data) return all_data def query(self, query=None, _composite_filters=None): - """Federate the query to all Data Sources attached to the + """Retrieve STIX objects that match query + + Federate the query to all DataSources attached to the Composite Data Source. Args: - query (list): list of filters to search on. + query (list): list of filters to search on - _composite_filters (list): a list of filters passed from the - Composite Data Source + _composite_filters (list): a list of filters passed from a + CompositeDataSource (i.e. if this CompositeDataSource is attached + to a parent CompositeDataSource), not user supplied Returns: all_data (list): list of STIX objects to be returned """ - if not self.get_all_data_sources(): + if not self.has_data_sources(): raise AttributeError('CompositeDataSource has no data sources') if not query: + # dont mess with the query (i.e. convert to a set, as thats done + # within the specific DataSources that are called) query = [] all_data = [] - all_filters = self.filters + + all_filters = set() + all_filters.update(self.filters) if _composite_filters: - all_filters = set(self.filters).update(_composite_filters) + all_filters.update(_composite_filters) # federate query to all attached data sources, # pass composite filters to id - for ds_id, ds in iteritems(self.data_sources): - data = ds.query(query=query, _composite_filters=list(all_filters)) + for ds in self.data_sources: + data = ds.query(query=query, _composite_filters=all_filters) all_data.extend(data) # remove exact duplicates (where duplicates are STIX 2.0 # objects with the same 'id' and 'modified' values) if len(all_data) > 0: - all_data = self.deduplicate(all_data) + all_data = deduplicate(all_data) return all_data - def add_data_source(self, data_sources): - """Add/attach Data Source to the Composite Data Source instance + def add_data_source(self, data_source): + """Attach a DataSource to CompositeDataSource instance Args: - data_sources (list): a list of Data Source objects to attach - to the Composite Data Source + data_source (DataSource): a stix2.DataSource to attach + to the CompositeDataSource """ - if not isinstance(data_sources, list): - data_sources = [data_sources] + if issubclass(data_source.__class__, DataSource): + if data_source.id not in [ds_.id for ds_ in self.data_sources]: + # check DataSource not already attached CompositeDataSource + self.data_sources.append(data_source) + else: + raise TypeError("DataSource (to be added) is not of type stix2.DataSource. DataSource type is '%s'" % type(data_source)) + + return + + def add_data_sources(self, data_sources): + """Attach list of DataSources to CompositeDataSource instance + + Args: + data_sources (list): stix2.DataSources to attach to + CompositeDataSource + """ for ds in data_sources: - if issubclass(ds.__class__, DataSource): - if ds.id in self.data_sources: - # data source already attached to Composite Data Source - continue - - # add data source to Composite Data Source - # (its id will be its key identifier) - self.data_sources[ds.id] = ds - else: - # the Data Source object is not a proper subclass - # of DataSource Abstract Class - # TODO: maybe log error? - continue - + self.add_data_source(ds) return - def remove_data_source(self, data_source_ids): - """Remove/detach Data Source from the Composite Data Source instance + def remove_data_source(self, data_source_id): + """Remove DataSource from the CompositeDataSource instance Args: - data_source_ids (list): a list of Data Source identifiers. + data_source_id (str): DataSource IDs. """ - for id in data_source_ids: - if id in self.data_sources: - del self.data_sources[id] - else: - raise ValueError("DataSource 'id' not found in CompositeDataSource collection.") + def _match(ds_id, candidate_ds_id): + return ds_id == candidate_ds_id + + self.data_sources[:] = [ds for ds in self.data_sources if not _match(ds.id, data_source_id)] + return + def remove_data_sources(self, data_source_ids): + """Remove DataSources from the CompositeDataSource instance + + Args: + data_source_ids (list): DataSource IDs + + """ + for ds_id in data_source_ids: + self.remove_data_source(ds_id) + return + + def has_data_sources(self): + return len(self.data_sources) + def get_all_data_sources(self): - """Return all attached Data Sources - - """ - return self.data_sources.values() + return self.data_sources diff --git a/stix2/sources/filesystem.py b/stix2/sources/filesystem.py index ade36c8..c45f281 100644 --- a/stix2/sources/filesystem.py +++ b/stix2/sources/filesystem.py @@ -12,71 +12,148 @@ TODO: Test everything import json import os -from stix2 import Bundle +from stix2.base import _STIXBase +from stix2.core import Bundle, parse from stix2.sources import DataSink, DataSource, DataStore -from stix2.sources.filters import Filter +from stix2.sources.filters import Filter, apply_common_filters +from stix2.utils import deduplicate class FileSystemStore(DataStore): + """FileSystemStore + + Provides an interface to an file directory of STIX objects. + FileSystemStore is a wrapper around a paired FileSystemSink + and FileSystemSource. + + Args: + stix_dir (str): path to directory of STIX objects + + Attributes: + source (FileSystemSource): FuleSystemSource + + sink (FileSystemSink): FileSystemSink + """ - """ - def __init__(self, stix_dir="stix_data"): + def __init__(self, stix_dir): super(FileSystemStore, self).__init__() self.source = FileSystemSource(stix_dir=stix_dir) self.sink = FileSystemSink(stix_dir=stix_dir) class FileSystemSink(DataSink): - """ - """ - def __init__(self, stix_dir="stix_data"): - super(FileSystemSink, self).__init__() - self.stix_dir = os.path.abspath(stix_dir) + """FileSystemSink - # check directory path exists - if not os.path.exists(self.stix_dir): - print("Error: directory path for STIX data does not exist") + Provides an interface for adding/pushing STIX objects + to file directory of STIX objects. + + Can be paired with a FileSystemSource, together as the two + components of a FileSystemStore. + + Args: + stix_dir (str): path to directory of STIX objects + + """ + def __init__(self, stix_dir): + super(FileSystemSink, self).__init__() + self._stix_dir = os.path.abspath(stix_dir) + + if not os.path.exists(self._stix_dir): + raise ValueError("directory path for STIX data does not exist") @property def stix_dir(self): - return self.stix_dir + return self._stix_dir - @stix_dir.setter - def stix_dir(self, dir): - self.stix_dir = dir + def add(self, stix_data=None): + """add STIX objects to file directory - def add(self, stix_objs=None): + Args: + stix_data (STIX object OR dict OR str OR list): valid STIX 2.0 content + in a STIX object(or list of), dict (or list of), or a STIX 2.0 + json encoded string + + TODO: Bundlify STIX content or no? When dumping to disk. """ - Q: bundlify or no? - """ - if not stix_objs: - stix_objs = [] - for stix_obj in stix_objs: - path = os.path.join(self.stix_dir, stix_obj["type"], stix_obj["id"]) - json.dump(Bundle([stix_obj]), open(path, 'w+'), indent=4) + def _check_path_and_write(stix_dir, stix_obj): + path = os.path.join(stix_dir, stix_obj["type"], stix_obj["id"] + ".json") + + if not os.path.exists(os.path.dirname(path)): + os.makedirs(os.path.dirname(path)) + + with open(path, "w") as f: + # Bundle() can take dict or STIX obj as argument + f.write(str(Bundle(stix_obj))) + + if isinstance(stix_data, _STIXBase): + # adding python STIX object + _check_path_and_write(self._stix_dir, stix_data) + + elif isinstance(stix_data, dict): + if stix_data["type"] == "bundle": + # adding json-formatted Bundle - extracting STIX objects + for stix_obj in stix_data["objects"]: + self.add(stix_obj) + else: + # adding json-formatted STIX + _check_path_and_write(self._stix_dir, stix_data) + + elif isinstance(stix_data, str): + # adding json encoded string of STIX content + stix_data = parse(stix_data) + if stix_data["type"] == "bundle": + for stix_obj in stix_data: + self.add(stix_obj) + else: + self.add(stix_data) + + elif isinstance(stix_data, list): + # if list, recurse call on individual STIX objects + for stix_obj in stix_data: + self.add(stix_obj) + + else: + raise ValueError("stix_data must be a STIX object(or list of, json formatted STIX(or list of) or a json formatted STIX bundle") class FileSystemSource(DataSource): - """ - """ - def __init__(self, stix_dir="stix_data"): - super(FileSystemSource, self).__init__() - self.stix_dir = os.path.abspath(stix_dir) + """FileSystemSource - # check directory path exists - if not os.path.exists(self.stix_dir): + Provides an interface for searching/retrieving + STIX objects from a STIX object file directory. + + Can be paired with a FileSystemSink, together as the two + components of a FileSystemStore. + + Args: + stix_dir (str): path to directory of STIX objects + + """ + def __init__(self, stix_dir): + super(FileSystemSource, self).__init__() + self._stix_dir = os.path.abspath(stix_dir) + + if not os.path.exists(self._stix_dir): print("Error: directory path for STIX data does not exist") @property def stix_dir(self): - return self.stix_dir - - @stix_dir.setter - def stix_dir(self, dir): - self.stix_dir = dir + return self._stix_dir def get(self, stix_id, _composite_filters=None): - """ + """retrieve STIX object from file directory via STIX ID + + Args: + stix_id (str): The STIX ID of the STIX object to be retrieved. + + composite_filters (set): set of filters passed from the parent + CompositeDataSource, not user supplied + + Returns: + (STIX object): STIX object that has the supplied STIX ID. + The STIX object is loaded from its json file, parsed into + a python STIX object and then returned + """ query = [Filter("id", "=", stix_id)] @@ -84,30 +161,63 @@ class FileSystemSource(DataSource): stix_obj = sorted(all_data, key=lambda k: k['modified'])[0] - return stix_obj + return parse(stix_obj) def all_versions(self, stix_id, _composite_filters=None): - """ - Notes: - Since FileSystem sources/sinks don't handle multiple versions - of a STIX object, this operation is unnecessary. Pass call to get(). + """retrieve STIX object from file directory via STIX ID, all versions + + Note: Since FileSystem sources/sinks don't handle multiple versions + of a STIX object, this operation is unnecessary. Pass call to get(). + + Args: + stix_id (str): The STIX ID of the STIX objects to be retrieved. + + composite_filters (set): set of filters passed from the parent + CompositeDataSource, not user supplied + + Returns: + (list): of STIX objects that has the supplied STIX ID. + The STIX objects are loaded from their json files, parsed into + a python STIX objects and then returned """ return [self.get(stix_id=stix_id, _composite_filters=_composite_filters)] def query(self, query=None, _composite_filters=None): - """ + """search and retrieve STIX objects based on the complete query + + A "complete query" includes the filters from the query, the filters + attached to MemorySource, and any filters passed from a + CompositeDataSource (i.e. _composite_filters) + + Args: + query (list): list of filters to search on + + composite_filters (set): set of filters passed from the + CompositeDataSource, not user supplied + + Returns: + (list): list of STIX objects that matches the supplied + query. The STIX objects are loaded from their json files, + parsed into a python STIX objects and then returned. + """ all_data = [] if query is None: - query = [] + query = set() + else: + if not isinstance(query, list): + # make sure dont make set from a Filter object, + # need to make a set from a list of Filter objects (even if just one Filter) + query = list(query) + query = set(query) # combine all query filters if self.filters: - query.extend(self.filters.values()) + query.update(self.filters) if _composite_filters: - query.extend(_composite_filters) + query.update(_composite_filters) # extract any filters that are for "type" or "id" , as we can then do # filtering before reading in the STIX objects. A STIX 'type' filter @@ -125,12 +235,12 @@ class FileSystemSource(DataSource): for filter in file_filters: if filter.field == "type": if filter.op == "=": - include_paths.append(os.path.join(self.stix_dir, filter.value)) + include_paths.append(os.path.join(self._stix_dir, filter.value)) elif filter.op == "!=": - declude_paths.append(os.path.join(self.stix_dir, filter.value)) + declude_paths.append(os.path.join(self._stix_dir, filter.value)) else: # have to walk entire STIX directory - include_paths.append(self.stix_dir) + include_paths.append(self._stix_dir) # if a user specifies a "type" filter like "type = ", # the filter is reducing the search space to single stix object types @@ -144,7 +254,7 @@ class FileSystemSource(DataSource): # user has specified types that are not wanted (i.e. "!=") # so query will look in all STIX directories that are not # the specified type. Compile correct dir paths - for dir in os.listdir(self.stix_dir): + for dir in os.listdir(self._stix_dir): if os.path.abspath(dir) not in declude_paths: include_paths.append(os.path.abspath(dir)) @@ -153,36 +263,50 @@ class FileSystemSource(DataSource): if "id" in [filter.field for filter in file_filters]: for filter in file_filters: if filter.field == "id" and filter.op == "=": - id = filter.value + id_ = filter.value break else: - id = None + id_ = None else: - id = None + id_ = None # now iterate through all STIX objs for path in include_paths: for root, dirs, files in os.walk(path): - for file in files: - if id: - if id == file.split(".")[0]: + for file_ in files: + if id_: + if id_ == file_.split(".")[0]: # since ID is specified in one of filters, can evaluate against filename first without loading - stix_obj = json.load(file)["objects"] + stix_obj = json.load(open(os.path.join(root, file_)))["objects"][0] # check against other filters, add if match - all_data.extend(self.apply_common_filters([stix_obj], query)) + matches = [stix_obj_ for stix_obj_ in apply_common_filters([stix_obj], query)] + all_data.extend(matches) else: # have to load into memory regardless to evaluate other filters - stix_obj = json.load(file)["objects"] - all_data.extend(self.apply_common_filters([stix_obj], query)) + stix_obj = json.load(open(os.path.join(root, file_)))["objects"][0] + matches = [stix_obj_ for stix_obj_ in apply_common_filters([stix_obj], query)] + all_data.extend(matches) - all_data = self.deduplicate(all_data) - return all_data + all_data = deduplicate(all_data) + + # parse python STIX objects from the STIX object dicts + stix_objs = [parse(stix_obj_dict) for stix_obj_dict in all_data] + + return stix_objs def _parse_file_filters(self, query): + """utility method to extract STIX common filters + that can used to possibly speed up querying STIX objects + from the file system + + Extracts filters that are for the "id" and "type" field of + a STIX object. As the file directory is organized by STIX + object type with filenames that are equivalent to the STIX + object ID, these filters can be used first to reduce the + search space of a FileSystemStore(or FileSystemSink) """ - """ - file_filters = [] - for filter in query: - if filter.field == "id" or filter.field == "type": - file_filters.append(filter) + file_filters = set() + for filter_ in query: + if filter_.field == "id" or filter_.field == "type": + file_filters.add(filter_) return file_filters diff --git a/stix2/sources/filters.py b/stix2/sources/filters.py index 7758369..a565006 100644 --- a/stix2/sources/filters.py +++ b/stix2/sources/filters.py @@ -4,10 +4,6 @@ Filters for Python STIX 2.0 DataSources, DataSinks, DataStores Classes: Filter -TODO: The script at the bottom of the module works (to capture -all the callable filter methods), however it causes this module -to be imported by itself twice. Not sure how big of deal that is, -or if cleaner solution possible. """ import collections @@ -15,6 +11,8 @@ import types # Currently, only STIX 2.0 common SDO fields (that are not complex objects) # are supported for filtering on + +"""Supported STIX properties""" STIX_COMMON_FIELDS = [ "created", "created_by_ref", @@ -30,32 +28,140 @@ STIX_COMMON_FIELDS = [ "modified", "object_marking_refs", "revoked", - "type", - "granular_markings" + "type" ] -# Supported filter operations +"""Supported filter operations""" FILTER_OPS = ['=', '!=', 'in', '>', '<', '>=', '<='] -# Supported filter value types +"""Supported filter value types""" FILTER_VALUE_TYPES = [bool, dict, float, int, list, str, tuple] # filter lookup map - STIX 2 common fields -> filter method STIX_COMMON_FILTERS_MAP = {} +def _check_filter_components(field, op, value): + """check filter meets minimum validity + + Note: Currently can create Filters that are not valid + STIX2 object common properties, as filter.field value + is not checked, only filter.op, filter.value are checked + here. They are just ignored when + applied within the DataSource API. For example, a user + can add a TAXII Filter, that is extracted and sent to + a TAXII endpoint within TAXIICollection and not applied + locally (within this API). + """ + + if op not in FILTER_OPS: + # check filter operator is supported + raise ValueError("Filter operator '%s' not supported for specified field: '%s'" % (op, field)) + + if type(value) not in FILTER_VALUE_TYPES: + # check filter value type is supported + raise TypeError("Filter value type '%s' is not supported. The type must be a python immutable type or dictionary" % type(value)) + + return True + + class Filter(collections.namedtuple("Filter", ['field', 'op', 'value'])): + """STIX 2 filters that support the querying functionality of STIX 2 + DataStores and DataSources. + + Initialized like a python tuple + + Args: + field (str): filter field name, corresponds to STIX 2 object property + + op (str): operator of the filter + + value (str): filter field value + + Example: + Filter("id", "=", "malware--0f862b01-99da-47cc-9bdb-db4a86a95bb1") + + """ __slots__ = () def __new__(cls, field, op, value): # If value is a list, convert it to a tuple so it is hashable. if isinstance(value, list): value = tuple(value) + + _check_filter_components(field, op, value) + self = super(Filter, cls).__new__(cls, field, op, value) return self + @property + def common(self): + """return whether Filter is valid STIX2 Object common property + + Note: The Filter operator and Filter value type are checked when + the filter is created, thus only leaving the Filter field to be + checked to make sure a valid STIX2 Object common property. + + Note: Filters that are not valid STIX2 Object common property + Filters are still allowed to be created for extended usage of + Filter. (e.g. TAXII specific filters can be created, which are + then extracted and sent to TAXII endpoint.) + """ + return self.field in STIX_COMMON_FIELDS + + +def apply_common_filters(stix_objs, query): + """Evaluate filters against a set of STIX 2.0 objects. + + Supports only STIX 2.0 common property fields + + Args: + stix_objs (list): list of STIX objects to apply the query to + + query (set): set of filters (combined form complete query) + + Returns: + (generator): of STIX objects that successfully evaluate against + the query. + + """ + + for stix_obj in stix_objs: + clean = True + for filter_ in query: + if not filter_.common: + # skip filter as it is not a STIX2 Object common property filter + continue + + if "." in filter_.field: + # For properties like granular_markings and external_references + # need to extract the first property from the string. + field = filter_.field.split(".")[0] + else: + field = filter_.field + + if field not in stix_obj.keys(): + # check filter "field" is in STIX object - if cant be + # applied to STIX object, STIX object is discarded + # (i.e. did not make it through the filter) + clean = False + break + + match = STIX_COMMON_FILTERS_MAP[filter_.field.split('.')[0]](filter_, stix_obj) + + if not match: + clean = False + break + elif match == -1: + raise ValueError("Error, filter operator: {0} not supported for specified field: {1}".format(filter_.op, filter_.field)) + + # if object unmarked after all filters, add it + if clean: + yield stix_obj + + +"""Base type filters""" -# primitive type filters def _all_filter(filter_, stix_obj_field): """all filter operations (for filters whose value type can be applied to any operation type)""" @@ -78,7 +184,7 @@ def _all_filter(filter_, stix_obj_field): def _id_filter(filter_, stix_obj_id): - """base filter types""" + """base STIX id filter""" if filter_.op == "=": return stix_obj_id == filter_.value elif filter_.op == "!=": @@ -88,6 +194,7 @@ def _id_filter(filter_, stix_obj_id): def _boolean_filter(filter_, stix_obj_field): + """base boolean filter""" if filter_.op == "=": return stix_obj_field == filter_.value elif filter_.op == "!=": @@ -97,19 +204,25 @@ def _boolean_filter(filter_, stix_obj_field): def _string_filter(filter_, stix_obj_field): + """base string filter""" return _all_filter(filter_, stix_obj_field) def _timestamp_filter(filter_, stix_obj_timestamp): + """base STIX 2 timestamp filter""" return _all_filter(filter_, stix_obj_timestamp) -# STIX 2.0 Common Property filters -# The naming of these functions is important as -# they are used to index a mapping dictionary from -# STIX common field names to these filter functions. -# -# REQUIRED naming scheme: -# "check__filter" + +"""STIX 2.0 Common Property Filters + +The naming of these functions is important as +they are used to index a mapping dictionary from +STIX common field names to these filter functions. + +REQUIRED naming scheme: + "check__filter" + +""" def check_created_filter(filter_, stix_obj): @@ -124,13 +237,15 @@ def check_external_references_filter(filter_, stix_obj): """ STIX object's can have a list of external references - external_references properties: + external_references properties supported: external_references.source_name (string) external_references.description (string) external_references.url (string) - external_references.hashes (hash, but for filtering purposes, a string) external_references.external_id (string) + external_references properties not supported: + external_references.hashes + """ for er in stix_obj["external_references"]: # grab er property name from filter field diff --git a/stix2/sources/memory.py b/stix2/sources/memory.py index 95d053c..c9910a6 100644 --- a/stix2/sources/memory.py +++ b/stix2/sources/memory.py @@ -6,7 +6,8 @@ Classes: MemorySink MemorySource -TODO: Test everything. + +TODO: Run through tests again, lot of changes. TODO: Use deduplicate() calls only when memory corpus is dirty (been added to) can save a lot of time for successive queries @@ -18,49 +19,87 @@ Notes: """ -import collections import json import os -from stix2 import Bundle +from stix2.base import _STIXBase +from stix2.core import Bundle, parse from stix2.sources import DataSink, DataSource, DataStore -from stix2.sources.filters import Filter +from stix2.sources.filters import Filter, apply_common_filters def _add(store, stix_data=None): - """Adds stix objects to MemoryStore/Source/Sink.""" - if isinstance(stix_data, collections.Mapping): - # stix objects are in a bundle - # make dictionary of the objects for easy lookup - for stix_obj in stix_data["objects"]: - store.data[stix_obj["id"]] = stix_obj + """Adds STIX objects to MemoryStore/Sink. + + Adds STIX objects to an in-memory dictionary for fast lookup. + Recursive function, breaks down STIX Bundles and lists. + + Args: + stix_data (list OR dict OR STIX object): STIX objects to be added + """ + + if isinstance(stix_data, _STIXBase): + # adding a python STIX object + store._data[stix_data["id"]] = stix_data + + elif isinstance(stix_data, dict): + if stix_data["type"] == "bundle": + # adding a json bundle - so just grab STIX objects + for stix_obj in stix_data["objects"]: + _add(store, stix_obj) + else: + # adding a json STIX object + store._data[stix_data["id"]] = stix_data + + elif isinstance(stix_data, str): + # adding json encoded string of STIX content + stix_data = parse(stix_data) + if stix_data["type"] == "bundle": + # recurse on each STIX object in bundle + for stix_obj in stix_data: + _add(store, stix_obj) + else: + _add(store, stix_data) + elif isinstance(stix_data, list): - # stix objects are in a list + # STIX objects are in a list- recurse on each object for stix_obj in stix_data: - store.data[stix_obj["id"]] = stix_obj + _add(store, stix_obj) + else: - raise ValueError("stix_data must be in bundle format or raw list") + raise TypeError("stix_data must be as STIX object(or list of),json formatted STIX (or list of), or a json formatted STIX bundle") class MemoryStore(DataStore): - """ - """ - def __init__(self, stix_data=None): - """ - Notes: - It doesn't make sense to create a MemoryStore by passing - in existing MemorySource and MemorySink because there could - be data concurrency issues. Just as easy to create new MemoryStore. + """Provides an interface to an in-memory dictionary + of STIX objects. MemoryStore is a wrapper around a paired + MemorySink and MemorySource - """ + Note: It doesn't make sense to create a MemoryStore by passing + in existing MemorySource and MemorySink because there could + be data concurrency issues. As well, just as easy to create new MemoryStore. + + Args: + stix_data (list OR dict OR STIX object): STIX content to be added + + Attributes: + _data (dict): the in-memory dict that holds STIX objects + + source (MemorySource): MemorySource + + sink (MemorySink): MemorySink + + """ + + def __init__(self, stix_data=None): super(MemoryStore, self).__init__() - self.data = {} + self._data = {} if stix_data: _add(self, stix_data) - self.source = MemorySource(stix_data=self.data, _store=True) - self.sink = MemorySink(stix_data=self.data, _store=True) + self.source = MemorySource(stix_data=self._data, _store=True) + self.sink = MemorySink(stix_data=self._data, _store=True) def save_to_file(self, file_path): return self.sink.save_to_file(file_path=file_path) @@ -70,64 +109,107 @@ class MemoryStore(DataStore): class MemorySink(DataSink): - """ - """ - def __init__(self, stix_data=None, _store=False): - """ - Args: - stix_data (dictionary OR list): valid STIX 2.0 content in - bundle or a list. - _store (bool): if the MemorySink is a part of a DataStore, - in which case "stix_data" is a direct reference to - shared memory with DataSource. + """Provides an interface for adding/pushing STIX objects + to an in-memory dictionary. - """ + Designed to be paired with a MemorySource, together as the two + components of a MemoryStore. + + Args: + stix_data (dict OR list): valid STIX 2.0 content in + bundle or a list. + + _store (bool): if the MemorySink is a part of a DataStore, + in which case "stix_data" is a direct reference to + shared memory with DataSource. Not user supplied + + Attributes: + _data (dict): the in-memory dict that holds STIX objects. + If apart of a MemoryStore, dict is shared between with + a MemorySource + """ + + def __init__(self, stix_data=None, _store=False): super(MemorySink, self).__init__() - self.data = {} + self._data = {} if _store: - self.data = stix_data + self._data = stix_data elif stix_data: - self.add(stix_data) + _add(self, stix_data) def add(self, stix_data): - """ + """add STIX objects to in-memory dictionary maintained by + the MemorySink (MemoryStore) + + see "_add()" for args documentation """ _add(self, stix_data) def save_to_file(self, file_path): + """write SITX objects in in-memory dictionary to json file, as a STIX Bundle + + Args: + file_path (str): file path to write STIX data to + """ - """ - json.dump(Bundle(self.data.values()), file_path, indent=4) + file_path = os.path.abspath(file_path) + if not os.path.exists(os.path.dirname(file_path)): + os.makedirs(os.path.dirname(file_path)) + with open(file_path, "w") as f: + f.write(str(Bundle(self._data.values()))) class MemorySource(DataSource): + """Provides an interface for searching/retrieving + STIX objects from an in-memory dictionary. + + Designed to be paired with a MemorySink, together as the two + components of a MemoryStore. + + Args: + stix_data (dict OR list OR STIX object): valid STIX 2.0 content in + bundle or list. + + _store (bool): if the MemorySource is a part of a DataStore, + in which case "stix_data" is a direct reference to shared + memory with DataSink. Not user supplied + + Attributes: + _data (dict): the in-memory dict that holds STIX objects. + If apart of a MemoryStore, dict is shared between with + a MemorySink + """ def __init__(self, stix_data=None, _store=False): - """ - Args: - stix_data (dictionary OR list): valid STIX 2.0 content in - bundle or list. - _store (bool): if the MemorySource is a part of a DataStore, - in which case "stix_data" is a direct reference to shared - memory with DataSink. - - """ super(MemorySource, self).__init__() - self.data = {} + self._data = {} if _store: - self.data = stix_data + self._data = stix_data elif stix_data: _add(self, stix_data) def get(self, stix_id, _composite_filters=None): + """retrieve STIX object from in-memory dict via STIX ID + + Args: + stix_id (str): The STIX ID of the STIX object to be retrieved. + + composite_filters (set): set of filters passed from the parent + CompositeDataSource, not user supplied + + Returns: + (dict OR STIX object): STIX object that has the supplied + ID. As the MemoryStore(i.e. MemorySink) adds STIX objects to memory + as they are supplied (either as python dictionary or STIX object), it + is returned in the same form as it as added """ - """ + if _composite_filters is None: # if get call is only based on 'id', no need to search, just retrieve from dict try: - stix_obj = self.data[stix_id] + stix_obj = self._data[stix_id] except KeyError: stix_obj = None return stix_obj @@ -143,44 +225,75 @@ class MemorySource(DataSource): return stix_obj def all_versions(self, stix_id, _composite_filters=None): - """ - Notes: - Since Memory sources/sinks don't handle multiple versions of a - STIX object, this operation is unnecessary. Translate call to get(). + """retrieve STIX objects from in-memory dict via STIX ID, all versions of it + + Note: Since Memory sources/sinks don't handle multiple versions of a + STIX object, this operation is unnecessary. Translate call to get(). Args: - stix_id (str): The id of the STIX 2.0 object to retrieve. Should - return a list of objects, all the versions of the object - specified by the "id". + stix_id (str): The STIX ID of the STIX 2 object to retrieve. + + composite_filters (set): set of filters passed from the parent + CompositeDataSource, not user supplied Returns: - (list): STIX object that matched ``stix_id``. + (list): list of STIX objects that has the supplied ID. As the + MemoryStore(i.e. MemorySink) adds STIX objects to memory as they + are supplied (either as python dictionary or STIX object), it + is returned in the same form as it as added """ return [self.get(stix_id=stix_id, _composite_filters=_composite_filters)] def query(self, query=None, _composite_filters=None): - """ + """search and retrieve STIX objects based on the complete query + + A "complete query" includes the filters from the query, the filters + attached to MemorySource, and any filters passed from a + CompositeDataSource (i.e. _composite_filters) + + Args: + query (list): list of filters to search on + + composite_filters (set): set of filters passed from the + CompositeDataSource, not user supplied + + Returns: + (list): list of STIX objects that matches the supplied + query. As the MemoryStore(i.e. MemorySink) adds STIX objects to memory + as they are supplied (either as python dictionary or STIX object), it + is returned in the same form as it as added + """ if query is None: - query = [] + query = set() + else: + if not isinstance(query, list): + # make sure dont make set from a Filter object, + # need to make a set from a list of Filter objects (even if just one Filter) + query = list(query) + query = set(query) # combine all query filters if self.filters: - query.extend(list(self.filters)) + query.update(self.filters) if _composite_filters: - query.extend(_composite_filters) + query.update(_composite_filters) # Apply STIX common property filters. - all_data = self.apply_common_filters(self.data.values(), query) + all_data = [stix_obj for stix_obj in apply_common_filters(self._data.values(), query)] return all_data def load_from_file(self, file_path): - """ + """load STIX data from json file + + File format is expected to be a single json + STIX object or json STIX bundle + + Args: + file_path (str): file path to load STIX data from """ file_path = os.path.abspath(file_path) stix_data = json.load(open(file_path, "r")) - - for stix_obj in stix_data["objects"]: - self.data[stix_obj["id"]] = stix_obj + _add(self, stix_data) diff --git a/stix2/sources/taxii.py b/stix2/sources/taxii.py index aa9bd11..63d5226 100644 --- a/stix2/sources/taxii.py +++ b/stix2/sources/taxii.py @@ -10,83 +10,144 @@ TODO: Test everything """ -import json - -from stix2.sources import DataSink, DataSource, DataStore, make_id -from stix2.sources.filters import Filter +from stix2.base import _STIXBase +from stix2.core import Bundle, parse +from stix2.sources import DataSink, DataSource, DataStore +from stix2.sources.filters import Filter, apply_common_filters +from stix2.utils import deduplicate TAXII_FILTERS = ['added_after', 'id', 'type', 'version'] class TAXIICollectionStore(DataStore): - """ + """Provides an interface to a local/remote TAXII Collection + of STIX data. TAXIICollectionStore is a wrapper + around a paired TAXIICollectionSink and TAXIICollectionSource. + + Args: + collection (taxii2.Collection): TAXII Collection instance """ def __init__(self, collection): - """ - Create a new TAXII Collection Data store - - Args: - collection (taxii2.Collection): Collection instance - - """ super(TAXIICollectionStore, self).__init__() self.source = TAXIICollectionSource(collection) self.sink = TAXIICollectionSink(collection) class TAXIICollectionSink(DataSink): - """ + """Provides an interface for pushing STIX objects to a local/remote + TAXII Collection endpoint. + + Args: + collection (taxii2.Collection): TAXII2 Collection instance + """ def __init__(self, collection): super(TAXIICollectionSink, self).__init__() self.collection = collection - def add(self, stix_obj): - """ - """ - self.collection.add_objects(self.create_bundle([json.loads(str(stix_obj))])) + def add(self, stix_data): + """add/push STIX content to TAXII Collection endpoint - @staticmethod - def create_bundle(objects): - return dict(id="bundle--%s" % make_id(), - objects=objects, - spec_version="2.0", - type="bundle") + Args: + stix_data (STIX object OR dict OR str OR list): valid STIX 2.0 content + in a STIX object (or Bundle), STIX onject dict (or Bundle dict), or a STIX 2.0 + json encoded string, or list of any of the following + + """ + + if isinstance(stix_data, _STIXBase): + # adding python STIX object + bundle = dict(Bundle(stix_data)) + + elif isinstance(stix_data, dict): + # adding python dict (of either Bundle or STIX obj) + if stix_data["type"] == "bundle": + bundle = stix_data + else: + bundle = dict(Bundle(stix_data)) + + elif isinstance(stix_data, list): + # adding list of something - recurse on each + for obj in stix_data: + self.add(obj) + + elif isinstance(stix_data, str): + # adding json encoded string of STIX content + stix_data = parse(stix_data) + if stix_data["type"] == "bundle": + bundle = dict(stix_data) + else: + bundle = dict(Bundle(stix_data)) + + else: + raise TypeError("stix_data must be as STIX object(or list of),json formatted STIX (or list of), or a json formatted STIX bundle") + + self.collection.add_objects(bundle) class TAXIICollectionSource(DataSource): - """ + """Provides an interface for searching/retrieving STIX objects + from a local/remote TAXII Collection endpoint. + + Args: + collection (taxii2.Collection): TAXII Collection instance + """ def __init__(self, collection): super(TAXIICollectionSource, self).__init__() self.collection = collection def get(self, stix_id, _composite_filters=None): - """ + """retrieve STIX object from local/remote STIX Collection + endpoint. + + Args: + stix_id (str): The STIX ID of the STIX object to be retrieved. + + composite_filters (set): set of filters passed from the parent + CompositeDataSource, not user supplied + + Returns: + (STIX object): STIX object that has the supplied STIX ID. + The STIX object is received from TAXII has dict, parsed into + a python STIX object and then returned + + """ # combine all query filters - query = [] + query = set() if self.filters: - query.extend(self.filters.values()) + query.update(self.filters) if _composite_filters: - query.extend(_composite_filters) + query.update(_composite_filters) # separate taxii query terms (can be done remotely) taxii_filters = self._parse_taxii_filters(query) stix_objs = self.collection.get_object(stix_id, taxii_filters)["objects"] - stix_obj = self.apply_common_filters(stix_objs, query) + stix_obj = [stix_obj for stix_obj in apply_common_filters(stix_objs, query)] - if len(stix_obj) > 0: + if len(stix_obj): stix_obj = stix_obj[0] else: stix_obj = None - return stix_obj + return parse(stix_obj) def all_versions(self, stix_id, _composite_filters=None): - """ + """retrieve STIX object from local/remote TAXII Collection + endpoint, all versions of it + + Args: + stix_id (str): The STIX ID of the STIX objects to be retrieved. + + composite_filters (set): set of filters passed from the parent + CompositeDataSource, not user supplied + + Returns: + (see query() as all_versions() is just a wrapper) + """ # make query in TAXII query format since 'id' is TAXII field query = [ @@ -99,16 +160,39 @@ class TAXIICollectionSource(DataSource): return all_data def query(self, query=None, _composite_filters=None): + """search and retreive STIX objects based on the complete query + + A "complete query" includes the filters from the query, the filters + attached to MemorySource, and any filters passed from a + CompositeDataSource (i.e. _composite_filters) + + Args: + query (list): list of filters to search on + + composite_filters (set): set of filters passed from the + CompositeDataSource, not user supplied + + Returns: + (list): list of STIX objects that matches the supplied + query. The STIX objects are received from TAXII as dicts, + parsed into python STIX objects and then returned. + """ - """ + if query is None: - query = [] + query = set() + else: + if not isinstance(query, list): + # make sure dont make set from a Filter object, + # need to make a set from a list of Filter objects (even if just one Filter) + query = list(query) + query = set(query) # combine all query filters if self.filters: - query.extend(self.filters.values()) + query.update(self.filters) if _composite_filters: - query.extend(_composite_filters) + query.update(_composite_filters) # separate taxii query terms (can be done remotely) taxii_filters = self._parse_taxii_filters(query) @@ -117,12 +201,15 @@ class TAXIICollectionSource(DataSource): all_data = self.collection.get_objects(filters=taxii_filters)["objects"] # deduplicate data (before filtering as reduces wasted filtering) - all_data = self.deduplicate(all_data) + all_data = deduplicate(all_data) - # apply local (composite and data source filters) - all_data = self.apply_common_filters(all_data, query) + # apply local (CompositeDataSource, TAXIICollectionSource and query filters) + all_data = [stix_obj for stix_obj in apply_common_filters(all_data, query)] - return all_data + # parse python STIX objects from the STIX object dicts + stix_objs = [parse(stix_obj_dict) for stix_obj_dict in all_data] + + return stix_objs def _parse_taxii_filters(self, query): """Parse out TAXII filters that the TAXII server can filter on. @@ -142,6 +229,7 @@ class TAXIICollectionSource(DataSource): for 'requests.get()'. """ + params = {} for filter_ in query: diff --git a/stix2/sro.py b/stix2/sro.py index af483bc..4fa0465 100644 --- a/stix2/sro.py +++ b/stix2/sro.py @@ -4,13 +4,18 @@ from collections import OrderedDict from .base import _STIXBase from .common import ExternalReference, GranularMarking +from .markings import MarkingsMixin from .properties import (BooleanProperty, IDProperty, IntegerProperty, ListProperty, ReferenceProperty, StringProperty, TimestampProperty, TypeProperty) from .utils import NOW -class Relationship(_STIXBase): +class STIXRelationshipObject(_STIXBase, MarkingsMixin): + pass + + +class Relationship(STIXRelationshipObject): _type = 'relationship' _properties = OrderedDict() @@ -45,7 +50,7 @@ class Relationship(_STIXBase): super(Relationship, self).__init__(**kwargs) -class Sighting(_STIXBase): +class Sighting(STIXRelationshipObject): _type = 'sighting' _properties = OrderedDict() _properties.update([ diff --git a/stix2/test/test_data_sources.py b/stix2/test/test_data_sources.py index 1415c34..e34d603 100644 --- a/stix2/test/test_data_sources.py +++ b/stix2/test/test_data_sources.py @@ -3,8 +3,9 @@ from taxii2client import Collection from stix2.sources import (CompositeDataSource, DataSink, DataSource, DataStore, make_id, taxii) -from stix2.sources.filters import Filter +from stix2.sources.filters import Filter, apply_common_filters from stix2.sources.memory import MemorySource, MemoryStore +from stix2.utils import deduplicate COLLECTION_URL = 'https://example.com/api1/collections/91a7b528-80eb-42ed-a74d-c6fbd5a26116/' @@ -206,39 +207,37 @@ def test_add_get_remove_filter(ds): Filter('id', '!=', 'stix object id'), Filter('labels', 'in', ["heartbleed", "malicious-activity"]), ] - invalid_filters = [ - Filter('description', '=', 'not supported field - just place holder'), - Filter('modified', '*', 'not supported operator - just place holder'), - Filter('created', '=', object()), - ] + + # Invalid filters - wont pass creation + # these filters will not be allowed to be created + # check proper errors are raised when trying to create them + + with pytest.raises(ValueError) as excinfo: + # create Filter that has an operator that is not allowed + Filter('modified', '*', 'not supported operator - just place holder') + assert str(excinfo.value) == "Filter operator '*' not supported for specified field: 'modified'" + + with pytest.raises(TypeError) as excinfo: + # create Filter that has a value type that is not allowed + Filter('created', '=', object()) + # On Python 2, the type of object() is `` On Python 3, it's ``. + assert str(excinfo.value).startswith("Filter value type") + assert str(excinfo.value).endswith("is not supported. The type must be a python immutable type or dictionary") assert len(ds.filters) == 0 - ds.add_filter(valid_filters[0]) + ds.filters.add(valid_filters[0]) assert len(ds.filters) == 1 # Addin the same filter again will have no effect since `filters` uses a set - ds.add_filter(valid_filters[0]) + ds.filters.add(valid_filters[0]) assert len(ds.filters) == 1 - ds.add_filter(valid_filters[1]) + ds.filters.add(valid_filters[1]) assert len(ds.filters) == 2 - ds.add_filter(valid_filters[2]) + ds.filters.add(valid_filters[2]) assert len(ds.filters) == 3 - # TODO: make better error messages - with pytest.raises(ValueError) as excinfo: - ds.add_filter(invalid_filters[0]) - assert str(excinfo.value) == "Filter 'field' is not a STIX 2.0 common property. Currently only STIX object common properties supported" - - with pytest.raises(ValueError) as excinfo: - ds.add_filter(invalid_filters[1]) - assert str(excinfo.value) == "Filter operation (from 'op' field) not supported" - - with pytest.raises(ValueError) as excinfo: - ds.add_filter(invalid_filters[2]) - assert str(excinfo.value) == "Filter 'value' type is not supported. The type(value) must be python immutable type or dictionary" - assert set(valid_filters) == ds.filters # remove @@ -246,7 +245,7 @@ def test_add_get_remove_filter(ds): assert len(ds.filters) == 2 - ds.add_filters(valid_filters) + ds.filters.update(valid_filters) def test_apply_common_filters(ds): @@ -320,7 +319,6 @@ def test_apply_common_filters(ds): Filter("created", ">", "2015-01-01T01:00:00.000Z"), Filter("revoked", "=", True), Filter("revoked", "!=", True), - Filter("revoked", "?", False), Filter("object_marking_refs", "=", "marking-definition--613f2e26-407d-48c7-9eca-b8e91df99dc9"), Filter("granular_markings.selectors", "in", "relationship_type"), Filter("granular_markings.marking_ref", "=", "marking-definition--5e57c739-391a-4eb3-b6be-7d15ca92d5ed"), @@ -332,7 +330,7 @@ def test_apply_common_filters(ds): ] # "Return any object whose type is not relationship" - resp = ds.apply_common_filters(stix_objs, [filters[0]]) + resp = [stix_obj for stix_obj in apply_common_filters(stix_objs, [filters[0]])] ids = [r['id'] for r in resp] assert stix_objs[0]['id'] in ids assert stix_objs[1]['id'] in ids @@ -340,138 +338,109 @@ def test_apply_common_filters(ds): assert len(ids) == 3 # "Return any object that matched id relationship--2f9a9aa9-108a-4333-83e2-4fb25add0463" - resp = ds.apply_common_filters(stix_objs, [filters[1]]) + resp = [stix_obj for stix_obj in apply_common_filters(stix_objs, [filters[1]])] assert resp[0]['id'] == stix_objs[2]['id'] assert len(resp) == 1 # "Return any object that contains remote-access-trojan in labels" - resp = ds.apply_common_filters(stix_objs, [filters[2]]) + resp = [stix_obj for stix_obj in apply_common_filters(stix_objs, [filters[2]])] assert resp[0]['id'] == stix_objs[0]['id'] assert len(resp) == 1 # "Return any object created after 2015-01-01T01:00:00.000Z" - resp = ds.apply_common_filters(stix_objs, [filters[3]]) + resp = [stix_obj for stix_obj in apply_common_filters(stix_objs, [filters[3]])] assert resp[0]['id'] == stix_objs[0]['id'] assert len(resp) == 2 # "Return any revoked object" - resp = ds.apply_common_filters(stix_objs, [filters[4]]) + resp = [stix_obj for stix_obj in apply_common_filters(stix_objs, [filters[4]])] assert resp[0]['id'] == stix_objs[2]['id'] assert len(resp) == 1 # "Return any object whose not revoked" # Note that if 'revoked' property is not present in object. # Currently we can't use such an expression to filter for... :( - resp = ds.apply_common_filters(stix_objs, [filters[5]]) + resp = [stix_obj for stix_obj in apply_common_filters(stix_objs, [filters[5]])] assert len(resp) == 0 - # Assert unknown operator for _boolean() raises exception. - with pytest.raises(ValueError) as excinfo: - ds.apply_common_filters(stix_objs, [filters[6]]) - - assert str(excinfo.value) == ("Error, filter operator: {0} not supported " - "for specified field: {1}" - .format(filters[6].op, filters[6].field)) - # "Return any object that matches marking-definition--613f2e26-407d-48c7-9eca-b8e91df99dc9 in object_marking_refs" - resp = ds.apply_common_filters(stix_objs, [filters[7]]) + resp = [stix_obj for stix_obj in apply_common_filters(stix_objs, [filters[6]])] assert resp[0]['id'] == stix_objs[2]['id'] assert len(resp) == 1 # "Return any object that contains relationship_type in their selectors AND # also has marking-definition--5e57c739-391a-4eb3-b6be-7d15ca92d5ed in marking_ref" - resp = ds.apply_common_filters(stix_objs, [filters[8], filters[9]]) + resp = [stix_obj for stix_obj in apply_common_filters(stix_objs, [filters[7], filters[8]])] assert resp[0]['id'] == stix_objs[2]['id'] assert len(resp) == 1 # "Return any object that contains CVE-2014-0160,CVE-2017-6608 in their external_id" - resp = ds.apply_common_filters(stix_objs, [filters[10]]) + resp = [stix_obj for stix_obj in apply_common_filters(stix_objs, [filters[9]])] assert resp[0]['id'] == stix_objs[3]['id'] assert len(resp) == 1 # "Return any object that matches created_by_ref identity--00000000-0000-0000-0000-b8e91df99dc9" - resp = ds.apply_common_filters(stix_objs, [filters[11]]) + resp = [stix_obj for stix_obj in apply_common_filters(stix_objs, [filters[10]])] assert len(resp) == 1 # "Return any object that matches marking-definition--613f2e26-0000-0000-0000-b8e91df99dc9 in object_marking_refs" (None) - resp = ds.apply_common_filters(stix_objs, [filters[12]]) + resp = [stix_obj for stix_obj in apply_common_filters(stix_objs, [filters[11]])] assert len(resp) == 0 # "Return any object that contains description in its selectors" (None) - resp = ds.apply_common_filters(stix_objs, [filters[13]]) + resp = [stix_obj for stix_obj in apply_common_filters(stix_objs, [filters[12]])] assert len(resp) == 0 # "Return any object that object that matches CVE in source_name" (None, case sensitive) - resp = ds.apply_common_filters(stix_objs, [filters[14]]) + resp = [stix_obj for stix_obj in apply_common_filters(stix_objs, [filters[13]])] assert len(resp) == 0 def test_filters0(ds): # "Return any object modified before 2017-01-28T13:49:53.935Z" - resp = ds.apply_common_filters(STIX_OBJS2, [Filter("modified", "<", "2017-01-28T13:49:53.935Z")]) + resp = [stix_obj for stix_obj in apply_common_filters(STIX_OBJS2, [Filter("modified", "<", "2017-01-28T13:49:53.935Z")])] assert resp[0]['id'] == STIX_OBJS2[1]['id'] assert len(resp) == 2 def test_filters1(ds): # "Return any object modified after 2017-01-28T13:49:53.935Z" - resp = ds.apply_common_filters(STIX_OBJS2, [Filter("modified", ">", "2017-01-28T13:49:53.935Z")]) + resp = [stix_obj for stix_obj in apply_common_filters(STIX_OBJS2, [Filter("modified", ">", "2017-01-28T13:49:53.935Z")])] assert resp[0]['id'] == STIX_OBJS2[0]['id'] assert len(resp) == 1 def test_filters2(ds): # "Return any object modified after or on 2017-01-28T13:49:53.935Z" - resp = ds.apply_common_filters(STIX_OBJS2, [Filter("modified", ">=", "2017-01-27T13:49:53.935Z")]) + resp = [stix_obj for stix_obj in apply_common_filters(STIX_OBJS2, [Filter("modified", ">=", "2017-01-27T13:49:53.935Z")])] assert resp[0]['id'] == STIX_OBJS2[0]['id'] assert len(resp) == 3 def test_filters3(ds): # "Return any object modified before or on 2017-01-28T13:49:53.935Z" - resp = ds.apply_common_filters(STIX_OBJS2, [Filter("modified", "<=", "2017-01-27T13:49:53.935Z")]) + resp = [stix_obj for stix_obj in apply_common_filters(STIX_OBJS2, [Filter("modified", "<=", "2017-01-27T13:49:53.935Z")])] assert resp[0]['id'] == STIX_OBJS2[1]['id'] assert len(resp) == 2 def test_filters4(ds): - fltr4 = Filter("modified", "?", "2017-01-27T13:49:53.935Z") - # Assert unknown operator for _all() raises exception. + # Assert invalid Filter cannot be created with pytest.raises(ValueError) as excinfo: - ds.apply_common_filters(STIX_OBJS2, [fltr4]) - assert str(excinfo.value) == ("Error, filter operator: {0} not supported " - "for specified field: {1}").format(fltr4.op, fltr4.field) + Filter("modified", "?", "2017-01-27T13:49:53.935Z") + assert str(excinfo.value) == ("Filter operator '?' not supported " + "for specified field: 'modified'") def test_filters5(ds): # "Return any object whose id is not indicator--d81f86b8-975b-bc0b-775e-810c5ad45a4f" - resp = ds.apply_common_filters(STIX_OBJS2, [Filter("id", "!=", "indicator--d81f86b8-975b-bc0b-775e-810c5ad45a4f")]) + resp = [stix_obj for stix_obj in apply_common_filters(STIX_OBJS2, [Filter("id", "!=", "indicator--d81f86b8-975b-bc0b-775e-810c5ad45a4f")])] assert resp[0]['id'] == STIX_OBJS2[0]['id'] assert len(resp) == 1 -def test_filters6(ds): - fltr6 = Filter("id", "?", "indicator--d81f86b8-975b-bc0b-775e-810c5ad45a4f") - # Assert unknown operator for _id() raises exception. - with pytest.raises(ValueError) as excinfo: - ds.apply_common_filters(STIX_OBJS2, [fltr6]) - - assert str(excinfo.value) == ("Error, filter operator: {0} not supported " - "for specified field: {1}").format(fltr6.op, fltr6.field) - - -def test_filters7(ds): - fltr7 = Filter("notacommonproperty", "=", "bar") - # Assert unknown field raises exception. - with pytest.raises(ValueError) as excinfo: - ds.apply_common_filters(STIX_OBJS2, [fltr7]) - - assert str(excinfo.value) == ("Error, field: {0} is not supported for " - "filtering on.").format(fltr7.field) - - def test_deduplicate(ds): - unique = ds.deduplicate(STIX_OBJS1) + unique = deduplicate(STIX_OBJS1) # Only 3 objects are unique # 2 id's vary @@ -494,17 +463,19 @@ def test_add_remove_composite_datasource(): ds2 = DataSource() ds3 = DataSink() - cds.add_data_source([ds1, ds2, ds1, ds3]) + with pytest.raises(TypeError) as excinfo: + cds.add_data_sources([ds1, ds2, ds1, ds3]) + assert str(excinfo.value) == ("DataSource (to be added) is not of type " + "stix2.DataSource. DataSource type is ''") + + cds.add_data_sources([ds1, ds2, ds1]) assert len(cds.get_all_data_sources()) == 2 - cds.remove_data_source([ds1.id, ds2.id]) + cds.remove_data_sources([ds1.id, ds2.id]) assert len(cds.get_all_data_sources()) == 0 - with pytest.raises(ValueError): - cds.remove_data_source([ds3.id]) - def test_composite_datasource_operations(): BUNDLE1 = dict(id="bundle--%s" % make_id(), @@ -515,7 +486,7 @@ def test_composite_datasource_operations(): ds1 = MemorySource(stix_data=BUNDLE1) ds2 = MemorySource(stix_data=STIX_OBJS2) - cds.add_data_source([ds1, ds2]) + cds.add_data_sources([ds1, ds2]) indicators = cds.all_versions("indicator--d81f86b9-975b-bc0b-775e-810c5ad45a4f") diff --git a/stix2/test/test_environment.py b/stix2/test/test_environment.py index 0871bb5..81f2cda 100644 --- a/stix2/test/test_environment.py +++ b/stix2/test/test_environment.py @@ -150,13 +150,11 @@ def test_environment_no_datastore(): env.query(INDICATOR_ID) assert 'Environment has no data source' in str(excinfo.value) - with pytest.raises(AttributeError) as excinfo: - env.add_filters(INDICATOR_ID) - assert 'Environment has no data source' in str(excinfo.value) - with pytest.raises(AttributeError) as excinfo: - env.add_filter(INDICATOR_ID) - assert 'Environment has no data source' in str(excinfo.value) +def test_environment_add_filters(): + env = stix2.Environment(factory=stix2.ObjectFactory()) + env.add_filters([INDICATOR_ID]) + env.add_filter(INDICATOR_ID) def test_environment_datastore_and_no_object_factory(): diff --git a/stix2/test/test_granular_markings.py b/stix2/test/test_granular_markings.py index e910ad3..f8fc803 100644 --- a/stix2/test/test_granular_markings.py +++ b/stix2/test/test_granular_markings.py @@ -1,7 +1,7 @@ import pytest -from stix2 import Malware, markings +from stix2 import TLP_RED, Malware, markings from .constants import MALWARE_MORE_KWARGS as MALWARE_KWARGS_CONST from .constants import MARKING_IDS @@ -45,6 +45,7 @@ def test_add_marking_mark_one_selector_multiple_refs(): }, ], **MALWARE_KWARGS), + MARKING_IDS[0], ), ( MALWARE_KWARGS, @@ -56,13 +57,26 @@ def test_add_marking_mark_one_selector_multiple_refs(): }, ], **MALWARE_KWARGS), + MARKING_IDS[0], + ), + ( + Malware(**MALWARE_KWARGS), + Malware( + granular_markings=[ + { + "selectors": ["description", "name"], + "marking_ref": TLP_RED.id, + }, + ], + **MALWARE_KWARGS), + TLP_RED, ), ]) def test_add_marking_mark_multiple_selector_one_refs(data): before = data[0] after = data[1] - before = markings.add_markings(before, [MARKING_IDS[0]], ["description", "name"]) + before = markings.add_markings(before, data[2], ["description", "name"]) for m in before["granular_markings"]: assert m in after["granular_markings"] @@ -347,36 +361,42 @@ def test_get_markings_positional_arguments_combinations(data): assert set(markings.get_markings(data, "x.z.foo2", False, True)) == set(["10"]) -@pytest.mark.parametrize("before", [ - Malware( - granular_markings=[ - { - "selectors": ["description"], - "marking_ref": MARKING_IDS[0] - }, - { - "selectors": ["description"], - "marking_ref": MARKING_IDS[1] - }, - ], - **MALWARE_KWARGS +@pytest.mark.parametrize("data", [ + ( + Malware( + granular_markings=[ + { + "selectors": ["description"], + "marking_ref": MARKING_IDS[0] + }, + { + "selectors": ["description"], + "marking_ref": MARKING_IDS[1] + }, + ], + **MALWARE_KWARGS + ), + [MARKING_IDS[0], MARKING_IDS[1]], ), - dict( - granular_markings=[ - { - "selectors": ["description"], - "marking_ref": MARKING_IDS[0] - }, - { - "selectors": ["description"], - "marking_ref": MARKING_IDS[1] - }, - ], - **MALWARE_KWARGS + ( + dict( + granular_markings=[ + { + "selectors": ["description"], + "marking_ref": MARKING_IDS[0] + }, + { + "selectors": ["description"], + "marking_ref": MARKING_IDS[1] + }, + ], + **MALWARE_KWARGS + ), + [MARKING_IDS[0], MARKING_IDS[1]], ), ]) -def test_remove_marking_remove_one_selector_with_multiple_refs(before): - before = markings.remove_markings(before, [MARKING_IDS[0], MARKING_IDS[1]], ["description"]) +def test_remove_marking_remove_one_selector_with_multiple_refs(data): + before = markings.remove_markings(data[0], data[1], ["description"]) assert "granular_markings" not in before diff --git a/stix2/test/test_markings.py b/stix2/test/test_markings.py index 0c6069a..456bf92 100644 --- a/stix2/test/test_markings.py +++ b/stix2/test/test_markings.py @@ -241,4 +241,14 @@ def test_marking_wrong_type_construction(): assert str(excinfo.value) == "Must supply a list, containing tuples. For example, [('property1', IntegerProperty())]" -# TODO: Add other examples +def test_campaign_add_markings(): + campaign = stix2.Campaign( + id="campaign--8e2e2d2b-17d4-4cbf-938f-98ee46b3cd3f", + created_by_ref="identity--f431f809-377b-45e0-aa1c-6a4751cae5ff", + created="2016-04-06T20:03:00Z", + modified="2016-04-06T20:03:00Z", + name="Green Group Attacks Against Finance", + description="Campaign by Green Group against a series of targets in the financial services sector.", + ) + campaign = campaign.add_markings(TLP_WHITE) + assert campaign.object_marking_refs[0] == TLP_WHITE.id diff --git a/stix2/test/test_object_markings.py b/stix2/test/test_object_markings.py index 36e8e4d..10949ab 100644 --- a/stix2/test/test_object_markings.py +++ b/stix2/test/test_object_markings.py @@ -1,7 +1,7 @@ import pytest -from stix2 import Malware, exceptions, markings +from stix2 import TLP_AMBER, Malware, exceptions, markings from .constants import FAKE_TIME, MALWARE_ID, MARKING_IDS from .constants import MALWARE_KWARGS as MALWARE_KWARGS_CONST @@ -21,18 +21,26 @@ MALWARE_KWARGS.update({ Malware(**MALWARE_KWARGS), Malware(object_marking_refs=[MARKING_IDS[0]], **MALWARE_KWARGS), + MARKING_IDS[0], ), ( MALWARE_KWARGS, dict(object_marking_refs=[MARKING_IDS[0]], **MALWARE_KWARGS), + MARKING_IDS[0], + ), + ( + Malware(**MALWARE_KWARGS), + Malware(object_marking_refs=[TLP_AMBER.id], + **MALWARE_KWARGS), + TLP_AMBER, ), ]) def test_add_markings_one_marking(data): before = data[0] after = data[1] - before = markings.add_markings(before, MARKING_IDS[0], None) + before = markings.add_markings(before, data[2], None) for m in before["object_marking_refs"]: assert m in after["object_marking_refs"] @@ -280,19 +288,28 @@ def test_remove_markings_object_level(data): **MALWARE_KWARGS), Malware(object_marking_refs=[MARKING_IDS[1]], **MALWARE_KWARGS), + [MARKING_IDS[0], MARKING_IDS[2]], ), ( dict(object_marking_refs=[MARKING_IDS[0], MARKING_IDS[1], MARKING_IDS[2]], **MALWARE_KWARGS), dict(object_marking_refs=[MARKING_IDS[1]], **MALWARE_KWARGS), + [MARKING_IDS[0], MARKING_IDS[2]], + ), + ( + Malware(object_marking_refs=[MARKING_IDS[0], MARKING_IDS[1], TLP_AMBER.id], + **MALWARE_KWARGS), + Malware(object_marking_refs=[MARKING_IDS[1]], + **MALWARE_KWARGS), + [MARKING_IDS[0], TLP_AMBER], ), ]) def test_remove_markings_multiple(data): before = data[0] after = data[1] - before = markings.remove_markings(before, [MARKING_IDS[0], MARKING_IDS[2]], None) + before = markings.remove_markings(before, data[2], None) assert before['object_marking_refs'] == after['object_marking_refs'] diff --git a/stix2/utils.py b/stix2/utils.py index ca195f6..94e7f4e 100644 --- a/stix2/utils.py +++ b/stix2/utils.py @@ -33,6 +33,34 @@ class STIXdatetime(dt.datetime): return "'%s'" % format_datetime(self) +def deduplicate(stix_obj_list): + """Deduplicate a list of STIX objects to a unique set + + Reduces a set of STIX objects to unique set by looking + at 'id' and 'modified' fields - as a unique object version + is determined by the combination of those fields + + Note: Be aware, as can be seen in the implementation + of deduplicate(),that if the "stix_obj_list" argument has + multiple STIX objects of the same version, the last object + version found in the list will be the one that is returned. + () + + Args: + stix_obj_list (list): list of STIX objects (dicts) + + Returns: + A list with a unique set of the passed list of STIX objects. + + """ + unique_objs = {} + + for obj in stix_obj_list: + unique_objs[(obj['id'], obj['modified'])] = obj + + return list(unique_objs.values()) + + def get_timestamp(): return STIXdatetime.now(tz=pytz.UTC)