diff --git a/CHANGES.md b/CHANGES.md index c8aa5d177f..7927714a36 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,20 @@ +Synapse 1.7.1 (2019-12-18) +========================== + +This release includes several security fixes as well as a fix to a bug exposed by the security fixes. Administrators are encouraged to upgrade as soon as possible. + +Security updates +---------------- + +- Fix a bug which could cause room events to be incorrectly authorized using events from a different room. ([\#6501](https://github.com/matrix-org/synapse/issues/6501), [\#6503](https://github.com/matrix-org/synapse/issues/6503), [\#6521](https://github.com/matrix-org/synapse/issues/6521), [\#6524](https://github.com/matrix-org/synapse/issues/6524), [\#6530](https://github.com/matrix-org/synapse/issues/6530), [\#6531](https://github.com/matrix-org/synapse/issues/6531)) +- Fix a bug causing responses to the `/context` client endpoint to not use the pruned version of the event. ([\#6553](https://github.com/matrix-org/synapse/issues/6553)) +- Fix a cause of state resets in room versions 2 onwards. ([\#6556](https://github.com/matrix-org/synapse/issues/6556), [\#6560](https://github.com/matrix-org/synapse/issues/6560)) + +Bugfixes +-------- + +- Fix a bug which could cause the federation server to incorrectly return errors when handling certain obscure event graphs. ([\#6526](https://github.com/matrix-org/synapse/issues/6526), [\#6527](https://github.com/matrix-org/synapse/issues/6527)) + Synapse 1.7.0 (2019-12-13) ========================== @@ -88,7 +105,7 @@ Internal Changes - Add a test scenario to make sure room history purges don't break `/messages` in the future. ([\#6392](https://github.com/matrix-org/synapse/issues/6392)) - Clarifications for the email configuration settings. ([\#6423](https://github.com/matrix-org/synapse/issues/6423)) - Add more tests to the blacklist when running in worker mode. ([\#6429](https://github.com/matrix-org/synapse/issues/6429)) -- Refactor data store layer to support multiple databases in the future. ([\#6454](https://github.com/matrix-org/synapse/issues/6454), [\#6464](https://github.com/matrix-org/synapse/issues/6464), [\#6469](https://github.com/matrix-org/synapse/issues/6469), [\#6487](https://github.com/matrix-org/synapse/issues/6487)) +- Refactor data store layer to support multiple databases in the future. ([\#6454](https://github.com/matrix-org/synapse/issues/6454), [\#6464](https://github.com/matrix-org/synapse/issues/6464), [\#6469](https://github.com/matrix-org/synapse/issues/6469), [\#6487](https://github.com/matrix-org/synapse/issues/6487)) - Port synapse.rest.client.v1 to async/await. ([\#6482](https://github.com/matrix-org/synapse/issues/6482)) - Port synapse.rest.client.v2_alpha to async/await. ([\#6483](https://github.com/matrix-org/synapse/issues/6483)) - Port SyncHandler to async/await. ([\#6484](https://github.com/matrix-org/synapse/issues/6484)) diff --git a/debian/changelog b/debian/changelog index bd43feb321..e400619eb9 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,9 @@ +matrix-synapse-py3 (1.7.1) stable; urgency=medium + + * New synapse release 1.7.1. + + -- Synapse Packaging team Wed, 18 Dec 2019 09:37:59 +0000 + matrix-synapse-py3 (1.7.0) stable; urgency=medium * New synapse release 1.7.0. diff --git a/synapse/__init__.py b/synapse/__init__.py index d3cf7b3d7b..e951bab593 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -36,7 +36,7 @@ try: except ImportError: pass -__version__ = "1.7.0" +__version__ = "1.7.1" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 5d0b7d2801..9fd52a8c77 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +from typing import Dict, Tuple from six import itervalues @@ -25,13 +26,7 @@ from twisted.internet import defer import synapse.logging.opentracing as opentracing import synapse.types from synapse import event_auth -from synapse.api.constants import ( - EventTypes, - JoinRules, - LimitBlockingTypes, - Membership, - UserTypes, -) +from synapse.api.constants import EventTypes, LimitBlockingTypes, Membership, UserTypes from synapse.api.errors import ( AuthError, Codes, @@ -513,71 +508,43 @@ class Auth(object): """ return self.store.is_server_admin(user) - @defer.inlineCallbacks - def compute_auth_events(self, event, current_state_ids, for_verification=False): + def compute_auth_events( + self, + event, + current_state_ids: Dict[Tuple[str, str], str], + for_verification: bool = False, + ): + """Given an event and current state return the list of event IDs used + to auth an event. + + If `for_verification` is False then only return auth events that + should be added to the event's `auth_events`. + + Returns: + defer.Deferred(list[str]): List of event IDs. + """ + if event.type == EventTypes.Create: - return [] + return defer.succeed([]) + + # Currently we ignore the `for_verification` flag even though there are + # some situations where we can drop particular auth events when adding + # to the event's `auth_events` (e.g. joins pointing to previous joins + # when room is publically joinable). Dropping event IDs has the + # advantage that the auth chain for the room grows slower, but we use + # the auth chain in state resolution v2 to order events, which means + # care must be taken if dropping events to ensure that it doesn't + # introduce undesirable "state reset" behaviour. + # + # All of which sounds a bit tricky so we don't bother for now. auth_ids = [] + for etype, state_key in event_auth.auth_types_for_event(event): + auth_ev_id = current_state_ids.get((etype, state_key)) + if auth_ev_id: + auth_ids.append(auth_ev_id) - key = (EventTypes.PowerLevels, "") - power_level_event_id = current_state_ids.get(key) - - if power_level_event_id: - auth_ids.append(power_level_event_id) - - key = (EventTypes.JoinRules, "") - join_rule_event_id = current_state_ids.get(key) - - key = (EventTypes.Member, event.sender) - member_event_id = current_state_ids.get(key) - - key = (EventTypes.Create, "") - create_event_id = current_state_ids.get(key) - if create_event_id: - auth_ids.append(create_event_id) - - if join_rule_event_id: - join_rule_event = yield self.store.get_event(join_rule_event_id) - join_rule = join_rule_event.content.get("join_rule") - is_public = join_rule == JoinRules.PUBLIC if join_rule else False - else: - is_public = False - - if event.type == EventTypes.Member: - e_type = event.content["membership"] - if e_type in [Membership.JOIN, Membership.INVITE]: - if join_rule_event_id: - auth_ids.append(join_rule_event_id) - - if e_type == Membership.JOIN: - if member_event_id and not is_public: - auth_ids.append(member_event_id) - else: - if member_event_id: - auth_ids.append(member_event_id) - - if for_verification: - key = (EventTypes.Member, event.state_key) - existing_event_id = current_state_ids.get(key) - if existing_event_id: - auth_ids.append(existing_event_id) - - if e_type == Membership.INVITE: - if "third_party_invite" in event.content: - key = ( - EventTypes.ThirdPartyInvite, - event.content["third_party_invite"]["signed"]["token"], - ) - third_party_invite_id = current_state_ids.get(key) - if third_party_invite_id: - auth_ids.append(third_party_invite_id) - elif member_event_id: - member_event = yield self.store.get_event(member_event_id) - if member_event.content["membership"] == Membership.JOIN: - auth_ids.append(member_event.event_id) - - return auth_ids + return defer.succeed(auth_ids) @defer.inlineCallbacks def check_can_change_room_list(self, room_id, user): diff --git a/synapse/event_auth.py b/synapse/event_auth.py index ec3243b27b..350ed9351f 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +from typing import Set, Tuple from canonicaljson import encode_canonical_json from signedjson.key import decode_verify_key_bytes @@ -48,6 +49,18 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru if not hasattr(event, "room_id"): raise AuthError(500, "Event has no room_id: %s" % event) + room_id = event.room_id + + # I'm not really expecting to get auth events in the wrong room, but let's + # sanity-check it + for auth_event in auth_events.values(): + if auth_event.room_id != room_id: + raise Exception( + "During auth for event %s in room %s, found event %s in the state " + "which is in room %s" + % (event.event_id, room_id, auth_event.event_id, auth_event.room_id) + ) + if do_sig_check: sender_domain = get_domain_from_id(event.sender) @@ -625,7 +638,7 @@ def get_public_keys(invite_event): return public_keys -def auth_types_for_event(event): +def auth_types_for_event(event) -> Set[Tuple[str]]: """Given an event, return a list of (EventType, StateKey) that may be needed to auth the event. The returned list may be a superset of what would actually be required depending on the full state of the room. @@ -634,20 +647,20 @@ def auth_types_for_event(event): actually auth the event. """ if event.type == EventTypes.Create: - return [] + return set() - auth_types = [ + auth_types = { (EventTypes.PowerLevels, ""), (EventTypes.Member, event.sender), (EventTypes.Create, ""), - ] + } if event.type == EventTypes.Member: membership = event.content["membership"] if membership in [Membership.JOIN, Membership.INVITE]: - auth_types.append((EventTypes.JoinRules, "")) + auth_types.add((EventTypes.JoinRules, "")) - auth_types.append((EventTypes.Member, event.state_key)) + auth_types.add((EventTypes.Member, event.state_key)) if membership == Membership.INVITE: if "third_party_invite" in event.content: @@ -655,6 +668,6 @@ def auth_types_for_event(event): EventTypes.ThirdPartyInvite, event.content["third_party_invite"]["signed"]["token"], ) - auth_types.append(key) + auth_types.add(key) return auth_types diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 709449c9e3..d396e6564f 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -18,8 +18,6 @@ import copy import itertools import logging -from six.moves import range - from prometheus_client import Counter from twisted.internet import defer @@ -39,7 +37,7 @@ from synapse.api.room_versions import ( ) from synapse.events import builder, room_version_to_event_format from synapse.federation.federation_base import FederationBase, event_from_pdu_json -from synapse.logging.context import make_deferred_yieldable, run_in_background +from synapse.logging.context import make_deferred_yieldable from synapse.logging.utils import log_function from synapse.util import unwrapFirstError from synapse.util.caches.expiringcache import ExpiringCache @@ -310,19 +308,12 @@ class FederationClient(FederationBase): return signed_pdu @defer.inlineCallbacks - @log_function - def get_state_for_room(self, destination, room_id, event_id): - """Requests all of the room state at a given event from a remote homeserver. - - Args: - destination (str): The remote homeserver to query for the state. - room_id (str): The id of the room we're interested in. - event_id (str): The id of the event we want the state at. + def get_room_state_ids(self, destination: str, room_id: str, event_id: str): + """Calls the /state_ids endpoint to fetch the state at a particular point + in the room, and the auth events for the given event Returns: - Deferred[Tuple[List[EventBase], List[EventBase]]]: - A list of events in the state, and a list of events in the auth chain - for the given event. + Tuple[List[str], List[str]]: a tuple of (state event_ids, auth event_ids) """ result = yield self.transport_layer.get_room_state_ids( destination, room_id, event_id=event_id @@ -331,86 +322,12 @@ class FederationClient(FederationBase): state_event_ids = result["pdu_ids"] auth_event_ids = result.get("auth_chain_ids", []) - fetched_events, failed_to_fetch = yield self.get_events_from_store_or_dest( - destination, room_id, set(state_event_ids + auth_event_ids) - ) + if not isinstance(state_event_ids, list) or not isinstance( + auth_event_ids, list + ): + raise Exception("invalid response from /state_ids") - if failed_to_fetch: - logger.warning( - "Failed to fetch missing state/auth events for %s: %s", - room_id, - failed_to_fetch, - ) - - event_map = {ev.event_id: ev for ev in fetched_events} - - pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map] - auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map] - - auth_chain.sort(key=lambda e: e.depth) - - return pdus, auth_chain - - @defer.inlineCallbacks - def get_events_from_store_or_dest(self, destination, room_id, event_ids): - """Fetch events from a remote destination, checking if we already have them. - - Args: - destination (str) - room_id (str) - event_ids (list) - - Returns: - Deferred: A deferred resolving to a 2-tuple where the first is a list of - events and the second is a list of event ids that we failed to fetch. - """ - seen_events = yield self.store.get_events(event_ids, allow_rejected=True) - signed_events = list(seen_events.values()) - - failed_to_fetch = set() - - missing_events = set(event_ids) - for k in seen_events: - missing_events.discard(k) - - if not missing_events: - return signed_events, failed_to_fetch - - logger.debug( - "Fetching unknown state/auth events %s for room %s", - missing_events, - event_ids, - ) - - room_version = yield self.store.get_room_version(room_id) - - batch_size = 20 - missing_events = list(missing_events) - for i in range(0, len(missing_events), batch_size): - batch = set(missing_events[i : i + batch_size]) - - deferreds = [ - run_in_background( - self.get_pdu, - destinations=[destination], - event_id=e_id, - room_version=room_version, - ) - for e_id in batch - ] - - res = yield make_deferred_yieldable( - defer.DeferredList(deferreds, consumeErrors=True) - ) - for success, result in res: - if success and result: - signed_events.append(result) - batch.discard(result.event_id) - - # We removed all events we successfully fetched from `batch` - failed_to_fetch.update(batch) - - return signed_events, failed_to_fetch + return state_event_ids, auth_event_ids @defer.inlineCallbacks @log_function diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index bc26921768..abe02907b9 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -64,8 +64,7 @@ from synapse.replication.http.federation import ( from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet from synapse.state import StateResolutionStore, resolve_events_with_store from synapse.types import UserID, get_domain_from_id -from synapse.util import unwrapFirstError -from synapse.util.async_helpers import Linearizer +from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.distributor import user_joined_room from synapse.util.retryutils import NotRetryingDestination from synapse.visibility import filter_events_for_server @@ -240,7 +239,6 @@ class FederationHandler(BaseHandler): return None state = None - auth_chain = [] # Get missing pdus if necessary. if not pdu.internal_metadata.is_outlier(): @@ -346,7 +344,6 @@ class FederationHandler(BaseHandler): # Calculate the state after each of the previous events, and # resolve them to find the correct state at the current event. - auth_chains = set() event_map = {event_id: pdu} try: # Get the state of the events we know about @@ -370,38 +367,14 @@ class FederationHandler(BaseHandler): p, ) - room_version = yield self.store.get_room_version(room_id) - with nested_logging_context(p): # note that if any of the missing prevs share missing state or # auth events, the requests to fetch those events are deduped # by the get_pdu_cache in federation_client. - ( - remote_state, - got_auth_chain, - ) = yield self.federation_client.get_state_for_room( - origin, room_id, p + (remote_state, _,) = yield self._get_state_for_room( + origin, room_id, p, include_event_in_state=True ) - # we want the state *after* p; get_state_for_room returns the - # state *before* p. - remote_event = yield self.federation_client.get_pdu( - [origin], p, room_version, outlier=True - ) - - if remote_event is None: - raise Exception( - "Unable to get missing prev_event %s" % (p,) - ) - - if remote_event.is_state(): - remote_state.append(remote_event) - - # XXX hrm I'm not convinced that duplicate events will compare - # for equality, so I'm not sure this does what the author - # hoped. - auth_chains.update(got_auth_chain) - remote_state_map = { (x.type, x.state_key): x.event_id for x in remote_state } @@ -410,7 +383,9 @@ class FederationHandler(BaseHandler): for x in remote_state: event_map[x.event_id] = x + room_version = yield self.store.get_room_version(room_id) state_map = yield resolve_events_with_store( + room_id, room_version, state_maps, event_map, @@ -430,7 +405,6 @@ class FederationHandler(BaseHandler): event_map.update(evs) state = [event_map[e] for e in six.itervalues(state_map)] - auth_chain = list(auth_chains) except Exception: logger.warning( "[%s %s] Error attempting to resolve state at missing " @@ -446,9 +420,7 @@ class FederationHandler(BaseHandler): affected=event_id, ) - yield self._process_received_pdu( - origin, pdu, state=state, auth_chain=auth_chain - ) + yield self._process_received_pdu(origin, pdu, state=state) @defer.inlineCallbacks def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth): @@ -584,50 +556,150 @@ class FederationHandler(BaseHandler): raise @defer.inlineCallbacks - def _process_received_pdu(self, origin, event, state, auth_chain): + @log_function + def _get_state_for_room( + self, destination, room_id, event_id, include_event_in_state + ): + """Requests all of the room state at a given event from a remote homeserver. + + Args: + destination (str): The remote homeserver to query for the state. + room_id (str): The id of the room we're interested in. + event_id (str): The id of the event we want the state at. + include_event_in_state: if true, the event itself will be included in the + returned state event list. + + Returns: + Deferred[Tuple[List[EventBase], List[EventBase]]]: + A list of events in the state, and a list of events in the auth chain + for the given event. + """ + ( + state_event_ids, + auth_event_ids, + ) = yield self.federation_client.get_room_state_ids( + destination, room_id, event_id=event_id + ) + + desired_events = set(state_event_ids + auth_event_ids) + + if include_event_in_state: + desired_events.add(event_id) + + event_map = yield self._get_events_from_store_or_dest( + destination, room_id, desired_events + ) + + failed_to_fetch = desired_events - event_map.keys() + if failed_to_fetch: + logger.warning( + "Failed to fetch missing state/auth events for %s: %s", + room_id, + failed_to_fetch, + ) + + remote_state = [ + event_map[e_id] for e_id in state_event_ids if e_id in event_map + ] + + if include_event_in_state: + remote_event = event_map.get(event_id) + if not remote_event: + raise Exception("Unable to get missing prev_event %s" % (event_id,)) + if remote_event.is_state() and remote_event.rejected_reason is None: + remote_state.append(remote_event) + + auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map] + auth_chain.sort(key=lambda e: e.depth) + + return remote_state, auth_chain + + @defer.inlineCallbacks + def _get_events_from_store_or_dest(self, destination, room_id, event_ids): + """Fetch events from a remote destination, checking if we already have them. + + Args: + destination (str) + room_id (str) + event_ids (Iterable[str]) + + Persists any events we don't already have as outliers. + + If we fail to fetch any of the events, a warning will be logged, and the event + will be omitted from the result. Likewise, any events which turn out not to + be in the given room. + + Returns: + Deferred[dict[str, EventBase]]: A deferred resolving to a map + from event_id to event + """ + fetched_events = yield self.store.get_events(event_ids, allow_rejected=True) + + missing_events = set(event_ids) - fetched_events.keys() + + if missing_events: + logger.debug( + "Fetching unknown state/auth events %s for room %s", + missing_events, + room_id, + ) + + yield self._get_events_and_persist( + destination=destination, room_id=room_id, events=missing_events + ) + + # we need to make sure we re-load from the database to get the rejected + # state correct. + fetched_events.update( + (yield self.store.get_events(missing_events, allow_rejected=True)) + ) + + # check for events which were in the wrong room. + # + # this can happen if a remote server claims that the state or + # auth_events at an event in room A are actually events in room B + + bad_events = list( + (event_id, event.room_id) + for event_id, event in fetched_events.items() + if event.room_id != room_id + ) + + for bad_event_id, bad_room_id in bad_events: + # This is a bogus situation, but since we may only discover it a long time + # after it happened, we try our best to carry on, by just omitting the + # bad events from the returned auth/state set. + logger.warning( + "Remote server %s claims event %s in room %s is an auth/state " + "event in room %s", + destination, + bad_event_id, + bad_room_id, + room_id, + ) + del fetched_events[bad_event_id] + + return fetched_events + + @defer.inlineCallbacks + def _process_received_pdu(self, origin, event, state): """ Called when we have a new pdu. We need to do auth checks and put it through the StateHandler. + + Args: + origin: server sending the event + + event: event to be persisted + + state: Normally None, but if we are handling a gap in the graph + (ie, we are missing one or more prev_events), the resolved state at the + event """ room_id = event.room_id event_id = event.event_id logger.debug("[%s %s] Processing event: %s", room_id, event_id, event) - event_ids = set() - if state: - event_ids |= {e.event_id for e in state} - if auth_chain: - event_ids |= {e.event_id for e in auth_chain} - - seen_ids = yield self.store.have_seen_events(event_ids) - - if state and auth_chain is not None: - # If we have any state or auth_chain given to us by the replication - # layer, then we should handle them (if we haven't before.) - - event_infos = [] - - for e in itertools.chain(auth_chain, state): - if e.event_id in seen_ids: - continue - e.internal_metadata.outlier = True - auth_ids = e.auth_event_ids() - auth = { - (e.type, e.state_key): e - for e in auth_chain - if e.event_id in auth_ids or e.type == EventTypes.Create - } - event_infos.append(_NewEventInfo(event=e, auth_events=auth)) - seen_ids.add(e.event_id) - - logger.info( - "[%s %s] persisting newly-received auth/state events %s", - room_id, - event_id, - [e.event.event_id for e in event_infos], - ) - yield self._handle_new_events(origin, event_infos) - try: context = yield self._handle_new_event(origin, event, state=state) except AuthError as e: @@ -683,8 +755,6 @@ class FederationHandler(BaseHandler): if dest == self.server_name: raise SynapseError(400, "Can't backfill from self.") - room_version = yield self.store.get_room_version(room_id) - events = yield self.federation_client.backfill( dest, room_id, limit=limit, extremities=extremities ) @@ -713,6 +783,9 @@ class FederationHandler(BaseHandler): event_ids = set(e.event_id for e in events) + # build a list of events whose prev_events weren't in the batch. + # (XXX: this will include events whose prev_events we already have; that doesn't + # sound right?) edges = [ev.event_id for ev in events if set(ev.prev_event_ids()) - event_ids] logger.info("backfill: Got %d events with %d edges", len(events), len(edges)) @@ -723,7 +796,7 @@ class FederationHandler(BaseHandler): state_events = {} events_to_state = {} for e_id in edges: - state, auth = yield self.federation_client.get_state_for_room( + state, auth = yield self._get_state_for_room( destination=dest, room_id=room_id, event_id=e_id ) auth_events.update({a.event_id: a for a in auth}) @@ -741,95 +814,11 @@ class FederationHandler(BaseHandler): auth_events.update( {e_id: event_map[e_id] for e_id in required_auth if e_id in event_map} ) - missing_auth = required_auth - set(auth_events) - failed_to_fetch = set() - # Try and fetch any missing auth events from both DB and remote servers. - # We repeatedly do this until we stop finding new auth events. - while missing_auth - failed_to_fetch: - logger.info("Missing auth for backfill: %r", missing_auth) - ret_events = yield self.store.get_events(missing_auth - failed_to_fetch) - auth_events.update(ret_events) - - required_auth.update( - a_id for event in ret_events.values() for a_id in event.auth_event_ids() - ) - missing_auth = required_auth - set(auth_events) - - if missing_auth - failed_to_fetch: - logger.info( - "Fetching missing auth for backfill: %r", - missing_auth - failed_to_fetch, - ) - - results = yield make_deferred_yieldable( - defer.gatherResults( - [ - run_in_background( - self.federation_client.get_pdu, - [dest], - event_id, - room_version=room_version, - outlier=True, - timeout=10000, - ) - for event_id in missing_auth - failed_to_fetch - ], - consumeErrors=True, - ) - ).addErrback(unwrapFirstError) - auth_events.update({a.event_id: a for a in results if a}) - required_auth.update( - a_id - for event in results - if event - for a_id in event.auth_event_ids() - ) - missing_auth = required_auth - set(auth_events) - - failed_to_fetch = missing_auth - set(auth_events) - - seen_events = yield self.store.have_seen_events( - set(auth_events.keys()) | set(state_events.keys()) - ) - - # We now have a chunk of events plus associated state and auth chain to - # persist. We do the persistence in two steps: - # 1. Auth events and state get persisted as outliers, plus the - # backward extremities get persisted (as non-outliers). - # 2. The rest of the events in the chunk get persisted one by one, as - # each one depends on the previous event for its state. - # - # The important thing is that events in the chunk get persisted as - # non-outliers, including when those events are also in the state or - # auth chain. Caution must therefore be taken to ensure that they are - # not accidentally marked as outliers. - - # Step 1a: persist auth events that *don't* appear in the chunk ev_infos = [] - for a in auth_events.values(): - # We only want to persist auth events as outliers that we haven't - # seen and aren't about to persist as part of the backfilled chunk. - if a.event_id in seen_events or a.event_id in event_map: - continue - a.internal_metadata.outlier = True - ev_infos.append( - _NewEventInfo( - event=a, - auth_events={ - ( - auth_events[a_id].type, - auth_events[a_id].state_key, - ): auth_events[a_id] - for a_id in a.auth_event_ids() - if a_id in auth_events - }, - ) - ) - - # Step 1b: persist the events in the chunk we fetched state for (i.e. - # the backwards extremities) as non-outliers. + # Step 1: persist the events in the chunk we fetched state for (i.e. + # the backwards extremities), with custom auth events and state for e_id in events_to_state: # For paranoia we ensure that these events are marked as # non-outliers @@ -1071,6 +1060,57 @@ class FederationHandler(BaseHandler): return False + @defer.inlineCallbacks + def _get_events_and_persist( + self, destination: str, room_id: str, events: Iterable[str] + ): + """Fetch the given events from a server, and persist them as outliers. + + Logs a warning if we can't find the given event. + """ + + room_version = yield self.store.get_room_version(room_id) + + event_infos = [] + + async def get_event(event_id: str): + with nested_logging_context(event_id): + try: + event = await self.federation_client.get_pdu( + [destination], event_id, room_version, outlier=True, + ) + if event is None: + logger.warning( + "Server %s didn't return event %s", destination, event_id, + ) + return + + # recursively fetch the auth events for this event + auth_events = await self._get_events_from_store_or_dest( + destination, room_id, event.auth_event_ids() + ) + auth = {} + for auth_event_id in event.auth_event_ids(): + ae = auth_events.get(auth_event_id) + if ae: + auth[(ae.type, ae.state_key)] = ae + + event_infos.append(_NewEventInfo(event, None, auth)) + + except Exception as e: + logger.warning( + "Error fetching missing state/auth event %s: %s %s", + event_id, + type(e), + e, + ) + + yield concurrently_execute(get_event, events, 5) + + yield self._handle_new_events( + destination, event_infos, + ) + def _sanity_check_event(self, ev): """ Do some early sanity checks of a received event diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 22768e97ff..60b8bbc7a5 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -907,7 +907,10 @@ class RoomContextHandler(object): results["events_before"] = yield filter_evts(results["events_before"]) results["events_after"] = yield filter_evts(results["events_after"]) - results["event"] = event + # filter_evts can return a pruned event in case the user is allowed to see that + # there's something there but not see the content, so use the event that's in + # `filtered` rather than the event we retrieved from the datastore. + results["event"] = filtered[0] if results["events_after"]: last_event_id = results["events_after"][-1].event_id @@ -938,7 +941,7 @@ class RoomContextHandler(object): if event_filter: state_events = event_filter.filter(state_events) - results["state"] = state_events + results["state"] = yield filter_evts(state_events) # We use a dummy token here as we only care about the room portion of # the token, which we replace. diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 139beef8ed..0e75e94c6f 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -16,7 +16,7 @@ import logging from collections import namedtuple -from typing import Iterable, Optional +from typing import Dict, Iterable, List, Optional, Tuple from six import iteritems, itervalues @@ -416,6 +416,7 @@ class StateHandler(object): with Measure(self.clock, "state._resolve_events"): new_state = yield resolve_events_with_store( + event.room_id, room_version, state_set_ids, event_map=state_map, @@ -461,7 +462,7 @@ class StateResolutionHandler(object): not be called for a single state group Args: - room_id (str): room we are resolving for (used for logging) + room_id (str): room we are resolving for (used for logging and sanity checks) room_version (str): version of the room state_groups_ids (dict[int, dict[(str, str), str]]): map from state group id to the state in that state group @@ -517,6 +518,7 @@ class StateResolutionHandler(object): logger.info("Resolving conflicted state for %r", room_id) with Measure(self.clock, "state._resolve_events"): new_state = yield resolve_events_with_store( + room_id, room_version, list(itervalues(state_groups_ids)), event_map=event_map, @@ -588,36 +590,44 @@ def _make_state_cache_entry(new_state, state_groups_ids): ) -def resolve_events_with_store(room_version, state_sets, event_map, state_res_store): +def resolve_events_with_store( + room_id: str, + room_version: str, + state_sets: List[Dict[Tuple[str, str], str]], + event_map: Optional[Dict[str, EventBase]], + state_res_store: "StateResolutionStore", +): """ Args: - room_version(str): Version of the room + room_id: the room we are working in - state_sets(list): List of dicts of (type, state_key) -> event_id, + room_version: Version of the room + + state_sets: List of dicts of (type, state_key) -> event_id, which are the different state groups to resolve. - event_map(dict[str,FrozenEvent]|None): + event_map: a dict from event_id to event, for any events that we happen to have in flight (eg, those currently being persisted). This will be used as a starting point fof finding the state we need; any missing events will be requested via state_map_factory. - If None, all events will be fetched via state_map_factory. + If None, all events will be fetched via state_res_store. - state_res_store (StateResolutionStore) + state_res_store: a place to fetch events from - Returns + Returns: Deferred[dict[(str, str), str]]: a map from (type, state_key) to event_id. """ v = KNOWN_ROOM_VERSIONS[room_version] if v.state_res == StateResolutionVersions.V1: return v1.resolve_events_with_store( - state_sets, event_map, state_res_store.get_events + room_id, state_sets, event_map, state_res_store.get_events ) else: return v2.resolve_events_with_store( - room_version, state_sets, event_map, state_res_store + room_id, room_version, state_sets, event_map, state_res_store ) diff --git a/synapse/state/v1.py b/synapse/state/v1.py index a2f92d9ff9..b2f9865f39 100644 --- a/synapse/state/v1.py +++ b/synapse/state/v1.py @@ -15,6 +15,7 @@ import hashlib import logging +from typing import Callable, Dict, List, Optional, Tuple from six import iteritems, iterkeys, itervalues @@ -24,6 +25,7 @@ from synapse import event_auth from synapse.api.constants import EventTypes from synapse.api.errors import AuthError from synapse.api.room_versions import RoomVersions +from synapse.events import EventBase logger = logging.getLogger(__name__) @@ -32,13 +34,20 @@ POWER_KEY = (EventTypes.PowerLevels, "") @defer.inlineCallbacks -def resolve_events_with_store(state_sets, event_map, state_map_factory): +def resolve_events_with_store( + room_id: str, + state_sets: List[Dict[Tuple[str, str], str]], + event_map: Optional[Dict[str, EventBase]], + state_map_factory: Callable, +): """ Args: - state_sets(list): List of dicts of (type, state_key) -> event_id, + room_id: the room we are working in + + state_sets: List of dicts of (type, state_key) -> event_id, which are the different state groups to resolve. - event_map(dict[str,FrozenEvent]|None): + event_map: a dict from event_id to event, for any events that we happen to have in flight (eg, those currently being persisted). This will be used as a starting point fof finding the state we need; any missing @@ -46,11 +55,11 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory): If None, all events will be fetched via state_map_factory. - state_map_factory(func): will be called + state_map_factory: will be called with a list of event_ids that are needed, and should return with a Deferred of dict of event_id to event. - Returns + Returns: Deferred[dict[(str, str), str]]: a map from (type, state_key) to event_id. """ @@ -76,6 +85,14 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory): if event_map is not None: state_map.update(event_map) + # everything in the state map should be in the right room + for event in state_map.values(): + if event.room_id != room_id: + raise Exception( + "Attempting to state-resolve for room %s with event %s which is in %s" + % (room_id, event.event_id, event.room_id,) + ) + # get the ids of the auth events which allow us to authenticate the # conflicted state, picking only from the unconflicting state. # @@ -95,6 +112,13 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory): ) state_map_new = yield state_map_factory(new_needed_events) + for event in state_map_new.values(): + if event.room_id != room_id: + raise Exception( + "Attempting to state-resolve for room %s with event %s which is in %s" + % (room_id, event.event_id, event.room_id,) + ) + state_map.update(state_map_new) return _resolve_with_state( diff --git a/synapse/state/v2.py b/synapse/state/v2.py index b327c86f40..cb77ed5b78 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -16,29 +16,40 @@ import heapq import itertools import logging +from typing import Dict, List, Optional, Tuple from six import iteritems, itervalues from twisted.internet import defer +import synapse.state from synapse import event_auth from synapse.api.constants import EventTypes from synapse.api.errors import AuthError +from synapse.events import EventBase logger = logging.getLogger(__name__) @defer.inlineCallbacks -def resolve_events_with_store(room_version, state_sets, event_map, state_res_store): +def resolve_events_with_store( + room_id: str, + room_version: str, + state_sets: List[Dict[Tuple[str, str], str]], + event_map: Optional[Dict[str, EventBase]], + state_res_store: "synapse.state.StateResolutionStore", +): """Resolves the state using the v2 state resolution algorithm Args: - room_version (str): The room version + room_id: the room we are working in - state_sets(list): List of dicts of (type, state_key) -> event_id, + room_version: The room version + + state_sets: List of dicts of (type, state_key) -> event_id, which are the different state groups to resolve. - event_map(dict[str,FrozenEvent]|None): + event_map: a dict from event_id to event, for any events that we happen to have in flight (eg, those currently being persisted). This will be used as a starting point fof finding the state we need; any missing @@ -46,9 +57,9 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto If None, all events will be fetched via state_res_store. - state_res_store (StateResolutionStore) + state_res_store: - Returns + Returns: Deferred[dict[(str, str), str]]: a map from (type, state_key) to event_id. """ @@ -84,6 +95,14 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto ) event_map.update(events) + # everything in the event map should be in the right room + for event in event_map.values(): + if event.room_id != room_id: + raise Exception( + "Attempting to state-resolve for room %s with event %s which is in %s" + % (room_id, event.event_id, event.room_id,) + ) + full_conflicted_set = set(eid for eid in full_conflicted_set if eid in event_map) logger.debug("%d full_conflicted_set entries", len(full_conflicted_set)) @@ -94,13 +113,14 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto ) sorted_power_events = yield _reverse_topological_power_sort( - power_events, event_map, state_res_store, full_conflicted_set + room_id, power_events, event_map, state_res_store, full_conflicted_set ) logger.debug("sorted %d power events", len(sorted_power_events)) # Now sequentially auth each one resolved_state = yield _iterative_auth_checks( + room_id, room_version, sorted_power_events, unconflicted_state, @@ -121,13 +141,18 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto pl = resolved_state.get((EventTypes.PowerLevels, ""), None) leftover_events = yield _mainline_sort( - leftover_events, pl, event_map, state_res_store + room_id, leftover_events, pl, event_map, state_res_store ) logger.debug("resolving remaining events") resolved_state = yield _iterative_auth_checks( - room_version, leftover_events, resolved_state, event_map, state_res_store + room_id, + room_version, + leftover_events, + resolved_state, + event_map, + state_res_store, ) logger.debug("resolved") @@ -141,11 +166,12 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto @defer.inlineCallbacks -def _get_power_level_for_sender(event_id, event_map, state_res_store): +def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store): """Return the power level of the sender of the given event according to their auth events. Args: + room_id (str) event_id (str) event_map (dict[str,FrozenEvent]) state_res_store (StateResolutionStore) @@ -153,11 +179,11 @@ def _get_power_level_for_sender(event_id, event_map, state_res_store): Returns: Deferred[int] """ - event = yield _get_event(event_id, event_map, state_res_store) + event = yield _get_event(room_id, event_id, event_map, state_res_store) pl = None for aid in event.auth_event_ids(): - aev = yield _get_event(aid, event_map, state_res_store) + aev = yield _get_event(room_id, aid, event_map, state_res_store) if (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""): pl = aev break @@ -165,7 +191,7 @@ def _get_power_level_for_sender(event_id, event_map, state_res_store): if pl is None: # Couldn't find power level. Check if they're the creator of the room for aid in event.auth_event_ids(): - aev = yield _get_event(aid, event_map, state_res_store) + aev = yield _get_event(room_id, aid, event_map, state_res_store) if (aev.type, aev.state_key) == (EventTypes.Create, ""): if aev.content.get("creator") == event.sender: return 100 @@ -279,7 +305,7 @@ def _is_power_event(event): @defer.inlineCallbacks def _add_event_and_auth_chain_to_graph( - graph, event_id, event_map, state_res_store, auth_diff + graph, room_id, event_id, event_map, state_res_store, auth_diff ): """Helper function for _reverse_topological_power_sort that add the event and its auth chain (that is in the auth diff) to the graph @@ -287,6 +313,7 @@ def _add_event_and_auth_chain_to_graph( Args: graph (dict[str, set[str]]): A map from event ID to the events auth event IDs + room_id (str): the room we are working in event_id (str): Event to add to the graph event_map (dict[str,FrozenEvent]) state_res_store (StateResolutionStore) @@ -298,7 +325,7 @@ def _add_event_and_auth_chain_to_graph( eid = state.pop() graph.setdefault(eid, set()) - event = yield _get_event(eid, event_map, state_res_store) + event = yield _get_event(room_id, eid, event_map, state_res_store) for aid in event.auth_event_ids(): if aid in auth_diff: if aid not in graph: @@ -308,11 +335,14 @@ def _add_event_and_auth_chain_to_graph( @defer.inlineCallbacks -def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_diff): +def _reverse_topological_power_sort( + room_id, event_ids, event_map, state_res_store, auth_diff +): """Returns a list of the event_ids sorted by reverse topological ordering, and then by power level and origin_server_ts Args: + room_id (str): the room we are working in event_ids (list[str]): The events to sort event_map (dict[str,FrozenEvent]) state_res_store (StateResolutionStore) @@ -325,12 +355,14 @@ def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_ graph = {} for event_id in event_ids: yield _add_event_and_auth_chain_to_graph( - graph, event_id, event_map, state_res_store, auth_diff + graph, room_id, event_id, event_map, state_res_store, auth_diff ) event_to_pl = {} for event_id in graph: - pl = yield _get_power_level_for_sender(event_id, event_map, state_res_store) + pl = yield _get_power_level_for_sender( + room_id, event_id, event_map, state_res_store + ) event_to_pl[event_id] = pl def _get_power_order(event_id): @@ -348,12 +380,13 @@ def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_ @defer.inlineCallbacks def _iterative_auth_checks( - room_version, event_ids, base_state, event_map, state_res_store + room_id, room_version, event_ids, base_state, event_map, state_res_store ): """Sequentially apply auth checks to each event in given list, updating the state as it goes along. Args: + room_id (str) room_version (str) event_ids (list[str]): Ordered list of events to apply auth checks to base_state (dict[tuple[str, str], str]): The set of state to start with @@ -370,7 +403,7 @@ def _iterative_auth_checks( auth_events = {} for aid in event.auth_event_ids(): - ev = yield _get_event(aid, event_map, state_res_store) + ev = yield _get_event(room_id, aid, event_map, state_res_store) if ev.rejected_reason is None: auth_events[(ev.type, ev.state_key)] = ev @@ -378,7 +411,7 @@ def _iterative_auth_checks( for key in event_auth.auth_types_for_event(event): if key in resolved_state: ev_id = resolved_state[key] - ev = yield _get_event(ev_id, event_map, state_res_store) + ev = yield _get_event(room_id, ev_id, event_map, state_res_store) if ev.rejected_reason is None: auth_events[key] = event_map[ev_id] @@ -400,11 +433,14 @@ def _iterative_auth_checks( @defer.inlineCallbacks -def _mainline_sort(event_ids, resolved_power_event_id, event_map, state_res_store): +def _mainline_sort( + room_id, event_ids, resolved_power_event_id, event_map, state_res_store +): """Returns a sorted list of event_ids sorted by mainline ordering based on the given event resolved_power_event_id Args: + room_id (str): room we're working in event_ids (list[str]): Events to sort resolved_power_event_id (str): The final resolved power level event ID event_map (dict[str,FrozenEvent]) @@ -417,11 +453,11 @@ def _mainline_sort(event_ids, resolved_power_event_id, event_map, state_res_stor pl = resolved_power_event_id while pl: mainline.append(pl) - pl_ev = yield _get_event(pl, event_map, state_res_store) + pl_ev = yield _get_event(room_id, pl, event_map, state_res_store) auth_events = pl_ev.auth_event_ids() pl = None for aid in auth_events: - ev = yield _get_event(aid, event_map, state_res_store) + ev = yield _get_event(room_id, aid, event_map, state_res_store) if (ev.type, ev.state_key) == (EventTypes.PowerLevels, ""): pl = aid break @@ -457,6 +493,8 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor Deferred[int] """ + room_id = event.room_id + # We do an iterative search, replacing `event with the power level in its # auth events (if any) while event: @@ -468,7 +506,7 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor event = None for aid in auth_events: - aev = yield _get_event(aid, event_map, state_res_store) + aev = yield _get_event(room_id, aid, event_map, state_res_store) if (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""): event = aev break @@ -478,11 +516,12 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor @defer.inlineCallbacks -def _get_event(event_id, event_map, state_res_store): +def _get_event(room_id, event_id, event_map, state_res_store): """Helper function to look up event in event_map, falling back to looking it up in the store Args: + room_id (str) event_id (str) event_map (dict[str,FrozenEvent]) state_res_store (StateResolutionStore) @@ -493,7 +532,14 @@ def _get_event(event_id, event_map, state_res_store): if event_id not in event_map: events = yield state_res_store.get_events([event_id], allow_rejected=True) event_map.update(events) - return event_map[event_id] + event = event_map[event_id] + assert event is not None + if event.room_id != room_id: + raise Exception( + "In state res for room %s, event %s is in %s" + % (room_id, event_id, event.room_id) + ) + return event def lexicographical_topological_sort(graph, key): diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 5c4de2e69f..04b6abdc24 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -140,8 +140,8 @@ def concurrently_execute(func, args, limit): Args: func (func): Function to execute, should return a deferred or coroutine. - args (list): List of arguments to pass to func, each invocation of func - gets a signle argument. + args (Iterable): List of arguments to pass to func, each invocation of func + gets a single argument. limit (int): Maximum number of conccurent executions. Returns: diff --git a/synapse/visibility.py b/synapse/visibility.py index dffe943b28..100dc47a8a 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -52,7 +52,8 @@ def filter_events_for_client( apply_retention_policies=True, ): """ - Check which events a user is allowed to see + Check which events a user is allowed to see. If the user can see the event but its + sender asked for their data to be erased, prune the content of the event. Args: storage diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 1ca7fa742f..e3af280ba6 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -29,6 +29,7 @@ import synapse.rest.admin from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.handlers.pagination import PurgeStatus from synapse.rest.client.v1 import login, profile, room +from synapse.rest.client.v2_alpha import account from synapse.util.stringutils import random_string from tests import unittest @@ -1597,3 +1598,129 @@ class LabelsTestCase(unittest.HomeserverTestCase): ) return event_id + + +class ContextTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + account.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.user_id = self.register_user("user", "password") + self.tok = self.login("user", "password") + self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + + self.other_user_id = self.register_user("user2", "password") + self.other_tok = self.login("user2", "password") + + self.helper.invite(self.room_id, self.user_id, self.other_user_id, tok=self.tok) + self.helper.join(self.room_id, self.other_user_id, tok=self.other_tok) + + def test_erased_sender(self): + """Test that an erasure request results in the requester's events being hidden + from any new member of the room. + """ + + # Send a bunch of events in the room. + + self.helper.send(self.room_id, "message 1", tok=self.tok) + self.helper.send(self.room_id, "message 2", tok=self.tok) + event_id = self.helper.send(self.room_id, "message 3", tok=self.tok)["event_id"] + self.helper.send(self.room_id, "message 4", tok=self.tok) + self.helper.send(self.room_id, "message 5", tok=self.tok) + + # Check that we can still see the messages before the erasure request. + + request, channel = self.make_request( + "GET", + '/rooms/%s/context/%s?filter={"types":["m.room.message"]}' + % (self.room_id, event_id), + access_token=self.tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + events_before = channel.json_body["events_before"] + + self.assertEqual(len(events_before), 2, events_before) + self.assertEqual( + events_before[0].get("content", {}).get("body"), + "message 2", + events_before[0], + ) + self.assertEqual( + events_before[1].get("content", {}).get("body"), + "message 1", + events_before[1], + ) + + self.assertEqual( + channel.json_body["event"].get("content", {}).get("body"), + "message 3", + channel.json_body["event"], + ) + + events_after = channel.json_body["events_after"] + + self.assertEqual(len(events_after), 2, events_after) + self.assertEqual( + events_after[0].get("content", {}).get("body"), + "message 4", + events_after[0], + ) + self.assertEqual( + events_after[1].get("content", {}).get("body"), + "message 5", + events_after[1], + ) + + # Deactivate the first account and erase the user's data. + + deactivate_account_handler = self.hs.get_deactivate_account_handler() + self.get_success( + deactivate_account_handler.deactivate_account(self.user_id, erase_data=True) + ) + + # Invite another user in the room. This is needed because messages will be + # pruned only if the user wasn't a member of the room when the messages were + # sent. + + invited_user_id = self.register_user("user3", "password") + invited_tok = self.login("user3", "password") + + self.helper.invite( + self.room_id, self.other_user_id, invited_user_id, tok=self.other_tok + ) + self.helper.join(self.room_id, invited_user_id, tok=invited_tok) + + # Check that a user that joined the room after the erasure request can't see + # the messages anymore. + + request, channel = self.make_request( + "GET", + '/rooms/%s/context/%s?filter={"types":["m.room.message"]}' + % (self.room_id, event_id), + access_token=invited_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + events_before = channel.json_body["events_before"] + + self.assertEqual(len(events_before), 2, events_before) + self.assertDictEqual(events_before[0].get("content"), {}, events_before[0]) + self.assertDictEqual(events_before[1].get("content"), {}, events_before[1]) + + self.assertDictEqual( + channel.json_body["event"].get("content"), {}, channel.json_body["event"] + ) + + events_after = channel.json_body["events_after"] + + self.assertEqual(len(events_after), 2, events_after) + self.assertDictEqual(events_after[0].get("content"), {}, events_after[0]) + self.assertEqual(events_after[1].get("content"), {}, events_after[1]) diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py index 8d3845c870..0f341d3ac3 100644 --- a/tests/state/test_v2.py +++ b/tests/state/test_v2.py @@ -58,6 +58,7 @@ class FakeEvent(object): self.type = type self.state_key = state_key self.content = content + self.room_id = ROOM_ID def to_event(self, auth_events, prev_events): """Given the auth_events and prev_events, convert to a Frozen Event @@ -418,6 +419,7 @@ class StateTestCase(unittest.TestCase): state_before = dict(state_at_event[prev_events[0]]) else: state_d = resolve_events_with_store( + ROOM_ID, RoomVersions.V2.identifier, [state_at_event[n] for n in prev_events], event_map=event_map, @@ -565,6 +567,7 @@ class SimpleParamStateTestCase(unittest.TestCase): # Test that we correctly handle passing `None` as the event_map state_d = resolve_events_with_store( + ROOM_ID, RoomVersions.V2.identifier, [self.state_at_bob, self.state_at_charlie], event_map=None,