Rework select properties to use get_dict(),
which automatically coerces values to a dictionary if possiblestix2.1
parent
aa69c38444
commit
2460fb75be
|
@ -12,6 +12,7 @@ from six import text_type
|
|||
|
||||
from .base import _Observable, _STIXBase
|
||||
from .exceptions import DictionaryKeyError
|
||||
from .utils import get_dict
|
||||
|
||||
|
||||
class Property(object):
|
||||
|
@ -181,7 +182,7 @@ class FloatProperty(Property):
|
|||
try:
|
||||
return float(value)
|
||||
except Exception:
|
||||
raise ValueError("must be an float.")
|
||||
raise ValueError("must be a float.")
|
||||
|
||||
|
||||
class BooleanProperty(Property):
|
||||
|
@ -233,7 +234,11 @@ class TimestampProperty(Property):
|
|||
class ObservableProperty(Property):
|
||||
|
||||
def clean(self, value):
|
||||
dictified = dict(value)
|
||||
try:
|
||||
dictified = get_dict(value)
|
||||
except ValueError:
|
||||
raise ValueError("The observable property must contain a dictionary")
|
||||
|
||||
from .__init__ import parse_observable # avoid circular import
|
||||
for key, obj in dictified.items():
|
||||
parsed_obj = parse_observable(obj, dictified.keys())
|
||||
|
@ -248,7 +253,11 @@ class ObservableProperty(Property):
|
|||
class DictionaryProperty(Property):
|
||||
|
||||
def clean(self, value):
|
||||
dictified = dict(value)
|
||||
try:
|
||||
dictified = get_dict(value)
|
||||
except ValueError:
|
||||
raise ValueError("The dictionary property must contain a dictionary")
|
||||
|
||||
for k in dictified.keys():
|
||||
if len(k) < 3:
|
||||
raise DictionaryKeyError(k, "shorter than 3 characters")
|
||||
|
@ -392,23 +401,25 @@ class ExtensionsProperty(DictionaryProperty):
|
|||
super(ExtensionsProperty, self).__init__(required)
|
||||
|
||||
def clean(self, value):
|
||||
if type(value) is dict:
|
||||
from .__init__ import EXT_MAP # avoid circular import
|
||||
if self.enclosing_type in EXT_MAP:
|
||||
specific_type_map = EXT_MAP[self.enclosing_type]
|
||||
for key, subvalue in value.items():
|
||||
if key in specific_type_map:
|
||||
cls = specific_type_map[key]
|
||||
if type(subvalue) is dict:
|
||||
value[key] = cls(**subvalue)
|
||||
elif type(subvalue) is cls:
|
||||
value[key] = subvalue
|
||||
else:
|
||||
raise ValueError("Cannot determine extension type.")
|
||||
else:
|
||||
raise ValueError("The key used in the extensions dictionary is not an extension type name")
|
||||
else:
|
||||
raise ValueError("The enclosing type has no extensions defined")
|
||||
else:
|
||||
try:
|
||||
dictified = get_dict(value)
|
||||
except ValueError:
|
||||
raise ValueError("The extensions property must contain a dictionary")
|
||||
return value
|
||||
|
||||
from .__init__ import EXT_MAP # avoid circular import
|
||||
if self.enclosing_type in EXT_MAP:
|
||||
specific_type_map = EXT_MAP[self.enclosing_type]
|
||||
for key, subvalue in dictified.items():
|
||||
if key in specific_type_map:
|
||||
cls = specific_type_map[key]
|
||||
if type(subvalue) is dict:
|
||||
dictified[key] = cls(**subvalue)
|
||||
elif type(subvalue) is cls:
|
||||
dictified[key] = subvalue
|
||||
else:
|
||||
raise ValueError("Cannot determine extension type.")
|
||||
else:
|
||||
raise ValueError("The key used in the extensions dictionary is not an extension type name")
|
||||
else:
|
||||
raise ValueError("The enclosing type has no extensions defined")
|
||||
return dictified
|
||||
|
|
|
@ -1,13 +1,14 @@
|
|||
import pytest
|
||||
|
||||
from stix2 import TCPExt
|
||||
from stix2.exceptions import DictionaryKeyError
|
||||
from stix2.observables import EmailMIMEComponent
|
||||
from stix2.properties import (BinaryProperty, BooleanProperty,
|
||||
DictionaryProperty, EmbeddedObjectProperty,
|
||||
EnumProperty, HashesProperty, HexProperty,
|
||||
IDProperty, IntegerProperty, ListProperty,
|
||||
Property, ReferenceProperty, StringProperty,
|
||||
TimestampProperty, TypeProperty)
|
||||
EnumProperty, ExtensionsProperty, HashesProperty,
|
||||
HexProperty, IDProperty, IntegerProperty,
|
||||
ListProperty, Property, ReferenceProperty,
|
||||
StringProperty, TimestampProperty, TypeProperty)
|
||||
|
||||
from .constants import FAKE_TIME
|
||||
|
||||
|
@ -255,3 +256,36 @@ def test_enum_property():
|
|||
|
||||
with pytest.raises(ValueError):
|
||||
enum_prop.clean('z')
|
||||
|
||||
|
||||
def test_extension_property_valid():
|
||||
ext_prop = ExtensionsProperty(enclosing_type='file')
|
||||
assert ext_prop({
|
||||
'windows-pebinary-ext': {
|
||||
'pe_type': 'exe'
|
||||
},
|
||||
})
|
||||
|
||||
|
||||
@pytest.mark.parametrize("data", [
|
||||
1,
|
||||
{'foobar-ext': {
|
||||
'pe_type': 'exe'
|
||||
}},
|
||||
{'windows-pebinary-ext': TCPExt()},
|
||||
])
|
||||
def test_extension_property_invalid(data):
|
||||
ext_prop = ExtensionsProperty(enclosing_type='file')
|
||||
with pytest.raises(ValueError):
|
||||
ext_prop.clean(data)
|
||||
|
||||
|
||||
def test_extension_property_invalid_type():
|
||||
ext_prop = ExtensionsProperty(enclosing_type='indicator')
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
ext_prop.clean({
|
||||
'windows-pebinary-ext': {
|
||||
'pe_type': 'exe'
|
||||
}}
|
||||
)
|
||||
assert 'no extensions defined' in str(excinfo.value)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import datetime as dt
|
||||
from io import StringIO
|
||||
|
||||
import pytest
|
||||
import pytz
|
||||
|
@ -17,3 +18,24 @@ eastern = pytz.timezone('US/Eastern')
|
|||
])
|
||||
def test_timestamp_formatting(dttm, timestamp):
|
||||
assert stix2.utils.format_datetime(dttm) == timestamp
|
||||
|
||||
|
||||
@pytest.mark.parametrize('data', [
|
||||
{"a": 1},
|
||||
'{"a": 1}',
|
||||
StringIO(u'{"a": 1}'),
|
||||
[("a", 1,)],
|
||||
])
|
||||
def test_get_dict(data):
|
||||
assert stix2.utils.get_dict(data)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('data', [
|
||||
1,
|
||||
[1],
|
||||
['a', 1],
|
||||
"foobar",
|
||||
])
|
||||
def test_get_dict_invalid(data):
|
||||
with pytest.raises(ValueError):
|
||||
stix2.utils.get_dict(data)
|
||||
|
|
|
@ -63,11 +63,17 @@ def get_dict(data):
|
|||
"""
|
||||
|
||||
if type(data) is dict:
|
||||
obj = data
|
||||
return data
|
||||
else:
|
||||
try:
|
||||
obj = json.loads(data)
|
||||
return json.loads(data)
|
||||
except TypeError:
|
||||
obj = json.load(data)
|
||||
|
||||
return obj
|
||||
pass
|
||||
try:
|
||||
return json.load(data)
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
return dict(data)
|
||||
except (ValueError, TypeError):
|
||||
raise ValueError("Cannot convert '%s' to dictionary." % str(data))
|
||||
|
|
Loading…
Reference in New Issue