From a5edb31f2b571ef467f267fd8330b6a1ee3fa03c Mon Sep 17 00:00:00 2001 From: Christian Keil Date: Fri, 5 Nov 2021 10:30:45 +0100 Subject: [PATCH] Adds overloaded signatures for Cluster.search. Fixes #4 --- pymispgalaxies/api.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/pymispgalaxies/api.py b/pymispgalaxies/api.py index 3aafbe2..f310054 100644 --- a/pymispgalaxies/api.py +++ b/pymispgalaxies/api.py @@ -8,7 +8,8 @@ import sys from collections.abc import Mapping from glob import glob import re -from typing import List, Dict, Optional, Any, Tuple, Iterator +from typing import List, Dict, Optional, Any, Tuple, Iterator, overload, Union +from typing_extensions import Literal try: import jsonschema # type: ignore @@ -225,7 +226,14 @@ class Cluster(Mapping): # type: ignore raise PyMISPGalaxiesError("Duplicate value ({}) in cluster: {}".format(new_cluster_value.value, self.name)) self.cluster_values[new_cluster_value.value] = new_cluster_value - def search(self, query: str, return_tags: bool=False) -> List[str]: + @overload + def search(self, query: str, return_tags: Literal[False]=False) -> List[ClusterValue]: ... + @overload + def search(self, query: str, return_tags: Literal[True]) -> List[str]: ... + @overload + def search(self, query: str, return_tags: bool) -> Union[List[ClusterValue], List[str]]: ... + + def search(self, query: str, return_tags: bool=False) -> Union[List[ClusterValue], List[str]]: matching = [] for v in self.values(): if [s for s in v.searchable if query.lower() in s.lower()]: