From a5edb31f2b571ef467f267fd8330b6a1ee3fa03c Mon Sep 17 00:00:00 2001 From: Christian Keil Date: Fri, 5 Nov 2021 10:30:45 +0100 Subject: [PATCH 1/2] Adds overloaded signatures for Cluster.search. Fixes #4 --- pymispgalaxies/api.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/pymispgalaxies/api.py b/pymispgalaxies/api.py index 3aafbe2..f310054 100644 --- a/pymispgalaxies/api.py +++ b/pymispgalaxies/api.py @@ -8,7 +8,8 @@ import sys from collections.abc import Mapping from glob import glob import re -from typing import List, Dict, Optional, Any, Tuple, Iterator +from typing import List, Dict, Optional, Any, Tuple, Iterator, overload, Union +from typing_extensions import Literal try: import jsonschema # type: ignore @@ -225,7 +226,14 @@ class Cluster(Mapping): # type: ignore raise PyMISPGalaxiesError("Duplicate value ({}) in cluster: {}".format(new_cluster_value.value, self.name)) self.cluster_values[new_cluster_value.value] = new_cluster_value - def search(self, query: str, return_tags: bool=False) -> List[str]: + @overload + def search(self, query: str, return_tags: Literal[False]=False) -> List[ClusterValue]: ... + @overload + def search(self, query: str, return_tags: Literal[True]) -> List[str]: ... + @overload + def search(self, query: str, return_tags: bool) -> Union[List[ClusterValue], List[str]]: ... + + def search(self, query: str, return_tags: bool=False) -> Union[List[ClusterValue], List[str]]: matching = [] for v in self.values(): if [s for s in v.searchable if query.lower() in s.lower()]: From f4d3209f0af9a582d2e940bd7dfe34bde58c7a45 Mon Sep 17 00:00:00 2001 From: Christian Keil Date: Fri, 5 Nov 2021 11:04:28 +0100 Subject: [PATCH 2/2] Switch Literal import based on Python version. --- pymispgalaxies/api.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pymispgalaxies/api.py b/pymispgalaxies/api.py index f310054..4f938fc 100644 --- a/pymispgalaxies/api.py +++ b/pymispgalaxies/api.py @@ -9,7 +9,11 @@ from collections.abc import Mapping from glob import glob import re from typing import List, Dict, Optional, Any, Tuple, Iterator, overload, Union -from typing_extensions import Literal + +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal try: import jsonschema # type: ignore