Merge pull request #1049 from matrix-org/erikj/presence_users_in_room
Use state handler instead of get_users_in_room/get_joined_hostspull/1051/head
commit
55fc17cf4b
|
@ -29,6 +29,7 @@ from synapse.util.caches.expiringcache import ExpiringCache
|
|||
from synapse.util.logutils import log_function
|
||||
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
||||
from synapse.events import FrozenEvent
|
||||
from synapse.types import get_domain_from_id
|
||||
import synapse.metrics
|
||||
|
||||
from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
|
||||
|
@ -63,6 +64,7 @@ class FederationClient(FederationBase):
|
|||
self._clock.looping_call(
|
||||
self._clear_tried_cache, 60 * 1000,
|
||||
)
|
||||
self.state = hs.get_state_handler()
|
||||
|
||||
def _clear_tried_cache(self):
|
||||
"""Clear pdu_destination_tried cache"""
|
||||
|
@ -811,7 +813,8 @@ class FederationClient(FederationBase):
|
|||
if len(signed_events) >= limit:
|
||||
defer.returnValue(signed_events)
|
||||
|
||||
servers = yield self.store.get_joined_hosts_for_room(room_id)
|
||||
users = yield self.state.get_current_user_in_room(room_id)
|
||||
servers = set(get_domain_from_id(u) for u in users)
|
||||
|
||||
servers = set(servers)
|
||||
servers.discard(self.server_name)
|
||||
|
|
|
@ -19,7 +19,7 @@ from ._base import BaseHandler
|
|||
|
||||
from synapse.api.errors import SynapseError, Codes, CodeMessageException, AuthError
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.types import RoomAlias, UserID
|
||||
from synapse.types import RoomAlias, UserID, get_domain_from_id
|
||||
|
||||
import logging
|
||||
import string
|
||||
|
@ -55,7 +55,8 @@ class DirectoryHandler(BaseHandler):
|
|||
# TODO(erikj): Add transactions.
|
||||
# TODO(erikj): Check if there is a current association.
|
||||
if not servers:
|
||||
servers = yield self.store.get_joined_hosts_for_room(room_id)
|
||||
users = yield self.state.get_current_user_in_room(room_id)
|
||||
servers = set(get_domain_from_id(u) for u in users)
|
||||
|
||||
if not servers:
|
||||
raise SynapseError(400, "Failed to get server list")
|
||||
|
@ -193,7 +194,8 @@ class DirectoryHandler(BaseHandler):
|
|||
Codes.NOT_FOUND
|
||||
)
|
||||
|
||||
extra_servers = yield self.store.get_joined_hosts_for_room(room_id)
|
||||
users = yield self.state.get_current_user_in_room(room_id)
|
||||
extra_servers = set(get_domain_from_id(u) for u in users)
|
||||
servers = set(extra_servers) | set(servers)
|
||||
|
||||
# If this server is in the list of servers, return it first.
|
||||
|
|
|
@ -47,6 +47,7 @@ class EventStreamHandler(BaseHandler):
|
|||
self.clock = hs.get_clock()
|
||||
|
||||
self.notifier = hs.get_notifier()
|
||||
self.state = hs.get_state_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
|
@ -90,7 +91,7 @@ class EventStreamHandler(BaseHandler):
|
|||
# Send down presence.
|
||||
if event.state_key == auth_user_id:
|
||||
# Send down presence for everyone in the room.
|
||||
users = yield self.store.get_users_in_room(event.room_id)
|
||||
users = yield self.state.get_current_user_in_room(event.room_id)
|
||||
states = yield presence_handler.get_states(
|
||||
users,
|
||||
as_event=True,
|
||||
|
|
|
@ -88,6 +88,8 @@ class PresenceHandler(object):
|
|||
self.notifier = hs.get_notifier()
|
||||
self.federation = hs.get_replication_layer()
|
||||
|
||||
self.state = hs.get_state_handler()
|
||||
|
||||
self.federation.register_edu_handler(
|
||||
"m.presence", self.incoming_presence
|
||||
)
|
||||
|
@ -532,7 +534,9 @@ class PresenceHandler(object):
|
|||
if not local_states:
|
||||
continue
|
||||
|
||||
hosts = yield self.store.get_joined_hosts_for_room(room_id)
|
||||
users = yield self.state.get_current_user_in_room(room_id)
|
||||
hosts = set(get_domain_from_id(u) for u in users)
|
||||
|
||||
for host in hosts:
|
||||
hosts_to_states.setdefault(host, []).extend(local_states)
|
||||
|
||||
|
@ -725,13 +729,13 @@ class PresenceHandler(object):
|
|||
# don't need to send to local clients here, as that is done as part
|
||||
# of the event stream/sync.
|
||||
# TODO: Only send to servers not already in the room.
|
||||
user_ids = yield self.state.get_current_user_in_room(room_id)
|
||||
if self.is_mine(user):
|
||||
state = yield self.current_state_for_user(user.to_string())
|
||||
|
||||
hosts = yield self.store.get_joined_hosts_for_room(room_id)
|
||||
hosts = set(get_domain_from_id(u) for u in user_ids)
|
||||
self._push_to_remotes({host: (state,) for host in hosts})
|
||||
else:
|
||||
user_ids = yield self.store.get_users_in_room(room_id)
|
||||
user_ids = filter(self.is_mine_id, user_ids)
|
||||
|
||||
states = yield self.current_state_for_users(user_ids)
|
||||
|
@ -955,6 +959,7 @@ class PresenceEventSource(object):
|
|||
self.get_presence_handler = hs.get_presence_handler
|
||||
self.clock = hs.get_clock()
|
||||
self.store = hs.get_datastore()
|
||||
self.state = hs.get_state_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
|
@ -1017,7 +1022,7 @@ class PresenceEventSource(object):
|
|||
|
||||
user_ids_to_check = set()
|
||||
for room_id in room_ids:
|
||||
users = yield self.store.get_users_in_room(room_id)
|
||||
users = yield self.state.get_current_user_in_room(room_id)
|
||||
user_ids_to_check.update(users)
|
||||
|
||||
user_ids_to_check.update(friends)
|
||||
|
|
|
@ -18,6 +18,7 @@ from ._base import BaseHandler
|
|||
from twisted.internet import defer
|
||||
|
||||
from synapse.util.logcontext import PreserveLoggingContext
|
||||
from synapse.types import get_domain_from_id
|
||||
|
||||
import logging
|
||||
|
||||
|
@ -37,6 +38,7 @@ class ReceiptsHandler(BaseHandler):
|
|||
"m.receipt", self._received_remote_receipt
|
||||
)
|
||||
self.clock = self.hs.get_clock()
|
||||
self.state = hs.get_state_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def received_client_receipt(self, room_id, receipt_type, user_id,
|
||||
|
@ -133,7 +135,8 @@ class ReceiptsHandler(BaseHandler):
|
|||
event_ids = receipt["event_ids"]
|
||||
data = receipt["data"]
|
||||
|
||||
remotedomains = yield self.store.get_joined_hosts_for_room(room_id)
|
||||
users = yield self.state.get_current_user_in_room(room_id)
|
||||
remotedomains = set(get_domain_from_id(u) for u in users)
|
||||
remotedomains = remotedomains.copy()
|
||||
remotedomains.discard(self.server_name)
|
||||
|
||||
|
|
|
@ -142,6 +142,7 @@ class SyncHandler(object):
|
|||
self.event_sources = hs.get_event_sources()
|
||||
self.clock = hs.get_clock()
|
||||
self.response_cache = ResponseCache(hs)
|
||||
self.state = hs.get_state_handler()
|
||||
|
||||
def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0,
|
||||
full_state=False):
|
||||
|
@ -670,7 +671,7 @@ class SyncHandler(object):
|
|||
|
||||
extra_users_ids = set(newly_joined_users)
|
||||
for room_id in newly_joined_rooms:
|
||||
users = yield self.store.get_users_in_room(room_id)
|
||||
users = yield self.state.get_current_user_in_room(room_id)
|
||||
extra_users_ids.update(users)
|
||||
extra_users_ids.discard(user.to_string())
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ from synapse.util.logcontext import (
|
|||
PreserveLoggingContext, preserve_fn, preserve_context_over_deferred,
|
||||
)
|
||||
from synapse.util.metrics import Measure
|
||||
from synapse.types import UserID
|
||||
from synapse.types import UserID, get_domain_from_id
|
||||
|
||||
import logging
|
||||
|
||||
|
@ -42,6 +42,7 @@ class TypingHandler(object):
|
|||
self.auth = hs.get_auth()
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
self.notifier = hs.get_notifier()
|
||||
self.state = hs.get_state_handler()
|
||||
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
|
@ -166,7 +167,8 @@ class TypingHandler(object):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def _push_update(self, room_id, user_id, typing):
|
||||
domains = yield self.store.get_joined_hosts_for_room(room_id)
|
||||
users = yield self.state.get_current_user_in_room(room_id)
|
||||
domains = set(get_domain_from_id(u) for u in users)
|
||||
|
||||
deferreds = []
|
||||
for domain in domains:
|
||||
|
@ -199,7 +201,8 @@ class TypingHandler(object):
|
|||
# Check that the string is a valid user id
|
||||
UserID.from_string(user_id)
|
||||
|
||||
domains = yield self.store.get_joined_hosts_for_room(room_id)
|
||||
users = yield self.state.get_current_user_in_room(room_id)
|
||||
domains = set(get_domain_from_id(u) for u in users)
|
||||
|
||||
if self.server_name in domains:
|
||||
self._push_update_local(
|
||||
|
|
|
@ -87,7 +87,7 @@ class BulkPushRuleEvaluator:
|
|||
)
|
||||
|
||||
room_members = yield self.store.get_joined_users_from_context(
|
||||
event.room_id, context,
|
||||
event.room_id, context.state_group, context.current_state_ids
|
||||
)
|
||||
|
||||
evaluator = PushRuleEvaluatorForEvent(event, len(room_members))
|
||||
|
|
|
@ -123,6 +123,11 @@ class SlavedEventStore(BaseSlavedStore):
|
|||
get_state_groups_ids = DataStore.get_state_groups_ids.__func__
|
||||
get_state_ids_for_event = DataStore.get_state_ids_for_event.__func__
|
||||
get_state_ids_for_events = DataStore.get_state_ids_for_events.__func__
|
||||
get_joined_users_from_context = DataStore.get_joined_users_from_context.__func__
|
||||
_get_joined_users_from_context = (
|
||||
RoomMemberStore.__dict__["_get_joined_users_from_context"]
|
||||
)
|
||||
|
||||
get_recent_events_for_room = DataStore.get_recent_events_for_room.__func__
|
||||
get_room_events_stream_for_rooms = (
|
||||
DataStore.get_room_events_stream_for_rooms.__func__
|
||||
|
@ -216,7 +221,6 @@ class SlavedEventStore(BaseSlavedStore):
|
|||
self._get_current_state_for_key.invalidate_all()
|
||||
self.get_rooms_for_user.invalidate_all()
|
||||
self.get_users_in_room.invalidate((event.room_id,))
|
||||
# self.get_joined_hosts_for_room.invalidate((event.room_id,))
|
||||
|
||||
self._invalidate_get_event_cache(event.event_id)
|
||||
|
||||
|
@ -240,7 +244,6 @@ class SlavedEventStore(BaseSlavedStore):
|
|||
|
||||
if event.type == EventTypes.Member:
|
||||
self.get_rooms_for_user.invalidate((event.state_key,))
|
||||
# self.get_joined_hosts_for_room.invalidate((event.room_id,))
|
||||
self.get_users_in_room.invalidate((event.room_id,))
|
||||
self._membership_stream_cache.entity_has_changed(
|
||||
event.state_key, event.internal_metadata.stream_ordering
|
||||
|
|
|
@ -124,6 +124,15 @@ class StateHandler(object):
|
|||
|
||||
defer.returnValue(state)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_current_user_in_room(self, room_id):
|
||||
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
|
||||
group, state_ids = yield self.resolve_state_groups(room_id, latest_event_ids)
|
||||
joined_users = yield self.store.get_joined_users_from_context(
|
||||
room_id, group, state_ids
|
||||
)
|
||||
defer.returnValue(joined_users)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def compute_event_context(self, event, old_state=None):
|
||||
""" Fills out the context with the `current state` of the graph. The
|
||||
|
|
|
@ -393,7 +393,6 @@ class EventsStore(SQLBaseStore):
|
|||
txn.call_after(self._get_current_state_for_key.invalidate_all)
|
||||
txn.call_after(self.get_rooms_for_user.invalidate_all)
|
||||
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
|
||||
txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
|
||||
|
||||
# Add an entry to the current_state_resets table to record the point
|
||||
# where we clobbered the current state
|
||||
|
|
|
@ -56,7 +56,6 @@ class RoomMemberStore(SQLBaseStore):
|
|||
|
||||
for event in events:
|
||||
txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,))
|
||||
txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
|
||||
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
|
||||
txn.call_after(
|
||||
self._membership_stream_cache.entity_has_changed,
|
||||
|
@ -238,11 +237,6 @@ class RoomMemberStore(SQLBaseStore):
|
|||
|
||||
return results
|
||||
|
||||
@cachedInlineCallbacks(max_entries=5000)
|
||||
def get_joined_hosts_for_room(self, room_id):
|
||||
user_ids = yield self.get_users_in_room(room_id)
|
||||
defer.returnValue(set(get_domain_from_id(uid) for uid in user_ids))
|
||||
|
||||
def _get_members_rows_txn(self, txn, room_id, membership=None, user_id=None):
|
||||
where_clause = "c.room_id = ?"
|
||||
where_values = [room_id]
|
||||
|
@ -360,8 +354,7 @@ class RoomMemberStore(SQLBaseStore):
|
|||
desc="who_forgot"
|
||||
)
|
||||
|
||||
def get_joined_users_from_context(self, room_id, context):
|
||||
state_group = context.state_group
|
||||
def get_joined_users_from_context(self, room_id, state_group, state_ids):
|
||||
if not state_group:
|
||||
# If state_group is None it means it has yet to be assigned a
|
||||
# state group, i.e. we need to make sure that calls with a state_group
|
||||
|
@ -370,7 +363,7 @@ class RoomMemberStore(SQLBaseStore):
|
|||
state_group = object()
|
||||
|
||||
return self._get_joined_users_from_context(
|
||||
room_id, state_group, context.current_state_ids
|
||||
room_id, state_group, state_ids
|
||||
)
|
||||
|
||||
@cachedInlineCallbacks(num_args=2, cache_context=True)
|
||||
|
|
|
@ -62,6 +62,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
|||
self.on_new_event = mock_notifier.on_new_event
|
||||
|
||||
self.auth = Mock(spec=[])
|
||||
self.state_handler = Mock()
|
||||
|
||||
hs = yield setup_test_homeserver(
|
||||
"test",
|
||||
|
@ -75,6 +76,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
|||
"set_received_txn_response",
|
||||
"get_destination_retry_timings",
|
||||
]),
|
||||
state_handler=self.state_handler,
|
||||
handlers=None,
|
||||
notifier=mock_notifier,
|
||||
resource_for_client=Mock(),
|
||||
|
@ -113,6 +115,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
|||
return set(member.domain for member in self.room_members)
|
||||
self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
|
||||
|
||||
def get_current_user_in_room(room_id):
|
||||
return set(str(u) for u in self.room_members)
|
||||
self.state_handler.get_current_user_in_room = get_current_user_in_room
|
||||
|
||||
self.auth.check_joined_room = check_joined_room
|
||||
|
||||
# Some local users to test with
|
||||
|
|
|
@ -78,44 +78,3 @@ class RoomMemberStoreTestCase(unittest.TestCase):
|
|||
)
|
||||
)]
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_room_hosts(self):
|
||||
yield self.inject_room_member(self.room, self.u_alice, Membership.JOIN)
|
||||
|
||||
self.assertEquals(
|
||||
{"test"},
|
||||
(yield self.store.get_joined_hosts_for_room(self.room.to_string()))
|
||||
)
|
||||
|
||||
# Should still have just one host after second join from it
|
||||
yield self.inject_room_member(self.room, self.u_bob, Membership.JOIN)
|
||||
|
||||
self.assertEquals(
|
||||
{"test"},
|
||||
(yield self.store.get_joined_hosts_for_room(self.room.to_string()))
|
||||
)
|
||||
|
||||
# Should now have two hosts after join from other host
|
||||
yield self.inject_room_member(self.room, self.u_charlie, Membership.JOIN)
|
||||
|
||||
self.assertEquals(
|
||||
{"test", "elsewhere"},
|
||||
(yield self.store.get_joined_hosts_for_room(self.room.to_string()))
|
||||
)
|
||||
|
||||
# Should still have both hosts
|
||||
yield self.inject_room_member(self.room, self.u_alice, Membership.LEAVE)
|
||||
|
||||
self.assertEquals(
|
||||
{"test", "elsewhere"},
|
||||
(yield self.store.get_joined_hosts_for_room(self.room.to_string()))
|
||||
)
|
||||
|
||||
# Should have only one host after other leaves
|
||||
yield self.inject_room_member(self.room, self.u_charlie, Membership.LEAVE)
|
||||
|
||||
self.assertEquals(
|
||||
{"test"},
|
||||
(yield self.store.get_joined_hosts_for_room(self.room.to_string()))
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue