Faster joins: Refactor handling of servers in room (#14954)

Ensure that the list of servers in a partial state room always contains
the server we joined off.

Also refactor `get_partial_state_servers_at_join` to return `None` when
the given room is no longer partial stated, to explicitly indicate when
the room has partial state. Otherwise it's not clear whether an empty
list means that the room has full state, or the room is partial stated,
but the server we joined off told us that there are no servers in the
room.

Signed-off-by: Sean Quah <seanq@matrix.org>
pull/14984/head
Sean Quah 2023-02-03 15:39:59 +00:00 committed by GitHub
parent 8e9fc28c6a
commit 0a686d1d13
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 77 additions and 37 deletions

1
changelog.d/14954.misc Normal file
View File

@ -0,0 +1 @@
Faster room joins: Refactor internal handling of servers in room to never store an empty list.

View File

@ -19,6 +19,7 @@ import itertools
import logging
from typing import (
TYPE_CHECKING,
AbstractSet,
Awaitable,
Callable,
Collection,
@ -110,8 +111,9 @@ class SendJoinResult:
# True if 'state' elides non-critical membership events
partial_state: bool
# if 'partial_state' is set, a list of the servers in the room (otherwise empty)
servers_in_room: List[str]
# If 'partial_state' is set, a set of the servers in the room (otherwise empty).
# Always contains the server we joined off.
servers_in_room: AbstractSet[str]
class FederationClient(FederationBase):
@ -1152,15 +1154,24 @@ class FederationClient(FederationBase):
% (auth_chain_create_events,)
)
if response.members_omitted and not response.servers_in_room:
raise InvalidResponseError(
"members_omitted was set, but no servers were listed in the room"
)
servers_in_room = None
if response.servers_in_room is not None:
servers_in_room = set(response.servers_in_room)
if response.members_omitted and not partial_state:
raise InvalidResponseError(
"members_omitted was set, but we asked for full state"
)
if response.members_omitted:
if not servers_in_room:
raise InvalidResponseError(
"members_omitted was set, but no servers were listed in the room"
)
if not partial_state:
raise InvalidResponseError(
"members_omitted was set, but we asked for full state"
)
# `servers_in_room` is supposed to be a complete list.
# Fix things up in case the remote homeserver is badly behaved.
servers_in_room.add(destination)
return SendJoinResult(
event=event,
@ -1168,7 +1179,7 @@ class FederationClient(FederationBase):
auth_chain=signed_auth,
origin=destination,
partial_state=response.members_omitted,
servers_in_room=response.servers_in_room or [],
servers_in_room=servers_in_room or frozenset(),
)
# MSC3083 defines additional error codes for room joins.

View File

@ -447,7 +447,7 @@ class FederationSender(AbstractFederationSender):
)
)
if len(partial_state_destinations) > 0:
if partial_state_destinations is not None:
destinations = partial_state_destinations
if destinations is None:

View File

@ -859,6 +859,7 @@ class DeviceHandler(DeviceWorkerHandler):
known_hosts_at_join = await self.store.get_partial_state_servers_at_join(
room_id
)
assert known_hosts_at_join is not None
potentially_changed_hosts.difference_update(known_hosts_at_join)
potentially_changed_hosts.discard(self.server_name)

View File

@ -20,7 +20,17 @@ import itertools
import logging
from enum import Enum
from http import HTTPStatus
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple, Union
from typing import (
TYPE_CHECKING,
AbstractSet,
Dict,
Iterable,
List,
Optional,
Set,
Tuple,
Union,
)
import attr
from prometheus_client import Histogram
@ -169,7 +179,7 @@ class FederationHandler:
# A dictionary mapping room IDs to (initial destination, other destinations)
# tuples.
self._partial_state_syncs_maybe_needing_restart: Dict[
str, Tuple[Optional[str], StrCollection]
str, Tuple[Optional[str], AbstractSet[str]]
] = {}
# A lock guarding the partial state flag for rooms.
# When the lock is held for a given room, no other concurrent code may
@ -1720,7 +1730,7 @@ class FederationHandler:
def _start_partial_state_room_sync(
self,
initial_destination: Optional[str],
other_destinations: StrCollection,
other_destinations: AbstractSet[str],
room_id: str,
) -> None:
"""Starts the background process to resync the state of a partial state room,
@ -1802,7 +1812,7 @@ class FederationHandler:
async def _sync_partial_state_room(
self,
initial_destination: Optional[str],
other_destinations: StrCollection,
other_destinations: AbstractSet[str],
room_id: str,
) -> None:
"""Background process to resync the state of a partial-state room
@ -1939,7 +1949,7 @@ class FederationHandler:
def _prioritise_destinations_for_partial_state_resync(
initial_destination: Optional[str],
other_destinations: StrCollection,
other_destinations: AbstractSet[str],
room_id: str,
) -> StrCollection:
"""Work out the order in which we should ask servers to resync events.

View File

@ -569,10 +569,11 @@ class StateStorageController:
is arbitrary for rooms with partial state.
"""
# We have to read this list first to mitigate races with un-partial stating.
# This will be empty for rooms with full state.
hosts_at_join = await self.stores.main.get_partial_state_servers_at_join(
room_id
)
if hosts_at_join is None:
hosts_at_join = frozenset()
hosts_from_state = await self.stores.main.get_current_hosts_in_room(room_id)

View File

@ -18,6 +18,7 @@ from abc import abstractmethod
from enum import Enum
from typing import (
TYPE_CHECKING,
AbstractSet,
Any,
Awaitable,
Collection,
@ -25,7 +26,6 @@ from typing import (
List,
Mapping,
Optional,
Sequence,
Set,
Tuple,
Union,
@ -109,7 +109,7 @@ class RoomSortOrder(Enum):
@attr.s(slots=True, frozen=True, auto_attribs=True)
class PartialStateResyncInfo:
joined_via: Optional[str]
servers_in_room: List[str] = attr.ib(factory=list)
servers_in_room: Set[str] = attr.ib(factory=set)
class RoomWorkerStore(CacheInvalidationWorkerStore):
@ -1193,21 +1193,35 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
get_rooms_for_retention_period_in_range_txn,
)
@cached(iterable=True)
async def get_partial_state_servers_at_join(self, room_id: str) -> Sequence[str]:
"""Gets the list of servers in a partial state room at the time we joined it.
async def get_partial_state_servers_at_join(
self, room_id: str
) -> Optional[AbstractSet[str]]:
"""Gets the set of servers in a partial state room at the time we joined it.
Returns:
The `servers_in_room` list from the `/send_join` response for partial state
rooms. May not be accurate or complete, as it comes from a remote
homeserver.
An empty list for full state rooms.
`None` for full state rooms.
"""
return await self.db_pool.simple_select_onecol(
"partial_state_rooms_servers",
keyvalues={"room_id": room_id},
retcol="server_name",
desc="get_partial_state_servers_at_join",
servers_in_room = await self._get_partial_state_servers_at_join(room_id)
if len(servers_in_room) == 0:
return None
return servers_in_room
@cached(iterable=True)
async def _get_partial_state_servers_at_join(
self, room_id: str
) -> AbstractSet[str]:
return frozenset(
await self.db_pool.simple_select_onecol(
"partial_state_rooms_servers",
keyvalues={"room_id": room_id},
retcol="server_name",
desc="get_partial_state_servers_at_join",
)
)
async def get_partial_state_room_resync_info(
@ -1252,7 +1266,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
# partial-joined between the two SELECTs, but this is unlikely to happen
# in practice.)
continue
entry.servers_in_room.append(server_name)
entry.servers_in_room.add(server_name)
return room_servers
@ -1942,7 +1956,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
async def store_partial_state_room(
self,
room_id: str,
servers: Collection[str],
servers: AbstractSet[str],
device_lists_stream_id: int,
joined_via: str,
) -> None:
@ -1957,11 +1971,13 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
Args:
room_id: the ID of the room
servers: other servers known to be in the room
servers: other servers known to be in the room. must include `joined_via`.
device_lists_stream_id: the device_lists stream ID at the time when we first
joined the room.
joined_via: the server name we requested a partial join from.
"""
assert joined_via in servers
await self.db_pool.runInteraction(
"store_partial_state_room",
self._store_partial_state_room_txn,
@ -1975,7 +1991,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
self,
txn: LoggingTransaction,
room_id: str,
servers: Collection[str],
servers: AbstractSet[str],
device_lists_stream_id: int,
joined_via: str,
) -> None:
@ -1998,7 +2014,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
)
self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,))
self._invalidate_cache_and_stream(
txn, self.get_partial_state_servers_at_join, (room_id,)
txn, self._get_partial_state_servers_at_join, (room_id,)
)
async def write_partial_state_rooms_join_event_id(
@ -2409,7 +2425,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
)
self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,))
self._invalidate_cache_and_stream(
txn, self.get_partial_state_servers_at_join, (room_id,)
txn, self._get_partial_state_servers_at_join, (room_id,)
)
DatabasePool.simple_insert_txn(

View File

@ -656,7 +656,7 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
EVENT_INVITATION_MEMBERSHIP,
],
partial_state=True,
servers_in_room=["example.com"],
servers_in_room={"example.com"},
)
)
)

View File

@ -171,7 +171,7 @@ class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase):
state=[create_event],
auth_chain=[create_event],
partial_state=False,
servers_in_room=[],
servers_in_room=frozenset(),
)
)
)