diff --git a/stix2/test/test_datastore_taxii.py b/stix2/test/test_datastore_taxii.py index 0a293b7..c20fbb6 100644 --- a/stix2/test/test_datastore_taxii.py +++ b/stix2/test/test_datastore_taxii.py @@ -1,24 +1,101 @@ +from stix2 import Bundle, ThreatActor, TAXIICollectionSource, TAXIICollectionSink +from stix2.datastore.filters import Filter + +import json +import pytest + +from taxii2client import Collection, _filter_kwargs_to_query_params +from medallion.filters.basic_filter import BasicFilter + COLLECTION_URL = 'https://example.com/api1/collections/91a7b528-80eb-42ed-a74d-c6fbd5a26116/' -class MockTAXIIClient(object): +class MockTAXIICollectionEndpoint(Collection): """Mock for taxii2_client.TAXIIClient""" - pass + + def __init__(self, url, **kwargs): + super(MockTAXIICollectionEndpoint, self).__init__(url, **kwargs) + self.objects = [] + + def add_objects(self, bundle): + self._verify_can_write() + if isinstance(bundle, str): + bundle = json.loads(bundle) + for object in bundle.get("objects", []): + self.objects.append(object) + + def get_objects(self, **filter_kwargs): + self._verify_can_read() + query_params = _filter_kwargs_to_query_params(filter_kwargs) + query_params = json.loads(query_params) + full_filter = BasicFilter(query_params or {}) + objs = full_filter.process_filter( + self.objects, + ("id", "type", "version"), + [] + ) + return Bundle(objects=objs) + + def get_object(self, id, version=None): + self._verify_can_read() + query_params = None + if version: + query_params = _filter_kwargs_to_query_params({"version": version}) + if query_params: + query_params = json.loads(query_params) + full_filter = BasicFilter(query_params or {}) + objs = full_filter.process_filter( + self.objects, + ("version",), + [] + ) + return Bundle(objects=objs) @pytest.fixture -def collection(): - return Collection(COLLECTION_URL, MockTAXIIClient()) +def collection(stix_objs1): + mock = MockTAXIICollectionEndpoint(COLLECTION_URL, **{ + "id": "91a7b528-80eb-42ed-a74d-c6fbd5a26116", + "title": "Writable Collection", + "description": "This collection is a dropbox for submitting indicators", + "can_read": True, + "can_write": True, + "media_types": [ + "application/vnd.oasis.stix+json; version=2.0" + ] + }) + + mock.objects.extend(stix_objs1) + return mock def test_ds_taxii(collection): - ds = taxii.TAXIICollectionSource(collection) + ds = TAXIICollectionSource(collection) assert ds.collection is not None -def test_ds_taxii_name(collection): - ds = taxii.TAXIICollectionSource(collection) - assert ds.collection is not None +def test_add_stix2_object(collection): + tc_sink = TAXIICollectionSink(collection) + + # create new STIX threat-actor + ta = ThreatActor(name="Teddy Bear", + labels=["nation-state"], + sophistication="innovator", + resource_level="government", + goals=[ + "compromising environment NGOs", + "water-hole attacks geared towards energy sector", + ]) + + tc_sink.add(ta) + + +def test_get_stix2_object(collection): + tc_sink = TAXIICollectionSource(collection) + + objects = tc_sink.get("indicator--d81f86b9-975b-bc0b-775e-810c5ad45a4f") + + assert objects def test_parse_taxii_filters(): @@ -37,7 +114,7 @@ def test_parse_taxii_filters(): Filter("version", "=", "first") ]) - ds = taxii.TAXIICollectionSource(collection) + ds = TAXIICollectionSource(collection) taxii_filters = ds._parse_taxii_filters(query) @@ -45,7 +122,7 @@ def test_parse_taxii_filters(): def test_add_get_remove_filter(): - ds = taxii.TAXIICollectionSource(collection) + ds = TAXIICollectionSource(collection) # First 3 filters are valid, remaining properties are erroneous in some way valid_filters = [