new: [ClusterValue] merge and Cluster.add() for merging

pull/28/head
Christophe Vandeplas 2024-06-24 13:10:23 +02:00
parent f4a679a35b
commit 21e1966ce2
No known key found for this signature in database
GPG Key ID: BDC48619FFDC5A5B
3 changed files with 103 additions and 6 deletions

View File

@ -339,6 +339,26 @@ class ClusterValue():
return None return None
return ClusterValueMeta(m) return ClusterValueMeta(m)
def merge(self, new: 'ClusterValue') -> None:
"""
Merges the new cluster value with the existing one. Practically it replaces the existing one but merges relations
"""
# backup relations
related_backup = self.related.copy()
# overwrite itself
self.__init__(new.to_dict()) # type: ignore [misc]
# merge relations with backup # LATER conver related to a class of Hashable type, as that would be much more efficient in keeping uniques
for rel in related_backup:
# if uuid exists, skip, as we already copied it
exists = False
for existing_item in self.related:
if rel['dest-uuid'] == existing_item['dest-uuid']:
exists = True
break
# else append rel to list
if not exists:
self.related.append(rel)
def to_json(self) -> str: def to_json(self) -> str:
""" """
Converts the ClusterValue object to a JSON string. Converts the ClusterValue object to a JSON string.
@ -361,7 +381,7 @@ class ClusterValue():
if self.description: if self.description:
to_return['description'] = self.description to_return['description'] = self.description
if self.meta: if self.meta:
to_return['meta'] = self.meta to_return['meta'] = self.meta.to_dict()
if self.related: if self.related:
to_return['related'] = self.related to_return['related'] = self.related
return to_return return to_return
@ -519,12 +539,21 @@ class Cluster(Mapping): # type: ignore
def append(self, cv: Union[Dict[str, Any], ClusterValue], skip_duplicates: bool = False) -> None: def append(self, cv: Union[Dict[str, Any], ClusterValue], skip_duplicates: bool = False) -> None:
""" """
Adds a cluster value to the cluster. Adds a cluster value to the cluster, and merge it if it already exists.
Args:
cv (Union[Dict[str, Any], ClusterValue]): The cluster value to add.
skip_duplicates (bool, optional): Flag indicating whether to skip duplicate values. Defaults to False.
""" """
if isinstance(cv, dict): if isinstance(cv, dict):
cv = ClusterValue(cv) cv = ClusterValue(cv)
if self.get(cv.value): existing = self.get(cv.value)
if skip_duplicates: if existing:
if cv.uuid == existing.uuid:
# merge the existing
self.cluster_values[cv.value.lower()].merge(cv)
return
elif skip_duplicates:
self.duplicates.append((self.name, cv.value)) self.duplicates.append((self.name, cv.value))
else: else:
raise PyMISPGalaxiesError("Duplicate value ({}) in cluster: {}".format(cv.value, self.name)) raise PyMISPGalaxiesError("Duplicate value ({}) in cluster: {}".format(cv.value, self.name))

View File

@ -2,13 +2,15 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import unittest import unittest
from pymispgalaxies import Galaxies, Clusters, UnableToRevertMachinetag from pymispgalaxies import Galaxies, Clusters, UnableToRevertMachinetag, Galaxy, Cluster
from glob import glob from glob import glob
import os import os
import json import json
from collections import Counter, defaultdict from collections import Counter, defaultdict
import warnings import warnings
from uuid import UUID from uuid import UUID
import filecmp
import tempfile
class TestPyMISPGalaxies(unittest.TestCase): class TestPyMISPGalaxies(unittest.TestCase):
@ -48,6 +50,16 @@ class TestPyMISPGalaxies(unittest.TestCase):
out = g.to_dict() out = g.to_dict()
self.assertDictEqual(out, galaxies_from_files[g.type]) self.assertDictEqual(out, galaxies_from_files[g.type])
@unittest.skip("We don't want to enforce it.")
def test_save_galaxies(self):
for galaxy_file in glob(os.path.join(self.galaxies.root_dir_galaxies, '*.json')):
with open(galaxy_file, 'r') as f:
galaxy = Galaxy(json.load(f))
with tempfile.NamedTemporaryFile(suffix='.json') as temp_file:
temp_file_no_suffix = temp_file.name[:-5]
galaxy.save(temp_file_no_suffix)
self.assertTrue(filecmp.cmp(galaxy_file, temp_file.name), msg=f"{galaxy_file} different when saving using Galaxy.save(). Maybe an sorting issue?")
def test_dump_clusters(self): def test_dump_clusters(self):
clusters_from_files = {} clusters_from_files = {}
for cluster_file in glob(os.path.join(self.clusters.root_dir_clusters, '*.json')): for cluster_file in glob(os.path.join(self.clusters.root_dir_clusters, '*.json')):
@ -59,6 +71,16 @@ class TestPyMISPGalaxies(unittest.TestCase):
print(name, c.name) print(name, c.name)
self.assertCountEqual(out, clusters_from_files[c.name]) self.assertCountEqual(out, clusters_from_files[c.name])
@unittest.skip("We don't want to enforce it.")
def test_save_clusters(self):
for cluster_file in glob(os.path.join(self.clusters.root_dir_clusters, '*.json')):
with open(cluster_file, 'r') as f:
cluster = Cluster(json.load(f))
with tempfile.NamedTemporaryFile(suffix='.json') as temp_file:
temp_file_no_suffix = temp_file.name[:-5]
cluster.save(temp_file_no_suffix)
self.assertTrue(filecmp.cmp(cluster_file, temp_file.name), msg=f"{cluster_file} different when saving using Cluster.save(). Maybe a sorting issue?")
def test_validate_schema_clusters(self): def test_validate_schema_clusters(self):
self.clusters.validate_with_schema() self.clusters.validate_with_schema()

View File

@ -2,7 +2,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import unittest import unittest
from pymispgalaxies import Galaxies, Clusters, Cluster from pymispgalaxies import Galaxies, Clusters, Cluster, ClusterValue
class TestPyMISPGalaxiesApi(unittest.TestCase): class TestPyMISPGalaxiesApi(unittest.TestCase):
@ -21,3 +21,49 @@ class TestPyMISPGalaxiesApi(unittest.TestCase):
with self.assertRaises(KeyError): with self.assertRaises(KeyError):
cluster.get_by_external_id('XXXXXX') cluster.get_by_external_id('XXXXXX')
def test_merge_cv(self):
cv_1 = ClusterValue({
'uuid': '1234',
'value': 'old value',
'description': 'old description',
'related': [
{
'dest-uuid': '1',
'type': 'subtechnique-of'
},
{
'dest-uuid': '2',
'type': 'old-type'
}
]
})
cv_2 = ClusterValue({
'uuid': '1234',
'value': 'new value',
'description': 'new description',
'related': [
{
'dest-uuid': '2',
'type': 'new-type'
},
{
'dest-uuid': '3',
'type': 'similar-to'
}
]
})
cv_1.merge(cv_2)
self.assertEqual(cv_1.value, 'new value')
self.assertEqual(cv_1.description, 'new description')
for rel in cv_1.related:
if rel['dest-uuid'] == '1':
self.assertEqual(rel['type'], 'subtechnique-of')
elif rel['dest-uuid'] == '2':
self.assertEqual(rel['type'], 'new-type')
elif rel['dest-uuid'] == '3':
self.assertEqual(rel['type'], 'similar-to')
else:
self.fail(f"Unexpected related: {rel}")