Add `Keyring.verify_events_for_server` and reduce memory usage (#10018)

Also add support for giving a callback to generate the JSON object to
verify. This should reduce memory usage, as we no longer have the event
in memory in dict form (which has a large memory footprint) for extend
periods of time.
pull/10029/head
Erik Johnston 2021-05-20 16:25:11 +01:00 committed by GitHub
parent 64887f06fc
commit 1c6a19002c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 94 additions and 22 deletions

1
changelog.d/10018.misc Normal file
View File

@ -0,0 +1 @@
Reduce memory usage when verifying signatures on large numbers of events at once.

View File

@ -17,7 +17,7 @@ import abc
import logging
import urllib
from collections import defaultdict
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Set, Tuple
import attr
from signedjson.key import (
@ -42,6 +42,8 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.config.key import TrustedKeyServer
from synapse.events import EventBase
from synapse.events.utils import prune_event_dict
from synapse.logging.context import (
PreserveLoggingContext,
make_deferred_yieldable,
@ -69,7 +71,11 @@ class VerifyJsonRequest:
Attributes:
server_name: The name of the server to verify against.
json_object: The JSON object to verify.
get_json_object: A callback to fetch the JSON object to verify.
A callback is used to allow deferring the creation of the JSON
object to verify until needed, e.g. for events we can defer
creating the redacted copy. This reduces the memory usage when
there are large numbers of in flight requests.
minimum_valid_until_ts: time at which we require the signing key to
be valid. (0 implies we don't care)
@ -88,14 +94,50 @@ class VerifyJsonRequest:
"""
server_name = attr.ib(type=str)
json_object = attr.ib(type=JsonDict)
get_json_object = attr.ib(type=Callable[[], JsonDict])
minimum_valid_until_ts = attr.ib(type=int)
request_name = attr.ib(type=str)
key_ids = attr.ib(init=False, type=List[str])
key_ids = attr.ib(type=List[str])
key_ready = attr.ib(default=attr.Factory(defer.Deferred), type=defer.Deferred)
def __attrs_post_init__(self):
self.key_ids = signature_ids(self.json_object, self.server_name)
@staticmethod
def from_json_object(
server_name: str,
json_object: JsonDict,
minimum_valid_until_ms: int,
request_name: str,
):
"""Create a VerifyJsonRequest to verify all signatures on a signed JSON
object for the given server.
"""
key_ids = signature_ids(json_object, server_name)
return VerifyJsonRequest(
server_name,
lambda: json_object,
minimum_valid_until_ms,
request_name=request_name,
key_ids=key_ids,
)
@staticmethod
def from_event(
server_name: str,
event: EventBase,
minimum_valid_until_ms: int,
):
"""Create a VerifyJsonRequest to verify all signatures on an event
object for the given server.
"""
key_ids = list(event.signatures.get(server_name, []))
return VerifyJsonRequest(
server_name,
# We defer creating the redacted json object, as it uses a lot more
# memory than the Event object itself.
lambda: prune_event_dict(event.room_version, event.get_pdu_json()),
minimum_valid_until_ms,
request_name=event.event_id,
key_ids=key_ids,
)
class KeyLookupError(ValueError):
@ -147,8 +189,13 @@ class Keyring:
Deferred[None]: completes if the the object was correctly signed, otherwise
errbacks with an error
"""
req = VerifyJsonRequest(server_name, json_object, validity_time, request_name)
requests = (req,)
request = VerifyJsonRequest.from_json_object(
server_name,
json_object,
validity_time,
request_name,
)
requests = (request,)
return make_deferred_yieldable(self._verify_objects(requests)[0])
def verify_json_objects_for_server(
@ -175,10 +222,41 @@ class Keyring:
logcontext.
"""
return self._verify_objects(
VerifyJsonRequest(server_name, json_object, validity_time, request_name)
VerifyJsonRequest.from_json_object(
server_name, json_object, validity_time, request_name
)
for server_name, json_object, validity_time, request_name in server_and_json
)
def verify_events_for_server(
self, server_and_events: Iterable[Tuple[str, EventBase, int]]
) -> List[defer.Deferred]:
"""Bulk verification of signatures on events.
Args:
server_and_events:
Iterable of `(server_name, event, validity_time)` tuples.
`server_name` is which server we are verifying the signature for
on the event.
`event` is the event that we'll verify the signatures of for
the given `server_name`.
`validity_time` is a timestamp at which the signing key must be
valid.
Returns:
List<Deferred[None]>: for each input triplet, a deferred indicating success
or failure to verify each event's signature for the given
server_name. The deferreds run their callbacks in the sentinel
logcontext.
"""
return self._verify_objects(
VerifyJsonRequest.from_event(server_name, event, validity_time)
for server_name, event, validity_time in server_and_events
)
def _verify_objects(
self, verify_requests: Iterable[VerifyJsonRequest]
) -> List[defer.Deferred]:
@ -892,7 +970,7 @@ async def _handle_key_deferred(verify_request: VerifyJsonRequest) -> None:
with PreserveLoggingContext():
_, key_id, verify_key = await verify_request.key_ready
json_object = verify_request.json_object
json_object = verify_request.get_json_object()
try:
verify_signed_json(json_object, server_name, verify_key)

View File

@ -137,11 +137,7 @@ class FederationBase:
return deferreds
class PduToCheckSig(
namedtuple(
"PduToCheckSig", ["pdu", "redacted_pdu_json", "sender_domain", "deferreds"]
)
):
class PduToCheckSig(namedtuple("PduToCheckSig", ["pdu", "sender_domain", "deferreds"])):
pass
@ -184,7 +180,6 @@ def _check_sigs_on_pdus(
pdus_to_check = [
PduToCheckSig(
pdu=p,
redacted_pdu_json=prune_event(p).get_pdu_json(),
sender_domain=get_domain_from_id(p.sender),
deferreds=[],
)
@ -195,13 +190,12 @@ def _check_sigs_on_pdus(
# (except if its a 3pid invite, in which case it may be sent by any server)
pdus_to_check_sender = [p for p in pdus_to_check if not _is_invite_via_3pid(p.pdu)]
more_deferreds = keyring.verify_json_objects_for_server(
more_deferreds = keyring.verify_events_for_server(
[
(
p.sender_domain,
p.redacted_pdu_json,
p.pdu,
p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
p.pdu.event_id,
)
for p in pdus_to_check_sender
]
@ -230,13 +224,12 @@ def _check_sigs_on_pdus(
if p.sender_domain != get_domain_from_id(p.pdu.event_id)
]
more_deferreds = keyring.verify_json_objects_for_server(
more_deferreds = keyring.verify_events_for_server(
[
(
get_domain_from_id(p.pdu.event_id),
p.redacted_pdu_json,
p.pdu,
p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
p.pdu.event_id,
)
for p in pdus_to_check_event_id
]