diff --git a/tests/test_expansions.py b/tests/test_expansions.py index a9714e4f..8099a62c 100644 --- a/tests/test_expansions.py +++ b/tests/test_expansions.py @@ -28,12 +28,15 @@ class TestExpansions(unittest.TestCase): return requests.post(urljoin(self.url, "query"), json=query) @staticmethod - def get_attribute(response): + def get_attribute_types(response): data = response.json() if not isinstance(data, dict): print(json.dumps(data, indent=2)) return data - return data['results']['Attribute'][0]['type'] + types = [] + for attribute in data['results']['Attribute']: + types.append(attribute['type']) + return types @staticmethod def get_data(response): @@ -52,7 +55,18 @@ class TestExpansions(unittest.TestCase): return data['error'] @staticmethod - def get_object(response): + def get_object_types(response): + data = response.json() + if not isinstance(data, dict): + print(json.dumps(data, indent=2)) + return data + names = [] + for obj in data['results']['Object']: + names.append(obj['name']) + return names + + @staticmethod + def get_first_object_type(response): data = response.json() if not isinstance(data, dict): print(json.dumps(data, indent=2)) @@ -95,7 +109,7 @@ class TestExpansions(unittest.TestCase): query['config'] = self.configs[module_name] response = self.misp_modules_post(query) try: - self.assertEqual(self.get_object(response), 'dns-record') + self.assertEqual(self.get_first_object_type(response), 'dns-record') except Exception: self.assertTrue(self.get_errors(response).startswith('You do not have enough APIVoid credits')) else: @@ -112,7 +126,7 @@ class TestExpansions(unittest.TestCase): } } response = self.misp_modules_post(query) - self.assertEqual(self.get_object(response), 'asn') + self.assertEqual(self.get_first_object_type(response), 'asn') def test_btc_steroids(self): if LiveCI: @@ -142,7 +156,7 @@ class TestExpansions(unittest.TestCase): query['config'] = self.configs[module_name] response = self.misp_modules_post(query) try: - self.assertEqual(self.get_object(response), 'passive-dns') + self.assertEqual(self.get_first_object_type(response), 'passive-dns') except Exception: self.assertTrue(self.get_errors(response).startswith('There is an authentication error')) else: @@ -160,7 +174,7 @@ class TestExpansions(unittest.TestCase): query['config'] = self.configs[module_name] response = self.misp_modules_post(query) try: - self.assertEqual(self.get_object(response), 'x509') + self.assertEqual(self.get_first_object_type(response), 'x509') except Exception: self.assertTrue(self.get_errors(response).startswith('There is an authentication error')) else: @@ -190,7 +204,7 @@ class TestExpansions(unittest.TestCase): "config": {}} response = self.misp_modules_post(query) try: - self.assertEqual(self.get_object(response), 'vulnerability') + self.assertEqual(self.get_first_object_type(response), 'vulnerability') except Exception: print(self.get_errors(response)) @@ -309,7 +323,7 @@ class TestExpansions(unittest.TestCase): "value": "149.13.33.14", "uuid": "ea89a33b-4ab7-4515-9f02-922a0bee333d"}} response = self.misp_modules_post(query) - self.assertEqual(self.get_object(response), 'asn') + self.assertEqual(self.get_first_object_type(response), 'asn') def test_ipqs_fraud_and_risk_scoring(self): module_name = "ipqs_fraud_and_risk_scoring" @@ -508,7 +522,7 @@ class TestExpansions(unittest.TestCase): if module_name in self.configs: query['config'] = self.configs[module_name] response = self.misp_modules_post(query) - self.assertEqual(self.get_object(response), 'ip-api-address') + self.assertEqual(self.get_first_object_type(response), 'ip-api-address') else: response = self.misp_modules_post(query) self.assertEqual(self.get_errors(response), 'Shodan authentication is missing') @@ -581,6 +595,7 @@ class TestExpansions(unittest.TestCase): 'a04ac6d98ad989312783d4fe3456c53730b212c79a426fb215708b6c6daa3de3', 'http://79.118.195.239:1924/.i') results = ('url', 'url', 'file', 'virustotal-report') + for query_type, query_value, result in zip(query_types[:2], query_values[:2], results[:2]): query = {"module": "urlhaus", "attribute": {"type": query_type, @@ -588,7 +603,8 @@ class TestExpansions(unittest.TestCase): "uuid": "ea89a33b-4ab7-4515-9f02-922a0bee333d"}} response = self.misp_modules_post(query) print(response.json()) - self.assertEqual(self.get_attribute(response), result) + self.assertIn(result, self.get_attribute_types(response)) + for query_type, query_value, result in zip(query_types[2:], query_values[2:], results[2:]): query = {"module": "urlhaus", "attribute": {"type": query_type, @@ -596,7 +612,7 @@ class TestExpansions(unittest.TestCase): "uuid": "ea89a33b-4ab7-4515-9f02-922a0bee333d"}} response = self.misp_modules_post(query) print(response.json()) - self.assertEqual(self.get_object(response), result) + self.assertIn(result, self.get_object_types(response)) def test_urlscan(self): module_name = "urlscan" @@ -641,7 +657,7 @@ class TestExpansions(unittest.TestCase): "config": self.configs[module_name]} response = self.misp_modules_post(query) try: - self.assertEqual(self.get_object(response), result) + self.assertEqual(self.get_first_object_type(response), result) except Exception: self.assertEqual(self.get_errors(response), "VirusTotal request rate limit exceeded.") else: @@ -684,7 +700,7 @@ class TestExpansions(unittest.TestCase): "config": self.configs[module_name]} response = self.misp_modules_post(query) try: - self.assertEqual(self.get_object(response), result) + self.assertEqual(self.get_first_object_type(response), result) except Exception: self.assertEqual(self.get_errors(response), "VirusTotal request rate limit exceeded.") else: @@ -730,7 +746,7 @@ class TestExpansions(unittest.TestCase): "uuid": "ea89a33b-4ab7-4515-9f02-922a0bee333d"}, "config": self.configs[module_name]} response = self.misp_modules_post(query) - self.assertEqual(self.get_object(response), result) + self.assertEqual(self.get_first_object_type(response), result) else: query = {"module": module_name, "attribute": {"type": query_types[0],