diff --git a/misp_modules/modules/expansion/virustotal.py b/misp_modules/modules/expansion/virustotal.py index 2c82787..7f18b54 100644 --- a/misp_modules/modules/expansion/virustotal.py +++ b/misp_modules/modules/expansion/virustotal.py @@ -1,6 +1,6 @@ import json -import requests from urllib.parse import urlparse +import vt from . import check_input_attribute, standard_error_message from pymisp import MISPAttribute, MISPEvent, MISPObject @@ -9,20 +9,23 @@ mispattributes = {'input': ['hostname', 'domain', "ip-src", "ip-dst", "md5", "sh 'format': 'misp_standard'} # possible module-types: 'expansion', 'hover' or both -moduleinfo = {'version': '4', 'author': 'Hannah Ward', - 'description': 'Get information from VirusTotal', +moduleinfo = {'version': '5', 'author': 'Hannah Ward', + 'description': 'Enrich observables with the VirusTotal v3 API', 'module-type': ['expansion']} # config fields that your code expects from the site admin moduleconfig = ["apikey", "event_limit", 'proxy_host', 'proxy_port', 'proxy_username', 'proxy_password'] -class VirusTotalParser(object): - def __init__(self, apikey, limit): - self.apikey = apikey - self.limit = limit - self.base_url = "https://www.virustotal.com/vtapi/v2/{}/report" +DEFAULT_RESULTS_LIMIT = 10 + + +class VirusTotalParser: + def __init__(self, client: vt.Client, limit: int) -> None: + self.client = client + self.limit = limit or DEFAULT_RESULTS_LIMIT self.misp_event = MISPEvent() + self.attribute = MISPAttribute() self.parsed_objects = {} self.input_types_mapping = {'ip-src': self.parse_ip, 'ip-dst': self.parse_ip, 'domain': self.parse_domain, 'hostname': self.parse_domain, @@ -30,196 +33,187 @@ class VirusTotalParser(object): 'sha256': self.parse_hash, 'url': self.parse_url} self.proxies = None - def query_api(self, attribute): - self.attribute = MISPAttribute() - self.attribute.from_dict(**attribute) - return self.input_types_mapping[self.attribute.type](self.attribute.value, recurse=True) + @staticmethod + def get_total_analysis(analysis: dict, known_distributors: dict = None) -> int: + if not analysis: + return 0 + count = sum([analysis['undetected'], analysis['suspicious'], analysis['harmless']]) + return count if known_distributors else count + analysis['malicious'] - def get_result(self): + def query_api(self, attribute: dict) -> None: + self.attribute.from_dict(**attribute) + self.input_types_mapping[self.attribute.type](self.attribute.value) + + def get_result(self) -> dict: event = json.loads(self.misp_event.to_json()) results = {key: event[key] for key in ('Attribute', 'Object') if (key in event and event[key])} return {'results': results} + def add_vt_report(self, report: vt.Object) -> str: + analysis = report.get('last_analysis_stats') + total = self.get_total_analysis(analysis, report.get('known_distributors')) + permalink = f'https://www.virustotal.com/gui/{report.type}/{report.id}' + + vt_object = MISPObject('virustotal-report') + vt_object.add_attribute('permalink', type='link', value=permalink) + detection_ratio = f"{analysis['malicious']}/{total}" if analysis else '-/-' + vt_object.add_attribute('detection-ratio', type='text', value=detection_ratio, disable_correlation=True) + self.misp_event.add_object(**vt_object) + return vt_object.uuid + + def create_misp_object(self, report: vt.Object) -> MISPObject: + misp_object = None + vt_uuid = self.add_vt_report(report) + if report.type == 'file': + misp_object = MISPObject('file') + for hash_type in ('md5', 'sha1', 'sha256'): + misp_object.add_attribute(**{'type': hash_type, + 'object_relation': hash_type, + 'value': report.get(hash_type)}) + elif report.type == 'domain': + misp_object = MISPObject('domain-ip') + misp_object.add_attribute('domain', type='domain', value=report.id) + elif report.type == 'ip_address': + misp_object = MISPObject('domain-ip') + misp_object.add_attribute('ip', type='ip-dst', value=report.id) + elif report.type == 'url': + misp_object = MISPObject('url') + misp_object.add_attribute('url', type='url', value=report.url) + misp_object.add_reference(vt_uuid, 'analyzed-with') + return misp_object + ################################################################################ #### Main parsing functions #### # noqa ################################################################################ - def parse_domain(self, domain, recurse=False): - req = requests.get(self.base_url.format('domain'), params={'apikey': self.apikey, 'domain': domain}, proxies=self.proxies) - if req.status_code != 200: - return req.status_code - req = req.json() - hash_type = 'sha256' - whois = 'whois' - feature_types = {'communicating': 'communicates-with', - 'downloaded': 'downloaded-from', - 'referrer': 'referring'} - siblings = (self.parse_siblings(domain) for domain in req['domain_siblings']) - uuid = self.parse_resolutions(req['resolutions'], req['subdomains'] if 'subdomains' in req else None, siblings) - for feature_type, relationship in feature_types.items(): - for feature in ('undetected_{}_samples', 'detected_{}_samples'): - for sample in req.get(feature.format(feature_type), [])[:self.limit]: - status_code = self.parse_hash(sample[hash_type], False, uuid, relationship) - if status_code != 200: - return status_code - if req.get(whois): - whois_object = MISPObject(whois) - whois_object.add_attribute('text', type='text', value=req[whois]) + def parse_domain(self, domain: str) -> str: + domain_report = self.client.get_object(f'/domains/{domain}') + + # DOMAIN + domain_object = self.create_misp_object(domain_report) + + # WHOIS + if domain_report.whois: + whois_object = MISPObject('whois') + whois_object.add_attribute('text', type='text', value=domain_report.whois) self.misp_event.add_object(**whois_object) - return self.parse_related_urls(req, recurse, uuid) - def parse_hash(self, sample, recurse=False, uuid=None, relationship=None): - req = requests.get(self.base_url.format('file'), params={'apikey': self.apikey, 'resource': sample}, proxies=self.proxies) - status_code = req.status_code - if req.status_code == 200: - req = req.json() - vt_uuid = self.parse_vt_object(req) - file_attributes = [] - for hash_type in ('md5', 'sha1', 'sha256'): - if req.get(hash_type): - file_attributes.append({'type': hash_type, 'object_relation': hash_type, - 'value': req[hash_type]}) - if file_attributes: - file_object = MISPObject('file') - for attribute in file_attributes: - file_object.add_attribute(**attribute) - file_object.add_reference(vt_uuid, 'analyzed-with') - if uuid and relationship: - file_object.add_reference(uuid, relationship) + # SIBLINGS AND SUBDOMAINS + for relationship_name, misp_name in [('siblings', 'sibling-of'), ('subdomains', 'subdomain')]: + rel_iterator = self.client.iterator(f'/domains/{domain_report.id}/{relationship_name}', limit=self.limit) + for item in rel_iterator: + attr = MISPAttribute() + attr.from_dict(**dict(type='domain', value=item.id)) + self.misp_event.add_attribute(**attr) + domain_object.add_reference(attr.uuid, misp_name) + + # RESOLUTIONS + resolutions_iterator = self.client.iterator(f'/domains/{domain_report.id}/resolutions', limit=self.limit) + for resolution in resolutions_iterator: + domain_object.add_attribute('ip', type='ip-dst', value=resolution.ip_address) + + # COMMUNICATING, DOWNLOADED AND REFERRER FILES + for relationship_name, misp_name in [ + ('communicating_files', 'communicates-with'), + ('downloaded_files', 'downloaded-from'), + ('referrer_files', 'referring') + ]: + files_iterator = self.client.iterator(f'/domains/{domain_report.id}/{relationship_name}', limit=self.limit) + for file in files_iterator: + file_object = self.create_misp_object(file) + file_object.add_reference(domain_object.uuid, misp_name) self.misp_event.add_object(**file_object) - return status_code - def parse_ip(self, ip, recurse=False): - req = requests.get(self.base_url.format('ip-address'), params={'apikey': self.apikey, 'ip': ip}, proxies=self.proxies) - if req.status_code != 200: - return req.status_code - req = req.json() - if req.get('asn'): - asn_mapping = {'network': ('ip-src', 'subnet-announced'), - 'country': ('text', 'country')} - asn_object = MISPObject('asn') - asn_object.add_attribute('asn', type='AS', value=req['asn']) - for key, value in asn_mapping.items(): - if req.get(key): - attribute_type, relation = value - asn_object.add_attribute(relation, type=attribute_type, value=req[key]) - self.misp_event.add_object(**asn_object) - uuid = self.parse_resolutions(req['resolutions']) if req.get('resolutions') else None - return self.parse_related_urls(req, recurse, uuid) + # URLS + urls_iterator = self.client.iterator(f'/domains/{domain_report.id}/urls', limit=self.limit) + for url in urls_iterator: + url_object = self.create_misp_object(url) + url_object.add_reference(domain_object.uuid, 'hosted-in') + self.misp_event.add_object(**url_object) - def parse_url(self, url, recurse=False, uuid=None): - req = requests.get(self.base_url.format('url'), params={'apikey': self.apikey, 'resource': url}, proxies=self.proxies) - status_code = req.status_code - if req.status_code == 200: - req = req.json() - vt_uuid = self.parse_vt_object(req) - if not recurse: - feature = 'url' - url_object = MISPObject(feature) - url_object.add_attribute(feature, type=feature, value=url) - url_object.add_reference(vt_uuid, 'analyzed-with') - if uuid: - url_object.add_reference(uuid, 'hosted-in') - self.misp_event.add_object(**url_object) - return status_code + self.misp_event.add_object(**domain_object) + return domain_object.uuid - ################################################################################ - #### Additional parsing functions #### # noqa - ################################################################################ + def parse_hash(self, file_hash: str) -> str: + file_report = self.client.get_object(f'files/{file_hash}') + file_object = self.create_misp_object(file_report) + self.misp_event.add_object(**file_object) + return file_object.uuid - def parse_related_urls(self, query_result, recurse, uuid=None): - if recurse: - for feature in ('detected_urls', 'undetected_urls'): - if feature in query_result: - for url in query_result[feature]: - value = url['url'] if isinstance(url, dict) else url[0] - status_code = self.parse_url(value, False, uuid) - if status_code != 200: - return status_code + def parse_ip(self, ip: str) -> str: + ip_report = self.client.get_object(f'/ip_addresses/{ip}') + + # IP + ip_object = self.create_misp_object(ip_report) + + # ASN + asn_object = MISPObject('asn') + asn_object.add_attribute('asn', type='AS', value=ip_report.asn) + asn_object.add_attribute('subnet-announced', type='ip-src', value=ip_report.network) + asn_object.add_attribute('country', type='text', value=ip_report.country) + self.misp_event.add_object(**asn_object) + + # RESOLUTIONS + resolutions_iterator = self.client.iterator(f'/ip_addresses/{ip_report.id}/resolutions', limit=self.limit) + for resolution in resolutions_iterator: + ip_object.add_attribute('domain', type='domain', value=resolution.host_name) + + # URLS + urls_iterator = self.client.iterator(f'/ip_addresses/{ip_report.id}/urls', limit=self.limit) + for url in urls_iterator: + url_object = self.create_misp_object(url) + url_object.add_reference(ip_object.uuid, 'hosted-in') + self.misp_event.add_object(**url_object) + + self.misp_event.add_object(**ip_object) + return ip_object.uuid + + def parse_url(self, url: str) -> str: + url_id = vt.url_id(url) + url_report = self.client.get_object(f'/urls/{url_id}') + url_object = self.create_misp_object(url_report) + self.misp_event.add_object(**url_object) + return url_object.uuid + + +def get_proxy_settings(config: dict) -> dict: + """Returns proxy settings in the requests format. + If no proxy settings are set, return None.""" + proxies = None + host = config.get('proxy_host') + port = config.get('proxy_port') + username = config.get('proxy_username') + password = config.get('proxy_password') + + if host: + if not port: + misperrors['error'] = 'The virustotal_proxy_host config is set, ' \ + 'please also set the virustotal_proxy_port.' + raise KeyError + parsed = urlparse(host) + if 'http' in parsed.scheme: + scheme = 'http' else: - for feature in ('detected_urls', 'undetected_urls'): - if feature in query_result: - for url in query_result[feature]: - value = url['url'] if isinstance(url, dict) else url[0] - self.misp_event.add_attribute('url', value) - return 200 + scheme = parsed.scheme + netloc = parsed.netloc + host = f'{netloc}:{port}' - def parse_resolutions(self, resolutions, subdomains=None, uuids=None): - domain_ip_object = MISPObject('domain-ip') - if self.attribute.type in ('domain', 'hostname'): - domain_ip_object.add_attribute('domain', type='domain', value=self.attribute.value) - attribute_type, relation, key = ('ip-dst', 'ip', 'ip_address') - else: - domain_ip_object.add_attribute('ip', type='ip-dst', value=self.attribute.value) - attribute_type, relation, key = ('domain', 'domain', 'hostname') - for resolution in resolutions: - domain_ip_object.add_attribute(relation, type=attribute_type, value=resolution[key]) - if subdomains: - for subdomain in subdomains: - attribute = MISPAttribute() - attribute.from_dict(**dict(type='domain', value=subdomain)) - self.misp_event.add_attribute(**attribute) - domain_ip_object.add_reference(attribute.uuid, 'subdomain') - if uuids: - for uuid in uuids: - domain_ip_object.add_reference(uuid, 'sibling-of') - self.misp_event.add_object(**domain_ip_object) - return domain_ip_object.uuid - - def parse_siblings(self, domain): - attribute = MISPAttribute() - attribute.from_dict(**dict(type='domain', value=domain)) - self.misp_event.add_attribute(**attribute) - return attribute.uuid - - def parse_vt_object(self, query_result): - if query_result['response_code'] == 1: - vt_object = MISPObject('virustotal-report') - vt_object.add_attribute('permalink', type='link', value=query_result['permalink']) - detection_ratio = '{}/{}'.format(query_result['positives'], query_result['total']) - vt_object.add_attribute('detection-ratio', type='text', value=detection_ratio, disable_correlation=True) - self.misp_event.add_object(**vt_object) - return vt_object.uuid - - def set_proxy_settings(self, config: dict) -> dict: - """Returns proxy settings in the requests format. - If no proxy settings are set, return None.""" - proxies = None - host = config.get('proxy_host') - port = config.get('proxy_port') - username = config.get('proxy_username') - password = config.get('proxy_password') - - if host: - if not port: - misperrors['error'] = 'The virustotal_proxy_host config is set, ' \ - 'please also set the virustotal_proxy_port.' + if username: + if not password: + misperrors['error'] = 'The virustotal_proxy_username config is set, ' \ + 'please also set the virustotal_proxy_password.' raise KeyError - parsed = urlparse(host) - if 'http' in parsed.scheme: - scheme = 'http' - else: - scheme = parsed.scheme - netloc = parsed.netloc - host = f'{netloc}:{port}' + auth = f'{username}:{password}' + host = auth + '@' + host - if username: - if not password: - misperrors['error'] = 'The virustotal_proxy_username config is set, ' \ - 'please also set the virustotal_proxy_password.' - raise KeyError - auth = f'{username}:{password}' - host = auth + '@' + host - - proxies = { - 'http': f'{scheme}://{host}', - 'https': f'{scheme}://{host}' - } - self.proxies = proxies - return True + proxies = { + 'http': f'{scheme}://{host}', + 'https': f'{scheme}://{host}' + } + return proxies -def parse_error(status_code): +def parse_error(status_code: int) -> str: status_mapping = {204: 'VirusTotal request rate limit exceeded.', 400: 'Incorrect request, please check the arguments.', 403: 'You don\'t have enough privileges to make the request.'} @@ -233,7 +227,7 @@ def handler(q=False): return False request = json.loads(q) if not request.get('config') or not request['config'].get('apikey'): - misperrors['error'] = "A VirusTotal api key is required for this module." + misperrors['error'] = 'A VirusTotal api key is required for this module.' return misperrors if not request.get('attribute') or not check_input_attribute(request['attribute']): return {'error': f'{standard_error_message}, which should contain at least a type, a value and an uuid.'} @@ -241,15 +235,21 @@ def handler(q=False): return {'error': 'Unsupported attribute type.'} event_limit = request['config'].get('event_limit') - if not isinstance(event_limit, int): - event_limit = 5 - parser = VirusTotalParser(request['config']['apikey'], event_limit) - parser.set_proxy_settings(request.get('config')) attribute = request['attribute'] - status = parser.query_api(attribute) - if status != 200: - misperrors['error'] = parse_error(status) + proxy_settings = get_proxy_settings(request.get('config')) + + try: + client = vt.Client(request['config']['apikey'], + headers={ + 'x-tool': 'MISPModuleVirusTotalExpansion', + }, + proxy=proxy_settings['http'] if proxy_settings else None) + parser = VirusTotalParser(client, int(event_limit) if event_limit else None) + parser.query_api(attribute) + except vt.APIError as ex: + misperrors['error'] = ex.message return misperrors + return parser.get_result() @@ -259,4 +259,4 @@ def introspection(): def version(): moduleinfo['config'] = moduleconfig - return moduleinfo + return moduleinfo \ No newline at end of file diff --git a/misp_modules/modules/expansion/virustotal_public.py b/misp_modules/modules/expansion/virustotal_public.py index c10f4d2..f72bda4 100644 --- a/misp_modules/modules/expansion/virustotal_public.py +++ b/misp_modules/modules/expansion/virustotal_public.py @@ -1,6 +1,6 @@ import json import logging -import requests +import vt from . import check_input_attribute, standard_error_message from urllib.parse import urlparse from pymisp import MISPAttribute, MISPEvent, MISPObject @@ -8,8 +8,8 @@ from pymisp import MISPAttribute, MISPEvent, MISPObject misperrors = {'error': 'Error'} mispattributes = {'input': ['hostname', 'domain', "ip-src", "ip-dst", "md5", "sha1", "sha256", "url"], 'format': 'misp_standard'} -moduleinfo = {'version': '1', 'author': 'Christian Studer', - 'description': 'Get information from VirusTotal public API v2.', +moduleinfo = {'version': '2', 'author': 'Christian Studer', + 'description': 'Enrich observables with the VirusTotal v3 public API', 'module-type': ['expansion', 'hover']} moduleconfig = ['apikey', 'proxy_host', 'proxy_port', 'proxy_username', 'proxy_password'] @@ -18,191 +18,188 @@ LOGGER = logging.getLogger('virus_total_public') LOGGER.setLevel(logging.INFO) -class VirusTotalParser(): - def __init__(self): - super(VirusTotalParser, self).__init__() +DEFAULT_RESULTS_LIMIT = 10 + + +class VirusTotalParser: + def __init__(self, client: vt.Client, limit: int) -> None: + self.client = client + self.limit = limit or DEFAULT_RESULTS_LIMIT self.misp_event = MISPEvent() + self.attribute = MISPAttribute() + self.parsed_objects = {} + self.input_types_mapping = {'ip-src': self.parse_ip, 'ip-dst': self.parse_ip, + 'domain': self.parse_domain, 'hostname': self.parse_domain, + 'md5': self.parse_hash, 'sha1': self.parse_hash, + 'sha256': self.parse_hash, 'url': self.parse_url} self.proxies = None - def declare_variables(self, apikey, attribute): - self.attribute = MISPAttribute() - self.attribute.from_dict(**attribute) - self.apikey = apikey + @staticmethod + def get_total_analysis(analysis: dict, known_distributors: dict = None) -> int: + if not analysis: + return 0 + count = sum([analysis['undetected'], analysis['suspicious'], analysis['harmless']]) + return count if known_distributors else count + analysis['malicious'] - def get_result(self): + def query_api(self, attribute: dict) -> None: + self.attribute.from_dict(**attribute) + self.input_types_mapping[self.attribute.type](self.attribute.value) + + def get_result(self) -> dict: event = json.loads(self.misp_event.to_json()) results = {key: event[key] for key in ('Attribute', 'Object') if (key in event and event[key])} return {'results': results} - def parse_urls(self, query_result): - for feature in ('detected_urls', 'undetected_urls'): - if feature in query_result: - for url in query_result[feature]: - value = url['url'] if isinstance(url, dict) else url[0] - self.misp_event.add_attribute('url', value) + def add_vt_report(self, report: vt.Object) -> str: + analysis = report.get('last_analysis_stats') + total = self.get_total_analysis(analysis, report.get('known_distributors')) + permalink = f'https://www.virustotal.com/gui/{report.type}/{report.id}' - def parse_resolutions(self, resolutions, subdomains=None, uuids=None): - domain_ip_object = MISPObject('domain-ip') - if self.attribute.type in ('domain', 'hostname'): - domain_ip_object.add_attribute('domain', type='domain', value=self.attribute.value) - attribute_type, relation, key = ('ip-dst', 'ip', 'ip_address') - else: - domain_ip_object.add_attribute('ip', type='ip-dst', value=self.attribute.value) - attribute_type, relation, key = ('domain', 'domain', 'hostname') - for resolution in resolutions: - domain_ip_object.add_attribute(relation, type=attribute_type, value=resolution[key]) - if subdomains: - for subdomain in subdomains: - attribute = MISPAttribute() - attribute.from_dict(**dict(type='domain', value=subdomain)) - self.misp_event.add_attribute(**attribute) - domain_ip_object.add_reference(attribute.uuid, 'subdomain') - if uuids: - for uuid in uuids: - domain_ip_object.add_reference(uuid, 'sibling-of') - self.misp_event.add_object(**domain_ip_object) + vt_object = MISPObject('virustotal-report') + vt_object.add_attribute('permalink', type='link', value=permalink) + detection_ratio = f"{analysis['malicious']}/{total}" if analysis else '-/-' + vt_object.add_attribute('detection-ratio', type='text', value=detection_ratio, disable_correlation=True) + self.misp_event.add_object(**vt_object) + return vt_object.uuid - def parse_vt_object(self, query_result): - if query_result['response_code'] == 1: - vt_object = MISPObject('virustotal-report') - vt_object.add_attribute('permalink', type='link', value=query_result['permalink']) - detection_ratio = '{}/{}'.format(query_result['positives'], query_result['total']) - vt_object.add_attribute('detection-ratio', type='text', value=detection_ratio) - self.misp_event.add_object(**vt_object) + def create_misp_object(self, report: vt.Object) -> MISPObject: + misp_object = None + vt_uuid = self.add_vt_report(report) + if report.type == 'file': + misp_object = MISPObject('file') + for hash_type in ('md5', 'sha1', 'sha256'): + misp_object.add_attribute(**{'type': hash_type, + 'object_relation': hash_type, + 'value': report.get(hash_type)}) + elif report.type == 'domain': + misp_object = MISPObject('domain-ip') + misp_object.add_attribute('domain', type='domain', value=report.id) + elif report.type == 'ip_address': + misp_object = MISPObject('domain-ip') + misp_object.add_attribute('ip', type='ip-dst', value=report.id) + elif report.type == 'url': + misp_object = MISPObject('url') + misp_object.add_attribute('url', type='url', value=report.url) + misp_object.add_reference(vt_uuid, 'analyzed-with') + return misp_object - def get_query_result(self, query_type): - params = {query_type: self.attribute.value, 'apikey': self.apikey} - return requests.get(self.base_url, params=params, proxies=self.proxies) + ################################################################################ + #### Main parsing functions #### # noqa + ################################################################################ - def set_proxy_settings(self, config: dict) -> dict: - """Returns proxy settings in the requests format. - If no proxy settings are set, return None.""" - proxies = None - host = config.get('proxy_host') - port = config.get('proxy_port') - username = config.get('proxy_username') - password = config.get('proxy_password') + def parse_domain(self, domain: str) -> str: + domain_report = self.client.get_object(f'/domains/{domain}') - if host: - if not port: - misperrors['error'] = 'The virustotal_public_proxy_host config is set, ' \ - 'please also set the virustotal_public_proxy_port.' - raise KeyError - parsed = urlparse(host) - if 'http' in parsed.scheme: - scheme = 'http' - else: - scheme = parsed.scheme - netloc = parsed.netloc - host = f'{netloc}:{port}' + # DOMAIN + domain_object = self.create_misp_object(domain_report) - if username: - if not password: - misperrors['error'] = 'The virustotal_public_proxy_username config is set, ' \ - 'please also set the virustotal_public_proxy_password.' - raise KeyError - auth = f'{username}:{password}' - host = auth + '@' + host - - proxies = { - 'http': f'{scheme}://{host}', - 'https': f'{scheme}://{host}' - } - self.proxies = proxies - return True - - -class DomainQuery(VirusTotalParser): - def __init__(self, apikey, attribute): - super(DomainQuery, self).__init__() - self.base_url = "https://www.virustotal.com/vtapi/v2/domain/report" - self.declare_variables(apikey, attribute) - - def parse_report(self, query_result): - hash_type = 'sha256' - whois = 'whois' - for feature_type in ('referrer', 'downloaded', 'communicating'): - for feature in ('undetected_{}_samples', 'detected_{}_samples'): - for sample in query_result.get(feature.format(feature_type), []): - self.misp_event.add_attribute(hash_type, sample[hash_type]) - if query_result.get(whois): - whois_object = MISPObject(whois) - whois_object.add_attribute('text', type='text', value=query_result[whois]) + # WHOIS + if domain_report.whois: + whois_object = MISPObject('whois') + whois_object.add_attribute('text', type='text', value=domain_report.whois) self.misp_event.add_object(**whois_object) - if 'domain_siblings' in query_result: - siblings = (self.parse_siblings(domain) for domain in query_result['domain_siblings']) - if 'subdomains' in query_result: - self.parse_resolutions(query_result['resolutions'], query_result['subdomains'], siblings) - self.parse_urls(query_result) - def parse_siblings(self, domain): - attribute = MISPAttribute() - attribute.from_dict(**dict(type='domain', value=domain)) - self.misp_event.add_attribute(**attribute) - return attribute.uuid + # SIBLINGS AND SUBDOMAINS + for relationship_name, misp_name in [('siblings', 'sibling-of'), ('subdomains', 'subdomain')]: + rel_iterator = self.client.iterator(f'/domains/{domain_report.id}/{relationship_name}', limit=self.limit) + for item in rel_iterator: + attr = MISPAttribute() + attr.from_dict(**dict(type='domain', value=item.id)) + self.misp_event.add_attribute(**attr) + domain_object.add_reference(attr.uuid, misp_name) + + # RESOLUTIONS + resolutions_iterator = self.client.iterator(f'/domains/{domain_report.id}/resolutions', limit=self.limit) + for resolution in resolutions_iterator: + domain_object.add_attribute('ip', type='ip-dst', value=resolution.ip_address) + + # COMMUNICATING AND REFERRER FILES + for relationship_name, misp_name in [ + ('communicating_files', 'communicates-with'), + ('referrer_files', 'referring') + ]: + files_iterator = self.client.iterator(f'/domains/{domain_report.id}/{relationship_name}', limit=self.limit) + for file in files_iterator: + file_object = self.create_misp_object(file) + file_object.add_reference(domain_object.uuid, misp_name) + self.misp_event.add_object(**file_object) + + self.misp_event.add_object(**domain_object) + return domain_object.uuid + + def parse_hash(self, file_hash: str) -> str: + file_report = self.client.get_object(f'files/{file_hash}') + file_object = self.create_misp_object(file_report) + self.misp_event.add_object(**file_object) + return file_object.uuid + + def parse_ip(self, ip: str) -> str: + ip_report = self.client.get_object(f'/ip_addresses/{ip}') + + # IP + ip_object = self.create_misp_object(ip_report) + + # ASN + asn_object = MISPObject('asn') + asn_object.add_attribute('asn', type='AS', value=ip_report.asn) + asn_object.add_attribute('subnet-announced', type='ip-src', value=ip_report.network) + asn_object.add_attribute('country', type='text', value=ip_report.country) + self.misp_event.add_object(**asn_object) + + # RESOLUTIONS + resolutions_iterator = self.client.iterator(f'/ip_addresses/{ip_report.id}/resolutions', limit=self.limit) + for resolution in resolutions_iterator: + ip_object.add_attribute('domain', type='domain', value=resolution.host_name) + + self.misp_event.add_object(**ip_object) + return ip_object.uuid + + def parse_url(self, url: str) -> str: + url_id = vt.url_id(url) + url_report = self.client.get_object(f'/urls/{url_id}') + url_object = self.create_misp_object(url_report) + self.misp_event.add_object(**url_object) + return url_object.uuid -class HashQuery(VirusTotalParser): - def __init__(self, apikey, attribute): - super(HashQuery, self).__init__() - self.base_url = "https://www.virustotal.com/vtapi/v2/file/report" - self.declare_variables(apikey, attribute) +def get_proxy_settings(config: dict) -> dict: + """Returns proxy settings in the requests format. + If no proxy settings are set, return None.""" + proxies = None + host = config.get('proxy_host') + port = config.get('proxy_port') + username = config.get('proxy_username') + password = config.get('proxy_password') - def parse_report(self, query_result): - file_attributes = [] - for hash_type in ('md5', 'sha1', 'sha256'): - if query_result.get(hash_type): - file_attributes.append({'type': hash_type, 'object_relation': hash_type, - 'value': query_result[hash_type]}) - if file_attributes: - file_object = MISPObject('file') - for attribute in file_attributes: - file_object.add_attribute(**attribute) - self.misp_event.add_object(**file_object) - self.parse_vt_object(query_result) + if host: + if not port: + misperrors['error'] = 'The virustotal_proxy_host config is set, ' \ + 'please also set the virustotal_proxy_port.' + raise KeyError + parsed = urlparse(host) + if 'http' in parsed.scheme: + scheme = 'http' + else: + scheme = parsed.scheme + netloc = parsed.netloc + host = f'{netloc}:{port}' + + if username: + if not password: + misperrors['error'] = 'The virustotal_proxy_username config is set, ' \ + 'please also set the virustotal_proxy_password.' + raise KeyError + auth = f'{username}:{password}' + host = auth + '@' + host + + proxies = { + 'http': f'{scheme}://{host}', + 'https': f'{scheme}://{host}' + } + return proxies -class IpQuery(VirusTotalParser): - def __init__(self, apikey, attribute): - super(IpQuery, self).__init__() - self.base_url = "https://www.virustotal.com/vtapi/v2/ip-address/report" - self.declare_variables(apikey, attribute) - - def parse_report(self, query_result): - if query_result.get('asn'): - asn_mapping = {'network': ('ip-src', 'subnet-announced'), - 'country': ('text', 'country')} - asn_object = MISPObject('asn') - asn_object.add_attribute('asn', type='AS', value=query_result['asn']) - for key, value in asn_mapping.items(): - if query_result.get(key): - attribute_type, relation = value - asn_object.add_attribute(relation, type=attribute_type, value=query_result[key]) - self.misp_event.add_object(**asn_object) - self.parse_urls(query_result) - if query_result.get('resolutions'): - self.parse_resolutions(query_result['resolutions']) - - -class UrlQuery(VirusTotalParser): - def __init__(self, apikey, attribute): - super(UrlQuery, self).__init__() - self.base_url = "https://www.virustotal.com/vtapi/v2/url/report" - self.declare_variables(apikey, attribute) - - def parse_report(self, query_result): - self.parse_vt_object(query_result) - - -domain = ('domain', DomainQuery) -ip = ('ip', IpQuery) -file = ('resource', HashQuery) -misp_type_mapping = {'domain': domain, 'hostname': domain, 'ip-src': ip, - 'ip-dst': ip, 'md5': file, 'sha1': file, 'sha256': file, - 'url': ('resource', UrlQuery)} - - -def parse_error(status_code): +def parse_error(status_code: int) -> str: status_mapping = {204: 'VirusTotal request rate limit exceeded.', 400: 'Incorrect request, please check the arguments.', 403: 'You don\'t have enough privileges to make the request.'} @@ -216,23 +213,29 @@ def handler(q=False): return False request = json.loads(q) if not request.get('config') or not request['config'].get('apikey'): - misperrors['error'] = "A VirusTotal api key is required for this module." + misperrors['error'] = 'A VirusTotal api key is required for this module.' return misperrors if not request.get('attribute') or not check_input_attribute(request['attribute']): return {'error': f'{standard_error_message}, which should contain at least a type, a value and an uuid.'} - attribute = request['attribute'] - if attribute['type'] not in mispattributes['input']: + if request['attribute']['type'] not in mispattributes['input']: return {'error': 'Unsupported attribute type.'} - query_type, to_call = misp_type_mapping[attribute['type']] - parser = to_call(request['config']['apikey'], attribute) - parser.set_proxy_settings(request.get('config')) - query_result = parser.get_query_result(query_type) - status_code = query_result.status_code - if status_code == 200: - parser.parse_report(query_result.json()) - else: - misperrors['error'] = parse_error(status_code) + + event_limit = request['config'].get('event_limit') + attribute = request['attribute'] + proxy_settings = get_proxy_settings(request.get('config')) + + try: + client = vt.Client(request['config']['apikey'], + headers={ + 'x-tool': 'MISPModuleVirusTotalPublicExpansion', + }, + proxy=proxy_settings['http'] if proxy_settings else None) + parser = VirusTotalParser(client, int(event_limit) if event_limit else None) + parser.query_api(attribute) + except vt.APIError as ex: + misperrors['error'] = ex.message return misperrors + return parser.get_result() @@ -242,4 +245,4 @@ def introspection(): def version(): moduleinfo['config'] = moduleconfig - return moduleinfo + return moduleinfo \ No newline at end of file