Merge branch 'develop' of github.com:matrix-org/synapse into matrix-org-hotfixes

michaelkaye/matrix_org_hotfixes_increase_replication_timeout
Erik Johnston 2020-01-07 14:24:36 +00:00
commit 45bf455948
21 changed files with 274 additions and 175 deletions

1
changelog.d/6629.misc Normal file
View File

@ -0,0 +1 @@
Simplify event creation code by removing redundant queries on the event_reference_hashes table.

1
changelog.d/6645.bugfix Normal file
View File

@ -0,0 +1 @@
Fix exceptions in the synchrotron worker log when events are rejected.

1
changelog.d/6647.misc Normal file
View File

@ -0,0 +1 @@
Port core background update routines to async/await.

1
changelog.d/6648.bugfix Normal file
View File

@ -0,0 +1 @@
Ensure that upgraded rooms are removed from the directory.

1
changelog.d/6652.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a bug causing Synapse not to fetch missing events when it believes it has every event in the room.

1
changelog.d/6653.misc Normal file
View File

@ -0,0 +1 @@
Port core background update routines to async/await.

View File

@ -166,6 +166,11 @@ class Store(
logger.exception("Failed to insert: %s", table) logger.exception("Failed to insert: %s", table)
raise raise
def set_room_is_public(self, room_id, is_public):
raise Exception(
"Attempt to set room_is_public during port_db: database not empty?"
)
class MockHomeserver: class MockHomeserver:
def __init__(self, config): def __init__(self, config):

View File

@ -48,7 +48,7 @@ from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.room import RoomStore from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.tcp.client import ReplicationClientHandler from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.replication.tcp.streams.events import EventsStreamEventRow from synapse.replication.tcp.streams.events import EventsStreamEventRow, EventsStreamRow
from synapse.rest.client.v1 import events from synapse.rest.client.v1 import events
from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet
from synapse.rest.client.v1.room import RoomInitialSyncRestServlet from synapse.rest.client.v1.room import RoomInitialSyncRestServlet
@ -371,8 +371,7 @@ class SyncReplicationHandler(ReplicationClientHandler):
def get_currently_syncing_users(self): def get_currently_syncing_users(self):
return self.presence_handler.get_currently_syncing_users() return self.presence_handler.get_currently_syncing_users()
@defer.inlineCallbacks async def process_and_notify(self, stream_name, token, rows):
def process_and_notify(self, stream_name, token, rows):
try: try:
if stream_name == "events": if stream_name == "events":
# We shouldn't get multiple rows per token for events stream, so # We shouldn't get multiple rows per token for events stream, so
@ -380,7 +379,14 @@ class SyncReplicationHandler(ReplicationClientHandler):
for row in rows: for row in rows:
if row.type != EventsStreamEventRow.TypeId: if row.type != EventsStreamEventRow.TypeId:
continue continue
event = yield self.store.get_event(row.data.event_id) assert isinstance(row, EventsStreamRow)
event = await self.store.get_event(
row.data.event_id, allow_rejected=True
)
if event.rejected_reason:
continue
extra_users = () extra_users = ()
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
extra_users = (event.state_key,) extra_users = (event.state_key,)
@ -412,11 +418,11 @@ class SyncReplicationHandler(ReplicationClientHandler):
elif stream_name == "device_lists": elif stream_name == "device_lists":
all_room_ids = set() all_room_ids = set()
for row in rows: for row in rows:
room_ids = yield self.store.get_rooms_for_user(row.user_id) room_ids = await self.store.get_rooms_for_user(row.user_id)
all_room_ids.update(room_ids) all_room_ids.update(room_ids)
self.notifier.on_new_event("device_list_key", token, rooms=all_room_ids) self.notifier.on_new_event("device_list_key", token, rooms=all_room_ids)
elif stream_name == "presence": elif stream_name == "presence":
yield self.presence_handler.process_replication_rows(token, rows) await self.presence_handler.process_replication_rows(token, rows)
elif stream_name == "receipts": elif stream_name == "receipts":
self.notifier.on_new_event( self.notifier.on_new_event(
"groups_key", token, users=[row.user_id for row in rows] "groups_key", token, users=[row.user_id for row in rows]

View File

@ -248,13 +248,13 @@ class FederationHandler(BaseHandler):
prevs = set(pdu.prev_event_ids()) prevs = set(pdu.prev_event_ids())
seen = await self.store.have_seen_events(prevs) seen = await self.store.have_seen_events(prevs)
if min_depth and pdu.depth < min_depth: if min_depth is not None and pdu.depth < min_depth:
# This is so that we don't notify the user about this # This is so that we don't notify the user about this
# message, to work around the fact that some events will # message, to work around the fact that some events will
# reference really really old events we really don't want to # reference really really old events we really don't want to
# send to the clients. # send to the clients.
pdu.internal_metadata.outlier = True pdu.internal_metadata.outlier = True
elif min_depth and pdu.depth > min_depth: elif min_depth is not None and pdu.depth > min_depth:
missing_prevs = prevs - seen missing_prevs = prevs - seen
if sent_to_us_directly and missing_prevs: if sent_to_us_directly and missing_prevs:
# If we're missing stuff, ensure we only fetch stuff one # If we're missing stuff, ensure we only fetch stuff one

View File

@ -48,7 +48,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.send_event import ReplicationSendEventRestServlet from synapse.replication.http.send_event import ReplicationSendEventRestServlet
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.types import RoomAlias, UserID, create_requester from synapse.types import Collection, RoomAlias, UserID, create_requester
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.frozenutils import frozendict_json_encoder
from synapse.util.metrics import measure_func from synapse.util.metrics import measure_func
@ -422,7 +422,7 @@ class EventCreationHandler(object):
event_dict, event_dict,
token_id=None, token_id=None,
txn_id=None, txn_id=None,
prev_events_and_hashes=None, prev_event_ids: Optional[Collection[str]] = None,
require_consent=True, require_consent=True,
): ):
""" """
@ -439,10 +439,9 @@ class EventCreationHandler(object):
token_id (str) token_id (str)
txn_id (str) txn_id (str)
prev_events_and_hashes (list[(str, dict[str, str], int)]|None): prev_event_ids:
the forward extremities to use as the prev_events for the the forward extremities to use as the prev_events for the
new event. For each event, a tuple of (event_id, hashes, depth) new event.
where *hashes* is a map from algorithm to hash.
If None, they will be requested from the database. If None, they will be requested from the database.
@ -498,9 +497,7 @@ class EventCreationHandler(object):
builder.internal_metadata.txn_id = txn_id builder.internal_metadata.txn_id = txn_id
event, context = yield self.create_new_client_event( event, context = yield self.create_new_client_event(
builder=builder, builder=builder, requester=requester, prev_event_ids=prev_event_ids,
requester=requester,
prev_events_and_hashes=prev_events_and_hashes,
) )
# In an ideal world we wouldn't need the second part of this condition. However, # In an ideal world we wouldn't need the second part of this condition. However,
@ -714,7 +711,7 @@ class EventCreationHandler(object):
@measure_func("create_new_client_event") @measure_func("create_new_client_event")
@defer.inlineCallbacks @defer.inlineCallbacks
def create_new_client_event( def create_new_client_event(
self, builder, requester=None, prev_events_and_hashes=None self, builder, requester=None, prev_event_ids: Optional[Collection[str]] = None
): ):
"""Create a new event for a local client """Create a new event for a local client
@ -723,10 +720,9 @@ class EventCreationHandler(object):
requester (synapse.types.Requester|None): requester (synapse.types.Requester|None):
prev_events_and_hashes (list[(str, dict[str, str], int)]|None): prev_event_ids:
the forward extremities to use as the prev_events for the the forward extremities to use as the prev_events for the
new event. For each event, a tuple of (event_id, hashes, depth) new event.
where *hashes* is a map from algorithm to hash.
If None, they will be requested from the database. If None, they will be requested from the database.
@ -734,22 +730,15 @@ class EventCreationHandler(object):
Deferred[(synapse.events.EventBase, synapse.events.snapshot.EventContext)] Deferred[(synapse.events.EventBase, synapse.events.snapshot.EventContext)]
""" """
if prev_events_and_hashes is not None: if prev_event_ids is not None:
assert len(prev_events_and_hashes) <= 10, ( assert len(prev_event_ids) <= 10, (
"Attempting to create an event with %i prev_events" "Attempting to create an event with %i prev_events"
% (len(prev_events_and_hashes),) % (len(prev_event_ids),)
) )
else: else:
prev_events_and_hashes = yield self.store.get_prev_events_for_room( prev_event_ids = yield self.store.get_prev_events_for_room(builder.room_id)
builder.room_id
)
prev_events = [ event = yield builder.build(prev_event_ids=prev_event_ids)
(event_id, prev_hashes)
for event_id, prev_hashes, _ in prev_events_and_hashes
]
event = yield builder.build(prev_event_ids=[p for p, _ in prev_events])
context = yield self.state.compute_event_context(event) context = yield self.state.compute_event_context(event)
if requester: if requester:
context.app_service = requester.app_service context.app_service = requester.app_service
@ -1042,9 +1031,7 @@ class EventCreationHandler(object):
# For each room we need to find a joined member we can use to send # For each room we need to find a joined member we can use to send
# the dummy event with. # the dummy event with.
prev_events_and_hashes = yield self.store.get_prev_events_for_room(room_id) latest_event_ids = yield self.store.get_prev_events_for_room(room_id)
latest_event_ids = (event_id for (event_id, _, _) in prev_events_and_hashes)
members = yield self.state.get_current_users_in_room( members = yield self.state.get_current_users_in_room(
room_id, latest_event_ids=latest_event_ids room_id, latest_event_ids=latest_event_ids
@ -1063,7 +1050,7 @@ class EventCreationHandler(object):
"room_id": room_id, "room_id": room_id,
"sender": user_id, "sender": user_id,
}, },
prev_events_and_hashes=prev_events_and_hashes, prev_event_ids=latest_event_ids,
) )
event.internal_metadata.proactively_send = False event.internal_metadata.proactively_send = False

View File

@ -25,7 +25,7 @@ from twisted.internet import defer
from synapse import types from synapse import types
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, Codes, SynapseError from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.types import RoomID, UserID from synapse.types import Collection, RoomID, UserID
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_joined_room, user_left_room from synapse.util.distributor import user_joined_room, user_left_room
@ -150,7 +150,7 @@ class RoomMemberHandler(object):
target, target,
room_id, room_id,
membership, membership,
prev_events_and_hashes, prev_event_ids: Collection[str],
txn_id=None, txn_id=None,
ratelimit=True, ratelimit=True,
content=None, content=None,
@ -178,7 +178,7 @@ class RoomMemberHandler(object):
}, },
token_id=requester.access_token_id, token_id=requester.access_token_id,
txn_id=txn_id, txn_id=txn_id,
prev_events_and_hashes=prev_events_and_hashes, prev_event_ids=prev_event_ids,
require_consent=require_consent, require_consent=require_consent,
) )
@ -390,8 +390,7 @@ class RoomMemberHandler(object):
if block_invite: if block_invite:
raise SynapseError(403, "Invites have been disabled on this server") raise SynapseError(403, "Invites have been disabled on this server")
prev_events_and_hashes = yield self.store.get_prev_events_for_room(room_id) latest_event_ids = yield self.store.get_prev_events_for_room(room_id)
latest_event_ids = (event_id for (event_id, _, _) in prev_events_and_hashes)
current_state_ids = yield self.state_handler.get_current_state_ids( current_state_ids = yield self.state_handler.get_current_state_ids(
room_id, latest_event_ids=latest_event_ids room_id, latest_event_ids=latest_event_ids
@ -505,7 +504,7 @@ class RoomMemberHandler(object):
membership=effective_membership_state, membership=effective_membership_state,
txn_id=txn_id, txn_id=txn_id,
ratelimit=ratelimit, ratelimit=ratelimit,
prev_events_and_hashes=prev_events_and_hashes, prev_event_ids=latest_event_ids,
content=content, content=content,
require_consent=require_consent, require_consent=require_consent,
) )

View File

@ -16,6 +16,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
import random import random
from abc import ABCMeta
from six import PY2 from six import PY2
from six.moves import builtins from six.moves import builtins
@ -30,7 +31,8 @@ from synapse.types import get_domain_from_id
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class SQLBaseStore(object): # some of our subclasses have abstract methods, so we use the ABCMeta metaclass.
class SQLBaseStore(metaclass=ABCMeta):
"""Base class for data stores that holds helper functions. """Base class for data stores that holds helper functions.
Note that multiple instances of this class will exist as there will be one Note that multiple instances of this class will exist as there will be one

View File

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Optional
from canonicaljson import json from canonicaljson import json
@ -97,15 +98,14 @@ class BackgroundUpdater(object):
def start_doing_background_updates(self): def start_doing_background_updates(self):
run_as_background_process("background_updates", self.run_background_updates) run_as_background_process("background_updates", self.run_background_updates)
@defer.inlineCallbacks async def run_background_updates(self, sleep=True):
def run_background_updates(self, sleep=True):
logger.info("Starting background schema updates") logger.info("Starting background schema updates")
while True: while True:
if sleep: if sleep:
yield self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0) await self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0)
try: try:
result = yield self.do_next_background_update( result = await self.do_next_background_update(
self.BACKGROUND_UPDATE_DURATION_MS self.BACKGROUND_UPDATE_DURATION_MS
) )
except Exception: except Exception:
@ -170,20 +170,21 @@ class BackgroundUpdater(object):
return not update_exists return not update_exists
@defer.inlineCallbacks async def do_next_background_update(
def do_next_background_update(self, desired_duration_ms): self, desired_duration_ms: float
) -> Optional[int]:
"""Does some amount of work on the next queued background update """Does some amount of work on the next queued background update
Returns once some amount of work is done.
Args: Args:
desired_duration_ms(float): How long we want to spend desired_duration_ms(float): How long we want to spend
updating. updating.
Returns: Returns:
A deferred that completes once some amount of work is done. None if there is no more work to do, otherwise an int
The deferred will have a value of None if there is currently
no more work to do.
""" """
if not self._background_update_queue: if not self._background_update_queue:
updates = yield self.db.simple_select_list( updates = await self.db.simple_select_list(
"background_updates", "background_updates",
keyvalues=None, keyvalues=None,
retcols=("update_name", "depends_on"), retcols=("update_name", "depends_on"),
@ -201,11 +202,12 @@ class BackgroundUpdater(object):
update_name = self._background_update_queue.pop(0) update_name = self._background_update_queue.pop(0)
self._background_update_queue.append(update_name) self._background_update_queue.append(update_name)
res = yield self._do_background_update(update_name, desired_duration_ms) res = await self._do_background_update(update_name, desired_duration_ms)
return res return res
@defer.inlineCallbacks async def _do_background_update(
def _do_background_update(self, update_name, desired_duration_ms): self, update_name: str, desired_duration_ms: float
) -> int:
logger.info("Starting update batch on background update '%s'", update_name) logger.info("Starting update batch on background update '%s'", update_name)
update_handler = self._background_update_handlers[update_name] update_handler = self._background_update_handlers[update_name]
@ -225,7 +227,7 @@ class BackgroundUpdater(object):
else: else:
batch_size = self.DEFAULT_BACKGROUND_BATCH_SIZE batch_size = self.DEFAULT_BACKGROUND_BATCH_SIZE
progress_json = yield self.db.simple_select_one_onecol( progress_json = await self.db.simple_select_one_onecol(
"background_updates", "background_updates",
keyvalues={"update_name": update_name}, keyvalues={"update_name": update_name},
retcol="progress_json", retcol="progress_json",
@ -234,7 +236,7 @@ class BackgroundUpdater(object):
progress = json.loads(progress_json) progress = json.loads(progress_json)
time_start = self._clock.time_msec() time_start = self._clock.time_msec()
items_updated = yield update_handler(progress, batch_size) items_updated = await update_handler(progress, batch_size)
time_stop = self._clock.time_msec() time_stop = self._clock.time_msec()
duration_ms = time_stop - time_start duration_ms = time_stop - time_start
@ -263,7 +265,9 @@ class BackgroundUpdater(object):
* A dict of the current progress * A dict of the current progress
* An integer count of the number of items to update in this batch. * An integer count of the number of items to update in this batch.
The handler should return a deferred integer count of items updated. The handler should return a deferred or coroutine which returns an integer count
of items updated.
The handler is responsible for updating the progress of the update. The handler is responsible for updating the progress of the update.
Args: Args:
@ -432,6 +436,21 @@ class BackgroundUpdater(object):
"background_updates", keyvalues={"update_name": update_name} "background_updates", keyvalues={"update_name": update_name}
) )
def _background_update_progress(self, update_name: str, progress: dict):
"""Update the progress of a background update
Args:
update_name: The name of the background update task
progress: The progress of the update.
"""
return self.db.runInteraction(
"background_update_progress",
self._background_update_progress_txn,
update_name,
progress,
)
def _background_update_progress_txn(self, txn, update_name, progress): def _background_update_progress_txn(self, txn, update_name, progress):
"""Update the progress of a background update """Update the progress of a background update

View File

@ -14,13 +14,10 @@
# limitations under the License. # limitations under the License.
import itertools import itertools
import logging import logging
import random
from six.moves import range from six.moves import range
from six.moves.queue import Empty, PriorityQueue from six.moves.queue import Empty, PriorityQueue
from unpaddedbase64 import encode_base64
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
@ -148,8 +145,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
retcol="event_id", retcol="event_id",
) )
@defer.inlineCallbacks def get_prev_events_for_room(self, room_id: str):
def get_prev_events_for_room(self, room_id):
""" """
Gets a subset of the current forward extremities in the given room. Gets a subset of the current forward extremities in the given room.
@ -160,41 +156,30 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
room_id (str): room_id room_id (str): room_id
Returns: Returns:
Deferred[list[(str, dict[str, str], int)]] Deferred[List[str]]: the event ids of the forward extremites
for each event, a tuple of (event_id, hashes, depth)
where *hashes* is a map from algorithm to hash.
"""
res = yield self.get_latest_event_ids_and_hashes_in_room(room_id)
if len(res) > 10:
# Sort by reverse depth, so we point to the most recent.
res.sort(key=lambda a: -a[2])
# we use half of the limit for the actual most recent events, and
# the other half to randomly point to some of the older events, to
# make sure that we don't completely ignore the older events.
res = res[0:5] + random.sample(res[5:], 5)
return res
def get_latest_event_ids_and_hashes_in_room(self, room_id):
"""
Gets the current forward extremities in the given room
Args:
room_id (str): room_id
Returns:
Deferred[list[(str, dict[str, str], int)]]
for each event, a tuple of (event_id, hashes, depth)
where *hashes* is a map from algorithm to hash.
""" """
return self.db.runInteraction( return self.db.runInteraction(
"get_latest_event_ids_and_hashes_in_room", "get_prev_events_for_room", self._get_prev_events_for_room_txn, room_id
self._get_latest_event_ids_and_hashes_in_room,
room_id,
) )
def _get_prev_events_for_room_txn(self, txn, room_id: str):
# we just use the 10 newest events. Older events will become
# prev_events of future events.
sql = """
SELECT e.event_id FROM event_forward_extremities AS f
INNER JOIN events AS e USING (event_id)
WHERE f.room_id = ?
ORDER BY e.depth DESC
LIMIT 10
"""
txn.execute(sql, (room_id,))
return [row[0] for row in txn]
def get_rooms_with_many_extremities(self, min_count, limit, room_id_filter): def get_rooms_with_many_extremities(self, min_count, limit, room_id_filter):
"""Get the top rooms with at least N extremities. """Get the top rooms with at least N extremities.
@ -243,27 +228,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
desc="get_latest_event_ids_in_room", desc="get_latest_event_ids_in_room",
) )
def _get_latest_event_ids_and_hashes_in_room(self, txn, room_id):
sql = (
"SELECT e.event_id, e.depth FROM events as e "
"INNER JOIN event_forward_extremities as f "
"ON e.event_id = f.event_id "
"AND e.room_id = f.room_id "
"WHERE f.room_id = ?"
)
txn.execute(sql, (room_id,))
results = []
for event_id, depth in txn.fetchall():
hashes = self._get_event_reference_hashes_txn(txn, event_id)
prev_hashes = {
k: encode_base64(v) for k, v in hashes.items() if k == "sha256"
}
results.append((event_id, prev_hashes, depth))
return results
def get_min_depth(self, room_id): def get_min_depth(self, room_id):
""" For hte given room, get the minimum depth we have seen for it. """ For hte given room, get the minimum depth we have seen for it.
""" """
@ -506,7 +470,7 @@ class EventFederationStore(EventFederationWorkerStore):
def _update_min_depth_for_room_txn(self, txn, room_id, depth): def _update_min_depth_for_room_txn(self, txn, room_id, depth):
min_depth = self._get_min_depth_interaction(txn, room_id) min_depth = self._get_min_depth_interaction(txn, room_id)
if min_depth and depth >= min_depth: if min_depth is not None and depth >= min_depth:
return return
self.db.simple_upsert_txn( self.db.simple_upsert_txn(

View File

@ -137,7 +137,7 @@ class EventsWorkerStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_event( def get_event(
self, self,
event_id: List[str], event_id: str,
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
get_prev_content: bool = False, get_prev_content: bool = False,
allow_rejected: bool = False, allow_rejected: bool = False,
@ -148,15 +148,22 @@ class EventsWorkerStore(SQLBaseStore):
Args: Args:
event_id: The event_id of the event to fetch event_id: The event_id of the event to fetch
redact_behaviour: Determine what to do with a redacted event. Possible values: redact_behaviour: Determine what to do with a redacted event. Possible values:
* AS_IS - Return the full event body with no redacted content * AS_IS - Return the full event body with no redacted content
* REDACT - Return the event but with a redacted body * REDACT - Return the event but with a redacted body
* DISALLOW - Do not return redacted events * DISALLOW - Do not return redacted events (behave as per allow_none
if the event is redacted)
get_prev_content: If True and event is a state event, get_prev_content: If True and event is a state event,
include the previous states content in the unsigned field. include the previous states content in the unsigned field.
allow_rejected: If True return rejected events.
allow_rejected: If True, return rejected events. Otherwise,
behave as per allow_none.
allow_none: If True, return None if no event found, if allow_none: If True, return None if no event found, if
False throw a NotFoundError False throw a NotFoundError
check_room_id: if not None, check the room of the found event. check_room_id: if not None, check the room of the found event.
If there is a mismatch, behave as per allow_none. If there is a mismatch, behave as per allow_none.
@ -196,14 +203,18 @@ class EventsWorkerStore(SQLBaseStore):
Args: Args:
event_ids: The event_ids of the events to fetch event_ids: The event_ids of the events to fetch
redact_behaviour: Determine what to do with a redacted event. Possible redact_behaviour: Determine what to do with a redacted event. Possible
values: values:
* AS_IS - Return the full event body with no redacted content * AS_IS - Return the full event body with no redacted content
* REDACT - Return the event but with a redacted body * REDACT - Return the event but with a redacted body
* DISALLOW - Do not return redacted events * DISALLOW - Do not return redacted events (omit them from the response)
get_prev_content: If True and event is a state event, get_prev_content: If True and event is a state event,
include the previous states content in the unsigned field. include the previous states content in the unsigned field.
allow_rejected: If True return rejected events.
allow_rejected: If True, return rejected events. Otherwise,
omits rejeted events from the response.
Returns: Returns:
Deferred : Dict from event_id to event. Deferred : Dict from event_id to event.
@ -228,15 +239,21 @@ class EventsWorkerStore(SQLBaseStore):
"""Get events from the database and return in a list in the same order """Get events from the database and return in a list in the same order
as given by `event_ids` arg. as given by `event_ids` arg.
Unknown events will be omitted from the response.
Args: Args:
event_ids: The event_ids of the events to fetch event_ids: The event_ids of the events to fetch
redact_behaviour: Determine what to do with a redacted event. Possible values: redact_behaviour: Determine what to do with a redacted event. Possible values:
* AS_IS - Return the full event body with no redacted content * AS_IS - Return the full event body with no redacted content
* REDACT - Return the event but with a redacted body * REDACT - Return the event but with a redacted body
* DISALLOW - Do not return redacted events * DISALLOW - Do not return redacted events (omit them from the response)
get_prev_content: If True and event is a state event, get_prev_content: If True and event is a state event,
include the previous states content in the unsigned field. include the previous states content in the unsigned field.
allow_rejected: If True, return rejected events.
allow_rejected: If True, return rejected events. Otherwise,
omits rejected events from the response.
Returns: Returns:
Deferred[list[EventBase]]: List of events fetched from the database. The Deferred[list[EventBase]]: List of events fetched from the database. The
@ -369,9 +386,14 @@ class EventsWorkerStore(SQLBaseStore):
If events are pulled from the database, they will be cached for future lookups. If events are pulled from the database, they will be cached for future lookups.
Unknown events are omitted from the response.
Args: Args:
event_ids (Iterable[str]): The event_ids of the events to fetch event_ids (Iterable[str]): The event_ids of the events to fetch
allow_rejected (bool): Whether to include rejected events
allow_rejected (bool): Whether to include rejected events. If False,
rejected events are omitted from the response.
Returns: Returns:
Deferred[Dict[str, _EventCacheEntry]]: Deferred[Dict[str, _EventCacheEntry]]:
@ -506,9 +528,13 @@ class EventsWorkerStore(SQLBaseStore):
Returned events will be added to the cache for future lookups. Returned events will be added to the cache for future lookups.
Unknown events are omitted from the response.
Args: Args:
event_ids (Iterable[str]): The event_ids of the events to fetch event_ids (Iterable[str]): The event_ids of the events to fetch
allow_rejected (bool): Whether to include rejected events
allow_rejected (bool): Whether to include rejected events. If False,
rejected events are omitted from the response.
Returns: Returns:
Deferred[Dict[str, _EventCacheEntry]]: Deferred[Dict[str, _EventCacheEntry]]:

View File

@ -17,6 +17,7 @@
import collections import collections
import logging import logging
import re import re
from abc import abstractmethod
from typing import Optional, Tuple from typing import Optional, Tuple
from six import integer_types from six import integer_types
@ -367,6 +368,8 @@ class RoomWorkerStore(SQLBaseStore):
class RoomBackgroundUpdateStore(SQLBaseStore): class RoomBackgroundUpdateStore(SQLBaseStore):
REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: Database, db_conn, hs):
super(RoomBackgroundUpdateStore, self).__init__(database, db_conn, hs) super(RoomBackgroundUpdateStore, self).__init__(database, db_conn, hs)
@ -376,6 +379,11 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
"insert_room_retention", self._background_insert_retention, "insert_room_retention", self._background_insert_retention,
) )
self.db.updates.register_background_update_handler(
self.REMOVE_TOMESTONED_ROOMS_BG_UPDATE,
self._remove_tombstoned_rooms_from_directory,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _background_insert_retention(self, progress, batch_size): def _background_insert_retention(self, progress, batch_size):
"""Retrieves a list of all rooms within a range and inserts an entry for each of """Retrieves a list of all rooms within a range and inserts an entry for each of
@ -444,6 +452,62 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
defer.returnValue(batch_size) defer.returnValue(batch_size)
async def _remove_tombstoned_rooms_from_directory(
self, progress, batch_size
) -> int:
"""Removes any rooms with tombstone events from the room directory
Nowadays this is handled by the room upgrade handler, but we may have some
that got left behind
"""
last_room = progress.get("room_id", "")
def _get_rooms(txn):
txn.execute(
"""
SELECT room_id
FROM rooms r
INNER JOIN current_state_events cse USING (room_id)
WHERE room_id > ? AND r.is_public
AND cse.type = '%s' AND cse.state_key = ''
ORDER BY room_id ASC
LIMIT ?;
"""
% EventTypes.Tombstone,
(last_room, batch_size),
)
return [row[0] for row in txn]
rooms = await self.db.runInteraction(
"get_tombstoned_directory_rooms", _get_rooms
)
if not rooms:
await self.db.updates._end_background_update(
self.REMOVE_TOMESTONED_ROOMS_BG_UPDATE
)
return 0
for room_id in rooms:
logger.info("Removing tombstoned room %s from the directory", room_id)
await self.set_room_is_public(room_id, False)
await self.db.updates._background_update_progress(
self.REMOVE_TOMESTONED_ROOMS_BG_UPDATE, {"room_id": rooms[-1]}
)
return len(rooms)
@abstractmethod
def set_room_is_public(self, room_id, is_public):
# this will need to be implemented if a background update is performed with
# existing (tombstoned, public) rooms in the database.
#
# It's overridden by RoomStore for the synapse master.
raise NotImplementedError()
class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: Database, db_conn, hs):

View File

@ -0,0 +1,18 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Now that #6232 is a thing, we can remove old rooms from the directory.
INSERT INTO background_updates (update_name, progress_json) VALUES
('remove_tombstoned_rooms_from_directory', '{}');

View File

@ -15,6 +15,7 @@
# limitations under the License. # limitations under the License.
import re import re
import string import string
import sys
from collections import namedtuple from collections import namedtuple
import attr import attr
@ -23,6 +24,17 @@ from unpaddedbase64 import decode_base64
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
# define a version of typing.Collection that works on python 3.5
if sys.version_info[:3] >= (3, 6, 0):
from typing import Collection
else:
from typing import Sized, Iterable, Container, TypeVar
T_co = TypeVar("T_co", covariant=True)
class Collection(Iterable[T_co], Container[T_co], Sized):
__slots__ = ()
class Requester( class Requester(
namedtuple( namedtuple(

View File

@ -2,44 +2,37 @@ from mock import Mock
from twisted.internet import defer from twisted.internet import defer
from synapse.storage.background_updates import BackgroundUpdater
from tests import unittest from tests import unittest
from tests.utils import setup_test_homeserver
class BackgroundUpdateTestCase(unittest.TestCase): class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
@defer.inlineCallbacks def prepare(self, reactor, clock, homeserver):
def setUp(self): self.updates = self.hs.get_datastore().db.updates # type: BackgroundUpdater
hs = yield setup_test_homeserver(self.addCleanup) # the base test class should have run the real bg updates for us
self.store = hs.get_datastore() self.assertTrue(self.updates.has_completed_background_updates())
self.clock = hs.get_clock()
self.update_handler = Mock() self.update_handler = Mock()
self.updates.register_background_update_handler(
yield self.store.db.updates.register_background_update_handler(
"test_update", self.update_handler "test_update", self.update_handler
) )
# run the real background updates, to get them out the way
# (perhaps we should run them as part of the test HS setup, since we
# run all of the other schema setup stuff there?)
while True:
res = yield self.store.db.updates.do_next_background_update(1000)
if res is None:
break
@defer.inlineCallbacks
def test_do_background_update(self): def test_do_background_update(self):
desired_count = 1000 # the time we claim each update takes
duration_ms = 42 duration_ms = 42
# the target runtime for each bg update
target_background_update_duration_ms = 50000
# first step: make a bit of progress # first step: make a bit of progress
@defer.inlineCallbacks @defer.inlineCallbacks
def update(progress, count): def update(progress, count):
self.clock.advance_time_msec(count * duration_ms) yield self.clock.sleep((count * duration_ms) / 1000)
progress = {"my_key": progress["my_key"] + 1} progress = {"my_key": progress["my_key"] + 1}
yield self.store.db.runInteraction( yield self.hs.get_datastore().db.runInteraction(
"update_progress", "update_progress",
self.store.db.updates._background_update_progress_txn, self.updates._background_update_progress_txn,
"test_update", "test_update",
progress, progress,
) )
@ -47,37 +40,46 @@ class BackgroundUpdateTestCase(unittest.TestCase):
self.update_handler.side_effect = update self.update_handler.side_effect = update
yield self.store.db.updates.start_background_update( self.get_success(
"test_update", {"my_key": 1} self.updates.start_background_update("test_update", {"my_key": 1})
) )
self.update_handler.reset_mock() self.update_handler.reset_mock()
result = yield self.store.db.updates.do_next_background_update( res = self.get_success(
duration_ms * desired_count self.updates.do_next_background_update(
target_background_update_duration_ms
),
by=0.1,
) )
self.assertIsNotNone(result) self.assertIsNotNone(res)
# on the first call, we should get run with the default background update size
self.update_handler.assert_called_once_with( self.update_handler.assert_called_once_with(
{"my_key": 1}, self.store.db.updates.DEFAULT_BACKGROUND_BATCH_SIZE {"my_key": 1}, self.updates.DEFAULT_BACKGROUND_BATCH_SIZE
) )
# second step: complete the update # second step: complete the update
# we should now get run with a much bigger number of items to update
@defer.inlineCallbacks @defer.inlineCallbacks
def update(progress, count): def update(progress, count):
yield self.store.db.updates._end_background_update("test_update") self.assertEqual(progress, {"my_key": 2})
self.assertAlmostEqual(
count, target_background_update_duration_ms / duration_ms, places=0,
)
yield self.updates._end_background_update("test_update")
return count return count
self.update_handler.side_effect = update self.update_handler.side_effect = update
self.update_handler.reset_mock() self.update_handler.reset_mock()
result = yield self.store.db.updates.do_next_background_update( result = self.get_success(
duration_ms * desired_count self.updates.do_next_background_update(target_background_update_duration_ms)
) )
self.assertIsNotNone(result) self.assertIsNotNone(result)
self.update_handler.assert_called_once_with({"my_key": 2}, desired_count) self.update_handler.assert_called_once()
# third step: we don't expect to be called any more # third step: we don't expect to be called any more
self.update_handler.reset_mock() self.update_handler.reset_mock()
result = yield self.store.db.updates.do_next_background_update( result = self.get_success(
duration_ms * desired_count self.updates.do_next_background_update(target_background_update_duration_ms)
) )
self.assertIsNone(result) self.assertIsNone(result)
self.assertFalse(self.update_handler.called) self.assertFalse(self.update_handler.called)

View File

@ -60,21 +60,14 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
(event_id, bytearray(b"ffff")), (event_id, bytearray(b"ffff")),
) )
for i in range(0, 11): for i in range(0, 20):
yield self.store.db.runInteraction("insert", insert_event, i) yield self.store.db.runInteraction("insert", insert_event, i)
# this should get the last five and five others # this should get the last ten
r = yield self.store.get_prev_events_for_room(room_id) r = yield self.store.get_prev_events_for_room(room_id)
self.assertEqual(10, len(r)) self.assertEqual(10, len(r))
for i in range(0, 5): for i in range(0, 10):
el = r[i] self.assertEqual("$event_%i:local" % (19 - i), r[i])
depth = el[2]
self.assertEqual(10 - i, depth)
for i in range(5, 5):
el = r[i]
depth = el[2]
self.assertLessEqual(5, depth)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_rooms_with_many_extremities(self): def test_get_rooms_with_many_extremities(self):

View File

@ -531,10 +531,6 @@ class HomeserverTestCase(TestCase):
secrets = self.hs.get_secrets() secrets = self.hs.get_secrets()
requester = Requester(user, None, False, None, None) requester = Requester(user, None, False, None, None)
prev_events_and_hashes = None
if prev_event_ids:
prev_events_and_hashes = [[p, {}, 0] for p in prev_event_ids]
event, context = self.get_success( event, context = self.get_success(
event_creator.create_event( event_creator.create_event(
requester, requester,
@ -544,7 +540,7 @@ class HomeserverTestCase(TestCase):
"sender": user.to_string(), "sender": user.to_string(),
"content": {"body": secrets.token_hex(), "msgtype": "m.text"}, "content": {"body": secrets.token_hex(), "msgtype": "m.text"},
}, },
prev_events_and_hashes=prev_events_and_hashes, prev_event_ids=prev_event_ids,
) )
) )