diff --git a/changelog.d/10117.feature b/changelog.d/10117.feature new file mode 100644 index 0000000000..e137e142c6 --- /dev/null +++ b/changelog.d/10117.feature @@ -0,0 +1 @@ +Significantly reduce memory usage of joining large remote rooms. diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index c840ffca71..e5a4685ed4 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -233,41 +233,19 @@ class Keyring: for server_name, json_object, validity_time in server_and_json ] - def verify_events_for_server( - self, server_and_events: Iterable[Tuple[str, EventBase, int]] - ) -> List[defer.Deferred]: - """Bulk verification of signatures on events. - - Args: - server_and_events: - Iterable of `(server_name, event, validity_time)` tuples. - - `server_name` is which server we are verifying the signature for - on the event. - - `event` is the event that we'll verify the signatures of for - the given `server_name`. - - `validity_time` is a timestamp at which the signing key must be - valid. - - Returns: - List: for each input triplet, a deferred indicating success - or failure to verify each event's signature for the given - server_name. The deferreds run their callbacks in the sentinel - logcontext. - """ - return [ - run_in_background( - self.process_request, - VerifyJsonRequest.from_event( - server_name, - event, - validity_time, - ), + async def verify_event_for_server( + self, + server_name: str, + event: EventBase, + validity_time: int, + ) -> None: + await self.process_request( + VerifyJsonRequest.from_event( + server_name, + event, + validity_time, ) - for server_name, event, validity_time in server_and_events - ] + ) async def process_request(self, verify_request: VerifyJsonRequest) -> None: """Processes the `VerifyJsonRequest`. Raises if the object is not signed diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index 3fe496dcd3..c066617b92 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -14,11 +14,6 @@ # limitations under the License. import logging from collections import namedtuple -from typing import Iterable, List - -from twisted.internet import defer -from twisted.internet.defer import Deferred, DeferredList -from twisted.python.failure import Failure from synapse.api.constants import MAX_DEPTH, EventTypes, Membership from synapse.api.errors import Codes, SynapseError @@ -28,11 +23,6 @@ from synapse.crypto.keyring import Keyring from synapse.events import EventBase, make_event_from_dict from synapse.events.utils import prune_event, validate_canonicaljson from synapse.http.servlet import assert_params_in_dict -from synapse.logging.context import ( - PreserveLoggingContext, - current_context, - make_deferred_yieldable, -) from synapse.types import JsonDict, get_domain_from_id logger = logging.getLogger(__name__) @@ -48,112 +38,82 @@ class FederationBase: self.store = hs.get_datastore() self._clock = hs.get_clock() - def _check_sigs_and_hash( + async def _check_sigs_and_hash( self, room_version: RoomVersion, pdu: EventBase - ) -> Deferred: - return make_deferred_yieldable( - self._check_sigs_and_hashes(room_version, [pdu])[0] - ) - - def _check_sigs_and_hashes( - self, room_version: RoomVersion, pdus: List[EventBase] - ) -> List[Deferred]: - """Checks that each of the received events is correctly signed by the - sending server. + ) -> EventBase: + """Checks that event is correctly signed by the sending server. Args: - room_version: The room version of the PDUs - pdus: the events to be checked + room_version: The room version of the PDU + pdu: the event to be checked Returns: - For each input event, a deferred which: - * returns the original event if the checks pass - * returns a redacted version of the event (if the signature + * the original event if the checks pass + * a redacted version of the event (if the signature matched but the hash did not) - * throws a SynapseError if the signature check failed. - The deferreds run their callbacks in the sentinel - """ - deferreds = _check_sigs_on_pdus(self.keyring, room_version, pdus) - - ctx = current_context() - - @defer.inlineCallbacks - def callback(_, pdu: EventBase): - with PreserveLoggingContext(ctx): - if not check_event_content_hash(pdu): - # let's try to distinguish between failures because the event was - # redacted (which are somewhat expected) vs actual ball-tampering - # incidents. - # - # This is just a heuristic, so we just assume that if the keys are - # about the same between the redacted and received events, then the - # received event was probably a redacted copy (but we then use our - # *actual* redacted copy to be on the safe side.) - redacted_event = prune_event(pdu) - if set(redacted_event.keys()) == set(pdu.keys()) and set( - redacted_event.content.keys() - ) == set(pdu.content.keys()): - logger.info( - "Event %s seems to have been redacted; using our redacted " - "copy", - pdu.event_id, - ) - else: - logger.warning( - "Event %s content has been tampered, redacting", - pdu.event_id, - ) - return redacted_event - - result = yield defer.ensureDeferred( - self.spam_checker.check_event_for_spam(pdu) - ) - - if result: - logger.warning( - "Event contains spam, redacting %s: %s", - pdu.event_id, - pdu.get_pdu_json(), - ) - return prune_event(pdu) - - return pdu - - def errback(failure: Failure, pdu: EventBase): - failure.trap(SynapseError) - with PreserveLoggingContext(ctx): - logger.warning( - "Signature check failed for %s: %s", - pdu.event_id, - failure.getErrorMessage(), - ) - return failure - - for deferred, pdu in zip(deferreds, pdus): - deferred.addCallbacks( - callback, errback, callbackArgs=[pdu], errbackArgs=[pdu] + * throws a SynapseError if the signature check failed.""" + try: + await _check_sigs_on_pdu(self.keyring, room_version, pdu) + except SynapseError as e: + logger.warning( + "Signature check failed for %s: %s", + pdu.event_id, + e, ) + raise - return deferreds + if not check_event_content_hash(pdu): + # let's try to distinguish between failures because the event was + # redacted (which are somewhat expected) vs actual ball-tampering + # incidents. + # + # This is just a heuristic, so we just assume that if the keys are + # about the same between the redacted and received events, then the + # received event was probably a redacted copy (but we then use our + # *actual* redacted copy to be on the safe side.) + redacted_event = prune_event(pdu) + if set(redacted_event.keys()) == set(pdu.keys()) and set( + redacted_event.content.keys() + ) == set(pdu.content.keys()): + logger.info( + "Event %s seems to have been redacted; using our redacted copy", + pdu.event_id, + ) + else: + logger.warning( + "Event %s content has been tampered, redacting", + pdu.event_id, + ) + return redacted_event + + result = await self.spam_checker.check_event_for_spam(pdu) + + if result: + logger.warning( + "Event contains spam, redacting %s: %s", + pdu.event_id, + pdu.get_pdu_json(), + ) + return prune_event(pdu) + + return pdu class PduToCheckSig(namedtuple("PduToCheckSig", ["pdu", "sender_domain", "deferreds"])): pass -def _check_sigs_on_pdus( - keyring: Keyring, room_version: RoomVersion, pdus: Iterable[EventBase] -) -> List[Deferred]: +async def _check_sigs_on_pdu( + keyring: Keyring, room_version: RoomVersion, pdu: EventBase +) -> None: """Check that the given events are correctly signed + Raise a SynapseError if the event wasn't correctly signed. + Args: keyring: keyring object to do the checks room_version: the room version of the PDUs pdus: the events to be checked - - Returns: - A Deferred for each event in pdus, which will either succeed if - the signatures are valid, or fail (with a SynapseError) if not. """ # we want to check that the event is signed by: @@ -177,90 +137,47 @@ def _check_sigs_on_pdus( # let's start by getting the domain for each pdu, and flattening the event back # to JSON. - pdus_to_check = [ - PduToCheckSig( - pdu=p, - sender_domain=get_domain_from_id(p.sender), - deferreds=[], - ) - for p in pdus - ] - # First we check that the sender event is signed by the sender's domain # (except if its a 3pid invite, in which case it may be sent by any server) - pdus_to_check_sender = [p for p in pdus_to_check if not _is_invite_via_3pid(p.pdu)] - - more_deferreds = keyring.verify_events_for_server( - [ - ( - p.sender_domain, - p.pdu, - p.pdu.origin_server_ts if room_version.enforce_key_validity else 0, + if not _is_invite_via_3pid(pdu): + try: + await keyring.verify_event_for_server( + get_domain_from_id(pdu.sender), + pdu, + pdu.origin_server_ts if room_version.enforce_key_validity else 0, ) - for p in pdus_to_check_sender - ] - ) - - def sender_err(e, pdu_to_check): - errmsg = "event id %s: unable to verify signature for sender %s: %s" % ( - pdu_to_check.pdu.event_id, - pdu_to_check.sender_domain, - e.getErrorMessage(), - ) - raise SynapseError(403, errmsg, Codes.FORBIDDEN) - - for p, d in zip(pdus_to_check_sender, more_deferreds): - d.addErrback(sender_err, p) - p.deferreds.append(d) + except Exception as e: + errmsg = "event id %s: unable to verify signature for sender %s: %s" % ( + pdu.event_id, + get_domain_from_id(pdu.sender), + e, + ) + raise SynapseError(403, errmsg, Codes.FORBIDDEN) # now let's look for events where the sender's domain is different to the # event id's domain (normally only the case for joins/leaves), and add additional # checks. Only do this if the room version has a concept of event ID domain # (ie, the room version uses old-style non-hash event IDs). - if room_version.event_format == EventFormatVersions.V1: - pdus_to_check_event_id = [ - p - for p in pdus_to_check - if p.sender_domain != get_domain_from_id(p.pdu.event_id) - ] - - more_deferreds = keyring.verify_events_for_server( - [ - ( - get_domain_from_id(p.pdu.event_id), - p.pdu, - p.pdu.origin_server_ts if room_version.enforce_key_validity else 0, - ) - for p in pdus_to_check_event_id - ] - ) - - def event_err(e, pdu_to_check): + if room_version.event_format == EventFormatVersions.V1 and get_domain_from_id( + pdu.event_id + ) != get_domain_from_id(pdu.sender): + try: + await keyring.verify_event_for_server( + get_domain_from_id(pdu.event_id), + pdu, + pdu.origin_server_ts if room_version.enforce_key_validity else 0, + ) + except Exception as e: errmsg = ( - "event id %s: unable to verify signature for event id domain: %s" - % (pdu_to_check.pdu.event_id, e.getErrorMessage()) + "event id %s: unable to verify signature for event id domain %s: %s" + % ( + pdu.event_id, + get_domain_from_id(pdu.event_id), + e, + ) ) raise SynapseError(403, errmsg, Codes.FORBIDDEN) - for p, d in zip(pdus_to_check_event_id, more_deferreds): - d.addErrback(event_err, p) - p.deferreds.append(d) - - # replace lists of deferreds with single Deferreds - return [_flatten_deferred_list(p.deferreds) for p in pdus_to_check] - - -def _flatten_deferred_list(deferreds: List[Deferred]) -> Deferred: - """Given a list of deferreds, either return the single deferred, - combine into a DeferredList, or return an already resolved deferred. - """ - if len(deferreds) > 1: - return DeferredList(deferreds, fireOnOneErrback=True, consumeErrors=True) - elif len(deferreds) == 1: - return deferreds[0] - else: - return defer.succeed(None) - def _is_invite_via_3pid(event: EventBase) -> bool: return ( diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index e0e9f5d0be..1076ebc036 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -21,6 +21,7 @@ from typing import ( Any, Awaitable, Callable, + Collection, Dict, Iterable, List, @@ -35,9 +36,6 @@ from typing import ( import attr from prometheus_client import Counter -from twisted.internet import defer -from twisted.internet.defer import Deferred - from synapse.api.constants import EventTypes, Membership from synapse.api.errors import ( CodeMessageException, @@ -56,10 +54,9 @@ from synapse.api.room_versions import ( from synapse.events import EventBase, builder from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.federation.transport.client import SendJoinResponse -from synapse.logging.context import make_deferred_yieldable, preserve_fn from synapse.logging.utils import log_function from synapse.types import JsonDict, get_domain_from_id -from synapse.util import unwrapFirstError +from synapse.util.async_helpers import concurrently_execute from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.retryutils import NotRetryingDestination @@ -360,10 +357,9 @@ class FederationClient(FederationBase): async def _check_sigs_and_hash_and_fetch( self, origin: str, - pdus: List[EventBase], + pdus: Collection[EventBase], room_version: RoomVersion, outlier: bool = False, - include_none: bool = False, ) -> List[EventBase]: """Takes a list of PDUs and checks the signatures and hashes of each one. If a PDU fails its signature check then we check if we have it in @@ -380,57 +376,87 @@ class FederationClient(FederationBase): pdu room_version outlier: Whether the events are outliers or not - include_none: Whether to include None in the returned list - for events that have failed their checks Returns: A list of PDUs that have valid signatures and hashes. """ - deferreds = self._check_sigs_and_hashes(room_version, pdus) - async def handle_check_result(pdu: EventBase, deferred: Deferred): + # We limit how many PDUs we check at once, as if we try to do hundreds + # of thousands of PDUs at once we see large memory spikes. + + valid_pdus = [] + + async def _execute(pdu: EventBase) -> None: + valid_pdu = await self._check_sigs_and_hash_and_fetch_one( + pdu=pdu, + origin=origin, + outlier=outlier, + room_version=room_version, + ) + + if valid_pdu: + valid_pdus.append(valid_pdu) + + await concurrently_execute(_execute, pdus, 10000) + + return valid_pdus + + async def _check_sigs_and_hash_and_fetch_one( + self, + pdu: EventBase, + origin: str, + room_version: RoomVersion, + outlier: bool = False, + ) -> Optional[EventBase]: + """Takes a PDU and checks its signatures and hashes. If the PDU fails + its signature check then we check if we have it in the database and if + not then request if from the originating server of that PDU. + + If then PDU fails its content hash check then it is redacted. + + Args: + origin + pdu + room_version + outlier: Whether the events are outliers or not + include_none: Whether to include None in the returned list + for events that have failed their checks + + Returns: + The PDU (possibly redacted) if it has valid signatures and hashes. + """ + + res = None + try: + res = await self._check_sigs_and_hash(room_version, pdu) + except SynapseError: + pass + + if not res: + # Check local db. + res = await self.store.get_event( + pdu.event_id, allow_rejected=True, allow_none=True + ) + + pdu_origin = get_domain_from_id(pdu.sender) + if not res and pdu_origin != origin: try: - res = await make_deferred_yieldable(deferred) + res = await self.get_pdu( + destinations=[pdu_origin], + event_id=pdu.event_id, + room_version=room_version, + outlier=outlier, + timeout=10000, + ) except SynapseError: - res = None + pass - if not res: - # Check local db. - res = await self.store.get_event( - pdu.event_id, allow_rejected=True, allow_none=True - ) + if not res: + logger.warning( + "Failed to find copy of %s with valid signature", pdu.event_id + ) - pdu_origin = get_domain_from_id(pdu.sender) - if not res and pdu_origin != origin: - try: - res = await self.get_pdu( - destinations=[pdu_origin], - event_id=pdu.event_id, - room_version=room_version, - outlier=outlier, - timeout=10000, - ) - except SynapseError: - pass - - if not res: - logger.warning( - "Failed to find copy of %s with valid signature", pdu.event_id - ) - - return res - - handle = preserve_fn(handle_check_result) - deferreds2 = [handle(pdu, deferred) for pdu, deferred in zip(pdus, deferreds)] - - valid_pdus = await make_deferred_yieldable( - defer.gatherResults(deferreds2, consumeErrors=True) - ).addErrback(unwrapFirstError) - - if include_none: - return valid_pdus - else: - return [p for p in valid_pdus if p] + return res async def get_event_auth( self, destination: str, room_id: str, event_id: str @@ -671,8 +697,6 @@ class FederationClient(FederationBase): state = response.state auth_chain = response.auth_events - pdus = {p.event_id: p for p in itertools.chain(state, auth_chain)} - create_event = None for e in state: if (e.type, e.state_key) == (EventTypes.Create, ""): @@ -696,14 +720,29 @@ class FederationClient(FederationBase): % (create_room_version,) ) - valid_pdus = await self._check_sigs_and_hash_and_fetch( - destination, - list(pdus.values()), - outlier=True, - room_version=room_version, + logger.info( + "Processing from send_join %d events", len(state) + len(auth_chain) ) - valid_pdus_map = {p.event_id: p for p in valid_pdus} + # We now go and check the signatures and hashes for the event. Note + # that we limit how many events we process at a time to keep the + # memory overhead from exploding. + valid_pdus_map: Dict[str, EventBase] = {} + + async def _execute(pdu: EventBase) -> None: + valid_pdu = await self._check_sigs_and_hash_and_fetch_one( + pdu=pdu, + origin=destination, + outlier=True, + room_version=room_version, + ) + + if valid_pdu: + valid_pdus_map[valid_pdu.event_id] = valid_pdu + + await concurrently_execute( + _execute, itertools.chain(state, auth_chain), 10000 + ) # NB: We *need* to copy to ensure that we don't have multiple # references being passed on, as that causes... issues. diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 5c55bb0125..061102c3c8 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -15,6 +15,7 @@ import collections import inspect +import itertools import logging from contextlib import contextmanager from typing import ( @@ -160,8 +161,11 @@ class ObservableDeferred: ) +T = TypeVar("T") + + def concurrently_execute( - func: Callable, args: Iterable[Any], limit: int + func: Callable[[T], Any], args: Iterable[T], limit: int ) -> defer.Deferred: """Executes the function with each argument concurrently while limiting the number of concurrent executions. @@ -173,20 +177,27 @@ def concurrently_execute( limit: Maximum number of conccurent executions. Returns: - Deferred[list]: Resolved when all function invocations have finished. + Deferred: Resolved when all function invocations have finished. """ it = iter(args) - async def _concurrently_execute_inner(): + async def _concurrently_execute_inner(value: T) -> None: try: while True: - await maybe_awaitable(func(next(it))) + await maybe_awaitable(func(value)) + value = next(it) except StopIteration: pass + # We use `itertools.islice` to handle the case where the number of args is + # less than the limit, avoiding needlessly spawning unnecessary background + # tasks. return make_deferred_yieldable( defer.gatherResults( - [run_in_background(_concurrently_execute_inner) for _ in range(limit)], + [ + run_in_background(_concurrently_execute_inner, value) + for value in itertools.islice(it, limit) + ], consumeErrors=True, ) ).addErrback(unwrapFirstError)