Convert the roommember database to async/await. (#8070)

pull/8072/head
Patrick Cloke 2020-08-12 12:14:34 -04:00 committed by GitHub
parent 5ecc8b5825
commit fbe930dad2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 116 additions and 242 deletions

1
changelog.d/8070.misc Normal file
View File

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

View File

@ -58,7 +58,6 @@ class SQLBaseStore(metaclass=ABCMeta):
"""
for host in {get_domain_from_id(u) for u in members_changed}:
self._attempt_to_invalidate_cache("is_host_joined", (room_id, host))
self._attempt_to_invalidate_cache("was_host_joined", (room_id, host))
self._attempt_to_invalidate_cache("get_users_in_room", (room_id,))
self._attempt_to_invalidate_cache("get_room_summary", (room_id,))

View File

@ -256,81 +256,6 @@ class PushRulesWorkerStore(
):
yield self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule)
@defer.inlineCallbacks
def bulk_get_push_rules_for_room(self, event, context):
state_group = context.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group
# of None don't hit previous cached calls with a None state_group.
# To do this we set the state_group to a new object as object() != object()
state_group = object()
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
result = yield self._bulk_get_push_rules_for_room(
event.room_id, state_group, current_state_ids, event=event
)
return result
@cachedInlineCallbacks(num_args=2, cache_context=True)
def _bulk_get_push_rules_for_room(
self, room_id, state_group, current_state_ids, cache_context, event=None
):
# We don't use `state_group`, its there so that we can cache based
# on it. However, its important that its never None, since two current_state's
# with a state_group of None are likely to be different.
# See bulk_get_push_rules_for_room for how we work around this.
assert state_group is not None
# We also will want to generate notifs for other people in the room so
# their unread countss are correct in the event stream, but to avoid
# generating them for bot / AS users etc, we only do so for people who've
# sent a read receipt into the room.
users_in_room = yield self._get_joined_users_from_context(
room_id,
state_group,
current_state_ids,
on_invalidate=cache_context.invalidate,
event=event,
)
# We ignore app service users for now. This is so that we don't fill
# up the `get_if_users_have_pushers` cache with AS entries that we
# know don't have pushers, nor even read receipts.
local_users_in_room = {
u
for u in users_in_room
if self.hs.is_mine_id(u)
and not self.get_if_app_services_interested_in_user(u)
}
# users in the room who have pushers need to get push rules run because
# that's how their pushers work
if_users_with_pushers = yield self.get_if_users_have_pushers(
local_users_in_room, on_invalidate=cache_context.invalidate
)
user_ids = {
uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
}
users_with_receipts = yield self.get_users_with_read_receipts_in_room(
room_id, on_invalidate=cache_context.invalidate
)
# any users with pushers must be ours: they have pushers
for uid in users_with_receipts:
if uid in local_users_in_room:
user_ids.add(uid)
rules_by_user = yield self.bulk_get_push_rules(
user_ids, on_invalidate=cache_context.invalidate
)
rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None}
return rules_by_user
@cachedList(
cached_method_name="get_push_rules_enabled_for_user",
list_name="user_ids",

View File

@ -15,11 +15,13 @@
# limitations under the License.
import logging
from typing import Iterable, List, Set
from typing import TYPE_CHECKING, Awaitable, Iterable, List, Optional, Set
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import (
@ -40,9 +42,12 @@ from synapse.storage.roommember import (
from synapse.types import Collection, get_domain_from_id
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
from synapse.util.metrics import Measure
if TYPE_CHECKING:
from synapse.state import _StateCacheEntry
logger = logging.getLogger(__name__)
@ -150,12 +155,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
@cached(max_entries=100000, iterable=True)
def get_users_in_room(self, room_id):
def get_users_in_room(self, room_id: str):
return self.db_pool.runInteraction(
"get_users_in_room", self.get_users_in_room_txn, room_id
)
def get_users_in_room_txn(self, txn, room_id):
def get_users_in_room_txn(self, txn, room_id: str) -> List[str]:
# If we can assume current_state_events.membership is up to date
# then we can avoid a join, which is a Very Good Thing given how
# frequently this function gets called.
@ -178,11 +183,11 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return [r[0] for r in txn]
@cached(max_entries=100000)
def get_room_summary(self, room_id):
def get_room_summary(self, room_id: str):
""" Get the details of a room roughly suitable for use by the room
summary extension to /sync. Useful when lazy loading room members.
Args:
room_id (str): The room ID to query
room_id: The room ID to query
Returns:
Deferred[dict[str, MemberSummary]:
dict of membership states, pointing to a MemberSummary named tuple.
@ -261,78 +266,59 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return self.db_pool.runInteraction("get_room_summary", _get_room_summary_txn)
def _get_user_counts_in_room_txn(self, txn, room_id):
"""
Get the user count in a room by membership.
Args:
room_id (str)
membership (Membership)
Returns:
Deferred[int]
"""
sql = """
SELECT m.membership, count(*) FROM room_memberships as m
INNER JOIN current_state_events as c USING(event_id)
WHERE c.type = 'm.room.member' AND c.room_id = ?
GROUP BY m.membership
"""
txn.execute(sql, (room_id,))
return {row[0]: row[1] for row in txn}
@cached()
def get_invited_rooms_for_local_user(self, user_id):
""" Get all the rooms the *local* user is invited to
def get_invited_rooms_for_local_user(self, user_id: str) -> Awaitable[RoomsForUser]:
"""Get all the rooms the *local* user is invited to.
Args:
user_id (str): The user ID.
user_id: The user ID.
Returns:
A deferred list of RoomsForUser.
A awaitable list of RoomsForUser.
"""
return self.get_rooms_for_local_user_where_membership_is(
user_id, [Membership.INVITE]
)
@defer.inlineCallbacks
def get_invite_for_local_user_in_room(self, user_id, room_id):
"""Gets the invite for the given *local* user and room
async def get_invite_for_local_user_in_room(
self, user_id: str, room_id: str
) -> Optional[RoomsForUser]:
"""Gets the invite for the given *local* user and room.
Args:
user_id (str)
room_id (str)
user_id: The user ID to find the invite of.
room_id: The room to user was invited to.
Returns:
Deferred: Resolves to either a RoomsForUser or None if no invite was
found.
Either a RoomsForUser or None if no invite was found.
"""
invites = yield self.get_invited_rooms_for_local_user(user_id)
invites = await self.get_invited_rooms_for_local_user(user_id)
for invite in invites:
if invite.room_id == room_id:
return invite
return None
@defer.inlineCallbacks
def get_rooms_for_local_user_where_membership_is(self, user_id, membership_list):
""" Get all the rooms for this *local* user where the membership for this user
async def get_rooms_for_local_user_where_membership_is(
self, user_id: str, membership_list: List[str]
) -> Optional[List[RoomsForUser]]:
"""Get all the rooms for this *local* user where the membership for this user
matches one in the membership list.
Filters out forgotten rooms.
Args:
user_id (str): The user ID.
membership_list (list): A list of synapse.api.constants.Membership
values which the user must be in.
user_id: The user ID.
membership_list: A list of synapse.api.constants.Membership
values which the user must be in.
Returns:
Deferred[list[RoomsForUser]]
The RoomsForUser that the user matches the membership types.
"""
if not membership_list:
return defer.succeed(None)
return None
rooms = yield self.db_pool.runInteraction(
rooms = await self.db_pool.runInteraction(
"get_rooms_for_local_user_where_membership_is",
self._get_rooms_for_local_user_where_membership_is_txn,
user_id,
@ -340,12 +326,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
# Now we filter out forgotten rooms
forgotten_rooms = yield self.get_forgotten_rooms_for_user(user_id)
forgotten_rooms = await self.get_forgotten_rooms_for_user(user_id)
return [room for room in rooms if room.room_id not in forgotten_rooms]
def _get_rooms_for_local_user_where_membership_is_txn(
self, txn, user_id, membership_list
):
self, txn, user_id: str, membership_list: List[str]
) -> List[RoomsForUser]:
# Paranoia check.
if not self.hs.is_mine_id(user_id):
raise Exception(
@ -374,14 +360,14 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return results
@cached(max_entries=500000, iterable=True)
def get_rooms_for_user_with_stream_ordering(self, user_id):
def get_rooms_for_user_with_stream_ordering(self, user_id: str):
"""Returns a set of room_ids the user is currently joined to.
If a remote user only returns rooms this server is currently
participating in.
Args:
user_id (str)
user_id
Returns:
Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns
@ -394,7 +380,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
user_id,
)
def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id):
def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id: str):
# We use `current_state_events` here and not `local_current_membership`
# as a) this gets called with remote users and b) this only gets called
# for rooms the server is participating in.
@ -458,37 +444,39 @@ class RoomMemberWorkerStore(EventsWorkerStore):
_get_users_server_still_shares_room_with_txn,
)
@defer.inlineCallbacks
def get_rooms_for_user(self, user_id, on_invalidate=None):
async def get_rooms_for_user(self, user_id: str, on_invalidate=None):
"""Returns a set of room_ids the user is currently joined to.
If a remote user only returns rooms this server is currently
participating in.
"""
rooms = yield self.get_rooms_for_user_with_stream_ordering(
rooms = await self.get_rooms_for_user_with_stream_ordering(
user_id, on_invalidate=on_invalidate
)
return frozenset(r.room_id for r in rooms)
@cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True)
def get_users_who_share_room_with_user(self, user_id, cache_context):
@cached(max_entries=500000, cache_context=True, iterable=True)
async def get_users_who_share_room_with_user(
self, user_id: str, cache_context: _CacheContext
) -> Set[str]:
"""Returns the set of users who share a room with `user_id`
"""
room_ids = yield self.get_rooms_for_user(
room_ids = await self.get_rooms_for_user(
user_id, on_invalidate=cache_context.invalidate
)
user_who_share_room = set()
for room_id in room_ids:
user_ids = yield self.get_users_in_room(
user_ids = await self.get_users_in_room(
room_id, on_invalidate=cache_context.invalidate
)
user_who_share_room.update(user_ids)
return user_who_share_room
@defer.inlineCallbacks
def get_joined_users_from_context(self, event, context):
async def get_joined_users_from_context(
self, event: EventBase, context: EventContext
):
state_group = context.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
@ -497,14 +485,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# To do this we set the state_group to a new object as object() != object()
state_group = object()
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
result = yield self._get_joined_users_from_context(
current_state_ids = await context.get_current_state_ids()
return await self._get_joined_users_from_context(
event.room_id, state_group, current_state_ids, event=event, context=context
)
return result
@defer.inlineCallbacks
def get_joined_users_from_state(self, room_id, state_entry):
async def get_joined_users_from_state(self, room_id, state_entry):
state_group = state_entry.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
@ -514,16 +500,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
state_group = object()
with Measure(self._clock, "get_joined_users_from_state"):
return (
yield self._get_joined_users_from_context(
room_id, state_group, state_entry.state, context=state_entry
)
return await self._get_joined_users_from_context(
room_id, state_group, state_entry.state, context=state_entry
)
@cachedInlineCallbacks(
num_args=2, cache_context=True, iterable=True, max_entries=100000
)
def _get_joined_users_from_context(
@cached(num_args=2, cache_context=True, iterable=True, max_entries=100000)
async def _get_joined_users_from_context(
self,
room_id,
state_group,
@ -535,7 +517,6 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# We don't use `state_group`, it's there so that we can cache based
# on it. However, it's important that it's never None, since two current_states
# with a state_group of None are likely to be different.
# See bulk_get_push_rules_for_room for how we work around this.
assert state_group is not None
users_in_room = {}
@ -588,7 +569,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
missing_member_event_ids.append(event_id)
if missing_member_event_ids:
event_to_memberships = yield self._get_joined_profiles_from_event_ids(
event_to_memberships = await self._get_joined_profiles_from_event_ids(
missing_member_event_ids
)
users_in_room.update((row for row in event_to_memberships.values() if row))
@ -612,12 +593,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
list_name="event_ids",
inlineCallbacks=True,
)
def _get_joined_profiles_from_event_ids(self, event_ids):
def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
"""For given set of member event_ids check if they point to a join
event and if so return the associated user and profile info.
Args:
event_ids (Iterable[str]): The member event IDs to lookup
event_ids: The member event IDs to lookup
Returns:
Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID
@ -644,8 +625,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
for row in rows
}
@cachedInlineCallbacks(max_entries=10000)
def is_host_joined(self, room_id, host):
@cached(max_entries=10000)
async def is_host_joined(self, room_id: str, host: str) -> bool:
if "%" in host or "_" in host:
raise Exception("Invalid host name")
@ -664,7 +645,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# the returned user actually has the correct domain.
like_clause = "%:" + host
rows = yield self.db_pool.execute(
rows = await self.db_pool.execute(
"is_host_joined", None, sql, room_id, like_clause
)
@ -678,50 +659,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return True
@cachedInlineCallbacks()
def was_host_joined(self, room_id, host):
"""Check whether the server is or ever was in the room.
Args:
room_id (str)
host (str)
Returns:
Deferred: Resolves to True if the host is/was in the room, otherwise
False.
"""
if "%" in host or "_" in host:
raise Exception("Invalid host name")
sql = """
SELECT user_id FROM room_memberships
WHERE room_id = ?
AND user_id LIKE ?
AND membership = 'join'
LIMIT 1
"""
# We do need to be careful to ensure that host doesn't have any wild cards
# in it, but we checked above for known ones and we'll check below that
# the returned user actually has the correct domain.
like_clause = "%:" + host
rows = yield self.db_pool.execute(
"was_host_joined", None, sql, room_id, like_clause
)
if not rows:
return False
user_id = rows[0][0]
if get_domain_from_id(user_id) != host:
# This can only happen if the host name has something funky in it
raise Exception("Invalid host name")
return True
@defer.inlineCallbacks
def get_joined_hosts(self, room_id, state_entry):
async def get_joined_hosts(self, room_id: str, state_entry):
state_group = state_entry.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
@ -731,32 +669,28 @@ class RoomMemberWorkerStore(EventsWorkerStore):
state_group = object()
with Measure(self._clock, "get_joined_hosts"):
return (
yield self._get_joined_hosts(
room_id, state_group, state_entry.state, state_entry=state_entry
)
return await self._get_joined_hosts(
room_id, state_group, state_entry.state, state_entry=state_entry
)
@cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True)
# @defer.inlineCallbacks
def _get_joined_hosts(self, room_id, state_group, current_state_ids, state_entry):
@cached(num_args=2, max_entries=10000, iterable=True)
async def _get_joined_hosts(
self, room_id, state_group, current_state_ids, state_entry
):
# We don't use `state_group`, its there so that we can cache based
# on it. However, its important that its never None, since two current_state's
# with a state_group of None are likely to be different.
# See bulk_get_push_rules_for_room for how we work around this.
assert state_group is not None
cache = yield self._get_joined_hosts_cache(room_id)
joined_hosts = yield cache.get_destinations(state_entry)
return joined_hosts
cache = await self._get_joined_hosts_cache(room_id)
return await cache.get_destinations(state_entry)
@cached(max_entries=10000)
def _get_joined_hosts_cache(self, room_id):
def _get_joined_hosts_cache(self, room_id: str) -> "_JoinedHostsCache":
return _JoinedHostsCache(self, room_id)
@cachedInlineCallbacks(num_args=2)
def did_forget(self, user_id, room_id):
@cached(num_args=2)
async def did_forget(self, user_id: str, room_id: str) -> bool:
"""Returns whether user_id has elected to discard history for room_id.
Returns False if they have since re-joined."""
@ -778,15 +712,15 @@ class RoomMemberWorkerStore(EventsWorkerStore):
rows = txn.fetchall()
return rows[0][0]
count = yield self.db_pool.runInteraction("did_forget_membership", f)
count = await self.db_pool.runInteraction("did_forget_membership", f)
return count == 0
@cached()
def get_forgotten_rooms_for_user(self, user_id):
def get_forgotten_rooms_for_user(self, user_id: str):
"""Gets all rooms the user has forgotten.
Args:
user_id (str)
user_id
Returns:
Deferred[set[str]]
@ -819,18 +753,17 @@ class RoomMemberWorkerStore(EventsWorkerStore):
"get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn
)
@defer.inlineCallbacks
def get_rooms_user_has_been_in(self, user_id):
async def get_rooms_user_has_been_in(self, user_id: str) -> Set[str]:
"""Get all rooms that the user has ever been in.
Args:
user_id (str)
user_id: The user ID to get the rooms of.
Returns:
Deferred[set[str]]: Set of room IDs.
Set of room IDs.
"""
room_ids = yield self.db_pool.simple_select_onecol(
room_ids = await self.db_pool.simple_select_onecol(
table="room_memberships",
keyvalues={"membership": Membership.JOIN, "user_id": user_id},
retcol="room_id",
@ -905,8 +838,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
where_clause="forgotten = 1",
)
@defer.inlineCallbacks
def _background_add_membership_profile(self, progress, batch_size):
async def _background_add_membership_profile(self, progress, batch_size):
target_min_stream_id = progress.get(
"target_min_stream_id_inclusive", self._min_stream_order_on_start
)
@ -971,19 +903,18 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
return len(rows)
result = yield self.db_pool.runInteraction(
result = await self.db_pool.runInteraction(
_MEMBERSHIP_PROFILE_UPDATE_NAME, add_membership_profile_txn
)
if not result:
yield self.db_pool.updates._end_background_update(
await self.db_pool.updates._end_background_update(
_MEMBERSHIP_PROFILE_UPDATE_NAME
)
return result
@defer.inlineCallbacks
def _background_current_state_membership(self, progress, batch_size):
async def _background_current_state_membership(self, progress, batch_size):
"""Update the new membership column on current_state_events.
This works by iterating over all rooms in alphebetical order.
@ -1029,14 +960,14 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
# string, which will compare before all room IDs correctly.
last_processed_room = progress.get("last_processed_room", "")
row_count, finished = yield self.db_pool.runInteraction(
row_count, finished = await self.db_pool.runInteraction(
"_background_current_state_membership_update",
_background_current_state_membership_txn,
last_processed_room,
)
if finished:
yield self.db_pool.updates._end_background_update(
await self.db_pool.updates._end_background_update(
_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME
)
@ -1047,7 +978,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super(RoomMemberStore, self).__init__(database, db_conn, hs)
def forget(self, user_id, room_id):
def forget(self, user_id: str, room_id: str):
"""Indicate that user_id wishes to discard history for room_id."""
def f(txn):
@ -1088,17 +1019,19 @@ class _JoinedHostsCache(object):
self._len = 0
@defer.inlineCallbacks
def get_destinations(self, state_entry):
async def get_destinations(self, state_entry: "_StateCacheEntry") -> Set[str]:
"""Get set of destinations for a state entry
Args:
state_entry(synapse.state._StateCacheEntry)
state_entry
Returns:
The destinations as a set.
"""
if state_entry.state_group == self.state_group:
return frozenset(self.hosts_to_joined_users)
with (yield self.linearizer.queue(())):
with (await self.linearizer.queue(())):
if state_entry.state_group == self.state_group:
pass
elif state_entry.prev_group == self.state_group:
@ -1110,7 +1043,7 @@ class _JoinedHostsCache(object):
user_id = state_key
known_joins = self.hosts_to_joined_users.setdefault(host, set())
event = yield self.store.get_event(event_id)
event = await self.store.get_event(event_id)
if event.membership == Membership.JOIN:
known_joins.add(user_id)
else:
@ -1119,7 +1052,7 @@ class _JoinedHostsCache(object):
if not known_joins:
self.hosts_to_joined_users.pop(host, None)
else:
joined_users = yield self.store.get_joined_users_from_state(
joined_users = await self.store.get_joined_users_from_state(
self.room_id, state_entry
)

View File

@ -1,3 +1,18 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from mock import Mock
from twisted.internet.defer import ensureDeferred, maybeDeferred, succeed
@ -10,6 +25,7 @@ from synapse.util.retryutils import NotRetryingDestination
from tests import unittest
from tests.server import ThreadedMemoryReactorClock, setup_test_homeserver
from tests.test_utils import make_awaitable
class MessageAcceptTests(unittest.HomeserverTestCase):
@ -173,7 +189,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
# Register a mock on the store so that the incoming update doesn't fail because
# we don't share a room with the user.
store = self.homeserver.get_datastore()
store.get_rooms_for_user = Mock(return_value=succeed(["!someroom:test"]))
store.get_rooms_for_user = Mock(return_value=make_awaitable(["!someroom:test"]))
# Manually inject a fake device list update. We need this update to include at
# least one prev_id so that the user's device list will need to be retried.