From 8843e9b19039bf525f5b5d4e3c9b6d5205b7286c Mon Sep 17 00:00:00 2001 From: Greg Back Date: Wed, 1 Feb 2017 13:27:24 -0600 Subject: [PATCH] WIP: refactor common fields. --- stix2/__init__.py | 84 +++++++++++++++++++++++++++++++---------------- 1 file changed, 56 insertions(+), 28 deletions(-) diff --git a/stix2/__init__.py b/stix2/__init__.py index da775a9..26e3c1f 100644 --- a/stix2/__init__.py +++ b/stix2/__init__.py @@ -21,6 +21,31 @@ def format_datetime(dt): class _STIXBase(collections.Mapping): """Base class for STIX object types""" + def _check_kwargs(self, **kwargs): + class_name = self.__class__.__name__ + + # Ensure that, if provided, the 'type' kwarg is correct. + required_type = self.__class__._type + if not kwargs.get('type'): + kwargs['type'] = required_type + if kwargs['type'] != required_type: + msg = "{0} must have type='{1}'." + raise ValueError(msg.format(class_name, required_type)) + + return kwargs + + def __init__(self, **kwargs): + # Detect any keyword arguments not allowed for a specific type + extra_kwargs = list(set(kwargs) - set(self.__class__._properties)) + if extra_kwargs: + raise TypeError("unexpected keyword arguments: " + str(extra_kwargs)) + + # TODO: move all of this back into init, once we check the right things + # in the right order. + self._check_kwargs(**kwargs) + + self._inner = kwargs + def __getitem__(self, key): return self._inner[key] @@ -47,11 +72,15 @@ class _STIXBase(collections.Mapping): class Bundle(_STIXBase): + _type = 'bundle' + _properties = [ + 'type', + 'id', + 'spec_version', + 'objects', + ] + def __init__(self, type="bundle", id=None, spec_version="2.0", objects=None): - - if type != 'bundle': - raise ValueError("Bundle must have type='bundle'.") - id = id or 'bundle--' + str(uuid.uuid4()) if not id.startswith('bundle--'): raise ValueError("Bundle id values must begin with 'bundle--'.") @@ -61,12 +90,13 @@ class Bundle(_STIXBase): objects = objects or [] - self._inner = { + kwargs = { 'type': type, 'id': id, 'spec_version': spec_version, 'objects': objects, } + super(Bundle, self).__init__(**kwargs) def _dict(self): bundle = { @@ -83,6 +113,7 @@ class Bundle(_STIXBase): class Indicator(_STIXBase): + _type = 'indicator' _properties = [ 'type', 'id', @@ -110,10 +141,8 @@ class Indicator(_STIXBase): # if we won't need it? now = datetime.datetime.now(tz=pytz.UTC) - if not kwargs.get('type'): - kwargs['type'] = 'indicator' - if kwargs['type'] != 'indicator': - raise ValueError("Indicator must have type='indicator'.") + # TODO: remove once we check all the fields in the right order + kwargs = self._check_kwargs(**kwargs) if not kwargs.get('id'): kwargs['id'] = 'indicator--' + str(uuid.uuid4()) @@ -126,19 +155,16 @@ class Indicator(_STIXBase): if not kwargs.get('pattern'): raise ValueError("Missing required field for Indicator: 'pattern'.") - extra_kwargs = list(set(kwargs.keys()) - set(self._properties)) - if extra_kwargs: - raise TypeError("unexpected keyword arguments: " + str(extra_kwargs)) - - self._inner = { - 'type': kwargs['type'], + kwargs.update({ + # 'type': kwargs['type'], 'id': kwargs['id'], 'created': kwargs.get('created', now), 'modified': kwargs.get('modified', now), 'labels': kwargs['labels'], 'pattern': kwargs['pattern'], 'valid_from': kwargs.get('valid_from', now), - } + }) + super(Indicator, self).__init__(**kwargs) def _dict(self): return { @@ -154,6 +180,7 @@ class Indicator(_STIXBase): class Malware(_STIXBase): + _type = 'malware' _properties = [ 'type', 'id', @@ -178,6 +205,9 @@ class Malware(_STIXBase): # if we won't need it? now = datetime.datetime.now(tz=pytz.UTC) + # TODO: remove once we check all the fields in the right order + kwargs = self._check_kwargs(**kwargs) + if not kwargs.get('type'): kwargs['type'] = 'malware' if kwargs['type'] != 'malware': @@ -194,18 +224,15 @@ class Malware(_STIXBase): if not kwargs.get('name'): raise ValueError("Missing required field for Malware: 'name'.") - extra_kwargs = list(set(kwargs.keys()) - set(self._properties)) - if extra_kwargs: - raise TypeError("unexpected keyword arguments: " + str(extra_kwargs)) - - self._inner = { + kwargs.update({ 'type': kwargs['type'], 'id': kwargs['id'], 'created': kwargs.get('created', now), 'modified': kwargs.get('modified', now), 'labels': kwargs['labels'], 'name': kwargs['name'], - } + }) + super(Malware, self).__init__(**kwargs) def _dict(self): return { @@ -220,6 +247,7 @@ class Malware(_STIXBase): class Relationship(_STIXBase): + _type = 'relationship' _properties = [ 'type', 'id', @@ -242,6 +270,9 @@ class Relationship(_STIXBase): # - description + # TODO: remove once we check all the fields in the right order + kwargs = self._check_kwargs(**kwargs) + if source_ref and not kwargs.get('source_ref'): kwargs['source_ref'] = source_ref if relationship_type and not kwargs.get('relationship_type'): @@ -276,11 +307,7 @@ class Relationship(_STIXBase): elif isinstance(kwargs['target_ref'], _STIXBase): kwargs['target_ref'] = kwargs['target_ref'].id - extra_kwargs = list(set(kwargs.keys()) - set(self._properties)) - if extra_kwargs: - raise TypeError("unexpected keyword arguments: " + str(extra_kwargs)) - - self._inner = { + kwargs.update({ 'type': kwargs['type'], 'id': kwargs['id'], 'created': kwargs.get('created', now), @@ -288,7 +315,8 @@ class Relationship(_STIXBase): 'relationship_type': kwargs['relationship_type'], 'source_ref': kwargs['source_ref'], 'target_ref': kwargs['target_ref'], - } + }) + super(Relationship, self).__init__(**kwargs) def _dict(self): return {