Fix errors when instantiating custom classes

Defining a custom object/observable/extension class with no custom
__init__() function would result in an `AttributeError` or `TypeError`,
depending on if the class sub-classed `object` or not.
stix2.1
Chris Lenk 2017-09-20 17:13:51 -04:00
parent c3a1c3dc3a
commit 21d978acc8
4 changed files with 86 additions and 5 deletions

View File

@ -9,8 +9,8 @@ from .common import (TLP_AMBER, TLP_GREEN, TLP_RED, TLP_WHITE, CustomMarking,
from .core import Bundle, _register_type, parse
from .environment import Environment, ObjectFactory
from .observables import (URL, AlternateDataStream, ArchiveExt, Artifact,
AutonomousSystem, CustomObservable, Directory,
DomainName, EmailAddress, EmailMessage,
AutonomousSystem, CustomExtension, CustomObservable,
Directory, DomainName, EmailAddress, EmailMessage,
EmailMIMEComponent, File, HTTPRequestExt, ICMPExt,
IPv4Address, IPv6Address, MACAddress, Mutex,
NetworkTraffic, NTFSExt, PDFExt, Process,

View File

@ -836,7 +836,14 @@ def CustomObservable(type='x-custom-observable', properties=None):
def __init__(self, **kwargs):
_Observable.__init__(self, **kwargs)
cls.__init__(self, **kwargs)
try:
cls.__init__(self, **kwargs)
except (AttributeError, TypeError) as e:
# Don't accidentally catch errors raised in a custom __init__()
if ("has no attribute '__init__'" in str(e) or
str(e) == "object.__init__() takes no parameters"):
return
raise e
_register_observable(_Custom)
return _Custom
@ -883,7 +890,14 @@ def CustomExtension(observable=None, type='x-custom-observable', properties={}):
def __init__(self, **kwargs):
_Extension.__init__(self, **kwargs)
cls.__init__(self, **kwargs)
try:
cls.__init__(self, **kwargs)
except (AttributeError, TypeError) as e:
# Don't accidentally catch errors raised in a custom __init__()
if ("has no attribute '__init__'" in str(e) or
str(e) == "object.__init__() takes no parameters"):
return
raise e
_register_extension(observable, _Custom)
return _Custom

View File

@ -346,7 +346,14 @@ def CustomObject(type='x-custom-type', properties=None):
def __init__(self, **kwargs):
_STIXBase.__init__(self, **kwargs)
cls.__init__(self, **kwargs)
try:
cls.__init__(self, **kwargs)
except (AttributeError, TypeError) as e:
# Don't accidentally catch errors raised in a custom __init__()
if ("has no attribute '__init__'" in str(e) or
str(e) == "object.__init__() takes no parameters"):
return
raise e
stix2._register_type(_Custom)
return _Custom

View File

@ -104,6 +104,26 @@ def test_custom_object_type():
assert "'property2' is too small." in str(excinfo.value)
def test_custom_object_no_init():
@stix2.sdo.CustomObject('x-new-obj', [
('property1', stix2.properties.StringProperty(required=True)),
])
class NewObj():
pass
no = NewObj(property1='something')
assert no.property1 == 'something'
@stix2.sdo.CustomObject('x-new-obj2', [
('property1', stix2.properties.StringProperty(required=True)),
])
class NewObj2(object):
pass
no2 = NewObj2(property1='something')
assert no2.property1 == 'something'
def test_parse_custom_object_type():
nt_string = """{
"type": "x-new-type",
@ -153,6 +173,26 @@ def test_custom_observable_object():
assert "'property2' is too small." in str(excinfo.value)
def test_custom_observable_object_no_init():
@stix2.observables.CustomObservable('x-new-observable', [
('property1', stix2.properties.StringProperty()),
])
class NewObs():
pass
no = NewObs(property1='something')
assert no.property1 == 'something'
@stix2.observables.CustomObservable('x-new-obs2', [
('property1', stix2.properties.StringProperty()),
])
class NewObs2(object):
pass
no2 = NewObs2(property1='something')
assert no2.property1 == 'something'
def test_custom_observable_object_invalid_ref_property():
with pytest.raises(ValueError) as excinfo:
@stix2.observables.CustomObservable('x-new-obs', [
@ -364,6 +404,26 @@ def test_custom_extension_invalid_observable():
assert "Custom observables must be created with the @CustomObservable decorator." in str(excinfo.value)
def test_custom_extension_no_init():
@stix2.observables.CustomExtension(stix2.DomainName, 'x-new-extension', {
'property1': stix2.properties.StringProperty(required=True),
})
class NewExt():
pass
ne = NewExt(property1="foobar")
assert ne.property1 == "foobar"
@stix2.observables.CustomExtension(stix2.DomainName, 'x-new-ext2', {
'property1': stix2.properties.StringProperty(required=True),
})
class NewExt2(object):
pass
ne2 = NewExt2(property1="foobar")
assert ne2.property1 == "foobar"
def test_parse_observable_with_custom_extension():
input_str = """{
"type": "domain-name",