couple of changes after merging against master

pull/1/head
Emmanuelle Vargas-Gonzalez 2021-02-19 10:05:56 -05:00
parent a8b6fa2100
commit 5067a3ff76
5 changed files with 28 additions and 53 deletions

View File

@ -4,8 +4,8 @@ import six
from .base import _cls_init
from .registration import (
_register_marking, _register_object, _register_observable,
_register_observable_extension,
_get_extension_class, _register_extension, _register_marking,
_register_object, _register_observable,
)

View File

@ -1,9 +1,6 @@
"""STIX2 Core parsing methods."""
import copy
import importlib
import pkgutil
import re
from . import registry
from .exceptions import ParseError

View File

@ -1,14 +1,13 @@
import re
from . import registry
from .base import _DomainObject, _Observable
from . import registry, version
from .base import _DomainObject
from .exceptions import DuplicateRegistrationError
from .properties import _validate_type
from .utils import PREFIX_21_REGEX, get_class_hierarchy_names
from .version import DEFAULT_VERSION
def _register_object(new_type, version=DEFAULT_VERSION):
def _register_object(new_type, version=version.DEFAULT_VERSION):
"""Register a custom STIX Object type.
Args:
@ -32,7 +31,7 @@ def _register_object(new_type, version=DEFAULT_VERSION):
properties = new_type._properties
if not version:
version = DEFAULT_VERSION
version = version.DEFAULT_VERSION
if version == "2.1":
for prop_name, prop in properties.items():
@ -45,7 +44,7 @@ def _register_object(new_type, version=DEFAULT_VERSION):
OBJ_MAP[new_type._type] = new_type
def _register_marking(new_marking, version=DEFAULT_VERSION):
def _register_marking(new_marking, version=version.DEFAULT_VERSION):
"""Register a custom STIX Marking Definition type.
Args:
@ -59,7 +58,7 @@ def _register_marking(new_marking, version=DEFAULT_VERSION):
properties = new_marking._properties
if not version:
version = DEFAULT_VERSION
version = version.DEFAULT_VERSION
_validate_type(mark_type, version)
@ -74,7 +73,7 @@ def _register_marking(new_marking, version=DEFAULT_VERSION):
OBJ_MAP_MARKING[mark_type] = new_marking
def _register_observable(new_observable, version=DEFAULT_VERSION):
def _register_observable(new_observable, version=version.DEFAULT_VERSION):
"""Register a custom STIX Cyber Observable type.
Args:
@ -86,7 +85,7 @@ def _register_observable(new_observable, version=DEFAULT_VERSION):
properties = new_observable._properties
if not version:
version = DEFAULT_VERSION
version = version.DEFAULT_VERSION
if version == "2.0":
# If using STIX2.0, check properties ending in "_ref/s" are ObjectReferenceProperties
@ -133,27 +132,25 @@ def _register_observable(new_observable, version=DEFAULT_VERSION):
OBJ_MAP_OBSERVABLE[new_observable._type] = new_observable
def _register_observable_extension(
observable, new_extension, version=DEFAULT_VERSION,
def _get_extension_class(extension_uuid, version):
"""Retrieve a registered class Extension"""
return registry.STIX2_OBJ_MAPS[version]['extensions'].get(extension_uuid)
def _register_extension(
new_extension, version=version.DEFAULT_VERSION,
):
"""Register a custom extension to a STIX Cyber Observable type.
"""Register a custom extension to any STIX Object type.
Args:
observable: An observable class or instance
new_extension (class): A class to register in the Observables
Extensions map.
new_extension (class): A class to register in the Extensions map.
version (str): Which STIX2 version to use. (e.g. "2.0", "2.1").
Defaults to the latest supported version.
"""
obs_class = observable if isinstance(observable, type) else \
type(observable)
ext_type = new_extension._type
properties = new_extension._properties
if not issubclass(obs_class, _Observable):
raise ValueError("'observable' must be a valid Observable class!")
_validate_type(ext_type, version)
if not new_extension._properties:
@ -163,37 +160,18 @@ def _register_observable_extension(
)
if version == "2.1":
if not ext_type.endswith('-ext'):
if not (ext_type.endswith('-ext') or ext_type.startswith('extension-definition--')):
raise ValueError(
"Invalid extension type name '%s': must end with '-ext'." %
"Invalid extension type name '%s': must end with '-ext' or start with 'extension-definition--<UUID>'." %
ext_type,
)
for prop_name, prop_value in properties.items():
for prop_name in properties.keys():
if not re.match(PREFIX_21_REGEX, prop_name):
raise ValueError("Property name '%s' must begin with an alpha character." % prop_name)
try:
observable_type = observable._type
except AttributeError:
raise ValueError(
"Unknown observable type. Custom observables must be "
"created with the @CustomObservable decorator.",
)
EXT_MAP = registry.STIX2_OBJ_MAPS[version]['extensions']
OBJ_MAP_OBSERVABLE = registry.STIX2_OBJ_MAPS[version]['observables']
EXT_MAP = registry.STIX2_OBJ_MAPS[version]['observable-extensions']
try:
if ext_type in EXT_MAP[observable_type].keys():
raise DuplicateRegistrationError("Observable Extension", ext_type)
EXT_MAP[observable_type][ext_type] = new_extension
except KeyError:
if observable_type not in OBJ_MAP_OBSERVABLE:
raise ValueError(
"Unknown observable type '%s'. Custom observables "
"must be created with the @CustomObservable decorator."
% observable_type,
)
else:
EXT_MAP[observable_type] = {ext_type: new_extension}
if ext_type in EXT_MAP:
raise DuplicateRegistrationError("Extension", ext_type)
EXT_MAP[ext_type] = new_extension

View File

@ -37,7 +37,7 @@ def _collect_stix2_mappings():
STIX2_OBJ_MAPS[ver] = {}
STIX2_OBJ_MAPS[ver]['objects'] = mod.OBJ_MAP
STIX2_OBJ_MAPS[ver]['observables'] = mod.OBJ_MAP_OBSERVABLE
STIX2_OBJ_MAPS[ver]['observable-extensions'] = mod.EXT_MAP
STIX2_OBJ_MAPS[ver]['extensions'] = mod.EXT_MAP
elif re.match(r'^stix2\.v2[0-9]\.common$', name) and is_pkg is False:
ver = _stix_vid_to_version(stix_vid)
mod = importlib.import_module(name, str(top_level_module.__name__))

View File

@ -40,7 +40,7 @@ class Bundle(_STIXBase21):
def get_obj(self, obj_uuid):
if "objects" in self._inner:
found_objs = [elem for elem in self.objects if elem['id'] == obj_uuid]
if found_objs == []:
if not found_objs:
raise KeyError("'%s' does not match the id property of any of the bundle's objects" % obj_uuid)
return found_objs
else: