Fixup types
parent
d91053493c
commit
97d1739910
|
@ -14,12 +14,13 @@
|
|||
# limitations under the License.
|
||||
import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.appservice.api import ApplicationServiceApi
|
||||
from synapse.types import GroupID, get_domain_from_id
|
||||
from synapse.util.caches.descriptors import cached
|
||||
from synapse.events import EventBase
|
||||
from synapse.types import GroupID, UserID, get_domain_from_id
|
||||
from synapse.util.caches.descriptors import _CacheContext, cached
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.storage.databases.main import DataStore
|
||||
|
@ -35,7 +36,13 @@ class ApplicationServiceState:
|
|||
class AppServiceTransaction:
|
||||
"""Represents an application service transaction."""
|
||||
|
||||
def __init__(self, service, id, events, ephemeral=None):
|
||||
def __init__(
|
||||
self,
|
||||
service: ApplicationService,
|
||||
id: int,
|
||||
events: List[EventBase],
|
||||
ephemeral=None,
|
||||
):
|
||||
self.service = service
|
||||
self.id = id
|
||||
self.events = events
|
||||
|
@ -198,9 +205,11 @@ class ApplicationService:
|
|||
return does_match
|
||||
|
||||
@cached(num_args=1, cache_context=True)
|
||||
async def matches_user_in_member_list(self, room_id, store, cache_context):
|
||||
async def matches_user_in_member_list(
|
||||
self, room_id: str, store: DataStore, cache_context: _CacheContext
|
||||
):
|
||||
member_list = await store.get_users_in_room(
|
||||
room_id, on_invalidate=cache_context.invalidate
|
||||
room_id
|
||||
)
|
||||
|
||||
# check joined member events
|
||||
|
@ -246,7 +255,9 @@ class ApplicationService:
|
|||
return False
|
||||
|
||||
@cached(num_args=1, cache_context=True)
|
||||
async def is_interested_in_presence(self, user_id, store, cache_context):
|
||||
async def is_interested_in_presence(
|
||||
self, user_id: UserID, store: DataStore, cache_context: _CacheContext
|
||||
):
|
||||
# Find all the rooms the sender is in
|
||||
if self.is_interested_in_user(user_id.to_string()):
|
||||
return True
|
||||
|
@ -254,7 +265,7 @@ class ApplicationService:
|
|||
|
||||
# Then find out if the appservice is interested in any of those rooms
|
||||
for room_id in room_ids:
|
||||
if await self.matches_user_in_member_list(room_id, store, cache_context):
|
||||
if await self.matches_user_in_member_list(room_id, store):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
|
|
@ -14,12 +14,13 @@
|
|||
# limitations under the License.
|
||||
import logging
|
||||
import urllib
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, Any, List, Optional
|
||||
|
||||
from prometheus_client import Counter
|
||||
|
||||
from synapse.api.constants import EventTypes, ThirdPartyEntityKind
|
||||
from synapse.api.errors import CodeMessageException
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.utils import serialize_event
|
||||
from synapse.http.client import SimpleHttpClient
|
||||
from synapse.types import JsonDict, ThirdPartyInstanceID
|
||||
|
@ -201,7 +202,13 @@ class ApplicationServiceApi(SimpleHttpClient):
|
|||
key = (service.id, protocol)
|
||||
return await self.protocol_meta_cache.wrap(key, _get)
|
||||
|
||||
async def push_bulk(self, service, events, ephemeral=None, txn_id=None):
|
||||
async def push_bulk(
|
||||
self,
|
||||
service: ApplicationService,
|
||||
events: List[EventBase],
|
||||
ephemeral: Optional[Any] = None,
|
||||
txn_id: Optional[int] = None,
|
||||
):
|
||||
if service.url is None:
|
||||
return True
|
||||
|
||||
|
@ -211,10 +218,9 @@ class ApplicationServiceApi(SimpleHttpClient):
|
|||
logger.warning(
|
||||
"push_bulk: Missing txn ID sending events to %s", service.url
|
||||
)
|
||||
txn_id = str(0)
|
||||
txn_id = str(txn_id)
|
||||
txn_id = 0
|
||||
|
||||
uri = service.url + ("/transactions/%s" % urllib.parse.quote(txn_id))
|
||||
uri = service.url + ("/transactions/%s" % urllib.parse.quote(str(txn_id)))
|
||||
body = {"events": events}
|
||||
if ephemeral:
|
||||
body["de.sorunome.msc2409.ephemeral"] = ephemeral
|
||||
|
|
|
@ -49,8 +49,10 @@ This is all tied together by the AppServiceScheduler which DIs the required
|
|||
components.
|
||||
"""
|
||||
import logging
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from synapse.appservice import ApplicationServiceState
|
||||
from synapse.appservice import ApplicationService, ApplicationServiceState
|
||||
from synapse.events import EventBase
|
||||
from synapse.logging.context import run_in_background
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
|
||||
|
@ -82,10 +84,12 @@ class ApplicationServiceScheduler:
|
|||
for service in services:
|
||||
self.txn_ctrl.start_recoverer(service)
|
||||
|
||||
def submit_event_for_as(self, service, event):
|
||||
def submit_event_for_as(self, service: ApplicationService, event: EventBase):
|
||||
self.queuer.enqueue(service, event)
|
||||
|
||||
def submit_ephemeral_events_for_as(self, service, events):
|
||||
def submit_ephemeral_events_for_as(
|
||||
self, service: ApplicationService, events: List[Any]
|
||||
):
|
||||
self.queuer.enqueue_ephemeral(service, events)
|
||||
|
||||
|
||||
|
@ -99,7 +103,7 @@ class _ServiceQueuer:
|
|||
|
||||
def __init__(self, txn_ctrl, clock):
|
||||
self.queued_events = {} # dict of {service_id: [events]}
|
||||
self.queued_ephemeral = {} # dict of {service_id: [events]}
|
||||
self.queued_ephemeral = {} # dict of {service_id: [events]}
|
||||
|
||||
# the appservices which currently have a transaction in flight
|
||||
self.requests_in_flight = set()
|
||||
|
@ -118,7 +122,7 @@ class _ServiceQueuer:
|
|||
"as-sender-%s" % (service.id), self._send_request, service
|
||||
)
|
||||
|
||||
def enqueue_ephemeral(self, service, events):
|
||||
def enqueue_ephemeral(self, service: ApplicationService, events: List[Any]):
|
||||
self.queued_ephemeral.setdefault(service.id, []).extend(events)
|
||||
|
||||
# start a sender for this appservice if we don't already have one
|
||||
|
@ -130,7 +134,9 @@ class _ServiceQueuer:
|
|||
"as-sender-%s" % (service.id), self._send_request, service
|
||||
)
|
||||
|
||||
async def _send_request(self, service, ephemeral=None):
|
||||
async def _send_request(
|
||||
self, service: ApplicationService, ephemeral: Optional[Any] = None
|
||||
):
|
||||
# sanity-check: we shouldn't get here if this service already has a sender
|
||||
# running.
|
||||
assert service.id not in self.requests_in_flight
|
||||
|
@ -175,9 +181,16 @@ class _TransactionController:
|
|||
# for UTs
|
||||
self.RECOVERER_CLASS = _Recoverer
|
||||
|
||||
async def send(self, service, events, ephemeral=None):
|
||||
async def send(
|
||||
self,
|
||||
service: ApplicationService,
|
||||
events: List[EventBase],
|
||||
ephemeral: Optional[Any] = None,
|
||||
):
|
||||
try:
|
||||
txn = await self.store.create_appservice_txn(service=service, events=events, ephemeral=ephemeral)
|
||||
txn = await self.store.create_appservice_txn(
|
||||
service=service, events=events, ephemeral=ephemeral
|
||||
)
|
||||
service_is_up = await self.is_service_up(service)
|
||||
if service_is_up:
|
||||
sent = await txn.send(self.as_api)
|
||||
|
@ -221,7 +234,7 @@ class _TransactionController:
|
|||
recoverer.recover()
|
||||
logger.info("Now %i active recoverers", len(self.recoverers))
|
||||
|
||||
async def is_service_up(self, service):
|
||||
async def is_service_up(self, service: ApplicationService):
|
||||
state = await self.store.get_appservice_state(service)
|
||||
return state == ApplicationServiceState.UP or state is None
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# limitations under the License.
|
||||
import logging
|
||||
|
||||
from synapse.appservice import ApplicationService
|
||||
from synapse.handlers._base import BaseHandler
|
||||
from synapse.types import ReadReceipt, get_domain_from_id
|
||||
from synapse.util.async_helpers import maybe_awaitable
|
||||
|
@ -140,7 +141,9 @@ class ReceiptEventSource:
|
|||
|
||||
return (events, to_key)
|
||||
|
||||
async def get_new_events_as(self, from_key, service, **kwargs):
|
||||
async def get_new_events_as(
|
||||
self, from_key: int, service: ApplicationService, **kwargs
|
||||
):
|
||||
from_key = int(from_key)
|
||||
to_key = self.get_current_key()
|
||||
|
||||
|
|
|
@ -15,9 +15,11 @@
|
|||
# limitations under the License.
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from synapse.appservice import AppServiceTransaction
|
||||
from synapse.appservice import ApplicationService, AppServiceTransaction
|
||||
from synapse.config.appservice import load_appservices
|
||||
from synapse.events import EventBase
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||
from synapse.storage.database import DatabasePool
|
||||
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||
|
@ -172,7 +174,12 @@ class ApplicationServiceTransactionWorkerStore(
|
|||
"application_services_state", {"as_id": service.id}, {"state": state}
|
||||
)
|
||||
|
||||
async def create_appservice_txn(self, service, events, ephemeral=None):
|
||||
async def create_appservice_txn(
|
||||
self,
|
||||
service: ApplicationService,
|
||||
events: List[EventBase],
|
||||
ephemeral: Optional[Any] = None,
|
||||
):
|
||||
"""Atomically creates a new transaction for this application service
|
||||
with the given list of events.
|
||||
|
||||
|
@ -353,7 +360,9 @@ class ApplicationServiceTransactionWorkerStore(
|
|||
|
||||
return upper_bound, events
|
||||
|
||||
async def get_type_stream_id_for_appservice(self, service, type: str) -> int:
|
||||
async def get_type_stream_id_for_appservice(
|
||||
self, service: ApplicationService, type: str
|
||||
) -> int:
|
||||
def get_type_stream_id_for_appservice_txn(txn):
|
||||
stream_id_type = "%s_stream_id" % type
|
||||
txn.execute(
|
||||
|
@ -371,7 +380,7 @@ class ApplicationServiceTransactionWorkerStore(
|
|||
)
|
||||
|
||||
async def set_type_stream_id_for_appservice(
|
||||
self, service, type: str, pos: int
|
||||
self, service: ApplicationService, type: str, pos: int
|
||||
) -> None:
|
||||
def set_type_stream_id_for_appservice_txn(txn):
|
||||
stream_id_type = "%s_stream_id" % type
|
||||
|
|
|
@ -284,7 +284,9 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
|
|||
return results
|
||||
|
||||
@cached(num_args=2,)
|
||||
async def _get_linearized_receipts_for_all_rooms(self, to_key, from_key=None):
|
||||
async def _get_linearized_receipts_for_all_rooms(
|
||||
self, to_key: int, from_key: Optional[int] = None
|
||||
):
|
||||
def f(txn):
|
||||
if from_key:
|
||||
sql = """
|
||||
|
|
Loading…
Reference in New Issue