From eb609c65d0794dd49efcd924bdc8743fd4253a93 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Fri, 18 Feb 2022 14:54:31 +0000 Subject: [PATCH] Fix bug in `StateFilter.return_expanded()` and add some tests. (#12016) --- changelog.d/12016.misc | 1 + synapse/storage/state.py | 8 ++- tests/storage/test_state.py | 109 ++++++++++++++++++++++++++++++++++++ 3 files changed, 117 insertions(+), 1 deletion(-) create mode 100644 changelog.d/12016.misc diff --git a/changelog.d/12016.misc b/changelog.d/12016.misc new file mode 100644 index 0000000000..8856ef46a9 --- /dev/null +++ b/changelog.d/12016.misc @@ -0,0 +1 @@ +Fix bug in `StateFilter.return_expanded()` and add some tests. \ No newline at end of file diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 913448f0f9..e79ecf64a0 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -204,13 +204,16 @@ class StateFilter: if get_all_members: # We want to return everything. return StateFilter.all() - else: + elif EventTypes.Member in self.types: # We want to return all non-members, but only particular # memberships return StateFilter( types=frozendict({EventTypes.Member: self.types[EventTypes.Member]}), include_others=True, ) + else: + # We want to return all non-members + return _ALL_NON_MEMBER_STATE_FILTER def make_sql_filter_clause(self) -> Tuple[str, List[str]]: """Converts the filter to an SQL clause. @@ -528,6 +531,9 @@ class StateFilter: _ALL_STATE_FILTER = StateFilter(types=frozendict(), include_others=True) +_ALL_NON_MEMBER_STATE_FILTER = StateFilter( + types=frozendict({EventTypes.Member: frozenset()}), include_others=True +) _NONE_STATE_FILTER = StateFilter(types=frozendict(), include_others=False) diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 70d52b088c..28c767ecfd 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -992,3 +992,112 @@ class StateFilterDifferenceTestCase(TestCase): StateFilter.none(), StateFilter.all(), ) + + +class StateFilterTestCase(TestCase): + def test_return_expanded(self): + """ + Tests the behaviour of the return_expanded() function that expands + StateFilters to include more state types (for the sake of cache hit rate). + """ + + self.assertEqual(StateFilter.all().return_expanded(), StateFilter.all()) + + self.assertEqual(StateFilter.none().return_expanded(), StateFilter.none()) + + # Concrete-only state filters stay the same + # (Case: mixed filter) + self.assertEqual( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:test", "@alicia:test"}, + "some.other.state.type": {""}, + }, + include_others=False, + ).return_expanded(), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:test", "@alicia:test"}, + "some.other.state.type": {""}, + }, + include_others=False, + ), + ) + + # Concrete-only state filters stay the same + # (Case: non-member-only filter) + self.assertEqual( + StateFilter.freeze( + {"some.other.state.type": {""}}, include_others=False + ).return_expanded(), + StateFilter.freeze({"some.other.state.type": {""}}, include_others=False), + ) + + # Concrete-only state filters stay the same + # (Case: member-only filter) + self.assertEqual( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:test", "@alicia:test"}, + }, + include_others=False, + ).return_expanded(), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:test", "@alicia:test"}, + }, + include_others=False, + ), + ) + + # Wildcard member-only state filters stay the same + self.assertEqual( + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=False, + ).return_expanded(), + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=False, + ), + ) + + # If there is a wildcard in the non-member portion of the filter, + # it's expanded to include ALL non-member events. + # (Case: mixed filter) + self.assertEqual( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:test", "@alicia:test"}, + "some.other.state.type": None, + }, + include_others=False, + ).return_expanded(), + StateFilter.freeze( + {EventTypes.Member: {"@wombat:test", "@alicia:test"}}, + include_others=True, + ), + ) + + # If there is a wildcard in the non-member portion of the filter, + # it's expanded to include ALL non-member events. + # (Case: non-member-only filter) + self.assertEqual( + StateFilter.freeze( + { + "some.other.state.type": None, + }, + include_others=False, + ).return_expanded(), + StateFilter.freeze({EventTypes.Member: set()}, include_others=True), + ) + self.assertEqual( + StateFilter.freeze( + { + "some.other.state.type": None, + "yet.another.state.type": {"wombat"}, + }, + include_others=False, + ).return_expanded(), + StateFilter.freeze({EventTypes.Member: set()}, include_others=True), + )