diff --git a/stix2/base.py b/stix2/base.py index aed1dd2..ca44daf 100644 --- a/stix2/base.py +++ b/stix2/base.py @@ -4,7 +4,7 @@ import collections import datetime as dt import json -from .exceptions import STIXValueError +from .exceptions import STIXValueError, MissingFieldsError from .utils import format_datetime, get_timestamp, NOW __all__ = ['STIXJSONEncoder', '_STIXBase'] @@ -46,7 +46,6 @@ class _STIXBase(collections.Mapping): def __init__(self, **kwargs): cls = self.__class__ - class_name = cls.__name__ # Use the same timestamp for any auto-generated datetimes self.__now = get_timestamp() @@ -59,9 +58,7 @@ class _STIXBase(collections.Mapping): required_fields = get_required_properties(cls._properties) missing_kwargs = set(required_fields) - set(kwargs) if missing_kwargs: - msg = "Missing required field(s) for {type}: ({fields})." - field_list = ", ".join(x for x in sorted(list(missing_kwargs))) - raise ValueError(msg.format(type=class_name, fields=field_list)) + raise MissingFieldsError(cls, missing_kwargs) for prop_name, prop_metadata in cls._properties.items(): self._check_property(prop_name, prop_metadata, kwargs) diff --git a/stix2/exceptions.py b/stix2/exceptions.py index 0c70dc8..e64a3de 100644 --- a/stix2/exceptions.py +++ b/stix2/exceptions.py @@ -6,6 +6,7 @@ class STIXValueError(STIXError, ValueError): """An invalid value was provided to a STIX object's __init__.""" def __init__(self, cls, prop_name, reason): + super(STIXValueError, self).__init__() self.cls = cls self.prop_name = prop_name self.reason = reason @@ -13,3 +14,17 @@ class STIXValueError(STIXError, ValueError): def __str__(self): msg = "Invalid value for {0.cls.__name__} '{0.prop_name}': {0.reason}" return msg.format(self) + + +class MissingFieldsError(STIXError, ValueError): + """Missing required field(s) when construting STIX object.""" + + def __init__(self, cls, fields): + super(MissingFieldsError, self).__init__() + self.cls = cls + self.fields = sorted(list(fields)) + + def __str__(self): + msg = "Missing required field(s) for {0}: ({1})." + return msg.format(self.cls.__name__, + ", ".join(x for x in self.fields)) diff --git a/stix2/test/test_external_reference.py b/stix2/test/test_external_reference.py index 5c39852..52e21cc 100644 --- a/stix2/test/test_external_reference.py +++ b/stix2/test/test_external_reference.py @@ -107,6 +107,9 @@ def test_external_reference_offline(): def test_external_reference_source_required(): - with pytest.raises(ValueError) as excinfo: + with pytest.raises(stix2.exceptions.MissingFieldsError) as excinfo: stix2.ExternalReference() + + assert excinfo.value.cls == stix2.ExternalReference + assert excinfo.value.fields == ["source_name"] assert str(excinfo.value) == "Missing required field(s) for ExternalReference: (source_name)." diff --git a/stix2/test/test_indicator.py b/stix2/test/test_indicator.py index 00ea1df..b532c3a 100644 --- a/stix2/test/test_indicator.py +++ b/stix2/test/test_indicator.py @@ -87,14 +87,20 @@ def test_indicator_id_must_start_with_indicator(): def test_indicator_required_fields(): - with pytest.raises(ValueError) as excinfo: + with pytest.raises(stix2.exceptions.MissingFieldsError) as excinfo: stix2.Indicator() + + assert excinfo.value.cls == stix2.Indicator + assert excinfo.value.fields == ["labels", "pattern"] assert str(excinfo.value) == "Missing required field(s) for Indicator: (labels, pattern)." def test_indicator_required_field_pattern(): - with pytest.raises(ValueError) as excinfo: + with pytest.raises(stix2.exceptions.MissingFieldsError) as excinfo: stix2.Indicator(labels=['malicious-activity']) + + assert excinfo.value.cls == stix2.Indicator + assert excinfo.value.fields == ["pattern"] assert str(excinfo.value) == "Missing required field(s) for Indicator: (pattern)." diff --git a/stix2/test/test_kill_chain_phases.py b/stix2/test/test_kill_chain_phases.py index f646f0a..d2ecc1f 100644 --- a/stix2/test/test_kill_chain_phases.py +++ b/stix2/test/test_kill_chain_phases.py @@ -36,23 +36,29 @@ def test_kill_chain_example(): def test_kill_chain_required_fields(): - with pytest.raises(ValueError) as excinfo: + with pytest.raises(stix2.exceptions.MissingFieldsError) as excinfo: stix2.KillChainPhase() + assert excinfo.value.cls == stix2.KillChainPhase + assert excinfo.value.fields == ["kill_chain_name", "phase_name"] assert str(excinfo.value) == "Missing required field(s) for KillChainPhase: (kill_chain_name, phase_name)." def test_kill_chain_required_field_chain_name(): - with pytest.raises(ValueError) as excinfo: + with pytest.raises(stix2.exceptions.MissingFieldsError) as excinfo: stix2.KillChainPhase(phase_name="weaponization") + assert excinfo.value.cls == stix2.KillChainPhase + assert excinfo.value.fields == ["kill_chain_name"] assert str(excinfo.value) == "Missing required field(s) for KillChainPhase: (kill_chain_name)." def test_kill_chain_required_field_phase_name(): - with pytest.raises(ValueError) as excinfo: + with pytest.raises(stix2.exceptions.MissingFieldsError) as excinfo: stix2.KillChainPhase(kill_chain_name="lockheed-martin-cyber-kill-chain") + assert excinfo.value.cls == stix2.KillChainPhase + assert excinfo.value.fields == ["phase_name"] assert str(excinfo.value) == "Missing required field(s) for KillChainPhase: (phase_name)." diff --git a/stix2/test/test_malware.py b/stix2/test/test_malware.py index d4ea9ec..ca83f33 100644 --- a/stix2/test/test_malware.py +++ b/stix2/test/test_malware.py @@ -71,14 +71,20 @@ def test_malware_id_must_start_with_malware(): def test_malware_required_fields(): - with pytest.raises(ValueError) as excinfo: + with pytest.raises(stix2.exceptions.MissingFieldsError) as excinfo: stix2.Malware() + + assert excinfo.value.cls == stix2.Malware + assert excinfo.value.fields == ["labels", "name"] assert str(excinfo.value) == "Missing required field(s) for Malware: (labels, name)." def test_malware_required_field_name(): - with pytest.raises(ValueError) as excinfo: + with pytest.raises(stix2.exceptions.MissingFieldsError) as excinfo: stix2.Malware(labels=['ransomware']) + + assert excinfo.value.cls == stix2.Malware + assert excinfo.value.fields == ["name"] assert str(excinfo.value) == "Missing required field(s) for Malware: (name)." diff --git a/stix2/test/test_relationship.py b/stix2/test/test_relationship.py index edee9fe..527dbee 100644 --- a/stix2/test/test_relationship.py +++ b/stix2/test/test_relationship.py @@ -74,23 +74,26 @@ def test_relationship_id_must_start_with_relationship(): def test_relationship_required_field_relationship_type(): - with pytest.raises(ValueError) as excinfo: + with pytest.raises(stix2.exceptions.MissingFieldsError) as excinfo: stix2.Relationship() assert str(excinfo.value) == "Missing required field(s) for Relationship: (relationship_type, source_ref, target_ref)." def test_relationship_missing_some_required_fields(): - with pytest.raises(ValueError) as excinfo: + with pytest.raises(stix2.exceptions.MissingFieldsError) as excinfo: stix2.Relationship(relationship_type='indicates') assert str(excinfo.value) == "Missing required field(s) for Relationship: (source_ref, target_ref)." def test_relationship_required_field_target_ref(): - with pytest.raises(ValueError) as excinfo: + with pytest.raises(stix2.exceptions.MissingFieldsError) as excinfo: stix2.Relationship( relationship_type='indicates', source_ref=INDICATOR_ID ) + + assert excinfo.value.cls == stix2.Relationship + assert excinfo.value.fields == ["target_ref"] assert str(excinfo.value) == "Missing required field(s) for Relationship: (target_ref)."