Merge branch 'develop' of github.com:matrix-org/synapse into anoa/third_party_event_rules_public_rooms

pull/8292/head
Andrew Morgan 2020-09-14 15:38:52 +01:00
commit 83a20a2dc2
52 changed files with 688 additions and 369 deletions

View File

@ -1,3 +1,12 @@
Synapse 1.20.0rc3 (2020-09-11)
==============================
Bugfixes
--------
- Fix a bug introduced in v1.20.0rc1 where the wrong exception was raised when invalid JSON data is encountered. ([\#8291](https://github.com/matrix-org/synapse/issues/8291))
Synapse 1.20.0rc2 (2020-09-09)
==============================

View File

@ -127,6 +127,20 @@ request to
with the query parameters from the original link, presented as a URL-encoded form. See the file
itself for more details.
Updated Single Sign-on HTML Templates
-------------------------------------
The ``saml_error.html`` template was removed from Synapse and replaced with the
``sso_error.html`` template. If your Synapse is configured to use SAML and a
custom ``sso_redirect_confirm_template_dir`` configuration then any customisations
of the ``saml_error.html`` template will need to be merged into the ``sso_error.html``
template. These templates are similar, but the parameters are slightly different:
* The ``msg`` parameter should be renamed to ``error_description``.
* There is no longer a ``code`` parameter for the response code.
* A string ``error`` parameter is available that includes a short hint of why a
user is seeing the error page.
ThirdPartyEventRules breaking changes
-------------------------------------

1
changelog.d/8248.feature Normal file
View File

@ -0,0 +1 @@
Consolidate the SSO error template across all configuration.

1
changelog.d/8281.misc Normal file
View File

@ -0,0 +1 @@
Change `StreamToken.room_key` to be a `RoomStreamToken` instance.

1
changelog.d/8294.feature Normal file
View File

@ -0,0 +1 @@
Add experimental support for sharding event persister.

1
changelog.d/8305.feature Normal file
View File

@ -0,0 +1 @@
Add the room topic and avatar to the room details admin API.

View File

@ -275,6 +275,8 @@ The following fields are possible in the JSON response body:
* `room_id` - The ID of the room.
* `name` - The name of the room.
* `topic` - The topic of the room.
* `avatar` - The `mxc` URI to the avatar of the room.
* `canonical_alias` - The canonical (main) alias address of the room.
* `joined_members` - How many users are currently in the room.
* `joined_local_members` - How many local users are currently in the room.
@ -304,6 +306,8 @@ Response:
{
"room_id": "!mscvqgqpHYjBGDxNym:matrix.org",
"name": "Music Theory",
"avatar": "mxc://matrix.org/AQDaVFlbkQoErdOgqWRgiGSV",
"topic": "Theory, Composition, Notation, Analysis",
"canonical_alias": "#musictheory:matrix.org",
"joined_members": 127
"joined_local_members": 2,

View File

@ -1485,11 +1485,14 @@ trusted_key_servers:
# At least one of `sp_config` or `config_path` must be set in this section to
# enable SAML login.
#
# (You will probably also want to set the following options to `false` to
# You will probably also want to set the following options to `false` to
# disable the regular login/registration flows:
# * enable_registration
# * password_config.enabled
#
# You will also want to investigate the settings under the "sso" configuration
# section below.
#
# Once SAML support is enabled, a metadata file will be exposed at
# https://<server>:<port>/_matrix/saml2/metadata.xml, which you may be able to
# use to configure your SAML IdP with. Alternatively, you can manually configure
@ -1612,31 +1615,6 @@ saml2_config:
# - attribute: department
# value: "sales"
# Directory in which Synapse will try to find the template files below.
# If not set, default templates from within the Synapse package will be used.
#
# DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates.
# If you *do* uncomment it, you will need to make sure that all the templates
# below are in the directory.
#
# Synapse will look for the following templates in this directory:
#
# * HTML page to display to users if something goes wrong during the
# authentication process: 'saml_error.html'.
#
# When rendering, this template is given the following variables:
# * code: an HTML error code corresponding to the error that is being
# returned (typically 400 or 500)
#
# * msg: a textual message describing the error.
#
# The variables will automatically be HTML-escaped.
#
# You can see the default templates at:
# https://github.com/matrix-org/synapse/tree/master/synapse/res/templates
#
#template_dir: "res/templates"
# OpenID Connect integration. The following settings can be used to make Synapse
# use an OpenID Connect Provider for authentication, instead of its internal

View File

@ -381,7 +381,6 @@ Handles searches in the user directory. It can handle REST endpoints matching
the following regular expressions:
^/_matrix/client/(api/v1|r0|unstable)/user_directory/search$
^/_matrix/client/unstable/uk.half-shot.msc2666/user/shared_rooms/.*$
When using this worker you must also set `update_user_directory: False` in the
shared configuration file to stop the main synapse running background

View File

@ -46,10 +46,12 @@ files =
synapse/server_notices,
synapse/spam_checker_api,
synapse/state,
synapse/storage/databases/main/events.py,
synapse/storage/databases/main/stream.py,
synapse/storage/databases/main/ui_auth.py,
synapse/storage/database.py,
synapse/storage/engines,
synapse/storage/persist_events.py,
synapse/storage/state.py,
synapse/storage/util,
synapse/streams,

View File

@ -48,7 +48,7 @@ try:
except ImportError:
pass
__version__ = "1.20.0rc2"
__version__ = "1.20.0rc3"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when

View File

@ -832,11 +832,26 @@ class ShardedWorkerHandlingConfig:
def should_handle(self, instance_name: str, key: str) -> bool:
"""Whether this instance is responsible for handling the given key.
"""
# If multiple instances are not defined we always return true.
# If multiple instances are not defined we always return true
if not self.instances or len(self.instances) == 1:
return True
return self.get_instance(key) == instance_name
def get_instance(self, key: str) -> str:
"""Get the instance responsible for handling the given key.
Note: For things like federation sending the config for which instance
is sending is known only to the sender instance if there is only one.
Therefore `should_handle` should be used where possible.
"""
if not self.instances:
return "master"
if len(self.instances) == 1:
return self.instances[0]
# We shard by taking the hash, modulo it by the number of instances and
# then checking whether this instance matches the instance at that
# index.
@ -846,7 +861,7 @@ class ShardedWorkerHandlingConfig:
dest_hash = sha256(key.encode("utf8")).digest()
dest_int = int.from_bytes(dest_hash, byteorder="little")
remainder = dest_int % (len(self.instances))
return self.instances[remainder] == instance_name
return self.instances[remainder]
__all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"]

View File

@ -142,3 +142,4 @@ class ShardedWorkerHandlingConfig:
instances: List[str]
def __init__(self, instances: List[str]) -> None: ...
def should_handle(self, instance_name: str, key: str) -> bool: ...
def get_instance(self, key: str) -> str: ...

View File

@ -169,10 +169,6 @@ class SAML2Config(Config):
saml2_config.get("saml_session_lifetime", "15m")
)
self.saml2_error_html_template = self.read_templates(
["saml_error.html"], saml2_config.get("template_dir")
)[0]
def _default_saml_config_dict(
self, required_attributes: set, optional_attributes: set
):
@ -225,11 +221,14 @@ class SAML2Config(Config):
# At least one of `sp_config` or `config_path` must be set in this section to
# enable SAML login.
#
# (You will probably also want to set the following options to `false` to
# You will probably also want to set the following options to `false` to
# disable the regular login/registration flows:
# * enable_registration
# * password_config.enabled
#
# You will also want to investigate the settings under the "sso" configuration
# section below.
#
# Once SAML support is enabled, a metadata file will be exposed at
# https://<server>:<port>/_matrix/saml2/metadata.xml, which you may be able to
# use to configure your SAML IdP with. Alternatively, you can manually configure
@ -351,31 +350,6 @@ class SAML2Config(Config):
# value: "staff"
# - attribute: department
# value: "sales"
# Directory in which Synapse will try to find the template files below.
# If not set, default templates from within the Synapse package will be used.
#
# DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates.
# If you *do* uncomment it, you will need to make sure that all the templates
# below are in the directory.
#
# Synapse will look for the following templates in this directory:
#
# * HTML page to display to users if something goes wrong during the
# authentication process: 'saml_error.html'.
#
# When rendering, this template is given the following variables:
# * code: an HTML error code corresponding to the error that is being
# returned (typically 400 or 500)
#
# * msg: a textual message describing the error.
#
# The variables will automatically be HTML-escaped.
#
# You can see the default templates at:
# https://github.com/matrix-org/synapse/tree/master/synapse/res/templates
#
#template_dir: "res/templates"
""" % {
"config_dir_path": config_dir_path
}

View File

@ -13,12 +13,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Union
import attr
from ._base import Config, ConfigError, ShardedWorkerHandlingConfig
from .server import ListenerConfig, parse_listener_def
def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]:
"""Helper for allowing parsing a string or list of strings to a config
option expecting a list of strings.
"""
if isinstance(obj, str):
return [obj]
return obj
@attr.s
class InstanceLocationConfig:
"""The host and port to talk to an instance via HTTP replication.
@ -33,11 +45,13 @@ class WriterLocations:
"""Specifies the instances that write various streams.
Attributes:
events: The instance that writes to the event and backfill streams.
events: The instance that writes to the typing stream.
events: The instances that write to the event and backfill streams.
typing: The instance that writes to the typing stream.
"""
events = attr.ib(default="master", type=str)
events = attr.ib(
default=["master"], type=List[str], converter=_instance_to_list_converter
)
typing = attr.ib(default="master", type=str)
@ -105,15 +119,18 @@ class WorkerConfig(Config):
writers = config.get("stream_writers") or {}
self.writers = WriterLocations(**writers)
# Check that the configured writer for events and typing also appears in
# Check that the configured writers for events and typing also appears in
# `instance_map`.
for stream in ("events", "typing"):
instance = getattr(self.writers, stream)
if instance != "master" and instance not in self.instance_map:
raise ConfigError(
"Instance %r is configured to write %s but does not appear in `instance_map` config."
% (instance, stream)
)
instances = _instance_to_list_converter(getattr(self.writers, stream))
for instance in instances:
if instance != "master" and instance not in self.instance_map:
raise ConfigError(
"Instance %r is configured to write %s but does not appear in `instance_map` config."
% (instance, stream)
)
self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events)
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """\

View File

@ -125,8 +125,8 @@ class AdminHandler(BaseHandler):
else:
stream_ordering = room.stream_ordering
from_key = str(RoomStreamToken(0, 0))
to_key = str(RoomStreamToken(None, stream_ordering))
from_key = RoomStreamToken(0, 0)
to_key = RoomStreamToken(None, stream_ordering)
written_events = set() # Events that we've processed in this room
@ -153,7 +153,7 @@ class AdminHandler(BaseHandler):
if not events:
break
from_key = events[-1].internal_metadata.after
from_key = RoomStreamToken.parse(events[-1].internal_metadata.after)
events = await filter_events_for_client(self.storage, user_id, events)

View File

@ -29,6 +29,7 @@ from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import (
RoomStreamToken,
StreamToken,
get_domain_from_id,
get_verify_key_from_cross_signing_key,
)
@ -104,18 +105,15 @@ class DeviceWorkerHandler(BaseHandler):
@trace
@measure_func("device.get_user_ids_changed")
async def get_user_ids_changed(self, user_id, from_token):
async def get_user_ids_changed(self, user_id: str, from_token: StreamToken):
"""Get list of users that have had the devices updated, or have newly
joined a room, that `user_id` may be interested in.
Args:
user_id (str)
from_token (StreamToken)
"""
set_tag("user_id", user_id)
set_tag("from_token", from_token)
now_room_key = await self.store.get_room_events_max_id()
now_room_id = self.store.get_room_max_stream_ordering()
now_room_key = RoomStreamToken(None, now_room_id)
room_ids = await self.store.get_rooms_for_user(user_id)
@ -142,7 +140,7 @@ class DeviceWorkerHandler(BaseHandler):
)
rooms_changed.update(event.room_id for event in member_events)
stream_ordering = RoomStreamToken.parse_stream_token(from_token.room_key).stream
stream_ordering = from_token.room_key.stream
possibly_changed = set(changed)
possibly_left = set()

View File

@ -896,7 +896,8 @@ class FederationHandler(BaseHandler):
)
)
await self._handle_new_events(dest, ev_infos, backfilled=True)
if ev_infos:
await self._handle_new_events(dest, room_id, ev_infos, backfilled=True)
# Step 2: Persist the rest of the events in the chunk one by one
events.sort(key=lambda e: e.depth)
@ -1189,7 +1190,7 @@ class FederationHandler(BaseHandler):
event_infos.append(_NewEventInfo(event, None, auth))
await self._handle_new_events(
destination, event_infos,
destination, room_id, event_infos,
)
def _sanity_check_event(self, ev):
@ -1336,15 +1337,15 @@ class FederationHandler(BaseHandler):
)
max_stream_id = await self._persist_auth_tree(
origin, auth_chain, state, event, room_version_obj
origin, room_id, auth_chain, state, event, room_version_obj
)
# We wait here until this instance has seen the events come down
# replication (if we're using replication) as the below uses caches.
#
# TODO: Currently the events stream is written to from master
await self._replication.wait_for_stream_position(
self.config.worker.writers.events, "events", max_stream_id
self.config.worker.events_shard_config.get_instance(room_id),
"events",
max_stream_id,
)
# Check whether this room is the result of an upgrade of a room we already know
@ -1593,7 +1594,7 @@ class FederationHandler(BaseHandler):
)
context = await self.state_handler.compute_event_context(event)
await self.persist_events_and_notify([(event, context)])
await self.persist_events_and_notify(event.room_id, [(event, context)])
return event
@ -1620,7 +1621,9 @@ class FederationHandler(BaseHandler):
await self.federation_client.send_leave(host_list, event)
context = await self.state_handler.compute_event_context(event)
stream_id = await self.persist_events_and_notify([(event, context)])
stream_id = await self.persist_events_and_notify(
event.room_id, [(event, context)]
)
return event, stream_id
@ -1868,7 +1871,7 @@ class FederationHandler(BaseHandler):
)
await self.persist_events_and_notify(
[(event, context)], backfilled=backfilled
event.room_id, [(event, context)], backfilled=backfilled
)
except Exception:
run_in_background(
@ -1881,6 +1884,7 @@ class FederationHandler(BaseHandler):
async def _handle_new_events(
self,
origin: str,
room_id: str,
event_infos: Iterable[_NewEventInfo],
backfilled: bool = False,
) -> None:
@ -1912,6 +1916,7 @@ class FederationHandler(BaseHandler):
)
await self.persist_events_and_notify(
room_id,
[
(ev_info.event, context)
for ev_info, context in zip(event_infos, contexts)
@ -1922,6 +1927,7 @@ class FederationHandler(BaseHandler):
async def _persist_auth_tree(
self,
origin: str,
room_id: str,
auth_events: List[EventBase],
state: List[EventBase],
event: EventBase,
@ -1936,6 +1942,7 @@ class FederationHandler(BaseHandler):
Args:
origin: Where the events came from
room_id,
auth_events
state
event
@ -2010,17 +2017,20 @@ class FederationHandler(BaseHandler):
events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR
await self.persist_events_and_notify(
room_id,
[
(e, events_to_context[e.event_id])
for e in itertools.chain(auth_events, state)
]
],
)
new_event_context = await self.state_handler.compute_event_context(
event, old_state=state
)
return await self.persist_events_and_notify([(event, new_event_context)])
return await self.persist_events_and_notify(
room_id, [(event, new_event_context)]
)
async def _prep_event(
self,
@ -2871,6 +2881,7 @@ class FederationHandler(BaseHandler):
async def persist_events_and_notify(
self,
room_id: str,
event_and_contexts: Sequence[Tuple[EventBase, EventContext]],
backfilled: bool = False,
) -> int:
@ -2878,14 +2889,19 @@ class FederationHandler(BaseHandler):
necessary.
Args:
event_and_contexts:
room_id: The room ID of events being persisted.
event_and_contexts: Sequence of events with their associated
context that should be persisted. All events must belong to
the same room.
backfilled: Whether these events are a result of
backfilling or not
"""
if self.config.worker.writers.events != self._instance_name:
instance = self.config.worker.events_shard_config.get_instance(room_id)
if instance != self._instance_name:
result = await self._send_events(
instance_name=self.config.worker.writers.events,
instance_name=instance,
store=self.store,
room_id=room_id,
event_and_contexts=event_and_contexts,
backfilled=backfilled,
)

View File

@ -25,7 +25,7 @@ from synapse.handlers.presence import format_user_presence_state
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.storage.roommember import RoomsForUser
from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict, Requester, StreamToken, UserID
from synapse.types import JsonDict, Requester, RoomStreamToken, StreamToken, UserID
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.response_cache import ResponseCache
@ -167,7 +167,7 @@ class InitialSyncHandler(BaseHandler):
self.state_handler.get_current_state, event.room_id
)
elif event.membership == Membership.LEAVE:
room_end_token = "s%d" % (event.stream_ordering,)
room_end_token = RoomStreamToken(None, event.stream_ordering,)
deferred_room_state = run_in_background(
self.state_store.get_state_for_events, [event.event_id]
)

View File

@ -376,9 +376,8 @@ class EventCreationHandler:
self.notifier = hs.get_notifier()
self.config = hs.config
self.require_membership_for_aliases = hs.config.require_membership_for_aliases
self._is_event_writer = (
self.config.worker.writers.events == hs.get_instance_name()
)
self._events_shard_config = self.config.worker.events_shard_config
self._instance_name = hs.get_instance_name()
self.room_invite_state_types = self.hs.config.room_invite_state_types
@ -902,9 +901,10 @@ class EventCreationHandler:
try:
# If we're a worker we need to hit out to the master.
if not self._is_event_writer:
writer_instance = self._events_shard_config.get_instance(event.room_id)
if writer_instance != self._instance_name:
result = await self.send_event(
instance_name=self.config.worker.writers.events,
instance_name=writer_instance,
event_id=event.event_id,
store=self.store,
requester=requester,
@ -972,7 +972,10 @@ class EventCreationHandler:
This should only be run on the instance in charge of persisting events.
"""
assert self._is_event_writer
assert self.storage.persistence is not None
assert self._events_shard_config.should_handle(
self._instance_name, event.room_id
)
if ratelimit:
# We check if this is a room admin redacting an event so that we

View File

@ -131,10 +131,10 @@ class OidcHandler:
def _render_error(
self, request, error: str, error_description: Optional[str] = None
) -> None:
"""Renders the error template and respond with it.
"""Render the error template and respond to the request with it.
This is used to show errors to the user. The template of this page can
be found under ``synapse/res/templates/sso_error.html``.
be found under `synapse/res/templates/sso_error.html`.
Args:
request: The incoming request from the browser.

View File

@ -344,7 +344,7 @@ class PaginationHandler:
# gets called.
raise Exception("limit not set")
room_token = RoomStreamToken.parse(from_token.room_key)
room_token = from_token.room_key
with await self.pagination_lock.read(room_id):
(
@ -381,7 +381,7 @@ class PaginationHandler:
if leave_token.topological < max_topo:
from_token = from_token.copy_and_replace(
"room_key", leave_token_str
"room_key", leave_token
)
await self.hs.get_handlers().federation_handler.maybe_backfill(

View File

@ -813,7 +813,9 @@ class RoomCreationHandler(BaseHandler):
# Always wait for room creation to progate before returning
await self._replication.wait_for_stream_position(
self.hs.config.worker.writers.events, "events", last_stream_id
self.hs.config.worker.events_shard_config.get_instance(room_id),
"events",
last_stream_id,
)
return result, last_stream_id
@ -1100,20 +1102,19 @@ class RoomEventSource:
async def get_new_events(
self,
user: UserID,
from_key: str,
from_key: RoomStreamToken,
limit: int,
room_ids: List[str],
is_guest: bool,
explicit_room_id: Optional[str] = None,
) -> Tuple[List[EventBase], str]:
) -> Tuple[List[EventBase], RoomStreamToken]:
# We just ignore the key for now.
to_key = self.get_current_key()
from_token = RoomStreamToken.parse(from_key)
if from_token.topological:
if from_key.topological:
logger.warning("Stream has topological part!!!! %r", from_key)
from_key = "s%s" % (from_token.stream,)
from_key = RoomStreamToken(None, from_key.stream)
app_service = self.store.get_app_service_by_user_id(user.to_string())
if app_service:
@ -1142,14 +1143,14 @@ class RoomEventSource:
events[:] = events[:limit]
if events:
end_key = events[-1].internal_metadata.after
end_key = RoomStreamToken.parse(events[-1].internal_metadata.after)
else:
end_key = to_key
return (events, end_key)
def get_current_key(self) -> str:
return "s%d" % (self.store.get_room_max_stream_ordering(),)
def get_current_key(self) -> RoomStreamToken:
return RoomStreamToken(None, self.store.get_room_max_stream_ordering())
def get_current_key_for_room(self, room_id: str) -> Awaitable[str]:
return self.store.get_room_events_max_id(room_id)
@ -1269,10 +1270,10 @@ class RoomShutdownHandler:
# We now wait for the create room to come back in via replication so
# that we can assume that all the joins/invites have propogated before
# we try and auto join below.
#
# TODO: Currently the events stream is written to from master
await self._replication.wait_for_stream_position(
self.hs.config.worker.writers.events, "events", stream_id
self.hs.config.worker.events_shard_config.get_instance(new_room_id),
"events",
stream_id,
)
else:
new_room_id = None
@ -1302,7 +1303,9 @@ class RoomShutdownHandler:
# Wait for leave to come in over replication before trying to forget.
await self._replication.wait_for_stream_position(
self.hs.config.worker.writers.events, "events", stream_id
self.hs.config.worker.events_shard_config.get_instance(room_id),
"events",
stream_id,
)
await self.room_member_handler.forget(target_requester.user, room_id)

View File

@ -82,13 +82,6 @@ class RoomMemberHandler:
self._enable_lookup = hs.config.enable_3pid_lookup
self.allow_per_room_profiles = self.config.allow_per_room_profiles
self._event_stream_writer_instance = hs.config.worker.writers.events
self._is_on_event_persistence_instance = (
self._event_stream_writer_instance == hs.get_instance_name()
)
if self._is_on_event_persistence_instance:
self.persist_event_storage = hs.get_storage().persistence
self._join_rate_limiter_local = Ratelimiter(
clock=self.clock,
rate_hz=hs.config.ratelimiting.rc_joins_local.per_second,

View File

@ -21,9 +21,10 @@ import saml2
import saml2.response
from saml2.client import Saml2Client
from synapse.api.errors import AuthError, SynapseError
from synapse.api.errors import SynapseError
from synapse.config import ConfigError
from synapse.config.saml2_config import SamlAttributeRequirement
from synapse.http.server import respond_with_html
from synapse.http.servlet import parse_string
from synapse.http.site import SynapseRequest
from synapse.module_api import ModuleApi
@ -41,6 +42,10 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
class MappingException(Exception):
"""Used to catch errors when mapping the SAML2 response to a user."""
@attr.s
class Saml2SessionData:
"""Data we track about SAML2 sessions"""
@ -68,6 +73,7 @@ class SamlHandler:
hs.config.saml2_grandfathered_mxid_source_attribute
)
self._saml2_attribute_requirements = hs.config.saml2.attribute_requirements
self._error_template = hs.config.sso_error_template
# plugin to do custom mapping from saml response to mxid
self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class(
@ -84,6 +90,25 @@ class SamlHandler:
# a lock on the mappings
self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock)
def _render_error(
self, request, error: str, error_description: Optional[str] = None
) -> None:
"""Render the error template and respond to the request with it.
This is used to show errors to the user. The template of this page can
be found under `synapse/res/templates/sso_error.html`.
Args:
request: The incoming request from the browser.
We'll respond with an HTML page describing the error.
error: A technical identifier for this error.
error_description: A human-readable description of the error.
"""
html = self._error_template.render(
error=error, error_description=error_description
)
respond_with_html(request, 400, html)
def handle_redirect_request(
self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None
) -> bytes:
@ -134,49 +159,6 @@ class SamlHandler:
# the dict.
self.expire_sessions()
# Pull out the user-agent and IP from the request.
user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[
0
].decode("ascii", "surrogateescape")
ip_address = self.hs.get_ip_from_request(request)
user_id, current_session = await self._map_saml_response_to_user(
resp_bytes, relay_state, user_agent, ip_address
)
# Complete the interactive auth session or the login.
if current_session and current_session.ui_auth_session_id:
await self._auth_handler.complete_sso_ui_auth(
user_id, current_session.ui_auth_session_id, request
)
else:
await self._auth_handler.complete_sso_login(user_id, request, relay_state)
async def _map_saml_response_to_user(
self,
resp_bytes: str,
client_redirect_url: str,
user_agent: str,
ip_address: str,
) -> Tuple[str, Optional[Saml2SessionData]]:
"""
Given a sample response, retrieve the cached session and user for it.
Args:
resp_bytes: The SAML response.
client_redirect_url: The redirect URL passed in by the client.
user_agent: The user agent of the client making the request.
ip_address: The IP address of the client making the request.
Returns:
Tuple of the user ID and SAML session associated with this response.
Raises:
SynapseError if there was a problem with the response.
RedirectException: some mapping providers may raise this if they need
to redirect to an interstitial page.
"""
try:
saml2_auth = self._saml_client.parse_authn_request_response(
resp_bytes,
@ -189,12 +171,23 @@ class SamlHandler:
# in the (user-visible) exception message, so let's log the exception here
# so we can track down the session IDs later.
logger.warning(str(e))
raise SynapseError(400, "Unexpected SAML2 login.")
self._render_error(
request, "unsolicited_response", "Unexpected SAML2 login."
)
return
except Exception as e:
raise SynapseError(400, "Unable to parse SAML2 response: %s." % (e,))
self._render_error(
request,
"invalid_response",
"Unable to parse SAML2 response: %s." % (e,),
)
return
if saml2_auth.not_signed:
raise SynapseError(400, "SAML2 response was not signed.")
self._render_error(
request, "unsigned_respond", "SAML2 response was not signed."
)
return
logger.debug("SAML2 response: %s", saml2_auth.origxml)
for assertion in saml2_auth.assertions:
@ -213,15 +206,73 @@ class SamlHandler:
saml2_auth.in_response_to, None
)
# Ensure that the attributes of the logged in user meet the required
# attributes.
for requirement in self._saml2_attribute_requirements:
_check_attribute_requirement(saml2_auth.ava, requirement)
if not _check_attribute_requirement(saml2_auth.ava, requirement):
self._render_error(
request, "unauthorised", "You are not authorised to log in here."
)
return
# Pull out the user-agent and IP from the request.
user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[
0
].decode("ascii", "surrogateescape")
ip_address = self.hs.get_ip_from_request(request)
# Call the mapper to register/login the user
try:
user_id = await self._map_saml_response_to_user(
saml2_auth, relay_state, user_agent, ip_address
)
except MappingException as e:
logger.exception("Could not map user")
self._render_error(request, "mapping_error", str(e))
return
# Complete the interactive auth session or the login.
if current_session and current_session.ui_auth_session_id:
await self._auth_handler.complete_sso_ui_auth(
user_id, current_session.ui_auth_session_id, request
)
else:
await self._auth_handler.complete_sso_login(user_id, request, relay_state)
async def _map_saml_response_to_user(
self,
saml2_auth: saml2.response.AuthnResponse,
client_redirect_url: str,
user_agent: str,
ip_address: str,
) -> str:
"""
Given a SAML response, retrieve the user ID for it and possibly register the user.
Args:
saml2_auth: The parsed SAML2 response.
client_redirect_url: The redirect URL passed in by the client.
user_agent: The user agent of the client making the request.
ip_address: The IP address of the client making the request.
Returns:
The user ID associated with this response.
Raises:
MappingException if there was a problem mapping the response to a user.
RedirectException: some mapping providers may raise this if they need
to redirect to an interstitial page.
"""
remote_user_id = self._user_mapping_provider.get_remote_user_id(
saml2_auth, client_redirect_url
)
if not remote_user_id:
raise Exception("Failed to extract remote user id from SAML response")
raise MappingException(
"Failed to extract remote user id from SAML response"
)
with (await self._mapping_lock.queue(self._auth_provider_id)):
# first of all, check if we already have a mapping for this user
@ -235,7 +286,7 @@ class SamlHandler:
)
if registered_user_id is not None:
logger.info("Found existing mapping %s", registered_user_id)
return registered_user_id, current_session
return registered_user_id
# backwards-compatibility hack: see if there is an existing user with a
# suitable mapping from the uid
@ -260,7 +311,7 @@ class SamlHandler:
await self._datastore.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id
)
return registered_user_id, current_session
return registered_user_id
# Map saml response to user attributes using the configured mapping provider
for i in range(1000):
@ -277,7 +328,7 @@ class SamlHandler:
localpart = attribute_dict.get("mxid_localpart")
if not localpart:
raise Exception(
raise MappingException(
"Error parsing SAML2 response: SAML mapping provider plugin "
"did not return a mxid_localpart value"
)
@ -294,8 +345,8 @@ class SamlHandler:
else:
# Unable to generate a username in 1000 iterations
# Break and return error to the user
raise SynapseError(
500, "Unable to generate a Matrix ID from the SAML response"
raise MappingException(
"Unable to generate a Matrix ID from the SAML response"
)
logger.info("Mapped SAML user to local part %s", localpart)
@ -310,7 +361,7 @@ class SamlHandler:
await self._datastore.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id
)
return registered_user_id, current_session
return registered_user_id
def expire_sessions(self):
expire_before = self._clock.time_msec() - self._saml2_session_lifetime
@ -323,11 +374,11 @@ class SamlHandler:
del self._outstanding_requests_dict[reqid]
def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement):
def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement) -> bool:
values = ava.get(req.attribute, [])
for v in values:
if v == req.value:
return
return True
logger.info(
"SAML2 attribute %s did not match required value '%s' (was '%s')",
@ -335,7 +386,7 @@ def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement):
req.value,
values,
)
raise AuthError(403, "You are not authorized to log in here.")
return False
DOT_REPLACE_PATTERN = re.compile(
@ -390,7 +441,7 @@ class DefaultSamlMappingProvider:
return saml_response.ava["uid"][0]
except KeyError:
logger.warning("SAML2 response lacks a 'uid' attestation")
raise SynapseError(400, "'uid' not in SAML2 response")
raise MappingException("'uid' not in SAML2 response")
def saml_response_to_user_attributes(
self,

View File

@ -378,7 +378,7 @@ class SyncHandler:
sync_config = sync_result_builder.sync_config
with Measure(self.clock, "ephemeral_by_room"):
typing_key = since_token.typing_key if since_token else "0"
typing_key = since_token.typing_key if since_token else 0
room_ids = sync_result_builder.joined_room_ids
@ -402,7 +402,7 @@ class SyncHandler:
event_copy = {k: v for (k, v) in event.items() if k != "room_id"}
ephemeral_by_room.setdefault(room_id, []).append(event_copy)
receipt_key = since_token.receipt_key if since_token else "0"
receipt_key = since_token.receipt_key if since_token else 0
receipt_source = self.event_sources.sources["receipt"]
receipts, receipt_key = await receipt_source.get_new_events(
@ -533,7 +533,7 @@ class SyncHandler:
if len(recents) > timeline_limit:
limited = True
recents = recents[-timeline_limit:]
room_key = recents[0].internal_metadata.before
room_key = RoomStreamToken.parse(recents[0].internal_metadata.before)
prev_batch_token = now_token.copy_and_replace("room_key", room_key)
@ -1322,6 +1322,7 @@ class SyncHandler:
is_guest=sync_config.is_guest,
include_offline=include_offline,
)
assert presence_key
sync_result_builder.now_token = now_token.copy_and_replace(
"presence_key", presence_key
)
@ -1484,7 +1485,7 @@ class SyncHandler:
if rooms_changed:
return True
stream_id = RoomStreamToken.parse_stream_token(since_token.room_key).stream
stream_id = since_token.room_key.stream
for room_id in sync_result_builder.joined_room_ids:
if self.store.has_room_changed_since(room_id, stream_id):
return True
@ -1750,7 +1751,7 @@ class SyncHandler:
continue
leave_token = now_token.copy_and_replace(
"room_key", "s%d" % (event.stream_ordering,)
"room_key", RoomStreamToken(None, event.stream_ordering)
)
room_entries.append(
RoomSyncResultBuilder(

View File

@ -54,6 +54,7 @@ from synapse.logging.opentracing import (
start_active_span,
tags,
)
from synapse.util import json_decoder
from synapse.util.async_helpers import timeout_deferred
from synapse.util.metrics import Measure
@ -164,7 +165,9 @@ async def _handle_json_response(
try:
check_content_type_is_json(response.headers)
d = treq.json_content(response)
# Use the custom JSON decoder (partially re-implements treq.json_content).
d = treq.text_content(response, encoding="utf-8")
d.addCallback(json_decoder.decode)
d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor)
body = await make_deferred_yieldable(d)

View File

@ -25,6 +25,7 @@ from typing import (
Set,
Tuple,
TypeVar,
Union,
)
from prometheus_client import Counter
@ -41,7 +42,7 @@ from synapse.logging.utils import log_function
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.streams.config import PaginationConfig
from synapse.types import Collection, StreamToken, UserID
from synapse.types import Collection, RoomStreamToken, StreamToken, UserID
from synapse.util.async_helpers import ObservableDeferred, timeout_deferred
from synapse.util.metrics import Measure
from synapse.visibility import filter_events_for_client
@ -111,7 +112,9 @@ class _NotifierUserStream:
with PreserveLoggingContext():
self.notify_deferred = ObservableDeferred(defer.Deferred())
def notify(self, stream_key: str, stream_id: int, time_now_ms: int):
def notify(
self, stream_key: str, stream_id: Union[int, RoomStreamToken], time_now_ms: int,
):
"""Notify any listeners for this user of a new event from an
event source.
Args:
@ -294,7 +297,12 @@ class Notifier:
rooms.add(event.room_id)
if users or rooms:
self.on_new_event("room_key", max_room_stream_id, users=users, rooms=rooms)
self.on_new_event(
"room_key",
RoomStreamToken(None, max_room_stream_id),
users=users,
rooms=rooms,
)
self._on_updated_room_token(max_room_stream_id)
def _on_updated_room_token(self, max_room_stream_id: int):
@ -329,7 +337,7 @@ class Notifier:
def on_new_event(
self,
stream_key: str,
new_token: int,
new_token: Union[int, RoomStreamToken],
users: Collection[UserID] = [],
rooms: Collection[str] = [],
):

View File

@ -65,10 +65,11 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
self.federation_handler = hs.get_handlers().federation_handler
@staticmethod
async def _serialize_payload(store, event_and_contexts, backfilled):
async def _serialize_payload(store, room_id, event_and_contexts, backfilled):
"""
Args:
store
room_id (str)
event_and_contexts (list[tuple[FrozenEvent, EventContext]])
backfilled (bool): Whether or not the events are the result of
backfilling
@ -88,7 +89,11 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
}
)
payload = {"events": event_payloads, "backfilled": backfilled}
payload = {
"events": event_payloads,
"backfilled": backfilled,
"room_id": room_id,
}
return payload
@ -96,6 +101,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
with Measure(self.clock, "repl_fed_send_events_parse"):
content = parse_json_object_from_request(request)
room_id = content["room_id"]
backfilled = content["backfilled"]
event_payloads = content["events"]
@ -120,7 +126,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
logger.info("Got %d events from federation", len(event_and_contexts))
max_stream_id = await self.federation_handler.persist_events_and_notify(
event_and_contexts, backfilled
room_id, event_and_contexts, backfilled
)
return 200, {"max_stream_id": max_stream_id}

View File

@ -109,7 +109,7 @@ class ReplicationCommandHandler:
if isinstance(stream, (EventsStream, BackfillStream)):
# Only add EventStream and BackfillStream as a source on the
# instance in charge of event persistence.
if hs.config.worker.writers.events == hs.get_instance_name():
if hs.get_instance_name() in hs.config.worker.writers.events:
self._streams_to_replicate.append(stream)
continue

View File

@ -19,7 +19,7 @@ from typing import List, Tuple, Type
import attr
from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance
from ._base import Stream, StreamUpdateResult, Token
"""Handling of the 'events' replication stream
@ -117,7 +117,7 @@ class EventsStream(Stream):
self._store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
current_token_without_instance(self._store.get_current_events_token),
self._store._stream_id_gen.get_current_token_for_writer,
self._update_function,
)

View File

@ -1,52 +0,0 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>SSO login error</title>
</head>
<body>
{# a 403 means we have actively rejected their login #}
{% if code == 403 %}
<p>You are not allowed to log in here.</p>
{% else %}
<p>
There was an error during authentication:
</p>
<div id="errormsg" style="margin:20px 80px">{{ msg }}</div>
<p>
If you are seeing this page after clicking a link sent to you via email, make
sure you only click the confirmation link once, and that you open the
validation link in the same client you're logging in from.
</p>
<p>
Try logging in again from your Matrix client and if the problem persists
please contact the server's administrator.
</p>
<script type="text/javascript">
// Error handling to support Auth0 errors that we might get through a GET request
// to the validation endpoint. If an error is provided, it's either going to be
// located in the query string or in a query string-like URI fragment.
// We try to locate the error from any of these two locations, but if we can't
// we just don't print anything specific.
let searchStr = "";
if (window.location.search) {
// window.location.searchParams isn't always defined when
// window.location.search is, so it's more reliable to parse the latter.
searchStr = window.location.search;
} else if (window.location.hash) {
// Replace the # with a ? so that URLSearchParams does the right thing and
// doesn't parse the first parameter incorrectly.
searchStr = window.location.hash.replace("#", "?");
}
// We might end up with no error in the URL, so we need to check if we have one
// to print one.
let errorDesc = new URLSearchParams(searchStr).get("error_description")
if (errorDesc) {
document.getElementById("errormsg").innerText = errorDesc;
}
</script>
{% endif %}
</body>
</html>

View File

@ -5,14 +5,49 @@
<title>SSO error</title>
</head>
<body>
<p>Oops! Something went wrong during authentication.</p>
{# If an error of unauthorised is returned it means we have actively rejected their login #}
{% if error == "unauthorised" %}
<p>You are not allowed to log in here.</p>
{% else %}
<p>
There was an error during authentication:
</p>
<div id="errormsg" style="margin:20px 80px">{{ error_description }}</div>
<p>
If you are seeing this page after clicking a link sent to you via email, make
sure you only click the confirmation link once, and that you open the
validation link in the same client you're logging in from.
</p>
<p>
Try logging in again from your Matrix client and if the problem persists
please contact the server's administrator.
</p>
<p>Error: <code>{{ error }}</code></p>
{% if error_description %}
<pre><code>{{ error_description }}</code></pre>
{% endif %}
<script type="text/javascript">
// Error handling to support Auth0 errors that we might get through a GET request
// to the validation endpoint. If an error is provided, it's either going to be
// located in the query string or in a query string-like URI fragment.
// We try to locate the error from any of these two locations, but if we can't
// we just don't print anything specific.
let searchStr = "";
if (window.location.search) {
// window.location.searchParams isn't always defined when
// window.location.search is, so it's more reliable to parse the latter.
searchStr = window.location.search;
} else if (window.location.hash) {
// Replace the # with a ? so that URLSearchParams does the right thing and
// doesn't parse the first parameter incorrectly.
searchStr = window.location.hash.replace("#", "?");
}
// We might end up with no error in the URL, so we need to check if we have one
// to print one.
let errorDesc = new URLSearchParams(searchStr).get("error_description")
if (errorDesc) {
document.getElementById("errormsg").innerText = errorDesc;
}
</script>
{% endif %}
</body>
</html>

View File

@ -13,10 +13,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.python import failure
from synapse.api.errors import SynapseError
from synapse.http.server import DirectServeHtmlResource, return_html_error
from synapse.http.server import DirectServeHtmlResource
class SAML2ResponseResource(DirectServeHtmlResource):
@ -27,21 +25,15 @@ class SAML2ResponseResource(DirectServeHtmlResource):
def __init__(self, hs):
super().__init__()
self._saml_handler = hs.get_saml_handler()
self._error_html_template = hs.config.saml2.saml2_error_html_template
async def _async_render_GET(self, request):
# We're not expecting any GET request on that resource if everything goes right,
# but some IdPs sometimes end up responding with a 302 redirect on this endpoint.
# In this case, just tell the user that something went wrong and they should
# try to authenticate again.
f = failure.Failure(
SynapseError(400, "Unexpected GET request on /saml2/authn_response")
self._saml_handler._render_error(
request, "unexpected_get", "Unexpected GET request on /saml2/authn_response"
)
return_html_error(f, request, self._error_html_template)
async def _async_render_POST(self, request):
try:
await self._saml_handler.handle_saml_response(request)
except Exception:
f = failure.Failure()
return_html_error(f, request, self._error_html_template)
await self._saml_handler.handle_saml_response(request)

View File

@ -47,6 +47,9 @@ class Storage:
# interfaces.
self.main = stores.main
self.persistence = EventsPersistenceStorage(hs, stores)
self.purge_events = PurgeEventsStorage(hs, stores)
self.state = StateGroupStorage(hs, stores)
self.persistence = None
if stores.persist_events:
self.persistence = EventsPersistenceStorage(hs, stores)

View File

@ -75,7 +75,7 @@ class Databases:
# If we're on a process that can persist events also
# instantiate a `PersistEventsStore`
if hs.config.worker.writers.events == hs.get_instance_name():
if hs.get_instance_name() in hs.config.worker.writers.events:
persist_events = PersistEventsStore(hs, database, main)
if "state" in database_config.databases:

View File

@ -438,7 +438,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
"""
if stream_ordering <= self.stream_ordering_month_ago:
raise StoreError(400, "stream_ordering too old")
raise StoreError(400, "stream_ordering too old %s" % (stream_ordering,))
sql = """
SELECT event_id FROM stream_ordering_to_exterm

View File

@ -32,7 +32,7 @@ from synapse.logging.utils import log_function
from synapse.storage._base import db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.search import SearchEntry
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import StateMap, get_domain_from_id
from synapse.util.frozenutils import frozendict_json_encoder
from synapse.util.iterutils import batch_iter
@ -97,18 +97,21 @@ class PersistEventsStore:
self.store = main_data_store
self.database_engine = db.engine
self._clock = hs.get_clock()
self._instance_name = hs.get_instance_name()
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
self.is_mine_id = hs.is_mine_id
# Ideally we'd move these ID gens here, unfortunately some other ID
# generators are chained off them so doing so is a bit of a PITA.
self._backfill_id_gen = self.store._backfill_id_gen # type: StreamIdGenerator
self._stream_id_gen = self.store._stream_id_gen # type: StreamIdGenerator
self._backfill_id_gen = (
self.store._backfill_id_gen
) # type: MultiWriterIdGenerator
self._stream_id_gen = self.store._stream_id_gen # type: MultiWriterIdGenerator
# This should only exist on instances that are configured to write
assert (
hs.config.worker.writers.events == hs.get_instance_name()
hs.get_instance_name() in hs.config.worker.writers.events
), "Can only instantiate EventsStore on master"
async def _persist_events_and_state_updates(
@ -213,7 +216,7 @@ class PersistEventsStore:
Returns:
Filtered event ids
"""
results = []
results = [] # type: List[str]
def _get_events_which_are_prevs_txn(txn, batch):
sql = """
@ -631,7 +634,9 @@ class PersistEventsStore:
)
@classmethod
def _filter_events_and_contexts_for_duplicates(cls, events_and_contexts):
def _filter_events_and_contexts_for_duplicates(
cls, events_and_contexts: List[Tuple[EventBase, EventContext]]
) -> List[Tuple[EventBase, EventContext]]:
"""Ensure that we don't have the same event twice.
Pick the earliest non-outlier if there is one, else the earliest one.
@ -641,7 +646,9 @@ class PersistEventsStore:
Returns:
list[(EventBase, EventContext)]: filtered list
"""
new_events_and_contexts = OrderedDict()
new_events_and_contexts = (
OrderedDict()
) # type: OrderedDict[str, Tuple[EventBase, EventContext]]
for event, context in events_and_contexts:
prev_event_context = new_events_and_contexts.get(event.event_id)
if prev_event_context:
@ -655,7 +662,12 @@ class PersistEventsStore:
new_events_and_contexts[event.event_id] = (event, context)
return list(new_events_and_contexts.values())
def _update_room_depths_txn(self, txn, events_and_contexts, backfilled):
def _update_room_depths_txn(
self,
txn,
events_and_contexts: List[Tuple[EventBase, EventContext]],
backfilled: bool,
):
"""Update min_depth for each room
Args:
@ -664,7 +676,7 @@ class PersistEventsStore:
we are persisting
backfilled (bool): True if the events were backfilled
"""
depth_updates = {}
depth_updates = {} # type: Dict[str, int]
for event, context in events_and_contexts:
# Remove the any existing cache entries for the event_ids
txn.call_after(self.store._invalidate_get_event_cache, event.event_id)
@ -800,6 +812,7 @@ class PersistEventsStore:
table="events",
values=[
{
"instance_name": self._instance_name,
"stream_ordering": event.internal_metadata.stream_ordering,
"topological_ordering": event.depth,
"depth": event.depth,
@ -1436,7 +1449,7 @@ class PersistEventsStore:
Forward extremities are handled when we first start persisting the events.
"""
events_by_room = {}
events_by_room = {} # type: Dict[str, List[EventBase]]
for ev in events:
events_by_room.setdefault(ev.room_id, []).append(ev)

View File

@ -42,7 +42,8 @@ from synapse.replication.tcp.streams import BackfillStream
from synapse.replication.tcp.streams.events import EventsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.types import Collection, get_domain_from_id
from synapse.util.caches.descriptors import Cache, cached
from synapse.util.iterutils import batch_iter
@ -78,27 +79,54 @@ class EventsWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super(EventsWorkerStore, self).__init__(database, db_conn, hs)
if hs.config.worker.writers.events == hs.get_instance_name():
# We are the process in charge of generating stream ids for events,
# so instantiate ID generators based on the database
self._stream_id_gen = StreamIdGenerator(
db_conn, "events", "stream_ordering",
if isinstance(database.engine, PostgresEngine):
# If we're using Postgres than we can use `MultiWriterIdGenerator`
# regardless of whether this process writes to the streams or not.
self._stream_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
instance_name=hs.get_instance_name(),
table="events",
instance_column="instance_name",
id_column="stream_ordering",
sequence_name="events_stream_seq",
)
self._backfill_id_gen = StreamIdGenerator(
db_conn,
"events",
"stream_ordering",
step=-1,
extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
self._backfill_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
instance_name=hs.get_instance_name(),
table="events",
instance_column="instance_name",
id_column="stream_ordering",
sequence_name="events_backfill_stream_seq",
positive=False,
)
else:
# Another process is in charge of persisting events and generating
# stream IDs: rely on the replication streams to let us know which
# IDs we can process.
self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering")
self._backfill_id_gen = SlavedIdTracker(
db_conn, "events", "stream_ordering", step=-1
)
# We shouldn't be running in worker mode with SQLite, but its useful
# to support it for unit tests.
#
# If this process is the writer than we need to use
# `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
# updated over replication. (Multiple writers are not supported for
# SQLite).
if hs.get_instance_name() in hs.config.worker.writers.events:
self._stream_id_gen = StreamIdGenerator(
db_conn, "events", "stream_ordering",
)
self._backfill_id_gen = StreamIdGenerator(
db_conn,
"events",
"stream_ordering",
step=-1,
extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
)
else:
self._stream_id_gen = SlavedIdTracker(
db_conn, "events", "stream_ordering"
)
self._backfill_id_gen = SlavedIdTracker(
db_conn, "events", "stream_ordering", step=-1
)
self._get_event_cache = Cache(
"*getEvent*",

View File

@ -104,7 +104,8 @@ class RoomWorkerStore(SQLBaseStore):
curr.local_users_in_room AS joined_local_members, rooms.room_version AS version,
rooms.creator, state.encryption, state.is_federatable AS federatable,
rooms.is_public AS public, state.join_rules, state.guest_access,
state.history_visibility, curr.current_state_events AS state_events
state.history_visibility, curr.current_state_events AS state_events,
state.avatar, state.topic
FROM rooms
LEFT JOIN room_stats_state state USING (room_id)
LEFT JOIN room_stats_current curr USING (room_id)

View File

@ -0,0 +1,16 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
ALTER TABLE events ADD COLUMN instance_name TEXT;

View File

@ -0,0 +1,26 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE SEQUENCE IF NOT EXISTS events_stream_seq;
SELECT setval('events_stream_seq', (
SELECT COALESCE(MAX(stream_ordering), 1) FROM events
));
CREATE SEQUENCE IF NOT EXISTS events_backfill_stream_seq;
SELECT setval('events_backfill_stream_seq', (
SELECT COALESCE(-MIN(stream_ordering), 1) FROM events
));

View File

@ -310,11 +310,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
async def get_room_events_stream_for_rooms(
self,
room_ids: Collection[str],
from_key: str,
to_key: str,
from_key: RoomStreamToken,
to_key: RoomStreamToken,
limit: int = 0,
order: str = "DESC",
) -> Dict[str, Tuple[List[EventBase], str]]:
) -> Dict[str, Tuple[List[EventBase], RoomStreamToken]]:
"""Get new room events in stream ordering since `from_key`.
Args:
@ -333,9 +333,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
- list of recent events in the room
- stream ordering key for the start of the chunk of events returned.
"""
from_id = RoomStreamToken.parse_stream_token(from_key).stream
room_ids = self._events_stream_cache.get_entities_changed(room_ids, from_id)
room_ids = self._events_stream_cache.get_entities_changed(
room_ids, from_key.stream
)
if not room_ids:
return {}
@ -364,16 +364,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return results
def get_rooms_that_changed(
self, room_ids: Collection[str], from_key: str
self, room_ids: Collection[str], from_key: RoomStreamToken
) -> Set[str]:
"""Given a list of rooms and a token, return rooms where there may have
been changes.
Args:
room_ids
from_key: The room_key portion of a StreamToken
"""
from_id = RoomStreamToken.parse_stream_token(from_key).stream
from_id = from_key.stream
return {
room_id
for room_id in room_ids
@ -383,11 +379,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
async def get_room_events_stream_for_room(
self,
room_id: str,
from_key: str,
to_key: str,
from_key: RoomStreamToken,
to_key: RoomStreamToken,
limit: int = 0,
order: str = "DESC",
) -> Tuple[List[EventBase], str]:
) -> Tuple[List[EventBase], RoomStreamToken]:
"""Get new room events in stream ordering since `from_key`.
Args:
@ -408,8 +404,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
if from_key == to_key:
return [], from_key
from_id = RoomStreamToken.parse_stream_token(from_key).stream
to_id = RoomStreamToken.parse_stream_token(to_key).stream
from_id = from_key.stream
to_id = to_key.stream
has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id)
@ -441,7 +437,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
ret.reverse()
if rows:
key = "s%d" % min(r.stream_ordering for r in rows)
key = RoomStreamToken(None, min(r.stream_ordering for r in rows))
else:
# Assume we didn't get anything because there was nothing to
# get.
@ -450,10 +446,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return ret, key
async def get_membership_changes_for_user(
self, user_id: str, from_key: str, to_key: str
self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken
) -> List[EventBase]:
from_id = RoomStreamToken.parse_stream_token(from_key).stream
to_id = RoomStreamToken.parse_stream_token(to_key).stream
from_id = from_key.stream
to_id = to_key.stream
if from_key == to_key:
return []
@ -491,8 +487,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return ret
async def get_recent_events_for_room(
self, room_id: str, limit: int, end_token: str
) -> Tuple[List[EventBase], str]:
self, room_id: str, limit: int, end_token: RoomStreamToken
) -> Tuple[List[EventBase], RoomStreamToken]:
"""Get the most recent events in the room in topological ordering.
Args:
@ -518,8 +514,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return (events, token)
async def get_recent_event_ids_for_room(
self, room_id: str, limit: int, end_token: str
) -> Tuple[List[_EventDictReturn], str]:
self, room_id: str, limit: int, end_token: RoomStreamToken
) -> Tuple[List[_EventDictReturn], RoomStreamToken]:
"""Get the most recent events in the room in topological ordering.
Args:
@ -535,13 +531,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
if limit == 0:
return [], end_token
parsed_end_token = RoomStreamToken.parse(end_token)
rows, token = await self.db_pool.runInteraction(
"get_recent_event_ids_for_room",
self._paginate_room_events_txn,
room_id,
from_token=parsed_end_token,
from_token=end_token,
limit=limit,
)
@ -619,17 +613,17 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
allow_none=allow_none,
)
async def get_stream_token_for_event(self, event_id: str) -> str:
async def get_stream_token_for_event(self, event_id: str) -> RoomStreamToken:
"""The stream token for an event
Args:
event_id: The id of the event to look up a stream token for.
Raises:
StoreError if the event wasn't in the database.
Returns:
A "s%d" stream token.
A stream token.
"""
stream_id = await self.get_stream_id_for_event(event_id)
return "s%d" % (stream_id,)
return RoomStreamToken(None, stream_id)
async def get_topological_token_for_event(self, event_id: str) -> str:
"""The stream token for an event
@ -954,7 +948,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
direction: str = "b",
limit: int = -1,
event_filter: Optional[Filter] = None,
) -> Tuple[List[_EventDictReturn], str]:
) -> Tuple[List[_EventDictReturn], RoomStreamToken]:
"""Returns list of events before or after a given token.
Args:
@ -1054,17 +1048,17 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
# TODO (erikj): We should work out what to do here instead.
next_token = to_token if to_token else from_token
return rows, str(next_token)
return rows, next_token
async def paginate_room_events(
self,
room_id: str,
from_key: str,
to_key: Optional[str] = None,
from_key: RoomStreamToken,
to_key: Optional[RoomStreamToken] = None,
direction: str = "b",
limit: int = -1,
event_filter: Optional[Filter] = None,
) -> Tuple[List[EventBase], str]:
) -> Tuple[List[EventBase], RoomStreamToken]:
"""Returns list of events before or after a given token.
Args:
@ -1083,17 +1077,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
and `to_key`).
"""
parsed_from_key = RoomStreamToken.parse(from_key)
parsed_to_key = None
if to_key:
parsed_to_key = RoomStreamToken.parse(to_key)
rows, token = await self.db_pool.runInteraction(
"paginate_room_events",
self._paginate_room_events_txn,
room_id,
parsed_from_key,
parsed_to_key,
from_key,
to_key,
direction,
limit,
event_filter,

View File

@ -18,7 +18,7 @@
import itertools
import logging
from collections import deque, namedtuple
from typing import Iterable, List, Optional, Set, Tuple
from typing import Dict, Iterable, List, Optional, Set, Tuple
from prometheus_client import Counter, Histogram
@ -31,7 +31,7 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.databases import Databases
from synapse.storage.databases.main.events import DeltaState
from synapse.types import StateMap
from synapse.types import Collection, StateMap
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.metrics import Measure
@ -185,6 +185,8 @@ class EventsPersistenceStorage:
# store for now.
self.main_store = stores.main
self.state_store = stores.state
assert stores.persist_events
self.persist_events_store = stores.persist_events
self._clock = hs.get_clock()
@ -208,7 +210,7 @@ class EventsPersistenceStorage:
Returns:
the stream ordering of the latest persisted event
"""
partitioned = {}
partitioned = {} # type: Dict[str, List[Tuple[EventBase, EventContext]]]
for event, ctx in events_and_contexts:
partitioned.setdefault(event.room_id, []).append((event, ctx))
@ -305,7 +307,9 @@ class EventsPersistenceStorage:
# Work out the new "current state" for each room.
# We do this by working out what the new extremities are and then
# calculating the state from that.
events_by_room = {}
events_by_room = (
{}
) # type: Dict[str, List[Tuple[EventBase, EventContext]]]
for event, context in chunk:
events_by_room.setdefault(event.room_id, []).append(
(event, context)
@ -436,7 +440,7 @@ class EventsPersistenceStorage:
self,
room_id: str,
event_contexts: List[Tuple[EventBase, EventContext]],
latest_event_ids: List[str],
latest_event_ids: Collection[str],
):
"""Calculates the new forward extremities for a room given events to
persist.
@ -470,7 +474,7 @@ class EventsPersistenceStorage:
# Remove any events which are prev_events of any existing events.
existing_prevs = await self.persist_events_store._get_events_which_are_prevs(
result
)
) # type: Collection[str]
result.difference_update(existing_prevs)
# Finally handle the case where the new events have soft-failed prev

View File

@ -240,8 +240,12 @@ class MultiWriterIdGenerator:
# gaps should be relatively rare it's still worth doing the book keeping
# that allows us to skip forwards when there are gapless runs of
# positions.
#
# We start at 1 here as a) the first generated stream ID will be 2, and
# b) other parts of the code assume that stream IDs are strictly greater
# than 0.
self._persisted_upto_position = (
min(self._current_positions.values()) if self._current_positions else 0
min(self._current_positions.values()) if self._current_positions else 1
)
self._known_persisted_positions = [] # type: List[int]
@ -398,9 +402,7 @@ class MultiWriterIdGenerator:
equal to it have been successfully persisted.
"""
# Currently we don't support this operation, as it's not obvious how to
# condense the stream positions of multiple writers into a single int.
raise NotImplementedError()
return self.get_persisted_upto_position()
def get_current_token_for_writer(self, instance_name: str) -> int:
"""Returns the position of the given writer.

View File

@ -425,7 +425,9 @@ class RoomStreamToken:
@attr.s(slots=True, frozen=True)
class StreamToken:
room_key = attr.ib(type=str)
room_key = attr.ib(
type=RoomStreamToken, validator=attr.validators.instance_of(RoomStreamToken)
)
presence_key = attr.ib(type=int)
typing_key = attr.ib(type=int)
receipt_key = attr.ib(type=int)
@ -445,21 +447,16 @@ class StreamToken:
while len(keys) < len(attr.fields(cls)):
# i.e. old token from before receipt_key
keys.append("0")
return cls(keys[0], *(int(k) for k in keys[1:]))
return cls(RoomStreamToken.parse(keys[0]), *(int(k) for k in keys[1:]))
except Exception:
raise SynapseError(400, "Invalid Token")
def to_string(self):
return self._SEPARATOR.join([str(k) for k in attr.astuple(self)])
return self._SEPARATOR.join([str(k) for k in attr.astuple(self, recurse=False)])
@property
def room_stream_id(self):
# TODO(markjh): Awful hack to work around hacks in the presence tests
# which assume that the keys are integers.
if type(self.room_key) is int:
return self.room_key
else:
return int(self.room_key[1:].split("-")[-1])
return self.room_key.stream
def is_after(self, other):
"""Does this token contain events that the other doesn't?"""
@ -475,7 +472,7 @@ class StreamToken:
or (int(other.groups_key) < int(self.groups_key))
)
def copy_and_advance(self, key, new_value):
def copy_and_advance(self, key, new_value) -> "StreamToken":
"""Advance the given key in the token to a new value if and only if the
new value is after the old value.
"""
@ -491,7 +488,7 @@ class StreamToken:
else:
return self
def copy_and_replace(self, key, new_value):
def copy_and_replace(self, key, new_value) -> "StreamToken":
return attr.evolve(self, **{key: new_value})

View File

@ -28,7 +28,7 @@ logger = logging.getLogger(__name__)
def _reject_invalid_json(val):
"""Do not allow Infinity, -Infinity, or NaN values in JSON."""
raise json.JSONDecodeError("Invalid JSON value: '%s'" % val)
raise ValueError("Invalid JSON value: '%s'" % val)
# Create a custom encoder to reduce the whitespace produced by JSON encoding and

View File

@ -15,6 +15,8 @@
# limitations under the License.
import logging
from parameterized import parameterized
from synapse.events import make_event_from_dict
from synapse.federation.federation_server import server_matches_acl_event
from synapse.rest import admin
@ -23,6 +25,37 @@ from synapse.rest.client.v1 import login, room
from tests import unittest
class FederationServerTests(unittest.FederatingHomeserverTestCase):
servlets = [
admin.register_servlets,
room.register_servlets,
login.register_servlets,
]
@parameterized.expand([(b"",), (b"foo",), (b'{"limit": Infinity}',)])
def test_bad_request(self, query_content):
"""
Querying with bad data returns a reasonable error code.
"""
u1 = self.register_user("u1", "pass")
u1_token = self.login("u1", "pass")
room_1 = self.helper.create_room_as(u1, tok=u1_token)
self.inject_room_member(room_1, "@user:other.example.com", "join")
"/get_missing_events/(?P<room_id>[^/]*)/?"
request, channel = self.make_request(
"POST",
"/_matrix/federation/v1/get_missing_events/%s" % (room_1,),
query_content,
)
self.render(request)
self.assertEquals(400, channel.code, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_NOT_JSON")
class ServerACLsTestCase(unittest.TestCase):
def test_blacklisted_server(self):
e = _create_acl_event({"allow": ["*"], "deny": ["evil.com"]})

View File

@ -16,6 +16,7 @@
from mock import Mock
from netaddr import IPSet
from parameterized import parameterized
from twisted.internet import defer
from twisted.internet.defer import TimeoutError
@ -511,3 +512,50 @@ class FederationClientTests(HomeserverTestCase):
self.reactor.advance(120)
self.assertTrue(conn.disconnecting)
@parameterized.expand([(b"",), (b"foo",), (b'{"a": Infinity}',)])
def test_json_error(self, return_value):
"""
Test what happens if invalid JSON is returned from the remote endpoint.
"""
test_d = defer.ensureDeferred(self.cl.get_json("testserv:8008", "foo/bar"))
self.pump()
# Nothing happened yet
self.assertNoResult(test_d)
# Make sure treq is trying to connect
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 1)
(host, port, factory, _timeout, _bindAddress) = clients[0]
self.assertEqual(host, "1.2.3.4")
self.assertEqual(port, 8008)
# complete the connection and wire it up to a fake transport
protocol = factory.buildProtocol(None)
transport = StringTransport()
protocol.makeConnection(transport)
# that should have made it send the request to the transport
self.assertRegex(transport.value(), b"^GET /foo/bar")
self.assertRegex(transport.value(), b"Host: testserv:8008")
# Deferred is still without a result
self.assertNoResult(test_d)
# Send it the HTTP response
protocol.dataReceived(
b"HTTP/1.1 200 OK\r\n"
b"Server: Fake\r\n"
b"Content-Type: application/json\r\n"
b"Content-Length: %i\r\n"
b"\r\n"
b"%s" % (len(return_value), return_value)
)
self.pump()
f = self.failureResultOf(test_d)
self.assertIsInstance(f.value, ValueError)

View File

@ -0,0 +1,80 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from io import BytesIO
from mock import Mock
from synapse.api.errors import SynapseError
from synapse.http.servlet import (
parse_json_object_from_request,
parse_json_value_from_request,
)
from tests import unittest
def make_request(content):
"""Make an object that acts enough like a request."""
request = Mock(spec=["content"])
if isinstance(content, dict):
content = json.dumps(content).encode("utf8")
request.content = BytesIO(content)
return request
class TestServletUtils(unittest.TestCase):
def test_parse_json_value(self):
"""Basic tests for parse_json_value_from_request."""
# Test round-tripping.
obj = {"foo": 1}
result = parse_json_value_from_request(make_request(obj))
self.assertEqual(result, obj)
# Results don't have to be objects.
result = parse_json_value_from_request(make_request(b'["foo"]'))
self.assertEqual(result, ["foo"])
# Test empty.
with self.assertRaises(SynapseError):
parse_json_value_from_request(make_request(b""))
result = parse_json_value_from_request(make_request(b""), allow_empty_body=True)
self.assertIsNone(result)
# Invalid UTF-8.
with self.assertRaises(SynapseError):
parse_json_value_from_request(make_request(b"\xFF\x00"))
# Invalid JSON.
with self.assertRaises(SynapseError):
parse_json_value_from_request(make_request(b"foo"))
with self.assertRaises(SynapseError):
parse_json_value_from_request(make_request(b'{"foo": Infinity}'))
def test_parse_json_object(self):
"""Basic tests for parse_json_object_from_request."""
# Test empty.
result = parse_json_object_from_request(
make_request(b""), allow_empty_body=True
)
self.assertEqual(result, {})
# Test not an object
with self.assertRaises(SynapseError):
parse_json_object_from_request(make_request(b'["foo"]'))

View File

@ -1174,6 +1174,8 @@ class RoomTestCase(unittest.HomeserverTestCase):
self.assertIn("room_id", channel.json_body)
self.assertIn("name", channel.json_body)
self.assertIn("topic", channel.json_body)
self.assertIn("avatar", channel.json_body)
self.assertIn("canonical_alias", channel.json_body)
self.assertIn("joined_members", channel.json_body)
self.assertIn("joined_local_members", channel.json_body)

View File

@ -71,7 +71,10 @@ async def inject_event(
"""
event, context = await create_event(hs, room_version, prev_event_ids, **kwargs)
await hs.get_storage().persistence.persist_event(event, context)
persistence = hs.get_storage().persistence
assert persistence is not None
await persistence.persist_event(event, context)
return event