Implement MSC3706: partial state in `/send_join` response (#11967)

* Make `get_auth_chain_ids` return a Set

It has a set internally, and a set is often useful where it gets used, so let's
avoid converting to an intermediate list.

* Minor refactors in `on_send_join_request`

A little bit of non-functional groundwork

* Implement MSC3706: partial state in /send_join response
pull/11969/head
Richard van der Hoff 2022-02-12 10:44:16 +00:00 committed by GitHub
parent b2b971f28a
commit 63c46349c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 262 additions and 21 deletions

View File

@ -0,0 +1 @@
Experimental implementation of [MSC3706](https://github.com/matrix-org/matrix-doc/pull/3706): extensions to `/send_join` to support reduced response size.

View File

@ -61,3 +61,6 @@ class ExperimentalConfig(Config):
self.msc2409_to_device_messages_enabled: bool = experimental.get( self.msc2409_to_device_messages_enabled: bool = experimental.get(
"msc2409_to_device_messages_enabled", False "msc2409_to_device_messages_enabled", False
) )
# MSC3706 (server-side support for partial state in /send_join responses)
self.msc3706_enabled: bool = experimental.get("msc3706_enabled", False)

View File

@ -20,6 +20,7 @@ from typing import (
Any, Any,
Awaitable, Awaitable,
Callable, Callable,
Collection,
Dict, Dict,
Iterable, Iterable,
List, List,
@ -64,7 +65,7 @@ from synapse.replication.http.federation import (
ReplicationGetQueryRestServlet, ReplicationGetQueryRestServlet,
) )
from synapse.storage.databases.main.lock import Lock from synapse.storage.databases.main.lock import Lock
from synapse.types import JsonDict, get_domain_from_id from synapse.types import JsonDict, StateMap, get_domain_from_id
from synapse.util import json_decoder, unwrapFirstError from synapse.util import json_decoder, unwrapFirstError
from synapse.util.async_helpers import Linearizer, concurrently_execute, gather_results from synapse.util.async_helpers import Linearizer, concurrently_execute, gather_results
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
@ -571,7 +572,7 @@ class FederationServer(FederationBase):
) -> JsonDict: ) -> JsonDict:
state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id) state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id)
auth_chain_ids = await self.store.get_auth_chain_ids(room_id, state_ids) auth_chain_ids = await self.store.get_auth_chain_ids(room_id, state_ids)
return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids} return {"pdu_ids": state_ids, "auth_chain_ids": list(auth_chain_ids)}
async def _on_context_state_request_compute( async def _on_context_state_request_compute(
self, room_id: str, event_id: Optional[str] self, room_id: str, event_id: Optional[str]
@ -645,27 +646,61 @@ class FederationServer(FederationBase):
return {"event": ret_pdu.get_pdu_json(time_now)} return {"event": ret_pdu.get_pdu_json(time_now)}
async def on_send_join_request( async def on_send_join_request(
self, origin: str, content: JsonDict, room_id: str self,
origin: str,
content: JsonDict,
room_id: str,
caller_supports_partial_state: bool = False,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
event, context = await self._on_send_membership_event( event, context = await self._on_send_membership_event(
origin, content, Membership.JOIN, room_id origin, content, Membership.JOIN, room_id
) )
prev_state_ids = await context.get_prev_state_ids() prev_state_ids = await context.get_prev_state_ids()
state_ids = list(prev_state_ids.values())
auth_chain = await self.store.get_auth_chain(room_id, state_ids)
state = await self.store.get_events(state_ids)
state_event_ids: Collection[str]
servers_in_room: Optional[Collection[str]]
if caller_supports_partial_state:
state_event_ids = _get_event_ids_for_partial_state_join(
event, prev_state_ids
)
servers_in_room = await self.state.get_hosts_in_room_at_events(
room_id, event_ids=event.prev_event_ids()
)
else:
state_event_ids = prev_state_ids.values()
servers_in_room = None
auth_chain_event_ids = await self.store.get_auth_chain_ids(
room_id, state_event_ids
)
# if the caller has opted in, we can omit any auth_chain events which are
# already in state_event_ids
if caller_supports_partial_state:
auth_chain_event_ids.difference_update(state_event_ids)
auth_chain_events = await self.store.get_events_as_list(auth_chain_event_ids)
state_events = await self.store.get_events_as_list(state_event_ids)
# we try to do all the async stuff before this point, so that time_now is as
# accurate as possible.
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
event_json = event.get_pdu_json() event_json = event.get_pdu_json(time_now)
return { resp = {
# TODO Remove the unstable prefix when servers have updated. # TODO Remove the unstable prefix when servers have updated.
"org.matrix.msc3083.v2.event": event_json, "org.matrix.msc3083.v2.event": event_json,
"event": event_json, "event": event_json,
"state": [p.get_pdu_json(time_now) for p in state.values()], "state": [p.get_pdu_json(time_now) for p in state_events],
"auth_chain": [p.get_pdu_json(time_now) for p in auth_chain], "auth_chain": [p.get_pdu_json(time_now) for p in auth_chain_events],
"org.matrix.msc3706.partial_state": caller_supports_partial_state,
} }
if servers_in_room is not None:
resp["org.matrix.msc3706.servers_in_room"] = list(servers_in_room)
return resp
async def on_make_leave_request( async def on_make_leave_request(
self, origin: str, room_id: str, user_id: str self, origin: str, room_id: str, user_id: str
) -> Dict[str, Any]: ) -> Dict[str, Any]:
@ -1339,3 +1374,39 @@ class FederationHandlerRegistry:
# error. # error.
logger.warning("No handler registered for query type %s", query_type) logger.warning("No handler registered for query type %s", query_type)
raise NotFoundError("No handler for Query type '%s'" % (query_type,)) raise NotFoundError("No handler for Query type '%s'" % (query_type,))
def _get_event_ids_for_partial_state_join(
join_event: EventBase,
prev_state_ids: StateMap[str],
) -> Collection[str]:
"""Calculate state to be retuned in a partial_state send_join
Args:
join_event: the join event being send_joined
prev_state_ids: the event ids of the state before the join
Returns:
the event ids to be returned
"""
# return all non-member events
state_event_ids = {
event_id
for (event_type, state_key), event_id in prev_state_ids.items()
if event_type != EventTypes.Member
}
# we also need the current state of the current user (it's going to
# be an auth event for the new join, so we may as well return it)
current_membership_event_id = prev_state_ids.get(
(EventTypes.Member, join_event.state_key)
)
if current_membership_event_id is not None:
state_event_ids.add(current_membership_event_id)
# TODO: return a few more members:
# - those with invites
# - those that are kicked? / banned
return state_event_ids

View File

@ -412,6 +412,16 @@ class FederationV2SendJoinServlet(BaseFederationServerServlet):
PREFIX = FEDERATION_V2_PREFIX PREFIX = FEDERATION_V2_PREFIX
def __init__(
self,
hs: "HomeServer",
authenticator: Authenticator,
ratelimiter: FederationRateLimiter,
server_name: str,
):
super().__init__(hs, authenticator, ratelimiter, server_name)
self._msc3706_enabled = hs.config.experimental.msc3706_enabled
async def on_PUT( async def on_PUT(
self, self,
origin: str, origin: str,
@ -422,7 +432,15 @@ class FederationV2SendJoinServlet(BaseFederationServerServlet):
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
# TODO(paul): assert that event_id parsed from path actually # TODO(paul): assert that event_id parsed from path actually
# match those given in content # match those given in content
result = await self.handler.on_send_join_request(origin, content, room_id)
partial_state = False
if self._msc3706_enabled:
partial_state = parse_boolean_from_args(
query, "org.matrix.msc3706.partial_state", default=False
)
result = await self.handler.on_send_join_request(
origin, content, room_id, caller_supports_partial_state=partial_state
)
return 200, result return 200, result

View File

@ -121,7 +121,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
room_id: str, room_id: str,
event_ids: Collection[str], event_ids: Collection[str],
include_given: bool = False, include_given: bool = False,
) -> List[str]: ) -> Set[str]:
"""Get auth events for given event_ids. The events *must* be state events. """Get auth events for given event_ids. The events *must* be state events.
Args: Args:
@ -130,7 +130,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
include_given: include the given events in result include_given: include the given events in result
Returns: Returns:
list of event_ids set of event_ids
""" """
# Check if we have indexed the room so we can use the chain cover # Check if we have indexed the room so we can use the chain cover
@ -159,7 +159,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
def _get_auth_chain_ids_using_cover_index_txn( def _get_auth_chain_ids_using_cover_index_txn(
self, txn: Cursor, room_id: str, event_ids: Collection[str], include_given: bool self, txn: Cursor, room_id: str, event_ids: Collection[str], include_given: bool
) -> List[str]: ) -> Set[str]:
"""Calculates the auth chain IDs using the chain index.""" """Calculates the auth chain IDs using the chain index."""
# First we look up the chain ID/sequence numbers for the given events. # First we look up the chain ID/sequence numbers for the given events.
@ -272,11 +272,11 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
txn.execute(sql, (chain_id, max_no)) txn.execute(sql, (chain_id, max_no))
results.update(r for r, in txn) results.update(r for r, in txn)
return list(results) return results
def _get_auth_chain_ids_txn( def _get_auth_chain_ids_txn(
self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool
) -> List[str]: ) -> Set[str]:
"""Calculates the auth chain IDs. """Calculates the auth chain IDs.
This is used when we don't have a cover index for the room. This is used when we don't have a cover index for the room.
@ -331,7 +331,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
front = new_front front = new_front
results.update(front) results.update(front)
return list(results) return results
async def get_auth_chain_difference( async def get_auth_chain_difference(
self, room_id: str, state_sets: List[Set[str]] self, room_id: str, state_sets: List[Set[str]]

View File

@ -16,12 +16,21 @@ import logging
from parameterized import parameterized from parameterized import parameterized
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.config.server import DEFAULT_ROOM_VERSION
from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.events import make_event_from_dict from synapse.events import make_event_from_dict
from synapse.federation.federation_server import server_matches_acl_event from synapse.federation.federation_server import server_matches_acl_event
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import login, room from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.unittest import override_config
class FederationServerTests(unittest.FederatingHomeserverTestCase): class FederationServerTests(unittest.FederatingHomeserverTestCase):
@ -152,6 +161,145 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase):
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
servlets = [
admin.register_servlets,
room.register_servlets,
login.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
super().prepare(reactor, clock, hs)
# create the room
creator_user_id = self.register_user("kermit", "test")
tok = self.login("kermit", "test")
self._room_id = self.helper.create_room_as(
room_creator=creator_user_id, tok=tok
)
# a second member on the orgin HS
second_member_user_id = self.register_user("fozzie", "bear")
tok2 = self.login("fozzie", "bear")
self.helper.join(self._room_id, second_member_user_id, tok=tok2)
def _make_join(self, user_id) -> JsonDict:
channel = self.make_signed_federation_request(
"GET",
f"/_matrix/federation/v1/make_join/{self._room_id}/{user_id}"
f"?ver={DEFAULT_ROOM_VERSION}",
)
self.assertEquals(channel.code, 200, channel.json_body)
return channel.json_body
def test_send_join(self):
"""happy-path test of send_join"""
joining_user = "@misspiggy:" + self.OTHER_SERVER_NAME
join_result = self._make_join(joining_user)
join_event_dict = join_result["event"]
add_hashes_and_signatures(
KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION],
join_event_dict,
signature_name=self.OTHER_SERVER_NAME,
signing_key=self.OTHER_SERVER_SIGNATURE_KEY,
)
channel = self.make_signed_federation_request(
"PUT",
f"/_matrix/federation/v2/send_join/{self._room_id}/x",
content=join_event_dict,
)
self.assertEquals(channel.code, 200, channel.json_body)
# we should get complete room state back
returned_state = [
(ev["type"], ev["state_key"]) for ev in channel.json_body["state"]
]
self.assertCountEqual(
returned_state,
[
("m.room.create", ""),
("m.room.power_levels", ""),
("m.room.join_rules", ""),
("m.room.history_visibility", ""),
("m.room.member", "@kermit:test"),
("m.room.member", "@fozzie:test"),
# nb: *not* the joining user
],
)
# also check the auth chain
returned_auth_chain_events = [
(ev["type"], ev["state_key"]) for ev in channel.json_body["auth_chain"]
]
self.assertCountEqual(
returned_auth_chain_events,
[
("m.room.create", ""),
("m.room.member", "@kermit:test"),
("m.room.power_levels", ""),
("m.room.join_rules", ""),
],
)
# the room should show that the new user is a member
r = self.get_success(
self.hs.get_state_handler().get_current_state(self._room_id)
)
self.assertEqual(r[("m.room.member", joining_user)].membership, "join")
@override_config({"experimental_features": {"msc3706_enabled": True}})
def test_send_join_partial_state(self):
"""When MSC3706 support is enabled, /send_join should return partial state"""
joining_user = "@misspiggy:" + self.OTHER_SERVER_NAME
join_result = self._make_join(joining_user)
join_event_dict = join_result["event"]
add_hashes_and_signatures(
KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION],
join_event_dict,
signature_name=self.OTHER_SERVER_NAME,
signing_key=self.OTHER_SERVER_SIGNATURE_KEY,
)
channel = self.make_signed_federation_request(
"PUT",
f"/_matrix/federation/v2/send_join/{self._room_id}/x?org.matrix.msc3706.partial_state=true",
content=join_event_dict,
)
self.assertEquals(channel.code, 200, channel.json_body)
# expect a reduced room state
returned_state = [
(ev["type"], ev["state_key"]) for ev in channel.json_body["state"]
]
self.assertCountEqual(
returned_state,
[
("m.room.create", ""),
("m.room.power_levels", ""),
("m.room.join_rules", ""),
("m.room.history_visibility", ""),
],
)
# the auth chain should not include anything already in "state"
returned_auth_chain_events = [
(ev["type"], ev["state_key"]) for ev in channel.json_body["auth_chain"]
]
self.assertCountEqual(
returned_auth_chain_events,
[
("m.room.member", "@kermit:test"),
],
)
# the room should show that the new user is a member
r = self.get_success(
self.hs.get_state_handler().get_current_state(self._room_id)
)
self.assertEqual(r[("m.room.member", joining_user)].membership, "join")
def _create_acl_event(content): def _create_acl_event(content):
return make_event_from_dict( return make_event_from_dict(
{ {

View File

@ -260,16 +260,16 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
self.assertCountEqual(auth_chain_ids, ["h", "i", "j", "k"]) self.assertCountEqual(auth_chain_ids, ["h", "i", "j", "k"])
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["h"])) auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["h"]))
self.assertEqual(auth_chain_ids, ["k"]) self.assertEqual(auth_chain_ids, {"k"})
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["i"])) auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["i"]))
self.assertEqual(auth_chain_ids, ["j"]) self.assertEqual(auth_chain_ids, {"j"})
# j and k have no parents. # j and k have no parents.
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["j"])) auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["j"]))
self.assertEqual(auth_chain_ids, []) self.assertEqual(auth_chain_ids, set())
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["k"])) auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["k"]))
self.assertEqual(auth_chain_ids, []) self.assertEqual(auth_chain_ids, set())
# More complex input sequences. # More complex input sequences.
auth_chain_ids = self.get_success( auth_chain_ids = self.get_success(