diff --git a/stix2/__init__.py b/stix2/__init__.py index 9b7e1c3..ce974b6 100644 --- a/stix2/__init__.py +++ b/stix2/__init__.py @@ -6,6 +6,18 @@ import uuid import pytz +COMMON_PROPERTIES = { + 'type': { + 'default': (lambda x: x._type), + 'validate': (lambda x, val: val == x._type), + 'error_msg': "{type} must have {field}='{expected}'.", + }, + 'id': {}, + 'created': {}, + 'modified': {}, +} + + def format_datetime(dt): # TODO: how to handle naive datetime @@ -29,13 +41,18 @@ class _STIXBase(collections.Mapping): def _check_kwargs(cls, **kwargs): class_name = cls.__name__ - # Ensure that, if provided, the 'type' kwarg is correct. - required_type = cls._type - if not kwargs.get('type'): - kwargs['type'] = required_type - if kwargs['type'] != required_type: - msg = "{0} must have type='{1}'." - raise ValueError(msg.format(class_name, required_type)) + for prop_name, prop_metadata in cls._properties.items(): + if prop_metadata.get('default') and prop_name not in kwargs: + kwargs[prop_name] = prop_metadata['default'](cls) + + if prop_metadata.get('validate'): + if not prop_metadata['validate'](cls, kwargs[prop_name]): + msg = prop_metadata['error_msg'].format( + type=class_name, + field=prop_name, + expected=prop_metadata.get('default')(cls), + ) + raise ValueError(msg) id_prefix = cls._type + "--" if not kwargs.get('id'): @@ -85,12 +102,16 @@ class _STIXBase(collections.Mapping): class Bundle(_STIXBase): _type = 'bundle' - _properties = [ - 'type', - 'id', - 'spec_version', - 'objects', - ] + _properties = { + 'type': { + 'default': (lambda x: x._type), + 'validate': (lambda x, val: val == x._type), + 'error_msg': "{type} must have {field}='{expected}'.", + }, + 'id': {}, + 'spec_version': {}, + 'objects': {}, + } def __init__(self, **kwargs): # TODO: remove once we check all the fields in the right order @@ -119,15 +140,12 @@ class Bundle(_STIXBase): class Indicator(_STIXBase): _type = 'indicator' - _properties = [ - 'type', - 'id', - 'created', - 'modified', - 'labels', - 'pattern', - 'valid_from', - ] + _properties = COMMON_PROPERTIES.copy() + _properties.update({ + 'labels': {}, + 'pattern': {}, + 'valid_from': {}, + }) def __init__(self, **kwargs): # TODO: @@ -177,14 +195,11 @@ class Indicator(_STIXBase): class Malware(_STIXBase): _type = 'malware' - _properties = [ - 'type', - 'id', - 'created', - 'modified', - 'labels', - 'name', - ] + _properties = COMMON_PROPERTIES.copy() + _properties.update({ + 'labels': {}, + 'name': {}, + }) def __init__(self, **kwargs): # TODO: @@ -230,15 +245,12 @@ class Malware(_STIXBase): class Relationship(_STIXBase): _type = 'relationship' - _properties = [ - 'type', - 'id', - 'created', - 'modified', - 'relationship_type', - 'source_ref', - 'target_ref', - ] + _properties = COMMON_PROPERTIES.copy() + _properties.update({ + 'relationship_type': {}, + 'source_ref': {}, + 'target_ref': {}, + }) # Explicitly define the first three kwargs to make readable Relationship declarations. def __init__(self, source_ref=None, relationship_type=None, target_ref=None,