When joining a remote room limit the number of events we concurrently check signatures/hashes for (#10117)
If we do hundreds of thousands at once the memory overhead can easily reach 500+ MB.pull/10141/head
							parent
							
								
									a0101fc021
								
							
						
					
					
						commit
						c842c581ed
					
				|  | @ -0,0 +1 @@ | |||
| Significantly reduce memory usage of joining large remote rooms. | ||||
|  | @ -233,41 +233,19 @@ class Keyring: | |||
|             for server_name, json_object, validity_time in server_and_json | ||||
|         ] | ||||
| 
 | ||||
|     def verify_events_for_server( | ||||
|         self, server_and_events: Iterable[Tuple[str, EventBase, int]] | ||||
|     ) -> List[defer.Deferred]: | ||||
|         """Bulk verification of signatures on events. | ||||
| 
 | ||||
|         Args: | ||||
|             server_and_events: | ||||
|                 Iterable of `(server_name, event, validity_time)` tuples. | ||||
| 
 | ||||
|                 `server_name` is which server we are verifying the signature for | ||||
|                 on the event. | ||||
| 
 | ||||
|                 `event` is the event that we'll verify the signatures of for | ||||
|                 the given `server_name`. | ||||
| 
 | ||||
|                 `validity_time` is a timestamp at which the signing key must be | ||||
|                 valid. | ||||
| 
 | ||||
|         Returns: | ||||
|             List<Deferred[None]>: for each input triplet, a deferred indicating success | ||||
|                 or failure to verify each event's signature for the given | ||||
|                 server_name. The deferreds run their callbacks in the sentinel | ||||
|                 logcontext. | ||||
|         """ | ||||
|         return [ | ||||
|             run_in_background( | ||||
|                 self.process_request, | ||||
|                 VerifyJsonRequest.from_event( | ||||
|                     server_name, | ||||
|                     event, | ||||
|                     validity_time, | ||||
|                 ), | ||||
|     async def verify_event_for_server( | ||||
|         self, | ||||
|         server_name: str, | ||||
|         event: EventBase, | ||||
|         validity_time: int, | ||||
|     ) -> None: | ||||
|         await self.process_request( | ||||
|             VerifyJsonRequest.from_event( | ||||
|                 server_name, | ||||
|                 event, | ||||
|                 validity_time, | ||||
|             ) | ||||
|             for server_name, event, validity_time in server_and_events | ||||
|         ] | ||||
|         ) | ||||
| 
 | ||||
|     async def process_request(self, verify_request: VerifyJsonRequest) -> None: | ||||
|         """Processes the `VerifyJsonRequest`. Raises if the object is not signed | ||||
|  |  | |||
|  | @ -14,11 +14,6 @@ | |||
| # limitations under the License. | ||||
| import logging | ||||
| from collections import namedtuple | ||||
| from typing import Iterable, List | ||||
| 
 | ||||
| from twisted.internet import defer | ||||
| from twisted.internet.defer import Deferred, DeferredList | ||||
| from twisted.python.failure import Failure | ||||
| 
 | ||||
| from synapse.api.constants import MAX_DEPTH, EventTypes, Membership | ||||
| from synapse.api.errors import Codes, SynapseError | ||||
|  | @ -28,11 +23,6 @@ from synapse.crypto.keyring import Keyring | |||
| from synapse.events import EventBase, make_event_from_dict | ||||
| from synapse.events.utils import prune_event, validate_canonicaljson | ||||
| from synapse.http.servlet import assert_params_in_dict | ||||
| from synapse.logging.context import ( | ||||
|     PreserveLoggingContext, | ||||
|     current_context, | ||||
|     make_deferred_yieldable, | ||||
| ) | ||||
| from synapse.types import JsonDict, get_domain_from_id | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
|  | @ -48,112 +38,82 @@ class FederationBase: | |||
|         self.store = hs.get_datastore() | ||||
|         self._clock = hs.get_clock() | ||||
| 
 | ||||
|     def _check_sigs_and_hash( | ||||
|     async def _check_sigs_and_hash( | ||||
|         self, room_version: RoomVersion, pdu: EventBase | ||||
|     ) -> Deferred: | ||||
|         return make_deferred_yieldable( | ||||
|             self._check_sigs_and_hashes(room_version, [pdu])[0] | ||||
|         ) | ||||
| 
 | ||||
|     def _check_sigs_and_hashes( | ||||
|         self, room_version: RoomVersion, pdus: List[EventBase] | ||||
|     ) -> List[Deferred]: | ||||
|         """Checks that each of the received events is correctly signed by the | ||||
|         sending server. | ||||
|     ) -> EventBase: | ||||
|         """Checks that event is correctly signed by the sending server. | ||||
| 
 | ||||
|         Args: | ||||
|             room_version: The room version of the PDUs | ||||
|             pdus: the events to be checked | ||||
|             room_version: The room version of the PDU | ||||
|             pdu: the event to be checked | ||||
| 
 | ||||
|         Returns: | ||||
|             For each input event, a deferred which: | ||||
|               * returns the original event if the checks pass | ||||
|               * returns a redacted version of the event (if the signature | ||||
|               * the original event if the checks pass | ||||
|               * a redacted version of the event (if the signature | ||||
|                 matched but the hash did not) | ||||
|               * throws a SynapseError if the signature check failed. | ||||
|             The deferreds run their callbacks in the sentinel | ||||
|         """ | ||||
|         deferreds = _check_sigs_on_pdus(self.keyring, room_version, pdus) | ||||
| 
 | ||||
|         ctx = current_context() | ||||
| 
 | ||||
|         @defer.inlineCallbacks | ||||
|         def callback(_, pdu: EventBase): | ||||
|             with PreserveLoggingContext(ctx): | ||||
|                 if not check_event_content_hash(pdu): | ||||
|                     # let's try to distinguish between failures because the event was | ||||
|                     # redacted (which are somewhat expected) vs actual ball-tampering | ||||
|                     # incidents. | ||||
|                     # | ||||
|                     # This is just a heuristic, so we just assume that if the keys are | ||||
|                     # about the same between the redacted and received events, then the | ||||
|                     # received event was probably a redacted copy (but we then use our | ||||
|                     # *actual* redacted copy to be on the safe side.) | ||||
|                     redacted_event = prune_event(pdu) | ||||
|                     if set(redacted_event.keys()) == set(pdu.keys()) and set( | ||||
|                         redacted_event.content.keys() | ||||
|                     ) == set(pdu.content.keys()): | ||||
|                         logger.info( | ||||
|                             "Event %s seems to have been redacted; using our redacted " | ||||
|                             "copy", | ||||
|                             pdu.event_id, | ||||
|                         ) | ||||
|                     else: | ||||
|                         logger.warning( | ||||
|                             "Event %s content has been tampered, redacting", | ||||
|                             pdu.event_id, | ||||
|                         ) | ||||
|                     return redacted_event | ||||
| 
 | ||||
|                 result = yield defer.ensureDeferred( | ||||
|                     self.spam_checker.check_event_for_spam(pdu) | ||||
|                 ) | ||||
| 
 | ||||
|                 if result: | ||||
|                     logger.warning( | ||||
|                         "Event contains spam, redacting %s: %s", | ||||
|                         pdu.event_id, | ||||
|                         pdu.get_pdu_json(), | ||||
|                     ) | ||||
|                     return prune_event(pdu) | ||||
| 
 | ||||
|                 return pdu | ||||
| 
 | ||||
|         def errback(failure: Failure, pdu: EventBase): | ||||
|             failure.trap(SynapseError) | ||||
|             with PreserveLoggingContext(ctx): | ||||
|                 logger.warning( | ||||
|                     "Signature check failed for %s: %s", | ||||
|                     pdu.event_id, | ||||
|                     failure.getErrorMessage(), | ||||
|                 ) | ||||
|             return failure | ||||
| 
 | ||||
|         for deferred, pdu in zip(deferreds, pdus): | ||||
|             deferred.addCallbacks( | ||||
|                 callback, errback, callbackArgs=[pdu], errbackArgs=[pdu] | ||||
|               * throws a SynapseError if the signature check failed.""" | ||||
|         try: | ||||
|             await _check_sigs_on_pdu(self.keyring, room_version, pdu) | ||||
|         except SynapseError as e: | ||||
|             logger.warning( | ||||
|                 "Signature check failed for %s: %s", | ||||
|                 pdu.event_id, | ||||
|                 e, | ||||
|             ) | ||||
|             raise | ||||
| 
 | ||||
|         return deferreds | ||||
|         if not check_event_content_hash(pdu): | ||||
|             # let's try to distinguish between failures because the event was | ||||
|             # redacted (which are somewhat expected) vs actual ball-tampering | ||||
|             # incidents. | ||||
|             # | ||||
|             # This is just a heuristic, so we just assume that if the keys are | ||||
|             # about the same between the redacted and received events, then the | ||||
|             # received event was probably a redacted copy (but we then use our | ||||
|             # *actual* redacted copy to be on the safe side.) | ||||
|             redacted_event = prune_event(pdu) | ||||
|             if set(redacted_event.keys()) == set(pdu.keys()) and set( | ||||
|                 redacted_event.content.keys() | ||||
|             ) == set(pdu.content.keys()): | ||||
|                 logger.info( | ||||
|                     "Event %s seems to have been redacted; using our redacted copy", | ||||
|                     pdu.event_id, | ||||
|                 ) | ||||
|             else: | ||||
|                 logger.warning( | ||||
|                     "Event %s content has been tampered, redacting", | ||||
|                     pdu.event_id, | ||||
|                 ) | ||||
|             return redacted_event | ||||
| 
 | ||||
|         result = await self.spam_checker.check_event_for_spam(pdu) | ||||
| 
 | ||||
|         if result: | ||||
|             logger.warning( | ||||
|                 "Event contains spam, redacting %s: %s", | ||||
|                 pdu.event_id, | ||||
|                 pdu.get_pdu_json(), | ||||
|             ) | ||||
|             return prune_event(pdu) | ||||
| 
 | ||||
|         return pdu | ||||
| 
 | ||||
| 
 | ||||
| class PduToCheckSig(namedtuple("PduToCheckSig", ["pdu", "sender_domain", "deferreds"])): | ||||
|     pass | ||||
| 
 | ||||
| 
 | ||||
| def _check_sigs_on_pdus( | ||||
|     keyring: Keyring, room_version: RoomVersion, pdus: Iterable[EventBase] | ||||
| ) -> List[Deferred]: | ||||
| async def _check_sigs_on_pdu( | ||||
|     keyring: Keyring, room_version: RoomVersion, pdu: EventBase | ||||
| ) -> None: | ||||
|     """Check that the given events are correctly signed | ||||
| 
 | ||||
|     Raise a SynapseError if the event wasn't correctly signed. | ||||
| 
 | ||||
|     Args: | ||||
|         keyring: keyring object to do the checks | ||||
|         room_version: the room version of the PDUs | ||||
|         pdus: the events to be checked | ||||
| 
 | ||||
|     Returns: | ||||
|         A Deferred for each event in pdus, which will either succeed if | ||||
|            the signatures are valid, or fail (with a SynapseError) if not. | ||||
|     """ | ||||
| 
 | ||||
|     # we want to check that the event is signed by: | ||||
|  | @ -177,90 +137,47 @@ def _check_sigs_on_pdus( | |||
|     # let's start by getting the domain for each pdu, and flattening the event back | ||||
|     # to JSON. | ||||
| 
 | ||||
|     pdus_to_check = [ | ||||
|         PduToCheckSig( | ||||
|             pdu=p, | ||||
|             sender_domain=get_domain_from_id(p.sender), | ||||
|             deferreds=[], | ||||
|         ) | ||||
|         for p in pdus | ||||
|     ] | ||||
| 
 | ||||
|     # First we check that the sender event is signed by the sender's domain | ||||
|     # (except if its a 3pid invite, in which case it may be sent by any server) | ||||
|     pdus_to_check_sender = [p for p in pdus_to_check if not _is_invite_via_3pid(p.pdu)] | ||||
| 
 | ||||
|     more_deferreds = keyring.verify_events_for_server( | ||||
|         [ | ||||
|             ( | ||||
|                 p.sender_domain, | ||||
|                 p.pdu, | ||||
|                 p.pdu.origin_server_ts if room_version.enforce_key_validity else 0, | ||||
|     if not _is_invite_via_3pid(pdu): | ||||
|         try: | ||||
|             await keyring.verify_event_for_server( | ||||
|                 get_domain_from_id(pdu.sender), | ||||
|                 pdu, | ||||
|                 pdu.origin_server_ts if room_version.enforce_key_validity else 0, | ||||
|             ) | ||||
|             for p in pdus_to_check_sender | ||||
|         ] | ||||
|     ) | ||||
| 
 | ||||
|     def sender_err(e, pdu_to_check): | ||||
|         errmsg = "event id %s: unable to verify signature for sender %s: %s" % ( | ||||
|             pdu_to_check.pdu.event_id, | ||||
|             pdu_to_check.sender_domain, | ||||
|             e.getErrorMessage(), | ||||
|         ) | ||||
|         raise SynapseError(403, errmsg, Codes.FORBIDDEN) | ||||
| 
 | ||||
|     for p, d in zip(pdus_to_check_sender, more_deferreds): | ||||
|         d.addErrback(sender_err, p) | ||||
|         p.deferreds.append(d) | ||||
|         except Exception as e: | ||||
|             errmsg = "event id %s: unable to verify signature for sender %s: %s" % ( | ||||
|                 pdu.event_id, | ||||
|                 get_domain_from_id(pdu.sender), | ||||
|                 e, | ||||
|             ) | ||||
|             raise SynapseError(403, errmsg, Codes.FORBIDDEN) | ||||
| 
 | ||||
|     # now let's look for events where the sender's domain is different to the | ||||
|     # event id's domain (normally only the case for joins/leaves), and add additional | ||||
|     # checks. Only do this if the room version has a concept of event ID domain | ||||
|     # (ie, the room version uses old-style non-hash event IDs). | ||||
|     if room_version.event_format == EventFormatVersions.V1: | ||||
|         pdus_to_check_event_id = [ | ||||
|             p | ||||
|             for p in pdus_to_check | ||||
|             if p.sender_domain != get_domain_from_id(p.pdu.event_id) | ||||
|         ] | ||||
| 
 | ||||
|         more_deferreds = keyring.verify_events_for_server( | ||||
|             [ | ||||
|                 ( | ||||
|                     get_domain_from_id(p.pdu.event_id), | ||||
|                     p.pdu, | ||||
|                     p.pdu.origin_server_ts if room_version.enforce_key_validity else 0, | ||||
|                 ) | ||||
|                 for p in pdus_to_check_event_id | ||||
|             ] | ||||
|         ) | ||||
| 
 | ||||
|         def event_err(e, pdu_to_check): | ||||
|     if room_version.event_format == EventFormatVersions.V1 and get_domain_from_id( | ||||
|         pdu.event_id | ||||
|     ) != get_domain_from_id(pdu.sender): | ||||
|         try: | ||||
|             await keyring.verify_event_for_server( | ||||
|                 get_domain_from_id(pdu.event_id), | ||||
|                 pdu, | ||||
|                 pdu.origin_server_ts if room_version.enforce_key_validity else 0, | ||||
|             ) | ||||
|         except Exception as e: | ||||
|             errmsg = ( | ||||
|                 "event id %s: unable to verify signature for event id domain: %s" | ||||
|                 % (pdu_to_check.pdu.event_id, e.getErrorMessage()) | ||||
|                 "event id %s: unable to verify signature for event id domain %s: %s" | ||||
|                 % ( | ||||
|                     pdu.event_id, | ||||
|                     get_domain_from_id(pdu.event_id), | ||||
|                     e, | ||||
|                 ) | ||||
|             ) | ||||
|             raise SynapseError(403, errmsg, Codes.FORBIDDEN) | ||||
| 
 | ||||
|         for p, d in zip(pdus_to_check_event_id, more_deferreds): | ||||
|             d.addErrback(event_err, p) | ||||
|             p.deferreds.append(d) | ||||
| 
 | ||||
|     # replace lists of deferreds with single Deferreds | ||||
|     return [_flatten_deferred_list(p.deferreds) for p in pdus_to_check] | ||||
| 
 | ||||
| 
 | ||||
| def _flatten_deferred_list(deferreds: List[Deferred]) -> Deferred: | ||||
|     """Given a list of deferreds, either return the single deferred, | ||||
|     combine into a DeferredList, or return an already resolved deferred. | ||||
|     """ | ||||
|     if len(deferreds) > 1: | ||||
|         return DeferredList(deferreds, fireOnOneErrback=True, consumeErrors=True) | ||||
|     elif len(deferreds) == 1: | ||||
|         return deferreds[0] | ||||
|     else: | ||||
|         return defer.succeed(None) | ||||
| 
 | ||||
| 
 | ||||
| def _is_invite_via_3pid(event: EventBase) -> bool: | ||||
|     return ( | ||||
|  |  | |||
|  | @ -21,6 +21,7 @@ from typing import ( | |||
|     Any, | ||||
|     Awaitable, | ||||
|     Callable, | ||||
|     Collection, | ||||
|     Dict, | ||||
|     Iterable, | ||||
|     List, | ||||
|  | @ -35,9 +36,6 @@ from typing import ( | |||
| import attr | ||||
| from prometheus_client import Counter | ||||
| 
 | ||||
| from twisted.internet import defer | ||||
| from twisted.internet.defer import Deferred | ||||
| 
 | ||||
| from synapse.api.constants import EventTypes, Membership | ||||
| from synapse.api.errors import ( | ||||
|     CodeMessageException, | ||||
|  | @ -56,10 +54,9 @@ from synapse.api.room_versions import ( | |||
| from synapse.events import EventBase, builder | ||||
| from synapse.federation.federation_base import FederationBase, event_from_pdu_json | ||||
| from synapse.federation.transport.client import SendJoinResponse | ||||
| from synapse.logging.context import make_deferred_yieldable, preserve_fn | ||||
| from synapse.logging.utils import log_function | ||||
| from synapse.types import JsonDict, get_domain_from_id | ||||
| from synapse.util import unwrapFirstError | ||||
| from synapse.util.async_helpers import concurrently_execute | ||||
| from synapse.util.caches.expiringcache import ExpiringCache | ||||
| from synapse.util.retryutils import NotRetryingDestination | ||||
| 
 | ||||
|  | @ -360,10 +357,9 @@ class FederationClient(FederationBase): | |||
|     async def _check_sigs_and_hash_and_fetch( | ||||
|         self, | ||||
|         origin: str, | ||||
|         pdus: List[EventBase], | ||||
|         pdus: Collection[EventBase], | ||||
|         room_version: RoomVersion, | ||||
|         outlier: bool = False, | ||||
|         include_none: bool = False, | ||||
|     ) -> List[EventBase]: | ||||
|         """Takes a list of PDUs and checks the signatures and hashes of each | ||||
|         one. If a PDU fails its signature check then we check if we have it in | ||||
|  | @ -380,57 +376,87 @@ class FederationClient(FederationBase): | |||
|             pdu | ||||
|             room_version | ||||
|             outlier: Whether the events are outliers or not | ||||
|             include_none: Whether to include None in the returned list | ||||
|                 for events that have failed their checks | ||||
| 
 | ||||
|         Returns: | ||||
|             A list of PDUs that have valid signatures and hashes. | ||||
|         """ | ||||
|         deferreds = self._check_sigs_and_hashes(room_version, pdus) | ||||
| 
 | ||||
|         async def handle_check_result(pdu: EventBase, deferred: Deferred): | ||||
|         # We limit how many PDUs we check at once, as if we try to do hundreds | ||||
|         # of thousands of PDUs at once we see large memory spikes. | ||||
| 
 | ||||
|         valid_pdus = [] | ||||
| 
 | ||||
|         async def _execute(pdu: EventBase) -> None: | ||||
|             valid_pdu = await self._check_sigs_and_hash_and_fetch_one( | ||||
|                 pdu=pdu, | ||||
|                 origin=origin, | ||||
|                 outlier=outlier, | ||||
|                 room_version=room_version, | ||||
|             ) | ||||
| 
 | ||||
|             if valid_pdu: | ||||
|                 valid_pdus.append(valid_pdu) | ||||
| 
 | ||||
|         await concurrently_execute(_execute, pdus, 10000) | ||||
| 
 | ||||
|         return valid_pdus | ||||
| 
 | ||||
|     async def _check_sigs_and_hash_and_fetch_one( | ||||
|         self, | ||||
|         pdu: EventBase, | ||||
|         origin: str, | ||||
|         room_version: RoomVersion, | ||||
|         outlier: bool = False, | ||||
|     ) -> Optional[EventBase]: | ||||
|         """Takes a PDU and checks its signatures and hashes. If the PDU fails | ||||
|         its signature check then we check if we have it in the database and if | ||||
|         not then request if from the originating server of that PDU. | ||||
| 
 | ||||
|         If then PDU fails its content hash check then it is redacted. | ||||
| 
 | ||||
|         Args: | ||||
|             origin | ||||
|             pdu | ||||
|             room_version | ||||
|             outlier: Whether the events are outliers or not | ||||
|             include_none: Whether to include None in the returned list | ||||
|                 for events that have failed their checks | ||||
| 
 | ||||
|         Returns: | ||||
|             The PDU (possibly redacted) if it has valid signatures and hashes. | ||||
|         """ | ||||
| 
 | ||||
|         res = None | ||||
|         try: | ||||
|             res = await self._check_sigs_and_hash(room_version, pdu) | ||||
|         except SynapseError: | ||||
|             pass | ||||
| 
 | ||||
|         if not res: | ||||
|             # Check local db. | ||||
|             res = await self.store.get_event( | ||||
|                 pdu.event_id, allow_rejected=True, allow_none=True | ||||
|             ) | ||||
| 
 | ||||
|         pdu_origin = get_domain_from_id(pdu.sender) | ||||
|         if not res and pdu_origin != origin: | ||||
|             try: | ||||
|                 res = await make_deferred_yieldable(deferred) | ||||
|                 res = await self.get_pdu( | ||||
|                     destinations=[pdu_origin], | ||||
|                     event_id=pdu.event_id, | ||||
|                     room_version=room_version, | ||||
|                     outlier=outlier, | ||||
|                     timeout=10000, | ||||
|                 ) | ||||
|             except SynapseError: | ||||
|                 res = None | ||||
|                 pass | ||||
| 
 | ||||
|             if not res: | ||||
|                 # Check local db. | ||||
|                 res = await self.store.get_event( | ||||
|                     pdu.event_id, allow_rejected=True, allow_none=True | ||||
|                 ) | ||||
|         if not res: | ||||
|             logger.warning( | ||||
|                 "Failed to find copy of %s with valid signature", pdu.event_id | ||||
|             ) | ||||
| 
 | ||||
|             pdu_origin = get_domain_from_id(pdu.sender) | ||||
|             if not res and pdu_origin != origin: | ||||
|                 try: | ||||
|                     res = await self.get_pdu( | ||||
|                         destinations=[pdu_origin], | ||||
|                         event_id=pdu.event_id, | ||||
|                         room_version=room_version, | ||||
|                         outlier=outlier, | ||||
|                         timeout=10000, | ||||
|                     ) | ||||
|                 except SynapseError: | ||||
|                     pass | ||||
| 
 | ||||
|             if not res: | ||||
|                 logger.warning( | ||||
|                     "Failed to find copy of %s with valid signature", pdu.event_id | ||||
|                 ) | ||||
| 
 | ||||
|             return res | ||||
| 
 | ||||
|         handle = preserve_fn(handle_check_result) | ||||
|         deferreds2 = [handle(pdu, deferred) for pdu, deferred in zip(pdus, deferreds)] | ||||
| 
 | ||||
|         valid_pdus = await make_deferred_yieldable( | ||||
|             defer.gatherResults(deferreds2, consumeErrors=True) | ||||
|         ).addErrback(unwrapFirstError) | ||||
| 
 | ||||
|         if include_none: | ||||
|             return valid_pdus | ||||
|         else: | ||||
|             return [p for p in valid_pdus if p] | ||||
|         return res | ||||
| 
 | ||||
|     async def get_event_auth( | ||||
|         self, destination: str, room_id: str, event_id: str | ||||
|  | @ -671,8 +697,6 @@ class FederationClient(FederationBase): | |||
|             state = response.state | ||||
|             auth_chain = response.auth_events | ||||
| 
 | ||||
|             pdus = {p.event_id: p for p in itertools.chain(state, auth_chain)} | ||||
| 
 | ||||
|             create_event = None | ||||
|             for e in state: | ||||
|                 if (e.type, e.state_key) == (EventTypes.Create, ""): | ||||
|  | @ -696,14 +720,29 @@ class FederationClient(FederationBase): | |||
|                     % (create_room_version,) | ||||
|                 ) | ||||
| 
 | ||||
|             valid_pdus = await self._check_sigs_and_hash_and_fetch( | ||||
|                 destination, | ||||
|                 list(pdus.values()), | ||||
|                 outlier=True, | ||||
|                 room_version=room_version, | ||||
|             logger.info( | ||||
|                 "Processing from send_join %d events", len(state) + len(auth_chain) | ||||
|             ) | ||||
| 
 | ||||
|             valid_pdus_map = {p.event_id: p for p in valid_pdus} | ||||
|             # We now go and check the signatures and hashes for the event. Note | ||||
|             # that we limit how many events we process at a time to keep the | ||||
|             # memory overhead from exploding. | ||||
|             valid_pdus_map: Dict[str, EventBase] = {} | ||||
| 
 | ||||
|             async def _execute(pdu: EventBase) -> None: | ||||
|                 valid_pdu = await self._check_sigs_and_hash_and_fetch_one( | ||||
|                     pdu=pdu, | ||||
|                     origin=destination, | ||||
|                     outlier=True, | ||||
|                     room_version=room_version, | ||||
|                 ) | ||||
| 
 | ||||
|                 if valid_pdu: | ||||
|                     valid_pdus_map[valid_pdu.event_id] = valid_pdu | ||||
| 
 | ||||
|             await concurrently_execute( | ||||
|                 _execute, itertools.chain(state, auth_chain), 10000 | ||||
|             ) | ||||
| 
 | ||||
|             # NB: We *need* to copy to ensure that we don't have multiple | ||||
|             # references being passed on, as that causes... issues. | ||||
|  |  | |||
|  | @ -15,6 +15,7 @@ | |||
| 
 | ||||
| import collections | ||||
| import inspect | ||||
| import itertools | ||||
| import logging | ||||
| from contextlib import contextmanager | ||||
| from typing import ( | ||||
|  | @ -160,8 +161,11 @@ class ObservableDeferred: | |||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| T = TypeVar("T") | ||||
| 
 | ||||
| 
 | ||||
| def concurrently_execute( | ||||
|     func: Callable, args: Iterable[Any], limit: int | ||||
|     func: Callable[[T], Any], args: Iterable[T], limit: int | ||||
| ) -> defer.Deferred: | ||||
|     """Executes the function with each argument concurrently while limiting | ||||
|     the number of concurrent executions. | ||||
|  | @ -173,20 +177,27 @@ def concurrently_execute( | |||
|         limit: Maximum number of conccurent executions. | ||||
| 
 | ||||
|     Returns: | ||||
|         Deferred[list]: Resolved when all function invocations have finished. | ||||
|         Deferred: Resolved when all function invocations have finished. | ||||
|     """ | ||||
|     it = iter(args) | ||||
| 
 | ||||
|     async def _concurrently_execute_inner(): | ||||
|     async def _concurrently_execute_inner(value: T) -> None: | ||||
|         try: | ||||
|             while True: | ||||
|                 await maybe_awaitable(func(next(it))) | ||||
|                 await maybe_awaitable(func(value)) | ||||
|                 value = next(it) | ||||
|         except StopIteration: | ||||
|             pass | ||||
| 
 | ||||
|     # We use `itertools.islice` to handle the case where the number of args is | ||||
|     # less than the limit, avoiding needlessly spawning unnecessary background | ||||
|     # tasks. | ||||
|     return make_deferred_yieldable( | ||||
|         defer.gatherResults( | ||||
|             [run_in_background(_concurrently_execute_inner) for _ in range(limit)], | ||||
|             [ | ||||
|                 run_in_background(_concurrently_execute_inner, value) | ||||
|                 for value in itertools.islice(it, limit) | ||||
|             ], | ||||
|             consumeErrors=True, | ||||
|         ) | ||||
|     ).addErrback(unwrapFirstError) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 Erik Johnston
						Erik Johnston