diff --git a/misp_modules/modules/expansion/virustotal.py b/misp_modules/modules/expansion/virustotal.py index 65623fb..1839bb3 100644 --- a/misp_modules/modules/expansion/virustotal.py +++ b/misp_modules/modules/expansion/virustotal.py @@ -1,167 +1,206 @@ +from pymisp import MISPAttribute, MISPEvent, MISPObject import json import requests -from requests import HTTPError -import base64 -from collections import defaultdict misperrors = {'error': 'Error'} -mispattributes = {'input': ['hostname', 'domain', "ip-src", "ip-dst", "md5", "sha1", "sha256", "sha512"], - 'output': ['domain', "ip-src", "ip-dst", "text", "md5", "sha1", "sha256", "sha512", "ssdeep", - "authentihash", "filename"]} +mispattributes = {'input': ['hostname', 'domain', "ip-src", "ip-dst", "md5", "sha1", "sha256", "sha512", "url"], + 'format': 'misp_standard'} # possible module-types: 'expansion', 'hover' or both -moduleinfo = {'version': '3', 'author': 'Hannah Ward', +moduleinfo = {'version': '4', 'author': 'Hannah Ward', 'description': 'Get information from virustotal', 'module-type': ['expansion']} # config fields that your code expects from the site admin -moduleconfig = ["apikey", "event_limit"] -comment = '{}: Enriched via VirusTotal' -hash_types = ["md5", "sha1", "sha256", "sha512"] +moduleconfig = ["apikey"] -class VirusTotalRequest(object): - def __init__(self, config): - self.apikey = config['apikey'] - self.limit = int(config.get('event_limit', 5)) +class VirusTotalParser(object): + def __init__(self, apikey): + self.apikey = apikey self.base_url = "https://www.virustotal.com/vtapi/v2/{}/report" - self.results = defaultdict(set) - self.to_return = [] - self.input_types_mapping = {'ip-src': self.get_ip, 'ip-dst': self.get_ip, - 'domain': self.get_domain, 'hostname': self.get_domain, - 'md5': self.get_hash, 'sha1': self.get_hash, - 'sha256': self.get_hash, 'sha512': self.get_hash} - self.output_types_mapping = {'submission_names': 'filename', 'ssdeep': 'ssdeep', - 'authentihash': 'authentihash', 'ITW_urls': 'url'} + self.misp_event = MISPEvent() + self.parsed_objects = {} + self.input_types_mapping = {'ip-src': self.parse_ip, 'ip-dst': self.parse_ip, + 'domain': self.parse_domain, 'hostname': self.parse_domain, + 'md5': self.parse_hash, 'sha1': self.parse_hash, + 'sha256': self.parse_hash, 'sha512': self.parse_hash, + 'url': self.parse_url} - def parse_request(self, q): - req_values = set() - for attribute_type, attribute_value in q.items(): - req_values.add(attribute_value) - try: - error = self.input_types_mapping[attribute_type](attribute_value) - except KeyError: - continue - if error is not None: - return error - for key, values in self.results.items(): - values = values.difference(req_values) - if values: - if isinstance(key, tuple): - types, comment = key - self.to_return.append({'types': list(types), 'values': list(values), 'comment': comment}) - else: - self.to_return.append({'types': key, 'values': list(values)}) - return self.to_return + def query_api(self, attribute): + self.attribute = MISPAttribute() + self.attribute.from_dict(**attribute) + return self.input_types_mapping[self.attribute.type](self.attribute.value, recurse=True) - def get_domain(self, domain, do_not_recurse=False): - req = requests.get(self.base_url.format('domain'), params={'domain': domain, 'apikey': self.apikey}) - try: - req.raise_for_status() + def get_result(self): + event = json.loads(self.misp_event.to_json())['Event'] + results = {key: event[key] for key in ('Attribute', 'Object') if (key in event and event[key])} + return {'results': results} + + ################################################################################ + #### Main parsing functions #### + ################################################################################ + + def parse_domain(self, domain, recurse=False): + req = requests.get(self.base_url.format('domain'), params={'apikey': self.apikey, 'domain': domain}) + if req.status_code != 200: + return req.status_code + req = req.json() + hash_type = 'sha256' + whois = 'whois' + feature_types = {'communicating': 'communicates-with', + 'downloaded': 'downloaded-from', + 'referrer': 'referring'} + siblings = (self.parse_siblings(domain) for domain in req['domain_siblings']) + uuid = self.parse_resolutions(req['resolutions'], req['subdomains'], siblings) + for feature_type, relationship in feature_types.items(): + for feature in ('undetected_{}_samples', 'detected_{}_samples'): + for sample in req.get(feature.format(feature_type), []): + status_code = self.parse_hash(sample[hash_type], False, uuid, relationship) + if status_code != 200: + return status_code + if req.get(whois): + whois_object = MISPObject(whois) + whois_object.add_attribute('text', type='text', value=req[whois]) + self.misp_event.add_object(**whois_object) + return self.parse_related_urls(req, recurse, uuid) + + def parse_hash(self, sample, recurse=False, uuid=None, relationship=None): + req = requests.get(self.base_url.format('file'), params={'apikey': self.apikey, 'resource': sample}) + status_code = req.status_code + if req.status_code == 200: req = req.json() - except HTTPError as e: - return str(e) - if req["response_code"] == 0: - # Nothing found - return [] - if "resolutions" in req: - for res in req["resolutions"][:self.limit]: - ip_address = res["ip_address"] - self.results[(("ip-dst", "ip-src"), comment.format(domain))].add(ip_address) - # Pivot from here to find all domain info - if not do_not_recurse: - error = self.get_ip(ip_address, True) - if error is not None: - return error - self.get_more_info(req) + vt_uuid = self.parse_vt_object(req) + file_attributes = [] + for hash_type in ('md5', 'sha1', 'sha256'): + if req.get(hash_type): + file_attributes.append({'type': hash_type, 'object_relation': hash_type, + 'value': req[hash_type]}) + if file_attributes: + file_object = MISPObject('file') + for attribute in file_attributes: + file_object.add_attribute(**attribute) + file_object.add_reference(vt_uuid, 'analyzed-with') + if uuid and relationship: + file_object.add_reference(uuid, relationship) + self.misp_event.add_object(**file_object) + return status_code - def get_hash(self, _hash): - req = requests.get(self.base_url.format('file'), params={'resource': _hash, 'apikey': self.apikey, 'allinfo': 1}) - try: - req.raise_for_status() + def parse_ip(self, ip, recurse=False): + req = requests.get(self.base_url.format('ip-address'), params={'apikey': self.apikey, 'ip': ip}) + if req.status_code != 200: + return req.status_code + req = req.json() + if req.get('asn'): + asn_mapping = {'network': ('ip-src', 'subnet-announced'), + 'country': ('text', 'country')} + asn_object = MISPObject('asn') + asn_object.add_attribute('asn', type='AS', value=req['asn']) + for key, value in asn_mapping.items(): + if req.get(key): + attribute_type, relation = value + asn_object.add_attribute(relation, type=attribute_type, value=req[key]) + self.misp_event.add_object(**asn_object) + uuid = self.parse_resolutions(req['resolutions']) if req.get('resolutions') else None + return self.parse_related_urls(req, recurse, uuid) + + def parse_url(self, url, recurse=False, uuid=None): + req = requests.get(self.base_url.format('url'), params={'apikey': self.apikey, 'resource': url}) + status_code = req.status_code + if req.status_code == 200: req = req.json() - except HTTPError as e: - return str(e) - if req["response_code"] == 0: - # Nothing found - return [] - self.get_more_info(req) + vt_uuid = self.parse_vt_object(req) + if not recurse: + feature = 'url' + url_object = MISPObject(feature) + url_object.add_attribute(feature, type=feature, value=url) + url_object.add_reference(vt_uuid, 'analyzed-with') + if uuid: + url_object.add_reference(uuid, 'hosted-in') + self.misp_event.add_object(**url_object) + return status_code - def get_ip(self, ip, do_not_recurse=False): - req = requests.get(self.base_url.format('ip-address'), params={'ip': ip, 'apikey': self.apikey}) - try: - req.raise_for_status() - req = req.json() - except HTTPError as e: - return str(e) - if req["response_code"] == 0: - # Nothing found - return [] - if "resolutions" in req: - for res in req["resolutions"][:self.limit]: - hostname = res["hostname"] - self.results[(("domain",), comment.format(ip))].add(hostname) - # Pivot from here to find all domain info - if not do_not_recurse: - error = self.get_domain(hostname, True) - if error is not None: - return error - self.get_more_info(req) + ################################################################################ + #### Additional parsing functions #### + ################################################################################ - def find_all(self, data): - hashes = [] - if isinstance(data, dict): - for key, value in data.items(): - if key in hash_types: - self.results[key].add(value) - hashes.append(value) - else: - if isinstance(value, (dict, list)): - hashes.extend(self.find_all(value)) - elif isinstance(data, list): - for d in data: - hashes.extend(self.find_all(d)) - return hashes + def parse_related_urls(self, query_result, recurse, uuid=None): + if recurse: + for feature in ('detected_urls', 'undetected_urls'): + if feature in query_result: + for url in query_result[feature]: + value = url['url'] if isinstance(url, dict) else url[0] + status_code = self.parse_url(value, False, uuid) + if status_code != 200: + return status_code + else: + for feature in ('detected_urls', 'undetected_urls'): + if feature in query_result: + for url in query_result[feature]: + value = url['url'] if isinstance(url, dict) else url[0] + self.misp_event.add_attribute('url', value) + return 200 - def get_more_info(self, req): - # Get all hashes first - hashes = self.find_all(req) - for h in hashes[:self.limit]: - # Search VT for some juicy info - try: - data = requests.get(self.base_url.format('file'), params={'resource': h, 'apikey': self.apikey, 'allinfo': 1}).json() - except Exception: - continue - # Go through euch key and check if it exists - for VT_type, MISP_type in self.output_types_mapping.items(): - if VT_type in data: - try: - self.results[((MISP_type,), comment.format(h))].add(data[VT_type]) - except TypeError: - self.results[((MISP_type,), comment.format(h))].update(data[VT_type]) - # Get the malware sample - sample = requests.get(self.base_url[:-6].format('file/download'), params={'hash': h, 'apikey': self.apikey}) - malsample = sample.content - # It is possible for VT to not give us any submission names - if "submission_names" in data: - self.to_return.append({"types": ["malware-sample"], "categories": ["Payload delivery"], - "values": data["submimssion_names"], "data": str(base64.b64encore(malsample), 'utf-8')}) + def parse_resolutions(self, resolutions, subdomains=None, uuids=None): + domain_ip_object = MISPObject('domain-ip') + if self.attribute.type == 'domain': + domain_ip_object.add_attribute('domain', type='domain', value=self.attribute.value) + attribute_type, relation, key = ('ip-dst', 'ip', 'ip_address') + else: + domain_ip_object.add_attribute('ip', type='ip-dst', value=self.attribute.value) + attribute_type, relation, key = ('domain', 'domain', 'hostname') + for resolution in resolutions: + domain_ip_object.add_attribute(relation, type=attribute_type, value=resolution[key]) + if subdomains: + for subdomain in subdomains: + attribute = MISPAttribute() + attribute.from_dict(**dict(type='domain', value=subdomain)) + self.misp_event.add_attribute(**attribute) + domain_ip_object.add_reference(attribute.uuid, 'subdomain') + if uuids: + for uuid in uuids: + domain_ip_object.add_reference(uuid, 'sibling-of') + self.misp_event.add_object(**domain_ip_object) + return domain_ip_object.uuid + + def parse_siblings(self, domain): + attribute = MISPAttribute() + attribute.from_dict(**dict(type='domain', value=domain)) + self.misp_event.add_attribute(**attribute) + return attribute.uuid + + def parse_vt_object(self, query_result): + vt_object = MISPObject('virustotal-report') + vt_object.add_attribute('permalink', type='link', value=query_result['permalink']) + detection_ratio = '{}/{}'.format(query_result['positives'], query_result['total']) + vt_object.add_attribute('detection-ratio', type='text', value=detection_ratio) + self.misp_event.add_object(**vt_object) + return vt_object.uuid + + +def parse_error(status_code): + status_mapping = {204: 'VirusTotal request rate limit exceeded.', + 400: 'Incorrect request, please check the arguments.', + 403: 'You don\'t have enough privileges to make the request.'} + if status_code in status_mapping: + return status_mapping[status_code] + return "VirusTotal may not be accessible." def handler(q=False): if q is False: return False - q = json.loads(q) - if not q.get('config') or not q['config'].get('apikey'): + request = json.loads(q) + if not request.get('config') or not request['config'].get('apikey'): misperrors['error'] = "A VirusTotal api key is required for this module." return misperrors - del q['module'] - query = VirusTotalRequest(q.pop('config')) - r = query.parse_request(q) - if isinstance(r, str): - misperrors['error'] = r + parser = VirusTotalParser(request['config']['apikey']) + attribute = request['attribute'] + status = parser.query_api(attribute) + if status != 200: + misperrors['error'] = parse_error(status) return misperrors - return {'results': r} + return parser.get_result() def introspection():