misp-warninglists/tools/generator.py

310 lines
10 KiB
Python
Raw Normal View History

import datetime
2021-06-12 12:13:23 +02:00
import ipaddress
import json
2020-07-21 13:42:50 +02:00
import logging
from inspect import currentframe, getframeinfo, getmodulename, stack
from os import mkdir, path
from typing import List, Union
import gzip
import requests
import dns.exception
import dns.resolver
2023-08-25 09:55:52 +02:00
import dns.reversename
from dateutil.parser import parse as parsedate
def init_logging():
rel_path = getframeinfo(currentframe()).filename
current_folder = path.dirname(path.abspath(rel_path))
LOG_DIR = path.join(current_folder, '../generators.log')
logFormatter = logging.Formatter(
"[%(asctime)s] %(levelname)s::%(funcName)s()::%(message)s"
)
rootLogger = logging.getLogger()
rootLogger.setLevel(logging.INFO)
# Log to file
fileHandler = logging.FileHandler(LOG_DIR)
fileHandler.setFormatter(logFormatter)
rootLogger.addHandler(fileHandler)
# Log to console too
''' consoleHandler = logging.StreamHandler()
consoleHandler.setFormatter(logFormatter)
rootLogger.addHandler(consoleHandler) '''
return rootLogger
init_logging()
def download_to_file(url, file, gzip_enable=False):
frame_records = stack()[1]
caller = getmodulename(frame_records[1]).upper()
user_agent = {
"User-agent": "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:46.0) Gecko/20100101 Firefox/46.0"
}
2020-07-21 13:42:50 +02:00
try:
2022-02-01 16:52:09 +01:00
logging.info(f'download_to_file - fetching url: {url}')
2020-07-21 13:42:50 +02:00
r = requests.head(url, headers=user_agent)
url_datetime = parsedate(r.headers['Last-Modified']).astimezone()
file_datetime = datetime.datetime.fromtimestamp(
path.getmtime(get_abspath_source_file(file))
).astimezone()
2020-07-21 13:42:50 +02:00
if url_datetime > file_datetime:
logging.info(
'{} File on server is newer, so downloading update to {}'.format(
caller, get_abspath_source_file(file)
)
)
actual_download_to_file(url, file, user_agent, gzip_enable=gzip_enable)
else:
logging.info('{} File on server is older, nothing to do'.format(caller))
except KeyError as exc:
logging.warning(
'{} KeyError in the headers. the {} header was not sent by server {}. Downloading file'.format(
caller, str(exc), url
)
)
actual_download_to_file(url, file, user_agent, gzip_enable=gzip_enable)
except FileNotFoundError as exc:
logging.info(
"{} File didn't exist, so downloading {} from {}".format(caller, file, url)
)
actual_download_to_file(url, file, user_agent, gzip_enable=gzip_enable)
except Exception as exc:
logging.warning('{} General exception occured: {}.'.format(caller, str(exc)))
actual_download_to_file(url, file, user_agent, gzip_enable=gzip_enable)
2020-07-21 13:42:50 +02:00
def actual_download_to_file(url, file, user_agent, gzip_enable=False):
2020-07-21 13:42:50 +02:00
r = requests.get(url, headers=user_agent)
with open(get_abspath_source_file(file), 'wb') as fd:
2020-07-21 13:42:50 +02:00
for chunk in r.iter_content(4096):
fd.write(chunk)
if gzip_enable:
with gzip.open(get_abspath_source_file(file), 'rb') as fgzip:
file_content = fgzip.read()
with open(get_abspath_source_file(file), 'wb') as fd:
fd.write(file_content)
2020-07-21 13:42:50 +02:00
def process_stream(url):
r = requests.get(url, stream=True)
2020-07-21 13:42:50 +02:00
data_list = []
for line in r.iter_lines():
v = line.decode('utf-8')
if not v.startswith("#"):
if v:
data_list.append(v)
2020-07-21 13:42:50 +02:00
return data_list
def download(url):
user_agent = {
"User-agent": "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:46.0) Gecko/20100101 Firefox/46.0"
}
return requests.get(url, headers=user_agent)
def get_abspath_list_file(dst):
rel_path = getframeinfo(currentframe()).filename
current_folder = path.dirname(path.abspath(rel_path))
real_path = path.join(current_folder, '../lists/{dst}/list.json'.format(dst=dst))
return path.abspath(path.realpath(real_path))
def get_abspath_source_file(dst):
rel_path = getframeinfo(currentframe()).filename
current_folder = path.dirname(path.abspath(rel_path))
tmp_path = path.join(current_folder, '../tmp/')
if not path.exists(tmp_path):
mkdir(tmp_path)
return path.abspath(path.realpath(path.join(tmp_path, '{dst}'.format(dst=dst))))
def get_version():
return int(datetime.date.today().strftime('%Y%m%d'))
def unique_sorted_warninglist(warninglist):
warninglist['list'] = sorted(set(warninglist['list']))
return warninglist
def write_to_file(warninglist, dst):
frame_records = stack()[1]
caller = getmodulename(frame_records[1]).upper()
try:
2021-06-12 12:13:23 +02:00
warninglist = unique_sorted_warninglist(warninglist)
with open(get_abspath_list_file(dst), 'w') as data_file:
2021-06-12 12:13:23 +02:00
json.dump(warninglist, data_file, indent=2, sort_keys=True)
data_file.write("\n")
logging.info(
'New warninglist written to {}.'.format(get_abspath_list_file(dst))
)
except Exception as exc:
logging.error('{} General exception occurred: {}.'.format(caller, str(exc)))
2021-06-12 12:13:23 +02:00
def consolidate_networks(networks):
# Split to IPv4 and IPv6 ranges
2021-06-12 12:13:23 +02:00
ipv4_networks = []
ipv6_networks = []
for network in networks:
if isinstance(network, str):
# Convert string to IpNetwork
network = ipaddress.ip_network(network)
2021-06-12 12:13:23 +02:00
if network.version == 4:
ipv4_networks.append(network)
else:
ipv6_networks.append(network)
# Collapse ranges
2021-06-12 12:13:23 +02:00
networks_to_keep = list(map(str, ipaddress.collapse_addresses(ipv4_networks)))
networks_to_keep.extend(map(str, ipaddress.collapse_addresses(ipv6_networks)))
return networks_to_keep
def create_resolver() -> dns.resolver.Resolver:
resolver = dns.resolver.Resolver(configure=False)
resolver.timeout = 30
resolver.lifetime = 30
resolver.cache = dns.resolver.LRUCache()
resolver.nameservers = ["193.17.47.1", "185.43.135.1"] # CZ.NIC nameservers
return resolver
class Dns:
def __init__(self, resolver: dns.resolver.Resolver):
self.__resolver = resolver
def _parse_spf(self, domain: str, spf: str) -> dict:
output = {"include": [], "ranges": [], "a": [], "mx": []}
for part in spf.split(" "):
if part.startswith("include:"):
output["include"].append(part.split(":", 1)[1])
elif part.startswith("redirect="):
output["include"].append(part.split("=", 1)[1])
elif part == "a":
output["a"].append(domain)
elif part.startswith("a:"):
output["a"].append(part.split(":", 1)[1])
elif part == "mx":
output["mx"].append(domain)
elif part.startswith("mx:"):
output["mx"].append(part.split(":", 1)[1])
elif part.startswith("ip4:") or part.startswith("ip6:"):
output["ranges"].append(
ipaddress.ip_network(part.split(":", 1)[1], strict=False)
)
return output
def get_ip_for_domain(
self, domain: str
) -> List[Union[ipaddress.IPv4Address, ipaddress.IPv6Address]]:
ranges = []
try:
for ip in self.__resolver.query(domain, "a"):
ranges.append(ipaddress.IPv4Address(str(ip)))
except (
dns.resolver.NoAnswer,
dns.resolver.NXDOMAIN,
dns.exception.Timeout,
dns.resolver.NoNameservers,
):
pass
try:
for ip in self.__resolver.query(domain, "aaaa"):
ranges.append(ipaddress.IPv6Address(str(ip)))
except (
dns.resolver.NoAnswer,
dns.resolver.NXDOMAIN,
dns.exception.Timeout,
dns.resolver.NoNameservers,
):
pass
return ranges
def get_mx_ips_for_domain(
self, domain: str
) -> List[Union[ipaddress.IPv4Address, ipaddress.IPv6Address]]:
ranges = []
try:
for rdata in self.__resolver.query(domain, "mx"):
ranges += self.get_ip_for_domain(rdata.exchange)
except (
dns.resolver.NoAnswer,
dns.resolver.NXDOMAIN,
dns.exception.Timeout,
dns.resolver.NoNameservers,
):
pass
return ranges
def get_ip_ranges_from_spf(
self, domain: str
) -> List[Union[ipaddress.IPv4Network, ipaddress.IPv6Network]]:
try:
txt_records = self.__resolver.query(domain, "TXT")
except (
dns.resolver.NoAnswer,
dns.resolver.NXDOMAIN,
dns.exception.Timeout,
dns.resolver.NoNameservers,
) as e:
logging.info(
"Could not fetch TXT record for domain {}: {}".format(domain, str(e))
)
return []
ranges = []
for rdata in txt_records:
record = "".join([s.decode("utf-8") for s in rdata.strings])
if not record.startswith("v=spf1"):
continue
parsed = self._parse_spf(domain, record)
ranges += parsed["ranges"]
for include in parsed["include"]:
ranges += self.get_ip_ranges_from_spf(include)
for domain in parsed["a"]:
ranges += map(ipaddress.ip_network, self.get_ip_for_domain(domain))
for mx in parsed["mx"]:
ranges += map(ipaddress.ip_network, self.get_mx_ips_for_domain(mx))
return ranges
2023-08-25 09:55:52 +02:00
def get_domain_from_ip(self, ip: str) -> str:
try:
records = dns.reversename.from_address(ip)
except (
dns.resolver.NoAnswer,
dns.resolver.NXDOMAIN,
dns.exception.Timeout,
dns.resolver.NoNameservers,
) as e:
2023-08-25 09:55:52 +02:00
logging.info("Could not fetch PTR record for IP {}: {}".format(ip, str(e)))
return []
2023-08-25 09:55:52 +02:00
return str(dns.resolver.resolve(records, "PTR")[0]).rstrip('.')
2023-08-25 09:55:52 +02:00
def main():
init_logging()
if __name__ == '__main__':
main()