revamp code in MockTAXIICollectionEndpoint, add more tests

master
Emmanuelle Vargas-Gonzalez 2018-11-29 18:36:37 -05:00
parent 06716e3cfd
commit c62b9e92e7
3 changed files with 80 additions and 37 deletions

View File

@ -32,36 +32,38 @@ class MockTAXIICollectionEndpoint(Collection):
def get_objects(self, **filter_kwargs): def get_objects(self, **filter_kwargs):
self._verify_can_read() self._verify_can_read()
query_params = _filter_kwargs_to_query_params(filter_kwargs) query_params = _filter_kwargs_to_query_params(filter_kwargs)
if not isinstance(query_params, dict): assert isinstance(query_params, dict)
query_params = json.loads(query_params, encoding='utf-8') full_filter = BasicFilter(query_params)
full_filter = BasicFilter(query_params or {})
objs = full_filter.process_filter( objs = full_filter.process_filter(
self.objects, self.objects,
("id", "type", "version"), ("id", "type", "version"),
[], [],
) )
if objs: if objs:
return stix2.v20.Bundle(objects=objs) return stix2.v21.Bundle(objects=objs)
else: else:
resp = Response() resp = Response()
resp.status_code = 404 resp.status_code = 404
resp.raise_for_status() resp.raise_for_status()
def get_object(self, id, version=None, accept=''): def get_object(self, id, **filter_kwargs):
self._verify_can_read() self._verify_can_read()
query_params = None query_params = _filter_kwargs_to_query_params(filter_kwargs)
if version: assert isinstance(query_params, dict)
query_params = _filter_kwargs_to_query_params({"version": version}) full_filter = BasicFilter(query_params)
if query_params:
query_params = json.loads(query_params, encoding='utf-8') # In this endpoint we must first filter objects by id beforehand.
full_filter = BasicFilter(query_params or {}) objects = [x for x in self.objects if x["id"] == id]
objs = full_filter.process_filter( if objects:
self.objects, filtered_objects = full_filter.process_filter(
objects,
("version",), ("version",),
[], [],
) )
if objs: else:
return stix2.v20.Bundle(objects=objs) filtered_objects = []
if filtered_objects:
return stix2.v21.Bundle(objects=filtered_objects)
else: else:
resp = Response() resp = Response()
resp.status_code = 404 resp.status_code = 404
@ -167,6 +169,20 @@ def test_add_list_object(collection, indicator):
tc_sink.add([ta, indicator]) tc_sink.add([ta, indicator])
def test_get_object_found(collection):
tc_source = stix2.TAXIICollectionSource(collection)
result = tc_source.query([
stix2.Filter("id", "=", "indicator--00000000-0000-4000-8000-000000000001"),
])
assert result
def test_get_object_not_found(collection):
tc_source = stix2.TAXIICollectionSource(collection)
result = tc_source.get("indicator--00000000-0000-4000-8000-000000000005")
assert result is None
def test_add_stix2_bundle_object(collection): def test_add_stix2_bundle_object(collection):
tc_sink = stix2.TAXIICollectionSink(collection) tc_sink = stix2.TAXIICollectionSink(collection)

View File

@ -8,8 +8,8 @@ from stix2.properties import (
ERROR_INVALID_ID, BinaryProperty, BooleanProperty, DictionaryProperty, ERROR_INVALID_ID, BinaryProperty, BooleanProperty, DictionaryProperty,
EmbeddedObjectProperty, EnumProperty, ExtensionsProperty, FloatProperty, EmbeddedObjectProperty, EnumProperty, ExtensionsProperty, FloatProperty,
HashesProperty, HexProperty, IDProperty, IntegerProperty, ListProperty, HashesProperty, HexProperty, IDProperty, IntegerProperty, ListProperty,
Property, ReferenceProperty, StringProperty, TimestampProperty, Property, ReferenceProperty, STIXObjectProperty, StringProperty,
TypeProperty, TimestampProperty, TypeProperty,
) )
from stix2.v20.common import MarkingProperty from stix2.v20.common import MarkingProperty
@ -496,3 +496,14 @@ def test_marking_property_error():
mark_prop.clean('my-marking') mark_prop.clean('my-marking')
assert str(excinfo.value) == "must be a Statement, TLP Marking or a registered marking." assert str(excinfo.value) == "must be a Statement, TLP Marking or a registered marking."
def test_stix_property_not_compliant_spec():
# This is a 2.0 test only...
indicator = stix2.v20.Indicator(spec_version="2.0", allow_custom=True, **constants.INDICATOR_KWARGS)
stix_prop = STIXObjectProperty(spec_version="2.0")
with pytest.raises(ValueError) as excinfo:
stix_prop.clean(indicator)
assert "Spec version 2.0 bundles don't yet support containing objects of a different spec version." in str(excinfo.value)

View File

@ -32,9 +32,8 @@ class MockTAXIICollectionEndpoint(Collection):
def get_objects(self, **filter_kwargs): def get_objects(self, **filter_kwargs):
self._verify_can_read() self._verify_can_read()
query_params = _filter_kwargs_to_query_params(filter_kwargs) query_params = _filter_kwargs_to_query_params(filter_kwargs)
if not isinstance(query_params, dict): assert isinstance(query_params, dict)
query_params = json.loads(query_params, encoding='utf-8') full_filter = BasicFilter(query_params)
full_filter = BasicFilter(query_params or {})
objs = full_filter.process_filter( objs = full_filter.process_filter(
self.objects, self.objects,
("id", "type", "version"), ("id", "type", "version"),
@ -47,21 +46,24 @@ class MockTAXIICollectionEndpoint(Collection):
resp.status_code = 404 resp.status_code = 404
resp.raise_for_status() resp.raise_for_status()
def get_object(self, id, version=None, accept=''): def get_object(self, id, **filter_kwargs):
self._verify_can_read() self._verify_can_read()
query_params = None query_params = _filter_kwargs_to_query_params(filter_kwargs)
if version: assert isinstance(query_params, dict)
query_params = _filter_kwargs_to_query_params({"version": version}) full_filter = BasicFilter(query_params)
if query_params:
query_params = json.loads(query_params, encoding='utf-8') # In this endpoint we must first filter objects by id beforehand.
full_filter = BasicFilter(query_params or {}) objects = [x for x in self.objects if x["id"] == id]
objs = full_filter.process_filter( if objects:
self.objects, filtered_objects = full_filter.process_filter(
objects,
("version",), ("version",),
[], [],
) )
if objs: else:
return stix2.v21.Bundle(objects=objs) filtered_objects = []
if filtered_objects:
return stix2.v21.Bundle(objects=filtered_objects)
else: else:
resp = Response() resp = Response()
resp.status_code = 404 resp.status_code = 404
@ -167,6 +169,20 @@ def test_add_list_object(collection, indicator):
tc_sink.add([ta, indicator]) tc_sink.add([ta, indicator])
def test_get_object_found(collection):
tc_source = stix2.TAXIICollectionSource(collection)
result = tc_source.query([
stix2.Filter("id", "=", "indicator--00000000-0000-4000-8000-000000000001"),
])
assert result
def test_get_object_not_found(collection):
tc_source = stix2.TAXIICollectionSource(collection)
result = tc_source.get("indicator--00000000-0000-4000-8000-000000000012")
assert result is None
def test_add_stix2_bundle_object(collection): def test_add_stix2_bundle_object(collection):
tc_sink = stix2.TAXIICollectionSink(collection) tc_sink = stix2.TAXIICollectionSink(collection)