diff --git a/stix2/custom.py b/stix2/custom.py index 5912dbf..0d83742 100644 --- a/stix2/custom.py +++ b/stix2/custom.py @@ -4,8 +4,8 @@ import six from .base import _cls_init from .registration import ( - _register_marking, _register_object, _register_observable, - _register_observable_extension, + _get_extension_class, _register_extension, _register_marking, + _register_object, _register_observable, ) diff --git a/stix2/parsing.py b/stix2/parsing.py index ea6ffd1..001f523 100644 --- a/stix2/parsing.py +++ b/stix2/parsing.py @@ -1,9 +1,6 @@ """STIX2 Core parsing methods.""" import copy -import importlib -import pkgutil -import re from . import registry from .exceptions import ParseError diff --git a/stix2/registration.py b/stix2/registration.py index 4ec019a..28d43ba 100644 --- a/stix2/registration.py +++ b/stix2/registration.py @@ -1,14 +1,13 @@ import re -from . import registry -from .base import _DomainObject, _Observable +from . import registry, version +from .base import _DomainObject from .exceptions import DuplicateRegistrationError from .properties import _validate_type from .utils import PREFIX_21_REGEX, get_class_hierarchy_names -from .version import DEFAULT_VERSION -def _register_object(new_type, version=DEFAULT_VERSION): +def _register_object(new_type, version=version.DEFAULT_VERSION): """Register a custom STIX Object type. Args: @@ -32,7 +31,7 @@ def _register_object(new_type, version=DEFAULT_VERSION): properties = new_type._properties if not version: - version = DEFAULT_VERSION + version = version.DEFAULT_VERSION if version == "2.1": for prop_name, prop in properties.items(): @@ -45,7 +44,7 @@ def _register_object(new_type, version=DEFAULT_VERSION): OBJ_MAP[new_type._type] = new_type -def _register_marking(new_marking, version=DEFAULT_VERSION): +def _register_marking(new_marking, version=version.DEFAULT_VERSION): """Register a custom STIX Marking Definition type. Args: @@ -59,7 +58,7 @@ def _register_marking(new_marking, version=DEFAULT_VERSION): properties = new_marking._properties if not version: - version = DEFAULT_VERSION + version = version.DEFAULT_VERSION _validate_type(mark_type, version) @@ -74,7 +73,7 @@ def _register_marking(new_marking, version=DEFAULT_VERSION): OBJ_MAP_MARKING[mark_type] = new_marking -def _register_observable(new_observable, version=DEFAULT_VERSION): +def _register_observable(new_observable, version=version.DEFAULT_VERSION): """Register a custom STIX Cyber Observable type. Args: @@ -86,7 +85,7 @@ def _register_observable(new_observable, version=DEFAULT_VERSION): properties = new_observable._properties if not version: - version = DEFAULT_VERSION + version = version.DEFAULT_VERSION if version == "2.0": # If using STIX2.0, check properties ending in "_ref/s" are ObjectReferenceProperties @@ -133,27 +132,25 @@ def _register_observable(new_observable, version=DEFAULT_VERSION): OBJ_MAP_OBSERVABLE[new_observable._type] = new_observable -def _register_observable_extension( - observable, new_extension, version=DEFAULT_VERSION, +def _get_extension_class(extension_uuid, version): + """Retrieve a registered class Extension""" + return registry.STIX2_OBJ_MAPS[version]['extensions'].get(extension_uuid) + + +def _register_extension( + new_extension, version=version.DEFAULT_VERSION, ): - """Register a custom extension to a STIX Cyber Observable type. + """Register a custom extension to any STIX Object type. Args: - observable: An observable class or instance - new_extension (class): A class to register in the Observables - Extensions map. + new_extension (class): A class to register in the Extensions map. version (str): Which STIX2 version to use. (e.g. "2.0", "2.1"). Defaults to the latest supported version. """ - obs_class = observable if isinstance(observable, type) else \ - type(observable) ext_type = new_extension._type properties = new_extension._properties - if not issubclass(obs_class, _Observable): - raise ValueError("'observable' must be a valid Observable class!") - _validate_type(ext_type, version) if not new_extension._properties: @@ -163,37 +160,18 @@ def _register_observable_extension( ) if version == "2.1": - if not ext_type.endswith('-ext'): + if not (ext_type.endswith('-ext') or ext_type.startswith('extension-definition--')): raise ValueError( - "Invalid extension type name '%s': must end with '-ext'." % + "Invalid extension type name '%s': must end with '-ext' or start with 'extension-definition--'." % ext_type, ) - for prop_name, prop_value in properties.items(): + for prop_name in properties.keys(): if not re.match(PREFIX_21_REGEX, prop_name): raise ValueError("Property name '%s' must begin with an alpha character." % prop_name) - try: - observable_type = observable._type - except AttributeError: - raise ValueError( - "Unknown observable type. Custom observables must be " - "created with the @CustomObservable decorator.", - ) + EXT_MAP = registry.STIX2_OBJ_MAPS[version]['extensions'] - OBJ_MAP_OBSERVABLE = registry.STIX2_OBJ_MAPS[version]['observables'] - EXT_MAP = registry.STIX2_OBJ_MAPS[version]['observable-extensions'] - - try: - if ext_type in EXT_MAP[observable_type].keys(): - raise DuplicateRegistrationError("Observable Extension", ext_type) - EXT_MAP[observable_type][ext_type] = new_extension - except KeyError: - if observable_type not in OBJ_MAP_OBSERVABLE: - raise ValueError( - "Unknown observable type '%s'. Custom observables " - "must be created with the @CustomObservable decorator." - % observable_type, - ) - else: - EXT_MAP[observable_type] = {ext_type: new_extension} + if ext_type in EXT_MAP: + raise DuplicateRegistrationError("Extension", ext_type) + EXT_MAP[ext_type] = new_extension diff --git a/stix2/registry.py b/stix2/registry.py index 3dcc3a5..90e2826 100644 --- a/stix2/registry.py +++ b/stix2/registry.py @@ -37,7 +37,7 @@ def _collect_stix2_mappings(): STIX2_OBJ_MAPS[ver] = {} STIX2_OBJ_MAPS[ver]['objects'] = mod.OBJ_MAP STIX2_OBJ_MAPS[ver]['observables'] = mod.OBJ_MAP_OBSERVABLE - STIX2_OBJ_MAPS[ver]['observable-extensions'] = mod.EXT_MAP + STIX2_OBJ_MAPS[ver]['extensions'] = mod.EXT_MAP elif re.match(r'^stix2\.v2[0-9]\.common$', name) and is_pkg is False: ver = _stix_vid_to_version(stix_vid) mod = importlib.import_module(name, str(top_level_module.__name__)) diff --git a/stix2/v21/bundle.py b/stix2/v21/bundle.py index 5497da5..990dfc1 100644 --- a/stix2/v21/bundle.py +++ b/stix2/v21/bundle.py @@ -40,7 +40,7 @@ class Bundle(_STIXBase21): def get_obj(self, obj_uuid): if "objects" in self._inner: found_objs = [elem for elem in self.objects if elem['id'] == obj_uuid] - if found_objs == []: + if not found_objs: raise KeyError("'%s' does not match the id property of any of the bundle's objects" % obj_uuid) return found_objs else: