Merge branch 'develop' into matrix-org-hotfixes

pull/8675/head
Brendan Abolivier 2020-09-04 11:02:10 +01:00
commit cc23d81a74
33 changed files with 400 additions and 379 deletions

View File

@ -1 +0,0 @@
Add experimental support for sharding event persister.

1
changelog.d/8233.misc Normal file
View File

@ -0,0 +1 @@
Refactor queries for device keys and cross-signatures.

1
changelog.d/8240.misc Normal file
View File

@ -0,0 +1 @@
Fix type hints for functions decorated with `@cached`.

1
changelog.d/8242.feature Normal file
View File

@ -0,0 +1 @@
Back out experimental support for sharding event persister. **PLEASE REMOVE THIS LINE FROM THE FINAL CHANGELOG**

1
changelog.d/8244.misc Normal file
View File

@ -0,0 +1 @@
Add type hints to pagination, initial sync and events handlers.

1
changelog.d/8245.misc Normal file
View File

@ -0,0 +1 @@
Remove obsolete `order` field from federation send queues.

View File

@ -1,6 +1,6 @@
[mypy] [mypy]
namespace_packages = True namespace_packages = True
plugins = mypy_zope:plugin plugins = mypy_zope:plugin, scripts-dev/mypy_synapse_plugin.py
follow_imports = silent follow_imports = silent
check_untyped_defs = True check_untyped_defs = True
show_error_codes = True show_error_codes = True
@ -17,10 +17,13 @@ files =
synapse/handlers/auth.py, synapse/handlers/auth.py,
synapse/handlers/cas_handler.py, synapse/handlers/cas_handler.py,
synapse/handlers/directory.py, synapse/handlers/directory.py,
synapse/handlers/events.py,
synapse/handlers/federation.py, synapse/handlers/federation.py,
synapse/handlers/identity.py, synapse/handlers/identity.py,
synapse/handlers/initial_sync.py,
synapse/handlers/message.py, synapse/handlers/message.py,
synapse/handlers/oidc_handler.py, synapse/handlers/oidc_handler.py,
synapse/handlers/pagination.py,
synapse/handlers/presence.py, synapse/handlers/presence.py,
synapse/handlers/room.py, synapse/handlers/room.py,
synapse/handlers/room_member.py, synapse/handlers/room_member.py,
@ -51,6 +54,7 @@ files =
synapse/storage/util, synapse/storage/util,
synapse/streams, synapse/streams,
synapse/types.py, synapse/types.py,
synapse/util/caches/descriptors.py,
synapse/util/caches/stream_change_cache.py, synapse/util/caches/stream_change_cache.py,
synapse/util/metrics.py, synapse/util/metrics.py,
tests/replication, tests/replication,

View File

@ -0,0 +1,85 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This is a mypy plugin for Synpase to deal with some of the funky typing that
can crop up, e.g the cache descriptors.
"""
from typing import Callable, Optional
from mypy.plugin import MethodSigContext, Plugin
from mypy.typeops import bind_self
from mypy.types import CallableType
class SynapsePlugin(Plugin):
def get_method_signature_hook(
self, fullname: str
) -> Optional[Callable[[MethodSigContext], CallableType]]:
if fullname.startswith(
"synapse.util.caches.descriptors._CachedFunction.__call__"
):
return cached_function_method_signature
return None
def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
"""Fixes the `_CachedFunction.__call__` signature to be correct.
It already has *almost* the correct signature, except:
1. the `self` argument needs to be marked as "bound"; and
2. any `cache_context` argument should be removed.
"""
# First we mark this as a bound function signature.
signature = bind_self(ctx.default_signature)
# Secondly, we remove any "cache_context" args.
#
# Note: We should be only doing this if `cache_context=True` is set, but if
# it isn't then the code will raise an exception when its called anyway, so
# its not the end of the world.
context_arg_index = None
for idx, name in enumerate(signature.arg_names):
if name == "cache_context":
context_arg_index = idx
break
if context_arg_index:
arg_types = list(signature.arg_types)
arg_types.pop(context_arg_index)
arg_names = list(signature.arg_names)
arg_names.pop(context_arg_index)
arg_kinds = list(signature.arg_kinds)
arg_kinds.pop(context_arg_index)
signature = signature.copy_modified(
arg_types=arg_types, arg_names=arg_names, arg_kinds=arg_kinds,
)
return signature
def plugin(version: str):
# This is the entry point of the plugin, and let's us deal with the fact
# that the mypy plugin interface is *not* stable by looking at the version
# string.
#
# However, since we pin the version of mypy Synapse uses in CI, we don't
# really care.
return SynapsePlugin

View File

@ -832,26 +832,11 @@ class ShardedWorkerHandlingConfig:
def should_handle(self, instance_name: str, key: str) -> bool: def should_handle(self, instance_name: str, key: str) -> bool:
"""Whether this instance is responsible for handling the given key. """Whether this instance is responsible for handling the given key.
""" """
# If multiple instances are not defined we always return true
# If multiple instances are not defined we always return true.
if not self.instances or len(self.instances) == 1: if not self.instances or len(self.instances) == 1:
return True return True
return self.get_instance(key) == instance_name
def get_instance(self, key: str) -> str:
"""Get the instance responsible for handling the given key.
Note: For things like federation sending the config for which instance
is sending is known only to the sender instance if there is only one.
Therefore `should_handle` should be used where possible.
"""
if not self.instances:
return "master"
if len(self.instances) == 1:
return self.instances[0]
# We shard by taking the hash, modulo it by the number of instances and # We shard by taking the hash, modulo it by the number of instances and
# then checking whether this instance matches the instance at that # then checking whether this instance matches the instance at that
# index. # index.
@ -861,7 +846,7 @@ class ShardedWorkerHandlingConfig:
dest_hash = sha256(key.encode("utf8")).digest() dest_hash = sha256(key.encode("utf8")).digest()
dest_int = int.from_bytes(dest_hash, byteorder="little") dest_int = int.from_bytes(dest_hash, byteorder="little")
remainder = dest_int % (len(self.instances)) remainder = dest_int % (len(self.instances))
return self.instances[remainder] return self.instances[remainder] == instance_name
__all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"] __all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"]

View File

@ -142,4 +142,3 @@ class ShardedWorkerHandlingConfig:
instances: List[str] instances: List[str]
def __init__(self, instances: List[str]) -> None: ... def __init__(self, instances: List[str]) -> None: ...
def should_handle(self, instance_name: str, key: str) -> bool: ... def should_handle(self, instance_name: str, key: str) -> bool: ...
def get_instance(self, key: str) -> str: ...

View File

@ -13,24 +13,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List, Union
import attr import attr
from ._base import Config, ConfigError, ShardedWorkerHandlingConfig from ._base import Config, ConfigError, ShardedWorkerHandlingConfig
from .server import ListenerConfig, parse_listener_def from .server import ListenerConfig, parse_listener_def
def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]:
"""Helper for allowing parsing a string or list of strings to a config
option expecting a list of strings.
"""
if isinstance(obj, str):
return [obj]
return obj
@attr.s @attr.s
class InstanceLocationConfig: class InstanceLocationConfig:
"""The host and port to talk to an instance via HTTP replication. """The host and port to talk to an instance via HTTP replication.
@ -45,13 +33,11 @@ class WriterLocations:
"""Specifies the instances that write various streams. """Specifies the instances that write various streams.
Attributes: Attributes:
events: The instances that write to the event and backfill streams. events: The instance that writes to the event and backfill streams.
typing: The instance that writes to the typing stream. events: The instance that writes to the typing stream.
""" """
events = attr.ib( events = attr.ib(default="master", type=str)
default=["master"], type=List[str], converter=_instance_to_list_converter
)
typing = attr.ib(default="master", type=str) typing = attr.ib(default="master", type=str)
@ -119,18 +105,15 @@ class WorkerConfig(Config):
writers = config.get("stream_writers") or {} writers = config.get("stream_writers") or {}
self.writers = WriterLocations(**writers) self.writers = WriterLocations(**writers)
# Check that the configured writers for events and typing also appears in # Check that the configured writer for events and typing also appears in
# `instance_map`. # `instance_map`.
for stream in ("events", "typing"): for stream in ("events", "typing"):
instances = _instance_to_list_converter(getattr(self.writers, stream)) instance = getattr(self.writers, stream)
for instance in instances: if instance != "master" and instance not in self.instance_map:
if instance != "master" and instance not in self.instance_map: raise ConfigError(
raise ConfigError( "Instance %r is configured to write %s but does not appear in `instance_map` config."
"Instance %r is configured to write %s but does not appear in `instance_map` config." % (instance, stream)
% (instance, stream) )
)
self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events)
def generate_config_section(self, config_dir_path, server_name, **kwargs): def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """\ return """\

View File

@ -108,8 +108,6 @@ class FederationSender(object):
), ),
) )
self._order = 1
self._is_processing = False self._is_processing = False
self._last_poked_id = -1 self._last_poked_id = -1
@ -290,9 +288,6 @@ class FederationSender(object):
# a transaction in progress. If we do, stick it in the pending_pdus # a transaction in progress. If we do, stick it in the pending_pdus
# table and we'll get back to it later. # table and we'll get back to it later.
order = self._order
self._order += 1
destinations = set(destinations) destinations = set(destinations)
destinations.discard(self.server_name) destinations.discard(self.server_name)
logger.debug("Sending to: %s", str(destinations)) logger.debug("Sending to: %s", str(destinations))
@ -304,7 +299,7 @@ class FederationSender(object):
sent_pdus_destination_dist_count.inc() sent_pdus_destination_dist_count.inc()
for destination in destinations: for destination in destinations:
self._get_per_destination_queue(destination).send_pdu(pdu, order) self._get_per_destination_queue(destination).send_pdu(pdu)
async def send_read_receipt(self, receipt: ReadReceipt) -> None: async def send_read_receipt(self, receipt: ReadReceipt) -> None:
"""Send a RR to any other servers in the room """Send a RR to any other servers in the room

View File

@ -95,8 +95,8 @@ class PerDestinationQueue(object):
self._destination = destination self._destination = destination
self.transmission_loop_running = False self.transmission_loop_running = False
# a list of tuples of (pending pdu, order) # a list of pending PDUs
self._pending_pdus = [] # type: List[Tuple[EventBase, int]] self._pending_pdus = [] # type: List[EventBase]
# XXX this is never actually used: see # XXX this is never actually used: see
# https://github.com/matrix-org/synapse/issues/7549 # https://github.com/matrix-org/synapse/issues/7549
@ -135,14 +135,13 @@ class PerDestinationQueue(object):
+ len(self._pending_edus_keyed) + len(self._pending_edus_keyed)
) )
def send_pdu(self, pdu: EventBase, order: int) -> None: def send_pdu(self, pdu: EventBase) -> None:
"""Add a PDU to the queue, and start the transmission loop if necessary """Add a PDU to the queue, and start the transmission loop if necessary
Args: Args:
pdu: pdu to send pdu: pdu to send
order
""" """
self._pending_pdus.append((pdu, order)) self._pending_pdus.append(pdu)
self.attempt_new_transaction() self.attempt_new_transaction()
def send_presence(self, states: Iterable[UserPresenceState]) -> None: def send_presence(self, states: Iterable[UserPresenceState]) -> None:
@ -188,7 +187,7 @@ class PerDestinationQueue(object):
returns immediately. Otherwise kicks off the process of sending a returns immediately. Otherwise kicks off the process of sending a
transaction in the background. transaction in the background.
""" """
# list of (pending_pdu, deferred, order)
if self.transmission_loop_running: if self.transmission_loop_running:
# XXX: this can get stuck on by a never-ending # XXX: this can get stuck on by a never-ending
# request at which point pending_pdus just keeps growing. # request at which point pending_pdus just keeps growing.
@ -213,7 +212,7 @@ class PerDestinationQueue(object):
) )
async def _transaction_transmission_loop(self) -> None: async def _transaction_transmission_loop(self) -> None:
pending_pdus = [] # type: List[Tuple[EventBase, int]] pending_pdus = [] # type: List[EventBase]
try: try:
self.transmission_loop_running = True self.transmission_loop_running = True
@ -388,13 +387,13 @@ class PerDestinationQueue(object):
"TX [%s] Failed to send transaction: %s", self._destination, e "TX [%s] Failed to send transaction: %s", self._destination, e
) )
for p, _ in pending_pdus: for p in pending_pdus:
logger.info( logger.info(
"Failed to send event %s to %s", p.event_id, self._destination "Failed to send event %s to %s", p.event_id, self._destination
) )
except Exception: except Exception:
logger.exception("TX [%s] Failed to send transaction", self._destination) logger.exception("TX [%s] Failed to send transaction", self._destination)
for p, _ in pending_pdus: for p in pending_pdus:
logger.info( logger.info(
"Failed to send event %s to %s", p.event_id, self._destination "Failed to send event %s to %s", p.event_id, self._destination
) )

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, List, Tuple from typing import TYPE_CHECKING, List
from synapse.api.errors import HttpResponseException from synapse.api.errors import HttpResponseException
from synapse.events import EventBase from synapse.events import EventBase
@ -57,11 +57,17 @@ class TransactionManager(object):
@measure_func("_send_new_transaction") @measure_func("_send_new_transaction")
async def send_new_transaction( async def send_new_transaction(
self, self, destination: str, pdus: List[EventBase], edus: List[Edu],
destination: str, ) -> bool:
pending_pdus: List[Tuple[EventBase, int]], """
pending_edus: List[Edu], Args:
): destination: The destination to send to (e.g. 'example.org')
pdus: In-order list of PDUs to send
edus: List of EDUs to send
Returns:
True iff the transaction was successful
"""
# Make a transaction-sending opentracing span. This span follows on from # Make a transaction-sending opentracing span. This span follows on from
# all the edus in that transaction. This needs to be done since there is # all the edus in that transaction. This needs to be done since there is
@ -71,7 +77,7 @@ class TransactionManager(object):
span_contexts = [] span_contexts = []
keep_destination = whitelisted_homeserver(destination) keep_destination = whitelisted_homeserver(destination)
for edu in pending_edus: for edu in edus:
context = edu.get_context() context = edu.get_context()
if context: if context:
span_contexts.append(extract_text_map(json_decoder.decode(context))) span_contexts.append(extract_text_map(json_decoder.decode(context)))
@ -79,12 +85,6 @@ class TransactionManager(object):
edu.strip_context() edu.strip_context()
with start_active_span_follows_from("send_transaction", span_contexts): with start_active_span_follows_from("send_transaction", span_contexts):
# Sort based on the order field
pending_pdus.sort(key=lambda t: t[1])
pdus = [x[0] for x in pending_pdus]
edus = pending_edus
success = True success = True
logger.debug("TX [%s] _attempt_new_transaction", destination) logger.debug("TX [%s] _attempt_new_transaction", destination)

View File

@ -15,29 +15,30 @@
import logging import logging
import random import random
from typing import TYPE_CHECKING, Iterable, List, Optional
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, SynapseError from synapse.api.errors import AuthError, SynapseError
from synapse.events import EventBase from synapse.events import EventBase
from synapse.handlers.presence import format_user_presence_state from synapse.handlers.presence import format_user_presence_state
from synapse.logging.utils import log_function from synapse.logging.utils import log_function
from synapse.types import UserID from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict, UserID
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
from ._base import BaseHandler from ._base import BaseHandler
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class EventStreamHandler(BaseHandler): class EventStreamHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super(EventStreamHandler, self).__init__(hs) super(EventStreamHandler, self).__init__(hs)
# Count of active streams per user
self._streams_per_user = {}
# Grace timers per user to delay the "stopped" signal
self._stop_timer_per_user = {}
self.distributor = hs.get_distributor() self.distributor = hs.get_distributor()
self.distributor.declare("started_user_eventstream") self.distributor.declare("started_user_eventstream")
self.distributor.declare("stopped_user_eventstream") self.distributor.declare("stopped_user_eventstream")
@ -52,14 +53,14 @@ class EventStreamHandler(BaseHandler):
@log_function @log_function
async def get_stream( async def get_stream(
self, self,
auth_user_id, auth_user_id: str,
pagin_config, pagin_config: PaginationConfig,
timeout=0, timeout: int = 0,
as_client_event=True, as_client_event: bool = True,
affect_presence=True, affect_presence: bool = True,
room_id=None, room_id: Optional[str] = None,
is_guest=False, is_guest: bool = False,
): ) -> JsonDict:
"""Fetches the events stream for a given user. """Fetches the events stream for a given user.
""" """
@ -98,7 +99,7 @@ class EventStreamHandler(BaseHandler):
# When the user joins a new room, or another user joins a currently # When the user joins a new room, or another user joins a currently
# joined room, we need to send down presence for those users. # joined room, we need to send down presence for those users.
to_add = [] to_add = [] # type: List[JsonDict]
for event in events: for event in events:
if not isinstance(event, EventBase): if not isinstance(event, EventBase):
continue continue
@ -110,7 +111,7 @@ class EventStreamHandler(BaseHandler):
# Send down presence for everyone in the room. # Send down presence for everyone in the room.
users = await self.state.get_current_users_in_room( users = await self.state.get_current_users_in_room(
event.room_id event.room_id
) ) # type: Iterable[str]
else: else:
users = [event.state_key] users = [event.state_key]
@ -144,20 +145,22 @@ class EventStreamHandler(BaseHandler):
class EventHandler(BaseHandler): class EventHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super(EventHandler, self).__init__(hs) super(EventHandler, self).__init__(hs)
self.storage = hs.get_storage() self.storage = hs.get_storage()
async def get_event(self, user, room_id, event_id): async def get_event(
self, user: UserID, room_id: Optional[str], event_id: str
) -> Optional[EventBase]:
"""Retrieve a single specified event. """Retrieve a single specified event.
Args: Args:
user (synapse.types.UserID): The user requesting the event user: The user requesting the event
room_id (str|None): The expected room id. We'll return None if the room_id: The expected room id. We'll return None if the
event's room does not match. event's room does not match.
event_id (str): The event ID to obtain. event_id: The event ID to obtain.
Returns: Returns:
dict: An event, or None if there is no event matching this ID. An event, or None if there is no event matching this ID.
Raises: Raises:
SynapseError if there was a problem retrieving this event, or SynapseError if there was a problem retrieving this event, or
AuthError if the user does not have the rights to inspect this AuthError if the user does not have the rights to inspect this

View File

@ -440,11 +440,11 @@ class FederationHandler(BaseHandler):
if not prevs - seen: if not prevs - seen:
return return
latest = await self.store.get_latest_event_ids_in_room(room_id) latest_list = await self.store.get_latest_event_ids_in_room(room_id)
# We add the prev events that we have seen to the latest # We add the prev events that we have seen to the latest
# list to ensure the remote server doesn't give them to us # list to ensure the remote server doesn't give them to us
latest = set(latest) latest = set(latest_list)
latest |= seen latest |= seen
logger.info( logger.info(
@ -781,7 +781,7 @@ class FederationHandler(BaseHandler):
# keys across all devices. # keys across all devices.
current_keys = [ current_keys = [
key key
for device in cached_devices for device in cached_devices.values()
for key in device.get("keys", {}).get("keys", {}).values() for key in device.get("keys", {}).get("keys", {}).values()
] ]
@ -923,8 +923,7 @@ class FederationHandler(BaseHandler):
) )
) )
if ev_infos: await self._handle_new_events(dest, ev_infos, backfilled=True)
await self._handle_new_events(dest, room_id, ev_infos, backfilled=True)
# Step 2: Persist the rest of the events in the chunk one by one # Step 2: Persist the rest of the events in the chunk one by one
events.sort(key=lambda e: e.depth) events.sort(key=lambda e: e.depth)
@ -1217,7 +1216,7 @@ class FederationHandler(BaseHandler):
event_infos.append(_NewEventInfo(event, None, auth)) event_infos.append(_NewEventInfo(event, None, auth))
await self._handle_new_events( await self._handle_new_events(
destination, room_id, event_infos, destination, event_infos,
) )
def _sanity_check_event(self, ev): def _sanity_check_event(self, ev):
@ -1364,15 +1363,15 @@ class FederationHandler(BaseHandler):
) )
max_stream_id = await self._persist_auth_tree( max_stream_id = await self._persist_auth_tree(
origin, room_id, auth_chain, state, event, room_version_obj origin, auth_chain, state, event, room_version_obj
) )
# We wait here until this instance has seen the events come down # We wait here until this instance has seen the events come down
# replication (if we're using replication) as the below uses caches. # replication (if we're using replication) as the below uses caches.
#
# TODO: Currently the events stream is written to from master
await self._replication.wait_for_stream_position( await self._replication.wait_for_stream_position(
self.config.worker.events_shard_config.get_instance(room_id), self.config.worker.writers.events, "events", max_stream_id
"events",
max_stream_id,
) )
# Check whether this room is the result of an upgrade of a room we already know # Check whether this room is the result of an upgrade of a room we already know
@ -1626,7 +1625,7 @@ class FederationHandler(BaseHandler):
) )
context = await self.state_handler.compute_event_context(event) context = await self.state_handler.compute_event_context(event)
await self.persist_events_and_notify(event.room_id, [(event, context)]) await self.persist_events_and_notify([(event, context)])
return event return event
@ -1653,9 +1652,7 @@ class FederationHandler(BaseHandler):
await self.federation_client.send_leave(host_list, event) await self.federation_client.send_leave(host_list, event)
context = await self.state_handler.compute_event_context(event) context = await self.state_handler.compute_event_context(event)
stream_id = await self.persist_events_and_notify( stream_id = await self.persist_events_and_notify([(event, context)])
event.room_id, [(event, context)]
)
return event, stream_id return event, stream_id
@ -1903,7 +1900,7 @@ class FederationHandler(BaseHandler):
) )
await self.persist_events_and_notify( await self.persist_events_and_notify(
event.room_id, [(event, context)], backfilled=backfilled [(event, context)], backfilled=backfilled
) )
except Exception: except Exception:
run_in_background( run_in_background(
@ -1916,7 +1913,6 @@ class FederationHandler(BaseHandler):
async def _handle_new_events( async def _handle_new_events(
self, self,
origin: str, origin: str,
room_id: str,
event_infos: Iterable[_NewEventInfo], event_infos: Iterable[_NewEventInfo],
backfilled: bool = False, backfilled: bool = False,
) -> None: ) -> None:
@ -1948,7 +1944,6 @@ class FederationHandler(BaseHandler):
) )
await self.persist_events_and_notify( await self.persist_events_and_notify(
room_id,
[ [
(ev_info.event, context) (ev_info.event, context)
for ev_info, context in zip(event_infos, contexts) for ev_info, context in zip(event_infos, contexts)
@ -1959,7 +1954,6 @@ class FederationHandler(BaseHandler):
async def _persist_auth_tree( async def _persist_auth_tree(
self, self,
origin: str, origin: str,
room_id: str,
auth_events: List[EventBase], auth_events: List[EventBase],
state: List[EventBase], state: List[EventBase],
event: EventBase, event: EventBase,
@ -1974,7 +1968,6 @@ class FederationHandler(BaseHandler):
Args: Args:
origin: Where the events came from origin: Where the events came from
room_id,
auth_events auth_events
state state
event event
@ -2049,20 +2042,17 @@ class FederationHandler(BaseHandler):
events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR
await self.persist_events_and_notify( await self.persist_events_and_notify(
room_id,
[ [
(e, events_to_context[e.event_id]) (e, events_to_context[e.event_id])
for e in itertools.chain(auth_events, state) for e in itertools.chain(auth_events, state)
], ]
) )
new_event_context = await self.state_handler.compute_event_context( new_event_context = await self.state_handler.compute_event_context(
event, old_state=state event, old_state=state
) )
return await self.persist_events_and_notify( return await self.persist_events_and_notify([(event, new_event_context)])
room_id, [(event, new_event_context)]
)
async def _prep_event( async def _prep_event(
self, self,
@ -2119,8 +2109,8 @@ class FederationHandler(BaseHandler):
if backfilled or event.internal_metadata.is_outlier(): if backfilled or event.internal_metadata.is_outlier():
return return
extrem_ids = await self.store.get_latest_event_ids_in_room(event.room_id) extrem_ids_list = await self.store.get_latest_event_ids_in_room(event.room_id)
extrem_ids = set(extrem_ids) extrem_ids = set(extrem_ids_list)
prev_event_ids = set(event.prev_event_ids()) prev_event_ids = set(event.prev_event_ids())
if extrem_ids == prev_event_ids: if extrem_ids == prev_event_ids:
@ -2913,7 +2903,6 @@ class FederationHandler(BaseHandler):
async def persist_events_and_notify( async def persist_events_and_notify(
self, self,
room_id: str,
event_and_contexts: Sequence[Tuple[EventBase, EventContext]], event_and_contexts: Sequence[Tuple[EventBase, EventContext]],
backfilled: bool = False, backfilled: bool = False,
) -> int: ) -> int:
@ -2921,19 +2910,14 @@ class FederationHandler(BaseHandler):
necessary. necessary.
Args: Args:
room_id: The room ID of events being persisted. event_and_contexts:
event_and_contexts: Sequence of events with their associated
context that should be persisted. All events must belong to
the same room.
backfilled: Whether these events are a result of backfilled: Whether these events are a result of
backfilling or not backfilling or not
""" """
instance = self.config.worker.events_shard_config.get_instance(room_id) if self.config.worker.writers.events != self._instance_name:
if instance != self._instance_name:
result = await self._send_events( result = await self._send_events(
instance_name=instance, instance_name=self.config.worker.writers.events,
store=self.store, store=self.store,
room_id=room_id,
event_and_contexts=event_and_contexts, event_and_contexts=event_and_contexts,
backfilled=backfilled, backfilled=backfilled,
) )

View File

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING
from twisted.internet import defer from twisted.internet import defer
@ -22,8 +23,9 @@ from synapse.api.errors import SynapseError
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.handlers.presence import format_user_presence_state from synapse.handlers.presence import format_user_presence_state
from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.storage.roommember import RoomsForUser
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.types import StreamToken, UserID from synapse.types import JsonDict, Requester, StreamToken, UserID
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.async_helpers import concurrently_execute from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
@ -31,11 +33,15 @@ from synapse.visibility import filter_events_for_client
from ._base import BaseHandler from ._base import BaseHandler
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class InitialSyncHandler(BaseHandler): class InitialSyncHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super(InitialSyncHandler, self).__init__(hs) super(InitialSyncHandler, self).__init__(hs)
self.hs = hs self.hs = hs
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
@ -48,27 +54,25 @@ class InitialSyncHandler(BaseHandler):
def snapshot_all_rooms( def snapshot_all_rooms(
self, self,
user_id=None, user_id: str,
pagin_config=None, pagin_config: PaginationConfig,
as_client_event=True, as_client_event: bool = True,
include_archived=False, include_archived: bool = False,
): ) -> JsonDict:
"""Retrieve a snapshot of all rooms the user is invited or has joined. """Retrieve a snapshot of all rooms the user is invited or has joined.
This snapshot may include messages for all rooms where the user is This snapshot may include messages for all rooms where the user is
joined, depending on the pagination config. joined, depending on the pagination config.
Args: Args:
user_id (str): The ID of the user making the request. user_id: The ID of the user making the request.
pagin_config (synapse.api.streams.PaginationConfig): The pagination pagin_config: The pagination config used to determine how many
config used to determine how many messages *PER ROOM* to return. messages *PER ROOM* to return.
as_client_event (bool): True to get events in client-server format. as_client_event: True to get events in client-server format.
include_archived (bool): True to get rooms that the user has left include_archived: True to get rooms that the user has left
Returns: Returns:
A list of dicts with "room_id" and "membership" keys for all rooms A JsonDict with the same format as the response to `/intialSync`
the user is currently invited or joined in on. Rooms where the user API
is joined on, may return a "messages" key with messages, depending
on the specified PaginationConfig.
""" """
key = ( key = (
user_id, user_id,
@ -91,11 +95,11 @@ class InitialSyncHandler(BaseHandler):
async def _snapshot_all_rooms( async def _snapshot_all_rooms(
self, self,
user_id=None, user_id: str,
pagin_config=None, pagin_config: PaginationConfig,
as_client_event=True, as_client_event: bool = True,
include_archived=False, include_archived: bool = False,
): ) -> JsonDict:
memberships = [Membership.INVITE, Membership.JOIN] memberships = [Membership.INVITE, Membership.JOIN]
if include_archived: if include_archived:
@ -134,7 +138,7 @@ class InitialSyncHandler(BaseHandler):
if limit is None: if limit is None:
limit = 10 limit = 10
async def handle_room(event): async def handle_room(event: RoomsForUser):
d = { d = {
"room_id": event.room_id, "room_id": event.room_id,
"membership": event.membership, "membership": event.membership,
@ -251,17 +255,18 @@ class InitialSyncHandler(BaseHandler):
return ret return ret
async def room_initial_sync(self, requester, room_id, pagin_config=None): async def room_initial_sync(
self, requester: Requester, room_id: str, pagin_config: PaginationConfig
) -> JsonDict:
"""Capture the a snapshot of a room. If user is currently a member of """Capture the a snapshot of a room. If user is currently a member of
the room this will be what is currently in the room. If the user left the room this will be what is currently in the room. If the user left
the room this will be what was in the room when they left. the room this will be what was in the room when they left.
Args: Args:
requester(Requester): The user to get a snapshot for. requester: The user to get a snapshot for.
room_id(str): The room to get a snapshot of. room_id: The room to get a snapshot of.
pagin_config(synapse.streams.config.PaginationConfig): pagin_config: The pagination config used to determine how many
The pagination config used to determine how many messages to messages to return.
return.
Raises: Raises:
AuthError if the user wasn't in the room. AuthError if the user wasn't in the room.
Returns: Returns:
@ -305,8 +310,14 @@ class InitialSyncHandler(BaseHandler):
return result return result
async def _room_initial_sync_parted( async def _room_initial_sync_parted(
self, user_id, room_id, pagin_config, membership, member_event_id, is_peeking self,
): user_id: str,
room_id: str,
pagin_config: PaginationConfig,
membership: Membership,
member_event_id: str,
is_peeking: bool,
) -> JsonDict:
room_state = await self.state_store.get_state_for_events([member_event_id]) room_state = await self.state_store.get_state_for_events([member_event_id])
room_state = room_state[member_event_id] room_state = room_state[member_event_id]
@ -350,8 +361,13 @@ class InitialSyncHandler(BaseHandler):
} }
async def _room_initial_sync_joined( async def _room_initial_sync_joined(
self, user_id, room_id, pagin_config, membership, is_peeking self,
): user_id: str,
room_id: str,
pagin_config: PaginationConfig,
membership: Membership,
is_peeking: bool,
) -> JsonDict:
current_state = await self.state.get_current_state(room_id=room_id) current_state = await self.state.get_current_state(room_id=room_id)
# TODO: These concurrently # TODO: These concurrently

View File

@ -376,8 +376,9 @@ class EventCreationHandler(object):
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.config = hs.config self.config = hs.config
self.require_membership_for_aliases = hs.config.require_membership_for_aliases self.require_membership_for_aliases = hs.config.require_membership_for_aliases
self._events_shard_config = self.config.worker.events_shard_config self._is_event_writer = (
self._instance_name = hs.get_instance_name() self.config.worker.writers.events == hs.get_instance_name()
)
self.room_invite_state_types = self.hs.config.room_invite_state_types self.room_invite_state_types = self.hs.config.room_invite_state_types
@ -905,10 +906,9 @@ class EventCreationHandler(object):
try: try:
# If we're a worker we need to hit out to the master. # If we're a worker we need to hit out to the master.
writer_instance = self._events_shard_config.get_instance(event.room_id) if not self._is_event_writer:
if writer_instance != self._instance_name:
result = await self.send_event( result = await self.send_event(
instance_name=writer_instance, instance_name=self.config.worker.writers.events,
event_id=event.event_id, event_id=event.event_id,
store=self.store, store=self.store,
requester=requester, requester=requester,
@ -976,9 +976,7 @@ class EventCreationHandler(object):
This should only be run on the instance in charge of persisting events. This should only be run on the instance in charge of persisting events.
""" """
assert self._events_shard_config.should_handle( assert self._is_event_writer
self._instance_name, event.room_id
)
if ratelimit: if ratelimit:
# We check if this is a room admin redacting an event so that we # We check if this is a room admin redacting an event so that we

View File

@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Any, Dict, Optional from typing import TYPE_CHECKING, Any, Dict, Optional, Set
from twisted.python.failure import Failure from twisted.python.failure import Failure
@ -30,6 +30,10 @@ from synapse.util.async_helpers import ReadWriteLock
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -68,7 +72,7 @@ class PaginationHandler(object):
paginating during a purge. paginating during a purge.
""" """
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -78,9 +82,9 @@ class PaginationHandler(object):
self._server_name = hs.hostname self._server_name = hs.hostname
self.pagination_lock = ReadWriteLock() self.pagination_lock = ReadWriteLock()
self._purges_in_progress_by_room = set() self._purges_in_progress_by_room = set() # type: Set[str]
# map from purge id to PurgeStatus # map from purge id to PurgeStatus
self._purges_by_id = {} self._purges_by_id = {} # type: Dict[str, PurgeStatus]
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
self._retention_default_max_lifetime = hs.config.retention_default_max_lifetime self._retention_default_max_lifetime = hs.config.retention_default_max_lifetime
@ -102,7 +106,9 @@ class PaginationHandler(object):
job["longest_max_lifetime"], job["longest_max_lifetime"],
) )
async def purge_history_for_rooms_in_range(self, min_ms, max_ms): async def purge_history_for_rooms_in_range(
self, min_ms: Optional[int], max_ms: Optional[int]
):
"""Purge outdated events from rooms within the given retention range. """Purge outdated events from rooms within the given retention range.
If a default retention policy is defined in the server's configuration and its If a default retention policy is defined in the server's configuration and its
@ -110,10 +116,10 @@ class PaginationHandler(object):
retention policy. retention policy.
Args: Args:
min_ms (int|None): Duration in milliseconds that define the lower limit of min_ms: Duration in milliseconds that define the lower limit of
the range to handle (exclusive). If None, it means that the range has no the range to handle (exclusive). If None, it means that the range has no
lower limit. lower limit.
max_ms (int|None): Duration in milliseconds that define the upper limit of max_ms: Duration in milliseconds that define the upper limit of
the range to handle (inclusive). If None, it means that the range has no the range to handle (inclusive). If None, it means that the range has no
upper limit. upper limit.
""" """
@ -220,18 +226,19 @@ class PaginationHandler(object):
"_purge_history", self._purge_history, purge_id, room_id, token, True, "_purge_history", self._purge_history, purge_id, room_id, token, True,
) )
def start_purge_history(self, room_id, token, delete_local_events=False): def start_purge_history(
self, room_id: str, token: str, delete_local_events: bool = False
) -> str:
"""Start off a history purge on a room. """Start off a history purge on a room.
Args: Args:
room_id (str): The room to purge from room_id: The room to purge from
token: topological token to delete events before
token (str): topological token to delete events before delete_local_events: True to delete local events as well as
delete_local_events (bool): True to delete local events as well as
remote ones remote ones
Returns: Returns:
str: unique ID for this purge transaction. unique ID for this purge transaction.
""" """
if room_id in self._purges_in_progress_by_room: if room_id in self._purges_in_progress_by_room:
raise SynapseError( raise SynapseError(
@ -284,14 +291,11 @@ class PaginationHandler(object):
self.hs.get_reactor().callLater(24 * 3600, clear_purge) self.hs.get_reactor().callLater(24 * 3600, clear_purge)
def get_purge_status(self, purge_id): def get_purge_status(self, purge_id: str) -> Optional[PurgeStatus]:
"""Get the current status of an active purge """Get the current status of an active purge
Args: Args:
purge_id (str): purge_id returned by start_purge_history purge_id: purge_id returned by start_purge_history
Returns:
PurgeStatus|None
""" """
return self._purges_by_id.get(purge_id) return self._purges_by_id.get(purge_id)
@ -312,8 +316,8 @@ class PaginationHandler(object):
async def get_messages( async def get_messages(
self, self,
requester: Requester, requester: Requester,
room_id: Optional[str] = None, room_id: str,
pagin_config: Optional[PaginationConfig] = None, pagin_config: PaginationConfig,
as_client_event: bool = True, as_client_event: bool = True,
event_filter: Optional[Filter] = None, event_filter: Optional[Filter] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
@ -368,11 +372,15 @@ class PaginationHandler(object):
# If they have left the room then clamp the token to be before # If they have left the room then clamp the token to be before
# they left the room, to save the effort of loading from the # they left the room, to save the effort of loading from the
# database. # database.
# This is only None if the room is world_readable, in which
# case "JOIN" would have been returned.
assert member_event_id
leave_token = await self.store.get_topological_token_for_event( leave_token = await self.store.get_topological_token_for_event(
member_event_id member_event_id
) )
leave_token = RoomStreamToken.parse(leave_token) if RoomStreamToken.parse(leave_token).topological < max_topo:
if leave_token.topological < max_topo:
source_config.from_key = str(leave_token) source_config.from_key = str(leave_token)
await self.hs.get_handlers().federation_handler.maybe_backfill( await self.hs.get_handlers().federation_handler.maybe_backfill(
@ -419,8 +427,8 @@ class PaginationHandler(object):
) )
if state_ids: if state_ids:
state = await self.store.get_events(list(state_ids.values())) state_dict = await self.store.get_events(list(state_ids.values()))
state = state.values() state = state_dict.values()
time_now = self.clock.time_msec() time_now = self.clock.time_msec()

View File

@ -804,9 +804,7 @@ class RoomCreationHandler(BaseHandler):
# Always wait for room creation to progate before returning # Always wait for room creation to progate before returning
await self._replication.wait_for_stream_position( await self._replication.wait_for_stream_position(
self.hs.config.worker.events_shard_config.get_instance(room_id), self.hs.config.worker.writers.events, "events", last_stream_id
"events",
last_stream_id,
) )
return result, last_stream_id return result, last_stream_id
@ -1262,10 +1260,10 @@ class RoomShutdownHandler(object):
# We now wait for the create room to come back in via replication so # We now wait for the create room to come back in via replication so
# that we can assume that all the joins/invites have propogated before # that we can assume that all the joins/invites have propogated before
# we try and auto join below. # we try and auto join below.
#
# TODO: Currently the events stream is written to from master
await self._replication.wait_for_stream_position( await self._replication.wait_for_stream_position(
self.hs.config.worker.events_shard_config.get_instance(new_room_id), self.hs.config.worker.writers.events, "events", stream_id
"events",
stream_id,
) )
else: else:
new_room_id = None new_room_id = None
@ -1295,9 +1293,7 @@ class RoomShutdownHandler(object):
# Wait for leave to come in over replication before trying to forget. # Wait for leave to come in over replication before trying to forget.
await self._replication.wait_for_stream_position( await self._replication.wait_for_stream_position(
self.hs.config.worker.events_shard_config.get_instance(room_id), self.hs.config.worker.writers.events, "events", stream_id
"events",
stream_id,
) )
await self.room_member_handler.forget(target_requester.user, room_id) await self.room_member_handler.forget(target_requester.user, room_id)

View File

@ -83,6 +83,13 @@ class RoomMemberHandler(object):
self._enable_lookup = hs.config.enable_3pid_lookup self._enable_lookup = hs.config.enable_3pid_lookup
self.allow_per_room_profiles = self.config.allow_per_room_profiles self.allow_per_room_profiles = self.config.allow_per_room_profiles
self._event_stream_writer_instance = hs.config.worker.writers.events
self._is_on_event_persistence_instance = (
self._event_stream_writer_instance == hs.get_instance_name()
)
if self._is_on_event_persistence_instance:
self.persist_event_storage = hs.get_storage().persistence
self._join_rate_limiter_local = Ratelimiter( self._join_rate_limiter_local = Ratelimiter(
clock=self.clock, clock=self.clock,
rate_hz=hs.config.ratelimiting.rc_joins_local.per_second, rate_hz=hs.config.ratelimiting.rc_joins_local.per_second,

View File

@ -65,11 +65,10 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
self.federation_handler = hs.get_handlers().federation_handler self.federation_handler = hs.get_handlers().federation_handler
@staticmethod @staticmethod
async def _serialize_payload(store, room_id, event_and_contexts, backfilled): async def _serialize_payload(store, event_and_contexts, backfilled):
""" """
Args: Args:
store store
room_id (str)
event_and_contexts (list[tuple[FrozenEvent, EventContext]]) event_and_contexts (list[tuple[FrozenEvent, EventContext]])
backfilled (bool): Whether or not the events are the result of backfilled (bool): Whether or not the events are the result of
backfilling backfilling
@ -89,11 +88,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
} }
) )
payload = { payload = {"events": event_payloads, "backfilled": backfilled}
"events": event_payloads,
"backfilled": backfilled,
"room_id": room_id,
}
return payload return payload
@ -101,7 +96,6 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
with Measure(self.clock, "repl_fed_send_events_parse"): with Measure(self.clock, "repl_fed_send_events_parse"):
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
room_id = content["room_id"]
backfilled = content["backfilled"] backfilled = content["backfilled"]
event_payloads = content["events"] event_payloads = content["events"]
@ -126,7 +120,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
logger.info("Got %d events from federation", len(event_and_contexts)) logger.info("Got %d events from federation", len(event_and_contexts))
max_stream_id = await self.federation_handler.persist_events_and_notify( max_stream_id = await self.federation_handler.persist_events_and_notify(
room_id, event_and_contexts, backfilled event_and_contexts, backfilled
) )
return 200, {"max_stream_id": max_stream_id} return 200, {"max_stream_id": max_stream_id}

View File

@ -109,7 +109,7 @@ class ReplicationCommandHandler:
if isinstance(stream, (EventsStream, BackfillStream)): if isinstance(stream, (EventsStream, BackfillStream)):
# Only add EventStream and BackfillStream as a source on the # Only add EventStream and BackfillStream as a source on the
# instance in charge of event persistence. # instance in charge of event persistence.
if hs.get_instance_name() in hs.config.worker.writers.events: if hs.config.worker.writers.events == hs.get_instance_name():
self._streams_to_replicate.append(stream) self._streams_to_replicate.append(stream)
continue continue

View File

@ -19,7 +19,7 @@ from typing import List, Tuple, Type
import attr import attr
from ._base import Stream, StreamUpdateResult, Token from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance
"""Handling of the 'events' replication stream """Handling of the 'events' replication stream
@ -117,7 +117,7 @@ class EventsStream(Stream):
self._store = hs.get_datastore() self._store = hs.get_datastore()
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(),
self._store._stream_id_gen.get_current_token_for_writer, current_token_without_instance(self._store.get_current_events_token),
self._update_function, self._update_function,
) )

View File

@ -68,7 +68,7 @@ class Databases(object):
# If we're on a process that can persist events also # If we're on a process that can persist events also
# instantiate a `PersistEventsStore` # instantiate a `PersistEventsStore`
if hs.get_instance_name() in hs.config.worker.writers.events: if hs.config.worker.writers.events == hs.get_instance_name():
persist_events = PersistEventsStore(hs, database, main) persist_events = PersistEventsStore(hs, database, main)
if "state" in database_config.databases: if "state" in database_config.databases:

View File

@ -25,6 +25,7 @@ from twisted.enterprise.adbapi import Connection
from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import make_in_list_sql_clause from synapse.storage.database import make_in_list_sql_clause
from synapse.storage.types import Cursor
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
@ -45,8 +46,9 @@ class DeviceKeyLookupResult:
# key) and "signatures" (a signature of the structure by the ed25519 key) # key) and "signatures" (a signature of the structure by the ed25519 key)
key_json = attr.ib(type=Optional[str]) key_json = attr.ib(type=Optional[str])
# cross-signing sigs # cross-signing sigs on this device.
signatures = attr.ib(type=Optional[Dict], default=None) # dict from (signing user_id)->(signing device_id)->sig
signatures = attr.ib(type=Optional[Dict[str, Dict[str, str]]], factory=dict)
class EndToEndKeyWorkerStore(SQLBaseStore): class EndToEndKeyWorkerStore(SQLBaseStore):
@ -133,7 +135,10 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
include_all_devices: bool = False, include_all_devices: bool = False,
include_deleted_devices: bool = False, include_deleted_devices: bool = False,
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]: ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
"""Fetch a list of device keys, together with their cross-signatures. """Fetch a list of device keys
Any cross-signatures made on the keys by the owner of the device are also
included.
Args: Args:
query_list: List of pairs of user_ids and device_ids. Device id can be None query_list: List of pairs of user_ids and device_ids. Device id can be None
@ -154,22 +159,51 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
result = await self.db_pool.runInteraction( result = await self.db_pool.runInteraction(
"get_e2e_device_keys", "get_e2e_device_keys",
self._get_e2e_device_keys_and_signatures_txn, self._get_e2e_device_keys_txn,
query_list, query_list,
include_all_devices, include_all_devices,
include_deleted_devices, include_deleted_devices,
) )
# get the (user_id, device_id) tuples to look up cross-signatures for
signature_query = (
(user_id, device_id)
for user_id, dev in result.items()
for device_id, d in dev.items()
if d is not None
)
for batch in batch_iter(signature_query, 50):
cross_sigs_result = await self.db_pool.runInteraction(
"get_e2e_cross_signing_signatures",
self._get_e2e_cross_signing_signatures_for_devices_txn,
batch,
)
# add each cross-signing signature to the correct device in the result dict.
for (user_id, key_id, device_id, signature) in cross_sigs_result:
target_device_result = result[user_id][device_id]
target_device_signatures = target_device_result.signatures
signing_user_signatures = target_device_signatures.setdefault(
user_id, {}
)
signing_user_signatures[key_id] = signature
log_kv(result) log_kv(result)
return result return result
def _get_e2e_device_keys_and_signatures_txn( def _get_e2e_device_keys_txn(
self, txn, query_list, include_all_devices=False, include_deleted_devices=False self, txn, query_list, include_all_devices=False, include_deleted_devices=False
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]: ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
"""Get information on devices from the database
The results include the device's keys and self-signatures, but *not* any
cross-signing signatures which have been added subsequently (for which, see
get_e2e_device_keys_and_signatures)
"""
query_clauses = [] query_clauses = []
query_params = [] query_params = []
signature_query_clauses = []
signature_query_params = []
if include_all_devices is False: if include_all_devices is False:
include_deleted_devices = False include_deleted_devices = False
@ -180,20 +214,12 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
for (user_id, device_id) in query_list: for (user_id, device_id) in query_list:
query_clause = "user_id = ?" query_clause = "user_id = ?"
query_params.append(user_id) query_params.append(user_id)
signature_query_clause = "target_user_id = ?"
signature_query_params.append(user_id)
if device_id is not None: if device_id is not None:
query_clause += " AND device_id = ?" query_clause += " AND device_id = ?"
query_params.append(device_id) query_params.append(device_id)
signature_query_clause += " AND target_device_id = ?"
signature_query_params.append(device_id)
signature_query_clause += " AND user_id = ?"
signature_query_params.append(user_id)
query_clauses.append(query_clause) query_clauses.append(query_clause)
signature_query_clauses.append(signature_query_clause)
sql = ( sql = (
"SELECT user_id, device_id, " "SELECT user_id, device_id, "
@ -221,41 +247,36 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
for user_id, device_id in deleted_devices: for user_id, device_id in deleted_devices:
result.setdefault(user_id, {})[device_id] = None result.setdefault(user_id, {})[device_id] = None
# get signatures on the device return result
signature_sql = ("SELECT * FROM e2e_cross_signing_signatures WHERE %s") % (
def _get_e2e_cross_signing_signatures_for_devices_txn(
self, txn: Cursor, device_query: Iterable[Tuple[str, str]]
) -> List[Tuple[str, str, str, str]]:
"""Get cross-signing signatures for a given list of devices
Returns signatures made by the owners of the devices.
Returns: a list of results; each entry in the list is a tuple of
(user_id, key_id, target_device_id, signature).
"""
signature_query_clauses = []
signature_query_params = []
for (user_id, device_id) in device_query:
signature_query_clauses.append(
"target_user_id = ? AND target_device_id = ? AND user_id = ?"
)
signature_query_params.extend([user_id, device_id, user_id])
signature_sql = """
SELECT user_id, key_id, target_device_id, signature
FROM e2e_cross_signing_signatures WHERE %s
""" % (
" OR ".join("(" + q + ")" for q in signature_query_clauses) " OR ".join("(" + q + ")" for q in signature_query_clauses)
) )
txn.execute(signature_sql, signature_query_params) txn.execute(signature_sql, signature_query_params)
rows = self.db_pool.cursor_to_dict(txn) return txn.fetchall()
# add each cross-signing signature to the correct device in the result dict.
for row in rows:
signing_user_id = row["user_id"]
signing_key_id = row["key_id"]
target_user_id = row["target_user_id"]
target_device_id = row["target_device_id"]
signature = row["signature"]
target_user_result = result.get(target_user_id)
if not target_user_result:
continue
target_device_result = target_user_result.get(target_device_id)
if not target_device_result:
# note that target_device_result will be None for deleted devices.
continue
target_device_signatures = target_device_result.signatures
if target_device_signatures is None:
target_device_signatures = target_device_result.signatures = {}
signing_user_signatures = target_device_signatures.setdefault(
signing_user_id, {}
)
signing_user_signatures[signing_key_id] = signature
return result
async def get_e2e_one_time_keys( async def get_e2e_one_time_keys(
self, user_id: str, device_id: str, key_ids: List[str] self, user_id: str, device_id: str, key_ids: List[str]

View File

@ -438,7 +438,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
""" """
if stream_ordering <= self.stream_ordering_month_ago: if stream_ordering <= self.stream_ordering_month_ago:
raise StoreError(400, "stream_ordering too old %s" % (stream_ordering,)) raise StoreError(400, "stream_ordering too old")
sql = """ sql = """
SELECT event_id FROM stream_ordering_to_exterm SELECT event_id FROM stream_ordering_to_exterm

View File

@ -97,7 +97,6 @@ class PersistEventsStore:
self.store = main_data_store self.store = main_data_store
self.database_engine = db.engine self.database_engine = db.engine
self._clock = hs.get_clock() self._clock = hs.get_clock()
self._instance_name = hs.get_instance_name()
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
@ -109,7 +108,7 @@ class PersistEventsStore:
# This should only exist on instances that are configured to write # This should only exist on instances that are configured to write
assert ( assert (
hs.get_instance_name() in hs.config.worker.writers.events hs.config.worker.writers.events == hs.get_instance_name()
), "Can only instantiate EventsStore on master" ), "Can only instantiate EventsStore on master"
async def _persist_events_and_state_updates( async def _persist_events_and_state_updates(
@ -801,7 +800,6 @@ class PersistEventsStore:
table="events", table="events",
values=[ values=[
{ {
"instance_name": self._instance_name,
"stream_ordering": event.internal_metadata.stream_ordering, "stream_ordering": event.internal_metadata.stream_ordering,
"topological_ordering": event.depth, "topological_ordering": event.depth,
"depth": event.depth, "depth": event.depth,

View File

@ -42,8 +42,7 @@ from synapse.replication.tcp.streams import BackfillStream
from synapse.replication.tcp.streams.events import EventsStream from synapse.replication.tcp.streams.events import EventsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool
from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.types import Collection, get_domain_from_id from synapse.types import Collection, get_domain_from_id
from synapse.util.caches.descriptors import Cache, cached from synapse.util.caches.descriptors import Cache, cached
from synapse.util.iterutils import batch_iter from synapse.util.iterutils import batch_iter
@ -79,54 +78,27 @@ class EventsWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super(EventsWorkerStore, self).__init__(database, db_conn, hs) super(EventsWorkerStore, self).__init__(database, db_conn, hs)
if isinstance(database.engine, PostgresEngine): if hs.config.worker.writers.events == hs.get_instance_name():
# If we're using Postgres than we can use `MultiWriterIdGenerator` # We are the process in charge of generating stream ids for events,
# regardless of whether this process writes to the streams or not. # so instantiate ID generators based on the database
self._stream_id_gen = MultiWriterIdGenerator( self._stream_id_gen = StreamIdGenerator(
db_conn=db_conn, db_conn, "events", "stream_ordering",
db=database,
instance_name=hs.get_instance_name(),
table="events",
instance_column="instance_name",
id_column="stream_ordering",
sequence_name="events_stream_seq",
) )
self._backfill_id_gen = MultiWriterIdGenerator( self._backfill_id_gen = StreamIdGenerator(
db_conn=db_conn, db_conn,
db=database, "events",
instance_name=hs.get_instance_name(), "stream_ordering",
table="events", step=-1,
instance_column="instance_name", extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
id_column="stream_ordering",
sequence_name="events_backfill_stream_seq",
positive=False,
) )
else: else:
# We shouldn't be running in worker mode with SQLite, but its useful # Another process is in charge of persisting events and generating
# to support it for unit tests. # stream IDs: rely on the replication streams to let us know which
# # IDs we can process.
# If this process is the writer than we need to use self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering")
# `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets self._backfill_id_gen = SlavedIdTracker(
# updated over replication. (Multiple writers are not supported for db_conn, "events", "stream_ordering", step=-1
# SQLite). )
if hs.get_instance_name() in hs.config.worker.writers.events:
self._stream_id_gen = StreamIdGenerator(
db_conn, "events", "stream_ordering",
)
self._backfill_id_gen = StreamIdGenerator(
db_conn,
"events",
"stream_ordering",
step=-1,
extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
)
else:
self._stream_id_gen = SlavedIdTracker(
db_conn, "events", "stream_ordering"
)
self._backfill_id_gen = SlavedIdTracker(
db_conn, "events", "stream_ordering", step=-1
)
self._get_event_cache = Cache( self._get_event_cache = Cache(
"*getEvent*", "*getEvent*",

View File

@ -1,16 +0,0 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
ALTER TABLE events ADD COLUMN instance_name TEXT;

View File

@ -1,26 +0,0 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE SEQUENCE IF NOT EXISTS events_stream_seq;
SELECT setval('events_stream_seq', (
SELECT COALESCE(MAX(stream_ordering), 1) FROM events
));
CREATE SEQUENCE IF NOT EXISTS events_backfill_stream_seq;
SELECT setval('events_backfill_stream_seq', (
SELECT COALESCE(-MIN(stream_ordering), 1) FROM events
));

View File

@ -231,12 +231,8 @@ class MultiWriterIdGenerator:
# gaps should be relatively rare it's still worth doing the book keeping # gaps should be relatively rare it's still worth doing the book keeping
# that allows us to skip forwards when there are gapless runs of # that allows us to skip forwards when there are gapless runs of
# positions. # positions.
#
# We start at 1 here as a) the first generated stream ID will be 2, and
# b) other parts of the code assume that stream IDs are strictly greater
# than 0.
self._persisted_upto_position = ( self._persisted_upto_position = (
min(self._current_positions.values()) if self._current_positions else 1 min(self._current_positions.values()) if self._current_positions else 0
) )
self._known_persisted_positions = [] # type: List[int] self._known_persisted_positions = [] # type: List[int]
@ -366,7 +362,9 @@ class MultiWriterIdGenerator:
equal to it have been successfully persisted. equal to it have been successfully persisted.
""" """
return self.get_persisted_upto_position() # Currently we don't support this operation, as it's not obvious how to
# condense the stream positions of multiple writers into a single int.
raise NotImplementedError()
def get_current_token_for_writer(self, instance_name: str) -> int: def get_current_token_for_writer(self, instance_name: str) -> int:
"""Returns the position of the given writer. """Returns the position of the given writer.

View File

@ -18,11 +18,10 @@ import functools
import inspect import inspect
import logging import logging
import threading import threading
from typing import Any, Tuple, Union, cast from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, Union, cast
from weakref import WeakValueDictionary from weakref import WeakValueDictionary
from prometheus_client import Gauge from prometheus_client import Gauge
from typing_extensions import Protocol
from twisted.internet import defer from twisted.internet import defer
@ -38,8 +37,10 @@ logger = logging.getLogger(__name__)
CacheKey = Union[Tuple, Any] CacheKey = Union[Tuple, Any]
F = TypeVar("F", bound=Callable[..., Any])
class _CachedFunction(Protocol):
class _CachedFunction(Generic[F]):
invalidate = None # type: Any invalidate = None # type: Any
invalidate_all = None # type: Any invalidate_all = None # type: Any
invalidate_many = None # type: Any invalidate_many = None # type: Any
@ -47,8 +48,11 @@ class _CachedFunction(Protocol):
cache = None # type: Any cache = None # type: Any
num_args = None # type: Any num_args = None # type: Any
def __name__(self): __name__ = None # type: str
...
# Note: This function signature is actually fiddled with by the synapse mypy
# plugin to a) make it a bound method, and b) remove any `cache_context` arg.
__call__ = None # type: F
cache_pending_metric = Gauge( cache_pending_metric = Gauge(
@ -123,7 +127,7 @@ class Cache(object):
self.name = name self.name = name
self.keylen = keylen self.keylen = keylen
self.thread = None self.thread = None # type: Optional[threading.Thread]
self.metrics = register_cache( self.metrics = register_cache(
"cache", "cache",
name, name,
@ -662,9 +666,13 @@ class _CacheContext:
def cached( def cached(
max_entries=1000, num_args=None, tree=False, cache_context=False, iterable=False max_entries: int = 1000,
): num_args: Optional[int] = None,
return lambda orig: CacheDescriptor( tree: bool = False,
cache_context: bool = False,
iterable: bool = False,
) -> Callable[[F], _CachedFunction[F]]:
func = lambda orig: CacheDescriptor(
orig, orig,
max_entries=max_entries, max_entries=max_entries,
num_args=num_args, num_args=num_args,
@ -673,8 +681,12 @@ def cached(
iterable=iterable, iterable=iterable,
) )
return cast(Callable[[F], _CachedFunction[F]], func)
def cachedList(cached_method_name, list_name, num_args=None):
def cachedList(
cached_method_name: str, list_name: str, num_args: Optional[int] = None
) -> Callable[[F], _CachedFunction[F]]:
"""Creates a descriptor that wraps a function in a `CacheListDescriptor`. """Creates a descriptor that wraps a function in a `CacheListDescriptor`.
Used to do batch lookups for an already created cache. A single argument Used to do batch lookups for an already created cache. A single argument
@ -684,11 +696,11 @@ def cachedList(cached_method_name, list_name, num_args=None):
cache. cache.
Args: Args:
cached_method_name (str): The name of the single-item lookup method. cached_method_name: The name of the single-item lookup method.
This is only used to find the cache to use. This is only used to find the cache to use.
list_name (str): The name of the argument that is the list to use to list_name: The name of the argument that is the list to use to
do batch lookups in the cache. do batch lookups in the cache.
num_args (int): Number of arguments to use as the key in the cache num_args: Number of arguments to use as the key in the cache
(including list_name). Defaults to all named parameters. (including list_name). Defaults to all named parameters.
Example: Example:
@ -702,9 +714,11 @@ def cachedList(cached_method_name, list_name, num_args=None):
def batch_do_something(self, first_arg, second_args): def batch_do_something(self, first_arg, second_args):
... ...
""" """
return lambda orig: CacheListDescriptor( func = lambda orig: CacheListDescriptor(
orig, orig,
cached_method_name=cached_method_name, cached_method_name=cached_method_name,
list_name=list_name, list_name=list_name,
num_args=num_args, num_args=num_args,
) )
return cast(Callable[[F], _CachedFunction[F]], func)