PyMISPWarningLists/pymispwarninglists/api.py

228 lines
8.3 KiB
Python
Raw Normal View History

2017-10-29 01:40:41 +02:00
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import json
import logging
2017-10-29 01:40:41 +02:00
import sys
2022-01-13 11:22:05 +01:00
from collections.abc import Mapping
from contextlib import suppress
2017-10-29 01:40:41 +02:00
from glob import glob
from ipaddress import ip_network, IPv6Address, IPv4Address, IPv4Network, _BaseNetwork, \
AddressValueError, NetmaskValueError
2019-07-25 17:11:52 +02:00
from pathlib import Path
from typing import Union, Dict, Any, List, Optional, Tuple, Sequence
from urllib.parse import urlparse
from . import tools
from .exceptions import PyMISPWarningListsError
2017-10-29 01:40:41 +02:00
try:
2021-01-21 15:23:15 +01:00
import jsonschema # type: ignore
2017-10-29 01:40:41 +02:00
HAS_JSONSCHEMA = True
except ImportError:
HAS_JSONSCHEMA = False
logger = logging.getLogger(__name__)
2021-01-21 15:23:15 +01:00
def json_default(obj: 'WarningList') -> Union[Dict, str]:
if isinstance(obj, WarningList):
return obj.to_dict()
2017-10-29 01:40:41 +02:00
class WarningList():
2020-04-07 14:30:55 +02:00
expected_types = ['string', 'substring', 'hostname', 'cidr', 'regex']
2017-12-24 13:14:34 +01:00
2021-01-21 15:23:15 +01:00
def __init__(self, warninglist: Dict[str, Any], slow_search: bool=False):
2017-10-29 01:40:41 +02:00
self.warninglist = warninglist
self.list = self.warninglist['list']
self.set = set(self.list)
2017-10-29 01:40:41 +02:00
self.description = self.warninglist['description']
self.version = int(self.warninglist['version'])
self.name = self.warninglist['name']
2017-12-24 13:14:34 +01:00
if self.warninglist['type'] not in self.expected_types:
2020-04-07 14:30:55 +02:00
raise PyMISPWarningListsError(f'Unexpected type ({self.warninglist["type"]}), please update the expected_type list')
self.type = self.warninglist['type']
2017-10-29 01:40:41 +02:00
if self.warninglist.get('matching_attributes'):
self.matching_attributes = self.warninglist['matching_attributes']
2017-11-01 01:06:50 +01:00
self.slow_search = slow_search
2017-12-24 13:14:34 +01:00
if self.slow_search and self.type == 'cidr':
self._ipv4_filter, self._ipv6_filter = compile_network_filters(self.list)
2017-11-01 01:06:50 +01:00
2021-01-21 15:23:15 +01:00
def __repr__(self) -> str:
2019-07-25 17:11:52 +02:00
return f'<{self.__class__.__name__}(type="{self.name}", version="{self.version}", description="{self.description}")'
2019-05-13 10:24:47 +02:00
2021-01-21 15:23:15 +01:00
def __contains__(self, value: str) -> bool:
2017-11-01 01:06:50 +01:00
if self.slow_search:
return self._slow_search(value)
return self._fast_search(value)
2021-01-21 15:23:15 +01:00
def to_dict(self) -> Dict:
2017-11-01 00:06:50 +01:00
to_return = {'list': [str(e) for e in self.list], 'name': self.name,
'description': self.description, 'version': self.version,
'type': self.type}
2017-10-29 01:40:41 +02:00
if hasattr(self, 'matching_attributes'):
to_return['matching_attributes'] = self.matching_attributes
return to_return
2021-01-21 15:23:15 +01:00
def to_json(self) -> str:
return json.dumps(self, default=json_default)
2017-11-01 00:06:50 +01:00
2021-01-21 15:23:15 +01:00
def _fast_search(self, value) -> bool:
return value in self.set
2017-11-01 01:06:50 +01:00
2021-01-21 15:23:15 +01:00
def _slow_search(self, value: str) -> bool:
2017-12-24 13:14:34 +01:00
if self.type == 'string':
# Exact match only, using fast search
2017-11-01 01:06:50 +01:00
return self._fast_search(value)
2017-12-24 13:14:34 +01:00
elif self.type == 'substring':
# Expected to match on a part of the value
# i.e.: value = 'blah.de' self.list == ['.fr', '.de']
return any(v in value for v in self.list)
elif self.type == 'hostname':
# Expected to match on hostnames in URLs (i.e. the search query is a URL)
# So we do a reverse search if any of the entries in the list are present in the URL
# i.e.: value = 'http://foo.blah.de/meh' self.list == ['blah.de', 'blah.fr']
2020-04-20 10:23:29 +02:00
parsed_url = urlparse(value)
if parsed_url.hostname:
value = parsed_url.hostname
2021-01-21 15:23:15 +01:00
return any(value == v or value.endswith("." + v.lstrip(".")) for v in self.list)
2017-12-24 13:14:34 +01:00
elif self.type == 'cidr':
with suppress(AddressValueError, NetmaskValueError):
ipv4 = IPv4Address(value)
return int(ipv4) in self._ipv4_filter
with suppress(AddressValueError, NetmaskValueError):
ipv6 = IPv6Address(value)
return int(ipv6) in self._ipv6_filter
# The value to search isn't an IP address, falling back to default
return self._fast_search(value)
2021-01-21 15:23:15 +01:00
return False
2017-10-31 03:11:37 +01:00
2017-10-29 01:40:41 +02:00
2022-01-13 11:22:05 +01:00
class WarningLists(Mapping):
2017-10-29 01:40:41 +02:00
def __init__(self, slow_search: bool=False, lists: Optional[List]=None, from_xdg_home: bool=False, path_to_repo: Optional[Path]= None):
2017-12-24 13:14:34 +01:00
"""Load all the warning lists from the package.
:slow_search: If true, uses the most appropriate search method. Can be slower. Default: exact match.
:lists: A list of warning lists (typically fetched from a MISP instance)
2017-12-24 13:14:34 +01:00
"""
if not lists:
if from_xdg_home:
path_to_repo = tools.get_xdg_home_dir()
if not path_to_repo.exists():
tools.update_warninglists()
if not path_to_repo or not path_to_repo.exists():
path_to_repo = Path(sys.modules['pymispwarninglists'].__file__).parent / 'data' / 'misp-warninglists' # type: ignore
lists = []
self.root_dir_warninglists = path_to_repo / 'lists'
2019-07-25 17:11:52 +02:00
for warninglist_file in glob(str(self.root_dir_warninglists / '*' / 'list.json')):
2021-09-14 16:31:58 +02:00
with open(warninglist_file, mode='r', encoding="utf-8") as f:
lists.append(json.load(f))
2019-07-25 17:11:52 +02:00
if not lists:
raise PyMISPWarningListsError('Unable to load the lists. Do not forget to initialize the submodule (git submodule update --init).')
2017-10-29 01:40:41 +02:00
self.warninglists = {}
for warninglist in lists:
2017-11-01 01:06:50 +01:00
self.warninglists[warninglist['name']] = WarningList(warninglist, slow_search)
2017-10-29 01:40:41 +02:00
def validate_with_schema(self):
if not HAS_JSONSCHEMA:
raise ImportError('jsonschema is required: pip install jsonschema')
2019-07-25 17:11:52 +02:00
schema = Path(sys.modules['pymispwarninglists'].__file__).parent / 'data' / 'misp-warninglists' / 'schema.json'
2017-10-29 01:40:41 +02:00
with open(schema, 'r') as f:
loaded_schema = json.load(f)
for w in self.warninglists.values():
jsonschema.validate(w.warninglist, loaded_schema)
def __getitem__(self, name):
return self.warninglists[name]
def __iter__(self):
return iter(self.warninglists)
2021-01-21 15:23:15 +01:00
def search(self, value) -> List:
2017-10-31 03:11:37 +01:00
matches = []
for name, wl in self.warninglists.items():
2017-11-01 00:06:50 +01:00
if value in wl:
matches.append(wl)
2017-10-31 03:11:37 +01:00
return matches
2017-10-29 01:40:41 +02:00
def __len__(self):
return len(self.warninglists)
2019-10-23 20:42:47 +02:00
def get_loaded_lists(self):
return self.warninglists
class NetworkFilter:
2023-06-22 22:36:41 +02:00
def __init__(self, digit_position: int, digit2filter: Optional[Dict[int, Union[bool, "NetworkFilter"]]] = None):
self.digit2filter: Dict[int, Union[bool, NetworkFilter]] = digit2filter or {0: False, 1: False}
self.digit_position = digit_position
def __contains__(self, ip: int) -> bool:
child = self.digit2filter[self._get_digit(ip)]
if isinstance(child, bool):
return child
return ip in child
def append(self, net: _BaseNetwork) -> None:
digit = self._get_digit(int(net.network_address))
if net.max_prefixlen - net.prefixlen == self.digit_position:
self.digit2filter[digit] = True
return
child = self.digit2filter[digit]
if child is False:
child = NetworkFilter(self.digit_position - 1)
self.digit2filter[digit] = child
if child is not True:
child.append(net)
def _get_digit(self, ip: int) -> int:
return (ip >> self.digit_position) & 1
def __repr__(self):
return f"NetworkFilter(digit_position={self.digit_position}, digit2filter={self.digit2filter})"
def __eq__(self, other):
return isinstance(other, NetworkFilter) and self.digit_position == other.digit_position and self.digit2filter == other.digit2filter
2023-06-22 22:36:41 +02:00
def compile_network_filters(values: list) -> Tuple[NetworkFilter, NetworkFilter]:
networks = convert_networks(values)
ipv4_filter = NetworkFilter(31)
ipv6_filter = NetworkFilter(127)
for net in networks:
root = ipv4_filter if isinstance(net, IPv4Network) else ipv6_filter
root.append(net)
return ipv4_filter, ipv6_filter
def convert_networks(values: list) -> Sequence[_BaseNetwork]:
valid_ips = []
invalid_ips = []
for value in values:
try:
valid_ips.append(ip_network(value))
except ValueError:
invalid_ips.append(value)
if invalid_ips:
logger.warning(f'Invalid IPs found: {invalid_ips}')
return valid_ips