Move query_by_type() to DataStoreMixin

stix2.0
Chris Lenk 2018-03-16 15:41:08 -04:00
parent eeb94562f9
commit fd6d9f74e9
4 changed files with 78 additions and 12 deletions

View File

@ -99,6 +99,25 @@ class DataStoreMixin(object):
except AttributeError:
raise AttributeError('%s has no data source to query' % self.__class__.__name__)
def query_by_type(self, *args, **kwargs):
"""Retrieve all objects of the given STIX object type.
Translate query_by_type() call to the appropriate DataSource call.
Args:
obj_type (str): The STIX object type to retrieve.
filters (list, optional): A list of additional filters to apply to
the query.
Returns:
list: The STIX objects that matched the query.
"""
try:
return self.source.query_by_type(*args, **kwargs)
except AttributeError:
raise AttributeError('%s has no data source to query' % self.__class__.__name__)
def creator_of(self, *args, **kwargs):
"""Retrieve the Identity refered to by the object's `created_by_ref`.
@ -277,6 +296,29 @@ class DataSource(with_metaclass(ABCMeta)):
"""
def query_by_type(self, obj_type='indicator', filters=None):
"""Retrieve all objects of the given STIX object type.
This helper function is a shortcut that calls query() under the hood.
Args:
obj_type (str): The STIX object type to retrieve.
filters (list, optional): A list of additional filters to apply to
the query.
Returns:
list: The STIX objects that matched the query.
"""
filter_list = [Filter('type', '=', obj_type)]
if filters:
if isinstance(filters, list):
filter_list += filters
else:
filter_list.append(filters)
return self.query(filter_list)
def creator_of(self, obj):
"""Retrieve the Identity refered to by the object's `created_by_ref`.
@ -542,6 +584,35 @@ class CompositeDataSource(DataSource):
return all_data
def query_by_type(self, *args, **kwargs):
"""Retrieve all objects of the given STIX object type.
Federate the query to all DataSources attached to the
Composite Data Source.
Args:
obj_type (str): The STIX object type to retrieve.
filters (list, optional): A list of additional filters to apply to
the query.
Returns:
list: The STIX objects that matched the query.
"""
if not self.has_data_sources():
raise AttributeError('CompositeDataSource has no data sources')
results = []
for ds in self.data_sources:
results.extend(ds.query_by_type(*args, **kwargs))
# remove exact duplicates (where duplicates are STIX 2.0
# objects with the same 'id' and 'modified' values)
if len(results) > 0:
results = deduplicate(results)
return results
def relationships(self, *args, **kwargs):
"""Retrieve Relationships involving the given STIX object.

View File

@ -90,10 +90,12 @@ class Environment(DataStoreMixin):
.. automethod:: get
.. automethod:: all_versions
.. automethod:: query
.. automethod:: query_by_type
.. automethod:: creator_of
.. automethod:: relationships
.. automethod:: related_to
.. automethod:: add
"""
def __init__(self, factory=ObjectFactory(), store=None, source=None, sink=None):

View File

@ -164,6 +164,10 @@ def test_environment_no_datastore():
env.query(INDICATOR_ID)
assert 'Environment has no data source' in str(excinfo.value)
with pytest.raises(AttributeError) as excinfo:
env.query_by_type('indicator')
assert 'Environment has no data source' in str(excinfo.value)
with pytest.raises(AttributeError) as excinfo:
env.relationships(INDICATOR_ID)
assert 'Environment has no data source' in str(excinfo.value)

View File

@ -13,7 +13,6 @@ from . import Report as _Report
from . import ThreatActor as _ThreatActor
from . import Tool as _Tool
from . import Vulnerability as _Vulnerability
from .datastore.filters import Filter
from .datastore.memory import MemoryStore
from .environment import Environment
@ -24,6 +23,7 @@ create = _environ.create
get = _environ.get
all_versions = _environ.all_versions
query = _environ.query
query_by_type = _environ.query_by_type
creator_of = _environ.creator_of
relationships = _environ.relationships
related_to = _environ.related_to
@ -80,17 +80,6 @@ for obj_type in STIX_OBJS:
# Functions to get all objects of a specific type
def query_by_type(obj_type='indicator', filters=None):
filter_list = [Filter('type', '=', obj_type)]
if filters:
if isinstance(filters, list):
filter_list += filters
else:
filter_list.append(filters)
return query(filter_list)
def attack_patterns(filters=None):
return query_by_type('attack-pattern', filters)