Some refactors around receipts stream (#16426)

pull/16429/head
Erik Johnston 2023-10-04 18:28:40 +03:00 committed by GitHub
parent a01ee24734
commit 80ec81dcc5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 111 additions and 80 deletions

1
changelog.d/16426.misc Normal file
View File

@ -0,0 +1 @@
Refactor some code to simplify and better type receipts stream adjacent code.

View File

@ -216,7 +216,7 @@ class ApplicationServicesHandler:
def notify_interested_services_ephemeral(
self,
stream_key: str,
stream_key: StreamKeyType,
new_token: Union[int, RoomStreamToken],
users: Collection[Union[str, UserID]],
) -> None:
@ -326,7 +326,7 @@ class ApplicationServicesHandler:
async def _notify_interested_services_ephemeral(
self,
services: List[ApplicationService],
stream_key: str,
stream_key: StreamKeyType,
new_token: int,
users: Collection[Union[str, UserID]],
) -> None:

View File

@ -19,7 +19,7 @@ from synapse.api.errors import SynapseError, UnrecognizedRequestError
from synapse.push.clientformat import format_push_rules_for_user
from synapse.storage.push_rule import RuleNotFoundException
from synapse.synapse_rust.push import get_base_rule_ids
from synapse.types import JsonDict, UserID
from synapse.types import JsonDict, StreamKeyType, UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -114,7 +114,9 @@ class PushRulesHandler:
user_id: the user ID the change is for.
"""
stream_id = self._main_store.get_max_push_rules_stream_id()
self._notifier.on_new_event("push_rules_key", stream_id, users=[user_id])
self._notifier.on_new_event(
StreamKeyType.PUSH_RULES, stream_id, users=[user_id]
)
async def push_rules_for_user(
self, user: UserID

View File

@ -130,11 +130,10 @@ class ReceiptsHandler:
async def _handle_new_receipts(self, receipts: List[ReadReceipt]) -> bool:
"""Takes a list of receipts, stores them and informs the notifier."""
min_batch_id: Optional[int] = None
max_batch_id: Optional[int] = None
receipts_persisted: List[ReadReceipt] = []
for receipt in receipts:
res = await self.store.insert_receipt(
stream_id = await self.store.insert_receipt(
receipt.room_id,
receipt.receipt_type,
receipt.user_id,
@ -143,30 +142,26 @@ class ReceiptsHandler:
receipt.data,
)
if not res:
# res will be None if this receipt is 'old'
if stream_id is None:
# stream_id will be None if this receipt is 'old'
continue
stream_id, max_persisted_id = res
receipts_persisted.append(receipt)
if min_batch_id is None or stream_id < min_batch_id:
min_batch_id = stream_id
if max_batch_id is None or max_persisted_id > max_batch_id:
max_batch_id = max_persisted_id
# Either both of these should be None or neither.
if min_batch_id is None or max_batch_id is None:
if not receipts_persisted:
# no new receipts
return False
affected_room_ids = list({r.room_id for r in receipts})
max_batch_id = self.store.get_max_receipt_stream_id()
affected_room_ids = list({r.room_id for r in receipts_persisted})
self.notifier.on_new_event(
StreamKeyType.RECEIPT, max_batch_id, rooms=affected_room_ids
)
# Note that the min here shouldn't be relied upon to be accurate.
await self.hs.get_pusherpool().on_new_receipts(
min_batch_id, max_batch_id, affected_room_ids
{r.user_id for r in receipts_persisted}
)
return True

View File

@ -126,7 +126,7 @@ class _NotifierUserStream:
def notify(
self,
stream_key: str,
stream_key: StreamKeyType,
stream_id: Union[int, RoomStreamToken],
time_now_ms: int,
) -> None:
@ -454,7 +454,7 @@ class Notifier:
def on_new_event(
self,
stream_key: str,
stream_key: StreamKeyType,
new_token: Union[int, RoomStreamToken],
users: Optional[Collection[Union[str, UserID]]] = None,
rooms: Optional[StrCollection] = None,
@ -655,30 +655,29 @@ class Notifier:
events: List[Union[JsonDict, EventBase]] = []
end_token = from_token
for name, source in self.event_sources.sources.get_sources():
keyname = "%s_key" % name
before_id = getattr(before_token, keyname)
after_id = getattr(after_token, keyname)
for keyname, source in self.event_sources.sources.get_sources():
before_id = before_token.get_field(keyname)
after_id = after_token.get_field(keyname)
if before_id == after_id:
continue
new_events, new_key = await source.get_new_events(
user=user,
from_key=getattr(from_token, keyname),
from_key=from_token.get_field(keyname),
limit=limit,
is_guest=is_peeking,
room_ids=room_ids,
explicit_room_id=explicit_room_id,
)
if name == "room":
if keyname == StreamKeyType.ROOM:
new_events = await filter_events_for_client(
self._storage_controllers,
user.to_string(),
new_events,
is_peeking=is_peeking,
)
elif name == "presence":
elif keyname == StreamKeyType.PRESENCE:
now = self.clock.time_msec()
new_events[:] = [
{

View File

@ -182,7 +182,7 @@ class Pusher(metaclass=abc.ABCMeta):
raise NotImplementedError()
@abc.abstractmethod
def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None:
def on_new_receipts(self) -> None:
raise NotImplementedError()
@abc.abstractmethod

View File

@ -99,7 +99,7 @@ class EmailPusher(Pusher):
pass
self.timed_call = None
def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None:
def on_new_receipts(self) -> None:
# We could wake up and cancel the timer but there tend to be quite a
# lot of read receipts so it's probably less work to just let the
# timer fire

View File

@ -160,7 +160,7 @@ class HttpPusher(Pusher):
if should_check_for_notifs:
self._start_processing()
def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None:
def on_new_receipts(self) -> None:
# Note that the min here shouldn't be relied upon to be accurate.
# We could check the receipts are actually m.read receipts here,

View File

@ -292,20 +292,12 @@ class PusherPool:
except Exception:
logger.exception("Exception in pusher on_new_notifications")
async def on_new_receipts(
self, min_stream_id: int, max_stream_id: int, affected_room_ids: Iterable[str]
) -> None:
async def on_new_receipts(self, users_affected: StrCollection) -> None:
if not self.pushers:
# nothing to do here.
return
try:
# Need to subtract 1 from the minimum because the lower bound here
# is not inclusive
users_affected = await self.store.get_users_sent_receipts_between(
min_stream_id - 1, max_stream_id
)
for u in users_affected:
# Don't push if the user account has expired
expired = await self._account_validity_handler.is_user_expired(u)
@ -314,7 +306,7 @@ class PusherPool:
if u in self.pushers:
for p in self.pushers[u].values():
p.on_new_receipts(min_stream_id, max_stream_id)
p.on_new_receipts()
except Exception:
logger.exception("Exception in pusher on_new_receipts")

View File

@ -129,9 +129,7 @@ class ReplicationDataHandler:
self.notifier.on_new_event(
StreamKeyType.RECEIPT, token, rooms=[row.room_id for row in rows]
)
await self._pusher_pool.on_new_receipts(
token, token, {row.room_id for row in rows}
)
await self._pusher_pool.on_new_receipts({row.user_id for row in rows})
elif stream_name == ToDeviceStream.NAME:
entities = [row.entity for row in rows if row.entity.startswith("@")]
if entities:

View File

@ -208,7 +208,7 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore):
"message": "Set room key",
"room_id": room_id,
"session_id": session_id,
StreamKeyType.ROOM: room_key,
StreamKeyType.ROOM.value: room_key,
}
)

View File

@ -742,7 +742,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
event_ids: List[str],
thread_id: Optional[str],
data: dict,
) -> Optional[Tuple[int, int]]:
) -> Optional[int]:
"""Insert a receipt, either from local client or remote server.
Automatically does conversion between linearized and graph
@ -804,9 +804,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
data,
)
max_persisted_id = self._receipts_id_gen.get_current_token()
return stream_id, max_persisted_id
return stream_id
async def _insert_graph_receipt(
self,

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Iterator, Tuple
from typing import TYPE_CHECKING, Sequence, Tuple
import attr
@ -23,7 +23,7 @@ from synapse.handlers.room import RoomEventSource
from synapse.handlers.typing import TypingNotificationEventSource
from synapse.logging.opentracing import trace
from synapse.streams import EventSource
from synapse.types import StreamToken
from synapse.types import StreamKeyType, StreamToken
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -37,9 +37,14 @@ class _EventSourcesInner:
receipt: ReceiptEventSource
account_data: AccountDataEventSource
def get_sources(self) -> Iterator[Tuple[str, EventSource]]:
for attribute in attr.fields(_EventSourcesInner):
yield attribute.name, getattr(self, attribute.name)
def get_sources(self) -> Sequence[Tuple[StreamKeyType, EventSource]]:
return [
(StreamKeyType.ROOM, self.room),
(StreamKeyType.PRESENCE, self.presence),
(StreamKeyType.TYPING, self.typing),
(StreamKeyType.RECEIPT, self.receipt),
(StreamKeyType.ACCOUNT_DATA, self.account_data),
]
class EventSources:

View File

@ -22,8 +22,8 @@ from typing import (
Any,
ClassVar,
Dict,
Final,
List,
Literal,
Mapping,
Match,
MutableMapping,
@ -34,6 +34,7 @@ from typing import (
Type,
TypeVar,
Union,
overload,
)
import attr
@ -649,20 +650,20 @@ class RoomStreamToken:
return "s%d" % (self.stream,)
class StreamKeyType:
class StreamKeyType(Enum):
"""Known stream types.
A stream is a list of entities ordered by an incrementing "stream token".
"""
ROOM: Final = "room_key"
PRESENCE: Final = "presence_key"
TYPING: Final = "typing_key"
RECEIPT: Final = "receipt_key"
ACCOUNT_DATA: Final = "account_data_key"
PUSH_RULES: Final = "push_rules_key"
TO_DEVICE: Final = "to_device_key"
DEVICE_LIST: Final = "device_list_key"
ROOM = "room_key"
PRESENCE = "presence_key"
TYPING = "typing_key"
RECEIPT = "receipt_key"
ACCOUNT_DATA = "account_data_key"
PUSH_RULES = "push_rules_key"
TO_DEVICE = "to_device_key"
DEVICE_LIST = "device_list_key"
UN_PARTIAL_STATED_ROOMS = "un_partial_stated_rooms_key"
@ -784,7 +785,7 @@ class StreamToken:
def room_stream_id(self) -> int:
return self.room_key.stream
def copy_and_advance(self, key: str, new_value: Any) -> "StreamToken":
def copy_and_advance(self, key: StreamKeyType, new_value: Any) -> "StreamToken":
"""Advance the given key in the token to a new value if and only if the
new value is after the old value.
@ -797,16 +798,44 @@ class StreamToken:
return new_token
new_token = self.copy_and_replace(key, new_value)
new_id = int(getattr(new_token, key))
old_id = int(getattr(self, key))
new_id = new_token.get_field(key)
old_id = self.get_field(key)
if old_id < new_id:
return new_token
else:
return self
def copy_and_replace(self, key: str, new_value: Any) -> "StreamToken":
return attr.evolve(self, **{key: new_value})
def copy_and_replace(self, key: StreamKeyType, new_value: Any) -> "StreamToken":
return attr.evolve(self, **{key.value: new_value})
@overload
def get_field(self, key: Literal[StreamKeyType.ROOM]) -> RoomStreamToken:
...
@overload
def get_field(
self,
key: Literal[
StreamKeyType.ACCOUNT_DATA,
StreamKeyType.DEVICE_LIST,
StreamKeyType.PRESENCE,
StreamKeyType.PUSH_RULES,
StreamKeyType.RECEIPT,
StreamKeyType.TO_DEVICE,
StreamKeyType.TYPING,
StreamKeyType.UN_PARTIAL_STATED_ROOMS,
],
) -> int:
...
@overload
def get_field(self, key: StreamKeyType) -> Union[int, RoomStreamToken]:
...
def get_field(self, key: StreamKeyType) -> Union[int, RoomStreamToken]:
"""Returns the stream ID for the given key."""
return getattr(self, key.value)
StreamToken.START = StreamToken(RoomStreamToken(None, 0), 0, 0, 0, 0, 0, 0, 0, 0, 0)

View File

@ -31,7 +31,7 @@ from synapse.appservice import (
from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.rest.client import login, receipts, register, room, sendtodevice
from synapse.server import HomeServer
from synapse.types import JsonDict, RoomStreamToken
from synapse.types import JsonDict, RoomStreamToken, StreamKeyType
from synapse.util import Clock
from synapse.util.stringutils import random_string
@ -304,7 +304,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
)
self.handler.notify_interested_services_ephemeral(
"receipt_key", 580, ["@fakerecipient:example.com"]
StreamKeyType.RECEIPT, 580, ["@fakerecipient:example.com"]
)
self.mock_scheduler.enqueue_for_appservice.assert_called_once_with(
interested_service, ephemeral=[event]
@ -332,7 +332,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
)
self.handler.notify_interested_services_ephemeral(
"receipt_key", 580, ["@fakerecipient:example.com"]
StreamKeyType.RECEIPT, 580, ["@fakerecipient:example.com"]
)
# This method will be called, but with an empty list of events
self.mock_scheduler.enqueue_for_appservice.assert_called_once_with(
@ -634,7 +634,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
self.get_success(
self.hs.get_application_service_handler()._notify_interested_services_ephemeral(
services=[interested_appservice],
stream_key="receipt_key",
stream_key=StreamKeyType.RECEIPT,
new_token=stream_token,
users=[self.exclusive_as_user],
)

View File

@ -28,7 +28,7 @@ from synapse.federation.transport.server import TransportLayerServer
from synapse.handlers.typing import TypingWriterHandler
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
from synapse.server import HomeServer
from synapse.types import JsonDict, Requester, UserID, create_requester
from synapse.types import JsonDict, Requester, StreamKeyType, UserID, create_requester
from synapse.util import Clock
from tests import unittest
@ -203,7 +203,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
)
)
self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
self.on_new_event.assert_has_calls(
[call(StreamKeyType.TYPING, 1, rooms=[ROOM_ID])]
)
self.assertEqual(self.event_source.get_current_key(), 1)
events = self.get_success(
@ -273,7 +275,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, 200)
self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
self.on_new_event.assert_has_calls(
[call(StreamKeyType.TYPING, 1, rooms=[ROOM_ID])]
)
self.assertEqual(self.event_source.get_current_key(), 1)
events = self.get_success(
@ -349,7 +353,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
)
)
self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
self.on_new_event.assert_has_calls(
[call(StreamKeyType.TYPING, 1, rooms=[ROOM_ID])]
)
self.mock_federation_client.put_json.assert_called_once_with(
"farm",
@ -399,7 +405,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
)
)
self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
self.on_new_event.assert_has_calls(
[call(StreamKeyType.TYPING, 1, rooms=[ROOM_ID])]
)
self.on_new_event.reset_mock()
self.assertEqual(self.event_source.get_current_key(), 1)
@ -425,7 +433,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.reactor.pump([16])
self.on_new_event.assert_has_calls([call("typing_key", 2, rooms=[ROOM_ID])])
self.on_new_event.assert_has_calls(
[call(StreamKeyType.TYPING, 2, rooms=[ROOM_ID])]
)
self.assertEqual(self.event_source.get_current_key(), 2)
events = self.get_success(
@ -459,7 +469,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
)
)
self.on_new_event.assert_has_calls([call("typing_key", 3, rooms=[ROOM_ID])])
self.on_new_event.assert_has_calls(
[call(StreamKeyType.TYPING, 3, rooms=[ROOM_ID])]
)
self.on_new_event.reset_mock()
self.assertEqual(self.event_source.get_current_key(), 3)