Speed up persisting large number of outliers (#16649)

Recalculating the roots tuple every iteration could be very expensive, so instead let's do a topological sort.
pull/16648/merge
Erik Johnston 2023-11-16 14:25:35 +00:00 committed by GitHub
parent fef08cbee8
commit 1b238e8837
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 134 additions and 12 deletions

1
changelog.d/16649.misc Normal file
View File

@ -0,0 +1 @@
Speed up persisting large number of outliers.

View File

@ -88,7 +88,7 @@ from synapse.types import (
) )
from synapse.types.state import StateFilter from synapse.types.state import StateFilter
from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.iterutils import batch_iter, partition from synapse.util.iterutils import batch_iter, partition, sorted_topologically_batched
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
from synapse.util.stringutils import shortstr from synapse.util.stringutils import shortstr
@ -1669,14 +1669,13 @@ class FederationEventHandler:
# XXX: it might be possible to kick this process off in parallel with fetching # XXX: it might be possible to kick this process off in parallel with fetching
# the events. # the events.
while event_map:
# build a list of events whose auth events are not in the queue.
roots = tuple(
ev
for ev in event_map.values()
if not any(aid in event_map for aid in ev.auth_event_ids())
)
# We need to persist an event's auth events before the event.
auth_graph = {
ev: [event_map[e_id] for e_id in ev.auth_event_ids() if e_id in event_map]
for ev in event_map.values()
}
for roots in sorted_topologically_batched(event_map.values(), auth_graph):
if not roots: if not roots:
# if *none* of the remaining events are ready, that means # if *none* of the remaining events are ready, that means
# we have a loop. This either means a bug in our logic, or that # we have a loop. This either means a bug in our logic, or that
@ -1698,9 +1697,6 @@ class FederationEventHandler:
await self._auth_and_persist_outliers_inner(room_id, roots) await self._auth_and_persist_outliers_inner(room_id, roots)
for ev in roots:
del event_map[ev.event_id]
async def _auth_and_persist_outliers_inner( async def _auth_and_persist_outliers_inner(
self, room_id: str, fetched_events: Collection[EventBase] self, room_id: str, fetched_events: Collection[EventBase]
) -> None: ) -> None:

View File

@ -135,3 +135,54 @@ def sorted_topologically(
degree_map[edge] -= 1 degree_map[edge] -= 1
if degree_map[edge] == 0: if degree_map[edge] == 0:
heapq.heappush(zero_degree, edge) heapq.heappush(zero_degree, edge)
def sorted_topologically_batched(
nodes: Iterable[T],
graph: Mapping[T, Collection[T]],
) -> Generator[Collection[T], None, None]:
r"""Walk the graph topologically, returning batches of nodes where all nodes
that references it have been previously returned.
For example, given the following graph:
A
/ \
B C
\ /
D
This function will return: `[[A], [B, C], [D]]`.
This function is useful for e.g. batch persisting events in an auth chain,
where we can only persist an event if all its auth events have already been
persisted.
"""
degree_map = {node: 0 for node in nodes}
reverse_graph: Dict[T, Set[T]] = {}
for node, edges in graph.items():
if node not in degree_map:
continue
for edge in set(edges):
if edge in degree_map:
degree_map[node] += 1
reverse_graph.setdefault(edge, set()).add(node)
reverse_graph.setdefault(node, set())
zero_degree = [node for node, degree in degree_map.items() if degree == 0]
while zero_degree:
new_zero_degree = []
for node in zero_degree:
for edge in reverse_graph.get(node, []):
if edge in degree_map:
degree_map[edge] -= 1
if degree_map[edge] == 0:
new_zero_degree.append(edge)
yield zero_degree
zero_degree = new_zero_degree

View File

@ -13,7 +13,11 @@
# limitations under the License. # limitations under the License.
from typing import Dict, Iterable, List, Sequence from typing import Dict, Iterable, List, Sequence
from synapse.util.iterutils import chunk_seq, sorted_topologically from synapse.util.iterutils import (
chunk_seq,
sorted_topologically,
sorted_topologically_batched,
)
from tests.unittest import TestCase from tests.unittest import TestCase
@ -107,3 +111,73 @@ class SortTopologically(TestCase):
graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3, 2, 1]} graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3, 2, 1]}
self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4]) self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
class SortTopologicallyBatched(TestCase):
"Test cases for `sorted_topologically_batched`"
def test_empty(self) -> None:
"Test that an empty graph works correctly"
graph: Dict[int, List[int]] = {}
self.assertEqual(list(sorted_topologically_batched([], graph)), [])
def test_handle_empty_graph(self) -> None:
"Test that a graph where a node doesn't have an entry is treated as empty"
graph: Dict[int, List[int]] = {}
# For disconnected nodes the output is simply sorted.
self.assertEqual(list(sorted_topologically_batched([1, 2], graph)), [[1, 2]])
def test_disconnected(self) -> None:
"Test that a graph with no edges work"
graph: Dict[int, List[int]] = {1: [], 2: []}
# For disconnected nodes the output is simply sorted.
self.assertEqual(list(sorted_topologically_batched([1, 2], graph)), [[1, 2]])
def test_linear(self) -> None:
"Test that a simple `4 -> 3 -> 2 -> 1` graph works"
graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3]}
self.assertEqual(
list(sorted_topologically_batched([4, 3, 2, 1], graph)),
[[1], [2], [3], [4]],
)
def test_subset(self) -> None:
"Test that only sorting a subset of the graph works"
graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3]}
self.assertEqual(list(sorted_topologically_batched([4, 3], graph)), [[3], [4]])
def test_fork(self) -> None:
"Test that a forked graph works"
graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [1], 4: [2, 3]}
# Valid orderings are `[1, 3, 2, 4]` or `[1, 2, 3, 4]`, but we should
# always get the same one.
self.assertEqual(
list(sorted_topologically_batched([4, 3, 2, 1], graph)), [[1], [2, 3], [4]]
)
def test_duplicates(self) -> None:
"Test that a graph with duplicate edges work"
graph: Dict[int, List[int]] = {1: [], 2: [1, 1], 3: [2, 2], 4: [3]}
self.assertEqual(
list(sorted_topologically_batched([4, 3, 2, 1], graph)),
[[1], [2], [3], [4]],
)
def test_multiple_paths(self) -> None:
"Test that a graph with multiple paths between two nodes work"
graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3, 2, 1]}
self.assertEqual(
list(sorted_topologically_batched([4, 3, 2, 1], graph)),
[[1], [2], [3], [4]],
)