Merge pull request #2970 from matrix-org/matthew/filter_members

Implement the lazy_load_members room state filter parameter
pull/3610/head
Matthew Hodgson 2018-07-26 00:03:01 +01:00 committed by GitHub
commit 1bcd0490c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 532 additions and 62 deletions

1
changelog.d/2970.feature Normal file
View File

@ -0,0 +1 @@
add support for the lazy_loaded_members filter as per MSC1227

View File

@ -113,7 +113,10 @@ ROOM_EVENT_FILTER_SCHEMA = {
}, },
"contains_url": { "contains_url": {
"type": "boolean" "type": "boolean"
} },
"lazy_load_members": {
"type": "boolean"
},
} }
} }
@ -261,6 +264,9 @@ class FilterCollection(object):
def ephemeral_limit(self): def ephemeral_limit(self):
return self._room_ephemeral_filter.limit() return self._room_ephemeral_filter.limit()
def lazy_load_members(self):
return self._room_state_filter.lazy_load_members()
def filter_presence(self, events): def filter_presence(self, events):
return self._presence_filter.filter(events) return self._presence_filter.filter(events)
@ -417,6 +423,9 @@ class Filter(object):
def limit(self): def limit(self):
return self.filter_json.get("limit", 10) return self.filter_json.get("limit", 10)
def lazy_load_members(self):
return self.filter_json.get("lazy_load_members", False)
def _matches_wildcard(actual_value, filter_value): def _matches_wildcard(actual_value, filter_value):
if filter_value.endswith("*"): if filter_value.endswith("*"):

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2015 - 2016 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -416,29 +417,44 @@ class SyncHandler(object):
)) ))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_after_event(self, event): def get_state_after_event(self, event, types=None, filtered_types=None):
""" """
Get the room state after the given event Get the room state after the given event
Args: Args:
event(synapse.events.EventBase): event of interest event(synapse.events.EventBase): event of interest
types(list[(str, str|None)]|None): List of (type, state_key) tuples
which are used to filter the state fetched. If `state_key` is None,
all events are returned of the given type.
May be None, which matches any key.
filtered_types(list[str]|None): Only apply filtering via `types` to this
list of event types. Other types of events are returned unfiltered.
If None, `types` filtering is applied to all events.
Returns: Returns:
A Deferred map from ((type, state_key)->Event) A Deferred map from ((type, state_key)->Event)
""" """
state_ids = yield self.store.get_state_ids_for_event(event.event_id) state_ids = yield self.store.get_state_ids_for_event(
event.event_id, types, filtered_types=filtered_types,
)
if event.is_state(): if event.is_state():
state_ids = state_ids.copy() state_ids = state_ids.copy()
state_ids[(event.type, event.state_key)] = event.event_id state_ids[(event.type, event.state_key)] = event.event_id
defer.returnValue(state_ids) defer.returnValue(state_ids)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_at(self, room_id, stream_position): def get_state_at(self, room_id, stream_position, types=None, filtered_types=None):
""" Get the room state at a particular stream position """ Get the room state at a particular stream position
Args: Args:
room_id(str): room for which to get state room_id(str): room for which to get state
stream_position(StreamToken): point at which to get state stream_position(StreamToken): point at which to get state
types(list[(str, str|None)]|None): List of (type, state_key) tuples
which are used to filter the state fetched. If `state_key` is None,
all events are returned of the given type.
filtered_types(list[str]|None): Only apply filtering via `types` to this
list of event types. Other types of events are returned unfiltered.
If None, `types` filtering is applied to all events.
Returns: Returns:
A Deferred map from ((type, state_key)->Event) A Deferred map from ((type, state_key)->Event)
@ -453,7 +469,9 @@ class SyncHandler(object):
if last_events: if last_events:
last_event = last_events[-1] last_event = last_events[-1]
state = yield self.get_state_after_event(last_event) state = yield self.get_state_after_event(
last_event, types, filtered_types=filtered_types,
)
else: else:
# no events in this room - so presumably no state # no events in this room - so presumably no state
@ -485,18 +503,42 @@ class SyncHandler(object):
# TODO(mjark) Check for new redactions in the state events. # TODO(mjark) Check for new redactions in the state events.
with Measure(self.clock, "compute_state_delta"): with Measure(self.clock, "compute_state_delta"):
types = None
lazy_load_members = sync_config.filter_collection.lazy_load_members()
filtered_types = None
if lazy_load_members:
# We only request state for the members needed to display the
# timeline:
types = [
(EventTypes.Member, state_key)
for state_key in set(
event.sender # FIXME: we also care about invite targets etc.
for event in batch.events
)
]
# only apply the filtering to room members
filtered_types = [EventTypes.Member]
if full_state: if full_state:
if batch: if batch:
current_state_ids = yield self.store.get_state_ids_for_event( current_state_ids = yield self.store.get_state_ids_for_event(
batch.events[-1].event_id batch.events[-1].event_id, types=types,
filtered_types=filtered_types,
) )
state_ids = yield self.store.get_state_ids_for_event( state_ids = yield self.store.get_state_ids_for_event(
batch.events[0].event_id batch.events[0].event_id, types=types,
filtered_types=filtered_types,
) )
else: else:
current_state_ids = yield self.get_state_at( current_state_ids = yield self.get_state_at(
room_id, stream_position=now_token room_id, stream_position=now_token, types=types,
filtered_types=filtered_types,
) )
state_ids = current_state_ids state_ids = current_state_ids
@ -511,18 +553,22 @@ class SyncHandler(object):
timeline_start=state_ids, timeline_start=state_ids,
previous={}, previous={},
current=current_state_ids, current=current_state_ids,
lazy_load_members=lazy_load_members,
) )
elif batch.limited: elif batch.limited:
state_at_previous_sync = yield self.get_state_at( state_at_previous_sync = yield self.get_state_at(
room_id, stream_position=since_token room_id, stream_position=since_token, types=types,
filtered_types=filtered_types,
) )
current_state_ids = yield self.store.get_state_ids_for_event( current_state_ids = yield self.store.get_state_ids_for_event(
batch.events[-1].event_id batch.events[-1].event_id, types=types,
filtered_types=filtered_types,
) )
state_at_timeline_start = yield self.store.get_state_ids_for_event( state_at_timeline_start = yield self.store.get_state_ids_for_event(
batch.events[0].event_id batch.events[0].event_id, types=types,
filtered_types=filtered_types,
) )
timeline_state = { timeline_state = {
@ -530,14 +576,35 @@ class SyncHandler(object):
for event in batch.events if event.is_state() for event in batch.events if event.is_state()
} }
# TODO: optionally filter out redundant membership events at this
# point, to stop repeatedly sending members in every /sync as if
# the client isn't tracking them.
# When implemented, this should filter using event_ids (not mxids).
# In practice, limited syncs are
# relatively rare so it's not a total disaster to send redundant
# members down at this point. Redundant members are ones which
# repeatedly get sent down /sync because we don't know if the client
# is caching them or not.
state_ids = _calculate_state( state_ids = _calculate_state(
timeline_contains=timeline_state, timeline_contains=timeline_state,
timeline_start=state_at_timeline_start, timeline_start=state_at_timeline_start,
previous=state_at_previous_sync, previous=state_at_previous_sync,
current=current_state_ids, current=current_state_ids,
lazy_load_members=lazy_load_members,
) )
else: else:
state_ids = {} state_ids = {}
if lazy_load_members:
# TODO: filter out redundant members based on their mxids (not their
# event_ids) at this point. We know we can do it based on mxid as this
# is an non-gappy incremental sync.
if types:
state_ids = yield self.store.get_state_ids_for_event(
batch.events[0].event_id, types=types,
filtered_types=filtered_types,
)
state = {} state = {}
if state_ids: if state_ids:
@ -1448,7 +1515,9 @@ def _action_has_highlight(actions):
return False return False
def _calculate_state(timeline_contains, timeline_start, previous, current): def _calculate_state(
timeline_contains, timeline_start, previous, current, lazy_load_members,
):
"""Works out what state to include in a sync response. """Works out what state to include in a sync response.
Args: Args:
@ -1457,6 +1526,9 @@ def _calculate_state(timeline_contains, timeline_start, previous, current):
previous (dict): state at the end of the previous sync (or empty dict previous (dict): state at the end of the previous sync (or empty dict
if this is an initial sync) if this is an initial sync)
current (dict): state at the end of the timeline current (dict): state at the end of the timeline
lazy_load_members (bool): whether to return members from timeline_start
or not. assumes that timeline_start has already been filtered to
include only the members the client needs to know about.
Returns: Returns:
dict dict
@ -1472,9 +1544,25 @@ def _calculate_state(timeline_contains, timeline_start, previous, current):
} }
c_ids = set(e for e in current.values()) c_ids = set(e for e in current.values())
tc_ids = set(e for e in timeline_contains.values())
p_ids = set(e for e in previous.values())
ts_ids = set(e for e in timeline_start.values()) ts_ids = set(e for e in timeline_start.values())
p_ids = set(e for e in previous.values())
tc_ids = set(e for e in timeline_contains.values())
# If we are lazyloading room members, we explicitly add the membership events
# for the senders in the timeline into the state block returned by /sync,
# as we may not have sent them to the client before. We find these membership
# events by filtering them out of timeline_start, which has already been filtered
# to only include membership events for the senders in the timeline.
# In practice, we can do this by removing them from the p_ids list,
# which is the list of relevant state we know we have already sent to the client.
# see https://github.com/matrix-org/synapse/pull/2970
# /files/efcdacad7d1b7f52f879179701c7e0d9b763511f#r204732809
if lazy_load_members:
p_ids.difference_update(
e for t, e in timeline_start.iteritems()
if t[0] == EventTypes.Member
)
state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids

View File

@ -186,7 +186,17 @@ class StateGroupWorkerStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_state_groups_from_groups(self, groups, types): def _get_state_groups_from_groups(self, groups, types):
"""Returns dictionary state_group -> (dict of (type, state_key) -> event id) """Returns the state groups for a given set of groups, filtering on
types of state events.
Args:
groups(list[int]): list of state group IDs to query
types (Iterable[str, str|None]|None): list of 2-tuples of the form
(`type`, `state_key`), where a `state_key` of `None` matches all
state_keys for the `type`. If None, all types are returned.
Returns:
dictionary state_group -> (dict of (type, state_key) -> event id)
""" """
results = {} results = {}
@ -200,8 +210,11 @@ class StateGroupWorkerStore(SQLBaseStore):
defer.returnValue(results) defer.returnValue(results)
def _get_state_groups_from_groups_txn(self, txn, groups, types=None): def _get_state_groups_from_groups_txn(
self, txn, groups, types=None,
):
results = {group: {} for group in groups} results = {group: {} for group in groups}
if types is not None: if types is not None:
types = list(set(types)) # deduplicate types list types = list(set(types)) # deduplicate types list
@ -239,7 +252,7 @@ class StateGroupWorkerStore(SQLBaseStore):
# Turns out that postgres doesn't like doing a list of OR's and # Turns out that postgres doesn't like doing a list of OR's and
# is about 1000x slower, so we just issue a query for each specific # is about 1000x slower, so we just issue a query for each specific
# type seperately. # type seperately.
if types: if types is not None:
clause_to_args = [ clause_to_args = [
( (
"AND type = ? AND state_key = ?", "AND type = ? AND state_key = ?",
@ -278,6 +291,7 @@ class StateGroupWorkerStore(SQLBaseStore):
else: else:
where_clauses.append("(type = ? AND state_key = ?)") where_clauses.append("(type = ? AND state_key = ?)")
where_args.extend([typ[0], typ[1]]) where_args.extend([typ[0], typ[1]])
where_clause = "AND (%s)" % (" OR ".join(where_clauses)) where_clause = "AND (%s)" % (" OR ".join(where_clauses))
else: else:
where_clause = "" where_clause = ""
@ -332,16 +346,20 @@ class StateGroupWorkerStore(SQLBaseStore):
return results return results
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_for_events(self, event_ids, types): def get_state_for_events(self, event_ids, types, filtered_types=None):
"""Given a list of event_ids and type tuples, return a list of state """Given a list of event_ids and type tuples, return a list of state
dicts for each event. The state dicts will only have the type/state_keys dicts for each event. The state dicts will only have the type/state_keys
that are in the `types` list. that are in the `types` list.
Args: Args:
event_ids (list) event_ids (list[string])
types (list): List of (type, state_key) tuples which are used to types (list[(str, str|None)]|None): List of (type, state_key) tuples
filter the state fetched. `state_key` may be None, which matches which are used to filter the state fetched. If `state_key` is None,
any `state_key` all events are returned of the given type.
May be None, which matches any key.
filtered_types(list[str]|None): Only apply filtering via `types` to this
list of event types. Other types of events are returned unfiltered.
If None, `types` filtering is applied to all events.
Returns: Returns:
deferred: A list of dicts corresponding to the event_ids given. deferred: A list of dicts corresponding to the event_ids given.
@ -352,7 +370,7 @@ class StateGroupWorkerStore(SQLBaseStore):
) )
groups = set(itervalues(event_to_groups)) groups = set(itervalues(event_to_groups))
group_to_state = yield self._get_state_for_groups(groups, types) group_to_state = yield self._get_state_for_groups(groups, types, filtered_types)
state_event_map = yield self.get_events( state_event_map = yield self.get_events(
[ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)], [ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)],
@ -371,15 +389,19 @@ class StateGroupWorkerStore(SQLBaseStore):
defer.returnValue({event: event_to_state[event] for event in event_ids}) defer.returnValue({event: event_to_state[event] for event in event_ids})
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_ids_for_events(self, event_ids, types=None): def get_state_ids_for_events(self, event_ids, types=None, filtered_types=None):
""" """
Get the state dicts corresponding to a list of events Get the state dicts corresponding to a list of events
Args: Args:
event_ids(list(str)): events whose state should be returned event_ids(list(str)): events whose state should be returned
types(list[(str, str)]|None): List of (type, state_key) tuples types(list[(str, str|None)]|None): List of (type, state_key) tuples
which are used to filter the state fetched. May be None, which which are used to filter the state fetched. If `state_key` is None,
matches any key all events are returned of the given type.
May be None, which matches any key.
filtered_types(list[str]|None): Only apply filtering via `types` to this
list of event types. Other types of events are returned unfiltered.
If None, `types` filtering is applied to all events.
Returns: Returns:
A deferred dict from event_id -> (type, state_key) -> state_event A deferred dict from event_id -> (type, state_key) -> state_event
@ -389,7 +411,7 @@ class StateGroupWorkerStore(SQLBaseStore):
) )
groups = set(itervalues(event_to_groups)) groups = set(itervalues(event_to_groups))
group_to_state = yield self._get_state_for_groups(groups, types) group_to_state = yield self._get_state_for_groups(groups, types, filtered_types)
event_to_state = { event_to_state = {
event_id: group_to_state[group] event_id: group_to_state[group]
@ -399,37 +421,45 @@ class StateGroupWorkerStore(SQLBaseStore):
defer.returnValue({event: event_to_state[event] for event in event_ids}) defer.returnValue({event: event_to_state[event] for event in event_ids})
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_for_event(self, event_id, types=None): def get_state_for_event(self, event_id, types=None, filtered_types=None):
""" """
Get the state dict corresponding to a particular event Get the state dict corresponding to a particular event
Args: Args:
event_id(str): event whose state should be returned event_id(str): event whose state should be returned
types(list[(str, str)]|None): List of (type, state_key) tuples types(list[(str, str|None)]|None): List of (type, state_key) tuples
which are used to filter the state fetched. May be None, which which are used to filter the state fetched. If `state_key` is None,
matches any key all events are returned of the given type.
May be None, which matches any key.
filtered_types(list[str]|None): Only apply filtering via `types` to this
list of event types. Other types of events are returned unfiltered.
If None, `types` filtering is applied to all events.
Returns: Returns:
A deferred dict from (type, state_key) -> state_event A deferred dict from (type, state_key) -> state_event
""" """
state_map = yield self.get_state_for_events([event_id], types) state_map = yield self.get_state_for_events([event_id], types, filtered_types)
defer.returnValue(state_map[event_id]) defer.returnValue(state_map[event_id])
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_ids_for_event(self, event_id, types=None): def get_state_ids_for_event(self, event_id, types=None, filtered_types=None):
""" """
Get the state dict corresponding to a particular event Get the state dict corresponding to a particular event
Args: Args:
event_id(str): event whose state should be returned event_id(str): event whose state should be returned
types(list[(str, str)]|None): List of (type, state_key) tuples types(list[(str, str|None)]|None): List of (type, state_key) tuples
which are used to filter the state fetched. May be None, which which are used to filter the state fetched. If `state_key` is None,
matches any key all events are returned of the given type.
May be None, which matches any key.
filtered_types(list[str]|None): Only apply filtering via `types` to this
list of event types. Other types of events are returned unfiltered.
If None, `types` filtering is applied to all events.
Returns: Returns:
A deferred dict from (type, state_key) -> state_event A deferred dict from (type, state_key) -> state_event
""" """
state_map = yield self.get_state_ids_for_events([event_id], types) state_map = yield self.get_state_ids_for_events([event_id], types, filtered_types)
defer.returnValue(state_map[event_id]) defer.returnValue(state_map[event_id])
@cached(max_entries=50000) @cached(max_entries=50000)
@ -460,56 +490,73 @@ class StateGroupWorkerStore(SQLBaseStore):
defer.returnValue({row["event_id"]: row["state_group"] for row in rows}) defer.returnValue({row["event_id"]: row["state_group"] for row in rows})
def _get_some_state_from_cache(self, group, types): def _get_some_state_from_cache(self, group, types, filtered_types=None):
"""Checks if group is in cache. See `_get_state_for_groups` """Checks if group is in cache. See `_get_state_for_groups`
Returns 3-tuple (`state_dict`, `missing_types`, `got_all`). Args:
`missing_types` is the list of types that aren't in the cache for that group(int): The state group to lookup
group. `got_all` is a bool indicating if we successfully retrieved all types(list[str, str|None]): List of 2-tuples of the form
(`type`, `state_key`), where a `state_key` of `None` matches all
state_keys for the `type`.
filtered_types(list[str]|None): Only apply filtering via `types` to this
list of event types. Other types of events are returned unfiltered.
If None, `types` filtering is applied to all events.
Returns 2-tuple (`state_dict`, `got_all`).
`got_all` is a bool indicating if we successfully retrieved all
requests state from the cache, if False we need to query the DB for the requests state from the cache, if False we need to query the DB for the
missing state. missing state.
Args:
group: The state group to lookup
types (list): List of 2-tuples of the form (`type`, `state_key`),
where a `state_key` of `None` matches all state_keys for the
`type`.
""" """
is_all, known_absent, state_dict_ids = self._state_group_cache.get(group) is_all, known_absent, state_dict_ids = self._state_group_cache.get(group)
type_to_key = {} type_to_key = {}
missing_types = set()
# tracks whether any of ourrequested types are missing from the cache
missing_types = False
for typ, state_key in types: for typ, state_key in types:
key = (typ, state_key) key = (typ, state_key)
if state_key is None:
if (
state_key is None or
(filtered_types is not None and typ not in filtered_types)
):
type_to_key[typ] = None type_to_key[typ] = None
missing_types.add(key) # we mark the type as missing from the cache because
# when the cache was populated it might have been done with a
# restricted set of state_keys, so the wildcard will not work
# and the cache may be incomplete.
missing_types = True
else: else:
if type_to_key.get(typ, object()) is not None: if type_to_key.get(typ, object()) is not None:
type_to_key.setdefault(typ, set()).add(state_key) type_to_key.setdefault(typ, set()).add(state_key)
if key not in state_dict_ids and key not in known_absent: if key not in state_dict_ids and key not in known_absent:
missing_types.add(key) missing_types = True
sentinel = object() sentinel = object()
def include(typ, state_key): def include(typ, state_key):
valid_state_keys = type_to_key.get(typ, sentinel) valid_state_keys = type_to_key.get(typ, sentinel)
if valid_state_keys is sentinel: if valid_state_keys is sentinel:
return False return filtered_types is not None and typ not in filtered_types
if valid_state_keys is None: if valid_state_keys is None:
return True return True
if state_key in valid_state_keys: if state_key in valid_state_keys:
return True return True
return False return False
got_all = is_all or not missing_types got_all = is_all
if not got_all:
# the cache is incomplete. We may still have got all the results we need, if
# we don't have any wildcards in the match list.
if not missing_types and filtered_types is None:
got_all = True
return { return {
k: v for k, v in iteritems(state_dict_ids) k: v for k, v in iteritems(state_dict_ids)
if include(k[0], k[1]) if include(k[0], k[1])
}, missing_types, got_all }, got_all
def _get_all_state_from_cache(self, group): def _get_all_state_from_cache(self, group):
"""Checks if group is in cache. See `_get_state_for_groups` """Checks if group is in cache. See `_get_state_for_groups`
@ -526,7 +573,7 @@ class StateGroupWorkerStore(SQLBaseStore):
return state_dict_ids, is_all return state_dict_ids, is_all
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_state_for_groups(self, groups, types=None): def _get_state_for_groups(self, groups, types=None, filtered_types=None):
"""Gets the state at each of a list of state groups, optionally """Gets the state at each of a list of state groups, optionally
filtering by type/state_key filtering by type/state_key
@ -540,6 +587,9 @@ class StateGroupWorkerStore(SQLBaseStore):
Otherwise, each entry should be a `(type, state_key)` tuple to Otherwise, each entry should be a `(type, state_key)` tuple to
include in the response. A `state_key` of None is a wildcard include in the response. A `state_key` of None is a wildcard
meaning that we require all state with that type. meaning that we require all state with that type.
filtered_types(list[str]|None): Only apply filtering via `types` to this
list of event types. Other types of events are returned unfiltered.
If None, `types` filtering is applied to all events.
Returns: Returns:
Deferred[dict[int, dict[(type, state_key), EventBase]]] Deferred[dict[int, dict[(type, state_key), EventBase]]]
@ -551,8 +601,8 @@ class StateGroupWorkerStore(SQLBaseStore):
missing_groups = [] missing_groups = []
if types is not None: if types is not None:
for group in set(groups): for group in set(groups):
state_dict_ids, _, got_all = self._get_some_state_from_cache( state_dict_ids, got_all = self._get_some_state_from_cache(
group, types, group, types, filtered_types
) )
results[group] = state_dict_ids results[group] = state_dict_ids
@ -579,13 +629,13 @@ class StateGroupWorkerStore(SQLBaseStore):
# cache. Hence, if we are doing a wildcard lookup, populate the # cache. Hence, if we are doing a wildcard lookup, populate the
# cache fully so that we can do an efficient lookup next time. # cache fully so that we can do an efficient lookup next time.
if types and any(k is None for (t, k) in types): if filtered_types or (types and any(k is None for (t, k) in types)):
types_to_fetch = None types_to_fetch = None
else: else:
types_to_fetch = types types_to_fetch = types
group_to_state_dict = yield self._get_state_groups_from_groups( group_to_state_dict = yield self._get_state_groups_from_groups(
missing_groups, types_to_fetch, missing_groups, types_to_fetch
) )
for group, group_state_dict in iteritems(group_to_state_dict): for group, group_state_dict in iteritems(group_to_state_dict):
@ -595,7 +645,10 @@ class StateGroupWorkerStore(SQLBaseStore):
if types: if types:
for k, v in iteritems(group_state_dict): for k, v in iteritems(group_state_dict):
(typ, _) = k (typ, _) = k
if k in types or (typ, None) in types: if (
(k in types or (typ, None) in types) or
(filtered_types and typ not in filtered_types)
):
state_dict[k] = v state_dict[k] = v
else: else:
state_dict.update(group_state_dict) state_dict.update(group_state_dict)

319
tests/storage/test_state.py Normal file
View File

@ -0,0 +1,319 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.types import RoomID, UserID
import tests.unittest
import tests.utils
logger = logging.getLogger(__name__)
class StateStoreTestCase(tests.unittest.TestCase):
def __init__(self, *args, **kwargs):
super(StateStoreTestCase, self).__init__(*args, **kwargs)
self.store = None # type: synapse.storage.DataStore
@defer.inlineCallbacks
def setUp(self):
hs = yield tests.utils.setup_test_homeserver()
self.store = hs.get_datastore()
self.event_builder_factory = hs.get_event_builder_factory()
self.event_creation_handler = hs.get_event_creation_handler()
self.u_alice = UserID.from_string("@alice:test")
self.u_bob = UserID.from_string("@bob:test")
self.room = RoomID.from_string("!abc123:test")
yield self.store.store_room(
self.room.to_string(),
room_creator_user_id="@creator:text",
is_public=True
)
@defer.inlineCallbacks
def inject_state_event(self, room, sender, typ, state_key, content):
builder = self.event_builder_factory.new({
"type": typ,
"sender": sender.to_string(),
"state_key": state_key,
"room_id": room.to_string(),
"content": content,
})
event, context = yield self.event_creation_handler.create_new_client_event(
builder
)
yield self.store.persist_event(event, context)
defer.returnValue(event)
def assertStateMapEqual(self, s1, s2):
for t in s1:
# just compare event IDs for simplicity
self.assertEqual(s1[t].event_id, s2[t].event_id)
self.assertEqual(len(s1), len(s2))
@defer.inlineCallbacks
def test_get_state_for_event(self):
# this defaults to a linear DAG as each new injection defaults to whatever
# forward extremities are currently in the DB for this room.
e1 = yield self.inject_state_event(
self.room, self.u_alice, EventTypes.Create, '', {},
)
e2 = yield self.inject_state_event(
self.room, self.u_alice, EventTypes.Name, '', {
"name": "test room"
},
)
e3 = yield self.inject_state_event(
self.room, self.u_alice, EventTypes.Member, self.u_alice.to_string(), {
"membership": Membership.JOIN
},
)
e4 = yield self.inject_state_event(
self.room, self.u_bob, EventTypes.Member, self.u_bob.to_string(), {
"membership": Membership.JOIN
},
)
e5 = yield self.inject_state_event(
self.room, self.u_bob, EventTypes.Member, self.u_bob.to_string(), {
"membership": Membership.LEAVE
},
)
# check we get the full state as of the final event
state = yield self.store.get_state_for_event(
e5.event_id, None, filtered_types=None
)
self.assertIsNotNone(e4)
self.assertStateMapEqual({
(e1.type, e1.state_key): e1,
(e2.type, e2.state_key): e2,
(e3.type, e3.state_key): e3,
# e4 is overwritten by e5
(e5.type, e5.state_key): e5,
}, state)
# check we can filter to the m.room.name event (with a '' state key)
state = yield self.store.get_state_for_event(
e5.event_id, [(EventTypes.Name, '')], filtered_types=None
)
self.assertStateMapEqual({
(e2.type, e2.state_key): e2,
}, state)
# check we can filter to the m.room.name event (with a wildcard None state key)
state = yield self.store.get_state_for_event(
e5.event_id, [(EventTypes.Name, None)], filtered_types=None
)
self.assertStateMapEqual({
(e2.type, e2.state_key): e2,
}, state)
# check we can grab the m.room.member events (with a wildcard None state key)
state = yield self.store.get_state_for_event(
e5.event_id, [(EventTypes.Member, None)], filtered_types=None
)
self.assertStateMapEqual({
(e3.type, e3.state_key): e3,
(e5.type, e5.state_key): e5,
}, state)
# check we can use filter_types to grab a specific room member
# without filtering out the other event types
state = yield self.store.get_state_for_event(
e5.event_id, [(EventTypes.Member, self.u_alice.to_string())],
filtered_types=[EventTypes.Member],
)
self.assertStateMapEqual({
(e1.type, e1.state_key): e1,
(e2.type, e2.state_key): e2,
(e3.type, e3.state_key): e3,
}, state)
# check that types=[], filtered_types=[EventTypes.Member]
# doesn't return all members
state = yield self.store.get_state_for_event(
e5.event_id, [], filtered_types=[EventTypes.Member],
)
self.assertStateMapEqual({
(e1.type, e1.state_key): e1,
(e2.type, e2.state_key): e2,
}, state)
#######################################################
# _get_some_state_from_cache tests against a full cache
#######################################################
room_id = self.room.to_string()
group_ids = yield self.store.get_state_groups_ids(room_id, [e5.event_id])
group = group_ids.keys()[0]
# test _get_some_state_from_cache correctly filters out members with types=[]
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
group, [], filtered_types=[EventTypes.Member]
)
self.assertEqual(is_all, True)
self.assertDictEqual({
(e1.type, e1.state_key): e1.event_id,
(e2.type, e2.state_key): e2.event_id,
}, state_dict)
# test _get_some_state_from_cache correctly filters in members with wildcard types
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member]
)
self.assertEqual(is_all, True)
self.assertDictEqual({
(e1.type, e1.state_key): e1.event_id,
(e2.type, e2.state_key): e2.event_id,
(e3.type, e3.state_key): e3.event_id,
# e4 is overwritten by e5
(e5.type, e5.state_key): e5.event_id,
}, state_dict)
# test _get_some_state_from_cache correctly filters in members with specific types
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
group, [(EventTypes.Member, e5.state_key)], filtered_types=[EventTypes.Member]
)
self.assertEqual(is_all, True)
self.assertDictEqual({
(e1.type, e1.state_key): e1.event_id,
(e2.type, e2.state_key): e2.event_id,
(e5.type, e5.state_key): e5.event_id,
}, state_dict)
# test _get_some_state_from_cache correctly filters in members with specific types
# and no filtered_types
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
group, [(EventTypes.Member, e5.state_key)], filtered_types=None
)
self.assertEqual(is_all, True)
self.assertDictEqual({
(e5.type, e5.state_key): e5.event_id,
}, state_dict)
#######################################################
# deliberately remove e2 (room name) from the _state_group_cache
(is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get(group)
self.assertEqual(is_all, True)
self.assertEqual(known_absent, set())
self.assertDictEqual(state_dict_ids, {
(e1.type, e1.state_key): e1.event_id,
(e2.type, e2.state_key): e2.event_id,
(e3.type, e3.state_key): e3.event_id,
# e4 is overwritten by e5
(e5.type, e5.state_key): e5.event_id,
})
state_dict_ids.pop((e2.type, e2.state_key))
self.store._state_group_cache.invalidate(group)
self.store._state_group_cache.update(
sequence=self.store._state_group_cache.sequence,
key=group,
value=state_dict_ids,
# list fetched keys so it knows it's partial
fetched_keys=(
(e1.type, e1.state_key),
(e3.type, e3.state_key),
(e5.type, e5.state_key),
)
)
(is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get(group)
self.assertEqual(is_all, False)
self.assertEqual(known_absent, set([
(e1.type, e1.state_key),
(e3.type, e3.state_key),
(e5.type, e5.state_key),
]))
self.assertDictEqual(state_dict_ids, {
(e1.type, e1.state_key): e1.event_id,
(e3.type, e3.state_key): e3.event_id,
(e5.type, e5.state_key): e5.event_id,
})
############################################
# test that things work with a partial cache
# test _get_some_state_from_cache correctly filters out members with types=[]
room_id = self.room.to_string()
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
group, [], filtered_types=[EventTypes.Member]
)
self.assertEqual(is_all, False)
self.assertDictEqual({
(e1.type, e1.state_key): e1.event_id,
}, state_dict)
# test _get_some_state_from_cache correctly filters in members wildcard types
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member]
)
self.assertEqual(is_all, False)
self.assertDictEqual({
(e1.type, e1.state_key): e1.event_id,
(e3.type, e3.state_key): e3.event_id,
# e4 is overwritten by e5
(e5.type, e5.state_key): e5.event_id,
}, state_dict)
# test _get_some_state_from_cache correctly filters in members with specific types
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
group, [(EventTypes.Member, e5.state_key)], filtered_types=[EventTypes.Member]
)
self.assertEqual(is_all, False)
self.assertDictEqual({
(e1.type, e1.state_key): e1.event_id,
(e5.type, e5.state_key): e5.event_id,
}, state_dict)
# test _get_some_state_from_cache correctly filters in members with specific types
# and no filtered_types
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
group, [(EventTypes.Member, e5.state_key)], filtered_types=None
)
self.assertEqual(is_all, True)
self.assertDictEqual({
(e5.type, e5.state_key): e5.event_id,
}, state_dict)