From 2460fb75be82e03291f7dd0fd22a31fcffffa12e Mon Sep 17 00:00:00 2001 From: clenk Date: Tue, 16 May 2017 09:25:08 -0400 Subject: [PATCH] Rework select properties to use get_dict(), which automatically coerces values to a dictionary if possible --- stix2/properties.py | 55 +++++++++++++++++++++-------------- stix2/test/test_properties.py | 42 +++++++++++++++++++++++--- stix2/test/test_utils.py | 22 ++++++++++++++ stix2/utils.py | 16 ++++++---- 4 files changed, 104 insertions(+), 31 deletions(-) diff --git a/stix2/properties.py b/stix2/properties.py index ede41db..64c986f 100644 --- a/stix2/properties.py +++ b/stix2/properties.py @@ -12,6 +12,7 @@ from six import text_type from .base import _Observable, _STIXBase from .exceptions import DictionaryKeyError +from .utils import get_dict class Property(object): @@ -181,7 +182,7 @@ class FloatProperty(Property): try: return float(value) except Exception: - raise ValueError("must be an float.") + raise ValueError("must be a float.") class BooleanProperty(Property): @@ -233,7 +234,11 @@ class TimestampProperty(Property): class ObservableProperty(Property): def clean(self, value): - dictified = dict(value) + try: + dictified = get_dict(value) + except ValueError: + raise ValueError("The observable property must contain a dictionary") + from .__init__ import parse_observable # avoid circular import for key, obj in dictified.items(): parsed_obj = parse_observable(obj, dictified.keys()) @@ -248,7 +253,11 @@ class ObservableProperty(Property): class DictionaryProperty(Property): def clean(self, value): - dictified = dict(value) + try: + dictified = get_dict(value) + except ValueError: + raise ValueError("The dictionary property must contain a dictionary") + for k in dictified.keys(): if len(k) < 3: raise DictionaryKeyError(k, "shorter than 3 characters") @@ -392,23 +401,25 @@ class ExtensionsProperty(DictionaryProperty): super(ExtensionsProperty, self).__init__(required) def clean(self, value): - if type(value) is dict: - from .__init__ import EXT_MAP # avoid circular import - if self.enclosing_type in EXT_MAP: - specific_type_map = EXT_MAP[self.enclosing_type] - for key, subvalue in value.items(): - if key in specific_type_map: - cls = specific_type_map[key] - if type(subvalue) is dict: - value[key] = cls(**subvalue) - elif type(subvalue) is cls: - value[key] = subvalue - else: - raise ValueError("Cannot determine extension type.") - else: - raise ValueError("The key used in the extensions dictionary is not an extension type name") - else: - raise ValueError("The enclosing type has no extensions defined") - else: + try: + dictified = get_dict(value) + except ValueError: raise ValueError("The extensions property must contain a dictionary") - return value + + from .__init__ import EXT_MAP # avoid circular import + if self.enclosing_type in EXT_MAP: + specific_type_map = EXT_MAP[self.enclosing_type] + for key, subvalue in dictified.items(): + if key in specific_type_map: + cls = specific_type_map[key] + if type(subvalue) is dict: + dictified[key] = cls(**subvalue) + elif type(subvalue) is cls: + dictified[key] = subvalue + else: + raise ValueError("Cannot determine extension type.") + else: + raise ValueError("The key used in the extensions dictionary is not an extension type name") + else: + raise ValueError("The enclosing type has no extensions defined") + return dictified diff --git a/stix2/test/test_properties.py b/stix2/test/test_properties.py index 4aa4973..5dc0084 100644 --- a/stix2/test/test_properties.py +++ b/stix2/test/test_properties.py @@ -1,13 +1,14 @@ import pytest +from stix2 import TCPExt from stix2.exceptions import DictionaryKeyError from stix2.observables import EmailMIMEComponent from stix2.properties import (BinaryProperty, BooleanProperty, DictionaryProperty, EmbeddedObjectProperty, - EnumProperty, HashesProperty, HexProperty, - IDProperty, IntegerProperty, ListProperty, - Property, ReferenceProperty, StringProperty, - TimestampProperty, TypeProperty) + EnumProperty, ExtensionsProperty, HashesProperty, + HexProperty, IDProperty, IntegerProperty, + ListProperty, Property, ReferenceProperty, + StringProperty, TimestampProperty, TypeProperty) from .constants import FAKE_TIME @@ -255,3 +256,36 @@ def test_enum_property(): with pytest.raises(ValueError): enum_prop.clean('z') + + +def test_extension_property_valid(): + ext_prop = ExtensionsProperty(enclosing_type='file') + assert ext_prop({ + 'windows-pebinary-ext': { + 'pe_type': 'exe' + }, + }) + + +@pytest.mark.parametrize("data", [ + 1, + {'foobar-ext': { + 'pe_type': 'exe' + }}, + {'windows-pebinary-ext': TCPExt()}, +]) +def test_extension_property_invalid(data): + ext_prop = ExtensionsProperty(enclosing_type='file') + with pytest.raises(ValueError): + ext_prop.clean(data) + + +def test_extension_property_invalid_type(): + ext_prop = ExtensionsProperty(enclosing_type='indicator') + with pytest.raises(ValueError) as excinfo: + ext_prop.clean({ + 'windows-pebinary-ext': { + 'pe_type': 'exe' + }} + ) + assert 'no extensions defined' in str(excinfo.value) diff --git a/stix2/test/test_utils.py b/stix2/test/test_utils.py index 3eee491..a70853c 100644 --- a/stix2/test/test_utils.py +++ b/stix2/test/test_utils.py @@ -1,4 +1,5 @@ import datetime as dt +from io import StringIO import pytest import pytz @@ -17,3 +18,24 @@ eastern = pytz.timezone('US/Eastern') ]) def test_timestamp_formatting(dttm, timestamp): assert stix2.utils.format_datetime(dttm) == timestamp + + +@pytest.mark.parametrize('data', [ + {"a": 1}, + '{"a": 1}', + StringIO(u'{"a": 1}'), + [("a", 1,)], +]) +def test_get_dict(data): + assert stix2.utils.get_dict(data) + + +@pytest.mark.parametrize('data', [ + 1, + [1], + ['a', 1], + "foobar", +]) +def test_get_dict_invalid(data): + with pytest.raises(ValueError): + stix2.utils.get_dict(data) diff --git a/stix2/utils.py b/stix2/utils.py index a2493e2..ed12cdf 100644 --- a/stix2/utils.py +++ b/stix2/utils.py @@ -63,11 +63,17 @@ def get_dict(data): """ if type(data) is dict: - obj = data + return data else: try: - obj = json.loads(data) + return json.loads(data) except TypeError: - obj = json.load(data) - - return obj + pass + try: + return json.load(data) + except AttributeError: + pass + try: + return dict(data) + except (ValueError, TypeError): + raise ValueError("Cannot convert '%s' to dictionary." % str(data))