fix: fix overloaded function

pull/25/head
Christophe Vandeplas 2024-06-18 10:43:17 +02:00
parent b5de7b54d4
commit 9d5c6a1b5d
No known key found for this signature in database
GPG Key ID: BDC48619FFDC5A5B
2 changed files with 31 additions and 6 deletions

View File

@ -1,4 +1,4 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from .api import Galaxies, Galaxy, Clusters, Cluster, EncodeGalaxies, EncodeClusters, UnableToRevertMachinetag from .api import Galaxies, Galaxy, Clusters, Cluster, ClusterValue, EncodeGalaxies, EncodeClusters, UnableToRevertMachinetag

View File

@ -345,6 +345,7 @@ class Cluster(Mapping): # type: ignore
Attributes: Attributes:
cluster (Dict[str, Any]): The dictionary containing the cluster data. cluster (Dict[str, Any]): The dictionary containing the cluster data.
cluster (str): The name of the existing cluster to load from the data folder.
name (str): The name of the cluster. name (str): The name of the cluster.
type (str): The type of the cluster. type (str): The type of the cluster.
source (str): The source of the cluster. source (str): The source of the cluster.
@ -368,20 +369,34 @@ class Cluster(Mapping): # type: ignore
to_json(self) -> str: Converts the Cluster object to a JSON string. to_json(self) -> str: Converts the Cluster object to a JSON string.
to_dict(self) -> Dict[str, Any]: Converts the Cluster object to a dictionary. to_dict(self) -> Dict[str, Any]: Converts the Cluster object to a dictionary.
""" """
def __init__(self, cluster: Dict[str, Any] | str, skip_duplicates: bool = False): @overload
def __init__(self, cluster: str, skip_duplicates: bool = False):
""" """
Initializes a Cluster object. Initializes a Cluster object from an existing cluster.
Args: Args:
cluster (Dict[str, Any] | str): A dictionary containing the cluster data, or the name of the existing cluster to load from the data folder. cluster (str): The name of the existing cluster to load from the data folder.
skip_duplicates (bool, optional): Flag indicating whether to skip duplicate values. Defaults to False. skip_duplicates (bool, optional): Flag indicating whether to skip duplicate values. Defaults to False.
""" """
...
@overload
def __init__(self, cluster: Dict[str, Any], skip_duplicates: bool = False):
"""
Initializes a Cluster object from a dictionary.
Args:
cluster (Dict[str, Any]): A dictionary containing the cluster data.
skip_duplicates (bool, optional): Flag indicating whether to skip duplicate values. Defaults to False.
"""
...
def __init__(self, cluster, skip_duplicates: bool = False):
if isinstance(cluster, str): if isinstance(cluster, str):
root_dir_clusters = os.path.join(os.path.abspath(os.path.dirname(sys.modules['pymispgalaxies'].__file__)), 'data', 'misp-galaxy', 'clusters') root_dir_clusters = os.path.join(os.path.abspath(os.path.dirname(sys.modules['pymispgalaxies'].__file__)), 'data', 'misp-galaxy', 'clusters')
cluster_file = os.path.join(root_dir_clusters, f"{cluster}.json") cluster_file = os.path.join(root_dir_clusters, f"{cluster}.json")
with open(cluster_file, 'r') as f: with open(cluster_file, 'r') as f:
self.__init__(json.load(f), skip_duplicates=skip_duplicates) self.__init__(json.load(f), skip_duplicates=skip_duplicates)
else: else:
self.cluster = cluster self.cluster = cluster
self.name = self.cluster['name'] self.name = self.cluster['name']
@ -463,10 +478,20 @@ class Cluster(Mapping): # type: ignore
return value return value
raise KeyError('No value with external_id: {}'.format(external_id)) raise KeyError('No value with external_id: {}'.format(external_id))
def add(self, cv: ClusterValue, skip_duplicates: bool) -> None: @overload
def add(self, cv: dict, skip_duplicates: bool = False) -> None:
...
@overload
def add(self, cv: ClusterValue, skip_duplicates: bool = False) -> None:
...
def add(self, cv, skip_duplicates: bool = False) -> None:
""" """
Adds a cluster value to the cluster. Adds a cluster value to the cluster.
""" """
if isinstance(cv, dict):
cv = ClusterValue(cv)
if self.get(cv.value): if self.get(cv.value):
if skip_duplicates: if skip_duplicates:
self.duplicates.append((self.name, cv.value)) self.duplicates.append((self.name, cv.value))