Convert device handler to async/await (#7871)

pull/7886/head
Patrick Cloke 2020-07-17 07:09:25 -04:00 committed by GitHub
parent 00e57b755c
commit 6b3ac3b8cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 162 additions and 166 deletions

1
changelog.d/7871.misc Normal file
View File

@ -0,0 +1 @@
Convert device handler to async/await.

View File

@ -15,9 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Dict, Optional
from twisted.internet import defer
from typing import Any, Dict, List, Optional
from synapse.api import errors
from synapse.api.constants import EventTypes
@ -57,21 +55,20 @@ class DeviceWorkerHandler(BaseHandler):
self._auth_handler = hs.get_auth_handler()
@trace
@defer.inlineCallbacks
def get_devices_by_user(self, user_id):
async def get_devices_by_user(self, user_id: str) -> List[Dict[str, Any]]:
"""
Retrieve the given user's devices
Args:
user_id (str):
user_id: The user ID to query for devices.
Returns:
defer.Deferred: list[dict[str, X]]: info on each device
info on each device
"""
set_tag("user_id", user_id)
device_map = yield self.store.get_devices_by_user(user_id)
device_map = await self.store.get_devices_by_user(user_id)
ips = yield self.store.get_last_client_ip_by_device(user_id, device_id=None)
ips = await self.store.get_last_client_ip_by_device(user_id, device_id=None)
devices = list(device_map.values())
for device in devices:
@ -81,24 +78,23 @@ class DeviceWorkerHandler(BaseHandler):
return devices
@trace
@defer.inlineCallbacks
def get_device(self, user_id, device_id):
async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
""" Retrieve the given device
Args:
user_id (str):
device_id (str):
user_id: The user to get the device from
device_id: The device to fetch.
Returns:
defer.Deferred: dict[str, X]: info on the device
info on the device
Raises:
errors.NotFoundError: if the device was not found
"""
try:
device = yield self.store.get_device(user_id, device_id)
device = await self.store.get_device(user_id, device_id)
except errors.StoreError:
raise errors.NotFoundError
ips = yield self.store.get_last_client_ip_by_device(user_id, device_id)
ips = await self.store.get_last_client_ip_by_device(user_id, device_id)
_update_device_from_client_ips(device, ips)
set_tag("device", device)
@ -106,10 +102,9 @@ class DeviceWorkerHandler(BaseHandler):
return device
@measure_func("device.get_user_ids_changed")
@trace
@defer.inlineCallbacks
def get_user_ids_changed(self, user_id, from_token):
@measure_func("device.get_user_ids_changed")
async def get_user_ids_changed(self, user_id, from_token):
"""Get list of users that have had the devices updated, or have newly
joined a room, that `user_id` may be interested in.
@ -120,13 +115,13 @@ class DeviceWorkerHandler(BaseHandler):
set_tag("user_id", user_id)
set_tag("from_token", from_token)
now_room_key = yield self.store.get_room_events_max_id()
now_room_key = await self.store.get_room_events_max_id()
room_ids = yield self.store.get_rooms_for_user(user_id)
room_ids = await self.store.get_rooms_for_user(user_id)
# First we check if any devices have changed for users that we share
# rooms with.
users_who_share_room = yield self.store.get_users_who_share_room_with_user(
users_who_share_room = await self.store.get_users_who_share_room_with_user(
user_id
)
@ -135,14 +130,14 @@ class DeviceWorkerHandler(BaseHandler):
# Always tell the user about their own devices
tracked_users.add(user_id)
changed = yield self.store.get_users_whose_devices_changed(
changed = await self.store.get_users_whose_devices_changed(
from_token.device_list_key, tracked_users
)
# Then work out if any users have since joined
rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key)
member_events = yield self.store.get_membership_changes_for_user(
member_events = await self.store.get_membership_changes_for_user(
user_id, from_token.room_key, now_room_key
)
rooms_changed.update(event.room_id for event in member_events)
@ -152,7 +147,7 @@ class DeviceWorkerHandler(BaseHandler):
possibly_changed = set(changed)
possibly_left = set()
for room_id in rooms_changed:
current_state_ids = yield self.store.get_current_state_ids(room_id)
current_state_ids = await self.store.get_current_state_ids(room_id)
# The user may have left the room
# TODO: Check if they actually did or if we were just invited.
@ -166,7 +161,7 @@ class DeviceWorkerHandler(BaseHandler):
# Fetch the current state at the time.
try:
event_ids = yield self.store.get_forward_extremeties_for_room(
event_ids = await self.store.get_forward_extremeties_for_room(
room_id, stream_ordering=stream_ordering
)
except errors.StoreError:
@ -192,7 +187,7 @@ class DeviceWorkerHandler(BaseHandler):
continue
# mapping from event_id -> state_dict
prev_state_ids = yield self.state_store.get_state_ids_for_events(event_ids)
prev_state_ids = await self.state_store.get_state_ids_for_events(event_ids)
# Check if we've joined the room? If so we just blindly add all the users to
# the "possibly changed" users.
@ -238,11 +233,10 @@ class DeviceWorkerHandler(BaseHandler):
return result
@defer.inlineCallbacks
def on_federation_query_user_devices(self, user_id):
stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id)
master_key = yield self.store.get_e2e_cross_signing_key(user_id, "master")
self_signing_key = yield self.store.get_e2e_cross_signing_key(
async def on_federation_query_user_devices(self, user_id):
stream_id, devices = await self.store.get_devices_with_keys_by_user(user_id)
master_key = await self.store.get_e2e_cross_signing_key(user_id, "master")
self_signing_key = await self.store.get_e2e_cross_signing_key(
user_id, "self_signing"
)
@ -271,8 +265,7 @@ class DeviceHandler(DeviceWorkerHandler):
hs.get_distributor().observe("user_left_room", self.user_left_room)
@defer.inlineCallbacks
def check_device_registered(
async def check_device_registered(
self, user_id, device_id, initial_device_display_name=None
):
"""
@ -290,13 +283,13 @@ class DeviceHandler(DeviceWorkerHandler):
str: device id (generated if none was supplied)
"""
if device_id is not None:
new_device = yield self.store.store_device(
new_device = await self.store.store_device(
user_id=user_id,
device_id=device_id,
initial_device_display_name=initial_device_display_name,
)
if new_device:
yield self.notify_device_update(user_id, [device_id])
await self.notify_device_update(user_id, [device_id])
return device_id
# if the device id is not specified, we'll autogen one, but loop a few
@ -304,33 +297,29 @@ class DeviceHandler(DeviceWorkerHandler):
attempts = 0
while attempts < 5:
device_id = stringutils.random_string(10).upper()
new_device = yield self.store.store_device(
new_device = await self.store.store_device(
user_id=user_id,
device_id=device_id,
initial_device_display_name=initial_device_display_name,
)
if new_device:
yield self.notify_device_update(user_id, [device_id])
await self.notify_device_update(user_id, [device_id])
return device_id
attempts += 1
raise errors.StoreError(500, "Couldn't generate a device ID.")
@trace
@defer.inlineCallbacks
def delete_device(self, user_id, device_id):
async def delete_device(self, user_id: str, device_id: str) -> None:
""" Delete the given device
Args:
user_id (str):
device_id (str):
Returns:
defer.Deferred:
user_id: The user to delete the device from.
device_id: The device to delete.
"""
try:
yield self.store.delete_device(user_id, device_id)
await self.store.delete_device(user_id, device_id)
except errors.StoreError as e:
if e.code == 404:
# no match
@ -342,49 +331,40 @@ class DeviceHandler(DeviceWorkerHandler):
else:
raise
yield defer.ensureDeferred(
self._auth_handler.delete_access_tokens_for_user(
user_id, device_id=device_id
)
await self._auth_handler.delete_access_tokens_for_user(
user_id, device_id=device_id
)
yield self.store.delete_e2e_keys_by_device(user_id=user_id, device_id=device_id)
await self.store.delete_e2e_keys_by_device(user_id=user_id, device_id=device_id)
yield self.notify_device_update(user_id, [device_id])
await self.notify_device_update(user_id, [device_id])
@trace
@defer.inlineCallbacks
def delete_all_devices_for_user(self, user_id, except_device_id=None):
async def delete_all_devices_for_user(
self, user_id: str, except_device_id: Optional[str] = None
) -> None:
"""Delete all of the user's devices
Args:
user_id (str):
except_device_id (str|None): optional device id which should not
be deleted
Returns:
defer.Deferred:
user_id: The user to remove all devices from
except_device_id: optional device id which should not be deleted
"""
device_map = yield self.store.get_devices_by_user(user_id)
device_map = await self.store.get_devices_by_user(user_id)
device_ids = list(device_map)
if except_device_id is not None:
device_ids = [d for d in device_ids if d != except_device_id]
yield self.delete_devices(user_id, device_ids)
await self.delete_devices(user_id, device_ids)
@defer.inlineCallbacks
def delete_devices(self, user_id, device_ids):
async def delete_devices(self, user_id: str, device_ids: List[str]) -> None:
""" Delete several devices
Args:
user_id (str):
device_ids (List[str]): The list of device IDs to delete
Returns:
defer.Deferred:
user_id: The user to delete devices from.
device_ids: The list of device IDs to delete
"""
try:
yield self.store.delete_devices(user_id, device_ids)
await self.store.delete_devices(user_id, device_ids)
except errors.StoreError as e:
if e.code == 404:
# no match
@ -397,28 +377,22 @@ class DeviceHandler(DeviceWorkerHandler):
# Delete access tokens and e2e keys for each device. Not optimised as it is not
# considered as part of a critical path.
for device_id in device_ids:
yield defer.ensureDeferred(
self._auth_handler.delete_access_tokens_for_user(
user_id, device_id=device_id
)
await self._auth_handler.delete_access_tokens_for_user(
user_id, device_id=device_id
)
yield self.store.delete_e2e_keys_by_device(
await self.store.delete_e2e_keys_by_device(
user_id=user_id, device_id=device_id
)
yield self.notify_device_update(user_id, device_ids)
await self.notify_device_update(user_id, device_ids)
@defer.inlineCallbacks
def update_device(self, user_id, device_id, content):
async def update_device(self, user_id: str, device_id: str, content: dict) -> None:
""" Update the given device
Args:
user_id (str):
device_id (str):
content (dict): body of update request
Returns:
defer.Deferred:
user_id: The user to update devices of.
device_id: The device to update.
content: body of update request
"""
# Reject a new displayname which is too long.
@ -431,10 +405,10 @@ class DeviceHandler(DeviceWorkerHandler):
)
try:
yield self.store.update_device(
await self.store.update_device(
user_id, device_id, new_display_name=new_display_name
)
yield self.notify_device_update(user_id, [device_id])
await self.notify_device_update(user_id, [device_id])
except errors.StoreError as e:
if e.code == 404:
raise errors.NotFoundError()
@ -443,12 +417,11 @@ class DeviceHandler(DeviceWorkerHandler):
@trace
@measure_func("notify_device_update")
@defer.inlineCallbacks
def notify_device_update(self, user_id, device_ids):
async def notify_device_update(self, user_id, device_ids):
"""Notify that a user's device(s) has changed. Pokes the notifier, and
remote servers if the user is local.
"""
users_who_share_room = yield self.store.get_users_who_share_room_with_user(
users_who_share_room = await self.store.get_users_who_share_room_with_user(
user_id
)
@ -459,7 +432,7 @@ class DeviceHandler(DeviceWorkerHandler):
set_tag("target_hosts", hosts)
position = yield self.store.add_device_change_to_streams(
position = await self.store.add_device_change_to_streams(
user_id, device_ids, list(hosts)
)
@ -468,11 +441,11 @@ class DeviceHandler(DeviceWorkerHandler):
"Notifying about update %r/%r, ID: %r", user_id, device_id, position
)
room_ids = yield self.store.get_rooms_for_user(user_id)
room_ids = await self.store.get_rooms_for_user(user_id)
# specify the user ID too since the user should always get their own device list
# updates, even if they aren't in any rooms.
yield self.notifier.on_new_event(
self.notifier.on_new_event(
"device_list_key", position, users=[user_id], rooms=room_ids
)
@ -484,29 +457,29 @@ class DeviceHandler(DeviceWorkerHandler):
self.federation_sender.send_device_messages(host)
log_kv({"message": "sent device update to host", "host": host})
@defer.inlineCallbacks
def notify_user_signature_update(self, from_user_id, user_ids):
async def notify_user_signature_update(
self, from_user_id: str, user_ids: List[str]
) -> None:
"""Notify a user that they have made new signatures of other users.
Args:
from_user_id (str): the user who made the signature
user_ids (list[str]): the users IDs that have new signatures
from_user_id: the user who made the signature
user_ids: the users IDs that have new signatures
"""
position = yield self.store.add_user_signature_change_to_streams(
position = await self.store.add_user_signature_change_to_streams(
from_user_id, user_ids
)
self.notifier.on_new_event("device_list_key", position, users=[from_user_id])
@defer.inlineCallbacks
def user_left_room(self, user, room_id):
async def user_left_room(self, user, room_id):
user_id = user.to_string()
room_ids = yield self.store.get_rooms_for_user(user_id)
room_ids = await self.store.get_rooms_for_user(user_id)
if not room_ids:
# We no longer share rooms with this user, so we'll no longer
# receive device updates. Mark this in DB.
yield self.store.mark_remote_user_device_list_as_unsubscribed(user_id)
await self.store.mark_remote_user_device_list_as_unsubscribed(user_id)
def _update_device_from_client_ips(device, client_ips):
@ -549,8 +522,7 @@ class DeviceListUpdater(object):
)
@trace
@defer.inlineCallbacks
def incoming_device_list_update(self, origin, edu_content):
async def incoming_device_list_update(self, origin, edu_content):
"""Called on incoming device list update from federation. Responsible
for parsing the EDU and adding to pending updates list.
"""
@ -583,7 +555,7 @@ class DeviceListUpdater(object):
)
return
room_ids = yield self.store.get_rooms_for_user(user_id)
room_ids = await self.store.get_rooms_for_user(user_id)
if not room_ids:
# We don't share any rooms with this user. Ignore update, as we
# probably won't get any further updates.
@ -608,14 +580,13 @@ class DeviceListUpdater(object):
(device_id, stream_id, prev_ids, edu_content)
)
yield self._handle_device_updates(user_id)
await self._handle_device_updates(user_id)
@measure_func("_incoming_device_list_update")
@defer.inlineCallbacks
def _handle_device_updates(self, user_id):
async def _handle_device_updates(self, user_id):
"Actually handle pending updates."
with (yield self._remote_edu_linearizer.queue(user_id)):
with (await self._remote_edu_linearizer.queue(user_id)):
pending_updates = self._pending_updates.pop(user_id, [])
if not pending_updates:
# This can happen since we batch updates
@ -632,7 +603,7 @@ class DeviceListUpdater(object):
# Given a list of updates we check if we need to resync. This
# happens if we've missed updates.
resync = yield self._need_to_do_resync(user_id, pending_updates)
resync = await self._need_to_do_resync(user_id, pending_updates)
if logger.isEnabledFor(logging.INFO):
logger.info(
@ -643,16 +614,16 @@ class DeviceListUpdater(object):
)
if resync:
yield self.user_device_resync(user_id)
await self.user_device_resync(user_id)
else:
# Simply update the single device, since we know that is the only
# change (because of the single prev_id matching the current cache)
for device_id, stream_id, prev_ids, content in pending_updates:
yield self.store.update_remote_device_list_cache_entry(
await self.store.update_remote_device_list_cache_entry(
user_id, device_id, content, stream_id
)
yield self.device_handler.notify_device_update(
await self.device_handler.notify_device_update(
user_id, [device_id for device_id, _, _, _ in pending_updates]
)
@ -660,14 +631,13 @@ class DeviceListUpdater(object):
stream_id for _, stream_id, _, _ in pending_updates
)
@defer.inlineCallbacks
def _need_to_do_resync(self, user_id, updates):
async def _need_to_do_resync(self, user_id, updates):
"""Given a list of updates for a user figure out if we need to do a full
resync, or whether we have enough data that we can just apply the delta.
"""
seen_updates = self._seen_updates.get(user_id, set())
extremity = yield self.store.get_device_list_last_stream_id_for_remote(user_id)
extremity = await self.store.get_device_list_last_stream_id_for_remote(user_id)
logger.debug("Current extremity for %r: %r", user_id, extremity)
@ -692,8 +662,7 @@ class DeviceListUpdater(object):
return False
@trace
@defer.inlineCallbacks
def _maybe_retry_device_resync(self):
async def _maybe_retry_device_resync(self):
"""Retry to resync device lists that are out of sync, except if another retry is
in progress.
"""
@ -705,12 +674,12 @@ class DeviceListUpdater(object):
# we don't send too many requests.
self._resync_retry_in_progress = True
# Get all of the users that need resyncing.
need_resync = yield self.store.get_user_ids_requiring_device_list_resync()
need_resync = await self.store.get_user_ids_requiring_device_list_resync()
# Iterate over the set of user IDs.
for user_id in need_resync:
try:
# Try to resync the current user's devices list.
result = yield self.user_device_resync(
result = await self.user_device_resync(
user_id=user_id, mark_failed_as_stale=False,
)
@ -734,16 +703,17 @@ class DeviceListUpdater(object):
# Allow future calls to retry resyncinc out of sync device lists.
self._resync_retry_in_progress = False
@defer.inlineCallbacks
def user_device_resync(self, user_id, mark_failed_as_stale=True):
async def user_device_resync(
self, user_id: str, mark_failed_as_stale: bool = True
) -> Optional[dict]:
"""Fetches all devices for a user and updates the device cache with them.
Args:
user_id (str): The user's id whose device_list will be updated.
mark_failed_as_stale (bool): Whether to mark the user's device list as stale
user_id: The user's id whose device_list will be updated.
mark_failed_as_stale: Whether to mark the user's device list as stale
if the attempt to resync failed.
Returns:
Deferred[dict]: a dict with device info as under the "devices" in the result of this
A dict with device info as under the "devices" in the result of this
request:
https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid
"""
@ -752,12 +722,12 @@ class DeviceListUpdater(object):
# Fetch all devices for the user.
origin = get_domain_from_id(user_id)
try:
result = yield self.federation.query_user_devices(origin, user_id)
result = await self.federation.query_user_devices(origin, user_id)
except NotRetryingDestination:
if mark_failed_as_stale:
# Mark the remote user's device list as stale so we know we need to retry
# it later.
yield self.store.mark_remote_user_device_cache_as_stale(user_id)
await self.store.mark_remote_user_device_cache_as_stale(user_id)
return
except (RequestSendFailed, HttpResponseException) as e:
@ -768,7 +738,7 @@ class DeviceListUpdater(object):
if mark_failed_as_stale:
# Mark the remote user's device list as stale so we know we need to retry
# it later.
yield self.store.mark_remote_user_device_cache_as_stale(user_id)
await self.store.mark_remote_user_device_cache_as_stale(user_id)
# We abort on exceptions rather than accepting the update
# as otherwise synapse will 'forget' that its device list
@ -792,7 +762,7 @@ class DeviceListUpdater(object):
if mark_failed_as_stale:
# Mark the remote user's device list as stale so we know we need to retry
# it later.
yield self.store.mark_remote_user_device_cache_as_stale(user_id)
await self.store.mark_remote_user_device_cache_as_stale(user_id)
return
log_kv({"result": result})
@ -833,25 +803,24 @@ class DeviceListUpdater(object):
stream_id,
)
yield self.store.update_remote_device_list_cache(user_id, devices, stream_id)
await self.store.update_remote_device_list_cache(user_id, devices, stream_id)
device_ids = [device["device_id"] for device in devices]
# Handle cross-signing keys.
cross_signing_device_ids = yield self.process_cross_signing_key_update(
cross_signing_device_ids = await self.process_cross_signing_key_update(
user_id, master_key, self_signing_key,
)
device_ids = device_ids + cross_signing_device_ids
yield self.device_handler.notify_device_update(user_id, device_ids)
await self.device_handler.notify_device_update(user_id, device_ids)
# We clobber the seen updates since we've re-synced from a given
# point.
self._seen_updates[user_id] = {stream_id}
defer.returnValue(result)
return result
@defer.inlineCallbacks
def process_cross_signing_key_update(
async def process_cross_signing_key_update(
self,
user_id: str,
master_key: Optional[Dict[str, Any]],
@ -872,14 +841,14 @@ class DeviceListUpdater(object):
device_ids = []
if master_key:
yield self.store.set_e2e_cross_signing_key(user_id, "master", master_key)
await self.store.set_e2e_cross_signing_key(user_id, "master", master_key)
_, verify_key = get_verify_key_from_cross_signing_key(master_key)
# verify_key is a VerifyKey from signedjson, which uses
# .version to denote the portion of the key ID after the
# algorithm and colon, which is the device ID
device_ids.append(verify_key.version)
if self_signing_key:
yield self.store.set_e2e_cross_signing_key(
await self.store.set_e2e_cross_signing_key(
user_id, "self_signing", self_signing_key
)
_, verify_key = get_verify_key_from_cross_signing_key(self_signing_key)

View File

@ -12,10 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import logging
from twisted.internet import defer
from twisted.internet.defer import Deferred, fail, succeed
from twisted.python import failure
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
@ -79,6 +81,28 @@ class Distributor(object):
run_as_background_process(name, self.signals[name].fire, *args, **kwargs)
def maybeAwaitableDeferred(f, *args, **kw):
"""
Invoke a function that may or may not return a Deferred or an Awaitable.
This is a modified version of twisted.internet.defer.maybeDeferred.
"""
try:
result = f(*args, **kw)
except Exception:
return fail(failure.Failure(captureVars=Deferred.debug))
if isinstance(result, Deferred):
return result
# Handle the additional case of an awaitable being returned.
elif inspect.isawaitable(result):
return defer.ensureDeferred(result)
elif isinstance(result, failure.Failure):
return fail(result)
else:
return succeed(result)
class Signal(object):
"""A Signal is a dispatch point that stores a list of callables as
observers of it.
@ -122,7 +146,7 @@ class Signal(object):
),
)
return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb)
return maybeAwaitableDeferred(observer, *args, **kwargs).addErrback(eb)
deferreds = [run_in_background(do, o) for o in self.observers]

View File

@ -142,10 +142,8 @@ class DeviceTestCase(unittest.HomeserverTestCase):
self.get_success(self.handler.delete_device(user1, "abc"))
# check the device was deleted
res = self.handler.get_device(user1, "abc")
self.pump()
self.assertIsInstance(
self.failureResultOf(res).value, synapse.api.errors.NotFoundError
self.get_failure(
self.handler.get_device(user1, "abc"), synapse.api.errors.NotFoundError
)
# we'd like to check the access token was invalidated, but that's a
@ -180,10 +178,9 @@ class DeviceTestCase(unittest.HomeserverTestCase):
def test_update_unknown_device(self):
update = {"display_name": "new_display"}
res = self.handler.update_device("user_id", "unknown_device_id", update)
self.pump()
self.assertIsInstance(
self.failureResultOf(res).value, synapse.api.errors.NotFoundError
self.get_failure(
self.handler.update_device("user_id", "unknown_device_id", update),
synapse.api.errors.NotFoundError,
)
def _record_users(self):

View File

@ -334,10 +334,12 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
res = None
try:
yield self.hs.get_device_handler().check_device_registered(
user_id=local_user,
device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
initial_device_display_name="new display name",
yield defer.ensureDeferred(
self.hs.get_device_handler().check_device_registered(
user_id=local_user,
device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
initial_device_display_name="new display name",
)
)
except errors.SynapseError as e:
res = e.code

View File

@ -173,7 +173,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
# Register a mock on the store so that the incoming update doesn't fail because
# we don't share a room with the user.
store = self.homeserver.get_datastore()
store.get_rooms_for_user = Mock(return_value=["!someroom:test"])
store.get_rooms_for_user = Mock(return_value=succeed(["!someroom:test"]))
# Manually inject a fake device list update. We need this update to include at
# least one prev_id so that the user's device list will need to be retried.
@ -218,23 +218,26 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
# Register mock device list retrieval on the federation client.
federation_client = self.homeserver.get_federation_client()
federation_client.query_user_devices = Mock(
return_value={
"user_id": remote_user_id,
"stream_id": 1,
"devices": [],
"master_key": {
return_value=succeed(
{
"user_id": remote_user_id,
"usage": ["master"],
"keys": {"ed25519:" + remote_master_key: remote_master_key},
},
"self_signing_key": {
"user_id": remote_user_id,
"usage": ["self_signing"],
"keys": {
"ed25519:" + remote_self_signing_key: remote_self_signing_key
"stream_id": 1,
"devices": [],
"master_key": {
"user_id": remote_user_id,
"usage": ["master"],
"keys": {"ed25519:" + remote_master_key: remote_master_key},
},
},
}
"self_signing_key": {
"user_id": remote_user_id,
"usage": ["self_signing"],
"keys": {
"ed25519:"
+ remote_self_signing_key: remote_self_signing_key
},
},
}
)
)
# Resync the device list.