diff --git a/stix2/__init__.py b/stix2/__init__.py index 763c4d4..4a9ec75 100644 --- a/stix2/__init__.py +++ b/stix2/__init__.py @@ -3,7 +3,7 @@ # flake8: noqa from .bundle import Bundle -from .observables import Artifact, AutonomousSystem, File +from .observables import Artifact, AutonomousSystem, EmailAddress, File from .other import ExternalReference, KillChainPhase, MarkingDefinition, \ GranularMarking, StatementMarking, TLPMarking from .sdo import AttackPattern, Campaign, CourseOfAction, Identity, Indicator, \ @@ -35,11 +35,12 @@ OBJ_MAP = { OBJ_MAP_OBSERVABLE = { 'artifact': Artifact, 'autonomous-system': AutonomousSystem, + 'email-address': EmailAddress, 'file': File, } -def parse(data, observable=False): +def parse(data): """Deserialize a string or file-like object into a STIX object""" obj = get_dict(data) @@ -49,13 +50,28 @@ def parse(data, observable=False): pass else: try: - if observable: - obj_class = OBJ_MAP_OBSERVABLE[obj['type']] - else: - obj_class = OBJ_MAP[obj['type']] + obj_class = OBJ_MAP[obj['type']] except KeyError: # TODO handle custom objects raise ValueError("Can't parse unknown object type '%s'!" % obj['type']) return obj_class(**obj) return obj + + +def parse_observable(data, _valid_refs): + """Deserialize a string or file-like object into a STIX Cyber Observable + object. + """ + + obj = get_dict(data) + obj['_valid_refs'] = _valid_refs + + if 'type' not in obj: + raise ValueError("'type' is a required field!") + try: + obj_class = OBJ_MAP_OBSERVABLE[obj['type']] + except KeyError: + # TODO handle custom objects + raise ValueError("Can't parse unknown object type '%s'!" % obj['type']) + return obj_class(**obj) diff --git a/stix2/base.py b/stix2/base.py index 5b58f06..eaf317c 100644 --- a/stix2/base.py +++ b/stix2/base.py @@ -5,7 +5,7 @@ import datetime as dt import json from .exceptions import ExtraFieldsError, ImmutableError, InvalidValueError, \ - MissingFieldsError + InvalidObjRefError, MissingFieldsError from .utils import format_datetime, get_timestamp, NOW __all__ = ['STIXJSONEncoder', '_STIXBase'] @@ -102,4 +102,20 @@ class _STIXBase(collections.Mapping): class Observable(_STIXBase): - pass + + def __init__(self, **kwargs): + self._STIXBase__valid_refs = kwargs.pop('_valid_refs') + super(Observable, self).__init__(**kwargs) + + def _check_property(self, prop_name, prop, kwargs): + super(Observable, self)._check_property(prop_name, prop, kwargs) + if prop_name.endswith('_ref'): + ref = kwargs[prop_name].split('--', 1)[0] + if ref not in self._STIXBase__valid_refs: + raise InvalidObjRefError(self.__class__, prop_name, "'%s' is not a valid object in local scope" % ref) + if prop_name.endswith('_refs'): + for r in kwargs[prop_name]: + ref = r.split('--', 1)[0] + if ref not in self._STIXBase__valid_refs: + raise InvalidObjRefError(self.__class__, prop_name, "'%s' is not a valid object in local scope" % ref) + # TODO also check the type of the object referenced, not just that the key exists diff --git a/stix2/exceptions.py b/stix2/exceptions.py index 61cec79..c23f20d 100644 --- a/stix2/exceptions.py +++ b/stix2/exceptions.py @@ -62,3 +62,17 @@ class DictionaryKeyError(STIXError, ValueError): def __str__(self): msg = "Invliad dictionary key {0.key}: ({0.reason})." return msg.format(self) + + +class InvalidObjRefError(STIXError, ValueError): + """A STIX Cyber Observable Object contains an invalid object reference.""" + + def __init__(self, cls, prop_name, reason): + super(InvalidObjRefError, self).__init__() + self.cls = cls + self.prop_name = prop_name + self.reason = reason + + def __str__(self): + msg = "Invalid object reference for '{0.cls.__name__}:{0.prop_name}': {0.reason}" + return msg.format(self) diff --git a/stix2/observables.py b/stix2/observables.py index 4e72cae..7cad32a 100644 --- a/stix2/observables.py +++ b/stix2/observables.py @@ -5,7 +5,7 @@ from .base import Observable # HashesProperty, HexProperty, IDProperty, # IntegerProperty, ListProperty, ReferenceProperty, # StringProperty, TimestampProperty, TypeProperty) -from .properties import BinaryProperty, HashesProperty, IntegerProperty, StringProperty, TypeProperty +from .properties import BinaryProperty, HashesProperty, IntegerProperty, ObjectReferenceProperty, StringProperty, TypeProperty class Artifact(Observable): @@ -29,6 +29,16 @@ class AutonomousSystem(Observable): } +class EmailAddress(Observable): + _type = 'email-address' + _properties = { + 'type': TypeProperty(_type), + 'value': StringProperty(required=True), + 'display_name': StringProperty(), + 'belongs_to_ref': ObjectReferenceProperty(), + } + + class File(Observable): _type = 'file' _properties = { diff --git a/stix2/properties.py b/stix2/properties.py index 5aae467..57ebeca 100644 --- a/stix2/properties.py +++ b/stix2/properties.py @@ -220,9 +220,9 @@ class ObservableProperty(Property): def clean(self, value): dictified = dict(value) - from .__init__ import parse # avoid circular import + from .__init__ import parse_observable # avoid circular import for key, obj in dictified.items(): - parsed_obj = parse(obj, observable=True) + parsed_obj = parse_observable(obj, dictified.keys()) if not issubclass(type(parsed_obj), Observable): raise ValueError("Objects in an observable property must be " "Cyber Observable Objects") @@ -337,12 +337,5 @@ class SelectorProperty(Property): return value -class ObjectReferenceProperty(Property): - def _init(self, valid_refs=None): - self.valid_refs = valid_refs - super(ObjectReferenceProperty, self).__init__() - - def clean(self, value): - if value not in self.valid_refs: - raise ValueError("must refer to observable objects in the same " - "Observable Objects container.") +class ObjectReferenceProperty(StringProperty): + pass diff --git a/stix2/test/test_observed_data.py b/stix2/test/test_observed_data.py index 5472597..28388c8 100644 --- a/stix2/test/test_observed_data.py +++ b/stix2/test/test_observed_data.py @@ -6,6 +6,7 @@ import pytz import stix2 from .constants import OBSERVED_DATA_ID +from ..exceptions import InvalidValueError EXPECTED = """{ "created": "2016-04-06T19:58:16Z", @@ -133,4 +134,22 @@ def test_parse_autonomous_system_valid(data): assert odata.objects["0"].rir == "ARIN" +@pytest.mark.parametrize("data", [ + """"1": { + "type": "email-address", + "value": "john@example.com", + "display_name": "John Doe", + "belongs_to_ref": "0" + }""", +]) +def test_parse_email_address(data): + odata_str = re.compile('\}.+\},', re.DOTALL).sub('}, %s},' % data, EXPECTED) + odata = stix2.parse(odata_str) + assert odata.objects["1"].type == "email-address" + + odata_str = re.compile('"belongs_to_ref": "0"', re.DOTALL).sub('"belongs_to_ref": "3"', odata_str) + with pytest.raises(InvalidValueError): + stix2.parse(odata_str) + + # TODO: Add other examples