Faster joins: Support for calling `/federation/v1/state` (#12013)

This is an endpoint that we have server-side support for, but no client-side support. It's going to be useful for resyncing partial-stated rooms, so let's introduce it.
pull/12058/head
Richard van der Hoff 2022-02-22 12:17:10 +00:00 committed by GitHub
parent 066171643b
commit 7273011f60
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 377 additions and 17 deletions

1
changelog.d/12013.misc Normal file
View File

@ -0,0 +1 @@
Preparation for faster-room-join work: Support for calling `/federation/v1/state` on a remote server.

View File

@ -47,6 +47,11 @@ class FederationBase:
) -> EventBase:
"""Checks that event is correctly signed by the sending server.
Also checks the content hash, and redacts the event if there is a mismatch.
Also runs the event through the spam checker; if it fails, redacts the event
and flags it as soft-failed.
Args:
room_version: The room version of the PDU
pdu: the event to be checked
@ -55,7 +60,10 @@ class FederationBase:
* the original event if the checks pass
* a redacted version of the event (if the signature
matched but the hash did not)
* throws a SynapseError if the signature check failed."""
Raises:
SynapseError if the signature check failed.
"""
try:
await _check_sigs_on_pdu(self.keyring, room_version, pdu)
except SynapseError as e:

View File

@ -419,26 +419,90 @@ class FederationClient(FederationBase):
return state_event_ids, auth_event_ids
async def get_room_state(
self,
destination: str,
room_id: str,
event_id: str,
room_version: RoomVersion,
) -> Tuple[List[EventBase], List[EventBase]]:
"""Calls the /state endpoint to fetch the state at a particular point
in the room.
Any invalid events (those with incorrect or unverifiable signatures or hashes)
are filtered out from the response, and any duplicate events are removed.
(Size limits and other event-format checks are *not* performed.)
Note that the result is not ordered, so callers must be careful to process
the events in an order that handles dependencies.
Returns:
a tuple of (state events, auth events)
"""
result = await self.transport_layer.get_room_state(
room_version,
destination,
room_id,
event_id,
)
state_events = result.state
auth_events = result.auth_events
# we may as well filter out any duplicates from the response, to save
# processing them multiple times. (In particular, events may be present in
# `auth_events` as well as `state`, which is redundant).
#
# We don't rely on the sort order of the events, so we can just stick them
# in a dict.
state_event_map = {event.event_id: event for event in state_events}
auth_event_map = {
event.event_id: event
for event in auth_events
if event.event_id not in state_event_map
}
logger.info(
"Processing from /state: %d state events, %d auth events",
len(state_event_map),
len(auth_event_map),
)
valid_auth_events = await self._check_sigs_and_hash_and_fetch(
destination, auth_event_map.values(), room_version
)
valid_state_events = await self._check_sigs_and_hash_and_fetch(
destination, state_event_map.values(), room_version
)
return valid_state_events, valid_auth_events
async def _check_sigs_and_hash_and_fetch(
self,
origin: str,
pdus: Collection[EventBase],
room_version: RoomVersion,
) -> List[EventBase]:
"""Takes a list of PDUs and checks the signatures and hashes of each
one. If a PDU fails its signature check then we check if we have it in
the database and if not then request if from the originating server of
that PDU.
"""Checks the signatures and hashes of a list of events.
If a PDU fails its signature check then we check if we have it in
the database, and if not then request it from the sender's server (if that
is different from `origin`). If that still fails, the event is omitted from
the returned list.
If a PDU fails its content hash check then it is redacted.
The given list of PDUs are not modified, instead the function returns
Also runs each event through the spam checker; if it fails, redacts the event
and flags it as soft-failed.
The given list of PDUs are not modified; instead the function returns
a new list.
Args:
origin
pdu
room_version
origin: The server that sent us these events
pdus: The events to be checked
room_version: the version of the room these events are in
Returns:
A list of PDUs that have valid signatures and hashes.
@ -469,11 +533,16 @@ class FederationClient(FederationBase):
origin: str,
room_version: RoomVersion,
) -> Optional[EventBase]:
"""Takes a PDU and checks its signatures and hashes. If the PDU fails
its signature check then we check if we have it in the database and if
not then request if from the originating server of that PDU.
"""Takes a PDU and checks its signatures and hashes.
If then PDU fails its content hash check then it is redacted.
If the PDU fails its signature check then we check if we have it in the
database; if not, we then request it from sender's server (if that is not the
same as `origin`). If that still fails, we return None.
If the PDU fails its content hash check, it is redacted.
Also runs the event through the spam checker; if it fails, redacts the event
and flags it as soft-failed.
Args:
origin

View File

@ -65,13 +65,12 @@ class TransportLayerClient:
async def get_room_state_ids(
self, destination: str, room_id: str, event_id: str
) -> JsonDict:
"""Requests all state for a given room from the given server at the
given event. Returns the state's event_id's
"""Requests the IDs of all state for a given room at the given event.
Args:
destination: The host name of the remote homeserver we want
to get the state from.
context: The name of the context we want the state of
room_id: the room we want the state of
event_id: The event we want the context at.
Returns:
@ -87,6 +86,29 @@ class TransportLayerClient:
try_trailing_slash_on_400=True,
)
async def get_room_state(
self, room_version: RoomVersion, destination: str, room_id: str, event_id: str
) -> "StateRequestResponse":
"""Requests the full state for a given room at the given event.
Args:
room_version: the version of the room (required to build the event objects)
destination: The host name of the remote homeserver we want
to get the state from.
room_id: the room we want the state of
event_id: The event we want the context at.
Returns:
Results in a dict received from the remote homeserver.
"""
path = _create_v1_path("/state/%s", room_id)
return await self.client.get_json(
destination,
path=path,
args={"event_id": event_id},
parser=_StateParser(room_version),
)
async def get_event(
self, destination: str, event_id: str, timeout: Optional[int] = None
) -> JsonDict:
@ -1284,6 +1306,14 @@ class SendJoinResponse:
servers_in_room: Optional[List[str]] = None
@attr.s(slots=True, auto_attribs=True)
class StateRequestResponse:
"""The parsed response of a `/state` request."""
auth_events: List[EventBase]
state: List[EventBase]
@ijson.coroutine
def _event_parser(event_dict: JsonDict) -> Generator[None, Tuple[str, Any], None]:
"""Helper function for use with `ijson.kvitems_coro` to parse key-value pairs
@ -1411,3 +1441,37 @@ class SendJoinParser(ByteParser[SendJoinResponse]):
self._response.event_dict, self._room_version
)
return self._response
class _StateParser(ByteParser[StateRequestResponse]):
"""A parser for the response to `/state` requests.
Args:
room_version: The version of the room.
"""
CONTENT_TYPE = "application/json"
def __init__(self, room_version: RoomVersion):
self._response = StateRequestResponse([], [])
self._room_version = room_version
self._coros = [
ijson.items_coro(
_event_list_parser(room_version, self._response.state),
"pdus.item",
use_float=True,
),
ijson.items_coro(
_event_list_parser(room_version, self._response.auth_events),
"auth_chain.item",
use_float=True,
),
]
def write(self, data: bytes) -> int:
for c in self._coros:
c.send(data)
return len(data)
def finish(self) -> StateRequestResponse:
return self._response

View File

@ -958,6 +958,7 @@ class MatrixFederationHttpClient:
)
return body
@overload
async def get_json(
self,
destination: str,
@ -967,7 +968,38 @@ class MatrixFederationHttpClient:
timeout: Optional[int] = None,
ignore_backoff: bool = False,
try_trailing_slash_on_400: bool = False,
parser: Literal[None] = None,
max_response_size: Optional[int] = None,
) -> Union[JsonDict, list]:
...
@overload
async def get_json(
self,
destination: str,
path: str,
args: Optional[QueryArgs] = ...,
retry_on_dns_fail: bool = ...,
timeout: Optional[int] = ...,
ignore_backoff: bool = ...,
try_trailing_slash_on_400: bool = ...,
parser: ByteParser[T] = ...,
max_response_size: Optional[int] = ...,
) -> T:
...
async def get_json(
self,
destination: str,
path: str,
args: Optional[QueryArgs] = None,
retry_on_dns_fail: bool = True,
timeout: Optional[int] = None,
ignore_backoff: bool = False,
try_trailing_slash_on_400: bool = False,
parser: Optional[ByteParser] = None,
max_response_size: Optional[int] = None,
):
"""GETs some json from the given host homeserver and path
Args:
@ -992,6 +1024,13 @@ class MatrixFederationHttpClient:
try_trailing_slash_on_400: True if on a 400 M_UNRECOGNIZED
response we should try appending a trailing slash to the end of
the request. Workaround for #3622 in Synapse <= v0.99.3.
parser: The parser to use to decode the response. Defaults to
parsing as JSON.
max_response_size: The maximum size to read from the response. If None,
uses the default.
Returns:
Succeeds when we get a 2xx HTTP response. The
result will be the decoded JSON body.
@ -1026,8 +1065,17 @@ class MatrixFederationHttpClient:
else:
_sec_timeout = self.default_timeout
if parser is None:
parser = JsonParser()
body = await _handle_response(
self.reactor, _sec_timeout, request, response, start_ms, parser=JsonParser()
self.reactor,
_sec_timeout,
request,
response,
start_ms,
parser=parser,
max_response_size=max_response_size,
)
return body

View File

@ -0,0 +1,149 @@
# Copyright 2022 Matrix.org Federation C.I.C
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from unittest import mock
import twisted.web.client
from twisted.internet import defer
from twisted.internet.protocol import Protocol
from twisted.python.failure import Failure
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.room_versions import RoomVersions
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
from tests.unittest import FederatingHomeserverTestCase
class FederationClientTest(FederatingHomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer):
super().prepare(reactor, clock, homeserver)
# mock out the Agent used by the federation client, which is easier than
# catching the HTTPS connection and do the TLS stuff.
self._mock_agent = mock.create_autospec(twisted.web.client.Agent, spec_set=True)
homeserver.get_federation_http_client().agent = self._mock_agent
def test_get_room_state(self):
creator = f"@creator:{self.OTHER_SERVER_NAME}"
test_room_id = "!room_id"
# mock up some events to use in the response.
# In real life, these would have things in `prev_events` and `auth_events`, but that's
# a bit annoying to mock up, and the code under test doesn't care, so we don't bother.
create_event_dict = self.add_hashes_and_signatures(
{
"room_id": test_room_id,
"type": "m.room.create",
"state_key": "",
"sender": creator,
"content": {"creator": creator},
"prev_events": [],
"auth_events": [],
"origin_server_ts": 500,
}
)
member_event_dict = self.add_hashes_and_signatures(
{
"room_id": test_room_id,
"type": "m.room.member",
"sender": creator,
"state_key": creator,
"content": {"membership": "join"},
"prev_events": [],
"auth_events": [],
"origin_server_ts": 600,
}
)
pl_event_dict = self.add_hashes_and_signatures(
{
"room_id": test_room_id,
"type": "m.room.power_levels",
"sender": creator,
"state_key": "",
"content": {},
"prev_events": [],
"auth_events": [],
"origin_server_ts": 700,
}
)
# mock up the response, and have the agent return it
self._mock_agent.request.return_value = defer.succeed(
_mock_response(
{
"pdus": [
create_event_dict,
member_event_dict,
pl_event_dict,
],
"auth_chain": [
create_event_dict,
member_event_dict,
],
}
)
)
# now fire off the request
state_resp, auth_resp = self.get_success(
self.hs.get_federation_client().get_room_state(
"yet_another_server",
test_room_id,
"event_id",
RoomVersions.V9,
)
)
# check the right call got made to the agent
self._mock_agent.request.assert_called_once_with(
b"GET",
b"matrix://yet_another_server/_matrix/federation/v1/state/%21room_id?event_id=event_id",
headers=mock.ANY,
bodyProducer=None,
)
# ... and that the response is correct.
# the auth_resp should be empty because all the events are also in state
self.assertEqual(auth_resp, [])
# all of the events should be returned in state_resp, though not necessarily
# in the same order. We just check the type on the assumption that if the type
# is right, so is the rest of the event.
self.assertCountEqual(
[e.type for e in state_resp],
["m.room.create", "m.room.member", "m.room.power_levels"],
)
def _mock_response(resp: JsonDict):
body = json.dumps(resp).encode("utf-8")
def deliver_body(p: Protocol):
p.dataReceived(body)
p.connectionLost(Failure(twisted.web.client.ResponseDone()))
response = mock.Mock(
code=200,
phrase=b"OK",
headers=twisted.web.client.Headers({"content-Type": ["application/json"]}),
length=len(body),
deliverBody=deliver_body,
)
mock.seal(response)
return response

View File

@ -51,7 +51,10 @@ from twisted.web.server import Request
from synapse import events
from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.config.homeserver import HomeServerConfig
from synapse.config.server import DEFAULT_ROOM_VERSION
from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.federation.transport.server import TransportLayerServer
from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest, SynapseSite
@ -839,6 +842,24 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
client_ip=client_ip,
)
def add_hashes_and_signatures(
self,
event_dict: JsonDict,
room_version: RoomVersion = KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION],
) -> JsonDict:
"""Adds hashes and signatures to the given event dict
Returns:
The modified event dict, for convenience
"""
add_hashes_and_signatures(
room_version,
event_dict,
signature_name=self.OTHER_SERVER_NAME,
signing_key=self.OTHER_SERVER_SIGNATURE_KEY,
)
return event_dict
def _auth_header_for_request(
origin: str,