diff --git a/stix2/__init__.py b/stix2/__init__.py index f881f96..573e9f8 100644 --- a/stix2/__init__.py +++ b/stix2/__init__.py @@ -223,9 +223,13 @@ class Relationship(_STIXBase): if not kwargs.get('source_ref'): raise ValueError("Missing required field for Relationship: 'source_ref'.") + elif isinstance(kwargs['source_ref'], _STIXBase): + kwargs['source_ref'] = kwargs['source_ref'].id if not kwargs.get('target_ref'): raise ValueError("Missing required field for Relationship: 'target_ref'.") + elif isinstance(kwargs['target_ref'], _STIXBase): + kwargs['target_ref'] = kwargs['target_ref'].id extra_kwargs = list(set(kwargs.keys()) - set(self._properties)) if extra_kwargs: diff --git a/stix2/test/test_stix2.py b/stix2/test/test_stix2.py index 23d9dd4..e8c1e21 100644 --- a/stix2/test/test_stix2.py +++ b/stix2/test/test_stix2.py @@ -324,3 +324,15 @@ def test_invalid_kwarg_to_relationship(): with pytest.raises(TypeError) as excinfo: relationship = stix2.Relationship(my_custom_property="foo", **RELATIONSHIP_KWARGS) assert "unexpected keyword arguments: ['my_custom_property']" in str(excinfo) + + +def test_create_relationship_from_objects_rather_than_ids(indicator, malware): + relationship = stix2.Relationship( + relationship_type="indicates", + source_ref=indicator, + target_ref=malware, + ) + + assert relationship.relationship_type == 'indicates' + assert relationship.source_ref == INDICATOR_ID + assert relationship.target_ref == MALWARE_ID