Rework select properties to use get_dict(),

which automatically coerces values to a dictionary if possible
stix2.1
clenk 2017-05-16 09:25:08 -04:00
parent aa69c38444
commit 2460fb75be
4 changed files with 104 additions and 31 deletions

View File

@ -12,6 +12,7 @@ from six import text_type
from .base import _Observable, _STIXBase from .base import _Observable, _STIXBase
from .exceptions import DictionaryKeyError from .exceptions import DictionaryKeyError
from .utils import get_dict
class Property(object): class Property(object):
@ -181,7 +182,7 @@ class FloatProperty(Property):
try: try:
return float(value) return float(value)
except Exception: except Exception:
raise ValueError("must be an float.") raise ValueError("must be a float.")
class BooleanProperty(Property): class BooleanProperty(Property):
@ -233,7 +234,11 @@ class TimestampProperty(Property):
class ObservableProperty(Property): class ObservableProperty(Property):
def clean(self, value): def clean(self, value):
dictified = dict(value) try:
dictified = get_dict(value)
except ValueError:
raise ValueError("The observable property must contain a dictionary")
from .__init__ import parse_observable # avoid circular import from .__init__ import parse_observable # avoid circular import
for key, obj in dictified.items(): for key, obj in dictified.items():
parsed_obj = parse_observable(obj, dictified.keys()) parsed_obj = parse_observable(obj, dictified.keys())
@ -248,7 +253,11 @@ class ObservableProperty(Property):
class DictionaryProperty(Property): class DictionaryProperty(Property):
def clean(self, value): def clean(self, value):
dictified = dict(value) try:
dictified = get_dict(value)
except ValueError:
raise ValueError("The dictionary property must contain a dictionary")
for k in dictified.keys(): for k in dictified.keys():
if len(k) < 3: if len(k) < 3:
raise DictionaryKeyError(k, "shorter than 3 characters") raise DictionaryKeyError(k, "shorter than 3 characters")
@ -392,23 +401,25 @@ class ExtensionsProperty(DictionaryProperty):
super(ExtensionsProperty, self).__init__(required) super(ExtensionsProperty, self).__init__(required)
def clean(self, value): def clean(self, value):
if type(value) is dict: try:
from .__init__ import EXT_MAP # avoid circular import dictified = get_dict(value)
if self.enclosing_type in EXT_MAP: except ValueError:
specific_type_map = EXT_MAP[self.enclosing_type]
for key, subvalue in value.items():
if key in specific_type_map:
cls = specific_type_map[key]
if type(subvalue) is dict:
value[key] = cls(**subvalue)
elif type(subvalue) is cls:
value[key] = subvalue
else:
raise ValueError("Cannot determine extension type.")
else:
raise ValueError("The key used in the extensions dictionary is not an extension type name")
else:
raise ValueError("The enclosing type has no extensions defined")
else:
raise ValueError("The extensions property must contain a dictionary") raise ValueError("The extensions property must contain a dictionary")
return value
from .__init__ import EXT_MAP # avoid circular import
if self.enclosing_type in EXT_MAP:
specific_type_map = EXT_MAP[self.enclosing_type]
for key, subvalue in dictified.items():
if key in specific_type_map:
cls = specific_type_map[key]
if type(subvalue) is dict:
dictified[key] = cls(**subvalue)
elif type(subvalue) is cls:
dictified[key] = subvalue
else:
raise ValueError("Cannot determine extension type.")
else:
raise ValueError("The key used in the extensions dictionary is not an extension type name")
else:
raise ValueError("The enclosing type has no extensions defined")
return dictified

View File

@ -1,13 +1,14 @@
import pytest import pytest
from stix2 import TCPExt
from stix2.exceptions import DictionaryKeyError from stix2.exceptions import DictionaryKeyError
from stix2.observables import EmailMIMEComponent from stix2.observables import EmailMIMEComponent
from stix2.properties import (BinaryProperty, BooleanProperty, from stix2.properties import (BinaryProperty, BooleanProperty,
DictionaryProperty, EmbeddedObjectProperty, DictionaryProperty, EmbeddedObjectProperty,
EnumProperty, HashesProperty, HexProperty, EnumProperty, ExtensionsProperty, HashesProperty,
IDProperty, IntegerProperty, ListProperty, HexProperty, IDProperty, IntegerProperty,
Property, ReferenceProperty, StringProperty, ListProperty, Property, ReferenceProperty,
TimestampProperty, TypeProperty) StringProperty, TimestampProperty, TypeProperty)
from .constants import FAKE_TIME from .constants import FAKE_TIME
@ -255,3 +256,36 @@ def test_enum_property():
with pytest.raises(ValueError): with pytest.raises(ValueError):
enum_prop.clean('z') enum_prop.clean('z')
def test_extension_property_valid():
ext_prop = ExtensionsProperty(enclosing_type='file')
assert ext_prop({
'windows-pebinary-ext': {
'pe_type': 'exe'
},
})
@pytest.mark.parametrize("data", [
1,
{'foobar-ext': {
'pe_type': 'exe'
}},
{'windows-pebinary-ext': TCPExt()},
])
def test_extension_property_invalid(data):
ext_prop = ExtensionsProperty(enclosing_type='file')
with pytest.raises(ValueError):
ext_prop.clean(data)
def test_extension_property_invalid_type():
ext_prop = ExtensionsProperty(enclosing_type='indicator')
with pytest.raises(ValueError) as excinfo:
ext_prop.clean({
'windows-pebinary-ext': {
'pe_type': 'exe'
}}
)
assert 'no extensions defined' in str(excinfo.value)

View File

@ -1,4 +1,5 @@
import datetime as dt import datetime as dt
from io import StringIO
import pytest import pytest
import pytz import pytz
@ -17,3 +18,24 @@ eastern = pytz.timezone('US/Eastern')
]) ])
def test_timestamp_formatting(dttm, timestamp): def test_timestamp_formatting(dttm, timestamp):
assert stix2.utils.format_datetime(dttm) == timestamp assert stix2.utils.format_datetime(dttm) == timestamp
@pytest.mark.parametrize('data', [
{"a": 1},
'{"a": 1}',
StringIO(u'{"a": 1}'),
[("a", 1,)],
])
def test_get_dict(data):
assert stix2.utils.get_dict(data)
@pytest.mark.parametrize('data', [
1,
[1],
['a', 1],
"foobar",
])
def test_get_dict_invalid(data):
with pytest.raises(ValueError):
stix2.utils.get_dict(data)

View File

@ -63,11 +63,17 @@ def get_dict(data):
""" """
if type(data) is dict: if type(data) is dict:
obj = data return data
else: else:
try: try:
obj = json.loads(data) return json.loads(data)
except TypeError: except TypeError:
obj = json.load(data) pass
try:
return obj return json.load(data)
except AttributeError:
pass
try:
return dict(data)
except (ValueError, TypeError):
raise ValueError("Cannot convert '%s' to dictionary." % str(data))