Merge branch 'Mathieu4141-cidr/faster-search'
commit
ad92de239c
|
@ -2,13 +2,16 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from collections.abc import Mapping
|
||||
from contextlib import suppress
|
||||
from glob import glob
|
||||
from ipaddress import ip_address, ip_network
|
||||
from ipaddress import ip_network, IPv6Address, IPv4Address, IPv4Network, _BaseNetwork, \
|
||||
AddressValueError, NetmaskValueError
|
||||
from pathlib import Path
|
||||
from typing import Union, Dict, Any, List, Optional
|
||||
from typing import Union, Dict, Any, List, Optional, Tuple, Sequence
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from . import tools
|
||||
|
@ -21,6 +24,9 @@ except ImportError:
|
|||
HAS_JSONSCHEMA = False
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def json_default(obj: 'WarningList') -> Union[Dict, str]:
|
||||
if isinstance(obj, WarningList):
|
||||
return obj.to_dict()
|
||||
|
@ -44,13 +50,9 @@ class WarningList():
|
|||
self.matching_attributes = self.warninglist['matching_attributes']
|
||||
|
||||
self.slow_search = slow_search
|
||||
self._network_objects = []
|
||||
|
||||
if self.slow_search and self.type == 'cidr':
|
||||
self._network_objects = self._network_index()
|
||||
# If network objects is empty, reverting to default anyway
|
||||
if not self._network_objects:
|
||||
self.slow_search = False
|
||||
self._ipv4_filter, self._ipv6_filter = compile_network_filters(self.list)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<{self.__class__.__name__}(type="{self.name}", version="{self.version}", description="{self.description}")'
|
||||
|
@ -74,16 +76,6 @@ class WarningList():
|
|||
def _fast_search(self, value) -> bool:
|
||||
return value in self.set
|
||||
|
||||
def _network_index(self) -> List:
|
||||
to_return = []
|
||||
for entry in self.list:
|
||||
try:
|
||||
# Try if the entry is a network bloc or an IP
|
||||
to_return.append(ip_network(entry))
|
||||
except ValueError:
|
||||
pass
|
||||
return to_return
|
||||
|
||||
def _slow_search(self, value: str) -> bool:
|
||||
if self.type == 'string':
|
||||
# Exact match only, using fast search
|
||||
|
@ -101,12 +93,14 @@ class WarningList():
|
|||
value = parsed_url.hostname
|
||||
return any(value == v or value.endswith("." + v.lstrip(".")) for v in self.list)
|
||||
elif self.type == 'cidr':
|
||||
try:
|
||||
ip_value = ip_address(value)
|
||||
except ValueError:
|
||||
# The value to search isn't an IP address, falling back to default
|
||||
return self._fast_search(value)
|
||||
return any((ip_value == obj or ip_value in obj) for obj in self._network_objects)
|
||||
with suppress(AddressValueError, NetmaskValueError):
|
||||
ipv4 = IPv4Address(value)
|
||||
return int(ipv4) in self._ipv4_filter
|
||||
with suppress(AddressValueError, NetmaskValueError):
|
||||
ipv6 = IPv6Address(value)
|
||||
return int(ipv6) in self._ipv6_filter
|
||||
# The value to search isn't an IP address, falling back to default
|
||||
return self._fast_search(value)
|
||||
return False
|
||||
|
||||
|
||||
|
@ -164,3 +158,70 @@ class WarningLists(Mapping):
|
|||
|
||||
def get_loaded_lists(self):
|
||||
return self.warninglists
|
||||
|
||||
|
||||
class NetworkFilter:
|
||||
def __init__(self, digit_position: int, digit2filter: Optional[Dict[int, Union[bool, "NetworkFilter"]]] = None):
|
||||
self.digit2filter: Dict[int, Union[bool, NetworkFilter]] = digit2filter or {0: False, 1: False}
|
||||
self.digit_position = digit_position
|
||||
|
||||
def __contains__(self, ip: int) -> bool:
|
||||
child = self.digit2filter[self._get_digit(ip)]
|
||||
if isinstance(child, bool):
|
||||
return child
|
||||
|
||||
return ip in child
|
||||
|
||||
def append(self, net: _BaseNetwork) -> None:
|
||||
digit = self._get_digit(int(net.network_address))
|
||||
|
||||
if net.max_prefixlen - net.prefixlen == self.digit_position:
|
||||
self.digit2filter[digit] = True
|
||||
return
|
||||
|
||||
child = self.digit2filter[digit]
|
||||
|
||||
if child is False:
|
||||
child = NetworkFilter(self.digit_position - 1)
|
||||
self.digit2filter[digit] = child
|
||||
|
||||
if child is not True:
|
||||
child.append(net)
|
||||
|
||||
def _get_digit(self, ip: int) -> int:
|
||||
return (ip >> self.digit_position) & 1
|
||||
|
||||
def __repr__(self):
|
||||
return f"NetworkFilter(digit_position={self.digit_position}, digit2filter={self.digit2filter})"
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, NetworkFilter) and self.digit_position == other.digit_position and self.digit2filter == other.digit2filter
|
||||
|
||||
|
||||
def compile_network_filters(values: list) -> Tuple[NetworkFilter, NetworkFilter]:
|
||||
networks = convert_networks(values)
|
||||
|
||||
ipv4_filter = NetworkFilter(31)
|
||||
ipv6_filter = NetworkFilter(127)
|
||||
|
||||
for net in networks:
|
||||
root = ipv4_filter if isinstance(net, IPv4Network) else ipv6_filter
|
||||
root.append(net)
|
||||
|
||||
return ipv4_filter, ipv6_filter
|
||||
|
||||
|
||||
def convert_networks(values: list) -> Sequence[_BaseNetwork]:
|
||||
valid_ips = []
|
||||
invalid_ips = []
|
||||
|
||||
for value in values:
|
||||
try:
|
||||
valid_ips.append(ip_network(value))
|
||||
except ValueError:
|
||||
invalid_ips.append(value)
|
||||
|
||||
if invalid_ips:
|
||||
logger.warning(f'Invalid IPs found: {invalid_ips}')
|
||||
|
||||
return valid_ips
|
||||
|
|
117
tests/tests.py
117
tests/tests.py
|
@ -6,8 +6,10 @@ import os
|
|||
import unittest
|
||||
|
||||
from glob import glob
|
||||
from ipaddress import IPv4Network
|
||||
|
||||
from pymispwarninglists import WarningLists, tools
|
||||
from pymispwarninglists import WarningLists, tools, WarningList
|
||||
from pymispwarninglists.api import compile_network_filters, NetworkFilter
|
||||
|
||||
|
||||
class TestPyMISPWarningLists(unittest.TestCase):
|
||||
|
@ -58,3 +60,116 @@ class TestPyMISPWarningLists(unittest.TestCase):
|
|||
self.assertTrue(tools.get_xdg_home_dir().exists())
|
||||
warninglists = WarningLists(from_xdg_home=True)
|
||||
self.assertEqual(len(warninglists), len(self.warninglists))
|
||||
|
||||
|
||||
class TestCidrList(unittest.TestCase):
|
||||
|
||||
cidr_list: WarningList
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
cls.cidr_list = WarningList(
|
||||
{
|
||||
"list": [
|
||||
"1.1.1.1",
|
||||
"51.8.152.0/21", "51.8.160.128/25",
|
||||
"2a01:4180:4051::400",
|
||||
"2a01:4180:c003:8::/61",
|
||||
],
|
||||
"description": "Test CIDR list",
|
||||
"version": 0,
|
||||
"name": "Test CIDR list",
|
||||
"type": "cidr",
|
||||
},
|
||||
slow_search=True
|
||||
)
|
||||
|
||||
def test_exact_match(self):
|
||||
assert "1.1.1.1" in self.cidr_list
|
||||
assert "2a01:4180:4051::400" in self.cidr_list
|
||||
|
||||
assert "3.3.3.3" not in self.cidr_list
|
||||
assert "2a01:4180:4051::401" not in self.cidr_list
|
||||
|
||||
def test_ipv4_bloc(self):
|
||||
# 51.8.152.0/21
|
||||
assert "51.8.152.0" in self.cidr_list
|
||||
assert "51.8.152.255" in self.cidr_list
|
||||
assert "51.8.153.0" in self.cidr_list
|
||||
assert "51.8.159.255" in self.cidr_list
|
||||
|
||||
# outside
|
||||
assert "51.8.151.0" not in self.cidr_list
|
||||
assert "51.8.160.0" not in self.cidr_list
|
||||
|
||||
# 51.8.160.128/25
|
||||
assert "51.8.160.128" in self.cidr_list
|
||||
assert "51.8.160.255" in self.cidr_list
|
||||
|
||||
def test_ipv6_bloc(self):
|
||||
assert "2a01:4180:c003:8::" in self.cidr_list
|
||||
assert "2a01:4180:c003:8::1" in self.cidr_list
|
||||
assert "2a01:4180:c003:8::ffff" in self.cidr_list
|
||||
assert "2a01:4180:c003:9::1000" in self.cidr_list
|
||||
|
||||
assert "2a01:4180:c003:7::ffff" not in self.cidr_list
|
||||
assert "2a01:4180:c003:10::" not in self.cidr_list
|
||||
|
||||
def test_search_for_ipv4_bloc(self):
|
||||
assert "51.8.152.0/21" in self.cidr_list
|
||||
|
||||
assert "51.8.152.0/22" not in self.cidr_list
|
||||
|
||||
def test_search_for_ipv4_as_int(self):
|
||||
# 51.8.152.0 as integer is 864710656
|
||||
assert 856201216 in self.cidr_list
|
||||
|
||||
|
||||
class TestNetworkCompilation(unittest.TestCase):
|
||||
def test_simple_case(self):
|
||||
ipv4_filter, ipv6_filter = compile_network_filters([IPv4Network("160.0.0.0/3"), IPv4Network("192.0.0.0/2")])
|
||||
|
||||
assert ipv6_filter == NetworkFilter(127)
|
||||
assert ipv4_filter == NetworkFilter(
|
||||
digit_position=31,
|
||||
digit2filter={
|
||||
0: False,
|
||||
1: NetworkFilter(
|
||||
digit_position=30,
|
||||
digit2filter={
|
||||
0: NetworkFilter(
|
||||
digit_position=29,
|
||||
digit2filter={
|
||||
0: False,
|
||||
1: True
|
||||
}
|
||||
),
|
||||
1: True
|
||||
},
|
||||
),
|
||||
},
|
||||
), ipv4_filter
|
||||
|
||||
def test_overwrite_with_bigger_network(self):
|
||||
ipv4_filter, ipv6_filter = compile_network_filters([IPv4Network("192.0.0.0/2"), IPv4Network("128.0.0.0/1")])
|
||||
|
||||
assert ipv6_filter == NetworkFilter(127)
|
||||
assert ipv4_filter == NetworkFilter(
|
||||
digit_position=31,
|
||||
digit2filter={
|
||||
0: False,
|
||||
1: True,
|
||||
},
|
||||
), ipv4_filter
|
||||
|
||||
def test_dont_overwrite_with_smaller_network(self):
|
||||
ipv4_filter, ipv6_filter = compile_network_filters([IPv4Network("128.0.0.0/1"), IPv4Network("192.0.0.0/2")])
|
||||
|
||||
assert ipv6_filter == NetworkFilter(127)
|
||||
assert ipv4_filter == NetworkFilter(
|
||||
digit_position=31,
|
||||
digit2filter={
|
||||
0: False,
|
||||
1: True,
|
||||
},
|
||||
), ipv4_filter
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
import random
|
||||
from datetime import datetime
|
||||
|
||||
from pymispwarninglists import WarningLists
|
||||
|
||||
start_time = datetime.now()
|
||||
|
||||
warning_lists = WarningLists(slow_search=True)
|
||||
warning_lists.warninglists = {
|
||||
name: warning_list
|
||||
for name, warning_list in warning_lists.warninglists.items()
|
||||
if warning_list.type == "cidr"
|
||||
}
|
||||
|
||||
print(f"Loaded {len(warning_lists)} warning lists in {datetime.now() - start_time}")
|
||||
|
||||
random_ip_v4 = [f"{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(0,255)}" for _ in range(100)]
|
||||
|
||||
start_time = datetime.now()
|
||||
|
||||
for ip in random_ip_v4:
|
||||
warning_lists.search(ip)
|
||||
|
||||
print(f"Searched for {len(random_ip_v4)} IPs in {len(warning_lists)} lists in {datetime.now() - start_time}")
|
Loading…
Reference in New Issue