|
|
|
@ -15,11 +15,13 @@
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
import logging
|
|
|
|
|
from typing import Iterable, List, Set
|
|
|
|
|
from typing import TYPE_CHECKING, Awaitable, Iterable, List, Optional, Set
|
|
|
|
|
|
|
|
|
|
from twisted.internet import defer
|
|
|
|
|
|
|
|
|
|
from synapse.api.constants import EventTypes, Membership
|
|
|
|
|
from synapse.events import EventBase
|
|
|
|
|
from synapse.events.snapshot import EventContext
|
|
|
|
|
from synapse.metrics import LaterGauge
|
|
|
|
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
|
|
|
|
from synapse.storage._base import (
|
|
|
|
@ -40,9 +42,12 @@ from synapse.storage.roommember import (
|
|
|
|
|
from synapse.types import Collection, get_domain_from_id
|
|
|
|
|
from synapse.util.async_helpers import Linearizer
|
|
|
|
|
from synapse.util.caches import intern_string
|
|
|
|
|
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
|
|
|
|
|
from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
|
|
|
|
|
from synapse.util.metrics import Measure
|
|
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
|
from synapse.state import _StateCacheEntry
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -150,12 +155,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@cached(max_entries=100000, iterable=True)
|
|
|
|
|
def get_users_in_room(self, room_id):
|
|
|
|
|
def get_users_in_room(self, room_id: str):
|
|
|
|
|
return self.db_pool.runInteraction(
|
|
|
|
|
"get_users_in_room", self.get_users_in_room_txn, room_id
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def get_users_in_room_txn(self, txn, room_id):
|
|
|
|
|
def get_users_in_room_txn(self, txn, room_id: str) -> List[str]:
|
|
|
|
|
# If we can assume current_state_events.membership is up to date
|
|
|
|
|
# then we can avoid a join, which is a Very Good Thing given how
|
|
|
|
|
# frequently this function gets called.
|
|
|
|
@ -178,11 +183,11 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|
|
|
|
return [r[0] for r in txn]
|
|
|
|
|
|
|
|
|
|
@cached(max_entries=100000)
|
|
|
|
|
def get_room_summary(self, room_id):
|
|
|
|
|
def get_room_summary(self, room_id: str):
|
|
|
|
|
""" Get the details of a room roughly suitable for use by the room
|
|
|
|
|
summary extension to /sync. Useful when lazy loading room members.
|
|
|
|
|
Args:
|
|
|
|
|
room_id (str): The room ID to query
|
|
|
|
|
room_id: The room ID to query
|
|
|
|
|
Returns:
|
|
|
|
|
Deferred[dict[str, MemberSummary]:
|
|
|
|
|
dict of membership states, pointing to a MemberSummary named tuple.
|
|
|
|
@ -261,78 +266,59 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|
|
|
|
|
|
|
|
|
return self.db_pool.runInteraction("get_room_summary", _get_room_summary_txn)
|
|
|
|
|
|
|
|
|
|
def _get_user_counts_in_room_txn(self, txn, room_id):
|
|
|
|
|
"""
|
|
|
|
|
Get the user count in a room by membership.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
room_id (str)
|
|
|
|
|
membership (Membership)
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Deferred[int]
|
|
|
|
|
"""
|
|
|
|
|
sql = """
|
|
|
|
|
SELECT m.membership, count(*) FROM room_memberships as m
|
|
|
|
|
INNER JOIN current_state_events as c USING(event_id)
|
|
|
|
|
WHERE c.type = 'm.room.member' AND c.room_id = ?
|
|
|
|
|
GROUP BY m.membership
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
txn.execute(sql, (room_id,))
|
|
|
|
|
return {row[0]: row[1] for row in txn}
|
|
|
|
|
|
|
|
|
|
@cached()
|
|
|
|
|
def get_invited_rooms_for_local_user(self, user_id):
|
|
|
|
|
""" Get all the rooms the *local* user is invited to
|
|
|
|
|
def get_invited_rooms_for_local_user(self, user_id: str) -> Awaitable[RoomsForUser]:
|
|
|
|
|
"""Get all the rooms the *local* user is invited to.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
user_id (str): The user ID.
|
|
|
|
|
user_id: The user ID.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
A deferred list of RoomsForUser.
|
|
|
|
|
A awaitable list of RoomsForUser.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
return self.get_rooms_for_local_user_where_membership_is(
|
|
|
|
|
user_id, [Membership.INVITE]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@defer.inlineCallbacks
|
|
|
|
|
def get_invite_for_local_user_in_room(self, user_id, room_id):
|
|
|
|
|
"""Gets the invite for the given *local* user and room
|
|
|
|
|
async def get_invite_for_local_user_in_room(
|
|
|
|
|
self, user_id: str, room_id: str
|
|
|
|
|
) -> Optional[RoomsForUser]:
|
|
|
|
|
"""Gets the invite for the given *local* user and room.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
user_id (str)
|
|
|
|
|
room_id (str)
|
|
|
|
|
user_id: The user ID to find the invite of.
|
|
|
|
|
room_id: The room to user was invited to.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Deferred: Resolves to either a RoomsForUser or None if no invite was
|
|
|
|
|
found.
|
|
|
|
|
Either a RoomsForUser or None if no invite was found.
|
|
|
|
|
"""
|
|
|
|
|
invites = yield self.get_invited_rooms_for_local_user(user_id)
|
|
|
|
|
invites = await self.get_invited_rooms_for_local_user(user_id)
|
|
|
|
|
for invite in invites:
|
|
|
|
|
if invite.room_id == room_id:
|
|
|
|
|
return invite
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
@defer.inlineCallbacks
|
|
|
|
|
def get_rooms_for_local_user_where_membership_is(self, user_id, membership_list):
|
|
|
|
|
""" Get all the rooms for this *local* user where the membership for this user
|
|
|
|
|
async def get_rooms_for_local_user_where_membership_is(
|
|
|
|
|
self, user_id: str, membership_list: List[str]
|
|
|
|
|
) -> Optional[List[RoomsForUser]]:
|
|
|
|
|
"""Get all the rooms for this *local* user where the membership for this user
|
|
|
|
|
matches one in the membership list.
|
|
|
|
|
|
|
|
|
|
Filters out forgotten rooms.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
user_id (str): The user ID.
|
|
|
|
|
membership_list (list): A list of synapse.api.constants.Membership
|
|
|
|
|
values which the user must be in.
|
|
|
|
|
user_id: The user ID.
|
|
|
|
|
membership_list: A list of synapse.api.constants.Membership
|
|
|
|
|
values which the user must be in.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Deferred[list[RoomsForUser]]
|
|
|
|
|
The RoomsForUser that the user matches the membership types.
|
|
|
|
|
"""
|
|
|
|
|
if not membership_list:
|
|
|
|
|
return defer.succeed(None)
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
rooms = yield self.db_pool.runInteraction(
|
|
|
|
|
rooms = await self.db_pool.runInteraction(
|
|
|
|
|
"get_rooms_for_local_user_where_membership_is",
|
|
|
|
|
self._get_rooms_for_local_user_where_membership_is_txn,
|
|
|
|
|
user_id,
|
|
|
|
@ -340,12 +326,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Now we filter out forgotten rooms
|
|
|
|
|
forgotten_rooms = yield self.get_forgotten_rooms_for_user(user_id)
|
|
|
|
|
forgotten_rooms = await self.get_forgotten_rooms_for_user(user_id)
|
|
|
|
|
return [room for room in rooms if room.room_id not in forgotten_rooms]
|
|
|
|
|
|
|
|
|
|
def _get_rooms_for_local_user_where_membership_is_txn(
|
|
|
|
|
self, txn, user_id, membership_list
|
|
|
|
|
):
|
|
|
|
|
self, txn, user_id: str, membership_list: List[str]
|
|
|
|
|
) -> List[RoomsForUser]:
|
|
|
|
|
# Paranoia check.
|
|
|
|
|
if not self.hs.is_mine_id(user_id):
|
|
|
|
|
raise Exception(
|
|
|
|
@ -374,14 +360,14 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
|
@cached(max_entries=500000, iterable=True)
|
|
|
|
|
def get_rooms_for_user_with_stream_ordering(self, user_id):
|
|
|
|
|
def get_rooms_for_user_with_stream_ordering(self, user_id: str):
|
|
|
|
|
"""Returns a set of room_ids the user is currently joined to.
|
|
|
|
|
|
|
|
|
|
If a remote user only returns rooms this server is currently
|
|
|
|
|
participating in.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
user_id (str)
|
|
|
|
|
user_id
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns
|
|
|
|
@ -394,7 +380,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|
|
|
|
user_id,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id):
|
|
|
|
|
def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id: str):
|
|
|
|
|
# We use `current_state_events` here and not `local_current_membership`
|
|
|
|
|
# as a) this gets called with remote users and b) this only gets called
|
|
|
|
|
# for rooms the server is participating in.
|
|
|
|
@ -458,37 +444,39 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|
|
|
|
_get_users_server_still_shares_room_with_txn,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@defer.inlineCallbacks
|
|
|
|
|
def get_rooms_for_user(self, user_id, on_invalidate=None):
|
|
|
|
|
async def get_rooms_for_user(self, user_id: str, on_invalidate=None):
|
|
|
|
|
"""Returns a set of room_ids the user is currently joined to.
|
|
|
|
|
|
|
|
|
|
If a remote user only returns rooms this server is currently
|
|
|
|
|
participating in.
|
|
|
|
|
"""
|
|
|
|
|
rooms = yield self.get_rooms_for_user_with_stream_ordering(
|
|
|
|
|
rooms = await self.get_rooms_for_user_with_stream_ordering(
|
|
|
|
|
user_id, on_invalidate=on_invalidate
|
|
|
|
|
)
|
|
|
|
|
return frozenset(r.room_id for r in rooms)
|
|
|
|
|
|
|
|
|
|
@cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True)
|
|
|
|
|
def get_users_who_share_room_with_user(self, user_id, cache_context):
|
|
|
|
|
@cached(max_entries=500000, cache_context=True, iterable=True)
|
|
|
|
|
async def get_users_who_share_room_with_user(
|
|
|
|
|
self, user_id: str, cache_context: _CacheContext
|
|
|
|
|
) -> Set[str]:
|
|
|
|
|
"""Returns the set of users who share a room with `user_id`
|
|
|
|
|
"""
|
|
|
|
|
room_ids = yield self.get_rooms_for_user(
|
|
|
|
|
room_ids = await self.get_rooms_for_user(
|
|
|
|
|
user_id, on_invalidate=cache_context.invalidate
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
user_who_share_room = set()
|
|
|
|
|
for room_id in room_ids:
|
|
|
|
|
user_ids = yield self.get_users_in_room(
|
|
|
|
|
user_ids = await self.get_users_in_room(
|
|
|
|
|
room_id, on_invalidate=cache_context.invalidate
|
|
|
|
|
)
|
|
|
|
|
user_who_share_room.update(user_ids)
|
|
|
|
|
|
|
|
|
|
return user_who_share_room
|
|
|
|
|
|
|
|
|
|
@defer.inlineCallbacks
|
|
|
|
|
def get_joined_users_from_context(self, event, context):
|
|
|
|
|
async def get_joined_users_from_context(
|
|
|
|
|
self, event: EventBase, context: EventContext
|
|
|
|
|
):
|
|
|
|
|
state_group = context.state_group
|
|
|
|
|
if not state_group:
|
|
|
|
|
# If state_group is None it means it has yet to be assigned a
|
|
|
|
@ -497,14 +485,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|
|
|
|
# To do this we set the state_group to a new object as object() != object()
|
|
|
|
|
state_group = object()
|
|
|
|
|
|
|
|
|
|
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
|
|
|
|
|
result = yield self._get_joined_users_from_context(
|
|
|
|
|
current_state_ids = await context.get_current_state_ids()
|
|
|
|
|
return await self._get_joined_users_from_context(
|
|
|
|
|
event.room_id, state_group, current_state_ids, event=event, context=context
|
|
|
|
|
)
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
@defer.inlineCallbacks
|
|
|
|
|
def get_joined_users_from_state(self, room_id, state_entry):
|
|
|
|
|
async def get_joined_users_from_state(self, room_id, state_entry):
|
|
|
|
|
state_group = state_entry.state_group
|
|
|
|
|
if not state_group:
|
|
|
|
|
# If state_group is None it means it has yet to be assigned a
|
|
|
|
@ -514,16 +500,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|
|
|
|
state_group = object()
|
|
|
|
|
|
|
|
|
|
with Measure(self._clock, "get_joined_users_from_state"):
|
|
|
|
|
return (
|
|
|
|
|
yield self._get_joined_users_from_context(
|
|
|
|
|
room_id, state_group, state_entry.state, context=state_entry
|
|
|
|
|
)
|
|
|
|
|
return await self._get_joined_users_from_context(
|
|
|
|
|
room_id, state_group, state_entry.state, context=state_entry
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@cachedInlineCallbacks(
|
|
|
|
|
num_args=2, cache_context=True, iterable=True, max_entries=100000
|
|
|
|
|
)
|
|
|
|
|
def _get_joined_users_from_context(
|
|
|
|
|
@cached(num_args=2, cache_context=True, iterable=True, max_entries=100000)
|
|
|
|
|
async def _get_joined_users_from_context(
|
|
|
|
|
self,
|
|
|
|
|
room_id,
|
|
|
|
|
state_group,
|
|
|
|
@ -535,7 +517,6 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|
|
|
|
# We don't use `state_group`, it's there so that we can cache based
|
|
|
|
|
# on it. However, it's important that it's never None, since two current_states
|
|
|
|
|
# with a state_group of None are likely to be different.
|
|
|
|
|
# See bulk_get_push_rules_for_room for how we work around this.
|
|
|
|
|
assert state_group is not None
|
|
|
|
|
|
|
|
|
|
users_in_room = {}
|
|
|
|
@ -588,7 +569,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|
|
|
|
missing_member_event_ids.append(event_id)
|
|
|
|
|
|
|
|
|
|
if missing_member_event_ids:
|
|
|
|
|
event_to_memberships = yield self._get_joined_profiles_from_event_ids(
|
|
|
|
|
event_to_memberships = await self._get_joined_profiles_from_event_ids(
|
|
|
|
|
missing_member_event_ids
|
|
|
|
|
)
|
|
|
|
|
users_in_room.update((row for row in event_to_memberships.values() if row))
|
|
|
|
@ -612,12 +593,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|
|
|
|
list_name="event_ids",
|
|
|
|
|
inlineCallbacks=True,
|
|
|
|
|
)
|
|
|
|
|
def _get_joined_profiles_from_event_ids(self, event_ids):
|
|
|
|
|
def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
|
|
|
|
|
"""For given set of member event_ids check if they point to a join
|
|
|
|
|
event and if so return the associated user and profile info.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
event_ids (Iterable[str]): The member event IDs to lookup
|
|
|
|
|
event_ids: The member event IDs to lookup
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID
|
|
|
|
@ -644,8 +625,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|
|
|
|
for row in rows
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@cachedInlineCallbacks(max_entries=10000)
|
|
|
|
|
def is_host_joined(self, room_id, host):
|
|
|
|
|
@cached(max_entries=10000)
|
|
|
|
|
async def is_host_joined(self, room_id: str, host: str) -> bool:
|
|
|
|
|
if "%" in host or "_" in host:
|
|
|
|
|
raise Exception("Invalid host name")
|
|
|
|
|
|
|
|
|
@ -664,7 +645,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|
|
|
|
# the returned user actually has the correct domain.
|
|
|
|
|
like_clause = "%:" + host
|
|
|
|
|
|
|
|
|
|
rows = yield self.db_pool.execute(
|
|
|
|
|
rows = await self.db_pool.execute(
|
|
|
|
|
"is_host_joined", None, sql, room_id, like_clause
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
@ -678,50 +659,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|
|
|
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
@cachedInlineCallbacks()
|
|
|
|
|
def was_host_joined(self, room_id, host):
|
|
|
|
|
"""Check whether the server is or ever was in the room.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
room_id (str)
|
|
|
|
|
host (str)
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Deferred: Resolves to True if the host is/was in the room, otherwise
|
|
|
|
|
False.
|
|
|
|
|
"""
|
|
|
|
|
if "%" in host or "_" in host:
|
|
|
|
|
raise Exception("Invalid host name")
|
|
|
|
|
|
|
|
|
|
sql = """
|
|
|
|
|
SELECT user_id FROM room_memberships
|
|
|
|
|
WHERE room_id = ?
|
|
|
|
|
AND user_id LIKE ?
|
|
|
|
|
AND membership = 'join'
|
|
|
|
|
LIMIT 1
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
# We do need to be careful to ensure that host doesn't have any wild cards
|
|
|
|
|
# in it, but we checked above for known ones and we'll check below that
|
|
|
|
|
# the returned user actually has the correct domain.
|
|
|
|
|
like_clause = "%:" + host
|
|
|
|
|
|
|
|
|
|
rows = yield self.db_pool.execute(
|
|
|
|
|
"was_host_joined", None, sql, room_id, like_clause
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if not rows:
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
user_id = rows[0][0]
|
|
|
|
|
if get_domain_from_id(user_id) != host:
|
|
|
|
|
# This can only happen if the host name has something funky in it
|
|
|
|
|
raise Exception("Invalid host name")
|
|
|
|
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
@defer.inlineCallbacks
|
|
|
|
|
def get_joined_hosts(self, room_id, state_entry):
|
|
|
|
|
async def get_joined_hosts(self, room_id: str, state_entry):
|
|
|
|
|
state_group = state_entry.state_group
|
|
|
|
|
if not state_group:
|
|
|
|
|
# If state_group is None it means it has yet to be assigned a
|
|
|
|
@ -731,32 +669,28 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|
|
|
|
state_group = object()
|
|
|
|
|
|
|
|
|
|
with Measure(self._clock, "get_joined_hosts"):
|
|
|
|
|
return (
|
|
|
|
|
yield self._get_joined_hosts(
|
|
|
|
|
room_id, state_group, state_entry.state, state_entry=state_entry
|
|
|
|
|
)
|
|
|
|
|
return await self._get_joined_hosts(
|
|
|
|
|
room_id, state_group, state_entry.state, state_entry=state_entry
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True)
|
|
|
|
|
# @defer.inlineCallbacks
|
|
|
|
|
def _get_joined_hosts(self, room_id, state_group, current_state_ids, state_entry):
|
|
|
|
|
@cached(num_args=2, max_entries=10000, iterable=True)
|
|
|
|
|
async def _get_joined_hosts(
|
|
|
|
|
self, room_id, state_group, current_state_ids, state_entry
|
|
|
|
|
):
|
|
|
|
|
# We don't use `state_group`, its there so that we can cache based
|
|
|
|
|
# on it. However, its important that its never None, since two current_state's
|
|
|
|
|
# with a state_group of None are likely to be different.
|
|
|
|
|
# See bulk_get_push_rules_for_room for how we work around this.
|
|
|
|
|
assert state_group is not None
|
|
|
|
|
|
|
|
|
|
cache = yield self._get_joined_hosts_cache(room_id)
|
|
|
|
|
joined_hosts = yield cache.get_destinations(state_entry)
|
|
|
|
|
|
|
|
|
|
return joined_hosts
|
|
|
|
|
cache = await self._get_joined_hosts_cache(room_id)
|
|
|
|
|
return await cache.get_destinations(state_entry)
|
|
|
|
|
|
|
|
|
|
@cached(max_entries=10000)
|
|
|
|
|
def _get_joined_hosts_cache(self, room_id):
|
|
|
|
|
def _get_joined_hosts_cache(self, room_id: str) -> "_JoinedHostsCache":
|
|
|
|
|
return _JoinedHostsCache(self, room_id)
|
|
|
|
|
|
|
|
|
|
@cachedInlineCallbacks(num_args=2)
|
|
|
|
|
def did_forget(self, user_id, room_id):
|
|
|
|
|
@cached(num_args=2)
|
|
|
|
|
async def did_forget(self, user_id: str, room_id: str) -> bool:
|
|
|
|
|
"""Returns whether user_id has elected to discard history for room_id.
|
|
|
|
|
|
|
|
|
|
Returns False if they have since re-joined."""
|
|
|
|
@ -778,15 +712,15 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|
|
|
|
rows = txn.fetchall()
|
|
|
|
|
return rows[0][0]
|
|
|
|
|
|
|
|
|
|
count = yield self.db_pool.runInteraction("did_forget_membership", f)
|
|
|
|
|
count = await self.db_pool.runInteraction("did_forget_membership", f)
|
|
|
|
|
return count == 0
|
|
|
|
|
|
|
|
|
|
@cached()
|
|
|
|
|
def get_forgotten_rooms_for_user(self, user_id):
|
|
|
|
|
def get_forgotten_rooms_for_user(self, user_id: str):
|
|
|
|
|
"""Gets all rooms the user has forgotten.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
user_id (str)
|
|
|
|
|
user_id
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Deferred[set[str]]
|
|
|
|
@ -819,18 +753,17 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|
|
|
|
"get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@defer.inlineCallbacks
|
|
|
|
|
def get_rooms_user_has_been_in(self, user_id):
|
|
|
|
|
async def get_rooms_user_has_been_in(self, user_id: str) -> Set[str]:
|
|
|
|
|
"""Get all rooms that the user has ever been in.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
user_id (str)
|
|
|
|
|
user_id: The user ID to get the rooms of.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Deferred[set[str]]: Set of room IDs.
|
|
|
|
|
Set of room IDs.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
room_ids = yield self.db_pool.simple_select_onecol(
|
|
|
|
|
room_ids = await self.db_pool.simple_select_onecol(
|
|
|
|
|
table="room_memberships",
|
|
|
|
|
keyvalues={"membership": Membership.JOIN, "user_id": user_id},
|
|
|
|
|
retcol="room_id",
|
|
|
|
@ -905,8 +838,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
|
|
|
|
|
where_clause="forgotten = 1",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@defer.inlineCallbacks
|
|
|
|
|
def _background_add_membership_profile(self, progress, batch_size):
|
|
|
|
|
async def _background_add_membership_profile(self, progress, batch_size):
|
|
|
|
|
target_min_stream_id = progress.get(
|
|
|
|
|
"target_min_stream_id_inclusive", self._min_stream_order_on_start
|
|
|
|
|
)
|
|
|
|
@ -971,19 +903,18 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
|
|
|
|
|
|
|
|
|
|
return len(rows)
|
|
|
|
|
|
|
|
|
|
result = yield self.db_pool.runInteraction(
|
|
|
|
|
result = await self.db_pool.runInteraction(
|
|
|
|
|
_MEMBERSHIP_PROFILE_UPDATE_NAME, add_membership_profile_txn
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if not result:
|
|
|
|
|
yield self.db_pool.updates._end_background_update(
|
|
|
|
|
await self.db_pool.updates._end_background_update(
|
|
|
|
|
_MEMBERSHIP_PROFILE_UPDATE_NAME
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
@defer.inlineCallbacks
|
|
|
|
|
def _background_current_state_membership(self, progress, batch_size):
|
|
|
|
|
async def _background_current_state_membership(self, progress, batch_size):
|
|
|
|
|
"""Update the new membership column on current_state_events.
|
|
|
|
|
|
|
|
|
|
This works by iterating over all rooms in alphebetical order.
|
|
|
|
@ -1029,14 +960,14 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
|
|
|
|
|
# string, which will compare before all room IDs correctly.
|
|
|
|
|
last_processed_room = progress.get("last_processed_room", "")
|
|
|
|
|
|
|
|
|
|
row_count, finished = yield self.db_pool.runInteraction(
|
|
|
|
|
row_count, finished = await self.db_pool.runInteraction(
|
|
|
|
|
"_background_current_state_membership_update",
|
|
|
|
|
_background_current_state_membership_txn,
|
|
|
|
|
last_processed_room,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if finished:
|
|
|
|
|
yield self.db_pool.updates._end_background_update(
|
|
|
|
|
await self.db_pool.updates._end_background_update(
|
|
|
|
|
_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
@ -1047,7 +978,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
|
|
|
|
|
def __init__(self, database: DatabasePool, db_conn, hs):
|
|
|
|
|
super(RoomMemberStore, self).__init__(database, db_conn, hs)
|
|
|
|
|
|
|
|
|
|
def forget(self, user_id, room_id):
|
|
|
|
|
def forget(self, user_id: str, room_id: str):
|
|
|
|
|
"""Indicate that user_id wishes to discard history for room_id."""
|
|
|
|
|
|
|
|
|
|
def f(txn):
|
|
|
|
@ -1088,17 +1019,19 @@ class _JoinedHostsCache(object):
|
|
|
|
|
|
|
|
|
|
self._len = 0
|
|
|
|
|
|
|
|
|
|
@defer.inlineCallbacks
|
|
|
|
|
def get_destinations(self, state_entry):
|
|
|
|
|
async def get_destinations(self, state_entry: "_StateCacheEntry") -> Set[str]:
|
|
|
|
|
"""Get set of destinations for a state entry
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
state_entry(synapse.state._StateCacheEntry)
|
|
|
|
|
state_entry
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
The destinations as a set.
|
|
|
|
|
"""
|
|
|
|
|
if state_entry.state_group == self.state_group:
|
|
|
|
|
return frozenset(self.hosts_to_joined_users)
|
|
|
|
|
|
|
|
|
|
with (yield self.linearizer.queue(())):
|
|
|
|
|
with (await self.linearizer.queue(())):
|
|
|
|
|
if state_entry.state_group == self.state_group:
|
|
|
|
|
pass
|
|
|
|
|
elif state_entry.prev_group == self.state_group:
|
|
|
|
@ -1110,7 +1043,7 @@ class _JoinedHostsCache(object):
|
|
|
|
|
user_id = state_key
|
|
|
|
|
known_joins = self.hosts_to_joined_users.setdefault(host, set())
|
|
|
|
|
|
|
|
|
|
event = yield self.store.get_event(event_id)
|
|
|
|
|
event = await self.store.get_event(event_id)
|
|
|
|
|
if event.membership == Membership.JOIN:
|
|
|
|
|
known_joins.add(user_id)
|
|
|
|
|
else:
|
|
|
|
@ -1119,7 +1052,7 @@ class _JoinedHostsCache(object):
|
|
|
|
|
if not known_joins:
|
|
|
|
|
self.hosts_to_joined_users.pop(host, None)
|
|
|
|
|
else:
|
|
|
|
|
joined_users = yield self.store.get_joined_users_from_state(
|
|
|
|
|
joined_users = await self.store.get_joined_users_from_state(
|
|
|
|
|
self.room_id, state_entry
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|