Refactor common ID check.

stix2.1
Greg Back 2017-02-01 13:44:57 -06:00
parent b5ab54b6a9
commit b4eb6c1fd1
1 changed files with 16 additions and 28 deletions

View File

@ -21,17 +21,29 @@ def format_datetime(dt):
class _STIXBase(collections.Mapping):
"""Base class for STIX object types"""
def _check_kwargs(self, **kwargs):
class_name = self.__class__.__name__
@classmethod
def _make_id(cls):
return cls._type + "--" + str(uuid.uuid4())
@classmethod
def _check_kwargs(cls, **kwargs):
class_name = cls.__name__
# Ensure that, if provided, the 'type' kwarg is correct.
required_type = self.__class__._type
required_type = cls._type
if not kwargs.get('type'):
kwargs['type'] = required_type
if kwargs['type'] != required_type:
msg = "{0} must have type='{1}'."
raise ValueError(msg.format(class_name, required_type))
id_prefix = cls._type + "--"
if not kwargs.get('id'):
kwargs['id'] = cls._make_id()
if not kwargs['id'].startswith(id_prefix):
msg = "{0} id values must begin with '{1}'."
raise ValueError(msg.format(class_name, id_prefix))
return kwargs
def __init__(self, **kwargs):
@ -84,11 +96,6 @@ class Bundle(_STIXBase):
# TODO: remove once we check all the fields in the right order
kwargs = self._check_kwargs(**kwargs)
if not kwargs.get('id'):
kwargs['id'] = 'bundle--' + str(uuid.uuid4())
if not kwargs['id'].startswith('bundle--'):
raise ValueError("Bundle id values must begin with 'bundle--'.")
if not kwargs.get('spec_version'):
kwargs['spec_version'] = '2.0'
if kwargs['spec_version'] != '2.0':
@ -142,11 +149,6 @@ class Indicator(_STIXBase):
# TODO: remove once we check all the fields in the right order
kwargs = self._check_kwargs(**kwargs)
if not kwargs.get('id'):
kwargs['id'] = 'indicator--' + str(uuid.uuid4())
if not kwargs['id'].startswith('indicator--'):
raise ValueError("Indicator id values must begin with 'indicator--'.")
if not kwargs.get('labels'):
raise ValueError("Missing required field for Indicator: 'labels'.")
@ -202,16 +204,6 @@ class Malware(_STIXBase):
# TODO: remove once we check all the fields in the right order
kwargs = self._check_kwargs(**kwargs)
if not kwargs.get('type'):
kwargs['type'] = 'malware'
if kwargs['type'] != 'malware':
raise ValueError("Malware must have type='malware'.")
if not kwargs.get('id'):
kwargs['id'] = 'malware--' + str(uuid.uuid4())
if not kwargs['id'].startswith('malware--'):
raise ValueError("Malware id values must begin with 'malware--'.")
if not kwargs.get('labels'):
raise ValueError("Missing required field for Malware: 'labels'.")
@ -274,11 +266,6 @@ class Relationship(_STIXBase):
# if we won't need it?
now = datetime.datetime.now(tz=pytz.UTC)
if not kwargs.get('type'):
kwargs['type'] = 'relationship'
if kwargs['type'] != 'relationship':
raise ValueError("Relationship must have type='relationship'.")
if not kwargs.get('id'):
kwargs['id'] = 'relationship--' + str(uuid.uuid4())
if not kwargs['id'].startswith('relationship--'):
@ -301,6 +288,7 @@ class Relationship(_STIXBase):
'created': kwargs.get('created', now),
'modified': kwargs.get('modified', now),
'relationship_type': kwargs['relationship_type'],
'target_ref': kwargs['target_ref'],
})
super(Relationship, self).__init__(**kwargs)