Fixup types

hs/push-reports-to-as
Will Hunt 2020-10-01 14:50:29 +01:00
parent d91053493c
commit 97d1739910
6 changed files with 72 additions and 28 deletions

View File

@ -14,12 +14,13 @@
# limitations under the License.
import logging
import re
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, List
from synapse.api.constants import EventTypes
from synapse.appservice.api import ApplicationServiceApi
from synapse.types import GroupID, get_domain_from_id
from synapse.util.caches.descriptors import cached
from synapse.events import EventBase
from synapse.types import GroupID, UserID, get_domain_from_id
from synapse.util.caches.descriptors import _CacheContext, cached
if TYPE_CHECKING:
from synapse.storage.databases.main import DataStore
@ -35,7 +36,13 @@ class ApplicationServiceState:
class AppServiceTransaction:
"""Represents an application service transaction."""
def __init__(self, service, id, events, ephemeral=None):
def __init__(
self,
service: ApplicationService,
id: int,
events: List[EventBase],
ephemeral=None,
):
self.service = service
self.id = id
self.events = events
@ -198,9 +205,11 @@ class ApplicationService:
return does_match
@cached(num_args=1, cache_context=True)
async def matches_user_in_member_list(self, room_id, store, cache_context):
async def matches_user_in_member_list(
self, room_id: str, store: DataStore, cache_context: _CacheContext
):
member_list = await store.get_users_in_room(
room_id, on_invalidate=cache_context.invalidate
room_id
)
# check joined member events
@ -246,7 +255,9 @@ class ApplicationService:
return False
@cached(num_args=1, cache_context=True)
async def is_interested_in_presence(self, user_id, store, cache_context):
async def is_interested_in_presence(
self, user_id: UserID, store: DataStore, cache_context: _CacheContext
):
# Find all the rooms the sender is in
if self.is_interested_in_user(user_id.to_string()):
return True
@ -254,7 +265,7 @@ class ApplicationService:
# Then find out if the appservice is interested in any of those rooms
for room_id in room_ids:
if await self.matches_user_in_member_list(room_id, store, cache_context):
if await self.matches_user_in_member_list(room_id, store):
return True
return False

View File

@ -14,12 +14,13 @@
# limitations under the License.
import logging
import urllib
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Any, List, Optional
from prometheus_client import Counter
from synapse.api.constants import EventTypes, ThirdPartyEntityKind
from synapse.api.errors import CodeMessageException
from synapse.events import EventBase
from synapse.events.utils import serialize_event
from synapse.http.client import SimpleHttpClient
from synapse.types import JsonDict, ThirdPartyInstanceID
@ -201,7 +202,13 @@ class ApplicationServiceApi(SimpleHttpClient):
key = (service.id, protocol)
return await self.protocol_meta_cache.wrap(key, _get)
async def push_bulk(self, service, events, ephemeral=None, txn_id=None):
async def push_bulk(
self,
service: ApplicationService,
events: List[EventBase],
ephemeral: Optional[Any] = None,
txn_id: Optional[int] = None,
):
if service.url is None:
return True
@ -211,10 +218,9 @@ class ApplicationServiceApi(SimpleHttpClient):
logger.warning(
"push_bulk: Missing txn ID sending events to %s", service.url
)
txn_id = str(0)
txn_id = str(txn_id)
txn_id = 0
uri = service.url + ("/transactions/%s" % urllib.parse.quote(txn_id))
uri = service.url + ("/transactions/%s" % urllib.parse.quote(str(txn_id)))
body = {"events": events}
if ephemeral:
body["de.sorunome.msc2409.ephemeral"] = ephemeral

View File

@ -49,8 +49,10 @@ This is all tied together by the AppServiceScheduler which DIs the required
components.
"""
import logging
from typing import Any, List, Optional
from synapse.appservice import ApplicationServiceState
from synapse.appservice import ApplicationService, ApplicationServiceState
from synapse.events import EventBase
from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
@ -82,10 +84,12 @@ class ApplicationServiceScheduler:
for service in services:
self.txn_ctrl.start_recoverer(service)
def submit_event_for_as(self, service, event):
def submit_event_for_as(self, service: ApplicationService, event: EventBase):
self.queuer.enqueue(service, event)
def submit_ephemeral_events_for_as(self, service, events):
def submit_ephemeral_events_for_as(
self, service: ApplicationService, events: List[Any]
):
self.queuer.enqueue_ephemeral(service, events)
@ -99,7 +103,7 @@ class _ServiceQueuer:
def __init__(self, txn_ctrl, clock):
self.queued_events = {} # dict of {service_id: [events]}
self.queued_ephemeral = {} # dict of {service_id: [events]}
self.queued_ephemeral = {} # dict of {service_id: [events]}
# the appservices which currently have a transaction in flight
self.requests_in_flight = set()
@ -118,7 +122,7 @@ class _ServiceQueuer:
"as-sender-%s" % (service.id), self._send_request, service
)
def enqueue_ephemeral(self, service, events):
def enqueue_ephemeral(self, service: ApplicationService, events: List[Any]):
self.queued_ephemeral.setdefault(service.id, []).extend(events)
# start a sender for this appservice if we don't already have one
@ -130,7 +134,9 @@ class _ServiceQueuer:
"as-sender-%s" % (service.id), self._send_request, service
)
async def _send_request(self, service, ephemeral=None):
async def _send_request(
self, service: ApplicationService, ephemeral: Optional[Any] = None
):
# sanity-check: we shouldn't get here if this service already has a sender
# running.
assert service.id not in self.requests_in_flight
@ -175,9 +181,16 @@ class _TransactionController:
# for UTs
self.RECOVERER_CLASS = _Recoverer
async def send(self, service, events, ephemeral=None):
async def send(
self,
service: ApplicationService,
events: List[EventBase],
ephemeral: Optional[Any] = None,
):
try:
txn = await self.store.create_appservice_txn(service=service, events=events, ephemeral=ephemeral)
txn = await self.store.create_appservice_txn(
service=service, events=events, ephemeral=ephemeral
)
service_is_up = await self.is_service_up(service)
if service_is_up:
sent = await txn.send(self.as_api)
@ -221,7 +234,7 @@ class _TransactionController:
recoverer.recover()
logger.info("Now %i active recoverers", len(self.recoverers))
async def is_service_up(self, service):
async def is_service_up(self, service: ApplicationService):
state = await self.store.get_appservice_state(service)
return state == ApplicationServiceState.UP or state is None

View File

@ -14,6 +14,7 @@
# limitations under the License.
import logging
from synapse.appservice import ApplicationService
from synapse.handlers._base import BaseHandler
from synapse.types import ReadReceipt, get_domain_from_id
from synapse.util.async_helpers import maybe_awaitable
@ -140,7 +141,9 @@ class ReceiptEventSource:
return (events, to_key)
async def get_new_events_as(self, from_key, service, **kwargs):
async def get_new_events_as(
self, from_key: int, service: ApplicationService, **kwargs
):
from_key = int(from_key)
to_key = self.get_current_key()

View File

@ -15,9 +15,11 @@
# limitations under the License.
import logging
import re
from typing import Any, List, Optional
from synapse.appservice import AppServiceTransaction
from synapse.appservice import ApplicationService, AppServiceTransaction
from synapse.config.appservice import load_appservices
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.events_worker import EventsWorkerStore
@ -172,7 +174,12 @@ class ApplicationServiceTransactionWorkerStore(
"application_services_state", {"as_id": service.id}, {"state": state}
)
async def create_appservice_txn(self, service, events, ephemeral=None):
async def create_appservice_txn(
self,
service: ApplicationService,
events: List[EventBase],
ephemeral: Optional[Any] = None,
):
"""Atomically creates a new transaction for this application service
with the given list of events.
@ -353,7 +360,9 @@ class ApplicationServiceTransactionWorkerStore(
return upper_bound, events
async def get_type_stream_id_for_appservice(self, service, type: str) -> int:
async def get_type_stream_id_for_appservice(
self, service: ApplicationService, type: str
) -> int:
def get_type_stream_id_for_appservice_txn(txn):
stream_id_type = "%s_stream_id" % type
txn.execute(
@ -371,7 +380,7 @@ class ApplicationServiceTransactionWorkerStore(
)
async def set_type_stream_id_for_appservice(
self, service, type: str, pos: int
self, service: ApplicationService, type: str, pos: int
) -> None:
def set_type_stream_id_for_appservice_txn(txn):
stream_id_type = "%s_stream_id" % type

View File

@ -284,7 +284,9 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
return results
@cached(num_args=2,)
async def _get_linearized_receipts_for_all_rooms(self, to_key, from_key=None):
async def _get_linearized_receipts_for_all_rooms(
self, to_key: int, from_key: Optional[int] = None
):
def f(txn):
if from_key:
sql = """