From 3e0e80141b9460c018274764617082dd7766660b Mon Sep 17 00:00:00 2001 From: clenk Date: Wed, 17 May 2017 15:21:02 -0400 Subject: [PATCH] For object reference properties, check the type of the object referenced, not only that it is included in the local scope. --- stix2/base.py | 36 ++++++++++++++---- stix2/observables.py | 62 +++++++++++++++---------------- stix2/properties.py | 11 +++++- stix2/test/test_observed_data.py | 64 ++++++++++++++++++++++---------- 4 files changed, 113 insertions(+), 60 deletions(-) diff --git a/stix2/base.py b/stix2/base.py index ba10d04..9917b4c 100644 --- a/stix2/base.py +++ b/stix2/base.py @@ -184,14 +184,36 @@ class _Observable(_STIXBase): self._STIXBase__valid_refs = [] super(_Observable, self).__init__(**kwargs) + def _check_ref(self, ref, prop, prop_name): + if ref not in self._STIXBase__valid_refs: + raise InvalidObjRefError(self.__class__, prop_name, "'%s' is not a valid object in local scope" % ref) + + try: + allowed_types = prop.contained.valid_types + except AttributeError: + try: + allowed_types = prop.valid_types + except AttributeError: + raise ValueError("'%s' is named like an object reference property but " + "is not an ObjectReferenceProperty or a ListProperty " + "containing ObjectReferenceProperty." % prop_name) + + if allowed_types: + try: + ref_type = self._STIXBase__valid_refs[ref] + except TypeError: + raise ValueError("'%s' must be created with _valid_refs as a dict, not a list." % self.__class__.__name__) + if ref_type not in allowed_types: + raise InvalidObjRefError(self.__class__, prop_name, "object reference '%s' is of an invalid type '%s'" % (ref, ref_type)) + def _check_property(self, prop_name, prop, kwargs): super(_Observable, self)._check_property(prop_name, prop, kwargs) - if prop_name.endswith('_ref') and prop_name in kwargs: + if prop_name not in kwargs: + return + + if prop_name.endswith('_ref'): ref = kwargs[prop_name] - if ref not in self._STIXBase__valid_refs: - raise InvalidObjRefError(self.__class__, prop_name, "'%s' is not a valid object in local scope" % ref) - elif prop_name.endswith('_refs') and prop_name in kwargs: + self._check_ref(ref, prop, prop_name) + elif prop_name.endswith('_refs'): for ref in kwargs[prop_name]: - if ref not in self._STIXBase__valid_refs: - raise InvalidObjRefError(self.__class__, prop_name, "'%s' is not a valid object in local scope" % ref) - # TODO also check the type of the object referenced, not just that the key exists + self._check_ref(ref, prop, prop_name) diff --git a/stix2/observables.py b/stix2/observables.py index 597508b..704c45e 100644 --- a/stix2/observables.py +++ b/stix2/observables.py @@ -49,7 +49,7 @@ class Directory(_Observable): 'created': TimestampProperty(), 'modified': TimestampProperty(), 'accessed': TimestampProperty(), - 'contains_refs': ListProperty(ObjectReferenceProperty), + 'contains_refs': ListProperty(ObjectReferenceProperty(valid_types=['file', 'directory'])), } @@ -58,7 +58,7 @@ class DomainName(_Observable): _properties = { 'type': TypeProperty(_type), 'value': StringProperty(required=True), - 'resolves_to_refs': ListProperty(ObjectReferenceProperty), + 'resolves_to_refs': ListProperty(ObjectReferenceProperty(valid_types=['ipv4-addr', 'ipv6-addr', 'domain-name'])), } @@ -68,14 +68,14 @@ class EmailAddress(_Observable): 'type': TypeProperty(_type), 'value': StringProperty(required=True), 'display_name': StringProperty(), - 'belongs_to_ref': ObjectReferenceProperty(), + 'belongs_to_ref': ObjectReferenceProperty(valid_types='user-account'), } class EmailMIMEComponent(_STIXBase): _properties = { 'body': StringProperty(), - 'body_raw_ref': ObjectReferenceProperty(), + 'body_raw_ref': ObjectReferenceProperty(valid_types=['artifact', 'file']), 'content_type': StringProperty(), 'content_disposition': StringProperty(), } @@ -92,17 +92,17 @@ class EmailMessage(_Observable): 'is_multipart': BooleanProperty(required=True), 'date': TimestampProperty(), 'content_type': StringProperty(), - 'from_ref': ObjectReferenceProperty(), - 'sender_ref': ObjectReferenceProperty(), - 'to_refs': ListProperty(ObjectReferenceProperty), - 'cc_refs': ListProperty(ObjectReferenceProperty), - 'bcc_refs': ListProperty(ObjectReferenceProperty), + 'from_ref': ObjectReferenceProperty(valid_types='email-addr'), + 'sender_ref': ObjectReferenceProperty(valid_types='email-addr'), + 'to_refs': ListProperty(ObjectReferenceProperty(valid_types='email-addr')), + 'cc_refs': ListProperty(ObjectReferenceProperty(valid_types='email-addr')), + 'bcc_refs': ListProperty(ObjectReferenceProperty(valid_types='email-addr')), 'subject': StringProperty(), 'received_lines': ListProperty(StringProperty), 'additional_header_fields': DictionaryProperty(), 'body': StringProperty(), 'body_multipart': ListProperty(EmbeddedObjectProperty(type=EmailMIMEComponent)), - 'raw_email_ref': ObjectReferenceProperty(), + 'raw_email_ref': ObjectReferenceProperty(valid_types='artifact'), } def _check_object_constaints(self): @@ -113,7 +113,7 @@ class EmailMessage(_Observable): class ArchiveExt(_STIXBase): _properties = { - 'contains_refs': ListProperty(ObjectReferenceProperty, required=True), + 'contains_refs': ListProperty(ObjectReferenceProperty(valid_types='file'), required=True), 'version': StringProperty(), 'comment': StringProperty(), } @@ -231,12 +231,12 @@ class File(_Observable): 'created': TimestampProperty(), 'modified': TimestampProperty(), 'accessed': TimestampProperty(), - 'parent_directory_ref': ObjectReferenceProperty(), + 'parent_directory_ref': ObjectReferenceProperty(valid_types='directory'), 'is_encrypted': BooleanProperty(), 'encryption_algorithm': StringProperty(), 'decryption_key': StringProperty(), 'contains_refs': ListProperty(ObjectReferenceProperty), - 'content_ref': ObjectReferenceProperty(), + 'content_ref': ObjectReferenceProperty(valid_types='artifact'), } def _check_object_constaints(self): @@ -250,8 +250,8 @@ class IPv4Address(_Observable): _properties = { 'type': TypeProperty(_type), 'value': StringProperty(required=True), - 'resolves_to_refs': ListProperty(ObjectReferenceProperty), - 'belongs_to_refs': ListProperty(ObjectReferenceProperty), + 'resolves_to_refs': ListProperty(ObjectReferenceProperty(valid_types='mac-addr')), + 'belongs_to_refs': ListProperty(ObjectReferenceProperty(valid_types='autonomous-system')), } @@ -260,8 +260,8 @@ class IPv6Address(_Observable): _properties = { 'type': TypeProperty(_type), 'value': StringProperty(required=True), - 'resolves_to_refs': ListProperty(ObjectReferenceProperty), - 'belongs_to_refs': ListProperty(ObjectReferenceProperty), + 'resolves_to_refs': ListProperty(ObjectReferenceProperty(valid_types='mac-addr')), + 'belongs_to_refs': ListProperty(ObjectReferenceProperty(valid_types='autonomous-system')), } @@ -288,7 +288,7 @@ class HTTPRequestExt(_STIXBase): 'request_version': StringProperty(), 'request_header': DictionaryProperty(), 'message_body_length': IntegerProperty(), - 'message_body_data_ref': ObjectReferenceProperty(), + 'message_body_data_ref': ObjectReferenceProperty(valid_types='artifact'), } @@ -347,8 +347,8 @@ class NetworkTraffic(_Observable): 'start': TimestampProperty(), 'end': TimestampProperty(), 'is_active': BooleanProperty(), - 'src_ref': ObjectReferenceProperty(), - 'dst_ref': ObjectReferenceProperty(), + 'src_ref': ObjectReferenceProperty(valid_types=['ipv4-addr', 'ipv6-addr', 'mac-addr', 'domain-name']), + 'dst_ref': ObjectReferenceProperty(valid_types=['ipv4-addr', 'ipv6-addr', 'mac-addr', 'domain-name']), 'src_port': IntegerProperty(), 'dst_port': IntegerProperty(), 'protocols': ListProperty(StringProperty, required=True), @@ -357,10 +357,10 @@ class NetworkTraffic(_Observable): 'src_packets': IntegerProperty(), 'dst_packets': IntegerProperty(), 'ipfix': DictionaryProperty(), - 'src_payload_ref': ObjectReferenceProperty(), - 'dst_payload_ref': ObjectReferenceProperty(), - 'encapsulates_refs': ListProperty(ObjectReferenceProperty), - 'encapsulates_by_ref': ObjectReferenceProperty(), + 'src_payload_ref': ObjectReferenceProperty(valid_types='artifact'), + 'dst_payload_ref': ObjectReferenceProperty(valid_types='artifact'), + 'encapsulates_refs': ListProperty(ObjectReferenceProperty(valid_types='network-traffic')), + 'encapsulates_by_ref': ObjectReferenceProperty(valid_types='network-traffic'), } def _check_object_constaints(self): @@ -392,7 +392,7 @@ class WindowsServiceExt(_STIXBase): "SERVICE_DISABLED", "SERVICE_SYSTEM_ALERT", ]), - 'service_dll_refs': ListProperty(ObjectReferenceProperty), + 'service_dll_refs': ListProperty(ObjectReferenceProperty(valid_types='file')), 'service_type': EnumProperty([ "SERVICE_KERNEL_DRIVER", "SERVICE_FILE_SYSTEM_DRIVER", @@ -425,11 +425,11 @@ class Process(_Observable): 'arguments': ListProperty(StringProperty), 'command_line': StringProperty(), 'environment_variables': DictionaryProperty(), - 'opened_connection_refs': ListProperty(ObjectReferenceProperty), - 'creator_user_ref': ObjectReferenceProperty(), - 'binary_ref': ObjectReferenceProperty(), - 'parent_ref': ObjectReferenceProperty(), - 'child_refs': ListProperty(ObjectReferenceProperty), + 'opened_connection_refs': ListProperty(ObjectReferenceProperty(valid_types='network-traffic')), + 'creator_user_ref': ObjectReferenceProperty(valid_types='user-account'), + 'binary_ref': ObjectReferenceProperty(valid_types='file'), + 'parent_ref': ObjectReferenceProperty(valid_types='process'), + 'child_refs': ListProperty(ObjectReferenceProperty('process')), } @@ -514,7 +514,7 @@ class WindowsRegistryKey(_Observable): 'values': ListProperty(EmbeddedObjectProperty(type=WindowsRegistryValueType)), # this is not the modified timestamps of the object itself 'modified': TimestampProperty(), - 'creator_user_ref': ObjectReferenceProperty(), + 'creator_user_ref': ObjectReferenceProperty(valid_types='user-account'), 'number_of_subkeys': IntegerProperty(), } diff --git a/stix2/properties.py b/stix2/properties.py index 6d217f4..bd1e3a2 100644 --- a/stix2/properties.py +++ b/stix2/properties.py @@ -245,9 +245,11 @@ class ObservableProperty(Property): except ValueError: raise ValueError("The observable property must contain a dictionary") + valid_refs = dict((k, v['type']) for (k, v) in dictified.items()) + from .__init__ import parse_observable # avoid circular import for key, obj in dictified.items(): - parsed_obj = parse_observable(obj, dictified.keys()) + parsed_obj = parse_observable(obj, valid_refs) if not issubclass(type(parsed_obj), _Observable): raise ValueError("Objects in an observable property must be " "Cyber Observable Objects") @@ -369,7 +371,12 @@ class SelectorProperty(Property): class ObjectReferenceProperty(StringProperty): - pass + + def __init__(self, valid_types=None, **kwargs): + if valid_types and type(valid_types) is not list: + valid_types = [valid_types] + self.valid_types = valid_types + super(ObjectReferenceProperty, self).__init__(**kwargs) class EmbeddedObjectProperty(Property): diff --git a/stix2/test/test_observed_data.py b/stix2/test/test_observed_data.py index 5c69fb9..5216892 100644 --- a/stix2/test/test_observed_data.py +++ b/stix2/test/test_observed_data.py @@ -6,7 +6,6 @@ import pytz import stix2 -from ..exceptions import InvalidValueError from .constants import OBSERVED_DATA_ID @@ -218,7 +217,7 @@ def test_parse_autonomous_system_valid(data): @pytest.mark.parametrize("data", [ - """"1": { + """{ "type": "email-address", "value": "john@example.com", "display_name": "John Doe", @@ -226,13 +225,12 @@ def test_parse_autonomous_system_valid(data): }""", ]) def test_parse_email_address(data): - odata_str = re.compile('\}.+\},', re.DOTALL).sub('}, %s},' % data, EXPECTED) - odata = stix2.parse(odata_str) - assert odata.objects["1"].type == "email-address" + odata = stix2.parse_observable(data, {"0": "user-account"}) + assert odata.type == "email-address" - odata_str = re.compile('"belongs_to_ref": "0"', re.DOTALL).sub('"belongs_to_ref": "3"', odata_str) - with pytest.raises(InvalidValueError): - stix2.parse(odata_str) + odata_str = re.compile('"belongs_to_ref": "0"', re.DOTALL).sub('"belongs_to_ref": "3"', data) + with pytest.raises(stix2.exceptions.InvalidObjRefError): + stix2.parse_observable(odata_str, {"0": "user-account"}) @pytest.mark.parametrize("data", [ @@ -276,7 +274,15 @@ def test_parse_email_address(data): """ ]) def test_parse_email_message(data): - odata = stix2.parse_observable(data, [str(i) for i in range(1, 6)]) + valid_refs = { + "0": "email-message", + "1": "email-addr", + "2": "email-addr", + "3": "email-addr", + "4": "artifact", + "5": "file", + } + odata = stix2.parse_observable(data, valid_refs) assert odata.type == "email-message" assert odata.body_multipart[0].content_disposition == "inline" @@ -365,8 +371,16 @@ def test_parse_file_archive(data): """ ]) def test_parse_email_message_with_at_least_one_error(data): + valid_refs = { + "0": "email-message", + "1": "email-addr", + "2": "email-addr", + "3": "email-addr", + "4": "artifact", + "5": "file", + } with pytest.raises(stix2.exceptions.AtLeastOnePropertyError) as excinfo: - stix2.parse_observable(data, [str(i) for i in range(1, 6)]) + stix2.parse_observable(data, valid_refs) assert excinfo.value.cls == stix2.EmailMIMEComponent assert excinfo.value.properties == ["body", "body_raw_ref"] @@ -385,7 +399,7 @@ def test_parse_email_message_with_at_least_one_error(data): """ ]) def test_parse_basic_tcp_traffic(data): - odata = stix2.parse_observable(data, ["0", "1"]) + odata = stix2.parse_observable(data, {"0": "ipv4-addr", "1": "ipv4-addr"}) assert odata.type == "network-traffic" assert odata.src_ref == "0" @@ -413,7 +427,7 @@ def test_parse_basic_tcp_traffic(data): ]) def test_parse_basic_tcp_traffic_with_error(data): with pytest.raises(stix2.exceptions.AtLeastOnePropertyError) as excinfo: - stix2.parse_observable(data, ["4"]) + stix2.parse_observable(data, {"4": "network-traffic"}) assert excinfo.value.cls == stix2.NetworkTraffic assert excinfo.value.properties == ["dst_ref", "src_ref"] @@ -512,7 +526,7 @@ def test_artifact_mutual_exclusion_error(): def test_directory_example(): - dir = stix2.Directory(_valid_refs=["1"], + dir = stix2.Directory(_valid_refs={"1": "file"}, path='/usr/lib', created="2015-12-21T19:00:00Z", modified="2015-12-24T19:00:00Z", @@ -540,7 +554,7 @@ def test_directory_example_ref_error(): def test_domain_name_example(): - dn = stix2.DomainName(_valid_refs=["1"], + dn = stix2.DomainName(_valid_refs={"1": 'domain-name'}, value="example.com", resolves_to_refs=["1"]) @@ -548,6 +562,16 @@ def test_domain_name_example(): assert dn.resolves_to_refs == ["1"] +def test_domain_name_example_invalid_ref_type(): + with pytest.raises(stix2.exceptions.InvalidObjRefError) as excinfo: + stix2.DomainName(_valid_refs={"1": "file"}, + value="example.com", + resolves_to_refs=["1"]) + + assert excinfo.value.cls == stix2.DomainName + assert excinfo.value.prop_name == "resolves_to_refs" + + def test_file_example(): f = stix2.File(name="qwerty.dll", hashes={ @@ -610,7 +634,7 @@ def test_file_example_encryption_error(): def test_ip4_address_example(): - ip4 = stix2.IPv4Address(_valid_refs=["1", "4", "5"], + ip4 = stix2.IPv4Address(_valid_refs={"4": "mac-addr", "5": "mac-addr"}, value="198.51.100.3", resolves_to_refs=["4", "5"]) @@ -637,7 +661,7 @@ def test_mac_address_example(): def test_network_traffic_example(): - nt = stix2.NetworkTraffic(_valid_refs=["0", "1"], + nt = stix2.NetworkTraffic(_valid_refs={"0": "ipv4-addr", "1": "ipv4-addr"}, protocols="tcp", src_ref="0", dst_ref="1") @@ -655,7 +679,7 @@ def test_network_traffic_http_request_example(): "User-Agent": "Mozilla/5.0 (Windows; U; Windows NT 5.1; en-US; rv:1.6) Gecko/20040113", "Host": "www.example.com" }) - nt = stix2.NetworkTraffic(_valid_refs=["0", "1"], + nt = stix2.NetworkTraffic(_valid_refs={"0": "ipv4-addr"}, protocols="tcp", src_ref="0", extensions={'http-request-ext': h}) @@ -670,7 +694,7 @@ def test_network_traffic_http_request_example(): def test_network_traffic_icmp_example(): h = stix2.ICMPExt(icmp_type_hex="08", icmp_code_hex="00") - nt = stix2.NetworkTraffic(_valid_refs=["0", "1"], + nt = stix2.NetworkTraffic(_valid_refs={"0": "ipv4-addr"}, protocols="tcp", src_ref="0", extensions={'icmp-ext': h}) @@ -683,7 +707,7 @@ def test_network_traffic_socket_example(): address_family="AF_INET", protocol_family="PF_INET", socket_type="SOCK_STREAM") - nt = stix2.NetworkTraffic(_valid_refs=["0", "1"], + nt = stix2.NetworkTraffic(_valid_refs={"0": "ipv4-addr"}, protocols="tcp", src_ref="0", extensions={'socket-ext': h}) @@ -695,7 +719,7 @@ def test_network_traffic_socket_example(): def test_network_traffic_tcp_example(): h = stix2.TCPExt(src_flags_hex="00000002") - nt = stix2.NetworkTraffic(_valid_refs=["0", "1"], + nt = stix2.NetworkTraffic(_valid_refs={"0": "ipv4-addr"}, protocols="tcp", src_ref="0", extensions={'tcp-ext': h})