diff --git a/pymispwarninglists/api.py b/pymispwarninglists/api.py index c80ea3c..36c8cd6 100644 --- a/pymispwarninglists/api.py +++ b/pymispwarninglists/api.py @@ -2,11 +2,14 @@ # -*- coding: utf-8 -*- import json +import logging import sys from collections.abc import Mapping +from contextlib import suppress from glob import glob -from ipaddress import ip_address, ip_network +from ipaddress import ip_network, IPv6Address, IPv4Address, IPv4Network, IPv6Network, _BaseNetwork, \ + AddressValueError, NetmaskValueError from pathlib import Path from typing import Union, Dict, Any, List, Optional from urllib.parse import urlparse @@ -21,11 +24,15 @@ except ImportError: HAS_JSONSCHEMA = False +logger = logging.getLogger(__name__) + + def json_default(obj: 'WarningList') -> Union[Dict, str]: if isinstance(obj, WarningList): return obj.to_dict() + class WarningList(): expected_types = ['string', 'substring', 'hostname', 'cidr', 'regex'] @@ -44,13 +51,9 @@ class WarningList(): self.matching_attributes = self.warninglist['matching_attributes'] self.slow_search = slow_search - self._network_objects = [] if self.slow_search and self.type == 'cidr': - self._network_objects = self._network_index() - # If network objects is empty, reverting to default anyway - if not self._network_objects: - self.slow_search = False + self._ipv4_filter, self._ipv6_filter = compile_network_filters(self.list) def __repr__(self) -> str: return f'<{self.__class__.__name__}(type="{self.name}", version="{self.version}", description="{self.description}")' @@ -74,16 +77,6 @@ class WarningList(): def _fast_search(self, value) -> bool: return value in self.set - def _network_index(self) -> List: - to_return = [] - for entry in self.list: - try: - # Try if the entry is a network bloc or an IP - to_return.append(ip_network(entry)) - except ValueError: - pass - return to_return - def _slow_search(self, value: str) -> bool: if self.type == 'string': # Exact match only, using fast search @@ -101,12 +94,14 @@ class WarningList(): value = parsed_url.hostname return any(value == v or value.endswith("." + v.lstrip(".")) for v in self.list) elif self.type == 'cidr': - try: - ip_value = ip_address(value) - except ValueError: - # The value to search isn't an IP address, falling back to default - return self._fast_search(value) - return any((ip_value == obj or ip_value in obj) for obj in self._network_objects) + with suppress(AddressValueError, NetmaskValueError): + ipv4 = IPv4Address(value) + return int(ipv4) in self._ipv4_filter + with suppress(AddressValueError, NetmaskValueError): + ipv6 = IPv6Address(value) + return int(ipv6) in self._ipv6_filter + # The value to search isn't an IP address, falling back to default + return self._fast_search(value) return False @@ -164,3 +159,70 @@ class WarningLists(Mapping): def get_loaded_lists(self): return self.warninglists + + +class NetworkFilter: + def __init__(self, digit_position: int, digit2filter: Optional[dict[int, Union[bool, "NetworkFilter"]]] = None): + self.digit2filter: dict[int, Union[bool, NetworkFilter]] = digit2filter or {0: False, 1: False} + self.digit_position = digit_position + + def __contains__(self, ip: int) -> bool: + child = self.digit2filter[self._get_digit(ip)] + if isinstance(child, bool): + return child + + return ip in child + + def append(self, net: _BaseNetwork) -> None: + digit = self._get_digit(int(net.network_address)) + + if net.max_prefixlen - net.prefixlen == self.digit_position: + self.digit2filter[digit] = True + return + + child = self.digit2filter[digit] + + if child is False: + child = NetworkFilter(self.digit_position - 1) + self.digit2filter[digit] = child + + if child is not True: + child.append(net) + + def _get_digit(self, ip:int) -> int: + return (ip >> self.digit_position) & 1 + + def __repr__(self): + return f"NetworkFilter(digit_position={self.digit_position}, digit2filter={self.digit2filter})" + + def __eq__(self, other): + return isinstance(other, NetworkFilter) and self.digit_position == other.digit_position and self.digit2filter == other.digit2filter + + +def compile_network_filters(values: list) -> tuple[NetworkFilter, NetworkFilter]: + networks = convert_networks(values) + + ipv4_filter = NetworkFilter(31) + ipv6_filter = NetworkFilter(127) + + for net in networks: + root = ipv4_filter if isinstance(net, IPv4Network) else ipv6_filter + root.append(net) + + return ipv4_filter, ipv6_filter + + +def convert_networks(values: list) -> list[_BaseNetwork]: + valid_ips = [] + invalid_ips = [] + + for value in values: + try: + valid_ips.append(ip_network(value)) + except ValueError: + invalid_ips.append(value) + + if invalid_ips: + logger.warning(f'Invalid IPs found: {invalid_ips}') + + return valid_ips diff --git a/tests/tests.py b/tests/tests.py index 90b77ad..3331991 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -6,8 +6,10 @@ import os import unittest from glob import glob +from ipaddress import IPv4Network -from pymispwarninglists import WarningLists, tools +from pymispwarninglists import WarningLists, tools, WarningList +from pymispwarninglists.api import compile_network_filters, NetworkFilter class TestPyMISPWarningLists(unittest.TestCase): @@ -58,3 +60,113 @@ class TestPyMISPWarningLists(unittest.TestCase): self.assertTrue(tools.get_xdg_home_dir().exists()) warninglists = WarningLists(from_xdg_home=True) self.assertEqual(len(warninglists), len(self.warninglists)) + + +class TestCidrList(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + cls.cidr_list = WarningList( + { + "list": [ + "1.1.1.1", + "51.8.152.0/21", "51.8.160.128/25", + "2a01:4180:4051::400", + "2a01:4180:c003:8::/61", + ], + "description": "Test CIDR list", + "version": 0, + "name": "Test CIDR list", + "type": "cidr", + }, + slow_search=True + ) + + def test_exact_match(self): + assert "1.1.1.1" in self.cidr_list + assert "2a01:4180:4051::400" in self.cidr_list + + assert "3.3.3.3" not in self.cidr_list + assert "2a01:4180:4051::401" not in self.cidr_list + + def test_ipv4_bloc(self): + # 51.8.152.0/21 + assert "51.8.152.0" in self.cidr_list + assert "51.8.152.255" in self.cidr_list + assert "51.8.153.0" in self.cidr_list + assert "51.8.159.255" in self.cidr_list + + # outside + assert "51.8.151.0" not in self.cidr_list + assert "51.8.160.0" not in self.cidr_list + + # 51.8.160.128/25 + assert "51.8.160.128" in self.cidr_list + assert "51.8.160.255" in self.cidr_list + + def test_ipv6_bloc(self): + assert "2a01:4180:c003:8::" in self.cidr_list + assert "2a01:4180:c003:8::1" in self.cidr_list + assert "2a01:4180:c003:8::ffff" in self.cidr_list + assert "2a01:4180:c003:9::1000" in self.cidr_list + + assert "2a01:4180:c003:7::ffff" not in self.cidr_list + assert "2a01:4180:c003:10::" not in self.cidr_list + + def test_search_for_ipv4_bloc(self): + assert "51.8.152.0/21" in self.cidr_list + + assert "51.8.152.0/22" not in self.cidr_list + + def test_search_for_ipv4_as_int(self): + # 51.8.152.0 as integer is 864710656 + assert 856201216 in self.cidr_list + + +class TestNetworkCompilation(unittest.TestCase): + def test_simple_case(self): + ipv4_filter, ipv6_filter = compile_network_filters([IPv4Network("160.0.0.0/3"), IPv4Network("192.0.0.0/2")]) + + assert ipv6_filter == NetworkFilter(127) + assert ipv4_filter == NetworkFilter( + digit_position=31, + digit2filter={ + 0: False, + 1: NetworkFilter( + digit_position=30, + digit2filter={ + 0: NetworkFilter( + digit_position=29, + digit2filter={ + 0: False, + 1: True + } + ), + 1: True + }, + ), + }, + ), ipv4_filter + + def test_overwrite_with_bigger_network(self): + ipv4_filter, ipv6_filter = compile_network_filters([IPv4Network("192.0.0.0/2"), IPv4Network("128.0.0.0/1")]) + + assert ipv6_filter == NetworkFilter(127) + assert ipv4_filter == NetworkFilter( + digit_position=31, + digit2filter={ + 0: False, + 1: True, + }, + ), ipv4_filter + + def test_dont_overwrite_with_smaller_network(self): + ipv4_filter, ipv6_filter = compile_network_filters([IPv4Network("128.0.0.0/1"), IPv4Network("192.0.0.0/2")]) + + assert ipv6_filter == NetworkFilter(127) + assert ipv4_filter == NetworkFilter( + digit_position=31, + digit2filter={ + 0: False, + 1: True, + }, + ), ipv4_filter diff --git a/tests/time_cidr_search.py b/tests/time_cidr_search.py new file mode 100644 index 0000000..c65f1cf --- /dev/null +++ b/tests/time_cidr_search.py @@ -0,0 +1,24 @@ +import random +from datetime import datetime + +from pymispwarninglists import WarningLists + +start_time = datetime.now() + +warning_lists = WarningLists(slow_search=True) +warning_lists.warninglists = { + name: warning_list + for name, warning_list in warning_lists.warninglists.items() + if warning_list.type == "cidr" +} + +print(f"Loaded {len(warning_lists)} warning lists in {datetime.now() - start_time}") + +random_ip_v4 = [f"{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(0,255)}" for _ in range(100)] + +start_time = datetime.now() + +for ip in random_ip_v4: + warning_lists.search(ip) + +print(f"Searched for {len(random_ip_v4)} IPs in {len(warning_lists)} lists in {datetime.now() - start_time}")