diff --git a/pymispwarninglists/__init__.py b/pymispwarninglists/__init__.py index 1ddde4a..36316a3 100644 --- a/pymispwarninglists/__init__.py +++ b/pymispwarninglists/__init__.py @@ -1,4 +1,4 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -from .api import WarningLists, EncodeWarningLists +from .api import WarningLists, EncodeWarningList diff --git a/pymispwarninglists/api.py b/pymispwarninglists/api.py index 172edf6..a065b2d 100644 --- a/pymispwarninglists/api.py +++ b/pymispwarninglists/api.py @@ -15,12 +15,11 @@ except ImportError: HAS_JSONSCHEMA = False -class EncodeWarningLists(JSONEncoder): +class EncodeWarningList(JSONEncoder): def default(self, obj): - try: - return obj._json() - except AttributeError: - return JSONEncoder.default(self, obj) + if isinstance(obj, WarningList): + return obj.to_dict() + return JSONEncoder.default(self, obj) class PyMISPWarningListsError(Exception): @@ -42,16 +41,19 @@ class WarningList(): if self.warninglist.get('matching_attributes'): self.matching_attributes = self.warninglist['matching_attributes'] - def _json(self): - to_return = {'list': self.list, 'name': self.name, 'description': self.description, - 'version': self.version} + def to_dict(self): + to_return = {'list': [str(e) for e in self.list], 'name': self.name, + 'description': self.description, 'version': self.version} if hasattr(self, 'type'): to_return['type'] = self.type if hasattr(self, 'matching_attributes'): to_return['matching_attributes'] = self.matching_attributes return to_return - def has_match(self, value): + def to_json(self): + return json.dumps(self.to_dict(), cls=EncodeWarningList) + + def __contains__(self, value): if value in self.list: return True return False @@ -87,8 +89,8 @@ class WarningLists(collections.Mapping): def search(self, value): matches = [] for name, wl in self.warninglists.items(): - if wl.has_match(value): - matches.append(wl) + if value in wl: + matches.append(wl) return matches def __len__(self): diff --git a/tests/tests.py b/tests/tests.py index 2dec739..19fcdaf 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -2,7 +2,7 @@ # -*- coding: utf-8 -*- import unittest -from pymispwarninglists import WarningLists, EncodeWarningLists +from pymispwarninglists import WarningLists from glob import glob import os import json @@ -20,7 +20,7 @@ class TestPyMISPWarningLists(unittest.TestCase): warninglist = json.load(f) warninglists_from_files[warninglist['name']] = warninglist for name, w in self.warninglists.items(): - out = w._json() + out = w.to_dict() self.assertDictEqual(out, warninglists_from_files[w.name]) def test_validate_schema_warninglists(self): @@ -28,4 +28,8 @@ class TestPyMISPWarningLists(unittest.TestCase): def test_json(self): for w in self.warninglists.values(): - json.dumps(w, cls=EncodeWarningLists) + w.to_json() + + def test_search(self): + results = self.warninglists.search('8.8.8.8') + self.assertEqual(results[0].name, 'List of known IPv4 public DNS resolvers')