diff --git a/changelog.d/3958.misc b/changelog.d/3958.misc new file mode 100644 index 0000000000..5931d06dcf --- /dev/null +++ b/changelog.d/3958.misc @@ -0,0 +1 @@ +Fix docstrings and add tests for state store methods diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 4b971efdba..3f4cbd61c4 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -255,7 +255,17 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): ) @defer.inlineCallbacks - def get_state_groups_ids(self, room_id, event_ids): + def get_state_groups_ids(self, _room_id, event_ids): + """Get the event IDs of all the state for the state groups for the given events + + Args: + _room_id (str): id of the room for these events + event_ids (iterable[str]): ids of the events + + Returns: + Deferred[dict[int, dict[tuple[str, str], str]]]: + dict of state_group_id -> (dict of (type, state_key) -> event id) + """ if not event_ids: defer.returnValue({}) @@ -270,7 +280,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): @defer.inlineCallbacks def get_state_ids_for_group(self, state_group): - """Get the state IDs for the given state group + """Get the event IDs of all the state in the given state group Args: state_group (int) @@ -286,7 +296,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): def get_state_groups(self, room_id, event_ids): """ Get the state groups for the given list of event_ids - The return value is a dict mapping group names to lists of events. + Returns: + Deferred[dict[int, list[EventBase]]]: + dict of state_group_id -> list of state events. """ if not event_ids: defer.returnValue({}) @@ -324,7 +336,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): member events (if True), or to exclude member events (if False) Returns: - dictionary state_group -> (dict of (type, state_key) -> event id) + Returns: + Deferred[dict[int, dict[tuple[str, str], str]]]: + dict of state_group_id -> (dict of (type, state_key) -> event id) """ results = {} @@ -732,8 +746,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): If None, `types` filtering is applied to all events. Returns: - Deferred[dict[int, dict[(type, state_key), EventBase]]] - a dictionary mapping from state group to state dictionary. + Deferred[dict[int, dict[tuple[str, str], str]]]: + dict of state_group_id -> (dict of (type, state_key) -> event id) """ if types is not None: non_member_types = [t for t in types if t[0] != EventTypes.Member] @@ -788,8 +802,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): If None, `types` filtering is applied to all events. Returns: - Deferred[dict[int, dict[(type, state_key), EventBase]]] - a dictionary mapping from state group to state dictionary. + Deferred[dict[int, dict[tuple[str, str], str]]]: + dict of state_group_id -> (dict of (type, state_key) -> event id) """ if types: types = frozenset(types) diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index b910965932..b9c5b39d59 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -74,6 +74,45 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertEqual(s1[t].event_id, s2[t].event_id) self.assertEqual(len(s1), len(s2)) + @defer.inlineCallbacks + def test_get_state_groups_ids(self): + e1 = yield self.inject_state_event( + self.room, self.u_alice, EventTypes.Create, '', {} + ) + e2 = yield self.inject_state_event( + self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"} + ) + + state_group_map = yield self.store.get_state_groups_ids(self.room, [e2.event_id]) + self.assertEqual(len(state_group_map), 1) + state_map = list(state_group_map.values())[0] + self.assertDictEqual( + state_map, + { + (EventTypes.Create, ''): e1.event_id, + (EventTypes.Name, ''): e2.event_id, + }, + ) + + @defer.inlineCallbacks + def test_get_state_groups(self): + e1 = yield self.inject_state_event( + self.room, self.u_alice, EventTypes.Create, '', {} + ) + e2 = yield self.inject_state_event( + self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"} + ) + + state_group_map = yield self.store.get_state_groups( + self.room, [e2.event_id]) + self.assertEqual(len(state_group_map), 1) + state_list = list(state_group_map.values())[0] + + self.assertEqual( + {ev.event_id for ev in state_list}, + {e1.event_id, e2.event_id}, + ) + @defer.inlineCallbacks def test_get_state_for_event(self):