fix: [typing] fix typing

pull/25/head
Christophe Vandeplas 2024-06-18 15:03:10 +02:00
parent 1ab1976343
commit 3827e3ed93
No known key found for this signature in database
GPG Key ID: BDC48619FFDC5A5B
1 changed files with 36 additions and 70 deletions

View File

@ -62,42 +62,29 @@ class Galaxy():
kill_chain_order (str, optional): The kill chain order of the galaxy.
"""
@overload
def __init__(self, galaxy: str):
def __init__(self, galaxy: Union[str, Dict[str, str]]):
"""
Initializes a Galaxy object from an existing galaxy.
Args:
galaxy (str): The name of the existing galaxy to load from the data folder.
"""
...
@overload
def __init__(self, galaxy: Dict[str, str]):
"""
Initializes a new instance of the Galaxy class.
Args:
galaxy (Dict[str, str]): The dictionary containing the galaxy data.
"""
...
def __init__(self, galaxy):
if isinstance(galaxy, str):
root_dir_galaxies = os.path.join(os.path.abspath(os.path.dirname(sys.modules['pymispgalaxies'].__file__)), 'data', 'misp-galaxy', 'galaxies')
root_dir_galaxies = os.path.join(os.path.abspath(os.path.dirname(sys.modules['pymispgalaxies'].__file__)), 'data', 'misp-galaxy', 'galaxies') # type: ignore [type-var, arg-type]
galaxy_file = os.path.join(root_dir_galaxies, f"{galaxy}.json")
with open(galaxy_file, 'r') as f:
self.__init__(json.load(f))
self.galaxy = json.load(f)
else:
self.galaxy = galaxy
self.type = self.galaxy['type']
self.name = self.galaxy['name']
self.icon = self.galaxy['icon']
self.description = self.galaxy['description']
self.version = self.galaxy['version']
self.uuid = self.galaxy['uuid']
self.namespace = self.galaxy.pop('namespace', None)
self.kill_chain_order = self.galaxy.pop('kill_chain_order', None)
self.type = self.galaxy['type']
self.name = self.galaxy['name']
self.icon = self.galaxy['icon']
self.description = self.galaxy['description']
self.version = self.galaxy['version']
self.uuid = self.galaxy['uuid']
self.namespace = self.galaxy.pop('namespace', None)
self.kill_chain_order = self.galaxy.pop('kill_chain_order', None)
def save(self, name: str) -> None:
"""
@ -106,7 +93,7 @@ class Galaxy():
Args:
name (str): The name of the file to save the galaxy to.
"""
root_dir_galaxies = os.path.join(os.path.abspath(os.path.dirname(sys.modules['pymispgalaxies'].__file__)), 'data', 'misp-galaxy', 'galaxies')
root_dir_galaxies = os.path.join(os.path.abspath(os.path.dirname(sys.modules['pymispgalaxies'].__file__)), 'data', 'misp-galaxy', 'galaxies') # type: ignore [type-var, arg-type]
galaxy_file = os.path.join(root_dir_galaxies, f"{name}.json")
with open(galaxy_file, 'w') as f:
json.dump(self, f, cls=EncodeGalaxies, indent=2, sort_keys=True, ensure_ascii=False)
@ -411,52 +398,39 @@ class Cluster(Mapping): # type: ignore
to_json(self) -> str: Converts the Cluster object to a JSON string.
to_dict(self) -> Dict[str, Any]: Converts the Cluster object to a dictionary.
"""
@overload
def __init__(self, cluster: str, skip_duplicates: bool = False):
def __init__(self, cluster: Union[Dict[str, Any], str], skip_duplicates: bool = False):
"""
Initializes a Cluster object from an existing cluster.
Args:
cluster (str): The name of the existing cluster to load from the data folder.
skip_duplicates (bool, optional): Flag indicating whether to skip duplicate values. Defaults to False.
"""
...
@overload
def __init__(self, cluster: Dict[str, Any], skip_duplicates: bool = False):
"""
Initializes a Cluster object from a dictionary.
Args:
cluster (Dict[str, Any]): A dictionary containing the cluster data.
skip_duplicates (bool, optional): Flag indicating whether to skip duplicate values. Defaults to False.
"""
...
def __init__(self, cluster, skip_duplicates: bool = False):
if isinstance(cluster, str):
root_dir_clusters = os.path.join(os.path.abspath(os.path.dirname(sys.modules['pymispgalaxies'].__file__)), 'data', 'misp-galaxy', 'clusters')
root_dir_clusters = os.path.join(os.path.abspath(os.path.dirname(sys.modules['pymispgalaxies'].__file__)), 'data', 'misp-galaxy', 'clusters') # type: ignore [type-var, arg-type]
cluster_file = os.path.join(root_dir_clusters, f"{cluster}.json")
with open(cluster_file, 'r') as f:
self.__init__(json.load(f), skip_duplicates=skip_duplicates)
self.cluster = json.load(f)
else:
self.cluster = cluster
self.name = self.cluster['name']
self.type = self.cluster['type']
self.source = self.cluster['source']
self.authors = self.cluster['authors']
self.description = self.cluster['description']
self.uuid = self.cluster['uuid']
self.version = self.cluster['version']
self.category = self.cluster['category']
self.cluster_values = {}
self.duplicates = []
try:
for value in self.cluster['values']:
new_cluster_value = ClusterValue(value)
self.append(new_cluster_value, skip_duplicates)
except KeyError:
pass
self.name = self.cluster['name']
self.type = self.cluster['type']
self.source = self.cluster['source']
self.authors = self.cluster['authors']
self.description = self.cluster['description']
self.uuid = self.cluster['uuid']
self.version = self.cluster['version']
self.category = self.cluster['category']
self.cluster_values: Dict[str, Any] = {}
self.duplicates: List[Tuple[str, str]] = []
try:
for value in self.cluster['values']:
new_cluster_value = ClusterValue(value)
self.append(new_cluster_value, skip_duplicates)
except KeyError:
pass
@overload
def search(self, query: str, return_tags: Literal[False] = False) -> List[ClusterValue]:
@ -520,7 +494,7 @@ class Cluster(Mapping): # type: ignore
return value
raise KeyError('No value with external_id: {}'.format(external_id))
def get_kill_chain_tactics(self) -> dict:
def get_kill_chain_tactics(self) -> Dict[str, List[str]]:
"""
Returns the sorted kill chain tactics associated with the cluster.
@ -532,7 +506,7 @@ class Cluster(Mapping): # type: ignore
if v.meta and v.meta.additional_properties and v.meta.additional_properties.get('kill_chain'):
for item in v.meta.additional_properties.get('kill_chain'):
items.add(item)
result = {}
result: Dict[str, List[str]] = {}
for item in items:
key, value = item.split(':')
if key not in result:
@ -543,15 +517,7 @@ class Cluster(Mapping): # type: ignore
result[key] = sorted(result[key])
return result
@overload
def append(self, cv: dict, skip_duplicates: bool = False) -> None:
...
@overload
def append(self, cv: ClusterValue, skip_duplicates: bool = False) -> None:
...
def append(self, cv, skip_duplicates: bool = False) -> None:
def append(self, cv: Union[Dict[str, Any], ClusterValue], skip_duplicates: bool = False) -> None:
"""
Adds a cluster value to the cluster.
"""
@ -571,7 +537,7 @@ class Cluster(Mapping): # type: ignore
Args:
name (str): The name of the file to save the cluster to.
"""
root_dir_clusters = os.path.join(os.path.abspath(os.path.dirname(sys.modules['pymispgalaxies'].__file__)), 'data', 'misp-galaxy', 'clusters')
root_dir_clusters = os.path.join(os.path.abspath(os.path.dirname(sys.modules['pymispgalaxies'].__file__)), 'data', 'misp-galaxy', 'clusters') # type: ignore [type-var, arg-type]
cluster_file = os.path.join(root_dir_clusters, f"{name}.json")
with open(cluster_file, 'w') as f:
json.dump(self, f, cls=EncodeClusters, indent=2, sort_keys=True, ensure_ascii=False)