[cidr] Use a decision tree for faster slow search

pull/22/head
Mathieu Beligon 2023-06-22 17:27:37 +02:00
parent 4d2e31cd4c
commit a0713f732c
3 changed files with 221 additions and 23 deletions

View File

@ -2,11 +2,14 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import json import json
import logging
import sys import sys
from collections.abc import Mapping from collections.abc import Mapping
from contextlib import suppress
from glob import glob from glob import glob
from ipaddress import ip_address, ip_network from ipaddress import ip_network, IPv6Address, IPv4Address, IPv4Network, IPv6Network, _BaseNetwork, \
AddressValueError, NetmaskValueError
from pathlib import Path from pathlib import Path
from typing import Union, Dict, Any, List, Optional from typing import Union, Dict, Any, List, Optional
from urllib.parse import urlparse from urllib.parse import urlparse
@ -21,11 +24,15 @@ except ImportError:
HAS_JSONSCHEMA = False HAS_JSONSCHEMA = False
logger = logging.getLogger(__name__)
def json_default(obj: 'WarningList') -> Union[Dict, str]: def json_default(obj: 'WarningList') -> Union[Dict, str]:
if isinstance(obj, WarningList): if isinstance(obj, WarningList):
return obj.to_dict() return obj.to_dict()
class WarningList(): class WarningList():
expected_types = ['string', 'substring', 'hostname', 'cidr', 'regex'] expected_types = ['string', 'substring', 'hostname', 'cidr', 'regex']
@ -44,13 +51,9 @@ class WarningList():
self.matching_attributes = self.warninglist['matching_attributes'] self.matching_attributes = self.warninglist['matching_attributes']
self.slow_search = slow_search self.slow_search = slow_search
self._network_objects = []
if self.slow_search and self.type == 'cidr': if self.slow_search and self.type == 'cidr':
self._network_objects = self._network_index() self._ipv4_filter, self._ipv6_filter = compile_network_filters(self.list)
# If network objects is empty, reverting to default anyway
if not self._network_objects:
self.slow_search = False
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<{self.__class__.__name__}(type="{self.name}", version="{self.version}", description="{self.description}")' return f'<{self.__class__.__name__}(type="{self.name}", version="{self.version}", description="{self.description}")'
@ -74,16 +77,6 @@ class WarningList():
def _fast_search(self, value) -> bool: def _fast_search(self, value) -> bool:
return value in self.set return value in self.set
def _network_index(self) -> List:
to_return = []
for entry in self.list:
try:
# Try if the entry is a network bloc or an IP
to_return.append(ip_network(entry))
except ValueError:
pass
return to_return
def _slow_search(self, value: str) -> bool: def _slow_search(self, value: str) -> bool:
if self.type == 'string': if self.type == 'string':
# Exact match only, using fast search # Exact match only, using fast search
@ -101,12 +94,14 @@ class WarningList():
value = parsed_url.hostname value = parsed_url.hostname
return any(value == v or value.endswith("." + v.lstrip(".")) for v in self.list) return any(value == v or value.endswith("." + v.lstrip(".")) for v in self.list)
elif self.type == 'cidr': elif self.type == 'cidr':
try: with suppress(AddressValueError, NetmaskValueError):
ip_value = ip_address(value) ipv4 = IPv4Address(value)
except ValueError: return int(ipv4) in self._ipv4_filter
with suppress(AddressValueError, NetmaskValueError):
ipv6 = IPv6Address(value)
return int(ipv6) in self._ipv6_filter
# The value to search isn't an IP address, falling back to default # The value to search isn't an IP address, falling back to default
return self._fast_search(value) return self._fast_search(value)
return any((ip_value == obj or ip_value in obj) for obj in self._network_objects)
return False return False
@ -164,3 +159,70 @@ class WarningLists(Mapping):
def get_loaded_lists(self): def get_loaded_lists(self):
return self.warninglists return self.warninglists
class NetworkFilter:
def __init__(self, digit_position: int, digit2filter: Optional[dict[int, Union[bool, "NetworkFilter"]]] = None):
self.digit2filter: dict[int, Union[bool, NetworkFilter]] = digit2filter or {0: False, 1: False}
self.digit_position = digit_position
def __contains__(self, ip: int) -> bool:
child = self.digit2filter[self._get_digit(ip)]
if isinstance(child, bool):
return child
return ip in child
def append(self, net: _BaseNetwork) -> None:
digit = self._get_digit(int(net.network_address))
if net.max_prefixlen - net.prefixlen == self.digit_position:
self.digit2filter[digit] = True
return
child = self.digit2filter[digit]
if child is False:
child = NetworkFilter(self.digit_position - 1)
self.digit2filter[digit] = child
if child is not True:
child.append(net)
def _get_digit(self, ip:int) -> int:
return (ip >> self.digit_position) & 1
def __repr__(self):
return f"NetworkFilter(digit_position={self.digit_position}, digit2filter={self.digit2filter})"
def __eq__(self, other):
return isinstance(other, NetworkFilter) and self.digit_position == other.digit_position and self.digit2filter == other.digit2filter
def compile_network_filters(values: list) -> tuple[NetworkFilter, NetworkFilter]:
networks = convert_networks(values)
ipv4_filter = NetworkFilter(31)
ipv6_filter = NetworkFilter(127)
for net in networks:
root = ipv4_filter if isinstance(net, IPv4Network) else ipv6_filter
root.append(net)
return ipv4_filter, ipv6_filter
def convert_networks(values: list) -> list[_BaseNetwork]:
valid_ips = []
invalid_ips = []
for value in values:
try:
valid_ips.append(ip_network(value))
except ValueError:
invalid_ips.append(value)
if invalid_ips:
logger.warning(f'Invalid IPs found: {invalid_ips}')
return valid_ips

View File

@ -6,8 +6,10 @@ import os
import unittest import unittest
from glob import glob from glob import glob
from ipaddress import IPv4Network
from pymispwarninglists import WarningLists, tools from pymispwarninglists import WarningLists, tools, WarningList
from pymispwarninglists.api import compile_network_filters, NetworkFilter
class TestPyMISPWarningLists(unittest.TestCase): class TestPyMISPWarningLists(unittest.TestCase):
@ -58,3 +60,113 @@ class TestPyMISPWarningLists(unittest.TestCase):
self.assertTrue(tools.get_xdg_home_dir().exists()) self.assertTrue(tools.get_xdg_home_dir().exists())
warninglists = WarningLists(from_xdg_home=True) warninglists = WarningLists(from_xdg_home=True)
self.assertEqual(len(warninglists), len(self.warninglists)) self.assertEqual(len(warninglists), len(self.warninglists))
class TestCidrList(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
cls.cidr_list = WarningList(
{
"list": [
"1.1.1.1",
"51.8.152.0/21", "51.8.160.128/25",
"2a01:4180:4051::400",
"2a01:4180:c003:8::/61",
],
"description": "Test CIDR list",
"version": 0,
"name": "Test CIDR list",
"type": "cidr",
},
slow_search=True
)
def test_exact_match(self):
assert "1.1.1.1" in self.cidr_list
assert "2a01:4180:4051::400" in self.cidr_list
assert "3.3.3.3" not in self.cidr_list
assert "2a01:4180:4051::401" not in self.cidr_list
def test_ipv4_bloc(self):
# 51.8.152.0/21
assert "51.8.152.0" in self.cidr_list
assert "51.8.152.255" in self.cidr_list
assert "51.8.153.0" in self.cidr_list
assert "51.8.159.255" in self.cidr_list
# outside
assert "51.8.151.0" not in self.cidr_list
assert "51.8.160.0" not in self.cidr_list
# 51.8.160.128/25
assert "51.8.160.128" in self.cidr_list
assert "51.8.160.255" in self.cidr_list
def test_ipv6_bloc(self):
assert "2a01:4180:c003:8::" in self.cidr_list
assert "2a01:4180:c003:8::1" in self.cidr_list
assert "2a01:4180:c003:8::ffff" in self.cidr_list
assert "2a01:4180:c003:9::1000" in self.cidr_list
assert "2a01:4180:c003:7::ffff" not in self.cidr_list
assert "2a01:4180:c003:10::" not in self.cidr_list
def test_search_for_ipv4_bloc(self):
assert "51.8.152.0/21" in self.cidr_list
assert "51.8.152.0/22" not in self.cidr_list
def test_search_for_ipv4_as_int(self):
# 51.8.152.0 as integer is 864710656
assert 856201216 in self.cidr_list
class TestNetworkCompilation(unittest.TestCase):
def test_simple_case(self):
ipv4_filter, ipv6_filter = compile_network_filters([IPv4Network("160.0.0.0/3"), IPv4Network("192.0.0.0/2")])
assert ipv6_filter == NetworkFilter(127)
assert ipv4_filter == NetworkFilter(
digit_position=31,
digit2filter={
0: False,
1: NetworkFilter(
digit_position=30,
digit2filter={
0: NetworkFilter(
digit_position=29,
digit2filter={
0: False,
1: True
}
),
1: True
},
),
},
), ipv4_filter
def test_overwrite_with_bigger_network(self):
ipv4_filter, ipv6_filter = compile_network_filters([IPv4Network("192.0.0.0/2"), IPv4Network("128.0.0.0/1")])
assert ipv6_filter == NetworkFilter(127)
assert ipv4_filter == NetworkFilter(
digit_position=31,
digit2filter={
0: False,
1: True,
},
), ipv4_filter
def test_dont_overwrite_with_smaller_network(self):
ipv4_filter, ipv6_filter = compile_network_filters([IPv4Network("128.0.0.0/1"), IPv4Network("192.0.0.0/2")])
assert ipv6_filter == NetworkFilter(127)
assert ipv4_filter == NetworkFilter(
digit_position=31,
digit2filter={
0: False,
1: True,
},
), ipv4_filter

24
tests/time_cidr_search.py Normal file
View File

@ -0,0 +1,24 @@
import random
from datetime import datetime
from pymispwarninglists import WarningLists
start_time = datetime.now()
warning_lists = WarningLists(slow_search=True)
warning_lists.warninglists = {
name: warning_list
for name, warning_list in warning_lists.warninglists.items()
if warning_list.type == "cidr"
}
print(f"Loaded {len(warning_lists)} warning lists in {datetime.now() - start_time}")
random_ip_v4 = [f"{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(0,255)}" for _ in range(100)]
start_time = datetime.now()
for ip in random_ip_v4:
warning_lists.search(ip)
print(f"Searched for {len(random_ip_v4)} IPs in {len(warning_lists)} lists in {datetime.now() - start_time}")