Refactor state group lookup to reduce DB hits (#4011)

Currently when fetching state groups from the data store we make two
hits two the database: once for members and once for non-members (unless
request is filtered to one or the other). This adds needless load to the
datbase, so this PR refactors the lookup to make only a single database
hit.
pull/4106/head
Erik Johnston 2018-10-25 17:49:55 +01:00 committed by GitHub
parent e5da60d75d
commit cb53ce9d64
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 717 additions and 488 deletions

1
changelog.d/4011.misc Normal file
View File

@ -0,0 +1 @@
Reduce database load when fetching state groups

View File

@ -156,7 +156,7 @@ class InitialSyncHandler(BaseHandler):
room_end_token = "s%d" % (event.stream_ordering,) room_end_token = "s%d" % (event.stream_ordering,)
deferred_room_state = run_in_background( deferred_room_state = run_in_background(
self.store.get_state_for_events, self.store.get_state_for_events,
[event.event_id], None, [event.event_id],
) )
deferred_room_state.addCallback( deferred_room_state.addCallback(
lambda states: states[event.event_id] lambda states: states[event.event_id]
@ -301,7 +301,7 @@ class InitialSyncHandler(BaseHandler):
def _room_initial_sync_parted(self, user_id, room_id, pagin_config, def _room_initial_sync_parted(self, user_id, room_id, pagin_config,
membership, member_event_id, is_peeking): membership, member_event_id, is_peeking):
room_state = yield self.store.get_state_for_events( room_state = yield self.store.get_state_for_events(
[member_event_id], None [member_event_id],
) )
room_state = room_state[member_event_id] room_state = room_state[member_event_id]

View File

@ -35,6 +35,7 @@ from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.replication.http.send_event import ReplicationSendEventRestServlet from synapse.replication.http.send_event import ReplicationSendEventRestServlet
from synapse.storage.state import StateFilter
from synapse.types import RoomAlias, UserID from synapse.types import RoomAlias, UserID
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.frozenutils import frozendict_json_encoder
@ -80,7 +81,7 @@ class MessageHandler(object):
elif membership == Membership.LEAVE: elif membership == Membership.LEAVE:
key = (event_type, state_key) key = (event_type, state_key)
room_state = yield self.store.get_state_for_events( room_state = yield self.store.get_state_for_events(
[membership_event_id], [key] [membership_event_id], StateFilter.from_types([key])
) )
data = room_state[membership_event_id].get(key) data = room_state[membership_event_id].get(key)
@ -88,7 +89,7 @@ class MessageHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_events( def get_state_events(
self, user_id, room_id, types=None, filtered_types=None, self, user_id, room_id, state_filter=StateFilter.all(),
at_token=None, is_guest=False, at_token=None, is_guest=False,
): ):
"""Retrieve all state events for a given room. If the user is """Retrieve all state events for a given room. If the user is
@ -100,13 +101,8 @@ class MessageHandler(object):
Args: Args:
user_id(str): The user requesting state events. user_id(str): The user requesting state events.
room_id(str): The room ID to get all state events from. room_id(str): The room ID to get all state events from.
types(list[(str, str|None)]|None): List of (type, state_key) tuples state_filter (StateFilter): The state filter used to fetch state
which are used to filter the state fetched. If `state_key` is None, from the database.
all events are returned of the given type.
May be None, which matches any key.
filtered_types(list[str]|None): Only apply filtering via `types` to this
list of event types. Other types of events are returned unfiltered.
If None, `types` filtering is applied to all events.
at_token(StreamToken|None): the stream token of the at which we are requesting at_token(StreamToken|None): the stream token of the at which we are requesting
the stats. If the user is not allowed to view the state as of that the stats. If the user is not allowed to view the state as of that
stream token, we raise a 403 SynapseError. If None, returns the current stream token, we raise a 403 SynapseError. If None, returns the current
@ -139,7 +135,7 @@ class MessageHandler(object):
event = last_events[0] event = last_events[0]
if visible_events: if visible_events:
room_state = yield self.store.get_state_for_events( room_state = yield self.store.get_state_for_events(
[event.event_id], types, filtered_types=filtered_types, [event.event_id], state_filter=state_filter,
) )
room_state = room_state[event.event_id] room_state = room_state[event.event_id]
else: else:
@ -158,12 +154,12 @@ class MessageHandler(object):
if membership == Membership.JOIN: if membership == Membership.JOIN:
state_ids = yield self.store.get_filtered_current_state_ids( state_ids = yield self.store.get_filtered_current_state_ids(
room_id, types, filtered_types=filtered_types, room_id, state_filter=state_filter,
) )
room_state = yield self.store.get_events(state_ids.values()) room_state = yield self.store.get_events(state_ids.values())
elif membership == Membership.LEAVE: elif membership == Membership.LEAVE:
room_state = yield self.store.get_state_for_events( room_state = yield self.store.get_state_for_events(
[membership_event_id], types, filtered_types=filtered_types, [membership_event_id], state_filter=state_filter,
) )
room_state = room_state[membership_event_id] room_state = room_state[membership_event_id]

View File

@ -21,6 +21,7 @@ from twisted.python.failure import Failure
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.storage.state import StateFilter
from synapse.types import RoomStreamToken from synapse.types import RoomStreamToken
from synapse.util.async_helpers import ReadWriteLock from synapse.util.async_helpers import ReadWriteLock
from synapse.util.logcontext import run_in_background from synapse.util.logcontext import run_in_background
@ -255,16 +256,14 @@ class PaginationHandler(object):
if event_filter and event_filter.lazy_load_members(): if event_filter and event_filter.lazy_load_members():
# TODO: remove redundant members # TODO: remove redundant members
types = [ # FIXME: we also care about invite targets etc.
(EventTypes.Member, state_key) state_filter = StateFilter.from_types(
for state_key in set( (EventTypes.Member, event.sender)
event.sender # FIXME: we also care about invite targets etc.
for event in events for event in events
) )
]
state_ids = yield self.store.get_state_ids_for_event( state_ids = yield self.store.get_state_ids_for_event(
events[0].event_id, types=types, events[0].event_id, state_filter=state_filter,
) )
if state_ids: if state_ids:

View File

@ -33,6 +33,7 @@ from synapse.api.constants import (
RoomCreationPreset, RoomCreationPreset,
) )
from synapse.api.errors import AuthError, Codes, StoreError, SynapseError from synapse.api.errors import AuthError, Codes, StoreError, SynapseError
from synapse.storage.state import StateFilter
from synapse.types import RoomAlias, RoomID, RoomStreamToken, StreamToken, UserID from synapse.types import RoomAlias, RoomID, RoomStreamToken, StreamToken, UserID
from synapse.util import stringutils from synapse.util import stringutils
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
@ -489,23 +490,24 @@ class RoomContextHandler(object):
else: else:
last_event_id = event_id last_event_id = event_id
types = None
filtered_types = None
if event_filter and event_filter.lazy_load_members(): if event_filter and event_filter.lazy_load_members():
members = set(ev.sender for ev in itertools.chain( state_filter = StateFilter.from_lazy_load_member_list(
ev.sender
for ev in itertools.chain(
results["events_before"], results["events_before"],
(results["event"],), (results["event"],),
results["events_after"], results["events_after"],
)) )
filtered_types = [EventTypes.Member] )
types = [(EventTypes.Member, member) for member in members] else:
state_filter = StateFilter.all()
# XXX: why do we return the state as of the last event rather than the # XXX: why do we return the state as of the last event rather than the
# first? Shouldn't we be consistent with /sync? # first? Shouldn't we be consistent with /sync?
# https://github.com/matrix-org/matrix-doc/issues/687 # https://github.com/matrix-org/matrix-doc/issues/687
state = yield self.store.get_state_for_events( state = yield self.store.get_state_for_events(
[last_event_id], types, filtered_types=filtered_types, [last_event_id], state_filter=state_filter,
) )
results["state"] = list(state[last_event_id].values()) results["state"] = list(state[last_event_id].values())

View File

@ -27,6 +27,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.push.clientformat import format_push_rules_for_user from synapse.push.clientformat import format_push_rules_for_user
from synapse.storage.roommember import MemberSummary from synapse.storage.roommember import MemberSummary
from synapse.storage.state import StateFilter
from synapse.types import RoomStreamToken from synapse.types import RoomStreamToken
from synapse.util.async_helpers import concurrently_execute from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
@ -469,25 +470,20 @@ class SyncHandler(object):
)) ))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_after_event(self, event, types=None, filtered_types=None): def get_state_after_event(self, event, state_filter=StateFilter.all()):
""" """
Get the room state after the given event Get the room state after the given event
Args: Args:
event(synapse.events.EventBase): event of interest event(synapse.events.EventBase): event of interest
types(list[(str, str|None)]|None): List of (type, state_key) tuples state_filter (StateFilter): The state filter used to fetch state
which are used to filter the state fetched. If `state_key` is None, from the database.
all events are returned of the given type.
May be None, which matches any key.
filtered_types(list[str]|None): Only apply filtering via `types` to this
list of event types. Other types of events are returned unfiltered.
If None, `types` filtering is applied to all events.
Returns: Returns:
A Deferred map from ((type, state_key)->Event) A Deferred map from ((type, state_key)->Event)
""" """
state_ids = yield self.store.get_state_ids_for_event( state_ids = yield self.store.get_state_ids_for_event(
event.event_id, types, filtered_types=filtered_types, event.event_id, state_filter=state_filter,
) )
if event.is_state(): if event.is_state():
state_ids = state_ids.copy() state_ids = state_ids.copy()
@ -495,18 +491,14 @@ class SyncHandler(object):
defer.returnValue(state_ids) defer.returnValue(state_ids)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_at(self, room_id, stream_position, types=None, filtered_types=None): def get_state_at(self, room_id, stream_position, state_filter=StateFilter.all()):
""" Get the room state at a particular stream position """ Get the room state at a particular stream position
Args: Args:
room_id(str): room for which to get state room_id(str): room for which to get state
stream_position(StreamToken): point at which to get state stream_position(StreamToken): point at which to get state
types(list[(str, str|None)]|None): List of (type, state_key) tuples state_filter (StateFilter): The state filter used to fetch state
which are used to filter the state fetched. If `state_key` is None, from the database.
all events are returned of the given type.
filtered_types(list[str]|None): Only apply filtering via `types` to this
list of event types. Other types of events are returned unfiltered.
If None, `types` filtering is applied to all events.
Returns: Returns:
A Deferred map from ((type, state_key)->Event) A Deferred map from ((type, state_key)->Event)
@ -522,7 +514,7 @@ class SyncHandler(object):
if last_events: if last_events:
last_event = last_events[-1] last_event = last_events[-1]
state = yield self.get_state_after_event( state = yield self.get_state_after_event(
last_event, types, filtered_types=filtered_types, last_event, state_filter=state_filter,
) )
else: else:
@ -563,10 +555,11 @@ class SyncHandler(object):
last_event = last_events[-1] last_event = last_events[-1]
state_ids = yield self.store.get_state_ids_for_event( state_ids = yield self.store.get_state_ids_for_event(
last_event.event_id, [ last_event.event_id,
state_filter=StateFilter.from_types([
(EventTypes.Name, ''), (EventTypes.Name, ''),
(EventTypes.CanonicalAlias, ''), (EventTypes.CanonicalAlias, ''),
] ]),
) )
# this is heavily cached, thus: fast. # this is heavily cached, thus: fast.
@ -717,8 +710,7 @@ class SyncHandler(object):
with Measure(self.clock, "compute_state_delta"): with Measure(self.clock, "compute_state_delta"):
types = None members_to_fetch = None
filtered_types = None
lazy_load_members = sync_config.filter_collection.lazy_load_members() lazy_load_members = sync_config.filter_collection.lazy_load_members()
include_redundant_members = ( include_redundant_members = (
@ -729,16 +721,21 @@ class SyncHandler(object):
# We only request state for the members needed to display the # We only request state for the members needed to display the
# timeline: # timeline:
types = [ members_to_fetch = set(
(EventTypes.Member, state_key)
for state_key in set(
event.sender # FIXME: we also care about invite targets etc. event.sender # FIXME: we also care about invite targets etc.
for event in batch.events for event in batch.events
) )
]
# only apply the filtering to room members if full_state:
filtered_types = [EventTypes.Member] # always make sure we LL ourselves so we know we're in the room
# (if we are) to fix https://github.com/vector-im/riot-web/issues/7209
# We only need apply this on full state syncs given we disabled
# LL for incr syncs in #3840.
members_to_fetch.add(sync_config.user.to_string())
state_filter = StateFilter.from_lazy_load_member_list(members_to_fetch)
else:
state_filter = StateFilter.all()
timeline_state = { timeline_state = {
(event.type, event.state_key): event.event_id (event.type, event.state_key): event.event_id
@ -746,28 +743,19 @@ class SyncHandler(object):
} }
if full_state: if full_state:
if lazy_load_members:
# always make sure we LL ourselves so we know we're in the room
# (if we are) to fix https://github.com/vector-im/riot-web/issues/7209
# We only need apply this on full state syncs given we disabled
# LL for incr syncs in #3840.
types.append((EventTypes.Member, sync_config.user.to_string()))
if batch: if batch:
current_state_ids = yield self.store.get_state_ids_for_event( current_state_ids = yield self.store.get_state_ids_for_event(
batch.events[-1].event_id, types=types, batch.events[-1].event_id, state_filter=state_filter,
filtered_types=filtered_types,
) )
state_ids = yield self.store.get_state_ids_for_event( state_ids = yield self.store.get_state_ids_for_event(
batch.events[0].event_id, types=types, batch.events[0].event_id, state_filter=state_filter,
filtered_types=filtered_types,
) )
else: else:
current_state_ids = yield self.get_state_at( current_state_ids = yield self.get_state_at(
room_id, stream_position=now_token, types=types, room_id, stream_position=now_token,
filtered_types=filtered_types, state_filter=state_filter,
) )
state_ids = current_state_ids state_ids = current_state_ids
@ -781,8 +769,7 @@ class SyncHandler(object):
) )
elif batch.limited: elif batch.limited:
state_at_timeline_start = yield self.store.get_state_ids_for_event( state_at_timeline_start = yield self.store.get_state_ids_for_event(
batch.events[0].event_id, types=types, batch.events[0].event_id, state_filter=state_filter,
filtered_types=filtered_types,
) )
# for now, we disable LL for gappy syncs - see # for now, we disable LL for gappy syncs - see
@ -797,17 +784,15 @@ class SyncHandler(object):
# members to just be ones which were timeline senders, which then ensures # members to just be ones which were timeline senders, which then ensures
# all of the rest get included in the state block (if we need to know # all of the rest get included in the state block (if we need to know
# about them). # about them).
types = None state_filter = StateFilter.all()
filtered_types = None
state_at_previous_sync = yield self.get_state_at( state_at_previous_sync = yield self.get_state_at(
room_id, stream_position=since_token, types=types, room_id, stream_position=since_token,
filtered_types=filtered_types, state_filter=state_filter,
) )
current_state_ids = yield self.store.get_state_ids_for_event( current_state_ids = yield self.store.get_state_ids_for_event(
batch.events[-1].event_id, types=types, batch.events[-1].event_id, state_filter=state_filter,
filtered_types=filtered_types,
) )
state_ids = _calculate_state( state_ids = _calculate_state(
@ -821,7 +806,7 @@ class SyncHandler(object):
else: else:
state_ids = {} state_ids = {}
if lazy_load_members: if lazy_load_members:
if types and batch.events: if members_to_fetch and batch.events:
# We're returning an incremental sync, with no # We're returning an incremental sync, with no
# "gap" since the previous sync, so normally there would be # "gap" since the previous sync, so normally there would be
# no state to return. # no state to return.
@ -831,8 +816,12 @@ class SyncHandler(object):
# timeline here, and then dedupe any redundant ones below. # timeline here, and then dedupe any redundant ones below.
state_ids = yield self.store.get_state_ids_for_event( state_ids = yield self.store.get_state_ids_for_event(
batch.events[0].event_id, types=types, batch.events[0].event_id,
filtered_types=None, # we only want members! # we only want members!
state_filter=StateFilter.from_types(
(EventTypes.Member, member)
for member in members_to_fetch
),
) )
if lazy_load_members and not include_redundant_members: if lazy_load_members and not include_redundant_members:

View File

@ -33,6 +33,7 @@ from synapse.http.servlet import (
parse_json_object_from_request, parse_json_object_from_request,
parse_string, parse_string,
) )
from synapse.storage.state import StateFilter
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID
@ -409,7 +410,7 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
room_id=room_id, room_id=room_id,
user_id=requester.user.to_string(), user_id=requester.user.to_string(),
at_token=at_token, at_token=at_token,
types=[(EventTypes.Member, None)], state_filter=StateFilter.from_types([(EventTypes.Member, None)]),
) )
chunk = [] chunk = []

View File

@ -2089,7 +2089,7 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
for sg in remaining_state_groups: for sg in remaining_state_groups:
logger.info("[purge] de-delta-ing remaining state group %s", sg) logger.info("[purge] de-delta-ing remaining state group %s", sg)
curr_state = self._get_state_groups_from_groups_txn( curr_state = self._get_state_groups_from_groups_txn(
txn, [sg], types=None txn, [sg],
) )
curr_state = curr_state[sg] curr_state = curr_state[sg]

File diff suppressed because it is too large Load Diff

View File

@ -23,6 +23,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
from synapse.storage.state import StateFilter
from synapse.types import get_domain_from_id from synapse.types import get_domain_from_id
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -72,7 +73,7 @@ def filter_events_for_client(store, user_id, events, is_peeking=False,
) )
event_id_to_state = yield store.get_state_for_events( event_id_to_state = yield store.get_state_for_events(
frozenset(e.event_id for e in events), frozenset(e.event_id for e in events),
types=types, state_filter=StateFilter.from_types(types),
) )
ignore_dict_content = yield store.get_global_account_data_by_type_for_user( ignore_dict_content = yield store.get_global_account_data_by_type_for_user(
@ -273,8 +274,8 @@ def filter_events_for_server(store, server_name, events):
# need to check membership (as we know the server is in the room). # need to check membership (as we know the server is in the room).
event_to_state_ids = yield store.get_state_ids_for_events( event_to_state_ids = yield store.get_state_ids_for_events(
frozenset(e.event_id for e in events), frozenset(e.event_id for e in events),
types=( state_filter=StateFilter.from_types(
(EventTypes.RoomHistoryVisibility, ""), types=((EventTypes.RoomHistoryVisibility, ""),),
) )
) )
@ -314,9 +315,11 @@ def filter_events_for_server(store, server_name, events):
# of the history vis and membership state at those events. # of the history vis and membership state at those events.
event_to_state_ids = yield store.get_state_ids_for_events( event_to_state_ids = yield store.get_state_ids_for_events(
frozenset(e.event_id for e in events), frozenset(e.event_id for e in events),
state_filter=StateFilter.from_types(
types=( types=(
(EventTypes.RoomHistoryVisibility, ""), (EventTypes.RoomHistoryVisibility, ""),
(EventTypes.Member, None), (EventTypes.Member, None),
),
) )
) )

View File

@ -18,6 +18,7 @@ import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.storage.state import StateFilter
from synapse.types import RoomID, UserID from synapse.types import RoomID, UserID
import tests.unittest import tests.unittest
@ -148,7 +149,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# check we get the full state as of the final event # check we get the full state as of the final event
state = yield self.store.get_state_for_event( state = yield self.store.get_state_for_event(
e5.event_id, None, filtered_types=None e5.event_id,
) )
self.assertIsNotNone(e4) self.assertIsNotNone(e4)
@ -166,33 +167,35 @@ class StateStoreTestCase(tests.unittest.TestCase):
# check we can filter to the m.room.name event (with a '' state key) # check we can filter to the m.room.name event (with a '' state key)
state = yield self.store.get_state_for_event( state = yield self.store.get_state_for_event(
e5.event_id, [(EventTypes.Name, '')], filtered_types=None e5.event_id, StateFilter.from_types([(EventTypes.Name, '')])
) )
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
# check we can filter to the m.room.name event (with a wildcard None state key) # check we can filter to the m.room.name event (with a wildcard None state key)
state = yield self.store.get_state_for_event( state = yield self.store.get_state_for_event(
e5.event_id, [(EventTypes.Name, None)], filtered_types=None e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
) )
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
# check we can grab the m.room.member events (with a wildcard None state key) # check we can grab the m.room.member events (with a wildcard None state key)
state = yield self.store.get_state_for_event( state = yield self.store.get_state_for_event(
e5.event_id, [(EventTypes.Member, None)], filtered_types=None e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
) )
self.assertStateMapEqual( self.assertStateMapEqual(
{(e3.type, e3.state_key): e3, (e5.type, e5.state_key): e5}, state {(e3.type, e3.state_key): e3, (e5.type, e5.state_key): e5}, state
) )
# check we can use filtered_types to grab a specific room member # check we can grab a specific room member without filtering out the
# without filtering out the other event types # other event types
state = yield self.store.get_state_for_event( state = yield self.store.get_state_for_event(
e5.event_id, e5.event_id,
[(EventTypes.Member, self.u_alice.to_string())], state_filter=StateFilter(
filtered_types=[EventTypes.Member], types={EventTypes.Member: {self.u_alice.to_string()}},
include_others=True,
)
) )
self.assertStateMapEqual( self.assertStateMapEqual(
@ -204,10 +207,12 @@ class StateStoreTestCase(tests.unittest.TestCase):
state, state,
) )
# check that types=[], filtered_types=[EventTypes.Member] # check that we can grab everything except members
# doesn't return all members
state = yield self.store.get_state_for_event( state = yield self.store.get_state_for_event(
e5.event_id, [], filtered_types=[EventTypes.Member] e5.event_id, state_filter=StateFilter(
types={EventTypes.Member: set()},
include_others=True,
),
) )
self.assertStateMapEqual( self.assertStateMapEqual(
@ -215,16 +220,21 @@ class StateStoreTestCase(tests.unittest.TestCase):
) )
####################################################### #######################################################
# _get_some_state_from_cache tests against a full cache # _get_state_for_group_using_cache tests against a full cache
####################################################### #######################################################
room_id = self.room.to_string() room_id = self.room.to_string()
group_ids = yield self.store.get_state_groups_ids(room_id, [e5.event_id]) group_ids = yield self.store.get_state_groups_ids(room_id, [e5.event_id])
group = list(group_ids.keys())[0] group = list(group_ids.keys())[0]
# test _get_some_state_from_cache correctly filters out members with types=[] # test _get_state_for_group_using_cache correctly filters out members
(state_dict, is_all) = yield self.store._get_some_state_from_cache( # with types=[]
self.store._state_group_cache, group, [], filtered_types=[EventTypes.Member] (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
self.store._state_group_cache, group,
state_filter=StateFilter(
types={EventTypes.Member: set()},
include_others=True,
),
) )
self.assertEqual(is_all, True) self.assertEqual(is_all, True)
@ -236,22 +246,27 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict, state_dict,
) )
(state_dict, is_all) = yield self.store._get_some_state_from_cache( (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
self.store._state_group_members_cache, self.store._state_group_members_cache,
group, group,
[], state_filter=StateFilter(
filtered_types=[EventTypes.Member], types={EventTypes.Member: set()},
include_others=True,
),
) )
self.assertEqual(is_all, True) self.assertEqual(is_all, True)
self.assertDictEqual({}, state_dict) self.assertDictEqual({}, state_dict)
# test _get_some_state_from_cache correctly filters in members with wildcard types # test _get_state_for_group_using_cache correctly filters in members
(state_dict, is_all) = yield self.store._get_some_state_from_cache( # with wildcard types
(state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
self.store._state_group_cache, self.store._state_group_cache,
group, group,
[(EventTypes.Member, None)], state_filter=StateFilter(
filtered_types=[EventTypes.Member], types={EventTypes.Member: None},
include_others=True,
),
) )
self.assertEqual(is_all, True) self.assertEqual(is_all, True)
@ -263,11 +278,13 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict, state_dict,
) )
(state_dict, is_all) = yield self.store._get_some_state_from_cache( (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
self.store._state_group_members_cache, self.store._state_group_members_cache,
group, group,
[(EventTypes.Member, None)], state_filter=StateFilter(
filtered_types=[EventTypes.Member], types={EventTypes.Member: None},
include_others=True,
),
) )
self.assertEqual(is_all, True) self.assertEqual(is_all, True)
@ -280,12 +297,15 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict, state_dict,
) )
# test _get_some_state_from_cache correctly filters in members with specific types # test _get_state_for_group_using_cache correctly filters in members
(state_dict, is_all) = yield self.store._get_some_state_from_cache( # with specific types
(state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
self.store._state_group_cache, self.store._state_group_cache,
group, group,
[(EventTypes.Member, e5.state_key)], state_filter=StateFilter(
filtered_types=[EventTypes.Member], types={EventTypes.Member: {e5.state_key}},
include_others=True,
),
) )
self.assertEqual(is_all, True) self.assertEqual(is_all, True)
@ -297,23 +317,27 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict, state_dict,
) )
(state_dict, is_all) = yield self.store._get_some_state_from_cache( (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
self.store._state_group_members_cache, self.store._state_group_members_cache,
group, group,
[(EventTypes.Member, e5.state_key)], state_filter=StateFilter(
filtered_types=[EventTypes.Member], types={EventTypes.Member: {e5.state_key}},
include_others=True,
),
) )
self.assertEqual(is_all, True) self.assertEqual(is_all, True)
self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict) self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
# test _get_some_state_from_cache correctly filters in members with specific types # test _get_state_for_group_using_cache correctly filters in members
# and no filtered_types # with specific types
(state_dict, is_all) = yield self.store._get_some_state_from_cache( (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
self.store._state_group_members_cache, self.store._state_group_members_cache,
group, group,
[(EventTypes.Member, e5.state_key)], state_filter=StateFilter(
filtered_types=None, types={EventTypes.Member: {e5.state_key}},
include_others=False,
),
) )
self.assertEqual(is_all, True) self.assertEqual(is_all, True)
@ -357,42 +381,54 @@ class StateStoreTestCase(tests.unittest.TestCase):
############################################ ############################################
# test that things work with a partial cache # test that things work with a partial cache
# test _get_some_state_from_cache correctly filters out members with types=[] # test _get_state_for_group_using_cache correctly filters out members
# with types=[]
room_id = self.room.to_string() room_id = self.room.to_string()
(state_dict, is_all) = yield self.store._get_some_state_from_cache( (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
self.store._state_group_cache, group, [], filtered_types=[EventTypes.Member] self.store._state_group_cache, group,
state_filter=StateFilter(
types={EventTypes.Member: set()},
include_others=True,
),
) )
self.assertEqual(is_all, False) self.assertEqual(is_all, False)
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
room_id = self.room.to_string() room_id = self.room.to_string()
(state_dict, is_all) = yield self.store._get_some_state_from_cache( (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
self.store._state_group_members_cache, self.store._state_group_members_cache,
group, group,
[], state_filter=StateFilter(
filtered_types=[EventTypes.Member], types={EventTypes.Member: set()},
include_others=True,
),
) )
self.assertEqual(is_all, True) self.assertEqual(is_all, True)
self.assertDictEqual({}, state_dict) self.assertDictEqual({}, state_dict)
# test _get_some_state_from_cache correctly filters in members wildcard types # test _get_state_for_group_using_cache correctly filters in members
(state_dict, is_all) = yield self.store._get_some_state_from_cache( # wildcard types
(state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
self.store._state_group_cache, self.store._state_group_cache,
group, group,
[(EventTypes.Member, None)], state_filter=StateFilter(
filtered_types=[EventTypes.Member], types={EventTypes.Member: None},
include_others=True,
),
) )
self.assertEqual(is_all, False) self.assertEqual(is_all, False)
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
(state_dict, is_all) = yield self.store._get_some_state_from_cache( (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
self.store._state_group_members_cache, self.store._state_group_members_cache,
group, group,
[(EventTypes.Member, None)], state_filter=StateFilter(
filtered_types=[EventTypes.Member], types={EventTypes.Member: None},
include_others=True,
),
) )
self.assertEqual(is_all, True) self.assertEqual(is_all, True)
@ -404,44 +440,53 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict, state_dict,
) )
# test _get_some_state_from_cache correctly filters in members with specific types # test _get_state_for_group_using_cache correctly filters in members
(state_dict, is_all) = yield self.store._get_some_state_from_cache( # with specific types
(state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
self.store._state_group_cache, self.store._state_group_cache,
group, group,
[(EventTypes.Member, e5.state_key)], state_filter=StateFilter(
filtered_types=[EventTypes.Member], types={EventTypes.Member: {e5.state_key}},
include_others=True,
),
) )
self.assertEqual(is_all, False) self.assertEqual(is_all, False)
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
(state_dict, is_all) = yield self.store._get_some_state_from_cache( (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
self.store._state_group_members_cache, self.store._state_group_members_cache,
group, group,
[(EventTypes.Member, e5.state_key)], state_filter=StateFilter(
filtered_types=[EventTypes.Member], types={EventTypes.Member: {e5.state_key}},
include_others=True,
),
) )
self.assertEqual(is_all, True) self.assertEqual(is_all, True)
self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict) self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
# test _get_some_state_from_cache correctly filters in members with specific types # test _get_state_for_group_using_cache correctly filters in members
# and no filtered_types # with specific types
(state_dict, is_all) = yield self.store._get_some_state_from_cache( (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
self.store._state_group_cache, self.store._state_group_cache,
group, group,
[(EventTypes.Member, e5.state_key)], state_filter=StateFilter(
filtered_types=None, types={EventTypes.Member: {e5.state_key}},
include_others=False,
),
) )
self.assertEqual(is_all, False) self.assertEqual(is_all, False)
self.assertDictEqual({}, state_dict) self.assertDictEqual({}, state_dict)
(state_dict, is_all) = yield self.store._get_some_state_from_cache( (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
self.store._state_group_members_cache, self.store._state_group_members_cache,
group, group,
[(EventTypes.Member, e5.state_key)], state_filter=StateFilter(
filtered_types=None, types={EventTypes.Member: {e5.state_key}},
include_others=False,
),
) )
self.assertEqual(is_all, True) self.assertEqual(is_all, True)