Process device list updates asynchronously (#12365)

pull/12455/head
Erik Johnston 2022-04-12 16:50:40 +01:00 committed by GitHub
parent 4bdbebccb9
commit aa28110264
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 40 additions and 119 deletions

View File

@ -0,0 +1 @@
Enable processing of device list updates asynchronously.

View File

@ -680,14 +680,6 @@ class ServerConfig(Config):
config.get("use_account_validity_in_account_status") or False config.get("use_account_validity_in_account_status") or False
) )
# This is a temporary option that enables fully using the new
# `device_lists_changes_in_room` without the backwards compat code. This
# is primarily for testing. If enabled the server should *not* be
# downgraded, as it may lead to missing device list updates.
self.use_new_device_lists_changes_in_room = (
config.get("use_new_device_lists_changes_in_room") or False
)
self.rooms_to_exclude_from_sync: List[str] = ( self.rooms_to_exclude_from_sync: List[str] = (
config.get("exclude_rooms_from_sync") or [] config.get("exclude_rooms_from_sync") or []
) )

View File

@ -291,12 +291,6 @@ class DeviceHandler(DeviceWorkerHandler):
# On start up check if there are any updates pending. # On start up check if there are any updates pending.
hs.get_reactor().callWhenRunning(self._handle_new_device_update_async) hs.get_reactor().callWhenRunning(self._handle_new_device_update_async)
# Used to decide if we calculate outbound pokes up front or not. By
# default we do to allow safely downgrading Synapse.
self.use_new_device_lists_changes_in_room = (
hs.config.server.use_new_device_lists_changes_in_room
)
def _check_device_name_length(self, name: Optional[str]) -> None: def _check_device_name_length(self, name: Optional[str]) -> None:
""" """
Checks whether a device name is longer than the maximum allowed length. Checks whether a device name is longer than the maximum allowed length.
@ -490,23 +484,9 @@ class DeviceHandler(DeviceWorkerHandler):
room_ids = await self.store.get_rooms_for_user(user_id) room_ids = await self.store.get_rooms_for_user(user_id)
hosts: Optional[Set[str]] = None
if not self.use_new_device_lists_changes_in_room:
hosts = set()
if self.hs.is_mine_id(user_id):
for room_id in room_ids:
joined_users = await self.store.get_users_in_room(room_id)
hosts.update(get_domain_from_id(u) for u in joined_users)
set_tag("target_hosts", hosts)
hosts.discard(self.server_name)
position = await self.store.add_device_change_to_streams( position = await self.store.add_device_change_to_streams(
user_id, user_id,
device_ids, device_ids,
hosts=hosts,
room_ids=room_ids, room_ids=room_ids,
) )
@ -528,14 +508,6 @@ class DeviceHandler(DeviceWorkerHandler):
# We may need to do some processing asynchronously. # We may need to do some processing asynchronously.
self._handle_new_device_update_async() self._handle_new_device_update_async()
if hosts:
logger.info(
"Sending device list update notif for %r to: %r", user_id, hosts
)
for host in hosts:
self.federation_sender.send_device_messages(host, immediate=False)
log_kv({"message": "sent device update to host", "host": host})
async def notify_user_signature_update( async def notify_user_signature_update(
self, from_user_id: str, user_ids: List[str] self, from_user_id: str, user_ids: List[str]
) -> None: ) -> None:

View File

@ -1582,7 +1582,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
self, self,
user_id: str, user_id: str,
device_ids: Collection[str], device_ids: Collection[str],
hosts: Optional[Collection[str]],
room_ids: Collection[str], room_ids: Collection[str],
) -> Optional[int]: ) -> Optional[int]:
"""Persist that a user's devices have been updated, and which hosts """Persist that a user's devices have been updated, and which hosts
@ -1592,9 +1591,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
user_id: The ID of the user whose device changed. user_id: The ID of the user whose device changed.
device_ids: The IDs of any changed devices. If empty, this function will device_ids: The IDs of any changed devices. If empty, this function will
return None. return None.
hosts: The remote destinations that should be notified of the change. If
None then the set of hosts have *not* been calculated, and will be
calculated later by a background task.
room_ids: The rooms that the user is in room_ids: The rooms that the user is in
Returns: Returns:
@ -1606,14 +1602,12 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
context = get_active_span_text_map() context = get_active_span_text_map()
def add_device_changes_txn( def add_device_changes_txn(txn, stream_ids):
txn, stream_ids_for_device_change, stream_ids_for_outbound_pokes
):
self._add_device_change_to_stream_txn( self._add_device_change_to_stream_txn(
txn, txn,
user_id, user_id,
device_ids, device_ids,
stream_ids_for_device_change, stream_ids,
) )
self._add_device_outbound_room_poke_txn( self._add_device_outbound_room_poke_txn(
@ -1621,43 +1615,17 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
user_id, user_id,
device_ids, device_ids,
room_ids, room_ids,
stream_ids_for_device_change, stream_ids,
context,
hosts_have_been_calculated=hosts is not None,
)
# If the set of hosts to send to has not been calculated yet (and so
# `hosts` is None) or there are no `hosts` to send to, then skip
# trying to persist them to the DB.
if not hosts:
return
self._add_device_outbound_poke_to_stream_txn(
txn,
user_id,
device_ids,
hosts,
stream_ids_for_outbound_pokes,
context, context,
) )
# `device_lists_stream` wants a stream ID per device update. async with self._device_list_id_gen.get_next_mult(
num_stream_ids = len(device_ids) len(device_ids)
) as stream_ids:
if hosts:
# `device_lists_outbound_pokes` wants a different stream ID for
# each row, which is a row per host per device update.
num_stream_ids += len(hosts) * len(device_ids)
async with self._device_list_id_gen.get_next_mult(num_stream_ids) as stream_ids:
stream_ids_for_device_change = stream_ids[: len(device_ids)]
stream_ids_for_outbound_pokes = stream_ids[len(device_ids) :]
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
"add_device_change_to_stream", "add_device_change_to_stream",
add_device_changes_txn, add_device_changes_txn,
stream_ids_for_device_change, stream_ids,
stream_ids_for_outbound_pokes,
) )
return stream_ids[-1] return stream_ids[-1]
@ -1752,19 +1720,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
room_ids: Collection[str], room_ids: Collection[str],
stream_ids: List[str], stream_ids: List[str],
context: Dict[str, str], context: Dict[str, str],
hosts_have_been_calculated: bool,
) -> None: ) -> None:
"""Record the user in the room has updated their device. """Record the user in the room has updated their device."""
Args:
hosts_have_been_calculated: True if `device_lists_outbound_pokes`
has been updated already with the updates.
"""
# We only need to convert to outbound pokes if they are our user.
converted_to_destinations = (
hosts_have_been_calculated or not self.hs.is_mine_id(user_id)
)
encoded_context = json_encoder.encode(context) encoded_context = json_encoder.encode(context)
@ -1789,7 +1746,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
device_id, device_id,
room_id, room_id,
stream_id, stream_id,
converted_to_destinations, False,
encoded_context, encoded_context,
) )
for room_id in room_ids for room_id in room_ids

View File

@ -66,9 +66,9 @@ Changes in SCHEMA_VERSION = 69:
SCHEMA_COMPAT_VERSION = ( SCHEMA_COMPAT_VERSION = (
# we now have `state_key` columns in both `events` and `state_events`, so # We now assume that `device_lists_changes_in_room` has been filled out for
# now incompatible with synapses wth SCHEMA_VERSION < 66. # recent device_list_updates.
66 69
) )
"""Limit on how far the synapse codebase can be rolled back without breaking db compat """Limit on how far the synapse codebase can be rolled back without breaking db compat

View File

@ -14,7 +14,6 @@
from typing import Optional from typing import Optional
from unittest.mock import Mock from unittest.mock import Mock
from parameterized import parameterized_class
from signedjson import key, sign from signedjson import key, sign
from signedjson.types import BaseKey, SigningKey from signedjson.types import BaseKey, SigningKey
@ -155,12 +154,6 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
) )
@parameterized_class(
[
{"enable_room_poke_code_path": False},
{"enable_room_poke_code_path": True},
]
)
class FederationSenderDevicesTestCases(HomeserverTestCase): class FederationSenderDevicesTestCases(HomeserverTestCase):
servlets = [ servlets = [
admin.register_servlets, admin.register_servlets,
@ -175,7 +168,6 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
def default_config(self): def default_config(self):
c = super().default_config() c = super().default_config()
c["send_federation"] = True c["send_federation"] = True
c["use_new_device_lists_changes_in_room"] = self.enable_room_poke_code_path
return c return c
def prepare(self, reactor, clock, hs): def prepare(self, reactor, clock, hs):

View File

@ -21,6 +21,29 @@ class DeviceStoreTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, hs): def prepare(self, reactor, clock, hs):
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
def add_device_change(self, user_id, device_ids, host):
"""Add a device list change for the given device to
`device_lists_outbound_pokes` table.
"""
for device_id in device_ids:
stream_id = self.get_success(
self.store.add_device_change_to_streams(
"user_id", [device_id], ["!some:room"]
)
)
self.get_success(
self.store.add_device_list_outbound_pokes(
user_id=user_id,
device_id=device_id,
room_id="!some:room",
stream_id=stream_id,
hosts=[host],
context={},
)
)
def test_store_new_device(self): def test_store_new_device(self):
self.get_success( self.get_success(
self.store.store_device("user_id", "device_id", "display_name") self.store.store_device("user_id", "device_id", "display_name")
@ -95,11 +118,7 @@ class DeviceStoreTestCase(HomeserverTestCase):
device_ids = ["device_id1", "device_id2"] device_ids = ["device_id1", "device_id2"]
# Add two device updates with sequential `stream_id`s # Add two device updates with sequential `stream_id`s
self.get_success( self.add_device_change("user_id", device_ids, "somehost")
self.store.add_device_change_to_streams(
"user_id", device_ids, ["somehost"], ["!some:room"]
)
)
# Get all device updates ever meant for this remote # Get all device updates ever meant for this remote
now_stream_id, device_updates = self.get_success( now_stream_id, device_updates = self.get_success(
@ -123,11 +142,7 @@ class DeviceStoreTestCase(HomeserverTestCase):
"device_id4", "device_id4",
"device_id5", "device_id5",
] ]
self.get_success( self.add_device_change("user_id", device_ids, "somehost")
self.store.add_device_change_to_streams(
"user_id", device_ids, ["somehost"], ["!some:room"]
)
)
# Get device updates meant for this remote # Get device updates meant for this remote
next_stream_id, device_updates = self.get_success( next_stream_id, device_updates = self.get_success(
@ -147,11 +162,7 @@ class DeviceStoreTestCase(HomeserverTestCase):
# Add some more device updates to ensure it still resumes properly # Add some more device updates to ensure it still resumes properly
device_ids = ["device_id6", "device_id7"] device_ids = ["device_id6", "device_id7"]
self.get_success( self.add_device_change("user_id", device_ids, "somehost")
self.store.add_device_change_to_streams(
"user_id", device_ids, ["somehost"], ["!some:room"]
)
)
# Get the next batch of device updates # Get the next batch of device updates
next_stream_id, device_updates = self.get_success( next_stream_id, device_updates = self.get_success(
@ -224,11 +235,7 @@ class DeviceStoreTestCase(HomeserverTestCase):
"fakeSelfSigning", "fakeSelfSigning",
] ]
self.get_success( self.add_device_change("@user_id:test", device_ids, "somehost")
self.store.add_device_change_to_streams(
"@user_id:test", device_ids, ["somehost"], ["!some:room"]
)
)
# Get device updates meant for this remote # Get device updates meant for this remote
next_stream_id, device_updates = self.get_success( next_stream_id, device_updates = self.get_success(