Simplify code

regex-type
Raphaël Vinot 2017-10-31 16:06:50 -07:00
parent 20dbcf3d5f
commit 0590f926b7
3 changed files with 21 additions and 15 deletions

View File

@ -1,4 +1,4 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from .api import WarningLists, EncodeWarningLists from .api import WarningLists, EncodeWarningList

View File

@ -15,12 +15,11 @@ except ImportError:
HAS_JSONSCHEMA = False HAS_JSONSCHEMA = False
class EncodeWarningLists(JSONEncoder): class EncodeWarningList(JSONEncoder):
def default(self, obj): def default(self, obj):
try: if isinstance(obj, WarningList):
return obj._json() return obj.to_dict()
except AttributeError: return JSONEncoder.default(self, obj)
return JSONEncoder.default(self, obj)
class PyMISPWarningListsError(Exception): class PyMISPWarningListsError(Exception):
@ -42,16 +41,19 @@ class WarningList():
if self.warninglist.get('matching_attributes'): if self.warninglist.get('matching_attributes'):
self.matching_attributes = self.warninglist['matching_attributes'] self.matching_attributes = self.warninglist['matching_attributes']
def _json(self): def to_dict(self):
to_return = {'list': self.list, 'name': self.name, 'description': self.description, to_return = {'list': [str(e) for e in self.list], 'name': self.name,
'version': self.version} 'description': self.description, 'version': self.version}
if hasattr(self, 'type'): if hasattr(self, 'type'):
to_return['type'] = self.type to_return['type'] = self.type
if hasattr(self, 'matching_attributes'): if hasattr(self, 'matching_attributes'):
to_return['matching_attributes'] = self.matching_attributes to_return['matching_attributes'] = self.matching_attributes
return to_return return to_return
def has_match(self, value): def to_json(self):
return json.dumps(self.to_dict(), cls=EncodeWarningList)
def __contains__(self, value):
if value in self.list: if value in self.list:
return True return True
return False return False
@ -87,8 +89,8 @@ class WarningLists(collections.Mapping):
def search(self, value): def search(self, value):
matches = [] matches = []
for name, wl in self.warninglists.items(): for name, wl in self.warninglists.items():
if wl.has_match(value): if value in wl:
matches.append(wl) matches.append(wl)
return matches return matches
def __len__(self): def __len__(self):

View File

@ -2,7 +2,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import unittest import unittest
from pymispwarninglists import WarningLists, EncodeWarningLists from pymispwarninglists import WarningLists
from glob import glob from glob import glob
import os import os
import json import json
@ -20,7 +20,7 @@ class TestPyMISPWarningLists(unittest.TestCase):
warninglist = json.load(f) warninglist = json.load(f)
warninglists_from_files[warninglist['name']] = warninglist warninglists_from_files[warninglist['name']] = warninglist
for name, w in self.warninglists.items(): for name, w in self.warninglists.items():
out = w._json() out = w.to_dict()
self.assertDictEqual(out, warninglists_from_files[w.name]) self.assertDictEqual(out, warninglists_from_files[w.name])
def test_validate_schema_warninglists(self): def test_validate_schema_warninglists(self):
@ -28,4 +28,8 @@ class TestPyMISPWarningLists(unittest.TestCase):
def test_json(self): def test_json(self):
for w in self.warninglists.values(): for w in self.warninglists.values():
json.dumps(w, cls=EncodeWarningLists) w.to_json()
def test_search(self):
results = self.warninglists.search('8.8.8.8')
self.assertEqual(results[0].name, 'List of known IPv4 public DNS resolvers')