[cidr] Use a decision tree for faster slow search

pull/22/head
Mathieu Beligon 2023-06-22 17:27:37 +02:00
parent 4d2e31cd4c
commit a0713f732c
3 changed files with 221 additions and 23 deletions

View File

@ -2,11 +2,14 @@
# -*- coding: utf-8 -*-
import json
import logging
import sys
from collections.abc import Mapping
from contextlib import suppress
from glob import glob
from ipaddress import ip_address, ip_network
from ipaddress import ip_network, IPv6Address, IPv4Address, IPv4Network, IPv6Network, _BaseNetwork, \
AddressValueError, NetmaskValueError
from pathlib import Path
from typing import Union, Dict, Any, List, Optional
from urllib.parse import urlparse
@ -21,11 +24,15 @@ except ImportError:
HAS_JSONSCHEMA = False
logger = logging.getLogger(__name__)
def json_default(obj: 'WarningList') -> Union[Dict, str]:
if isinstance(obj, WarningList):
return obj.to_dict()
class WarningList():
expected_types = ['string', 'substring', 'hostname', 'cidr', 'regex']
@ -44,13 +51,9 @@ class WarningList():
self.matching_attributes = self.warninglist['matching_attributes']
self.slow_search = slow_search
self._network_objects = []
if self.slow_search and self.type == 'cidr':
self._network_objects = self._network_index()
# If network objects is empty, reverting to default anyway
if not self._network_objects:
self.slow_search = False
self._ipv4_filter, self._ipv6_filter = compile_network_filters(self.list)
def __repr__(self) -> str:
return f'<{self.__class__.__name__}(type="{self.name}", version="{self.version}", description="{self.description}")'
@ -74,16 +77,6 @@ class WarningList():
def _fast_search(self, value) -> bool:
return value in self.set
def _network_index(self) -> List:
to_return = []
for entry in self.list:
try:
# Try if the entry is a network bloc or an IP
to_return.append(ip_network(entry))
except ValueError:
pass
return to_return
def _slow_search(self, value: str) -> bool:
if self.type == 'string':
# Exact match only, using fast search
@ -101,12 +94,14 @@ class WarningList():
value = parsed_url.hostname
return any(value == v or value.endswith("." + v.lstrip(".")) for v in self.list)
elif self.type == 'cidr':
try:
ip_value = ip_address(value)
except ValueError:
# The value to search isn't an IP address, falling back to default
return self._fast_search(value)
return any((ip_value == obj or ip_value in obj) for obj in self._network_objects)
with suppress(AddressValueError, NetmaskValueError):
ipv4 = IPv4Address(value)
return int(ipv4) in self._ipv4_filter
with suppress(AddressValueError, NetmaskValueError):
ipv6 = IPv6Address(value)
return int(ipv6) in self._ipv6_filter
# The value to search isn't an IP address, falling back to default
return self._fast_search(value)
return False
@ -164,3 +159,70 @@ class WarningLists(Mapping):
def get_loaded_lists(self):
return self.warninglists
class NetworkFilter:
def __init__(self, digit_position: int, digit2filter: Optional[dict[int, Union[bool, "NetworkFilter"]]] = None):
self.digit2filter: dict[int, Union[bool, NetworkFilter]] = digit2filter or {0: False, 1: False}
self.digit_position = digit_position
def __contains__(self, ip: int) -> bool:
child = self.digit2filter[self._get_digit(ip)]
if isinstance(child, bool):
return child
return ip in child
def append(self, net: _BaseNetwork) -> None:
digit = self._get_digit(int(net.network_address))
if net.max_prefixlen - net.prefixlen == self.digit_position:
self.digit2filter[digit] = True
return
child = self.digit2filter[digit]
if child is False:
child = NetworkFilter(self.digit_position - 1)
self.digit2filter[digit] = child
if child is not True:
child.append(net)
def _get_digit(self, ip:int) -> int:
return (ip >> self.digit_position) & 1
def __repr__(self):
return f"NetworkFilter(digit_position={self.digit_position}, digit2filter={self.digit2filter})"
def __eq__(self, other):
return isinstance(other, NetworkFilter) and self.digit_position == other.digit_position and self.digit2filter == other.digit2filter
def compile_network_filters(values: list) -> tuple[NetworkFilter, NetworkFilter]:
networks = convert_networks(values)
ipv4_filter = NetworkFilter(31)
ipv6_filter = NetworkFilter(127)
for net in networks:
root = ipv4_filter if isinstance(net, IPv4Network) else ipv6_filter
root.append(net)
return ipv4_filter, ipv6_filter
def convert_networks(values: list) -> list[_BaseNetwork]:
valid_ips = []
invalid_ips = []
for value in values:
try:
valid_ips.append(ip_network(value))
except ValueError:
invalid_ips.append(value)
if invalid_ips:
logger.warning(f'Invalid IPs found: {invalid_ips}')
return valid_ips

View File

@ -6,8 +6,10 @@ import os
import unittest
from glob import glob
from ipaddress import IPv4Network
from pymispwarninglists import WarningLists, tools
from pymispwarninglists import WarningLists, tools, WarningList
from pymispwarninglists.api import compile_network_filters, NetworkFilter
class TestPyMISPWarningLists(unittest.TestCase):
@ -58,3 +60,113 @@ class TestPyMISPWarningLists(unittest.TestCase):
self.assertTrue(tools.get_xdg_home_dir().exists())
warninglists = WarningLists(from_xdg_home=True)
self.assertEqual(len(warninglists), len(self.warninglists))
class TestCidrList(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
cls.cidr_list = WarningList(
{
"list": [
"1.1.1.1",
"51.8.152.0/21", "51.8.160.128/25",
"2a01:4180:4051::400",
"2a01:4180:c003:8::/61",
],
"description": "Test CIDR list",
"version": 0,
"name": "Test CIDR list",
"type": "cidr",
},
slow_search=True
)
def test_exact_match(self):
assert "1.1.1.1" in self.cidr_list
assert "2a01:4180:4051::400" in self.cidr_list
assert "3.3.3.3" not in self.cidr_list
assert "2a01:4180:4051::401" not in self.cidr_list
def test_ipv4_bloc(self):
# 51.8.152.0/21
assert "51.8.152.0" in self.cidr_list
assert "51.8.152.255" in self.cidr_list
assert "51.8.153.0" in self.cidr_list
assert "51.8.159.255" in self.cidr_list
# outside
assert "51.8.151.0" not in self.cidr_list
assert "51.8.160.0" not in self.cidr_list
# 51.8.160.128/25
assert "51.8.160.128" in self.cidr_list
assert "51.8.160.255" in self.cidr_list
def test_ipv6_bloc(self):
assert "2a01:4180:c003:8::" in self.cidr_list
assert "2a01:4180:c003:8::1" in self.cidr_list
assert "2a01:4180:c003:8::ffff" in self.cidr_list
assert "2a01:4180:c003:9::1000" in self.cidr_list
assert "2a01:4180:c003:7::ffff" not in self.cidr_list
assert "2a01:4180:c003:10::" not in self.cidr_list
def test_search_for_ipv4_bloc(self):
assert "51.8.152.0/21" in self.cidr_list
assert "51.8.152.0/22" not in self.cidr_list
def test_search_for_ipv4_as_int(self):
# 51.8.152.0 as integer is 864710656
assert 856201216 in self.cidr_list
class TestNetworkCompilation(unittest.TestCase):
def test_simple_case(self):
ipv4_filter, ipv6_filter = compile_network_filters([IPv4Network("160.0.0.0/3"), IPv4Network("192.0.0.0/2")])
assert ipv6_filter == NetworkFilter(127)
assert ipv4_filter == NetworkFilter(
digit_position=31,
digit2filter={
0: False,
1: NetworkFilter(
digit_position=30,
digit2filter={
0: NetworkFilter(
digit_position=29,
digit2filter={
0: False,
1: True
}
),
1: True
},
),
},
), ipv4_filter
def test_overwrite_with_bigger_network(self):
ipv4_filter, ipv6_filter = compile_network_filters([IPv4Network("192.0.0.0/2"), IPv4Network("128.0.0.0/1")])
assert ipv6_filter == NetworkFilter(127)
assert ipv4_filter == NetworkFilter(
digit_position=31,
digit2filter={
0: False,
1: True,
},
), ipv4_filter
def test_dont_overwrite_with_smaller_network(self):
ipv4_filter, ipv6_filter = compile_network_filters([IPv4Network("128.0.0.0/1"), IPv4Network("192.0.0.0/2")])
assert ipv6_filter == NetworkFilter(127)
assert ipv4_filter == NetworkFilter(
digit_position=31,
digit2filter={
0: False,
1: True,
},
), ipv4_filter

24
tests/time_cidr_search.py Normal file
View File

@ -0,0 +1,24 @@
import random
from datetime import datetime
from pymispwarninglists import WarningLists
start_time = datetime.now()
warning_lists = WarningLists(slow_search=True)
warning_lists.warninglists = {
name: warning_list
for name, warning_list in warning_lists.warninglists.items()
if warning_list.type == "cidr"
}
print(f"Loaded {len(warning_lists)} warning lists in {datetime.now() - start_time}")
random_ip_v4 = [f"{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(0,255)}" for _ in range(100)]
start_time = datetime.now()
for ip in random_ip_v4:
warning_lists.search(ip)
print(f"Searched for {len(random_ip_v4)} IPs in {len(warning_lists)} lists in {datetime.now() - start_time}")