diff --git a/changelog.d/8655.misc b/changelog.d/8655.misc new file mode 100644 index 0000000000..b588bdd3e2 --- /dev/null +++ b/changelog.d/8655.misc @@ -0,0 +1 @@ +Add more type hints to the application services code. diff --git a/mypy.ini b/mypy.ini index 1fbd8decf8..1ece2ba082 100644 --- a/mypy.ini +++ b/mypy.ini @@ -57,6 +57,7 @@ files = synapse/server_notices, synapse/spam_checker_api, synapse/state, + synapse/storage/databases/main/appservice.py, synapse/storage/databases/main/events.py, synapse/storage/databases/main/registration.py, synapse/storage/databases/main/stream.py, @@ -82,6 +83,9 @@ ignore_missing_imports = True [mypy-zope] ignore_missing_imports = True +[mypy-bcrypt] +ignore_missing_imports = True + [mypy-constantly] ignore_missing_imports = True diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 3ed29a2c16..9fc8444228 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -12,9 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging -from typing import Dict, List, Optional, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Union from prometheus_client import Counter @@ -34,16 +33,20 @@ from synapse.metrics.background_process_metrics import ( run_as_background_process, wrap_as_background_process, ) -from synapse.types import Collection, JsonDict, RoomStreamToken, UserID +from synapse.storage.databases.main.directory import RoomAliasMapping +from synapse.types import Collection, JsonDict, RoomAlias, RoomStreamToken, UserID from synapse.util.metrics import Measure +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) events_processed_counter = Counter("synapse_handlers_appservice_events_processed", "") class ApplicationServicesHandler: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.is_mine_id = hs.is_mine_id self.appservice_api = hs.get_application_service_api() @@ -247,7 +250,9 @@ class ApplicationServicesHandler: service, "presence", new_token ) - async def _handle_typing(self, service: ApplicationService, new_token: int): + async def _handle_typing( + self, service: ApplicationService, new_token: int + ) -> List[JsonDict]: typing_source = self.event_sources.sources["typing"] # Get the typing events from just before current typing, _ = await typing_source.get_new_events_as( @@ -259,7 +264,7 @@ class ApplicationServicesHandler: ) return typing - async def _handle_receipts(self, service: ApplicationService): + async def _handle_receipts(self, service: ApplicationService) -> List[JsonDict]: from_key = await self.store.get_type_stream_id_for_appservice( service, "read_receipt" ) @@ -271,7 +276,7 @@ class ApplicationServicesHandler: async def _handle_presence( self, service: ApplicationService, users: Collection[Union[str, UserID]] - ): + ) -> List[JsonDict]: events = [] # type: List[JsonDict] presence_source = self.event_sources.sources["presence"] from_key = await self.store.get_type_stream_id_for_appservice( @@ -301,11 +306,11 @@ class ApplicationServicesHandler: return events - async def query_user_exists(self, user_id): + async def query_user_exists(self, user_id: str) -> bool: """Check if any application service knows this user_id exists. Args: - user_id(str): The user to query if they exist on any AS. + user_id: The user to query if they exist on any AS. Returns: True if this user exists on at least one application service. """ @@ -316,11 +321,13 @@ class ApplicationServicesHandler: return True return False - async def query_room_alias_exists(self, room_alias): + async def query_room_alias_exists( + self, room_alias: RoomAlias + ) -> Optional[RoomAliasMapping]: """Check if an application service knows this room alias exists. Args: - room_alias(RoomAlias): The room alias to query. + room_alias: The room alias to query. Returns: namedtuple: with keys "room_id" and "servers" or None if no association can be found. @@ -336,10 +343,13 @@ class ApplicationServicesHandler: ) if is_known_alias: # the alias exists now so don't query more ASes. - result = await self.store.get_association_from_room_alias(room_alias) - return result + return await self.store.get_association_from_room_alias(room_alias) - async def query_3pe(self, kind, protocol, fields): + return None + + async def query_3pe( + self, kind: str, protocol: str, fields: Dict[bytes, List[bytes]] + ) -> List[JsonDict]: services = self._get_services_for_3pn(protocol) results = await make_deferred_yieldable( @@ -361,7 +371,9 @@ class ApplicationServicesHandler: return ret - async def get_3pe_protocols(self, only_protocol=None): + async def get_3pe_protocols( + self, only_protocol: Optional[str] = None + ) -> Dict[str, JsonDict]: services = self.store.get_app_services() protocols = {} # type: Dict[str, List[JsonDict]] @@ -379,7 +391,7 @@ class ApplicationServicesHandler: if info is not None: protocols[p].append(info) - def _merge_instances(infos): + def _merge_instances(infos: List[JsonDict]) -> JsonDict: if not infos: return {} @@ -394,19 +406,17 @@ class ApplicationServicesHandler: return combined - for p in protocols.keys(): - protocols[p] = _merge_instances(protocols[p]) + return {p: _merge_instances(protocols[p]) for p in protocols.keys()} - return protocols - - async def _get_services_for_event(self, event): + async def _get_services_for_event( + self, event: EventBase + ) -> List[ApplicationService]: """Retrieve a list of application services interested in this event. Args: - event(Event): The event to check. Can be None if alias_list is not. + event: The event to check. Can be None if alias_list is not. Returns: - list: A list of services interested in this - event based on the service regex. + A list of services interested in this event based on the service regex. """ services = self.store.get_app_services() @@ -420,17 +430,15 @@ class ApplicationServicesHandler: return interested_list - def _get_services_for_user(self, user_id): + def _get_services_for_user(self, user_id: str) -> List[ApplicationService]: services = self.store.get_app_services() - interested_list = [s for s in services if (s.is_interested_in_user(user_id))] - return interested_list + return [s for s in services if (s.is_interested_in_user(user_id))] - def _get_services_for_3pn(self, protocol): + def _get_services_for_3pn(self, protocol: str) -> List[ApplicationService]: services = self.store.get_app_services() - interested_list = [s for s in services if s.is_interested_in_protocol(protocol)] - return interested_list + return [s for s in services if s.is_interested_in_protocol(protocol)] - async def _is_unknown_user(self, user_id): + async def _is_unknown_user(self, user_id: str) -> bool: if not self.is_mine_id(user_id): # we don't know if they are unknown or not since it isn't one of our # users. We can't poke ASes. @@ -445,9 +453,8 @@ class ApplicationServicesHandler: service_list = [s for s in services if s.sender == user_id] return len(service_list) == 0 - async def _check_user_exists(self, user_id): + async def _check_user_exists(self, user_id: str) -> bool: unknown_user = await self._is_unknown_user(user_id) if unknown_user: - exists = await self.query_user_exists(user_id) - return exists + return await self.query_user_exists(user_id) return True diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index dd14ab69d7..276594f3d9 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -18,10 +18,20 @@ import logging import time import unicodedata import urllib.parse -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterable, + List, + Optional, + Tuple, + Union, +) import attr -import bcrypt # type: ignore[import] +import bcrypt import pymacaroons from synapse.api.constants import LoginType @@ -49,6 +59,9 @@ from synapse.util.threepids import canonicalise_email from ._base import BaseHandler +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) @@ -149,11 +162,7 @@ class SsoLoginExtraAttributes: class AuthHandler(BaseHandler): SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000 - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): - """ + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.checkers = {} # type: Dict[str, UserInteractiveAuthChecker] diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 637a938bac..26eef6eb61 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -15,21 +15,31 @@ # limitations under the License. import logging import re -from typing import List +from typing import TYPE_CHECKING, List, Optional, Pattern, Tuple -from synapse.appservice import ApplicationService, AppServiceTransaction +from synapse.appservice import ( + ApplicationService, + ApplicationServiceState, + AppServiceTransaction, +) from synapse.config.appservice import load_appservices from synapse.events import EventBase from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.storage.types import Connection from synapse.types import JsonDict from synapse.util import json_encoder +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) -def _make_exclusive_regex(services_cache): +def _make_exclusive_regex( + services_cache: List[ApplicationService], +) -> Optional[Pattern]: # We precompile a regex constructed from all the regexes that the AS's # have registered for exclusive users. exclusive_user_regexes = [ @@ -39,17 +49,19 @@ def _make_exclusive_regex(services_cache): ] if exclusive_user_regexes: exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes) - exclusive_user_regex = re.compile(exclusive_user_regex) + exclusive_user_pattern = re.compile( + exclusive_user_regex + ) # type: Optional[Pattern] else: # We handle this case specially otherwise the constructed regex # will always match - exclusive_user_regex = None + exclusive_user_pattern = None - return exclusive_user_regex + return exclusive_user_pattern class ApplicationServiceWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): self.services_cache = load_appservices( hs.hostname, hs.config.app_service_config_files ) @@ -60,7 +72,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore): def get_app_services(self): return self.services_cache - def get_if_app_services_interested_in_user(self, user_id): + def get_if_app_services_interested_in_user(self, user_id: str) -> bool: """Check if the user is one associated with an app service (exclusively) """ if self.exclusive_user_regex: @@ -68,7 +80,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore): else: return False - def get_app_service_by_user_id(self, user_id): + def get_app_service_by_user_id(self, user_id: str) -> Optional[ApplicationService]: """Retrieve an application service from their user ID. All application services have associated with them a particular user ID. @@ -77,35 +89,35 @@ class ApplicationServiceWorkerStore(SQLBaseStore): a user ID to an application service. Args: - user_id(str): The user ID to see if it is an application service. + user_id: The user ID to see if it is an application service. Returns: - synapse.appservice.ApplicationService or None. + The application service or None. """ for service in self.services_cache: if service.sender == user_id: return service return None - def get_app_service_by_token(self, token): + def get_app_service_by_token(self, token: str) -> Optional[ApplicationService]: """Get the application service with the given appservice token. Args: - token (str): The application service token. + token: The application service token. Returns: - synapse.appservice.ApplicationService or None. + The application service or None. """ for service in self.services_cache: if service.token == token: return service return None - def get_app_service_by_id(self, as_id): + def get_app_service_by_id(self, as_id: str) -> Optional[ApplicationService]: """Get the application service with the given appservice ID. Args: - as_id (str): The application service ID. + as_id: The application service ID. Returns: - synapse.appservice.ApplicationService or None. + The application service or None. """ for service in self.services_cache: if service.id == as_id: @@ -124,11 +136,13 @@ class ApplicationServiceStore(ApplicationServiceWorkerStore): class ApplicationServiceTransactionWorkerStore( ApplicationServiceWorkerStore, EventsWorkerStore ): - async def get_appservices_by_state(self, state): + async def get_appservices_by_state( + self, state: ApplicationServiceState + ) -> List[ApplicationService]: """Get a list of application services based on their state. Args: - state(ApplicationServiceState): The state to filter on. + state: The state to filter on. Returns: A list of ApplicationServices, which may be empty. """ @@ -145,13 +159,15 @@ class ApplicationServiceTransactionWorkerStore( services.append(service) return services - async def get_appservice_state(self, service): + async def get_appservice_state( + self, service: ApplicationService + ) -> Optional[ApplicationServiceState]: """Get the application service state. Args: - service(ApplicationService): The service whose state to set. + service: The service whose state to set. Returns: - An ApplicationServiceState. + An ApplicationServiceState or none. """ result = await self.db_pool.simple_select_one( "application_services_state", @@ -164,12 +180,14 @@ class ApplicationServiceTransactionWorkerStore( return result.get("state") return None - async def set_appservice_state(self, service, state) -> None: + async def set_appservice_state( + self, service: ApplicationService, state: ApplicationServiceState + ) -> None: """Set the application service state. Args: - service(ApplicationService): The service whose state to set. - state(ApplicationServiceState): The connectivity state to apply. + service: The service whose state to set. + state: The connectivity state to apply. """ await self.db_pool.simple_upsert( "application_services_state", {"as_id": service.id}, {"state": state} @@ -226,13 +244,14 @@ class ApplicationServiceTransactionWorkerStore( "create_appservice_txn", _create_appservice_txn ) - async def complete_appservice_txn(self, txn_id, service) -> None: + async def complete_appservice_txn( + self, txn_id: int, service: ApplicationService + ) -> None: """Completes an application service transaction. Args: - txn_id(str): The transaction ID being completed. - service(ApplicationService): The application service which was sent - this transaction. + txn_id: The transaction ID being completed. + service: The application service which was sent this transaction. """ txn_id = int(txn_id) @@ -242,7 +261,7 @@ class ApplicationServiceTransactionWorkerStore( # has probably missed some events), so whine loudly but still continue, # since it shouldn't fail completion of the transaction. last_txn_id = self._get_last_txn(txn, service.id) - if (last_txn_id + 1) != txn_id: + if (txn_id + 1) != txn_id: logger.error( "appservice: Completing a transaction which has an ID > 1 from " "the last ID sent to this AS. We've either dropped events or " @@ -272,12 +291,13 @@ class ApplicationServiceTransactionWorkerStore( "complete_appservice_txn", _complete_appservice_txn ) - async def get_oldest_unsent_txn(self, service): - """Get the oldest transaction which has not been sent for this - service. + async def get_oldest_unsent_txn( + self, service: ApplicationService + ) -> Optional[AppServiceTransaction]: + """Get the oldest transaction which has not been sent for this service. Args: - service(ApplicationService): The app service to get the oldest txn. + service: The app service to get the oldest txn. Returns: An AppServiceTransaction or None. """ @@ -313,7 +333,7 @@ class ApplicationServiceTransactionWorkerStore( service=service, id=entry["txn_id"], events=events, ephemeral=[] ) - def _get_last_txn(self, txn, service_id): + def _get_last_txn(self, txn, service_id: Optional[str]) -> int: txn.execute( "SELECT last_txn FROM application_services_state WHERE as_id=?", (service_id,), @@ -324,7 +344,7 @@ class ApplicationServiceTransactionWorkerStore( else: return int(last_txn_id[0]) # select 'last_txn' col - async def set_appservice_last_pos(self, pos) -> None: + async def set_appservice_last_pos(self, pos: int) -> None: def set_appservice_last_pos_txn(txn): txn.execute( "UPDATE appservice_stream_position SET stream_ordering = ?", (pos,) @@ -334,7 +354,9 @@ class ApplicationServiceTransactionWorkerStore( "set_appservice_last_pos", set_appservice_last_pos_txn ) - async def get_new_events_for_appservice(self, current_id, limit): + async def get_new_events_for_appservice( + self, current_id: int, limit: int + ) -> Tuple[int, List[EventBase]]: """Get all new events for an appservice""" def get_new_events_for_appservice_txn(txn): @@ -394,7 +416,7 @@ class ApplicationServiceTransactionWorkerStore( ) async def set_type_stream_id_for_appservice( - self, service: ApplicationService, type: str, pos: int + self, service: ApplicationService, type: str, pos: Optional[int] ) -> None: if type not in ("read_receipt", "presence"): raise ValueError(