Changes to handlers to support fetching events for appservices

hs/super-wip-edus-down-sync
Will Hunt 2020-09-21 15:10:06 +01:00
parent 78911ca46a
commit ae724db899
4 changed files with 143 additions and 0 deletions

View File

@ -20,6 +20,20 @@ from prometheus_client import Counter
from twisted.internet import defer
import synapse
from typing import (
Awaitable,
Callable,
Dict,
Iterable,
List,
Optional,
Set,
Tuple,
TypeVar,
Union,
)
from synapse.types import RoomStreamToken
from synapse.api.constants import EventTypes
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics import (
@ -43,6 +57,7 @@ class ApplicationServicesHandler:
self.started_scheduler = False
self.clock = hs.get_clock()
self.notify_appservices = hs.config.notify_appservices
self.event_sources = hs.get_event_sources()
self.current_max = 0
self.is_processing = False
@ -158,6 +173,40 @@ class ApplicationServicesHandler:
finally:
self.is_processing = False
async def notify_interested_services_ephemeral(self, stream_key: str, new_token: Union[int, RoomStreamToken]):
services = [service for service in self.store.get_app_services() if service.supports_ephemeral]
if not services or not self.notify_appservices:
return
logger.info("Checking interested services for %s" % (stream_key))
with Measure(self.clock, "notify_interested_services_ephemeral"):
for service in services:
events = []
if stream_key == "typing_key":
from_key = new_token - 1
typing_source = self.event_sources.sources["typing"]
# Get the typing events from just before current
typing, _typing_key = await typing_source.get_new_events_as(
service=service,
from_key=from_key
)
events = typing
elif stream_key == "receipt_key":
from_key = new_token - 1
receipts_source = self.event_sources.sources["receipt"]
receipts, _receipts_key = await receipts_source.get_new_events_as(
service=service,
from_key=from_key
)
events = receipts
elif stream_key == "presence":
# TODO: This. Presence means trying to determine all the
# users the appservice cares about, which means checking
# all the rooms the appservice is in.
if events:
# TODO: Do in background?
await self.scheduler.submit_ephemeral_events_for_as(service, events)
async def query_user_exists(self, user_id):
"""Check if any application service knows this user_id exists.

View File

@ -140,5 +140,27 @@ class ReceiptEventSource:
return (events, to_key)
async def get_new_events_as(self, from_key, service, **kwargs):
from_key = int(from_key)
to_key = self.get_current_key()
if from_key == to_key:
return [], to_key
# We first need to fetch all new receipts
rooms_to_events = await self.store.get_linearized_receipts_for_all_rooms(
from_key=from_key, to_key=to_key
)
# Then filter down to rooms that the AS can read
events = []
for room_id, event in rooms_to_events.items():
if not await service.matches_user_in_member_list(room_id, self.store):
continue
events.append(event)
return (events, to_key)
def get_current_key(self, direction="f"):
return self.store.get_max_receipt_stream_id()

View File

@ -19,6 +19,7 @@ from collections import namedtuple
from typing import TYPE_CHECKING, List, Set, Tuple
from synapse.api.errors import AuthError, ShadowBanError, SynapseError
from synapse.appservice import ApplicationService
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.streams import TypingStream
from synapse.types import UserID, get_domain_from_id
@ -430,6 +431,24 @@ class TypingNotificationEventSource:
"content": {"user_ids": list(typing)},
}
async def get_new_events_as(self, from_key, service, **kwargs):
with Measure(self.clock, "typing.get_new_events_as"):
from_key = int(from_key)
handler = self.get_typing_handler()
events = []
for room_id in handler._room_serials.keys():
if handler._room_serials[room_id] <= from_key:
print("Key too old")
continue
# XXX: Store gut wrenching
if not await service.matches_user_in_member_list(room_id, handler.store):
continue
events.append(self._make_event_for(room_id))
return (events, handler._latest_room_serial)
async def get_new_events(self, from_key, room_ids, **kwargs):
with Measure(self.clock, "typing.get_new_events"):
from_key = int(from_key)

View File

@ -123,6 +123,15 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
for row in rows
}
async def get_linearized_receipts_for_all_rooms(
self, to_key: int, from_key: Optional[int] = None
) -> List[dict]:
results = await self._get_linearized_receipts_for_all_rooms(
to_key, from_key=from_key
)
return results
async def get_linearized_receipts_for_rooms(
self, room_ids: List[str], to_key: int, from_key: Optional[int] = None
) -> List[dict]:
@ -274,6 +283,50 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
}
return results
@cached(
num_args=2,
)
async def _get_linearized_receipts_for_all_rooms(self, to_key, from_key=None):
def f(txn):
if from_key:
sql = """
SELECT * FROM receipts_linearized WHERE
stream_id > ? AND stream_id <= ?
"""
txn.execute(sql, [from_key, to_key])
else:
sql = """
SELECT * FROM receipts_linearized WHERE
stream_id <= ?
"""
txn.execute(sql, [to_key])
return self.db_pool.cursor_to_dict(txn)
txn_results = await self.db_pool.runInteraction(
"_get_linearized_receipts_for_all_rooms", f
)
results = {}
for row in txn_results:
# We want a single event per room, since we want to batch the
# receipts by room, event and type.
room_event = results.setdefault(
row["room_id"],
{"type": "m.receipt", "room_id": row["room_id"], "content": {}},
)
# The content is of the form:
# {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. }
event_entry = room_event["content"].setdefault(row["event_id"], {})
receipt_type = event_entry.setdefault(row["receipt_type"], {})
receipt_type[row["user_id"]] = db_to_json(row["data"])
return results
async def get_users_sent_receipts_between(
self, last_id: int, current_id: int
) -> List[str]: