diff --git a/pymispgalaxies/api.py b/pymispgalaxies/api.py index 3aafbe2..4f938fc 100644 --- a/pymispgalaxies/api.py +++ b/pymispgalaxies/api.py @@ -8,7 +8,12 @@ import sys from collections.abc import Mapping from glob import glob import re -from typing import List, Dict, Optional, Any, Tuple, Iterator +from typing import List, Dict, Optional, Any, Tuple, Iterator, overload, Union + +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal try: import jsonschema # type: ignore @@ -225,7 +230,14 @@ class Cluster(Mapping): # type: ignore raise PyMISPGalaxiesError("Duplicate value ({}) in cluster: {}".format(new_cluster_value.value, self.name)) self.cluster_values[new_cluster_value.value] = new_cluster_value - def search(self, query: str, return_tags: bool=False) -> List[str]: + @overload + def search(self, query: str, return_tags: Literal[False]=False) -> List[ClusterValue]: ... + @overload + def search(self, query: str, return_tags: Literal[True]) -> List[str]: ... + @overload + def search(self, query: str, return_tags: bool) -> Union[List[ClusterValue], List[str]]: ... + + def search(self, query: str, return_tags: bool=False) -> Union[List[ClusterValue], List[str]]: matching = [] for v in self.values(): if [s for s in v.searchable if query.lower() in s.lower()]: