diff --git a/changelog.d/9702.misc b/changelog.d/9702.misc deleted file mode 100644 index c6e63450a9..0000000000 --- a/changelog.d/9702.misc +++ /dev/null @@ -1 +0,0 @@ -Speed up federation transmission by using fewer database calls. Contributed by @ShadowJonathan. diff --git a/contrib/experiments/test_messaging.py b/contrib/experiments/test_messaging.py index 5dd172052b..31b8a68225 100644 --- a/contrib/experiments/test_messaging.py +++ b/contrib/experiments/test_messaging.py @@ -224,16 +224,14 @@ class HomeServer(ReplicationHandler): destinations = yield self.get_servers_for_context(room_name) try: - yield self.replication_layer.send_pdus( - [ - Pdu.create_new( - context=room_name, - pdu_type="sy.room.message", - content={"sender": sender, "body": body}, - origin=self.server_name, - destinations=destinations, - ) - ] + yield self.replication_layer.send_pdu( + Pdu.create_new( + context=room_name, + pdu_type="sy.room.message", + content={"sender": sender, "body": body}, + origin=self.server_name, + destinations=destinations, + ) ) except Exception as e: logger.exception(e) @@ -255,7 +253,7 @@ class HomeServer(ReplicationHandler): origin=self.server_name, destinations=destinations, ) - yield self.replication_layer.send_pdus([pdu]) + yield self.replication_layer.send_pdu(pdu) except Exception as e: logger.exception(e) @@ -267,18 +265,16 @@ class HomeServer(ReplicationHandler): destinations = yield self.get_servers_for_context(room_name) try: - yield self.replication_layer.send_pdus( - [ - Pdu.create_new( - context=room_name, - is_state=True, - pdu_type="sy.room.member", - state_key=invitee, - content={"membership": "invite"}, - origin=self.server_name, - destinations=destinations, - ) - ] + yield self.replication_layer.send_pdu( + Pdu.create_new( + context=room_name, + is_state=True, + pdu_type="sy.room.member", + state_key=invitee, + content={"membership": "invite"}, + origin=self.server_name, + destinations=destinations, + ) ) except Exception as e: logger.exception(e) diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 022bbf7dad..deb40f4610 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -14,26 +14,19 @@ import abc import logging -from typing import ( - TYPE_CHECKING, - Collection, - Dict, - Hashable, - Iterable, - List, - Optional, - Set, - Tuple, -) +from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Set, Tuple from prometheus_client import Counter +from twisted.internet import defer + import synapse.metrics from synapse.api.presence import UserPresenceState from synapse.events import EventBase from synapse.federation.sender.per_destination_queue import PerDestinationQueue from synapse.federation.sender.transaction_manager import TransactionManager from synapse.federation.units import Edu +from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics import ( LaterGauge, event_processing_loop_counter, @@ -262,27 +255,15 @@ class FederationSender(AbstractFederationSender): if not events and next_token >= self._last_poked_id: break - async def get_destinations_for_event( - event: EventBase, - ) -> Collection[str]: - """Computes the destinations to which this event must be sent. - - This returns an empty tuple when there are no destinations to send to, - or if this event is not from this homeserver and it is not sending - it on behalf of another server. - - Will also filter out destinations which this sender is not responsible for, - if multiple federation senders exist. - """ - + async def handle_event(event: EventBase) -> None: # Only send events for this server. send_on_behalf_of = event.internal_metadata.get_send_on_behalf_of() is_mine = self.is_mine_id(event.sender) if not is_mine and send_on_behalf_of is None: - return () + return if not event.internal_metadata.should_proactively_send(): - return () + return destinations = None # type: Optional[Set[str]] if not event.prev_event_ids(): @@ -317,7 +298,7 @@ class FederationSender(AbstractFederationSender): "Failed to calculate hosts in room for event: %s", event.event_id, ) - return () + return destinations = { d @@ -327,15 +308,17 @@ class FederationSender(AbstractFederationSender): ) } - destinations.discard(self.server_name) - if send_on_behalf_of is not None: # If we are sending the event on behalf of another server # then it already has the event and there is no reason to # send the event to it. destinations.discard(send_on_behalf_of) + logger.debug("Sending %s to %r", event, destinations) + if destinations: + await self._send_pdu(event, destinations) + now = self.clock.time_msec() ts = await self.store.get_received_ts(event.event_id) @@ -343,29 +326,24 @@ class FederationSender(AbstractFederationSender): "federation_sender" ).observe((now - ts) / 1000) - return destinations - return () + async def handle_room_events(events: Iterable[EventBase]) -> None: + with Measure(self.clock, "handle_room_events"): + for event in events: + await handle_event(event) - async def get_federatable_events_and_destinations( - events: Iterable[EventBase], - ) -> List[Tuple[EventBase, Collection[str]]]: - with Measure(self.clock, "get_destinations_for_events"): - # Fetch federation destinations per event, - # skip if get_destinations_for_event returns an empty collection, - # return list of event->destinations pairs. - return [ - (event, dests) - for (event, dests) in [ - (event, await get_destinations_for_event(event)) - for event in events - ] - if dests - ] + events_by_room = {} # type: Dict[str, List[EventBase]] + for event in events: + events_by_room.setdefault(event.room_id, []).append(event) - events_and_dests = await get_federatable_events_and_destinations(events) - - # Send corresponding events to each destination queue - await self._distribute_events(events_and_dests) + await make_deferred_yieldable( + defer.gatherResults( + [ + run_in_background(handle_room_events, evs) + for evs in events_by_room.values() + ], + consumeErrors=True, + ) + ) await self.store.update_federation_out_pos("events", next_token) @@ -383,7 +361,7 @@ class FederationSender(AbstractFederationSender): events_processed_counter.inc(len(events)) event_processing_loop_room_count.labels("federation_sender").inc( - len({event.room_id for event in events}) + len(events_by_room) ) event_processing_loop_counter.labels("federation_sender").inc() @@ -395,53 +373,34 @@ class FederationSender(AbstractFederationSender): finally: self._is_processing = False - async def _distribute_events( - self, - events_and_dests: Iterable[Tuple[EventBase, Collection[str]]], - ) -> None: - """Distribute events to the respective per_destination queues. + async def _send_pdu(self, pdu: EventBase, destinations: Iterable[str]) -> None: + # We loop through all destinations to see whether we already have + # a transaction in progress. If we do, stick it in the pending_pdus + # table and we'll get back to it later. - Also persists last-seen per-room stream_ordering to 'destination_rooms'. + destinations = set(destinations) + destinations.discard(self.server_name) + logger.debug("Sending to: %s", str(destinations)) - Args: - events_and_dests: A list of tuples, which are (event: EventBase, destinations: Collection[str]). - Every event is paired with its intended destinations (in federation). - """ - # Tuples of room_id + destination to their max-seen stream_ordering - room_with_dest_stream_ordering = {} # type: Dict[Tuple[str, str], int] + if not destinations: + return - # List of events to send to each destination - events_by_dest = {} # type: Dict[str, List[EventBase]] + sent_pdus_destination_dist_total.inc(len(destinations)) + sent_pdus_destination_dist_count.inc() - # For each event-destinations pair... - for event, destinations in events_and_dests: + assert pdu.internal_metadata.stream_ordering - # (we got this from the database, it's filled) - assert event.internal_metadata.stream_ordering - - sent_pdus_destination_dist_total.inc(len(destinations)) - sent_pdus_destination_dist_count.inc() - - # ...iterate over those destinations.. - for destination in destinations: - # ...update their stream-ordering... - room_with_dest_stream_ordering[(event.room_id, destination)] = max( - event.internal_metadata.stream_ordering, - room_with_dest_stream_ordering.get((event.room_id, destination), 0), - ) - - # ...and add the event to each destination queue. - events_by_dest.setdefault(destination, []).append(event) - - # Bulk-store destination_rooms stream_ids - await self.store.bulk_store_destination_rooms_entries( - room_with_dest_stream_ordering + # track the fact that we have a PDU for these destinations, + # to allow us to perform catch-up later on if the remote is unreachable + # for a while. + await self.store.store_destination_rooms_entries( + destinations, + pdu.room_id, + pdu.internal_metadata.stream_ordering, ) - for destination, pdus in events_by_dest.items(): - logger.debug("Sending %d pdus to %s", len(pdus), destination) - - self._get_per_destination_queue(destination).send_pdus(pdus) + for destination in destinations: + self._get_per_destination_queue(destination).send_pdu(pdu) async def send_read_receipt(self, receipt: ReadReceipt) -> None: """Send a RR to any other servers in the room diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index 3bb66bce32..3b053ebcfb 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -154,22 +154,19 @@ class PerDestinationQueue: + len(self._pending_edus_keyed) ) - def send_pdus(self, pdus: Iterable[EventBase]) -> None: - """Add PDUs to the queue, and start the transmission loop if necessary + def send_pdu(self, pdu: EventBase) -> None: + """Add a PDU to the queue, and start the transmission loop if necessary Args: - pdus: pdus to send + pdu: pdu to send """ if not self._catching_up or self._last_successful_stream_ordering is None: # only enqueue the PDU if we are not catching up (False) or do not # yet know if we have anything to catch up (None) - self._pending_pdus.extend(pdus) + self._pending_pdus.append(pdu) else: - self._catchup_last_skipped = max( - pdu.internal_metadata.stream_ordering - for pdu in pdus - if pdu.internal_metadata.stream_ordering is not None - ) + assert pdu.internal_metadata.stream_ordering + self._catchup_last_skipped = pdu.internal_metadata.stream_ordering self.attempt_new_transaction() diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index b28ca61f80..82335e7a9d 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -14,7 +14,7 @@ import logging from collections import namedtuple -from typing import Dict, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple from canonicaljson import encode_canonical_json @@ -295,33 +295,37 @@ class TransactionStore(TransactionWorkerStore): }, ) - async def bulk_store_destination_rooms_entries( - self, room_and_destination_to_ordering: Dict[Tuple[str, str], int] - ): + async def store_destination_rooms_entries( + self, + destinations: Iterable[str], + room_id: str, + stream_ordering: int, + ) -> None: """ - Updates or creates `destination_rooms` entries for a number of events. + Updates or creates `destination_rooms` entries in batch for a single event. Args: - room_and_destination_to_ordering: A mapping of (room, destination) -> stream_id + destinations: list of destinations + room_id: the room_id of the event + stream_ordering: the stream_ordering of the event """ await self.db_pool.simple_upsert_many( table="destinations", key_names=("destination",), - key_values={(d,) for _, d in room_and_destination_to_ordering.keys()}, + key_values=[(d,) for d in destinations], value_names=[], value_values=[], desc="store_destination_rooms_entries_dests", ) + rows = [(destination, room_id) for destination in destinations] await self.db_pool.simple_upsert_many( table="destination_rooms", - key_names=("room_id", "destination"), - key_values=list(room_and_destination_to_ordering.keys()), + key_names=("destination", "room_id"), + key_values=rows, value_names=["stream_ordering"], - value_values=[ - (stream_id,) for stream_id in room_and_destination_to_ordering.values() - ], + value_values=[(stream_ordering,)] * len(rows), desc="store_destination_rooms_entries_rooms", )