diff --git a/changelog.d/12332.misc b/changelog.d/12332.misc new file mode 100644 index 0000000000..9f333e718a --- /dev/null +++ b/changelog.d/12332.misc @@ -0,0 +1 @@ +Avoid trying to calculate the state at outlier events. diff --git a/synapse/visibility.py b/synapse/visibility.py index 49519eb8f5..250f073597 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -1,4 +1,5 @@ # Copyright 2014 - 2016 OpenMarket Ltd +# Copyright (C) The Matrix.org Foundation C.I.C. 2022 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Dict, FrozenSet, List, Optional +from typing import Collection, Dict, FrozenSet, List, Optional, Tuple + +from typing_extensions import Final from synapse.api.constants import EventTypes, HistoryVisibility, Membership from synapse.events import EventBase @@ -40,6 +43,8 @@ MEMBERSHIP_PRIORITY = ( Membership.BAN, ) +_HISTORY_VIS_KEY: Final[Tuple[str, str]] = (EventTypes.RoomHistoryVisibility, "") + async def filter_events_for_client( storage: Storage, @@ -74,7 +79,7 @@ async def filter_events_for_client( # to clients. events = [e for e in events if not e.internal_metadata.is_soft_failed()] - types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id)) + types = (_HISTORY_VIS_KEY, (EventTypes.Member, user_id)) # we exclude outliers at this point, and then handle them separately later event_id_to_state = await storage.state.get_state_for_events( @@ -157,7 +162,7 @@ async def filter_events_for_client( state = event_id_to_state[event.event_id] # get the room_visibility at the time of the event. - visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None) + visibility_event = state.get(_HISTORY_VIS_KEY, None) if visibility_event: visibility = visibility_event.content.get( "history_visibility", HistoryVisibility.SHARED @@ -293,67 +298,28 @@ async def filter_events_for_server( return True return False - def check_event_is_visible(event: EventBase, state: StateMap[EventBase]) -> bool: - history = state.get((EventTypes.RoomHistoryVisibility, ""), None) - if history: - visibility = history.content.get( - "history_visibility", HistoryVisibility.SHARED - ) - if visibility in [HistoryVisibility.INVITED, HistoryVisibility.JOINED]: - # We now loop through all state events looking for - # membership states for the requesting server to determine - # if the server is either in the room or has been invited - # into the room. - for ev in state.values(): - if ev.type != EventTypes.Member: - continue - try: - domain = get_domain_from_id(ev.state_key) - except Exception: - continue + def check_event_is_visible( + visibility: str, memberships: StateMap[EventBase] + ) -> bool: + if visibility not in (HistoryVisibility.INVITED, HistoryVisibility.JOINED): + return True - if domain != server_name: - continue + # We now loop through all membership events looking for + # membership states for the requesting server to determine + # if the server is either in the room or has been invited + # into the room. + for ev in memberships.values(): + assert get_domain_from_id(ev.state_key) == server_name - memtype = ev.membership - if memtype == Membership.JOIN: - return True - elif memtype == Membership.INVITE: - if visibility == HistoryVisibility.INVITED: - return True - else: - # server has no users in the room: redact - return False + memtype = ev.membership + if memtype == Membership.JOIN: + return True + elif memtype == Membership.INVITE: + if visibility == HistoryVisibility.INVITED: + return True - return True - - # Lets check to see if all the events have a history visibility - # of "shared" or "world_readable". If that's the case then we don't - # need to check membership (as we know the server is in the room). - event_to_state_ids = await storage.state.get_state_ids_for_events( - frozenset(e.event_id for e in events), - state_filter=StateFilter.from_types( - types=((EventTypes.RoomHistoryVisibility, ""),) - ), - ) - - visibility_ids = set() - for sids in event_to_state_ids.values(): - hist = sids.get((EventTypes.RoomHistoryVisibility, "")) - if hist: - visibility_ids.add(hist) - - # If we failed to find any history visibility events then the default - # is "shared" visibility. - if not visibility_ids: - all_open = True - else: - event_map = await storage.main.get_events(visibility_ids) - all_open = all( - e.content.get("history_visibility") - in (None, HistoryVisibility.SHARED, HistoryVisibility.WORLD_READABLE) - for e in event_map.values() - ) + # server has no users in the room: redact + return False if not check_history_visibility_only: erased_senders = await storage.main.are_users_erased(e.sender for e in events) @@ -362,34 +328,100 @@ async def filter_events_for_server( # to no users having been erased. erased_senders = {} - if all_open: - # all the history_visibility state affecting these events is open, so - # we don't need to filter by membership state. We *do* need to check - # for user erasure, though. - if erased_senders: - to_return = [] - for e in events: - if not is_sender_erased(e, erased_senders): - to_return.append(e) - elif redact: - to_return.append(prune_event(e)) + # Let's check to see if all the events have a history visibility + # of "shared" or "world_readable". If that's the case then we don't + # need to check membership (as we know the server is in the room). + event_to_history_vis = await _event_to_history_vis(storage, events) - return to_return + # for any with restricted vis, we also need the memberships + event_to_memberships = await _event_to_memberships( + storage, + [ + e + for e in events + if event_to_history_vis[e.event_id] + not in (HistoryVisibility.SHARED, HistoryVisibility.WORLD_READABLE) + ], + server_name, + ) - # If there are no erased users then we can just return the given list - # of events without having to copy it. - return events + to_return = [] + for e in events: + erased = is_sender_erased(e, erased_senders) + visible = check_event_is_visible( + event_to_history_vis[e.event_id], event_to_memberships.get(e.event_id, {}) + ) + if visible and not erased: + to_return.append(e) + elif redact: + to_return.append(prune_event(e)) - # Ok, so we're dealing with events that have non-trivial visibility - # rules, so we need to also get the memberships of the room. + return to_return - # first, for each event we're wanting to return, get the event_ids - # of the history vis and membership state at those events. + +async def _event_to_history_vis( + storage: Storage, events: Collection[EventBase] +) -> Dict[str, str]: + """Get the history visibility at each of the given events + + Returns a map from event id to history_visibility setting + """ + + # outliers get special treatment here. We don't have the state at that point in the + # room (and attempting to look it up will raise an exception), so all we can really + # do is assume that the requesting server is allowed to see the event. That's + # equivalent to there not being a history_visibility event, so we just exclude + # any outliers from the query. + event_to_state_ids = await storage.state.get_state_ids_for_events( + frozenset(e.event_id for e in events if not e.internal_metadata.is_outlier()), + state_filter=StateFilter.from_types(types=(_HISTORY_VIS_KEY,)), + ) + + visibility_ids = { + vis_event_id + for vis_event_id in ( + state_ids.get(_HISTORY_VIS_KEY) for state_ids in event_to_state_ids.values() + ) + if vis_event_id + } + vis_events = await storage.main.get_events(visibility_ids) + + result: Dict[str, str] = {} + for event in events: + vis = HistoryVisibility.SHARED + state_ids = event_to_state_ids.get(event.event_id) + + # if we didn't find any state for this event, it's an outlier, and we assume + # it's open + visibility_id = None + if state_ids: + visibility_id = state_ids.get(_HISTORY_VIS_KEY) + + if visibility_id: + vis_event = vis_events[visibility_id] + vis = vis_event.content.get("history_visibility", HistoryVisibility.SHARED) + assert isinstance(vis, str) + + result[event.event_id] = vis + return result + + +async def _event_to_memberships( + storage: Storage, events: Collection[EventBase], server_name: str +) -> Dict[str, StateMap[EventBase]]: + """Get the remote membership list at each of the given events + + Returns a map from event id to state map, which will contain only membership events + for the given server. + """ + + if not events: + return {} + + # for each event, get the event_ids of the membership state at those events. event_to_state_ids = await storage.state.get_state_ids_for_events( frozenset(e.event_id for e in events), - state_filter=StateFilter.from_types( - types=((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, None)) - ), + state_filter=StateFilter.from_types(types=((EventTypes.Member, None),)), ) # We only want to pull out member events that correspond to the @@ -405,10 +437,7 @@ async def filter_events_for_server( for key, event_id in key_to_eid.items() } - def include(typ, state_key): - if typ != EventTypes.Member: - return True - + def include(state_key: str) -> bool: # we avoid using get_domain_from_id here for efficiency. idx = state_key.find(":") if idx == -1: @@ -416,10 +445,14 @@ async def filter_events_for_server( return state_key[idx + 1 :] == server_name event_map = await storage.main.get_events( - [e_id for e_id, key in event_id_to_state_key.items() if include(key[0], key[1])] + [ + e_id + for e_id, (_, state_key) in event_id_to_state_key.items() + if include(state_key) + ] ) - event_to_state = { + return { e_id: { key: event_map[inner_e_id] for key, inner_e_id in key_to_eid.items() @@ -427,14 +460,3 @@ async def filter_events_for_server( } for e_id, key_to_eid in event_to_state_ids.items() } - - to_return = [] - for e in events: - erased = is_sender_erased(e, erased_senders) - visible = check_event_is_visible(e, event_to_state[e.event_id]) - if visible and not erased: - to_return.append(e) - elif redact: - to_return.append(prune_event(e)) - - return to_return diff --git a/tests/test_visibility.py b/tests/test_visibility.py index 532e3fe9cd..a02fd4f79a 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -17,6 +17,7 @@ from unittest.mock import patch from synapse.api.room_versions import RoomVersions from synapse.events import EventBase, make_event_from_dict +from synapse.events.snapshot import EventContext from synapse.types import JsonDict, create_requester from synapse.visibility import filter_events_for_client, filter_events_for_server @@ -73,6 +74,39 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id) self.assertEqual(filtered[i].content["a"], "b") + def test_filter_outlier(self) -> None: + # outlier events must be returned, for the good of the collective federation + self.get_success(self._inject_room_member("@resident:remote_hs")) + self.get_success(self._inject_visibility("@resident:remote_hs", "joined")) + + outlier = self.get_success(self._inject_outlier()) + self.assertEqual( + self.get_success( + filter_events_for_server(self.storage, "remote_hs", [outlier]) + ), + [outlier], + ) + + # it should also work when there are other events in the list + evt = self.get_success(self._inject_message("@unerased:local_hs")) + + filtered = self.get_success( + filter_events_for_server(self.storage, "remote_hs", [outlier, evt]) + ) + self.assertEqual(len(filtered), 2, f"expected 2 results, got: {filtered}") + self.assertEqual(filtered[0], outlier) + self.assertEqual(filtered[1].event_id, evt.event_id) + self.assertEqual(filtered[1].content, evt.content) + + # ... but other servers should only be able to see the outlier (the other should + # be redacted) + filtered = self.get_success( + filter_events_for_server(self.storage, "other_server", [outlier, evt]) + ) + self.assertEqual(filtered[0], outlier) + self.assertEqual(filtered[1].event_id, evt.event_id) + self.assertNotIn("body", filtered[1].content) + def test_erased_user(self) -> None: # 4 message events, from erased and unerased users, with a membership # change in the middle of them. @@ -187,6 +221,25 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): self.get_success(self.storage.persistence.persist_event(event, context)) return event + def _inject_outlier(self) -> EventBase: + builder = self.event_builder_factory.for_room_version( + RoomVersions.V1, + { + "type": "m.room.member", + "sender": "@test:user", + "state_key": "@test:user", + "room_id": TEST_ROOM_ID, + "content": {"membership": "join"}, + }, + ) + + event = self.get_success(builder.build(prev_event_ids=[], auth_event_ids=[])) + event.internal_metadata.outlier = True + self.get_success( + self.storage.persistence.persist_event(event, EventContext.for_outlier()) + ) + return event + class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase): def test_out_of_band_invite_rejection(self):