Avoid creating events with huge numbers of prev_events

In most cases, we limit the number of prev_events for a given event to 10
events. This fixes a particular code path which created events with huge
numbers of prev_events.
pull/3113/head
Richard van der Hoff 2018-04-16 18:41:37 +01:00
parent 512633ef44
commit 639480e14a
4 changed files with 162 additions and 54 deletions

View File

@ -37,7 +37,6 @@ from ._base import BaseHandler
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
import logging import logging
import random
import simplejson import simplejson
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -433,7 +432,7 @@ class EventCreationHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def create_event(self, requester, event_dict, token_id=None, txn_id=None, def create_event(self, requester, event_dict, token_id=None, txn_id=None,
prev_event_ids=None): prev_events_and_hashes=None):
""" """
Given a dict from a client, create a new event. Given a dict from a client, create a new event.
@ -447,7 +446,13 @@ class EventCreationHandler(object):
event_dict (dict): An entire event event_dict (dict): An entire event
token_id (str) token_id (str)
txn_id (str) txn_id (str)
prev_event_ids (list): The prev event ids to use when creating the event
prev_events_and_hashes (list[(str, dict[str, str], int)]|None):
the forward extremities to use as the prev_events for the
new event. For each event, a tuple of (event_id, hashes, depth)
where *hashes* is a map from algorithm to hash.
If None, they will be requested from the database.
Returns: Returns:
Tuple of created event (FrozenEvent), Context Tuple of created event (FrozenEvent), Context
@ -485,7 +490,7 @@ class EventCreationHandler(object):
event, context = yield self.create_new_client_event( event, context = yield self.create_new_client_event(
builder=builder, builder=builder,
requester=requester, requester=requester,
prev_event_ids=prev_event_ids, prev_events_and_hashes=prev_events_and_hashes,
) )
defer.returnValue((event, context)) defer.returnValue((event, context))
@ -588,39 +593,44 @@ class EventCreationHandler(object):
@measure_func("create_new_client_event") @measure_func("create_new_client_event")
@defer.inlineCallbacks @defer.inlineCallbacks
def create_new_client_event(self, builder, requester=None, prev_event_ids=None): def create_new_client_event(self, builder, requester=None,
if prev_event_ids: prev_events_and_hashes=None):
prev_events = yield self.store.add_event_hashes(prev_event_ids) """Create a new event for a local client
prev_max_depth = yield self.store.get_max_depth_of_events(prev_event_ids)
depth = prev_max_depth + 1 Args:
else: builder (EventBuilder):
latest_ret = yield self.store.get_latest_event_ids_and_hashes_in_room(
builder.room_id, requester (synapse.types.Requester|None):
prev_events_and_hashes (list[(str, dict[str, str], int)]|None):
the forward extremities to use as the prev_events for the
new event. For each event, a tuple of (event_id, hashes, depth)
where *hashes* is a map from algorithm to hash.
If None, they will be requested from the database.
Returns:
Deferred[(synapse.events.EventBase, synapse.events.snapshot.EventContext)]
"""
if prev_events_and_hashes is not None:
assert len(prev_events_and_hashes) <= 10, \
"Attempting to create an event with %i prev_events" % (
len(prev_events_and_hashes),
) )
else:
prev_events_and_hashes = \
yield self.store.get_prev_events_for_room(builder.room_id)
# We want to limit the max number of prev events we point to in our if prev_events_and_hashes:
# new event depth = max([d for _, _, d in prev_events_and_hashes]) + 1
if len(latest_ret) > 10: else:
# Sort by reverse depth, so we point to the most recent. depth = 1
latest_ret.sort(key=lambda a: -a[2])
new_latest_ret = latest_ret[:5]
# We also randomly point to some of the older events, to make prev_events = [
# sure that we don't completely ignore the older events. (event_id, prev_hashes)
if latest_ret[5:]: for event_id, prev_hashes, _ in prev_events_and_hashes
sample_size = min(5, len(latest_ret[5:])) ]
new_latest_ret.extend(random.sample(latest_ret[5:], sample_size))
latest_ret = new_latest_ret
if latest_ret:
depth = max([d for _, _, d in latest_ret]) + 1
else:
depth = 1
prev_events = [
(event_id, prev_hashes)
for event_id, prev_hashes, _ in latest_ret
]
builder.prev_events = prev_events builder.prev_events = prev_events
builder.depth = depth builder.depth = depth

View File

@ -149,7 +149,7 @@ class RoomMemberHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _local_membership_update( def _local_membership_update(
self, requester, target, room_id, membership, self, requester, target, room_id, membership,
prev_event_ids, prev_events_and_hashes,
txn_id=None, txn_id=None,
ratelimit=True, ratelimit=True,
content=None, content=None,
@ -175,7 +175,7 @@ class RoomMemberHandler(object):
}, },
token_id=requester.access_token_id, token_id=requester.access_token_id,
txn_id=txn_id, txn_id=txn_id,
prev_event_ids=prev_event_ids, prev_events_and_hashes=prev_events_and_hashes,
) )
# Check if this event matches the previous membership event for the user. # Check if this event matches the previous membership event for the user.
@ -314,7 +314,12 @@ class RoomMemberHandler(object):
403, "Invites have been disabled on this server", 403, "Invites have been disabled on this server",
) )
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) prev_events_and_hashes = yield self.store.get_prev_events_for_room(
room_id,
)
latest_event_ids = (
event_id for (event_id, _, _) in prev_events_and_hashes
)
current_state_ids = yield self.state_handler.get_current_state_ids( current_state_ids = yield self.state_handler.get_current_state_ids(
room_id, latest_event_ids=latest_event_ids, room_id, latest_event_ids=latest_event_ids,
) )
@ -403,7 +408,7 @@ class RoomMemberHandler(object):
membership=effective_membership_state, membership=effective_membership_state,
txn_id=txn_id, txn_id=txn_id,
ratelimit=ratelimit, ratelimit=ratelimit,
prev_event_ids=latest_event_ids, prev_events_and_hashes=prev_events_and_hashes,
content=content, content=content,
) )
defer.returnValue(res) defer.returnValue(res)

View File

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import random
from twisted.internet import defer from twisted.internet import defer
@ -133,7 +134,47 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
retcol="event_id", retcol="event_id",
) )
@defer.inlineCallbacks
def get_prev_events_for_room(self, room_id):
"""
Gets a subset of the current forward extremities in the given room.
Limits the result to 10 extremities, so that we can avoid creating
events which refer to hundreds of prev_events.
Args:
room_id (str): room_id
Returns:
Deferred[list[(str, dict[str, str], int)]]
for each event, a tuple of (event_id, hashes, depth)
where *hashes* is a map from algorithm to hash.
"""
res = yield self.get_latest_event_ids_and_hashes_in_room(room_id)
if len(res) > 10:
# Sort by reverse depth, so we point to the most recent.
res.sort(key=lambda a: -a[2])
# we use half of the limit for the actual most recent events, and
# the other half to randomly point to some of the older events, to
# make sure that we don't completely ignore the older events.
res = res[0:5] + random.sample(res[5:], 5)
defer.returnValue(res)
def get_latest_event_ids_and_hashes_in_room(self, room_id): def get_latest_event_ids_and_hashes_in_room(self, room_id):
"""
Gets the current forward extremities in the given room
Args:
room_id (str): room_id
Returns:
Deferred[list[(str, dict[str, str], int)]]
for each event, a tuple of (event_id, hashes, depth)
where *hashes* is a map from algorithm to hash.
"""
return self.runInteraction( return self.runInteraction(
"get_latest_event_ids_and_hashes_in_room", "get_latest_event_ids_and_hashes_in_room",
self._get_latest_event_ids_and_hashes_in_room, self._get_latest_event_ids_and_hashes_in_room,
@ -182,22 +223,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
room_id, room_id,
) )
@defer.inlineCallbacks
def get_max_depth_of_events(self, event_ids):
sql = (
"SELECT MAX(depth) FROM events WHERE event_id IN (%s)"
) % (",".join(["?"] * len(event_ids)),)
rows = yield self._execute(
"get_max_depth_of_events", None,
sql, *event_ids
)
if rows:
defer.returnValue(rows[0][0])
else:
defer.returnValue(1)
def _get_min_depth_interaction(self, txn, room_id): def _get_min_depth_interaction(self, txn, room_id):
min_depth = self._simple_select_one_onecol_txn( min_depth = self._simple_select_one_onecol_txn(
txn, txn,

View File

@ -0,0 +1,68 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
import tests.unittest
import tests.utils
class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
hs = yield tests.utils.setup_test_homeserver()
self.store = hs.get_datastore()
@defer.inlineCallbacks
def test_get_prev_events_for_room(self):
room_id = '@ROOM:local'
# add a bunch of events and hashes to act as forward extremities
def insert_event(txn, i):
event_id = '$event_%i:local' % i
txn.execute((
"INSERT INTO events ("
" room_id, event_id, type, depth, topological_ordering,"
" content, processed, outlier) "
"VALUES (?, ?, 'm.test', ?, ?, 'test', ?, ?)"
), (room_id, event_id, i, i, True, False))
txn.execute((
'INSERT INTO event_forward_extremities (room_id, event_id) '
'VALUES (?, ?)'
), (room_id, event_id))
txn.execute((
'INSERT INTO event_reference_hashes '
'(event_id, algorithm, hash) '
"VALUES (?, 'sha256', ?)"
), (event_id, 'ffff'))
for i in range(0, 11):
yield self.store.runInteraction("insert", insert_event, i)
# this should get the last five and five others
r = yield self.store.get_prev_events_for_room(room_id)
self.assertEqual(10, len(r))
for i in range(0, 5):
el = r[i]
depth = el[2]
self.assertEqual(10 - i, depth)
for i in range(5, 5):
el = r[i]
depth = el[2]
self.assertLessEqual(5, depth)