From 17071506abf94f8e0bce9ec0808e6f622b68647f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vinot?= Date: Tue, 31 Oct 2017 17:06:50 -0700 Subject: [PATCH] Add support for CIDR lookups --- pymispwarninglists/api.py | 58 ++++++++++++++++++++++++++++++++++----- tests/tests.py | 9 ++++++ 2 files changed, 60 insertions(+), 7 deletions(-) diff --git a/pymispwarninglists/api.py b/pymispwarninglists/api.py index a065b2d..92d59e4 100644 --- a/pymispwarninglists/api.py +++ b/pymispwarninglists/api.py @@ -7,6 +7,8 @@ import os import sys import collections from glob import glob +from ipaddress import ip_address, ip_network + try: import jsonschema @@ -30,7 +32,7 @@ class PyMISPWarningListsError(Exception): class WarningList(): - def __init__(self, warninglist): + def __init__(self, warninglist, slow_search=False): self.warninglist = warninglist self.list = self.warninglist['list'] self.description = self.warninglist['description'] @@ -41,6 +43,20 @@ class WarningList(): if self.warninglist.get('matching_attributes'): self.matching_attributes = self.warninglist['matching_attributes'] + self.slow_search = slow_search + self._network_objects = [] + + if self.slow_search: + self._network_objects = self._slow_index() + # If network objects is empty, reverting to default anyway + if not self._network_objects: + self.slow_search = False + + def __contains__(self, value): + if self.slow_search: + return self._slow_search(value) + return self._fast_search(value) + def to_dict(self): to_return = {'list': [str(e) for e in self.list], 'name': self.name, 'description': self.description, 'version': self.version} @@ -53,22 +69,50 @@ class WarningList(): def to_json(self): return json.dumps(self.to_dict(), cls=EncodeWarningList) - def __contains__(self, value): - if value in self.list: - return True - return False + def _fast_search(self, value): + return value in self.list + + def _slow_index(self): + to_return = [] + for entry in self.list: + try: + # Try if the entry is a network bloc + to_return.append(ip_network(entry)) + continue + except ValueError: + pass + try: + # Try if the entry is an IP + to_return.append(ip_address(entry)) + continue + except ValueError: + pass + # Not an IP nor a network + return to_return + + def _slow_search(self, value): + try: + value = ip_address(value) + except ValueError: + # The value to search isn't an IP address, falling back to default + return self._fast_search(value) + for obj in self._network_objects: + if value == obj or value in obj: + return True + # If nothing has been found yet, fallback to default + return self._fast_search(value) class WarningLists(collections.Mapping): - def __init__(self): + def __init__(self, slow_search=False): self.root_dir_warninglists = os.path.join(os.path.abspath(os.path.dirname(sys.modules['pymispwarninglists'].__file__)), 'data', 'misp-warninglists', 'lists') self.warninglists = {} for warninglist_file in glob(os.path.join(self.root_dir_warninglists, '*', 'list.json')): with open(warninglist_file, 'r') as f: warninglist = json.load(f) - self.warninglists[warninglist['name']] = WarningList(warninglist) + self.warninglists[warninglist['name']] = WarningList(warninglist, slow_search) def validate_with_schema(self): if not HAS_JSONSCHEMA: diff --git a/tests/tests.py b/tests/tests.py index 19fcdaf..993f221 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -33,3 +33,12 @@ class TestPyMISPWarningLists(unittest.TestCase): def test_search(self): results = self.warninglists.search('8.8.8.8') self.assertEqual(results[0].name, 'List of known IPv4 public DNS resolvers') + + def test_slow_search(self): + self.warninglists = WarningLists(True) + results = self.warninglists.search('8.8.8.8') + self.assertEqual(results[0].name, 'List of known IPv4 public DNS resolvers') + results = self.warninglists.search('100.64.1.56') + self.assertEqual(results[0].name, 'List of RFC 6598 CIDR blocks') + results = self.warninglists.search('2001:DB8::34:1') + self.assertEqual(results[0].name, 'List of RFC 3849 CIDR blocks')