[cidr] Use a decision tree for faster slow search
parent
4d2e31cd4c
commit
a0713f732c
|
@ -2,11 +2,14 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
|
from contextlib import suppress
|
||||||
from glob import glob
|
from glob import glob
|
||||||
from ipaddress import ip_address, ip_network
|
from ipaddress import ip_network, IPv6Address, IPv4Address, IPv4Network, IPv6Network, _BaseNetwork, \
|
||||||
|
AddressValueError, NetmaskValueError
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union, Dict, Any, List, Optional
|
from typing import Union, Dict, Any, List, Optional
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
@ -21,11 +24,15 @@ except ImportError:
|
||||||
HAS_JSONSCHEMA = False
|
HAS_JSONSCHEMA = False
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def json_default(obj: 'WarningList') -> Union[Dict, str]:
|
def json_default(obj: 'WarningList') -> Union[Dict, str]:
|
||||||
if isinstance(obj, WarningList):
|
if isinstance(obj, WarningList):
|
||||||
return obj.to_dict()
|
return obj.to_dict()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class WarningList():
|
class WarningList():
|
||||||
|
|
||||||
expected_types = ['string', 'substring', 'hostname', 'cidr', 'regex']
|
expected_types = ['string', 'substring', 'hostname', 'cidr', 'regex']
|
||||||
|
@ -44,13 +51,9 @@ class WarningList():
|
||||||
self.matching_attributes = self.warninglist['matching_attributes']
|
self.matching_attributes = self.warninglist['matching_attributes']
|
||||||
|
|
||||||
self.slow_search = slow_search
|
self.slow_search = slow_search
|
||||||
self._network_objects = []
|
|
||||||
|
|
||||||
if self.slow_search and self.type == 'cidr':
|
if self.slow_search and self.type == 'cidr':
|
||||||
self._network_objects = self._network_index()
|
self._ipv4_filter, self._ipv6_filter = compile_network_filters(self.list)
|
||||||
# If network objects is empty, reverting to default anyway
|
|
||||||
if not self._network_objects:
|
|
||||||
self.slow_search = False
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f'<{self.__class__.__name__}(type="{self.name}", version="{self.version}", description="{self.description}")'
|
return f'<{self.__class__.__name__}(type="{self.name}", version="{self.version}", description="{self.description}")'
|
||||||
|
@ -74,16 +77,6 @@ class WarningList():
|
||||||
def _fast_search(self, value) -> bool:
|
def _fast_search(self, value) -> bool:
|
||||||
return value in self.set
|
return value in self.set
|
||||||
|
|
||||||
def _network_index(self) -> List:
|
|
||||||
to_return = []
|
|
||||||
for entry in self.list:
|
|
||||||
try:
|
|
||||||
# Try if the entry is a network bloc or an IP
|
|
||||||
to_return.append(ip_network(entry))
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
return to_return
|
|
||||||
|
|
||||||
def _slow_search(self, value: str) -> bool:
|
def _slow_search(self, value: str) -> bool:
|
||||||
if self.type == 'string':
|
if self.type == 'string':
|
||||||
# Exact match only, using fast search
|
# Exact match only, using fast search
|
||||||
|
@ -101,12 +94,14 @@ class WarningList():
|
||||||
value = parsed_url.hostname
|
value = parsed_url.hostname
|
||||||
return any(value == v or value.endswith("." + v.lstrip(".")) for v in self.list)
|
return any(value == v or value.endswith("." + v.lstrip(".")) for v in self.list)
|
||||||
elif self.type == 'cidr':
|
elif self.type == 'cidr':
|
||||||
try:
|
with suppress(AddressValueError, NetmaskValueError):
|
||||||
ip_value = ip_address(value)
|
ipv4 = IPv4Address(value)
|
||||||
except ValueError:
|
return int(ipv4) in self._ipv4_filter
|
||||||
|
with suppress(AddressValueError, NetmaskValueError):
|
||||||
|
ipv6 = IPv6Address(value)
|
||||||
|
return int(ipv6) in self._ipv6_filter
|
||||||
# The value to search isn't an IP address, falling back to default
|
# The value to search isn't an IP address, falling back to default
|
||||||
return self._fast_search(value)
|
return self._fast_search(value)
|
||||||
return any((ip_value == obj or ip_value in obj) for obj in self._network_objects)
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@ -164,3 +159,70 @@ class WarningLists(Mapping):
|
||||||
|
|
||||||
def get_loaded_lists(self):
|
def get_loaded_lists(self):
|
||||||
return self.warninglists
|
return self.warninglists
|
||||||
|
|
||||||
|
|
||||||
|
class NetworkFilter:
|
||||||
|
def __init__(self, digit_position: int, digit2filter: Optional[dict[int, Union[bool, "NetworkFilter"]]] = None):
|
||||||
|
self.digit2filter: dict[int, Union[bool, NetworkFilter]] = digit2filter or {0: False, 1: False}
|
||||||
|
self.digit_position = digit_position
|
||||||
|
|
||||||
|
def __contains__(self, ip: int) -> bool:
|
||||||
|
child = self.digit2filter[self._get_digit(ip)]
|
||||||
|
if isinstance(child, bool):
|
||||||
|
return child
|
||||||
|
|
||||||
|
return ip in child
|
||||||
|
|
||||||
|
def append(self, net: _BaseNetwork) -> None:
|
||||||
|
digit = self._get_digit(int(net.network_address))
|
||||||
|
|
||||||
|
if net.max_prefixlen - net.prefixlen == self.digit_position:
|
||||||
|
self.digit2filter[digit] = True
|
||||||
|
return
|
||||||
|
|
||||||
|
child = self.digit2filter[digit]
|
||||||
|
|
||||||
|
if child is False:
|
||||||
|
child = NetworkFilter(self.digit_position - 1)
|
||||||
|
self.digit2filter[digit] = child
|
||||||
|
|
||||||
|
if child is not True:
|
||||||
|
child.append(net)
|
||||||
|
|
||||||
|
def _get_digit(self, ip:int) -> int:
|
||||||
|
return (ip >> self.digit_position) & 1
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"NetworkFilter(digit_position={self.digit_position}, digit2filter={self.digit2filter})"
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
return isinstance(other, NetworkFilter) and self.digit_position == other.digit_position and self.digit2filter == other.digit2filter
|
||||||
|
|
||||||
|
|
||||||
|
def compile_network_filters(values: list) -> tuple[NetworkFilter, NetworkFilter]:
|
||||||
|
networks = convert_networks(values)
|
||||||
|
|
||||||
|
ipv4_filter = NetworkFilter(31)
|
||||||
|
ipv6_filter = NetworkFilter(127)
|
||||||
|
|
||||||
|
for net in networks:
|
||||||
|
root = ipv4_filter if isinstance(net, IPv4Network) else ipv6_filter
|
||||||
|
root.append(net)
|
||||||
|
|
||||||
|
return ipv4_filter, ipv6_filter
|
||||||
|
|
||||||
|
|
||||||
|
def convert_networks(values: list) -> list[_BaseNetwork]:
|
||||||
|
valid_ips = []
|
||||||
|
invalid_ips = []
|
||||||
|
|
||||||
|
for value in values:
|
||||||
|
try:
|
||||||
|
valid_ips.append(ip_network(value))
|
||||||
|
except ValueError:
|
||||||
|
invalid_ips.append(value)
|
||||||
|
|
||||||
|
if invalid_ips:
|
||||||
|
logger.warning(f'Invalid IPs found: {invalid_ips}')
|
||||||
|
|
||||||
|
return valid_ips
|
||||||
|
|
114
tests/tests.py
114
tests/tests.py
|
@ -6,8 +6,10 @@ import os
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from glob import glob
|
from glob import glob
|
||||||
|
from ipaddress import IPv4Network
|
||||||
|
|
||||||
from pymispwarninglists import WarningLists, tools
|
from pymispwarninglists import WarningLists, tools, WarningList
|
||||||
|
from pymispwarninglists.api import compile_network_filters, NetworkFilter
|
||||||
|
|
||||||
|
|
||||||
class TestPyMISPWarningLists(unittest.TestCase):
|
class TestPyMISPWarningLists(unittest.TestCase):
|
||||||
|
@ -58,3 +60,113 @@ class TestPyMISPWarningLists(unittest.TestCase):
|
||||||
self.assertTrue(tools.get_xdg_home_dir().exists())
|
self.assertTrue(tools.get_xdg_home_dir().exists())
|
||||||
warninglists = WarningLists(from_xdg_home=True)
|
warninglists = WarningLists(from_xdg_home=True)
|
||||||
self.assertEqual(len(warninglists), len(self.warninglists))
|
self.assertEqual(len(warninglists), len(self.warninglists))
|
||||||
|
|
||||||
|
|
||||||
|
class TestCidrList(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls) -> None:
|
||||||
|
cls.cidr_list = WarningList(
|
||||||
|
{
|
||||||
|
"list": [
|
||||||
|
"1.1.1.1",
|
||||||
|
"51.8.152.0/21", "51.8.160.128/25",
|
||||||
|
"2a01:4180:4051::400",
|
||||||
|
"2a01:4180:c003:8::/61",
|
||||||
|
],
|
||||||
|
"description": "Test CIDR list",
|
||||||
|
"version": 0,
|
||||||
|
"name": "Test CIDR list",
|
||||||
|
"type": "cidr",
|
||||||
|
},
|
||||||
|
slow_search=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_exact_match(self):
|
||||||
|
assert "1.1.1.1" in self.cidr_list
|
||||||
|
assert "2a01:4180:4051::400" in self.cidr_list
|
||||||
|
|
||||||
|
assert "3.3.3.3" not in self.cidr_list
|
||||||
|
assert "2a01:4180:4051::401" not in self.cidr_list
|
||||||
|
|
||||||
|
def test_ipv4_bloc(self):
|
||||||
|
# 51.8.152.0/21
|
||||||
|
assert "51.8.152.0" in self.cidr_list
|
||||||
|
assert "51.8.152.255" in self.cidr_list
|
||||||
|
assert "51.8.153.0" in self.cidr_list
|
||||||
|
assert "51.8.159.255" in self.cidr_list
|
||||||
|
|
||||||
|
# outside
|
||||||
|
assert "51.8.151.0" not in self.cidr_list
|
||||||
|
assert "51.8.160.0" not in self.cidr_list
|
||||||
|
|
||||||
|
# 51.8.160.128/25
|
||||||
|
assert "51.8.160.128" in self.cidr_list
|
||||||
|
assert "51.8.160.255" in self.cidr_list
|
||||||
|
|
||||||
|
def test_ipv6_bloc(self):
|
||||||
|
assert "2a01:4180:c003:8::" in self.cidr_list
|
||||||
|
assert "2a01:4180:c003:8::1" in self.cidr_list
|
||||||
|
assert "2a01:4180:c003:8::ffff" in self.cidr_list
|
||||||
|
assert "2a01:4180:c003:9::1000" in self.cidr_list
|
||||||
|
|
||||||
|
assert "2a01:4180:c003:7::ffff" not in self.cidr_list
|
||||||
|
assert "2a01:4180:c003:10::" not in self.cidr_list
|
||||||
|
|
||||||
|
def test_search_for_ipv4_bloc(self):
|
||||||
|
assert "51.8.152.0/21" in self.cidr_list
|
||||||
|
|
||||||
|
assert "51.8.152.0/22" not in self.cidr_list
|
||||||
|
|
||||||
|
def test_search_for_ipv4_as_int(self):
|
||||||
|
# 51.8.152.0 as integer is 864710656
|
||||||
|
assert 856201216 in self.cidr_list
|
||||||
|
|
||||||
|
|
||||||
|
class TestNetworkCompilation(unittest.TestCase):
|
||||||
|
def test_simple_case(self):
|
||||||
|
ipv4_filter, ipv6_filter = compile_network_filters([IPv4Network("160.0.0.0/3"), IPv4Network("192.0.0.0/2")])
|
||||||
|
|
||||||
|
assert ipv6_filter == NetworkFilter(127)
|
||||||
|
assert ipv4_filter == NetworkFilter(
|
||||||
|
digit_position=31,
|
||||||
|
digit2filter={
|
||||||
|
0: False,
|
||||||
|
1: NetworkFilter(
|
||||||
|
digit_position=30,
|
||||||
|
digit2filter={
|
||||||
|
0: NetworkFilter(
|
||||||
|
digit_position=29,
|
||||||
|
digit2filter={
|
||||||
|
0: False,
|
||||||
|
1: True
|
||||||
|
}
|
||||||
|
),
|
||||||
|
1: True
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
), ipv4_filter
|
||||||
|
|
||||||
|
def test_overwrite_with_bigger_network(self):
|
||||||
|
ipv4_filter, ipv6_filter = compile_network_filters([IPv4Network("192.0.0.0/2"), IPv4Network("128.0.0.0/1")])
|
||||||
|
|
||||||
|
assert ipv6_filter == NetworkFilter(127)
|
||||||
|
assert ipv4_filter == NetworkFilter(
|
||||||
|
digit_position=31,
|
||||||
|
digit2filter={
|
||||||
|
0: False,
|
||||||
|
1: True,
|
||||||
|
},
|
||||||
|
), ipv4_filter
|
||||||
|
|
||||||
|
def test_dont_overwrite_with_smaller_network(self):
|
||||||
|
ipv4_filter, ipv6_filter = compile_network_filters([IPv4Network("128.0.0.0/1"), IPv4Network("192.0.0.0/2")])
|
||||||
|
|
||||||
|
assert ipv6_filter == NetworkFilter(127)
|
||||||
|
assert ipv4_filter == NetworkFilter(
|
||||||
|
digit_position=31,
|
||||||
|
digit2filter={
|
||||||
|
0: False,
|
||||||
|
1: True,
|
||||||
|
},
|
||||||
|
), ipv4_filter
|
||||||
|
|
|
@ -0,0 +1,24 @@
|
||||||
|
import random
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from pymispwarninglists import WarningLists
|
||||||
|
|
||||||
|
start_time = datetime.now()
|
||||||
|
|
||||||
|
warning_lists = WarningLists(slow_search=True)
|
||||||
|
warning_lists.warninglists = {
|
||||||
|
name: warning_list
|
||||||
|
for name, warning_list in warning_lists.warninglists.items()
|
||||||
|
if warning_list.type == "cidr"
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f"Loaded {len(warning_lists)} warning lists in {datetime.now() - start_time}")
|
||||||
|
|
||||||
|
random_ip_v4 = [f"{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(0,255)}" for _ in range(100)]
|
||||||
|
|
||||||
|
start_time = datetime.now()
|
||||||
|
|
||||||
|
for ip in random_ip_v4:
|
||||||
|
warning_lists.search(ip)
|
||||||
|
|
||||||
|
print(f"Searched for {len(random_ip_v4)} IPs in {len(warning_lists)} lists in {datetime.now() - start_time}")
|
Loading…
Reference in New Issue