diff --git a/stix2/properties.py b/stix2/properties.py index 55c5625..ba308c8 100644 --- a/stix2/properties.py +++ b/stix2/properties.py @@ -14,7 +14,10 @@ from .base import _STIXBase from .exceptions import CustomContentError, DictionaryKeyError, STIXError from .parsing import parse, parse_observable from .registry import STIX2_OBJ_MAPS -from .utils import _get_dict, get_class_hierarchy_names, parse_into_datetime +from .utils import ( + STIXTypeClass, _get_dict, get_class_hierarchy_names, get_type_from_id, + is_object, is_stix_type, parse_into_datetime, to_enum, +) from .version import DEFAULT_VERSION try: @@ -501,7 +504,6 @@ class HexProperty(Property): class ReferenceProperty(Property): - _OBJECT_CATEGORIES = {"SDO", "SCO", "SRO"} _WHITELIST, _BLACKLIST = range(2) def __init__(self, valid_types=None, invalid_types=None, spec_version=DEFAULT_VERSION, **kwargs): @@ -525,9 +527,22 @@ class ReferenceProperty(Property): if valid_types is not None and len(valid_types) == 0: raise ValueError("Impossible type constraint: empty whitelist") - self.types = set(valid_types or invalid_types) self.auth_type = self._WHITELIST if valid_types else self._BLACKLIST + # Divide type requirements into generic type classes and specific + # types. With respect to strings, values recognized as STIXTypeClass + # enum names are generic; all else are specifics. + self.generics = set() + self.specifics = set() + types = valid_types or invalid_types + for type_ in types: + try: + enum_value = to_enum(type_, STIXTypeClass) + except KeyError: + self.specifics.add(type_) + else: + self.generics.add(enum_value) + super(ReferenceProperty, self).__init__(**kwargs) def clean(self, value, allow_custom): @@ -537,7 +552,7 @@ class ReferenceProperty(Property): _validate_id(value, self.spec_version, None) - obj_type = value[:value.index('--')] + obj_type = get_type_from_id(value) # Only comes into play when inverting a hybrid whitelist. # E.g. if the possible generic categories are A, B, C, then the @@ -548,8 +563,8 @@ class ReferenceProperty(Property): # blacklist. blacklist_exceptions = set() - generics = self.types & self._OBJECT_CATEGORIES - specifics = self.types - generics + generics = self.generics + specifics = self.specifics auth_type = self.auth_type if allow_custom and auth_type == self._WHITELIST and generics: # If allowing customization and using a whitelist, and if generic @@ -560,20 +575,19 @@ class ReferenceProperty(Property): # in the wrong category. I.e. flip the whitelist set to a # blacklist of a complementary set. auth_type = self._BLACKLIST - generics = self._OBJECT_CATEGORIES - generics + generics = set(STIXTypeClass) - generics blacklist_exceptions, specifics = specifics, blacklist_exceptions if auth_type == self._WHITELIST: - type_ok = _type_in_generic_set( - obj_type, generics, self.spec_version + type_ok = is_stix_type( + obj_type, self.spec_version, *generics ) or obj_type in specifics else: type_ok = ( - not _type_in_generic_set( - obj_type, generics, self.spec_version, - ) - and obj_type not in specifics + not is_stix_type( + obj_type, self.spec_version, *generics + ) and obj_type not in specifics ) or obj_type in blacklist_exceptions if not type_ok: @@ -585,9 +599,8 @@ class ReferenceProperty(Property): # We need to figure out whether the referenced object is custom or # not. No good way to do that at present... just check if # unregistered and for the "x-" type prefix, for now? - has_custom = not _type_in_generic_set( - obj_type, self._OBJECT_CATEGORIES, self.spec_version, - ) or obj_type.startswith("x-") + has_custom = not is_object(obj_type, self.spec_version) \ + or obj_type.startswith("x-") if not allow_custom and has_custom: raise CustomContentError( @@ -597,34 +610,6 @@ class ReferenceProperty(Property): return value, has_custom -def _type_in_generic_set(type_, type_set, spec_version): - """ - Determine if type_ is in the given set, with respect to the given STIX - version. This handles special generic category values "SDO", "SCO", - "SRO", so it's not a simple set containment check. The type_set is - implicitly "OR"d. - """ - type_maps = STIX2_OBJ_MAPS[spec_version] - - result = False - for type_id in type_set: - if type_id == "SDO": - result = type_ in type_maps["objects"] and type_ not in [ - "relationship", "sighting", - ] # sigh - elif type_id == "SCO": - result = type_ in type_maps["observables"] - elif type_id == "SRO": - result = type_ in ["relationship", "sighting"] - else: - raise ValueError("Unrecognized generic type category: " + type_id) - - if result: - break - - return result - - SELECTOR_REGEX = re.compile(r"^([a-z0-9_-]{3,250}(\.(\[\d+\]|[a-z0-9_-]{1,250}))*|id)$") diff --git a/stix2/test/v21/test_location.py b/stix2/test/v21/test_location.py index e912fba..9f42255 100644 --- a/stix2/test/v21/test_location.py +++ b/stix2/test/v21/test_location.py @@ -47,7 +47,8 @@ EXPECTED_LOCATION_2_REPR = "Location(" + " ".join( id='location--a6e9345f-5a15-4c29-8bb3-7dcc5d168d64', created='2016-04-06T20:03:00.000Z', modified='2016-04-06T20:03:00.000Z', - region='northern-america'""".split()) + ")" + region='northern-america'""".split() +) + ")" def test_location_with_some_required_properties(): diff --git a/stix2/test/v21/test_timestamp_precision.py b/stix2/test/v21/test_timestamp_precision.py index 8cb9735..831bd7a 100644 --- a/stix2/test/v21/test_timestamp_precision.py +++ b/stix2/test/v21/test_timestamp_precision.py @@ -5,8 +5,8 @@ import pytest import stix2 from stix2.utils import ( - Precision, PrecisionConstraint, STIXdatetime, _to_enum, format_datetime, - parse_into_datetime, + Precision, PrecisionConstraint, STIXdatetime, format_datetime, + parse_into_datetime, to_enum, ) _DT = datetime.datetime.utcnow() @@ -27,7 +27,7 @@ _DT_STR = _DT.strftime("%Y-%m-%dT%H:%M:%S") ], ) def test_to_enum(value, enum_type, enum_default, enum_expected): - result = _to_enum(value, enum_type, enum_default) + result = to_enum(value, enum_type, enum_default) assert result == enum_expected @@ -41,7 +41,7 @@ def test_to_enum(value, enum_type, enum_default, enum_expected): ) def test_to_enum_errors(value, err_type): with pytest.raises(err_type): - _to_enum(value, Precision) + to_enum(value, Precision) @pytest.mark.xfail( diff --git a/stix2/utils.py b/stix2/utils.py index 08e272d..647a89f 100644 --- a/stix2/utils.py +++ b/stix2/utils.py @@ -45,7 +45,7 @@ class PrecisionConstraint(enum.Enum): # no need for a MAX constraint yet -def _to_enum(value, enum_type, enum_default=None): +def to_enum(value, enum_type, enum_default=None): """ Detect and convert strings to enums and None to a default enum. This allows use of strings and None in APIs, while enforcing the enum type: if @@ -88,11 +88,11 @@ class STIXdatetime(dt.datetime): """ def __new__(cls, *args, **kwargs): - precision = _to_enum( + precision = to_enum( kwargs.pop("precision", Precision.ANY), Precision, ) - precision_constraint = _to_enum( + precision_constraint = to_enum( kwargs.pop("precision_constraint", PrecisionConstraint.EXACT), PrecisionConstraint, ) @@ -233,8 +233,8 @@ def parse_into_datetime( :return: A STIXdatetime instance, which is a datetime but also carries the precision info necessary to properly JSON-serialize it. """ - precision = _to_enum(precision, Precision) - precision_constraint = _to_enum(precision_constraint, PrecisionConstraint) + precision = to_enum(precision, Precision) + precision_constraint = to_enum(precision_constraint, PrecisionConstraint) if isinstance(value, dt.date): if hasattr(value, 'hour'):