revamp code in MockTAXIICollectionEndpoint, add more tests

master
Emmanuelle Vargas-Gonzalez 2018-11-29 18:36:37 -05:00
parent 06716e3cfd
commit c62b9e92e7
3 changed files with 80 additions and 37 deletions

View File

@ -32,36 +32,38 @@ class MockTAXIICollectionEndpoint(Collection):
def get_objects(self, **filter_kwargs):
self._verify_can_read()
query_params = _filter_kwargs_to_query_params(filter_kwargs)
if not isinstance(query_params, dict):
query_params = json.loads(query_params, encoding='utf-8')
full_filter = BasicFilter(query_params or {})
assert isinstance(query_params, dict)
full_filter = BasicFilter(query_params)
objs = full_filter.process_filter(
self.objects,
("id", "type", "version"),
[],
)
if objs:
return stix2.v20.Bundle(objects=objs)
return stix2.v21.Bundle(objects=objs)
else:
resp = Response()
resp.status_code = 404
resp.raise_for_status()
def get_object(self, id, version=None, accept=''):
def get_object(self, id, **filter_kwargs):
self._verify_can_read()
query_params = None
if version:
query_params = _filter_kwargs_to_query_params({"version": version})
if query_params:
query_params = json.loads(query_params, encoding='utf-8')
full_filter = BasicFilter(query_params or {})
objs = full_filter.process_filter(
self.objects,
("version",),
[],
)
if objs:
return stix2.v20.Bundle(objects=objs)
query_params = _filter_kwargs_to_query_params(filter_kwargs)
assert isinstance(query_params, dict)
full_filter = BasicFilter(query_params)
# In this endpoint we must first filter objects by id beforehand.
objects = [x for x in self.objects if x["id"] == id]
if objects:
filtered_objects = full_filter.process_filter(
objects,
("version",),
[],
)
else:
filtered_objects = []
if filtered_objects:
return stix2.v21.Bundle(objects=filtered_objects)
else:
resp = Response()
resp.status_code = 404
@ -167,6 +169,20 @@ def test_add_list_object(collection, indicator):
tc_sink.add([ta, indicator])
def test_get_object_found(collection):
tc_source = stix2.TAXIICollectionSource(collection)
result = tc_source.query([
stix2.Filter("id", "=", "indicator--00000000-0000-4000-8000-000000000001"),
])
assert result
def test_get_object_not_found(collection):
tc_source = stix2.TAXIICollectionSource(collection)
result = tc_source.get("indicator--00000000-0000-4000-8000-000000000005")
assert result is None
def test_add_stix2_bundle_object(collection):
tc_sink = stix2.TAXIICollectionSink(collection)

View File

@ -8,8 +8,8 @@ from stix2.properties import (
ERROR_INVALID_ID, BinaryProperty, BooleanProperty, DictionaryProperty,
EmbeddedObjectProperty, EnumProperty, ExtensionsProperty, FloatProperty,
HashesProperty, HexProperty, IDProperty, IntegerProperty, ListProperty,
Property, ReferenceProperty, StringProperty, TimestampProperty,
TypeProperty,
Property, ReferenceProperty, STIXObjectProperty, StringProperty,
TimestampProperty, TypeProperty,
)
from stix2.v20.common import MarkingProperty
@ -496,3 +496,14 @@ def test_marking_property_error():
mark_prop.clean('my-marking')
assert str(excinfo.value) == "must be a Statement, TLP Marking or a registered marking."
def test_stix_property_not_compliant_spec():
# This is a 2.0 test only...
indicator = stix2.v20.Indicator(spec_version="2.0", allow_custom=True, **constants.INDICATOR_KWARGS)
stix_prop = STIXObjectProperty(spec_version="2.0")
with pytest.raises(ValueError) as excinfo:
stix_prop.clean(indicator)
assert "Spec version 2.0 bundles don't yet support containing objects of a different spec version." in str(excinfo.value)

View File

@ -32,9 +32,8 @@ class MockTAXIICollectionEndpoint(Collection):
def get_objects(self, **filter_kwargs):
self._verify_can_read()
query_params = _filter_kwargs_to_query_params(filter_kwargs)
if not isinstance(query_params, dict):
query_params = json.loads(query_params, encoding='utf-8')
full_filter = BasicFilter(query_params or {})
assert isinstance(query_params, dict)
full_filter = BasicFilter(query_params)
objs = full_filter.process_filter(
self.objects,
("id", "type", "version"),
@ -47,21 +46,24 @@ class MockTAXIICollectionEndpoint(Collection):
resp.status_code = 404
resp.raise_for_status()
def get_object(self, id, version=None, accept=''):
def get_object(self, id, **filter_kwargs):
self._verify_can_read()
query_params = None
if version:
query_params = _filter_kwargs_to_query_params({"version": version})
if query_params:
query_params = json.loads(query_params, encoding='utf-8')
full_filter = BasicFilter(query_params or {})
objs = full_filter.process_filter(
self.objects,
("version",),
[],
)
if objs:
return stix2.v21.Bundle(objects=objs)
query_params = _filter_kwargs_to_query_params(filter_kwargs)
assert isinstance(query_params, dict)
full_filter = BasicFilter(query_params)
# In this endpoint we must first filter objects by id beforehand.
objects = [x for x in self.objects if x["id"] == id]
if objects:
filtered_objects = full_filter.process_filter(
objects,
("version",),
[],
)
else:
filtered_objects = []
if filtered_objects:
return stix2.v21.Bundle(objects=filtered_objects)
else:
resp = Response()
resp.status_code = 404
@ -167,6 +169,20 @@ def test_add_list_object(collection, indicator):
tc_sink.add([ta, indicator])
def test_get_object_found(collection):
tc_source = stix2.TAXIICollectionSource(collection)
result = tc_source.query([
stix2.Filter("id", "=", "indicator--00000000-0000-4000-8000-000000000001"),
])
assert result
def test_get_object_not_found(collection):
tc_source = stix2.TAXIICollectionSource(collection)
result = tc_source.get("indicator--00000000-0000-4000-8000-000000000012")
assert result is None
def test_add_stix2_bundle_object(collection):
tc_sink = stix2.TAXIICollectionSink(collection)