Merge branch 'develop' into matrix-org-hotfixes

pull/8675/head
Brendan Abolivier 2020-03-09 15:06:56 +00:00
commit 74050d0c1c
83 changed files with 1912 additions and 491 deletions

View File

@ -6,12 +6,7 @@
set -ex set -ex
apt-get update apt-get update
apt-get install -y python3.5 python3.5-dev python3-pip libxml2-dev libxslt-dev zlib1g-dev apt-get install -y python3.5 python3.5-dev python3-pip libxml2-dev libxslt-dev zlib1g-dev tox
# workaround for https://github.com/jaraco/zipp/issues/40
python3.5 -m pip install 'setuptools>=34.4.0'
python3.5 -m pip install tox
export LANG="C.UTF-8" export LANG="C.UTF-8"

View File

@ -1,3 +1,18 @@
Synapse 1.11.1 (2020-03-03)
===========================
This release includes a security fix impacting installations using Single Sign-On (i.e. SAML2 or CAS) for authentication. Administrators of such installations are encouraged to upgrade as soon as possible.
The release also includes fixes for a couple of other bugs.
Bugfixes
--------
- Add a confirmation step to the SSO login flow before redirecting users to the redirect URL. ([b2bd54a2](https://github.com/matrix-org/synapse/commit/b2bd54a2e31d9a248f73fadb184ae9b4cbdb49f9), [65c73cdf](https://github.com/matrix-org/synapse/commit/65c73cdfec1876a9fec2fd2c3a74923cd146fe0b), [a0178df1](https://github.com/matrix-org/synapse/commit/a0178df10422a76fd403b82d2b2a4ed28a9a9d1e))
- Fixed set a user as an admin with the admin API `PUT /_synapse/admin/v2/users/<user_id>`. Contributed by @dklimpel. ([\#6910](https://github.com/matrix-org/synapse/issues/6910))
- Fix bug introduced in Synapse 1.11.0 which sometimes caused errors when joining rooms over federation, with `'coroutine' object has no attribute 'event_id'`. ([\#6996](https://github.com/matrix-org/synapse/issues/6996))
Synapse 1.11.0 (2020-02-21) Synapse 1.11.0 (2020-02-21)
=========================== ===========================

View File

@ -418,7 +418,7 @@ so, you will need to edit `homeserver.yaml`, as follows:
for having Synapse automatically provision and renew federation for having Synapse automatically provision and renew federation
certificates through ACME can be found at [ACME.md](docs/ACME.md). certificates through ACME can be found at [ACME.md](docs/ACME.md).
Note that, as pointed out in that document, this feature will not Note that, as pointed out in that document, this feature will not
work with installs set up after November 2020. work with installs set up after November 2019.
If you are using your own certificate, be sure to use a `.pem` file that If you are using your own certificate, be sure to use a `.pem` file that
includes the full certificate chain including any intermediate certificates includes the full certificate chain including any intermediate certificates

1
changelog.d/6309.misc Normal file
View File

@ -0,0 +1 @@
Add type hints to `logging/context.py`.

1
changelog.d/6315.feature Normal file
View File

@ -0,0 +1 @@
Expose the `synctl`, `hash_password` and `generate_config` commands in the snapcraft package. Contributed by @devec0.

1
changelog.d/6874.misc Normal file
View File

@ -0,0 +1 @@
Refactoring work in preparation for changing the event redaction algorithm.

1
changelog.d/6875.misc Normal file
View File

@ -0,0 +1 @@
Refactoring work in preparation for changing the event redaction algorithm.

1
changelog.d/6971.feature Normal file
View File

@ -0,0 +1 @@
Validate the alt_aliases property of canonical alias events.

1
changelog.d/6986.feature Normal file
View File

@ -0,0 +1 @@
Users with a power level sufficient to modify the canonical alias of a room can now delete room aliases.

1
changelog.d/6987.misc Normal file
View File

@ -0,0 +1 @@
Add some type annotations to the database storage classes.

1
changelog.d/6995.misc Normal file
View File

@ -0,0 +1 @@
Add some type annotations to the federation base & client classes.

1
changelog.d/7002.misc Normal file
View File

@ -0,0 +1 @@
Merge worker apps together.

1
changelog.d/7003.misc Normal file
View File

@ -0,0 +1 @@
Refactoring work in preparation for changing the event redaction algorithm.

1
changelog.d/7015.misc Normal file
View File

@ -0,0 +1 @@
Change date in INSTALL.md#tls-certificates for last date of getting TLS certificates to November 2019.

1
changelog.d/7018.bugfix Normal file
View File

@ -0,0 +1 @@
Fix py35-old CI by using native tox package.

1
changelog.d/7019.misc Normal file
View File

@ -0,0 +1 @@
Port `synapse.handlers.presence` to async/await.

1
changelog.d/7020.misc Normal file
View File

@ -0,0 +1 @@
Port `synapse.rest.keys` to async/await.

1
changelog.d/7030.feature Normal file
View File

@ -0,0 +1 @@
Break down monthly active users by `appservice_id` and emit via Prometheus.

1
changelog.d/7035.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a bug causing `org.matrix.dummy_event` to be included in responses from `/sync`.

1
changelog.d/7037.feature Normal file
View File

@ -0,0 +1 @@
Implement updated authorization rules and redaction rules for aliases events, from [MSC2261](https://github.com/matrix-org/matrix-doc/pull/2261) and [MSC2432](https://github.com/matrix-org/matrix-doc/pull/2432).

1
changelog.d/7045.misc Normal file
View File

@ -0,0 +1 @@
Add a type check to `is_verified` when processing room keys.

1
changelog.d/7048.doc Normal file
View File

@ -0,0 +1 @@
Document that the fallback auth endpoints must be routed to the same worker node as the register endpoints.

1
changelog.d/7055.misc Normal file
View File

@ -0,0 +1 @@
Merge worker apps together.

View File

@ -15,10 +15,9 @@ services:
restart: unless-stopped restart: unless-stopped
# See the readme for a full documentation of the environment settings # See the readme for a full documentation of the environment settings
environment: environment:
- SYNAPSE_CONFIG_PATH=/etc/homeserver.yaml - SYNAPSE_CONFIG_PATH=/data/homeserver.yaml
volumes: volumes:
# You may either store all the files in a local folder # You may either store all the files in a local folder
- ./matrix-config/homeserver.yaml:/etc/homeserver.yaml
- ./files:/data - ./files:/data
# .. or you may split this between different storage points # .. or you may split this between different storage points
# - ./files:/data # - ./files:/data

View File

@ -1,6 +1,6 @@
# Using the Synapse Grafana dashboard # Using the Synapse Grafana dashboard
0. Set up Prometheus and Grafana. Out of scope for this readme. Useful documentation about using Grafana with Prometheus: http://docs.grafana.org/features/datasources/prometheus/ 0. Set up Prometheus and Grafana. Out of scope for this readme. Useful documentation about using Grafana with Prometheus: http://docs.grafana.org/features/datasources/prometheus/
1. Have your Prometheus scrape your Synapse. https://github.com/matrix-org/synapse/blob/master/docs/metrics-howto.rst 1. Have your Prometheus scrape your Synapse. https://github.com/matrix-org/synapse/blob/master/docs/metrics-howto.md
2. Import dashboard into Grafana. Download `synapse.json`. Import it to Grafana and select the correct Prometheus datasource. http://docs.grafana.org/reference/export_import/ 2. Import dashboard into Grafana. Download `synapse.json`. Import it to Grafana and select the correct Prometheus datasource. http://docs.grafana.org/reference/export_import/
3. Set up additional recording rules 3. Set up additional recording rules

6
debian/changelog vendored
View File

@ -1,3 +1,9 @@
matrix-synapse-py3 (1.11.1) stable; urgency=medium
* New synapse release 1.11.1.
-- Synapse Packaging team <packages@matrix.org> Tue, 03 Mar 2020 15:01:22 +0000
matrix-synapse-py3 (1.11.0) stable; urgency=medium matrix-synapse-py3 (1.11.0) stable; urgency=medium
* New synapse release 1.11.0. * New synapse release 1.11.0.

View File

@ -1360,6 +1360,56 @@ saml2_config:
# # name: value # # name: value
# Additional settings to use with single-sign on systems such as SAML2 and CAS.
#
sso:
# A list of client URLs which are whitelisted so that the user does not
# have to confirm giving access to their account to the URL. Any client
# whose URL starts with an entry in the following list will not be subject
# to an additional confirmation step after the SSO login is completed.
#
# WARNING: An entry such as "https://my.client" is insecure, because it
# will also match "https://my.client.evil.site", exposing your users to
# phishing attacks from evil.site. To avoid this, include a slash after the
# hostname: "https://my.client/".
#
# By default, this list is empty.
#
#client_whitelist:
# - https://riot.im/develop
# - https://my.custom.client/
# Directory in which Synapse will try to find the template files below.
# If not set, default templates from within the Synapse package will be used.
#
# DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates.
# If you *do* uncomment it, you will need to make sure that all the templates
# below are in the directory.
#
# Synapse will look for the following templates in this directory:
#
# * HTML page for a confirmation step before redirecting back to the client
# with the login token: 'sso_redirect_confirm.html'.
#
# When rendering, this template is given three variables:
# * redirect_url: the URL the user is about to be redirected to. Needs
# manual escaping (see
# https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
#
# * display_url: the same as `redirect_url`, but with the query
# parameters stripped. The intention is to have a
# human-readable URL to show to users, not to use it as
# the final address to redirect to. Needs manual escaping
# (see https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
#
# * server_name: the homeserver's name.
#
# You can see the default templates at:
# https://github.com/matrix-org/synapse/tree/master/synapse/res/templates
#
#template_dir: "res/templates"
# The JWT needs to contain a globally unique "sub" (subject) claim. # The JWT needs to contain a globally unique "sub" (subject) claim.
# #
#jwt_config: #jwt_config:

View File

@ -273,6 +273,7 @@ Additionally, the following REST endpoints can be handled, but all requests must
be routed to the same instance: be routed to the same instance:
^/_matrix/client/(r0|unstable)/register$ ^/_matrix/client/(r0|unstable)/register$
^/_matrix/client/(r0|unstable)/auth/.*/fallback/web$
Pagination requests can also be handled, but all requests with the same path Pagination requests can also be handled, but all requests with the same path
room must be routed to the same instance. Additionally, care must be taken to room must be routed to the same instance. Additionally, care must be taken to

View File

@ -1,20 +1,31 @@
name: matrix-synapse name: matrix-synapse
base: core18 base: core18
version: git version: git
summary: Reference Matrix homeserver summary: Reference Matrix homeserver
description: | description: |
Synapse is the reference Matrix homeserver. Synapse is the reference Matrix homeserver.
Matrix is a federated and decentralised instant messaging and VoIP system. Matrix is a federated and decentralised instant messaging and VoIP system.
grade: stable grade: stable
confinement: strict confinement: strict
apps: apps:
matrix-synapse: matrix-synapse:
command: synctl --no-daemonize start $SNAP_COMMON/homeserver.yaml command: synctl --no-daemonize start $SNAP_COMMON/homeserver.yaml
stop-command: synctl -c $SNAP_COMMON stop stop-command: synctl -c $SNAP_COMMON stop
plugs: [network-bind, network] plugs: [network-bind, network]
daemon: simple daemon: simple
hash-password:
command: hash_password
generate-config:
command: generate_config
generate-signing-key:
command: generate_signing_key.py
register-new-matrix-user:
command: register_new_matrix_user
plugs: [network]
synctl:
command: synctl
parts: parts:
matrix-synapse: matrix-synapse:
source: . source: .

View File

@ -36,7 +36,7 @@ try:
except ImportError: except ImportError:
pass pass
__version__ = "1.11.0" __version__ = "1.11.1"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when # We import here so that we don't have to install a bunch of deps when

View File

@ -539,7 +539,7 @@ class Auth(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def check_can_change_room_list(self, room_id: str, user: UserID): def check_can_change_room_list(self, room_id: str, user: UserID):
"""Check if the user is allowed to edit the room's entry in the """Determine whether the user is allowed to edit the room's entry in the
published room list. published room list.
Args: Args:
@ -570,12 +570,7 @@ class Auth(object):
) )
user_level = event_auth.get_user_power_level(user_id, auth_events) user_level = event_auth.get_user_power_level(user_id, auth_events)
if user_level < send_level: return user_level >= send_level
raise AuthError(
403,
"This server requires you to be a moderator in the room to"
" edit its room list entry",
)
@staticmethod @staticmethod
def has_access_token(request): def has_access_token(request):

View File

@ -66,6 +66,7 @@ class Codes(object):
EXPIRED_ACCOUNT = "ORG_MATRIX_EXPIRED_ACCOUNT" EXPIRED_ACCOUNT = "ORG_MATRIX_EXPIRED_ACCOUNT"
INVALID_SIGNATURE = "M_INVALID_SIGNATURE" INVALID_SIGNATURE = "M_INVALID_SIGNATURE"
USER_DEACTIVATED = "M_USER_DEACTIVATED" USER_DEACTIVATED = "M_USER_DEACTIVATED"
BAD_ALIAS = "M_BAD_ALIAS"
class CodeMessageException(RuntimeError): class CodeMessageException(RuntimeError):

View File

@ -57,7 +57,7 @@ class RoomVersion(object):
state_res = attr.ib() # int; one of the StateResolutionVersions state_res = attr.ib() # int; one of the StateResolutionVersions
enforce_key_validity = attr.ib() # bool enforce_key_validity = attr.ib() # bool
# bool: before MSC2260, anyone was allowed to send an aliases event # bool: before MSC2261/MSC2432, m.room.aliases had special auth rules and redaction rules
special_case_aliases_auth = attr.ib(type=bool, default=False) special_case_aliases_auth = attr.ib(type=bool, default=False)
@ -102,12 +102,13 @@ class RoomVersions(object):
enforce_key_validity=True, enforce_key_validity=True,
special_case_aliases_auth=True, special_case_aliases_auth=True,
) )
MSC2260_DEV = RoomVersion( MSC2432_DEV = RoomVersion(
"org.matrix.msc2260", "org.matrix.msc2432",
RoomDisposition.UNSTABLE, RoomDisposition.UNSTABLE,
EventFormatVersions.V3, EventFormatVersions.V3,
StateResolutionVersions.V2, StateResolutionVersions.V2,
enforce_key_validity=True, enforce_key_validity=True,
special_case_aliases_auth=False,
) )
@ -119,6 +120,6 @@ KNOWN_ROOM_VERSIONS = {
RoomVersions.V3, RoomVersions.V3,
RoomVersions.V4, RoomVersions.V4,
RoomVersions.V5, RoomVersions.V5,
RoomVersions.MSC2260_DEV, RoomVersions.MSC2432_DEV,
) )
} # type: Dict[str, RoomVersion] } # type: Dict[str, RoomVersion]

View File

@ -494,20 +494,26 @@ class GenericWorkerServer(HomeServer):
elif name == "federation": elif name == "federation":
resources.update({FEDERATION_PREFIX: TransportLayerServer(self)}) resources.update({FEDERATION_PREFIX: TransportLayerServer(self)})
elif name == "media": elif name == "media":
media_repo = self.get_media_repository_resource() if self.config.can_load_media_repo:
media_repo = self.get_media_repository_resource()
# We need to serve the admin servlets for media on the # We need to serve the admin servlets for media on the
# worker. # worker.
admin_resource = JsonResource(self, canonical_json=False) admin_resource = JsonResource(self, canonical_json=False)
register_servlets_for_media_repo(self, admin_resource) register_servlets_for_media_repo(self, admin_resource)
resources.update( resources.update(
{ {
MEDIA_PREFIX: media_repo, MEDIA_PREFIX: media_repo,
LEGACY_MEDIA_PREFIX: media_repo, LEGACY_MEDIA_PREFIX: media_repo,
"/_synapse/admin": admin_resource, "/_synapse/admin": admin_resource,
} }
) )
else:
logger.warning(
"A 'media' listener is configured but the media"
" repository is disabled. Ignoring."
)
if name == "openid" and "federation" not in res["names"]: if name == "openid" and "federation" not in res["names"]:
# Only load the openid resource separately if federation resource # Only load the openid resource separately if federation resource

View File

@ -298,6 +298,11 @@ class SynapseHomeServer(HomeServer):
# Gauges to expose monthly active user control metrics # Gauges to expose monthly active user control metrics
current_mau_gauge = Gauge("synapse_admin_mau:current", "Current MAU") current_mau_gauge = Gauge("synapse_admin_mau:current", "Current MAU")
current_mau_by_service_gauge = Gauge(
"synapse_admin_mau_current_mau_by_service",
"Current MAU by service",
["app_service"],
)
max_mau_gauge = Gauge("synapse_admin_mau:max", "MAU Limit") max_mau_gauge = Gauge("synapse_admin_mau:max", "MAU Limit")
registered_reserved_users_mau_gauge = Gauge( registered_reserved_users_mau_gauge = Gauge(
"synapse_admin_mau:registered_reserved_users", "synapse_admin_mau:registered_reserved_users",
@ -585,12 +590,20 @@ def run(hs):
@defer.inlineCallbacks @defer.inlineCallbacks
def generate_monthly_active_users(): def generate_monthly_active_users():
current_mau_count = 0 current_mau_count = 0
current_mau_count_by_service = {}
reserved_users = () reserved_users = ()
store = hs.get_datastore() store = hs.get_datastore()
if hs.config.limit_usage_by_mau or hs.config.mau_stats_only: if hs.config.limit_usage_by_mau or hs.config.mau_stats_only:
current_mau_count = yield store.get_monthly_active_count() current_mau_count = yield store.get_monthly_active_count()
current_mau_count_by_service = (
yield store.get_monthly_active_count_by_service()
)
reserved_users = yield store.get_registered_reserved_users() reserved_users = yield store.get_registered_reserved_users()
current_mau_gauge.set(float(current_mau_count)) current_mau_gauge.set(float(current_mau_count))
for app_service, count in current_mau_count_by_service.items():
current_mau_by_service_gauge.labels(app_service).set(float(count))
registered_reserved_users_mau_gauge.set(float(len(reserved_users))) registered_reserved_users_mau_gauge.set(float(len(reserved_users)))
max_mau_gauge.set(float(hs.config.max_mau_value)) max_mau_gauge.set(float(hs.config.max_mau_value))

View File

@ -24,6 +24,7 @@ from synapse.config import (
server, server,
server_notices_config, server_notices_config,
spam_checker, spam_checker,
sso,
stats, stats,
third_party_event_rules, third_party_event_rules,
tls, tls,
@ -57,6 +58,7 @@ class RootConfig:
key: key.KeyConfig key: key.KeyConfig
saml2: saml2_config.SAML2Config saml2: saml2_config.SAML2Config
cas: cas.CasConfig cas: cas.CasConfig
sso: sso.SSOConfig
jwt: jwt_config.JWTConfig jwt: jwt_config.JWTConfig
password: password.PasswordConfig password: password.PasswordConfig
email: emailconfig.EmailConfig email: emailconfig.EmailConfig

View File

@ -38,6 +38,7 @@ from .saml2_config import SAML2Config
from .server import ServerConfig from .server import ServerConfig
from .server_notices_config import ServerNoticesConfig from .server_notices_config import ServerNoticesConfig
from .spam_checker import SpamCheckerConfig from .spam_checker import SpamCheckerConfig
from .sso import SSOConfig
from .stats import StatsConfig from .stats import StatsConfig
from .third_party_event_rules import ThirdPartyRulesConfig from .third_party_event_rules import ThirdPartyRulesConfig
from .tls import TlsConfig from .tls import TlsConfig
@ -65,6 +66,7 @@ class HomeServerConfig(RootConfig):
KeyConfig, KeyConfig,
SAML2Config, SAML2Config,
CasConfig, CasConfig,
SSOConfig,
JWTConfig, JWTConfig,
PasswordConfig, PasswordConfig,
EmailConfig, EmailConfig,

92
synapse/config/sso.py Normal file
View File

@ -0,0 +1,92 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict
import pkg_resources
from ._base import Config
class SSOConfig(Config):
"""SSO Configuration
"""
section = "sso"
def read_config(self, config, **kwargs):
sso_config = config.get("sso") or {} # type: Dict[str, Any]
# Pick a template directory in order of:
# * The sso-specific template_dir
# * /path/to/synapse/install/res/templates
template_dir = sso_config.get("template_dir")
if not template_dir:
template_dir = pkg_resources.resource_filename("synapse", "res/templates",)
self.sso_redirect_confirm_template_dir = template_dir
self.sso_client_whitelist = sso_config.get("client_whitelist") or []
def generate_config_section(self, **kwargs):
return """\
# Additional settings to use with single-sign on systems such as SAML2 and CAS.
#
sso:
# A list of client URLs which are whitelisted so that the user does not
# have to confirm giving access to their account to the URL. Any client
# whose URL starts with an entry in the following list will not be subject
# to an additional confirmation step after the SSO login is completed.
#
# WARNING: An entry such as "https://my.client" is insecure, because it
# will also match "https://my.client.evil.site", exposing your users to
# phishing attacks from evil.site. To avoid this, include a slash after the
# hostname: "https://my.client/".
#
# By default, this list is empty.
#
#client_whitelist:
# - https://riot.im/develop
# - https://my.custom.client/
# Directory in which Synapse will try to find the template files below.
# If not set, default templates from within the Synapse package will be used.
#
# DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates.
# If you *do* uncomment it, you will need to make sure that all the templates
# below are in the directory.
#
# Synapse will look for the following templates in this directory:
#
# * HTML page for a confirmation step before redirecting back to the client
# with the login token: 'sso_redirect_confirm.html'.
#
# When rendering, this template is given three variables:
# * redirect_url: the URL the user is about to be redirected to. Needs
# manual escaping (see
# https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
#
# * display_url: the same as `redirect_url`, but with the query
# parameters stripped. The intention is to have a
# human-readable URL to show to users, not to use it as
# the final address to redirect to. Needs manual escaping
# (see https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
#
# * server_name: the homeserver's name.
#
# You can see the default templates at:
# https://github.com/matrix-org/synapse/tree/master/synapse/res/templates
#
#template_dir: "res/templates"
"""

View File

@ -140,7 +140,7 @@ def compute_event_signature(
Returns: Returns:
a dictionary in the same format of an event's signatures field. a dictionary in the same format of an event's signatures field.
""" """
redact_json = prune_event_dict(event_dict) redact_json = prune_event_dict(room_version, event_dict)
redact_json.pop("age_ts", None) redact_json.pop("age_ts", None)
redact_json.pop("unsigned", None) redact_json.pop("unsigned", None)
if logger.isEnabledFor(logging.DEBUG): if logger.isEnabledFor(logging.DEBUG):

View File

@ -137,7 +137,7 @@ def check(
raise AuthError(403, "This room has been marked as unfederatable.") raise AuthError(403, "This room has been marked as unfederatable.")
# 4. If type is m.room.aliases # 4. If type is m.room.aliases
if event.type == EventTypes.Aliases: if event.type == EventTypes.Aliases and room_version_obj.special_case_aliases_auth:
# 4a. If event has no state_key, reject # 4a. If event has no state_key, reject
if not event.is_state(): if not event.is_state():
raise AuthError(403, "Alias event must be a state event") raise AuthError(403, "Alias event must be a state event")
@ -152,10 +152,8 @@ def check(
) )
# 4c. Otherwise, allow. # 4c. Otherwise, allow.
# This is removed by https://github.com/matrix-org/matrix-doc/pull/2260 logger.debug("Allowing! %s", event)
if room_version_obj.special_case_aliases_auth: return
logger.debug("Allowing! %s", event)
return
if logger.isEnabledFor(logging.DEBUG): if logger.isEnabledFor(logging.DEBUG):
logger.debug("Auth events: %s", [a.event_id for a in auth_events.values()]) logger.debug("Auth events: %s", [a.event_id for a in auth_events.values()])

View File

@ -15,9 +15,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import abc
import os import os
from distutils.util import strtobool from distutils.util import strtobool
from typing import Optional, Type from typing import Dict, Optional, Type
import six import six
@ -199,15 +200,25 @@ class _EventInternalMetadata(object):
return self._dict.get("redacted", False) return self._dict.get("redacted", False)
class EventBase(object): class EventBase(metaclass=abc.ABCMeta):
@property
@abc.abstractmethod
def format_version(self) -> int:
"""The EventFormatVersion implemented by this event"""
...
def __init__( def __init__(
self, self,
event_dict, event_dict: JsonDict,
signatures={}, room_version: RoomVersion,
unsigned={}, signatures: Dict[str, Dict[str, str]],
internal_metadata_dict={}, unsigned: JsonDict,
rejected_reason=None, internal_metadata_dict: JsonDict,
rejected_reason: Optional[str],
): ):
assert room_version.event_format == self.format_version
self.room_version = room_version
self.signatures = signatures self.signatures = signatures
self.unsigned = unsigned self.unsigned = unsigned
self.rejected_reason = rejected_reason self.rejected_reason = rejected_reason
@ -303,7 +314,13 @@ class EventBase(object):
class FrozenEvent(EventBase): class FrozenEvent(EventBase):
format_version = EventFormatVersions.V1 # All events of this type are V1 format_version = EventFormatVersions.V1 # All events of this type are V1
def __init__(self, event_dict, internal_metadata_dict={}, rejected_reason=None): def __init__(
self,
event_dict: JsonDict,
room_version: RoomVersion,
internal_metadata_dict: JsonDict = {},
rejected_reason: Optional[str] = None,
):
event_dict = dict(event_dict) event_dict = dict(event_dict)
# Signatures is a dict of dicts, and this is faster than doing a # Signatures is a dict of dicts, and this is faster than doing a
@ -326,8 +343,9 @@ class FrozenEvent(EventBase):
self._event_id = event_dict["event_id"] self._event_id = event_dict["event_id"]
super(FrozenEvent, self).__init__( super().__init__(
frozen_dict, frozen_dict,
room_version=room_version,
signatures=signatures, signatures=signatures,
unsigned=unsigned, unsigned=unsigned,
internal_metadata_dict=internal_metadata_dict, internal_metadata_dict=internal_metadata_dict,
@ -352,7 +370,13 @@ class FrozenEvent(EventBase):
class FrozenEventV2(EventBase): class FrozenEventV2(EventBase):
format_version = EventFormatVersions.V2 # All events of this type are V2 format_version = EventFormatVersions.V2 # All events of this type are V2
def __init__(self, event_dict, internal_metadata_dict={}, rejected_reason=None): def __init__(
self,
event_dict: JsonDict,
room_version: RoomVersion,
internal_metadata_dict: JsonDict = {},
rejected_reason: Optional[str] = None,
):
event_dict = dict(event_dict) event_dict = dict(event_dict)
# Signatures is a dict of dicts, and this is faster than doing a # Signatures is a dict of dicts, and this is faster than doing a
@ -377,8 +401,9 @@ class FrozenEventV2(EventBase):
self._event_id = None self._event_id = None
super(FrozenEventV2, self).__init__( super().__init__(
frozen_dict, frozen_dict,
room_version=room_version,
signatures=signatures, signatures=signatures,
unsigned=unsigned, unsigned=unsigned,
internal_metadata_dict=internal_metadata_dict, internal_metadata_dict=internal_metadata_dict,
@ -445,7 +470,7 @@ class FrozenEventV3(FrozenEventV2):
return self._event_id return self._event_id
def event_type_from_format_version(format_version: int) -> Type[EventBase]: def _event_type_from_format_version(format_version: int) -> Type[EventBase]:
"""Returns the python type to use to construct an Event object for the """Returns the python type to use to construct an Event object for the
given event format version. given event format version.
@ -474,5 +499,5 @@ def make_event_from_dict(
rejected_reason: Optional[str] = None, rejected_reason: Optional[str] = None,
) -> EventBase: ) -> EventBase:
"""Construct an EventBase from the given event dict""" """Construct an EventBase from the given event dict"""
event_type = event_type_from_format_version(room_version.event_format) event_type = _event_type_from_format_version(room_version.event_format)
return event_type(event_dict, internal_metadata_dict, rejected_reason) return event_type(event_dict, room_version, internal_metadata_dict, rejected_reason)

View File

@ -23,6 +23,7 @@ from frozendict import frozendict
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, RelationTypes from synapse.api.constants import EventTypes, RelationTypes
from synapse.api.room_versions import RoomVersion
from synapse.util.async_helpers import yieldable_gather_results from synapse.util.async_helpers import yieldable_gather_results
from . import EventBase from . import EventBase
@ -35,26 +36,20 @@ from . import EventBase
SPLIT_FIELD_REGEX = re.compile(r"(?<!\\)\.") SPLIT_FIELD_REGEX = re.compile(r"(?<!\\)\.")
def prune_event(event): def prune_event(event: EventBase) -> EventBase:
""" Returns a pruned version of the given event, which removes all keys we """ Returns a pruned version of the given event, which removes all keys we
don't know about or think could potentially be dodgy. don't know about or think could potentially be dodgy.
This is used when we "redact" an event. We want to remove all fields that This is used when we "redact" an event. We want to remove all fields that
the user has specified, but we do want to keep necessary information like the user has specified, but we do want to keep necessary information like
type, state_key etc. type, state_key etc.
Args:
event (FrozenEvent)
Returns:
FrozenEvent
""" """
pruned_event_dict = prune_event_dict(event.get_dict()) pruned_event_dict = prune_event_dict(event.room_version, event.get_dict())
from . import event_type_from_format_version from . import make_event_from_dict
pruned_event = event_type_from_format_version(event.format_version)( pruned_event = make_event_from_dict(
pruned_event_dict, event.internal_metadata.get_dict() pruned_event_dict, event.room_version, event.internal_metadata.get_dict()
) )
# Mark the event as redacted # Mark the event as redacted
@ -63,15 +58,12 @@ def prune_event(event):
return pruned_event return pruned_event
def prune_event_dict(event_dict): def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
"""Redacts the event_dict in the same way as `prune_event`, except it """Redacts the event_dict in the same way as `prune_event`, except it
operates on dicts rather than event objects operates on dicts rather than event objects
Args:
event_dict (dict)
Returns: Returns:
dict: A copy of the pruned event dict A copy of the pruned event dict
""" """
allowed_keys = [ allowed_keys = [
@ -118,7 +110,7 @@ def prune_event_dict(event_dict):
"kick", "kick",
"redact", "redact",
) )
elif event_type == EventTypes.Aliases: elif event_type == EventTypes.Aliases and room_version.special_case_aliases_auth:
add_fields("aliases") add_fields("aliases")
elif event_type == EventTypes.RoomHistoryVisibility: elif event_type == EventTypes.RoomHistoryVisibility:
add_fields("history_visibility") add_fields("history_visibility")

View File

@ -15,11 +15,13 @@
# limitations under the License. # limitations under the License.
import logging import logging
from collections import namedtuple from collections import namedtuple
from typing import Iterable, List
import six import six
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.defer import DeferredList from twisted.internet.defer import Deferred, DeferredList
from twisted.python.failure import Failure
from synapse.api.constants import MAX_DEPTH, EventTypes, Membership from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
@ -29,6 +31,7 @@ from synapse.api.room_versions import (
RoomVersion, RoomVersion,
) )
from synapse.crypto.event_signing import check_event_content_hash from synapse.crypto.event_signing import check_event_content_hash
from synapse.crypto.keyring import Keyring
from synapse.events import EventBase, make_event_from_dict from synapse.events import EventBase, make_event_from_dict
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
from synapse.http.servlet import assert_params_in_dict from synapse.http.servlet import assert_params_in_dict
@ -56,7 +59,12 @@ class FederationBase(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_sigs_and_hash_and_fetch( def _check_sigs_and_hash_and_fetch(
self, origin, pdus, room_version, outlier=False, include_none=False self,
origin: str,
pdus: List[EventBase],
room_version: str,
outlier: bool = False,
include_none: bool = False,
): ):
"""Takes a list of PDUs and checks the signatures and hashs of each """Takes a list of PDUs and checks the signatures and hashs of each
one. If a PDU fails its signature check then we check if we have it in one. If a PDU fails its signature check then we check if we have it in
@ -69,11 +77,11 @@ class FederationBase(object):
a new list. a new list.
Args: Args:
origin (str) origin
pdu (list) pdu
room_version (str) room_version
outlier (bool): Whether the events are outliers or not outlier: Whether the events are outliers or not
include_none (str): Whether to include None in the returned list include_none: Whether to include None in the returned list
for events that have failed their checks for events that have failed their checks
Returns: Returns:
@ -82,7 +90,7 @@ class FederationBase(object):
deferreds = self._check_sigs_and_hashes(room_version, pdus) deferreds = self._check_sigs_and_hashes(room_version, pdus)
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_check_result(pdu, deferred): def handle_check_result(pdu: EventBase, deferred: Deferred):
try: try:
res = yield make_deferred_yieldable(deferred) res = yield make_deferred_yieldable(deferred)
except SynapseError: except SynapseError:
@ -96,12 +104,16 @@ class FederationBase(object):
if not res and pdu.origin != origin: if not res and pdu.origin != origin:
try: try:
res = yield self.get_pdu( # This should not exist in the base implementation, until
destinations=[pdu.origin], # this is fixed, ignore it for typing. See issue #6997.
event_id=pdu.event_id, res = yield defer.ensureDeferred(
room_version=room_version, self.get_pdu( # type: ignore
outlier=outlier, destinations=[pdu.origin],
timeout=10000, event_id=pdu.event_id,
room_version=room_version,
outlier=outlier,
timeout=10000,
)
) )
except SynapseError: except SynapseError:
pass pass
@ -125,21 +137,23 @@ class FederationBase(object):
else: else:
return [p for p in valid_pdus if p] return [p for p in valid_pdus if p]
def _check_sigs_and_hash(self, room_version, pdu): def _check_sigs_and_hash(self, room_version: str, pdu: EventBase) -> Deferred:
return make_deferred_yieldable( return make_deferred_yieldable(
self._check_sigs_and_hashes(room_version, [pdu])[0] self._check_sigs_and_hashes(room_version, [pdu])[0]
) )
def _check_sigs_and_hashes(self, room_version, pdus): def _check_sigs_and_hashes(
self, room_version: str, pdus: List[EventBase]
) -> List[Deferred]:
"""Checks that each of the received events is correctly signed by the """Checks that each of the received events is correctly signed by the
sending server. sending server.
Args: Args:
room_version (str): The room version of the PDUs room_version: The room version of the PDUs
pdus (list[FrozenEvent]): the events to be checked pdus: the events to be checked
Returns: Returns:
list[Deferred]: for each input event, a deferred which: For each input event, a deferred which:
* returns the original event if the checks pass * returns the original event if the checks pass
* returns a redacted version of the event (if the signature * returns a redacted version of the event (if the signature
matched but the hash did not) matched but the hash did not)
@ -150,7 +164,7 @@ class FederationBase(object):
ctx = LoggingContext.current_context() ctx = LoggingContext.current_context()
def callback(_, pdu): def callback(_, pdu: EventBase):
with PreserveLoggingContext(ctx): with PreserveLoggingContext(ctx):
if not check_event_content_hash(pdu): if not check_event_content_hash(pdu):
# let's try to distinguish between failures because the event was # let's try to distinguish between failures because the event was
@ -187,7 +201,7 @@ class FederationBase(object):
return pdu return pdu
def errback(failure, pdu): def errback(failure: Failure, pdu: EventBase):
failure.trap(SynapseError) failure.trap(SynapseError)
with PreserveLoggingContext(ctx): with PreserveLoggingContext(ctx):
logger.warning( logger.warning(
@ -213,16 +227,18 @@ class PduToCheckSig(
pass pass
def _check_sigs_on_pdus(keyring, room_version, pdus): def _check_sigs_on_pdus(
keyring: Keyring, room_version: str, pdus: Iterable[EventBase]
) -> List[Deferred]:
"""Check that the given events are correctly signed """Check that the given events are correctly signed
Args: Args:
keyring (synapse.crypto.Keyring): keyring object to do the checks keyring: keyring object to do the checks
room_version (str): the room version of the PDUs room_version: the room version of the PDUs
pdus (Collection[EventBase]): the events to be checked pdus: the events to be checked
Returns: Returns:
List[Deferred]: a Deferred for each event in pdus, which will either succeed if A Deferred for each event in pdus, which will either succeed if
the signatures are valid, or fail (with a SynapseError) if not. the signatures are valid, or fail (with a SynapseError) if not.
""" """
@ -327,7 +343,7 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
return [_flatten_deferred_list(p.deferreds) for p in pdus_to_check] return [_flatten_deferred_list(p.deferreds) for p in pdus_to_check]
def _flatten_deferred_list(deferreds): def _flatten_deferred_list(deferreds: List[Deferred]) -> Deferred:
"""Given a list of deferreds, either return the single deferred, """Given a list of deferreds, either return the single deferred,
combine into a DeferredList, or return an already resolved deferred. combine into a DeferredList, or return an already resolved deferred.
""" """
@ -339,7 +355,7 @@ def _flatten_deferred_list(deferreds):
return defer.succeed(None) return defer.succeed(None)
def _is_invite_via_3pid(event): def _is_invite_via_3pid(event: EventBase) -> bool:
return ( return (
event.type == EventTypes.Member event.type == EventTypes.Member
and event.membership == Membership.INVITE and event.membership == Membership.INVITE

View File

@ -187,7 +187,7 @@ class FederationClient(FederationBase):
async def backfill( async def backfill(
self, dest: str, room_id: str, limit: int, extremities: Iterable[str] self, dest: str, room_id: str, limit: int, extremities: Iterable[str]
) -> List[EventBase]: ) -> Optional[List[EventBase]]:
"""Requests some more historic PDUs for the given room from the """Requests some more historic PDUs for the given room from the
given destination server. given destination server.
@ -199,9 +199,9 @@ class FederationClient(FederationBase):
""" """
logger.debug("backfill extrem=%s", extremities) logger.debug("backfill extrem=%s", extremities)
# If there are no extremeties then we've (probably) reached the start. # If there are no extremities then we've (probably) reached the start.
if not extremities: if not extremities:
return return None
transaction_data = await self.transport_layer.backfill( transaction_data = await self.transport_layer.backfill(
dest, room_id, extremities, limit dest, room_id, extremities, limit
@ -284,7 +284,7 @@ class FederationClient(FederationBase):
pdu_list = [ pdu_list = [
event_from_pdu_json(p, room_version, outlier=outlier) event_from_pdu_json(p, room_version, outlier=outlier)
for p in transaction_data["pdus"] for p in transaction_data["pdus"]
] ] # type: List[EventBase]
if pdu_list and pdu_list[0]: if pdu_list and pdu_list[0]:
pdu = pdu_list[0] pdu = pdu_list[0]
@ -615,7 +615,7 @@ class FederationClient(FederationBase):
] ]
if auth_chain_create_events != [create_event.event_id]: if auth_chain_create_events != [create_event.event_id]:
raise InvalidResponseError( raise InvalidResponseError(
"Unexpected create event(s) in auth chain" "Unexpected create event(s) in auth chain: %s"
% (auth_chain_create_events,) % (auth_chain_create_events,)
) )

View File

@ -17,6 +17,8 @@
import logging import logging
import time import time
import unicodedata import unicodedata
import urllib.parse
from typing import Any
import attr import attr
import bcrypt import bcrypt
@ -38,8 +40,11 @@ from synapse.api.errors import (
from synapse.api.ratelimiting import Ratelimiter from synapse.api.ratelimiting import Ratelimiter
from synapse.handlers.ui_auth import INTERACTIVE_AUTH_CHECKERS from synapse.handlers.ui_auth import INTERACTIVE_AUTH_CHECKERS
from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
from synapse.http.server import finish_request
from synapse.http.site import SynapseRequest
from synapse.logging.context import defer_to_thread from synapse.logging.context import defer_to_thread
from synapse.module_api import ModuleApi from synapse.module_api import ModuleApi
from synapse.push.mailer import load_jinja2_templates
from synapse.types import UserID from synapse.types import UserID
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
@ -108,6 +113,16 @@ class AuthHandler(BaseHandler):
self._clock = self.hs.get_clock() self._clock = self.hs.get_clock()
# Load the SSO redirect confirmation page HTML template
self._sso_redirect_confirm_template = load_jinja2_templates(
hs.config.sso_redirect_confirm_template_dir, ["sso_redirect_confirm.html"],
)[0]
self._server_name = hs.config.server_name
# cast to tuple for use with str.startswith
self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist)
@defer.inlineCallbacks @defer.inlineCallbacks
def validate_user_via_ui_auth(self, requester, request_body, clientip): def validate_user_via_ui_auth(self, requester, request_body, clientip):
""" """
@ -927,6 +942,65 @@ class AuthHandler(BaseHandler):
else: else:
return defer.succeed(False) return defer.succeed(False)
def complete_sso_login(
self,
registered_user_id: str,
request: SynapseRequest,
client_redirect_url: str,
):
"""Having figured out a mxid for this user, complete the HTTP request
Args:
registered_user_id: The registered user ID to complete SSO login for.
request: The request to complete.
client_redirect_url: The URL to which to redirect the user at the end of the
process.
"""
# Create a login token
login_token = self.macaroon_gen.generate_short_term_login_token(
registered_user_id
)
# Append the login token to the original redirect URL (i.e. with its query
# parameters kept intact) to build the URL to which the template needs to
# redirect the users once they have clicked on the confirmation link.
redirect_url = self.add_query_param_to_url(
client_redirect_url, "loginToken", login_token
)
# if the client is whitelisted, we can redirect straight to it
if client_redirect_url.startswith(self._whitelisted_sso_clients):
request.redirect(redirect_url)
finish_request(request)
return
# Otherwise, serve the redirect confirmation page.
# Remove the query parameters from the redirect URL to get a shorter version of
# it. This is only to display a human-readable URL in the template, but not the
# URL we redirect users to.
redirect_url_no_params = client_redirect_url.split("?")[0]
html = self._sso_redirect_confirm_template.render(
display_url=redirect_url_no_params,
redirect_url=redirect_url,
server_name=self._server_name,
).encode("utf-8")
request.setResponseCode(200)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(b"Content-Length", b"%d" % (len(html),))
request.write(html)
finish_request(request)
@staticmethod
def add_query_param_to_url(url: str, param_name: str, param: Any):
url_parts = list(urllib.parse.urlparse(url))
query = dict(urllib.parse.parse_qsl(url_parts[4]))
query.update({param_name: param})
url_parts[4] = urllib.parse.urlencode(query)
return urllib.parse.urlunparse(url_parts)
@attr.s @attr.s
class MacaroonGenerator(object): class MacaroonGenerator(object):

View File

@ -13,11 +13,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import collections
import logging import logging
import string import string
from typing import List from typing import Iterable, List, Optional
from twisted.internet import defer from twisted.internet import defer
@ -30,6 +28,7 @@ from synapse.api.errors import (
StoreError, StoreError,
SynapseError, SynapseError,
) )
from synapse.appservice import ApplicationService
from synapse.types import Requester, RoomAlias, UserID, get_domain_from_id from synapse.types import Requester, RoomAlias, UserID, get_domain_from_id
from ._base import BaseHandler from ._base import BaseHandler
@ -57,7 +56,13 @@ class DirectoryHandler(BaseHandler):
self.spam_checker = hs.get_spam_checker() self.spam_checker = hs.get_spam_checker()
@defer.inlineCallbacks @defer.inlineCallbacks
def _create_association(self, room_alias, room_id, servers=None, creator=None): def _create_association(
self,
room_alias: RoomAlias,
room_id: str,
servers: Optional[Iterable[str]] = None,
creator: Optional[str] = None,
):
# general association creation for both human users and app services # general association creation for both human users and app services
for wchar in string.whitespace: for wchar in string.whitespace:
@ -83,17 +88,21 @@ class DirectoryHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def create_association( def create_association(
self, requester, room_alias, room_id, servers=None, check_membership=True, self,
requester: Requester,
room_alias: RoomAlias,
room_id: str,
servers: Optional[List[str]] = None,
check_membership: bool = True,
): ):
"""Attempt to create a new alias """Attempt to create a new alias
Args: Args:
requester (Requester) requester
room_alias (RoomAlias) room_alias
room_id (str) room_id
servers (list[str]|None): List of servers that others servers servers: Iterable of servers that others servers should try and join via
should try and join via check_membership: Whether to check if the user is in the room
check_membership (bool): Whether to check if the user is in the room
before the alias can be set (if the server's config requires it). before the alias can be set (if the server's config requires it).
Returns: Returns:
@ -147,15 +156,15 @@ class DirectoryHandler(BaseHandler):
yield self._create_association(room_alias, room_id, servers, creator=user_id) yield self._create_association(room_alias, room_id, servers, creator=user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def delete_association(self, requester, room_alias): def delete_association(self, requester: Requester, room_alias: RoomAlias):
"""Remove an alias from the directory """Remove an alias from the directory
(this is only meant for human users; AS users should call (this is only meant for human users; AS users should call
delete_appservice_association) delete_appservice_association)
Args: Args:
requester (Requester): requester
room_alias (RoomAlias): room_alias
Returns: Returns:
Deferred[unicode]: room id that the alias used to point to Deferred[unicode]: room id that the alias used to point to
@ -191,16 +200,16 @@ class DirectoryHandler(BaseHandler):
room_id = yield self._delete_association(room_alias) room_id = yield self._delete_association(room_alias)
try: try:
yield self._update_canonical_alias( yield self._update_canonical_alias(requester, user_id, room_id, room_alias)
requester, requester.user.to_string(), room_id, room_alias
)
except AuthError as e: except AuthError as e:
logger.info("Failed to update alias events: %s", e) logger.info("Failed to update alias events: %s", e)
return room_id return room_id
@defer.inlineCallbacks @defer.inlineCallbacks
def delete_appservice_association(self, service, room_alias): def delete_appservice_association(
self, service: ApplicationService, room_alias: RoomAlias
):
if not service.is_interested_in_alias(room_alias.to_string()): if not service.is_interested_in_alias(room_alias.to_string()):
raise SynapseError( raise SynapseError(
400, 400,
@ -210,7 +219,7 @@ class DirectoryHandler(BaseHandler):
yield self._delete_association(room_alias) yield self._delete_association(room_alias)
@defer.inlineCallbacks @defer.inlineCallbacks
def _delete_association(self, room_alias): def _delete_association(self, room_alias: RoomAlias):
if not self.hs.is_mine(room_alias): if not self.hs.is_mine(room_alias):
raise SynapseError(400, "Room alias must be local") raise SynapseError(400, "Room alias must be local")
@ -219,7 +228,7 @@ class DirectoryHandler(BaseHandler):
return room_id return room_id
@defer.inlineCallbacks @defer.inlineCallbacks
def get_association(self, room_alias): def get_association(self, room_alias: RoomAlias):
room_id = None room_id = None
if self.hs.is_mine(room_alias): if self.hs.is_mine(room_alias):
result = yield self.get_association_from_room_alias(room_alias) result = yield self.get_association_from_room_alias(room_alias)
@ -284,7 +293,9 @@ class DirectoryHandler(BaseHandler):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def _update_canonical_alias(self, requester, user_id, room_id, room_alias): def _update_canonical_alias(
self, requester: Requester, user_id: str, room_id: str, room_alias: RoomAlias
):
""" """
Send an updated canonical alias event if the removed alias was set as Send an updated canonical alias event if the removed alias was set as
the canonical alias or listed in the alt_aliases field. the canonical alias or listed in the alt_aliases field.
@ -307,15 +318,17 @@ class DirectoryHandler(BaseHandler):
send_update = True send_update = True
content.pop("alias", "") content.pop("alias", "")
# Filter alt_aliases for the removed alias. # Filter the alt_aliases property for the removed alias. Note that the
alt_aliases = content.pop("alt_aliases", None) # value is not modified if alt_aliases is of an unexpected form.
# If the aliases are not a list (or not found) do not attempt to modify alt_aliases = content.get("alt_aliases")
# the list. if isinstance(alt_aliases, (list, tuple)) and alias_str in alt_aliases:
if isinstance(alt_aliases, collections.Sequence):
send_update = True send_update = True
alt_aliases = [alias for alias in alt_aliases if alias != alias_str] alt_aliases = [alias for alias in alt_aliases if alias != alias_str]
if alt_aliases: if alt_aliases:
content["alt_aliases"] = alt_aliases content["alt_aliases"] = alt_aliases
else:
del content["alt_aliases"]
if send_update: if send_update:
yield self.event_creation_handler.create_and_send_nonmember_event( yield self.event_creation_handler.create_and_send_nonmember_event(
@ -331,7 +344,7 @@ class DirectoryHandler(BaseHandler):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def get_association_from_room_alias(self, room_alias): def get_association_from_room_alias(self, room_alias: RoomAlias):
result = yield self.store.get_association_from_room_alias(room_alias) result = yield self.store.get_association_from_room_alias(room_alias)
if not result: if not result:
# Query AS to see if it exists # Query AS to see if it exists
@ -339,7 +352,7 @@ class DirectoryHandler(BaseHandler):
result = yield as_handler.query_room_alias_exists(room_alias) result = yield as_handler.query_room_alias_exists(room_alias)
return result return result
def can_modify_alias(self, alias, user_id=None): def can_modify_alias(self, alias: RoomAlias, user_id: Optional[str] = None):
# Any application service "interested" in an alias they are regexing on # Any application service "interested" in an alias they are regexing on
# can modify the alias. # can modify the alias.
# Users can only modify the alias if ALL the interested services have # Users can only modify the alias if ALL the interested services have
@ -360,22 +373,42 @@ class DirectoryHandler(BaseHandler):
return defer.succeed(True) return defer.succeed(True)
@defer.inlineCallbacks @defer.inlineCallbacks
def _user_can_delete_alias(self, alias, user_id): def _user_can_delete_alias(self, alias: RoomAlias, user_id: str):
"""Determine whether a user can delete an alias.
One of the following must be true:
1. The user created the alias.
2. The user is a server administrator.
3. The user has a power-level sufficient to send a canonical alias event
for the current room.
"""
creator = yield self.store.get_room_alias_creator(alias.to_string()) creator = yield self.store.get_room_alias_creator(alias.to_string())
if creator is not None and creator == user_id: if creator is not None and creator == user_id:
return True return True
is_admin = yield self.auth.is_server_admin(UserID.from_string(user_id)) # Resolve the alias to the corresponding room.
return is_admin room_mapping = yield self.get_association(alias)
room_id = room_mapping["room_id"]
if not room_id:
return False
res = yield self.auth.check_can_change_room_list(
room_id, UserID.from_string(user_id)
)
return res
@defer.inlineCallbacks @defer.inlineCallbacks
def edit_published_room_list(self, requester, room_id, visibility): def edit_published_room_list(
self, requester: Requester, room_id: str, visibility: str
):
"""Edit the entry of the room in the published room list. """Edit the entry of the room in the published room list.
requester requester
room_id (str) room_id
visibility (str): "public" or "private" visibility: "public" or "private"
""" """
user_id = requester.user.to_string() user_id = requester.user.to_string()
@ -400,7 +433,15 @@ class DirectoryHandler(BaseHandler):
if room is None: if room is None:
raise SynapseError(400, "Unknown room") raise SynapseError(400, "Unknown room")
yield self.auth.check_can_change_room_list(room_id, requester.user) can_change_room_list = yield self.auth.check_can_change_room_list(
room_id, requester.user
)
if not can_change_room_list:
raise AuthError(
403,
"This server requires you to be a moderator in the room to"
" edit its room list entry",
)
making_public = visibility == "public" making_public = visibility == "public"
if making_public: if making_public:
@ -421,16 +462,16 @@ class DirectoryHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def edit_published_appservice_room_list( def edit_published_appservice_room_list(
self, appservice_id, network_id, room_id, visibility self, appservice_id: str, network_id: str, room_id: str, visibility: str
): ):
"""Add or remove a room from the appservice/network specific public """Add or remove a room from the appservice/network specific public
room list. room list.
Args: Args:
appservice_id (str): ID of the appservice that owns the list appservice_id: ID of the appservice that owns the list
network_id (str): The ID of the network the list is associated with network_id: The ID of the network the list is associated with
room_id (str) room_id
visibility (str): either "public" or "private" visibility: either "public" or "private"
""" """
if visibility not in ["public", "private"]: if visibility not in ["public", "private"]:
raise SynapseError(400, "Invalid visibility setting") raise SynapseError(400, "Invalid visibility setting")

View File

@ -207,6 +207,13 @@ class E2eRoomKeysHandler(object):
changed = False # if anything has changed, we need to update the etag changed = False # if anything has changed, we need to update the etag
for room_id, room in iteritems(room_keys["rooms"]): for room_id, room in iteritems(room_keys["rooms"]):
for session_id, room_key in iteritems(room["sessions"]): for session_id, room_key in iteritems(room["sessions"]):
if not isinstance(room_key["is_verified"], bool):
msg = (
"is_verified must be a boolean in keys for session %s in"
"room %s" % (session_id, room_id)
)
raise SynapseError(400, msg, Codes.INVALID_PARAM)
log_kv( log_kv(
{ {
"message": "Trying to upload room key", "message": "Trying to upload room key",

View File

@ -888,19 +888,60 @@ class EventCreationHandler(object):
yield self.base_handler.maybe_kick_guest_users(event, context) yield self.base_handler.maybe_kick_guest_users(event, context)
if event.type == EventTypes.CanonicalAlias: if event.type == EventTypes.CanonicalAlias:
# Check the alias is acually valid (at this time at least) # Validate a newly added alias or newly added alt_aliases.
original_alias = None
original_alt_aliases = set()
original_event_id = event.unsigned.get("replaces_state")
if original_event_id:
original_event = yield self.store.get_event(original_event_id)
if original_event:
original_alias = original_event.content.get("alias", None)
original_alt_aliases = original_event.content.get("alt_aliases", [])
# Check the alias is currently valid (if it has changed).
room_alias_str = event.content.get("alias", None) room_alias_str = event.content.get("alias", None)
if room_alias_str: directory_handler = self.hs.get_handlers().directory_handler
if room_alias_str and room_alias_str != original_alias:
room_alias = RoomAlias.from_string(room_alias_str) room_alias = RoomAlias.from_string(room_alias_str)
directory_handler = self.hs.get_handlers().directory_handler
mapping = yield directory_handler.get_association(room_alias) mapping = yield directory_handler.get_association(room_alias)
if mapping["room_id"] != event.room_id: if mapping["room_id"] != event.room_id:
raise SynapseError( raise SynapseError(
400, 400,
"Room alias %s does not point to the room" % (room_alias_str,), "Room alias %s does not point to the room" % (room_alias_str,),
Codes.BAD_ALIAS,
) )
# Check that alt_aliases is the proper form.
alt_aliases = event.content.get("alt_aliases", [])
if not isinstance(alt_aliases, (list, tuple)):
raise SynapseError(
400, "The alt_aliases property must be a list.", Codes.INVALID_PARAM
)
# If the old version of alt_aliases is of an unknown form,
# completely replace it.
if not isinstance(original_alt_aliases, (list, tuple)):
original_alt_aliases = []
# Check that each alias is currently valid.
new_alt_aliases = set(alt_aliases) - set(original_alt_aliases)
if new_alt_aliases:
for alias_str in new_alt_aliases:
room_alias = RoomAlias.from_string(alias_str)
mapping = yield directory_handler.get_association(room_alias)
if mapping["room_id"] != event.room_id:
raise SynapseError(
400,
"Room alias %s does not point to the room"
% (room_alias_str,),
Codes.BAD_ALIAS,
)
federation_handler = self.hs.get_handlers().federation_handler federation_handler = self.hs.get_handlers().federation_handler
if event.type == EventTypes.Member: if event.type == EventTypes.Member:

View File

@ -25,7 +25,6 @@ from synapse.api.errors import SynapseError
from synapse.config import ConfigError from synapse.config import ConfigError
from synapse.http.servlet import parse_string from synapse.http.servlet import parse_string
from synapse.module_api import ModuleApi from synapse.module_api import ModuleApi
from synapse.rest.client.v1.login import SSOAuthHandler
from synapse.types import ( from synapse.types import (
UserID, UserID,
map_username_to_mxid_localpart, map_username_to_mxid_localpart,
@ -48,7 +47,7 @@ class Saml2SessionData:
class SamlHandler: class SamlHandler:
def __init__(self, hs): def __init__(self, hs):
self._saml_client = Saml2Client(hs.config.saml2_sp_config) self._saml_client = Saml2Client(hs.config.saml2_sp_config)
self._sso_auth_handler = SSOAuthHandler(hs) self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler() self._registration_handler = hs.get_registration_handler()
self._clock = hs.get_clock() self._clock = hs.get_clock()
@ -116,7 +115,7 @@ class SamlHandler:
self.expire_sessions() self.expire_sessions()
user_id = await self._map_saml_response_to_user(resp_bytes, relay_state) user_id = await self._map_saml_response_to_user(resp_bytes, relay_state)
self._sso_auth_handler.complete_sso_login(user_id, request, relay_state) self._auth_handler.complete_sso_login(user_id, request, relay_state)
async def _map_saml_response_to_user(self, resp_bytes, client_redirect_url): async def _map_saml_response_to_user(self, resp_bytes, client_redirect_url):
try: try:

View File

@ -27,10 +27,15 @@ import inspect
import logging import logging
import threading import threading
import types import types
from typing import Any, List from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union
from typing_extensions import Literal
from twisted.internet import defer, threads from twisted.internet import defer, threads
if TYPE_CHECKING:
from synapse.logging.scopecontextmanager import _LogContextScope
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
try: try:
@ -91,7 +96,7 @@ class ContextResourceUsage(object):
"evt_db_fetch_count", "evt_db_fetch_count",
] ]
def __init__(self, copy_from=None): def __init__(self, copy_from: "Optional[ContextResourceUsage]" = None) -> None:
"""Create a new ContextResourceUsage """Create a new ContextResourceUsage
Args: Args:
@ -101,27 +106,28 @@ class ContextResourceUsage(object):
if copy_from is None: if copy_from is None:
self.reset() self.reset()
else: else:
self.ru_utime = copy_from.ru_utime # FIXME: mypy can't infer the types set via reset() above, so specify explicitly for now
self.ru_stime = copy_from.ru_stime self.ru_utime = copy_from.ru_utime # type: float
self.db_txn_count = copy_from.db_txn_count self.ru_stime = copy_from.ru_stime # type: float
self.db_txn_count = copy_from.db_txn_count # type: int
self.db_txn_duration_sec = copy_from.db_txn_duration_sec self.db_txn_duration_sec = copy_from.db_txn_duration_sec # type: float
self.db_sched_duration_sec = copy_from.db_sched_duration_sec self.db_sched_duration_sec = copy_from.db_sched_duration_sec # type: float
self.evt_db_fetch_count = copy_from.evt_db_fetch_count self.evt_db_fetch_count = copy_from.evt_db_fetch_count # type: int
def copy(self): def copy(self) -> "ContextResourceUsage":
return ContextResourceUsage(copy_from=self) return ContextResourceUsage(copy_from=self)
def reset(self): def reset(self) -> None:
self.ru_stime = 0.0 self.ru_stime = 0.0
self.ru_utime = 0.0 self.ru_utime = 0.0
self.db_txn_count = 0 self.db_txn_count = 0
self.db_txn_duration_sec = 0 self.db_txn_duration_sec = 0.0
self.db_sched_duration_sec = 0 self.db_sched_duration_sec = 0.0
self.evt_db_fetch_count = 0 self.evt_db_fetch_count = 0
def __repr__(self): def __repr__(self) -> str:
return ( return (
"<ContextResourceUsage ru_stime='%r', ru_utime='%r', " "<ContextResourceUsage ru_stime='%r', ru_utime='%r', "
"db_txn_count='%r', db_txn_duration_sec='%r', " "db_txn_count='%r', db_txn_duration_sec='%r', "
@ -135,7 +141,7 @@ class ContextResourceUsage(object):
self.evt_db_fetch_count, self.evt_db_fetch_count,
) )
def __iadd__(self, other): def __iadd__(self, other: "ContextResourceUsage") -> "ContextResourceUsage":
"""Add another ContextResourceUsage's stats to this one's. """Add another ContextResourceUsage's stats to this one's.
Args: Args:
@ -149,7 +155,7 @@ class ContextResourceUsage(object):
self.evt_db_fetch_count += other.evt_db_fetch_count self.evt_db_fetch_count += other.evt_db_fetch_count
return self return self
def __isub__(self, other): def __isub__(self, other: "ContextResourceUsage") -> "ContextResourceUsage":
self.ru_utime -= other.ru_utime self.ru_utime -= other.ru_utime
self.ru_stime -= other.ru_stime self.ru_stime -= other.ru_stime
self.db_txn_count -= other.db_txn_count self.db_txn_count -= other.db_txn_count
@ -158,17 +164,20 @@ class ContextResourceUsage(object):
self.evt_db_fetch_count -= other.evt_db_fetch_count self.evt_db_fetch_count -= other.evt_db_fetch_count
return self return self
def __add__(self, other): def __add__(self, other: "ContextResourceUsage") -> "ContextResourceUsage":
res = ContextResourceUsage(copy_from=self) res = ContextResourceUsage(copy_from=self)
res += other res += other
return res return res
def __sub__(self, other): def __sub__(self, other: "ContextResourceUsage") -> "ContextResourceUsage":
res = ContextResourceUsage(copy_from=self) res = ContextResourceUsage(copy_from=self)
res -= other res -= other
return res return res
LoggingContextOrSentinel = Union["LoggingContext", "LoggingContext.Sentinel"]
class LoggingContext(object): class LoggingContext(object):
"""Additional context for log formatting. Contexts are scoped within a """Additional context for log formatting. Contexts are scoped within a
"with" block. "with" block.
@ -201,7 +210,14 @@ class LoggingContext(object):
class Sentinel(object): class Sentinel(object):
"""Sentinel to represent the root context""" """Sentinel to represent the root context"""
__slots__ = [] # type: List[Any] __slots__ = ["previous_context", "alive", "request", "scope"]
def __init__(self) -> None:
# Minimal set for compatibility with LoggingContext
self.previous_context = None
self.alive = None
self.request = None
self.scope = None
def __str__(self): def __str__(self):
return "sentinel" return "sentinel"
@ -235,7 +251,7 @@ class LoggingContext(object):
sentinel = Sentinel() sentinel = Sentinel()
def __init__(self, name=None, parent_context=None, request=None): def __init__(self, name=None, parent_context=None, request=None) -> None:
self.previous_context = LoggingContext.current_context() self.previous_context = LoggingContext.current_context()
self.name = name self.name = name
@ -250,7 +266,7 @@ class LoggingContext(object):
self.request = None self.request = None
self.tag = "" self.tag = ""
self.alive = True self.alive = True
self.scope = None self.scope = None # type: Optional[_LogContextScope]
self.parent_context = parent_context self.parent_context = parent_context
@ -261,13 +277,13 @@ class LoggingContext(object):
# the request param overrides the request from the parent context # the request param overrides the request from the parent context
self.request = request self.request = request
def __str__(self): def __str__(self) -> str:
if self.request: if self.request:
return str(self.request) return str(self.request)
return "%s@%x" % (self.name, id(self)) return "%s@%x" % (self.name, id(self))
@classmethod @classmethod
def current_context(cls): def current_context(cls) -> LoggingContextOrSentinel:
"""Get the current logging context from thread local storage """Get the current logging context from thread local storage
Returns: Returns:
@ -276,7 +292,9 @@ class LoggingContext(object):
return getattr(cls.thread_local, "current_context", cls.sentinel) return getattr(cls.thread_local, "current_context", cls.sentinel)
@classmethod @classmethod
def set_current_context(cls, context): def set_current_context(
cls, context: LoggingContextOrSentinel
) -> LoggingContextOrSentinel:
"""Set the current logging context in thread local storage """Set the current logging context in thread local storage
Args: Args:
context(LoggingContext): The context to activate. context(LoggingContext): The context to activate.
@ -291,7 +309,7 @@ class LoggingContext(object):
context.start() context.start()
return current return current
def __enter__(self): def __enter__(self) -> "LoggingContext":
"""Enters this logging context into thread local storage""" """Enters this logging context into thread local storage"""
old_context = self.set_current_context(self) old_context = self.set_current_context(self)
if self.previous_context != old_context: if self.previous_context != old_context:
@ -304,7 +322,7 @@ class LoggingContext(object):
return self return self
def __exit__(self, type, value, traceback): def __exit__(self, type, value, traceback) -> None:
"""Restore the logging context in thread local storage to the state it """Restore the logging context in thread local storage to the state it
was before this context was entered. was before this context was entered.
Returns: Returns:
@ -318,7 +336,6 @@ class LoggingContext(object):
logger.warning( logger.warning(
"Expected logging context %s but found %s", self, current "Expected logging context %s but found %s", self, current
) )
self.previous_context = None
self.alive = False self.alive = False
# if we have a parent, pass our CPU usage stats on # if we have a parent, pass our CPU usage stats on
@ -330,7 +347,7 @@ class LoggingContext(object):
# reset them in case we get entered again # reset them in case we get entered again
self._resource_usage.reset() self._resource_usage.reset()
def copy_to(self, record): def copy_to(self, record) -> None:
"""Copy logging fields from this context to a log record or """Copy logging fields from this context to a log record or
another LoggingContext another LoggingContext
""" """
@ -341,14 +358,14 @@ class LoggingContext(object):
# we also track the current scope: # we also track the current scope:
record.scope = self.scope record.scope = self.scope
def copy_to_twisted_log_entry(self, record): def copy_to_twisted_log_entry(self, record) -> None:
""" """
Copy logging fields from this context to a Twisted log record. Copy logging fields from this context to a Twisted log record.
""" """
record["request"] = self.request record["request"] = self.request
record["scope"] = self.scope record["scope"] = self.scope
def start(self): def start(self) -> None:
if get_thread_id() != self.main_thread: if get_thread_id() != self.main_thread:
logger.warning("Started logcontext %s on different thread", self) logger.warning("Started logcontext %s on different thread", self)
return return
@ -358,7 +375,7 @@ class LoggingContext(object):
if not self.usage_start: if not self.usage_start:
self.usage_start = get_thread_resource_usage() self.usage_start = get_thread_resource_usage()
def stop(self): def stop(self) -> None:
if get_thread_id() != self.main_thread: if get_thread_id() != self.main_thread:
logger.warning("Stopped logcontext %s on different thread", self) logger.warning("Stopped logcontext %s on different thread", self)
return return
@ -378,7 +395,7 @@ class LoggingContext(object):
self.usage_start = None self.usage_start = None
def get_resource_usage(self): def get_resource_usage(self) -> ContextResourceUsage:
"""Get resources used by this logcontext so far. """Get resources used by this logcontext so far.
Returns: Returns:
@ -398,11 +415,13 @@ class LoggingContext(object):
return res return res
def _get_cputime(self): def _get_cputime(self) -> Tuple[float, float]:
"""Get the cpu usage time so far """Get the cpu usage time so far
Returns: Tuple[float, float]: seconds in user mode, seconds in system mode Returns: Tuple[float, float]: seconds in user mode, seconds in system mode
""" """
assert self.usage_start is not None
current = get_thread_resource_usage() current = get_thread_resource_usage()
# Indicate to mypy that we know that self.usage_start is None. # Indicate to mypy that we know that self.usage_start is None.
@ -430,13 +449,13 @@ class LoggingContext(object):
return utime_delta, stime_delta return utime_delta, stime_delta
def add_database_transaction(self, duration_sec): def add_database_transaction(self, duration_sec: float) -> None:
if duration_sec < 0: if duration_sec < 0:
raise ValueError("DB txn time can only be non-negative") raise ValueError("DB txn time can only be non-negative")
self._resource_usage.db_txn_count += 1 self._resource_usage.db_txn_count += 1
self._resource_usage.db_txn_duration_sec += duration_sec self._resource_usage.db_txn_duration_sec += duration_sec
def add_database_scheduled(self, sched_sec): def add_database_scheduled(self, sched_sec: float) -> None:
"""Record a use of the database pool """Record a use of the database pool
Args: Args:
@ -447,7 +466,7 @@ class LoggingContext(object):
raise ValueError("DB scheduling time can only be non-negative") raise ValueError("DB scheduling time can only be non-negative")
self._resource_usage.db_sched_duration_sec += sched_sec self._resource_usage.db_sched_duration_sec += sched_sec
def record_event_fetch(self, event_count): def record_event_fetch(self, event_count: int) -> None:
"""Record a number of events being fetched from the db """Record a number of events being fetched from the db
Args: Args:
@ -464,10 +483,10 @@ class LoggingContextFilter(logging.Filter):
missing fields missing fields
""" """
def __init__(self, **defaults): def __init__(self, **defaults) -> None:
self.defaults = defaults self.defaults = defaults
def filter(self, record): def filter(self, record) -> Literal[True]:
"""Add each fields from the logging contexts to the record. """Add each fields from the logging contexts to the record.
Returns: Returns:
True to include the record in the log output. True to include the record in the log output.
@ -492,12 +511,13 @@ class PreserveLoggingContext(object):
__slots__ = ["current_context", "new_context", "has_parent"] __slots__ = ["current_context", "new_context", "has_parent"]
def __init__(self, new_context=None): def __init__(self, new_context: Optional[LoggingContext] = None) -> None:
if new_context is None: if new_context is None:
new_context = LoggingContext.sentinel self.new_context = LoggingContext.sentinel # type: LoggingContextOrSentinel
self.new_context = new_context else:
self.new_context = new_context
def __enter__(self): def __enter__(self) -> None:
"""Captures the current logging context""" """Captures the current logging context"""
self.current_context = LoggingContext.set_current_context(self.new_context) self.current_context = LoggingContext.set_current_context(self.new_context)
@ -506,7 +526,7 @@ class PreserveLoggingContext(object):
if not self.current_context.alive: if not self.current_context.alive:
logger.debug("Entering dead context: %s", self.current_context) logger.debug("Entering dead context: %s", self.current_context)
def __exit__(self, type, value, traceback): def __exit__(self, type, value, traceback) -> None:
"""Restores the current logging context""" """Restores the current logging context"""
context = LoggingContext.set_current_context(self.current_context) context = LoggingContext.set_current_context(self.current_context)
@ -525,7 +545,9 @@ class PreserveLoggingContext(object):
logger.debug("Restoring dead context: %s", self.current_context) logger.debug("Restoring dead context: %s", self.current_context)
def nested_logging_context(suffix, parent_context=None): def nested_logging_context(
suffix: str, parent_context: Optional[LoggingContext] = None
) -> LoggingContext:
"""Creates a new logging context as a child of another. """Creates a new logging context as a child of another.
The nested logging context will have a 'request' made up of the parent context's The nested logging context will have a 'request' made up of the parent context's
@ -546,10 +568,12 @@ def nested_logging_context(suffix, parent_context=None):
Returns: Returns:
LoggingContext: new logging context. LoggingContext: new logging context.
""" """
if parent_context is None: if parent_context is not None:
parent_context = LoggingContext.current_context() context = parent_context # type: LoggingContextOrSentinel
else:
context = LoggingContext.current_context()
return LoggingContext( return LoggingContext(
parent_context=parent_context, request=parent_context.request + "-" + suffix parent_context=context, request=str(context.request) + "-" + suffix
) )
@ -654,7 +678,10 @@ def make_deferred_yieldable(deferred):
return deferred return deferred
def _set_context_cb(result, context): ResultT = TypeVar("ResultT")
def _set_context_cb(result: ResultT, context: LoggingContext) -> ResultT:
"""A callback function which just sets the logging context""" """A callback function which just sets the logging context"""
LoggingContext.set_current_context(context) LoggingContext.set_current_context(context)
return result return result

View File

@ -17,6 +17,7 @@ import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.types import UserID from synapse.types import UserID
@ -211,3 +212,21 @@ class ModuleApi(object):
Deferred[object]: result of func Deferred[object]: result of func
""" """
return self._store.db.runInteraction(desc, func, *args, **kwargs) return self._store.db.runInteraction(desc, func, *args, **kwargs)
def complete_sso_login(
self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str
):
"""Complete a SSO login by redirecting the user to a page to confirm whether they
want their access token sent to `client_redirect_url`, or redirect them to that
URL with a token directly if the URL matches with one of the whitelisted clients.
Args:
registered_user_id: The MXID that has been registered as a previous step of
of this SSO login.
request: The request to respond to.
client_redirect_url: The URL to which to offer to redirect the user (or to
redirect them directly if whitelisted).
"""
self._auth_handler.complete_sso_login(
registered_user_id, request, client_redirect_url,
)

View File

@ -555,10 +555,12 @@ class Mailer(object):
else: else:
# If the reason room doesn't have a name, say who the messages # If the reason room doesn't have a name, say who the messages
# are from explicitly to avoid, "messages in the Bob room" # are from explicitly to avoid, "messages in the Bob room"
room_id = reason["room_id"]
sender_ids = list( sender_ids = list(
{ {
notif_events[n["event_id"]].sender notif_events[n["event_id"]].sender
for n in notifs_by_room[reason["room_id"]] for n in notifs_by_room[room_id]
} }
) )

View File

@ -18,7 +18,7 @@ import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import event_type_from_format_version from synapse.events import make_event_from_dict
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint from synapse.replication.http._base import ReplicationEndpoint
@ -38,6 +38,9 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
{ {
"events": [{ "events": [{
"event": { .. serialized event .. }, "event": { .. serialized event .. },
"room_version": .., // "1", "2", "3", etc: the version of the room
// containing the event
"event_format_version": .., // 1,2,3 etc: the event format version
"internal_metadata": { .. serialized internal_metadata .. }, "internal_metadata": { .. serialized internal_metadata .. },
"rejected_reason": .., // The event.rejected_reason field "rejected_reason": .., // The event.rejected_reason field
"context": { .. serialized event context .. }, "context": { .. serialized event context .. },
@ -73,6 +76,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
event_payloads.append( event_payloads.append(
{ {
"event": event.get_pdu_json(), "event": event.get_pdu_json(),
"room_version": event.room_version.identifier,
"event_format_version": event.format_version, "event_format_version": event.format_version,
"internal_metadata": event.internal_metadata.get_dict(), "internal_metadata": event.internal_metadata.get_dict(),
"rejected_reason": event.rejected_reason, "rejected_reason": event.rejected_reason,
@ -95,12 +99,13 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
event_and_contexts = [] event_and_contexts = []
for event_payload in event_payloads: for event_payload in event_payloads:
event_dict = event_payload["event"] event_dict = event_payload["event"]
format_ver = event_payload["event_format_version"] room_ver = KNOWN_ROOM_VERSIONS[event_payload["room_version"]]
internal_metadata = event_payload["internal_metadata"] internal_metadata = event_payload["internal_metadata"]
rejected_reason = event_payload["rejected_reason"] rejected_reason = event_payload["rejected_reason"]
EventType = event_type_from_format_version(format_ver) event = make_event_from_dict(
event = EventType(event_dict, internal_metadata, rejected_reason) event_dict, room_ver, internal_metadata, rejected_reason
)
context = EventContext.deserialize( context = EventContext.deserialize(
self.storage, event_payload["context"] self.storage, event_payload["context"]

View File

@ -17,7 +17,8 @@ import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.events import event_type_from_format_version from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import make_event_from_dict
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint from synapse.replication.http._base import ReplicationEndpoint
@ -37,6 +38,9 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
{ {
"event": { .. serialized event .. }, "event": { .. serialized event .. },
"room_version": .., // "1", "2", "3", etc: the version of the room
// containing the event
"event_format_version": .., // 1,2,3 etc: the event format version
"internal_metadata": { .. serialized internal_metadata .. }, "internal_metadata": { .. serialized internal_metadata .. },
"rejected_reason": .., // The event.rejected_reason field "rejected_reason": .., // The event.rejected_reason field
"context": { .. serialized event context .. }, "context": { .. serialized event context .. },
@ -77,6 +81,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
payload = { payload = {
"event": event.get_pdu_json(), "event": event.get_pdu_json(),
"room_version": event.room_version.identifier,
"event_format_version": event.format_version, "event_format_version": event.format_version,
"internal_metadata": event.internal_metadata.get_dict(), "internal_metadata": event.internal_metadata.get_dict(),
"rejected_reason": event.rejected_reason, "rejected_reason": event.rejected_reason,
@ -93,12 +98,13 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
event_dict = content["event"] event_dict = content["event"]
format_ver = content["event_format_version"] room_ver = KNOWN_ROOM_VERSIONS[content["room_version"]]
internal_metadata = content["internal_metadata"] internal_metadata = content["internal_metadata"]
rejected_reason = content["rejected_reason"] rejected_reason = content["rejected_reason"]
EventType = event_type_from_format_version(format_ver) event = make_event_from_dict(
event = EventType(event_dict, internal_metadata, rejected_reason) event_dict, room_ver, internal_metadata, rejected_reason
)
requester = Requester.deserialize(self.store, content["requester"]) requester = Requester.deserialize(self.store, content["requester"])
context = EventContext.deserialize(self.storage, content["context"]) context = EventContext.deserialize(self.storage, content["context"])

View File

@ -0,0 +1,14 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>SSO redirect confirmation</title>
</head>
<body>
<p>The application at <span style="font-weight:bold">{{ display_url | e }}</span> is requesting full access to your <span style="font-weight:bold">{{ server_name }}</span> Matrix account.</p>
<p>If you don't recognise this address, you should ignore this and close this tab.</p>
<p>
<a href="{{ redirect_url | e }}">I trust this address</a>
</p>
</body>
</html>

View File

@ -211,9 +211,7 @@ class UserRestServletV2(RestServlet):
if target_user == auth_user and not set_admin_to: if target_user == auth_user and not set_admin_to:
raise SynapseError(400, "You may not demote yourself.") raise SynapseError(400, "You may not demote yourself.")
await self.admin_handler.set_user_server_admin( await self.store.set_server_admin(target_user, set_admin_to)
target_user, set_admin_to
)
if "password" in body: if "password" in body:
if ( if (
@ -651,6 +649,6 @@ class UserAdminServlet(RestServlet):
if target_user == auth_user and not set_admin_to: if target_user == auth_user and not set_admin_to:
raise SynapseError(400, "You may not demote yourself.") raise SynapseError(400, "You may not demote yourself.")
await self.store.set_user_server_admin(target_user, set_admin_to) await self.store.set_server_admin(target_user, set_admin_to)
return 200, {} return 200, {}

View File

@ -28,7 +28,7 @@ from synapse.http.servlet import (
parse_json_object_from_request, parse_json_object_from_request,
parse_string, parse_string,
) )
from synapse.http.site import SynapseRequest from synapse.push.mailer import load_jinja2_templates
from synapse.rest.client.v2_alpha._base import client_patterns from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.well_known import WellKnownBuilder from synapse.rest.well_known import WellKnownBuilder
from synapse.types import UserID, map_username_to_mxid_localpart from synapse.types import UserID, map_username_to_mxid_localpart
@ -548,6 +548,16 @@ class SSOAuthHandler(object):
self._registration_handler = hs.get_registration_handler() self._registration_handler = hs.get_registration_handler()
self._macaroon_gen = hs.get_macaroon_generator() self._macaroon_gen = hs.get_macaroon_generator()
# Load the redirect page HTML template
self._template = load_jinja2_templates(
hs.config.sso_redirect_confirm_template_dir, ["sso_redirect_confirm.html"],
)[0]
self._server_name = hs.config.server_name
# cast to tuple for use with str.startswith
self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist)
async def on_successful_auth( async def on_successful_auth(
self, username, request, client_redirect_url, user_display_name=None self, username, request, client_redirect_url, user_display_name=None
): ):
@ -580,36 +590,9 @@ class SSOAuthHandler(object):
localpart=localpart, default_display_name=user_display_name localpart=localpart, default_display_name=user_display_name
) )
self.complete_sso_login(registered_user_id, request, client_redirect_url) self._auth_handler.complete_sso_login(
registered_user_id, request, client_redirect_url
def complete_sso_login(
self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str
):
"""Having figured out a mxid for this user, complete the HTTP request
Args:
registered_user_id:
request:
client_redirect_url:
"""
login_token = self._macaroon_gen.generate_short_term_login_token(
registered_user_id
) )
redirect_url = self._add_login_token_to_redirect_url(
client_redirect_url, login_token
)
# Load page
request.redirect(redirect_url)
finish_request(request)
@staticmethod
def _add_login_token_to_redirect_url(url, token):
url_parts = list(urllib.parse.urlparse(url))
query = dict(urllib.parse.parse_qsl(url_parts[4]))
query.update({"loginToken": token})
url_parts[4] = urllib.parse.urlencode(query)
return urllib.parse.urlunparse(url_parts)
def register_servlets(hs, http_server): def register_servlets(hs, http_server):

View File

@ -18,8 +18,6 @@ from typing import Dict, Set
from canonicaljson import encode_canonical_json, json from canonicaljson import encode_canonical_json, json
from signedjson.sign import sign_json from signedjson.sign import sign_json
from twisted.internet import defer
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.crypto.keyring import ServerKeyFetcher from synapse.crypto.keyring import ServerKeyFetcher
from synapse.http.server import ( from synapse.http.server import (
@ -125,8 +123,7 @@ class RemoteKey(DirectServeResource):
await self.query_keys(request, query, query_remote_on_cache_miss=True) await self.query_keys(request, query, query_remote_on_cache_miss=True)
@defer.inlineCallbacks async def query_keys(self, request, query, query_remote_on_cache_miss=False):
def query_keys(self, request, query, query_remote_on_cache_miss=False):
logger.info("Handling query for keys %r", query) logger.info("Handling query for keys %r", query)
store_queries = [] store_queries = []
@ -143,7 +140,7 @@ class RemoteKey(DirectServeResource):
for key_id in key_ids: for key_id in key_ids:
store_queries.append((server_name, key_id, None)) store_queries.append((server_name, key_id, None))
cached = yield self.store.get_server_keys_json(store_queries) cached = await self.store.get_server_keys_json(store_queries)
json_results = set() json_results = set()
@ -215,8 +212,8 @@ class RemoteKey(DirectServeResource):
json_results.add(bytes(result["key_json"])) json_results.add(bytes(result["key_json"]))
if cache_misses and query_remote_on_cache_miss: if cache_misses and query_remote_on_cache_miss:
yield self.fetcher.get_keys(cache_misses) await self.fetcher.get_keys(cache_misses)
yield self.query_keys(request, query, query_remote_on_cache_miss=False) await self.query_keys(request, query, query_remote_on_cache_miss=False)
else: else:
signed_keys = [] signed_keys = []
for key_json in json_results: for key_json in json_results:

View File

@ -608,6 +608,23 @@ class EventPushActionsWorkerStore(SQLBaseStore):
return range_end return range_end
@defer.inlineCallbacks
def get_time_of_last_push_action_before(self, stream_ordering):
def f(txn):
sql = (
"SELECT e.received_ts"
" FROM event_push_actions AS ep"
" JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id"
" WHERE ep.stream_ordering > ?"
" ORDER BY ep.stream_ordering ASC"
" LIMIT 1"
)
txn.execute(sql, (stream_ordering,))
return txn.fetchone()
result = yield self.db.runInteraction("get_time_of_last_push_action_before", f)
return result[0] if result else None
class EventPushActionsStore(EventPushActionsWorkerStore): class EventPushActionsStore(EventPushActionsWorkerStore):
EPA_HIGHLIGHT_INDEX = "epa_highlight_index" EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
@ -735,23 +752,6 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"]) pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
return push_actions return push_actions
@defer.inlineCallbacks
def get_time_of_last_push_action_before(self, stream_ordering):
def f(txn):
sql = (
"SELECT e.received_ts"
" FROM event_push_actions AS ep"
" JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id"
" WHERE ep.stream_ordering > ?"
" ORDER BY ep.stream_ordering ASC"
" LIMIT 1"
)
txn.execute(sql, (stream_ordering,))
return txn.fetchone()
result = yield self.db.runInteraction("get_time_of_last_push_action_before", f)
return result[0] if result else None
@defer.inlineCallbacks @defer.inlineCallbacks
def get_latest_push_action_stream_ordering(self): def get_latest_push_action_stream_ordering(self):
def f(txn): def f(txn):

View File

@ -1168,7 +1168,11 @@ class EventsStore(
and original_event.internal_metadata.is_redacted() and original_event.internal_metadata.is_redacted()
): ):
# Redaction was allowed # Redaction was allowed
pruned_json = encode_json(prune_event_dict(original_event.get_dict())) pruned_json = encode_json(
prune_event_dict(
original_event.room_version, original_event.get_dict()
)
)
else: else:
# Redaction wasn't allowed # Redaction wasn't allowed
pruned_json = None pruned_json = None
@ -1929,7 +1933,9 @@ class EventsStore(
return return
# Prune the event's dict then convert it to JSON. # Prune the event's dict then convert it to JSON.
pruned_json = encode_json(prune_event_dict(event.get_dict())) pruned_json = encode_json(
prune_event_dict(event.room_version, event.get_dict())
)
# Update the event_json table to replace the event's JSON with the pruned # Update the event_json table to replace the event's JSON with the pruned
# JSON. # JSON.

View File

@ -28,9 +28,12 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.errors import NotFoundError from synapse.api.errors import NotFoundError
from synapse.api.room_versions import EventFormatVersions from synapse.api.room_versions import (
from synapse.events import FrozenEvent, event_type_from_format_version # noqa: F401 KNOWN_ROOM_VERSIONS,
from synapse.events.snapshot import EventContext # noqa: F401 EventFormatVersions,
RoomVersions,
)
from synapse.events import make_event_from_dict
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
from synapse.logging.context import LoggingContext, PreserveLoggingContext from synapse.logging.context import LoggingContext, PreserveLoggingContext
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
@ -580,8 +583,49 @@ class EventsWorkerStore(SQLBaseStore):
# of a event format version, so it must be a V1 event. # of a event format version, so it must be a V1 event.
format_version = EventFormatVersions.V1 format_version = EventFormatVersions.V1
original_ev = event_type_from_format_version(format_version)( room_version_id = row["room_version_id"]
if not room_version_id:
# this should only happen for out-of-band membership events
if not internal_metadata.get("out_of_band_membership"):
logger.warning(
"Room %s for event %s is unknown", d["room_id"], event_id
)
continue
# take a wild stab at the room version based on the event format
if format_version == EventFormatVersions.V1:
room_version = RoomVersions.V1
elif format_version == EventFormatVersions.V2:
room_version = RoomVersions.V3
else:
room_version = RoomVersions.V5
else:
room_version = KNOWN_ROOM_VERSIONS.get(room_version_id)
if not room_version:
logger.error(
"Event %s in room %s has unknown room version %s",
event_id,
d["room_id"],
room_version_id,
)
continue
if room_version.event_format != format_version:
logger.error(
"Event %s in room %s with version %s has wrong format: "
"expected %s, was %s",
event_id,
d["room_id"],
room_version_id,
room_version.event_format,
format_version,
)
continue
original_ev = make_event_from_dict(
event_dict=d, event_dict=d,
room_version=room_version,
internal_metadata_dict=internal_metadata, internal_metadata_dict=internal_metadata,
rejected_reason=rejected_reason, rejected_reason=rejected_reason,
) )
@ -661,6 +705,12 @@ class EventsWorkerStore(SQLBaseStore):
of EventFormatVersions. 'None' means the event predates of EventFormatVersions. 'None' means the event predates
EventFormatVersions (so the event is format V1). EventFormatVersions (so the event is format V1).
* room_version_id (str|None): The version of the room which contains the event.
Hopefully one of RoomVersions.
Due to historical reasons, there may be a few events in the database which
do not have an associated room; in this case None will be returned here.
* rejected_reason (str|None): if the event was rejected, the reason * rejected_reason (str|None): if the event was rejected, the reason
why. why.
@ -676,17 +726,18 @@ class EventsWorkerStore(SQLBaseStore):
""" """
event_dict = {} event_dict = {}
for evs in batch_iter(event_ids, 200): for evs in batch_iter(event_ids, 200):
sql = ( sql = """\
"SELECT " SELECT
" e.event_id, " e.event_id,
" e.internal_metadata," e.internal_metadata,
" e.json," e.json,
" e.format_version, " e.format_version,
" rej.reason " r.room_version,
" FROM event_json as e" rej.reason
" LEFT JOIN rejections as rej USING (event_id)" FROM event_json as e
" WHERE " LEFT JOIN rooms r USING (room_id)
) LEFT JOIN rejections as rej USING (event_id)
WHERE """
clause, args = make_in_list_sql_clause( clause, args = make_in_list_sql_clause(
txn.database_engine, "e.event_id", evs txn.database_engine, "e.event_id", evs
@ -701,7 +752,8 @@ class EventsWorkerStore(SQLBaseStore):
"internal_metadata": row[1], "internal_metadata": row[1],
"json": row[2], "json": row[2],
"format_version": row[3], "format_version": row[3],
"rejected_reason": row[4], "room_version_id": row[4],
"rejected_reason": row[5],
"redactions": [], "redactions": [],
} }

View File

@ -43,13 +43,40 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
def _count_users(txn): def _count_users(txn):
sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users" sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users"
txn.execute(sql) txn.execute(sql)
(count,) = txn.fetchone() (count,) = txn.fetchone()
return count return count
return self.db.runInteraction("count_users", _count_users) return self.db.runInteraction("count_users", _count_users)
@cached(num_args=0)
def get_monthly_active_count_by_service(self):
"""Generates current count of monthly active users broken down by service.
A service is typically an appservice but also includes native matrix users.
Since the `monthly_active_users` table is populated from the `user_ips` table
`config.track_appservice_user_ips` must be set to `true` for this
method to return anything other than native matrix users.
Returns:
Deferred[dict]: dict that includes a mapping between app_service_id
and the number of occurrences.
"""
def _count_users_by_service(txn):
sql = """
SELECT COALESCE(appservice_id, 'native'), COALESCE(count(*), 0)
FROM monthly_active_users
LEFT JOIN users ON monthly_active_users.user_id=users.name
GROUP BY appservice_id;
"""
txn.execute(sql)
result = txn.fetchall()
return dict(result)
return self.db.runInteraction("count_users_by_service", _count_users_by_service)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_registered_reserved_users(self): def get_registered_reserved_users(self):
"""Of the reserved threepids defined in config, which are associated """Of the reserved threepids defined in config, which are associated
@ -291,6 +318,9 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
) )
self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ()) self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ())
self._invalidate_cache_and_stream(
txn, self.get_monthly_active_count_by_service, ()
)
self._invalidate_cache_and_stream( self._invalidate_cache_and_stream(
txn, self.user_last_seen_monthly_active, (user_id,) txn, self.user_last_seen_monthly_active, (user_id,)
) )

View File

@ -301,12 +301,16 @@ class RegistrationWorkerStore(SQLBaseStore):
admin (bool): true iff the user is to be a server admin, admin (bool): true iff the user is to be a server admin,
false otherwise. false otherwise.
""" """
return self.db.simple_update_one(
table="users", def set_server_admin_txn(txn):
keyvalues={"name": user.to_string()}, self.db.simple_update_one_txn(
updatevalues={"admin": 1 if admin else 0}, txn, "users", {"name": user.to_string()}, {"admin": 1 if admin else 0}
desc="set_server_admin", )
) self._invalidate_cache_and_stream(
txn, self.get_user_by_id, (user.to_string(),)
)
return self.db.runInteraction("set_server_admin", set_server_admin_txn)
def _query_for_auth(self, txn, token): def _query_for_auth(self, txn, token):
sql = ( sql = (

View File

@ -15,6 +15,8 @@
import logging import logging
from twisted.internet import defer
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -56,7 +58,7 @@ class StateDeltasStore(SQLBaseStore):
# if the CSDs haven't changed between prev_stream_id and now, we # if the CSDs haven't changed between prev_stream_id and now, we
# know for certain that they haven't changed between prev_stream_id and # know for certain that they haven't changed between prev_stream_id and
# max_stream_id. # max_stream_id.
return max_stream_id, [] return defer.succeed((max_stream_id, []))
def get_current_state_deltas_txn(txn): def get_current_state_deltas_txn(txn):
# First we calculate the max stream id that will give us less than # First we calculate the max stream id that will give us less than

View File

@ -15,9 +15,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
import sys
import time import time
from typing import Iterable, Tuple from time import monotonic as monotonic_time
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple
from six import iteritems, iterkeys, itervalues from six import iteritems, iterkeys, itervalues
from six.moves import intern, range from six.moves import intern, range
@ -32,24 +32,14 @@ from synapse.config.database import DatabaseConnectionConfig
from synapse.logging.context import LoggingContext, make_deferred_yieldable from synapse.logging.context import LoggingContext, make_deferred_yieldable
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.background_updates import BackgroundUpdater
from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.types import Connection, Cursor
from synapse.util.stringutils import exception_to_unicode from synapse.util.stringutils import exception_to_unicode
# import a function which will return a monotonic time, in seconds
try:
# on python 3, use time.monotonic, since time.clock can go backwards
from time import monotonic as monotonic_time
except ImportError:
# ... but python 2 doesn't have it
from time import clock as monotonic_time
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
try: # python 3 does not have a maximum int value
MAX_TXN_ID = sys.maxint - 1 MAX_TXN_ID = 2 ** 63 - 1
except AttributeError:
# python 3 does not have a maximum int value
MAX_TXN_ID = 2 ** 63 - 1
sql_logger = logging.getLogger("synapse.storage.SQL") sql_logger = logging.getLogger("synapse.storage.SQL")
transaction_logger = logging.getLogger("synapse.storage.txn") transaction_logger = logging.getLogger("synapse.storage.txn")
@ -77,7 +67,7 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
def make_pool( def make_pool(
reactor, db_config: DatabaseConnectionConfig, engine reactor, db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
) -> adbapi.ConnectionPool: ) -> adbapi.ConnectionPool:
"""Get the connection pool for the database. """Get the connection pool for the database.
""" """
@ -90,7 +80,9 @@ def make_pool(
) )
def make_conn(db_config: DatabaseConnectionConfig, engine): def make_conn(
db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
) -> Connection:
"""Make a new connection to the database and return it. """Make a new connection to the database and return it.
Returns: Returns:
@ -107,20 +99,27 @@ def make_conn(db_config: DatabaseConnectionConfig, engine):
return db_conn return db_conn
class LoggingTransaction(object): # The type of entry which goes on our after_callbacks and exception_callbacks lists.
#
# Python 3.5.2 doesn't support Callable with an ellipsis, so we wrap it in quotes so
# that mypy sees the type but the runtime python doesn't.
_CallbackListEntry = Tuple["Callable[..., None]", Iterable[Any], Dict[str, Any]]
class LoggingTransaction:
"""An object that almost-transparently proxies for the 'txn' object """An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging and metrics to the .execute() passed to the constructor. Adds logging and metrics to the .execute()
method. method.
Args: Args:
txn: The database transcation object to wrap. txn: The database transcation object to wrap.
name (str): The name of this transactions for logging. name: The name of this transactions for logging.
database_engine (Sqlite3Engine|PostgresEngine) database_engine
after_callbacks(list|None): A list that callbacks will be appended to after_callbacks: A list that callbacks will be appended to
that have been added by `call_after` which should be run on that have been added by `call_after` which should be run on
successful completion of the transaction. None indicates that no successful completion of the transaction. None indicates that no
callbacks should be allowed to be scheduled to run. callbacks should be allowed to be scheduled to run.
exception_callbacks(list|None): A list that callbacks will be appended exception_callbacks: A list that callbacks will be appended
to that have been added by `call_on_exception` which should be run to that have been added by `call_on_exception` which should be run
if transaction ends with an error. None indicates that no callbacks if transaction ends with an error. None indicates that no callbacks
should be allowed to be scheduled to run. should be allowed to be scheduled to run.
@ -135,46 +134,67 @@ class LoggingTransaction(object):
] ]
def __init__( def __init__(
self, txn, name, database_engine, after_callbacks=None, exception_callbacks=None self,
txn: Cursor,
name: str,
database_engine: BaseDatabaseEngine,
after_callbacks: Optional[List[_CallbackListEntry]] = None,
exception_callbacks: Optional[List[_CallbackListEntry]] = None,
): ):
object.__setattr__(self, "txn", txn) self.txn = txn
object.__setattr__(self, "name", name) self.name = name
object.__setattr__(self, "database_engine", database_engine) self.database_engine = database_engine
object.__setattr__(self, "after_callbacks", after_callbacks) self.after_callbacks = after_callbacks
object.__setattr__(self, "exception_callbacks", exception_callbacks) self.exception_callbacks = exception_callbacks
def call_after(self, callback, *args, **kwargs): def call_after(self, callback: "Callable[..., None]", *args, **kwargs):
"""Call the given callback on the main twisted thread after the """Call the given callback on the main twisted thread after the
transaction has finished. Used to invalidate the caches on the transaction has finished. Used to invalidate the caches on the
correct thread. correct thread.
""" """
# if self.after_callbacks is None, that means that whatever constructed the
# LoggingTransaction isn't expecting there to be any callbacks; assert that
# is not the case.
assert self.after_callbacks is not None
self.after_callbacks.append((callback, args, kwargs)) self.after_callbacks.append((callback, args, kwargs))
def call_on_exception(self, callback, *args, **kwargs): def call_on_exception(self, callback: "Callable[..., None]", *args, **kwargs):
# if self.exception_callbacks is None, that means that whatever constructed the
# LoggingTransaction isn't expecting there to be any callbacks; assert that
# is not the case.
assert self.exception_callbacks is not None
self.exception_callbacks.append((callback, args, kwargs)) self.exception_callbacks.append((callback, args, kwargs))
def __getattr__(self, name): def fetchall(self) -> List[Tuple]:
return getattr(self.txn, name) return self.txn.fetchall()
def __setattr__(self, name, value): def fetchone(self) -> Tuple:
setattr(self.txn, name, value) return self.txn.fetchone()
def __iter__(self): def __iter__(self) -> Iterator[Tuple]:
return self.txn.__iter__() return self.txn.__iter__()
@property
def rowcount(self) -> int:
return self.txn.rowcount
@property
def description(self) -> Any:
return self.txn.description
def execute_batch(self, sql, args): def execute_batch(self, sql, args):
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
from psycopg2.extras import execute_batch from psycopg2.extras import execute_batch # type: ignore
self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args) self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
else: else:
for val in args: for val in args:
self.execute(sql, val) self.execute(sql, val)
def execute(self, sql, *args): def execute(self, sql: str, *args: Any):
self._do_execute(self.txn.execute, sql, *args) self._do_execute(self.txn.execute, sql, *args)
def executemany(self, sql, *args): def executemany(self, sql: str, *args: Any):
self._do_execute(self.txn.executemany, sql, *args) self._do_execute(self.txn.executemany, sql, *args)
def _make_sql_one_line(self, sql): def _make_sql_one_line(self, sql):
@ -207,6 +227,9 @@ class LoggingTransaction(object):
sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs) sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs)
sql_query_timer.labels(sql.split()[0]).observe(secs) sql_query_timer.labels(sql.split()[0]).observe(secs)
def close(self):
self.txn.close()
class PerformanceCounters(object): class PerformanceCounters(object):
def __init__(self): def __init__(self):
@ -251,7 +274,9 @@ class Database(object):
_TXN_ID = 0 _TXN_ID = 0
def __init__(self, hs, database_config: DatabaseConnectionConfig, engine): def __init__(
self, hs, database_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
):
self.hs = hs self.hs = hs
self._clock = hs.get_clock() self._clock = hs.get_clock()
self._database_config = database_config self._database_config = database_config
@ -259,9 +284,9 @@ class Database(object):
self.updates = BackgroundUpdater(hs, self) self.updates = BackgroundUpdater(hs, self)
self._previous_txn_total_time = 0 self._previous_txn_total_time = 0.0
self._current_txn_total_time = 0 self._current_txn_total_time = 0.0
self._previous_loop_ts = 0 self._previous_loop_ts = 0.0
# TODO(paul): These can eventually be removed once the metrics code # TODO(paul): These can eventually be removed once the metrics code
# is running in mainline, and we have some nice monitoring frontends # is running in mainline, and we have some nice monitoring frontends
@ -463,23 +488,23 @@ class Database(object):
sql_txn_timer.labels(desc).observe(duration) sql_txn_timer.labels(desc).observe(duration)
@defer.inlineCallbacks @defer.inlineCallbacks
def runInteraction(self, desc, func, *args, **kwargs): def runInteraction(self, desc: str, func: Callable, *args: Any, **kwargs: Any):
"""Starts a transaction on the database and runs a given function """Starts a transaction on the database and runs a given function
Arguments: Arguments:
desc (str): description of the transaction, for logging and metrics desc: description of the transaction, for logging and metrics
func (func): callback function, which will be called with a func: callback function, which will be called with a
database transaction (twisted.enterprise.adbapi.Transaction) as database transaction (twisted.enterprise.adbapi.Transaction) as
its first argument, followed by `args` and `kwargs`. its first argument, followed by `args` and `kwargs`.
args (list): positional args to pass to `func` args: positional args to pass to `func`
kwargs (dict): named args to pass to `func` kwargs: named args to pass to `func`
Returns: Returns:
Deferred: The result of func Deferred: The result of func
""" """
after_callbacks = [] after_callbacks = [] # type: List[_CallbackListEntry]
exception_callbacks = [] exception_callbacks = [] # type: List[_CallbackListEntry]
if LoggingContext.current_context() == LoggingContext.sentinel: if LoggingContext.current_context() == LoggingContext.sentinel:
logger.warning("Starting db txn '%s' from sentinel context", desc) logger.warning("Starting db txn '%s' from sentinel context", desc)
@ -505,15 +530,15 @@ class Database(object):
return result return result
@defer.inlineCallbacks @defer.inlineCallbacks
def runWithConnection(self, func, *args, **kwargs): def runWithConnection(self, func: Callable, *args: Any, **kwargs: Any):
"""Wraps the .runWithConnection() method on the underlying db_pool. """Wraps the .runWithConnection() method on the underlying db_pool.
Arguments: Arguments:
func (func): callback function, which will be called with a func: callback function, which will be called with a
database connection (twisted.enterprise.adbapi.Connection) as database connection (twisted.enterprise.adbapi.Connection) as
its first argument, followed by `args` and `kwargs`. its first argument, followed by `args` and `kwargs`.
args (list): positional args to pass to `func` args: positional args to pass to `func`
kwargs (dict): named args to pass to `func` kwargs: named args to pass to `func`
Returns: Returns:
Deferred: The result of func Deferred: The result of func
@ -800,7 +825,7 @@ class Database(object):
return False return False
# We didn't find any existing rows, so insert a new one # We didn't find any existing rows, so insert a new one
allvalues = {} allvalues = {} # type: Dict[str, Any]
allvalues.update(keyvalues) allvalues.update(keyvalues)
allvalues.update(values) allvalues.update(values)
allvalues.update(insertion_values) allvalues.update(insertion_values)
@ -829,7 +854,7 @@ class Database(object):
Returns: Returns:
None None
""" """
allvalues = {} allvalues = {} # type: Dict[str, Any]
allvalues.update(keyvalues) allvalues.update(keyvalues)
allvalues.update(insertion_values) allvalues.update(insertion_values)
@ -916,7 +941,7 @@ class Database(object):
Returns: Returns:
None None
""" """
allnames = [] allnames = [] # type: List[str]
allnames.extend(key_names) allnames.extend(key_names)
allnames.extend(value_names) allnames.extend(value_names)
@ -1100,7 +1125,7 @@ class Database(object):
keyvalues : dict of column names and values to select the rows with keyvalues : dict of column names and values to select the rows with
retcols : list of strings giving the names of the columns to return retcols : list of strings giving the names of the columns to return
""" """
results = [] results = [] # type: List[Dict[str, Any]]
if not iterable: if not iterable:
return results return results
@ -1439,7 +1464,7 @@ class Database(object):
raise ValueError("order_direction must be one of 'ASC' or 'DESC'.") raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
where_clause = "WHERE " if filters or keyvalues else "" where_clause = "WHERE " if filters or keyvalues else ""
arg_list = [] arg_list = [] # type: List[Any]
if filters: if filters:
where_clause += " AND ".join("%s LIKE ?" % (k,) for k in filters) where_clause += " AND ".join("%s LIKE ?" % (k,) for k in filters)
arg_list += list(filters.values()) arg_list += list(filters.values())

View File

@ -12,29 +12,31 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import importlib
import platform import platform
from ._base import IncorrectDatabaseSetup from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup
from .postgres import PostgresEngine from .postgres import PostgresEngine
from .sqlite import Sqlite3Engine from .sqlite import Sqlite3Engine
SUPPORTED_MODULE = {"sqlite3": Sqlite3Engine, "psycopg2": PostgresEngine}
def create_engine(database_config) -> BaseDatabaseEngine:
def create_engine(database_config):
name = database_config["name"] name = database_config["name"]
engine_class = SUPPORTED_MODULE.get(name, None)
if engine_class: if name == "sqlite3":
import sqlite3
return Sqlite3Engine(sqlite3, database_config)
if name == "psycopg2":
# pypy requires psycopg2cffi rather than psycopg2 # pypy requires psycopg2cffi rather than psycopg2
if name == "psycopg2" and platform.python_implementation() == "PyPy": if platform.python_implementation() == "PyPy":
name = "psycopg2cffi" import psycopg2cffi as psycopg2 # type: ignore
module = importlib.import_module(name) else:
return engine_class(module, database_config) import psycopg2 # type: ignore
return PostgresEngine(psycopg2, database_config)
raise RuntimeError("Unsupported database engine '%s'" % (name,)) raise RuntimeError("Unsupported database engine '%s'" % (name,))
__all__ = ["create_engine", "IncorrectDatabaseSetup"] __all__ = ["create_engine", "BaseDatabaseEngine", "IncorrectDatabaseSetup"]

View File

@ -12,7 +12,94 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import abc
from typing import Generic, TypeVar
from synapse.storage.types import Connection
class IncorrectDatabaseSetup(RuntimeError): class IncorrectDatabaseSetup(RuntimeError):
pass pass
ConnectionType = TypeVar("ConnectionType", bound=Connection)
class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
def __init__(self, module, database_config: dict):
self.module = module
@property
@abc.abstractmethod
def single_threaded(self) -> bool:
...
@property
@abc.abstractmethod
def can_native_upsert(self) -> bool:
"""
Do we support native UPSERTs?
"""
...
@property
@abc.abstractmethod
def supports_tuple_comparison(self) -> bool:
"""
Do we support comparing tuples, i.e. `(a, b) > (c, d)`?
"""
...
@property
@abc.abstractmethod
def supports_using_any_list(self) -> bool:
"""
Do we support using `a = ANY(?)` and passing a list
"""
...
@abc.abstractmethod
def check_database(
self, db_conn: ConnectionType, allow_outdated_version: bool = False
) -> None:
...
@abc.abstractmethod
def check_new_database(self, txn) -> None:
"""Gets called when setting up a brand new database. This allows us to
apply stricter checks on new databases versus existing database.
"""
...
@abc.abstractmethod
def convert_param_style(self, sql: str) -> str:
...
@abc.abstractmethod
def on_new_connection(self, db_conn: ConnectionType) -> None:
...
@abc.abstractmethod
def is_deadlock(self, error: Exception) -> bool:
...
@abc.abstractmethod
def is_connection_closed(self, conn: ConnectionType) -> bool:
...
@abc.abstractmethod
def lock_table(self, txn, table: str) -> None:
...
@abc.abstractmethod
def get_next_state_group_id(self, txn) -> int:
"""Returns an int that can be used as a new state_group ID
"""
...
@property
@abc.abstractmethod
def server_version(self) -> str:
"""Gets a string giving the server version. For example: '3.22.0'
"""
...

View File

@ -15,16 +15,14 @@
import logging import logging
from ._base import IncorrectDatabaseSetup from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class PostgresEngine(object): class PostgresEngine(BaseDatabaseEngine):
single_threaded = False
def __init__(self, database_module, database_config): def __init__(self, database_module, database_config):
self.module = database_module super().__init__(database_module, database_config)
self.module.extensions.register_type(self.module.extensions.UNICODE) self.module.extensions.register_type(self.module.extensions.UNICODE)
# Disables passing `bytes` to txn.execute, c.f. #6186. If you do # Disables passing `bytes` to txn.execute, c.f. #6186. If you do
@ -36,6 +34,10 @@ class PostgresEngine(object):
self.synchronous_commit = database_config.get("synchronous_commit", True) self.synchronous_commit = database_config.get("synchronous_commit", True)
self._version = None # unknown as yet self._version = None # unknown as yet
@property
def single_threaded(self) -> bool:
return False
def check_database(self, db_conn, allow_outdated_version: bool = False): def check_database(self, db_conn, allow_outdated_version: bool = False):
# Get the version of PostgreSQL that we're using. As per the psycopg2 # Get the version of PostgreSQL that we're using. As per the psycopg2
# docs: The number is formed by converting the major, minor, and # docs: The number is formed by converting the major, minor, and

View File

@ -12,16 +12,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import sqlite3
import struct import struct
import threading import threading
from synapse.storage.engines import BaseDatabaseEngine
class Sqlite3Engine(object):
single_threaded = True
class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection]):
def __init__(self, database_module, database_config): def __init__(self, database_module, database_config):
self.module = database_module super().__init__(database_module, database_config)
database = database_config.get("args", {}).get("database") database = database_config.get("args", {}).get("database")
self._is_in_memory = database in (None, ":memory:",) self._is_in_memory = database in (None, ":memory:",)
@ -31,6 +31,10 @@ class Sqlite3Engine(object):
self._current_state_group_id = None self._current_state_group_id = None
self._current_state_group_id_lock = threading.Lock() self._current_state_group_id_lock = threading.Lock()
@property
def single_threaded(self) -> bool:
return True
@property @property
def can_native_upsert(self): def can_native_upsert(self):
""" """
@ -68,7 +72,6 @@ class Sqlite3Engine(object):
return sql return sql
def on_new_connection(self, db_conn): def on_new_connection(self, db_conn):
# We need to import here to avoid an import loop. # We need to import here to avoid an import loop.
from synapse.storage.prepare_database import prepare_database from synapse.storage.prepare_database import prepare_database

65
synapse/storage/types.py Normal file
View File

@ -0,0 +1,65 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Iterable, Iterator, List, Tuple
from typing_extensions import Protocol
"""
Some very basic protocol definitions for the DB-API2 classes specified in PEP-249
"""
class Cursor(Protocol):
def execute(self, sql: str, parameters: Iterable[Any] = ...) -> Any:
...
def executemany(self, sql: str, parameters: Iterable[Iterable[Any]]) -> Any:
...
def fetchall(self) -> List[Tuple]:
...
def fetchone(self) -> Tuple:
...
@property
def description(self) -> Any:
return None
@property
def rowcount(self) -> int:
return 0
def __iter__(self) -> Iterator[Tuple]:
...
def close(self) -> None:
...
class Connection(Protocol):
def cursor(self) -> Cursor:
...
def close(self) -> None:
...
def commit(self) -> None:
...
def rollback(self, *args, **kwargs) -> None:
...

View File

@ -23,7 +23,7 @@ import attr
from signedjson.key import decode_verify_key_bytes from signedjson.key import decode_verify_key_bytes
from unpaddedbase64 import decode_base64 from unpaddedbase64 import decode_base64
from synapse.api.errors import SynapseError from synapse.api.errors import Codes, SynapseError
# define a version of typing.Collection that works on python 3.5 # define a version of typing.Collection that works on python 3.5
if sys.version_info[:3] >= (3, 6, 0): if sys.version_info[:3] >= (3, 6, 0):
@ -166,11 +166,13 @@ class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "dom
return self return self
@classmethod @classmethod
def from_string(cls, s): def from_string(cls, s: str):
"""Parse the string given by 's' into a structure object.""" """Parse the string given by 's' into a structure object."""
if len(s) < 1 or s[0:1] != cls.SIGIL: if len(s) < 1 or s[0:1] != cls.SIGIL:
raise SynapseError( raise SynapseError(
400, "Expected %s string to start with '%s'" % (cls.__name__, cls.SIGIL) 400,
"Expected %s string to start with '%s'" % (cls.__name__, cls.SIGIL),
Codes.INVALID_PARAM,
) )
parts = s[1:].split(":", 1) parts = s[1:].split(":", 1)
@ -179,6 +181,7 @@ class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "dom
400, 400,
"Expected %s of the form '%slocalname:domain'" "Expected %s of the form '%slocalname:domain'"
% (cls.__name__, cls.SIGIL), % (cls.__name__, cls.SIGIL),
Codes.INVALID_PARAM,
) )
domain = parts[1] domain = parts[1]
@ -235,11 +238,13 @@ class GroupID(DomainSpecificString):
def from_string(cls, s): def from_string(cls, s):
group_id = super(GroupID, cls).from_string(s) group_id = super(GroupID, cls).from_string(s)
if not group_id.localpart: if not group_id.localpart:
raise SynapseError(400, "Group ID cannot be empty") raise SynapseError(400, "Group ID cannot be empty", Codes.INVALID_PARAM)
if contains_invalid_mxid_characters(group_id.localpart): if contains_invalid_mxid_characters(group_id.localpart):
raise SynapseError( raise SynapseError(
400, "Group ID can only contain characters a-z, 0-9, or '=_-./'" 400,
"Group ID can only contain characters a-z, 0-9, or '=_-./'",
Codes.INVALID_PARAM,
) )
return group_id return group_id

View File

@ -119,6 +119,9 @@ def filter_events_for_client(
the original event if they can see it as normal. the original event if they can see it as normal.
""" """
if event.type == "org.matrix.dummy_event":
return None
if not event.is_state() and event.sender in ignore_list: if not event.is_state() and event.sender in ignore_list:
return None return None

View File

@ -27,6 +27,11 @@ class FrontendProxyTests(HomeserverTestCase):
return hs return hs
def default_config(self, name="test"):
c = super().default_config(name)
c["worker_app"] = "synapse.app.frontend_proxy"
return c
def test_listen_http_with_presence_enabled(self): def test_listen_http_with_presence_enabled(self):
""" """
When presence is on, the stub servlet will not register. When presence is on, the stub servlet will not register.

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.api.room_versions import RoomVersions
from synapse.events import make_event_from_dict from synapse.events import make_event_from_dict
from synapse.events.utils import ( from synapse.events.utils import (
copy_power_levels_contents, copy_power_levels_contents,
@ -36,9 +37,9 @@ class PruneEventTestCase(unittest.TestCase):
""" Asserts that a new event constructed with `evdict` will look like """ Asserts that a new event constructed with `evdict` will look like
`matchdict` when it is redacted. """ `matchdict` when it is redacted. """
def run_test(self, evdict, matchdict): def run_test(self, evdict, matchdict, **kwargs):
self.assertEquals( self.assertEquals(
prune_event(make_event_from_dict(evdict)).get_dict(), matchdict prune_event(make_event_from_dict(evdict, **kwargs)).get_dict(), matchdict
) )
def test_minimal(self): def test_minimal(self):
@ -128,6 +129,36 @@ class PruneEventTestCase(unittest.TestCase):
}, },
) )
def test_alias_event(self):
"""Alias events have special behavior up through room version 6."""
self.run_test(
{
"type": "m.room.aliases",
"event_id": "$test:domain",
"content": {"aliases": ["test"]},
},
{
"type": "m.room.aliases",
"event_id": "$test:domain",
"content": {"aliases": ["test"]},
"signatures": {},
"unsigned": {},
},
)
def test_msc2432_alias_event(self):
"""After MSC2432, alias events have no special behavior."""
self.run_test(
{"type": "m.room.aliases", "content": {"aliases": ["test"]}},
{
"type": "m.room.aliases",
"content": {},
"signatures": {},
"unsigned": {},
},
room_version=RoomVersions.MSC2432_DEV,
)
class SerializeEventTestCase(unittest.TestCase): class SerializeEventTestCase(unittest.TestCase):
def serialize(self, ev, fields): def serialize(self, ev, fields):

View File

@ -18,6 +18,7 @@ from mock import Mock
from twisted.internet import defer from twisted.internet import defer
import synapse
import synapse.api.errors import synapse.api.errors
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.config.room_directory import RoomDirectoryConfig from synapse.config.room_directory import RoomDirectoryConfig
@ -87,38 +88,6 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
ignore_backoff=True, ignore_backoff=True,
) )
def test_delete_alias_not_allowed(self):
room_id = "!8765qwer:test"
self.get_success(
self.store.create_room_alias_association(self.my_room, room_id, ["test"])
)
self.get_failure(
self.handler.delete_association(
create_requester("@user:test"), self.my_room
),
synapse.api.errors.AuthError,
)
def test_delete_alias(self):
room_id = "!8765qwer:test"
user_id = "@user:test"
self.get_success(
self.store.create_room_alias_association(
self.my_room, room_id, ["test"], user_id
)
)
result = self.get_success(
self.handler.delete_association(create_requester(user_id), self.my_room)
)
self.assertEquals(room_id, result)
# The alias should not be found.
self.get_failure(
self.handler.get_association(self.my_room), synapse.api.errors.SynapseError
)
def test_incoming_fed_query(self): def test_incoming_fed_query(self):
self.get_success( self.get_success(
self.store.create_room_alias_association( self.store.create_room_alias_association(
@ -133,6 +102,119 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
self.assertEquals({"room_id": "!8765asdf:test", "servers": ["test"]}, response) self.assertEquals({"room_id": "!8765asdf:test", "servers": ["test"]}, response)
class TestDeleteAlias(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
room.register_servlets,
directory.register_servlets,
]
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.handler = hs.get_handlers().directory_handler
self.state_handler = hs.get_state_handler()
# Create user
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
# Create a test room
self.room_id = self.helper.create_room_as(
self.admin_user, tok=self.admin_user_tok
)
self.test_alias = "#test:test"
self.room_alias = RoomAlias.from_string(self.test_alias)
# Create a test user.
self.test_user = self.register_user("user", "pass", admin=False)
self.test_user_tok = self.login("user", "pass")
self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok)
def _create_alias(self, user):
# Create a new alias to this room.
self.get_success(
self.store.create_room_alias_association(
self.room_alias, self.room_id, ["test"], user
)
)
def test_delete_alias_not_allowed(self):
"""A user that doesn't meet the expected guidelines cannot delete an alias."""
self._create_alias(self.admin_user)
self.get_failure(
self.handler.delete_association(
create_requester(self.test_user), self.room_alias
),
synapse.api.errors.AuthError,
)
def test_delete_alias_creator(self):
"""An alias creator can delete their own alias."""
# Create an alias from a different user.
self._create_alias(self.test_user)
# Delete the user's alias.
result = self.get_success(
self.handler.delete_association(
create_requester(self.test_user), self.room_alias
)
)
self.assertEquals(self.room_id, result)
# Confirm the alias is gone.
self.get_failure(
self.handler.get_association(self.room_alias),
synapse.api.errors.SynapseError,
)
def test_delete_alias_admin(self):
"""A server admin can delete an alias created by another user."""
# Create an alias from a different user.
self._create_alias(self.test_user)
# Delete the user's alias as the admin.
result = self.get_success(
self.handler.delete_association(
create_requester(self.admin_user), self.room_alias
)
)
self.assertEquals(self.room_id, result)
# Confirm the alias is gone.
self.get_failure(
self.handler.get_association(self.room_alias),
synapse.api.errors.SynapseError,
)
def test_delete_alias_sufficient_power(self):
"""A user with a sufficient power level should be able to delete an alias."""
self._create_alias(self.admin_user)
# Increase the user's power level.
self.helper.send_state(
self.room_id,
"m.room.power_levels",
{"users": {self.test_user: 100}},
tok=self.admin_user_tok,
)
# They can now delete the alias.
result = self.get_success(
self.handler.delete_association(
create_requester(self.test_user), self.room_alias
)
)
self.assertEquals(self.room_id, result)
# Confirm the alias is gone.
self.get_failure(
self.handler.get_association(self.room_alias),
synapse.api.errors.SynapseError,
)
class CanonicalAliasTestCase(unittest.HomeserverTestCase): class CanonicalAliasTestCase(unittest.HomeserverTestCase):
"""Test modifications of the canonical alias when delete aliases. """Test modifications of the canonical alias when delete aliases.
""" """
@ -159,30 +241,42 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
) )
self.test_alias = "#test:test" self.test_alias = "#test:test"
self.room_alias = RoomAlias.from_string(self.test_alias) self.room_alias = self._add_alias(self.test_alias)
def _add_alias(self, alias: str) -> RoomAlias:
"""Add an alias to the test room."""
room_alias = RoomAlias.from_string(alias)
# Create a new alias to this room. # Create a new alias to this room.
self.get_success( self.get_success(
self.store.create_room_alias_association( self.store.create_room_alias_association(
self.room_alias, self.room_id, ["test"], self.admin_user room_alias, self.room_id, ["test"], self.admin_user
)
)
return room_alias
def _set_canonical_alias(self, content):
"""Configure the canonical alias state on the room."""
self.helper.send_state(
self.room_id, "m.room.canonical_alias", content, tok=self.admin_user_tok,
)
def _get_canonical_alias(self):
"""Get the canonical alias state of the room."""
return self.get_success(
self.state_handler.get_current_state(
self.room_id, EventTypes.CanonicalAlias, ""
) )
) )
def test_remove_alias(self): def test_remove_alias(self):
"""Removing an alias that is the canonical alias should remove it there too.""" """Removing an alias that is the canonical alias should remove it there too."""
# Set this new alias as the canonical alias for this room # Set this new alias as the canonical alias for this room
self.helper.send_state( self._set_canonical_alias(
self.room_id, {"alias": self.test_alias, "alt_aliases": [self.test_alias]}
"m.room.canonical_alias",
{"alias": self.test_alias, "alt_aliases": [self.test_alias]},
tok=self.admin_user_tok,
) )
data = self.get_success( data = self._get_canonical_alias()
self.state_handler.get_current_state(
self.room_id, EventTypes.CanonicalAlias, ""
)
)
self.assertEqual(data["content"]["alias"], self.test_alias) self.assertEqual(data["content"]["alias"], self.test_alias)
self.assertEqual(data["content"]["alt_aliases"], [self.test_alias]) self.assertEqual(data["content"]["alt_aliases"], [self.test_alias])
@ -193,11 +287,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
) )
) )
data = self.get_success( data = self._get_canonical_alias()
self.state_handler.get_current_state(
self.room_id, EventTypes.CanonicalAlias, ""
)
)
self.assertNotIn("alias", data["content"]) self.assertNotIn("alias", data["content"])
self.assertNotIn("alt_aliases", data["content"]) self.assertNotIn("alt_aliases", data["content"])
@ -205,29 +295,17 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
"""Removing an alias listed as in alt_aliases should remove it there too.""" """Removing an alias listed as in alt_aliases should remove it there too."""
# Create a second alias. # Create a second alias.
other_test_alias = "#test2:test" other_test_alias = "#test2:test"
other_room_alias = RoomAlias.from_string(other_test_alias) other_room_alias = self._add_alias(other_test_alias)
self.get_success(
self.store.create_room_alias_association(
other_room_alias, self.room_id, ["test"], self.admin_user
)
)
# Set the alias as the canonical alias for this room. # Set the alias as the canonical alias for this room.
self.helper.send_state( self._set_canonical_alias(
self.room_id,
"m.room.canonical_alias",
{ {
"alias": self.test_alias, "alias": self.test_alias,
"alt_aliases": [self.test_alias, other_test_alias], "alt_aliases": [self.test_alias, other_test_alias],
}, }
tok=self.admin_user_tok,
) )
data = self.get_success( data = self._get_canonical_alias()
self.state_handler.get_current_state(
self.room_id, EventTypes.CanonicalAlias, ""
)
)
self.assertEqual(data["content"]["alias"], self.test_alias) self.assertEqual(data["content"]["alias"], self.test_alias)
self.assertEqual( self.assertEqual(
data["content"]["alt_aliases"], [self.test_alias, other_test_alias] data["content"]["alt_aliases"], [self.test_alias, other_test_alias]
@ -240,11 +318,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
) )
) )
data = self.get_success( data = self._get_canonical_alias()
self.state_handler.get_current_state(
self.room_id, EventTypes.CanonicalAlias, ""
)
)
self.assertEqual(data["content"]["alias"], self.test_alias) self.assertEqual(data["content"]["alias"], self.test_alias)
self.assertEqual(data["content"]["alt_aliases"], [self.test_alias]) self.assertEqual(data["content"]["alt_aliases"], [self.test_alias])

View File

@ -15,6 +15,7 @@ import logging
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from synapse.api.room_versions import RoomVersions
from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.handlers.room import RoomEventSource from synapse.handlers.room import RoomEventSource
@ -58,6 +59,15 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(FrozenEvent)] self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(FrozenEvent)]
return super(SlavedEventStoreTestCase, self).setUp() return super(SlavedEventStoreTestCase, self).setUp()
def prepare(self, *args, **kwargs):
super().prepare(*args, **kwargs)
self.get_success(
self.master_store.store_room(
ROOM_ID, USER_ID, is_public=False, room_version=RoomVersions.V1,
)
)
def tearDown(self): def tearDown(self):
[unpatch() for unpatch in self.unpatches] [unpatch() for unpatch in self.unpatches]

View File

@ -16,6 +16,7 @@
import hashlib import hashlib
import hmac import hmac
import json import json
import urllib.parse
from mock import Mock from mock import Mock
@ -371,22 +372,24 @@ class UserRestTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs): def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.url = "/_synapse/admin/v2/users/@bob:test"
self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass") self.admin_user_tok = self.login("admin", "pass")
self.other_user = self.register_user("user", "pass") self.other_user = self.register_user("user", "pass")
self.other_user_token = self.login("user", "pass") self.other_user_token = self.login("user", "pass")
self.url_other_user = "/_synapse/admin/v2/users/%s" % urllib.parse.quote(
self.other_user
)
def test_requester_is_no_admin(self): def test_requester_is_no_admin(self):
""" """
If the user is not a server admin, an error is returned. If the user is not a server admin, an error is returned.
""" """
self.hs.config.registration_shared_secret = None self.hs.config.registration_shared_secret = None
url = "/_synapse/admin/v2/users/@bob:test"
request, channel = self.make_request( request, channel = self.make_request(
"GET", self.url, access_token=self.other_user_token, "GET", url, access_token=self.other_user_token,
) )
self.render(request) self.render(request)
@ -394,7 +397,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("You are not a server admin", channel.json_body["error"]) self.assertEqual("You are not a server admin", channel.json_body["error"])
request, channel = self.make_request( request, channel = self.make_request(
"PUT", self.url, access_token=self.other_user_token, content=b"{}", "PUT", url, access_token=self.other_user_token, content=b"{}",
) )
self.render(request) self.render(request)
@ -417,24 +420,26 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual("M_NOT_FOUND", channel.json_body["errcode"]) self.assertEqual("M_NOT_FOUND", channel.json_body["errcode"])
def test_requester_is_admin(self): def test_create_server_admin(self):
""" """
If the user is a server admin, a new user is created. Check that a new admin user is created successfully.
""" """
self.hs.config.registration_shared_secret = None self.hs.config.registration_shared_secret = None
url = "/_synapse/admin/v2/users/@bob:test"
# Create user (server admin)
body = json.dumps( body = json.dumps(
{ {
"password": "abc123", "password": "abc123",
"admin": True, "admin": True,
"displayname": "Bob's name",
"threepids": [{"medium": "email", "address": "bob@bob.bob"}], "threepids": [{"medium": "email", "address": "bob@bob.bob"}],
} }
) )
# Create user
request, channel = self.make_request( request, channel = self.make_request(
"PUT", "PUT",
self.url, url,
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"), content=body.encode(encoding="utf_8"),
) )
@ -442,29 +447,85 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("bob", channel.json_body["displayname"]) self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
self.assertEqual(True, channel.json_body["admin"])
# Get user # Get user
request, channel = self.make_request( request, channel = self.make_request(
"GET", self.url, access_token=self.admin_user_tok, "GET", url, access_token=self.admin_user_tok,
) )
self.render(request) self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("bob", channel.json_body["displayname"]) self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual(1, channel.json_body["admin"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual(0, channel.json_body["is_guest"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
self.assertEqual(0, channel.json_body["deactivated"]) self.assertEqual(True, channel.json_body["admin"])
self.assertEqual(False, channel.json_body["is_guest"])
self.assertEqual(False, channel.json_body["deactivated"])
def test_create_user(self):
"""
Check that a new regular user is created successfully.
"""
self.hs.config.registration_shared_secret = None
url = "/_synapse/admin/v2/users/@bob:test"
# Create user
body = json.dumps(
{
"password": "abc123",
"admin": False,
"displayname": "Bob's name",
"threepids": [{"medium": "email", "address": "bob@bob.bob"}],
}
)
request, channel = self.make_request(
"PUT",
url,
access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"),
)
self.render(request)
self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
self.assertEqual(False, channel.json_body["admin"])
# Get user
request, channel = self.make_request(
"GET", url, access_token=self.admin_user_tok,
)
self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
self.assertEqual(False, channel.json_body["admin"])
self.assertEqual(False, channel.json_body["is_guest"])
self.assertEqual(False, channel.json_body["deactivated"])
def test_set_password(self):
"""
Test setting a new password for another user.
"""
self.hs.config.registration_shared_secret = None
# Change password # Change password
body = json.dumps({"password": "hahaha"}) body = json.dumps({"password": "hahaha"})
request, channel = self.make_request( request, channel = self.make_request(
"PUT", "PUT",
self.url, self.url_other_user,
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"), content=body.encode(encoding="utf_8"),
) )
@ -472,41 +533,133 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
def test_set_displayname(self):
"""
Test setting the displayname of another user.
"""
self.hs.config.registration_shared_secret = None
# Modify user # Modify user
body = json.dumps({"displayname": "foobar"})
request, channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"),
)
self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual("foobar", channel.json_body["displayname"])
# Get user
request, channel = self.make_request(
"GET", self.url_other_user, access_token=self.admin_user_tok,
)
self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual("foobar", channel.json_body["displayname"])
def test_set_threepid(self):
"""
Test setting threepid for an other user.
"""
self.hs.config.registration_shared_secret = None
# Delete old and add new threepid to user
body = json.dumps( body = json.dumps(
{ {"threepids": [{"medium": "email", "address": "bob3@bob.bob"}]}
"displayname": "foobar",
"deactivated": True,
"threepids": [{"medium": "email", "address": "bob2@bob.bob"}],
}
) )
request, channel = self.make_request( request, channel = self.make_request(
"PUT", "PUT",
self.url, self.url_other_user,
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"), content=body.encode(encoding="utf_8"),
) )
self.render(request) self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual("foobar", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"])
# Get user
request, channel = self.make_request(
"GET", self.url_other_user, access_token=self.admin_user_tok,
)
self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"])
def test_deactivate_user(self):
"""
Test deactivating another user.
"""
# Deactivate user
body = json.dumps({"deactivated": True})
request, channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"),
)
self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["deactivated"]) self.assertEqual(True, channel.json_body["deactivated"])
# the user is deactivated, the threepid will be deleted # the user is deactivated, the threepid will be deleted
# Get user # Get user
request, channel = self.make_request( request, channel = self.make_request(
"GET", self.url, access_token=self.admin_user_tok, "GET", self.url_other_user, access_token=self.admin_user_tok,
) )
self.render(request) self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual("foobar", channel.json_body["displayname"]) self.assertEqual(True, channel.json_body["deactivated"])
self.assertEqual(1, channel.json_body["admin"])
self.assertEqual(0, channel.json_body["is_guest"]) def test_set_user_as_admin(self):
self.assertEqual(1, channel.json_body["deactivated"]) """
Test setting the admin flag on a user.
"""
self.hs.config.registration_shared_secret = None
# Set a user as an admin
body = json.dumps({"admin": True})
request, channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"),
)
self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["admin"])
# Get user
request, channel = self.make_request(
"GET", self.url_other_user, access_token=self.admin_user_tok,
)
self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["admin"])
def test_accidental_deactivation_prevention(self): def test_accidental_deactivation_prevention(self):
""" """
@ -514,13 +667,14 @@ class UserRestTestCase(unittest.HomeserverTestCase):
for the deactivated body parameter for the deactivated body parameter
""" """
self.hs.config.registration_shared_secret = None self.hs.config.registration_shared_secret = None
url = "/_synapse/admin/v2/users/@bob:test"
# Create user # Create user
body = json.dumps({"password": "abc123"}) body = json.dumps({"password": "abc123"})
request, channel = self.make_request( request, channel = self.make_request(
"PUT", "PUT",
self.url, url,
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"), content=body.encode(encoding="utf_8"),
) )
@ -532,7 +686,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Get user # Get user
request, channel = self.make_request( request, channel = self.make_request(
"GET", self.url, access_token=self.admin_user_tok, "GET", url, access_token=self.admin_user_tok,
) )
self.render(request) self.render(request)
@ -546,7 +700,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request( request, channel = self.make_request(
"PUT", "PUT",
self.url, url,
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"), content=body.encode(encoding="utf_8"),
) )
@ -556,7 +710,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Check user is not deactivated # Check user is not deactivated
request, channel = self.make_request( request, channel = self.make_request(
"GET", self.url, access_token=self.admin_user_tok, "GET", url, access_token=self.admin_user_tok,
) )
self.render(request) self.render(request)

View File

@ -1,4 +1,7 @@
import json import json
import urllib.parse
from mock import Mock
import synapse.rest.admin import synapse.rest.admin
from synapse.rest.client.v1 import login from synapse.rest.client.v1 import login
@ -252,3 +255,111 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
) )
self.render(request) self.render(request)
self.assertEquals(channel.code, 200, channel.result) self.assertEquals(channel.code, 200, channel.result)
class CASRedirectConfirmTestCase(unittest.HomeserverTestCase):
servlets = [
login.register_servlets,
]
def make_homeserver(self, reactor, clock):
self.base_url = "https://matrix.goodserver.com/"
self.redirect_path = "_synapse/client/login/sso/redirect/confirm"
config = self.default_config()
config["cas_config"] = {
"enabled": True,
"server_url": "https://fake.test",
"service_url": "https://matrix.goodserver.com:8448",
}
async def get_raw(uri, args):
"""Return an example response payload from a call to the `/proxyValidate`
endpoint of a CAS server, copied from
https://apereo.github.io/cas/5.0.x/protocol/CAS-Protocol-V2-Specification.html#26-proxyvalidate-cas-20
This needs to be returned by an async function (as opposed to set as the
mock's return value) because the corresponding Synapse code awaits on it.
"""
return """
<cas:serviceResponse xmlns:cas='http://www.yale.edu/tp/cas'>
<cas:authenticationSuccess>
<cas:user>username</cas:user>
<cas:proxyGrantingTicket>PGTIOU-84678-8a9d...</cas:proxyGrantingTicket>
<cas:proxies>
<cas:proxy>https://proxy2/pgtUrl</cas:proxy>
<cas:proxy>https://proxy1/pgtUrl</cas:proxy>
</cas:proxies>
</cas:authenticationSuccess>
</cas:serviceResponse>
"""
mocked_http_client = Mock(spec=["get_raw"])
mocked_http_client.get_raw.side_effect = get_raw
self.hs = self.setup_test_homeserver(
config=config, proxied_http_client=mocked_http_client,
)
return self.hs
def test_cas_redirect_confirm(self):
"""Tests that the SSO login flow serves a confirmation page before redirecting a
user to the redirect URL.
"""
base_url = "/_matrix/client/r0/login/cas/ticket?redirectUrl"
redirect_url = "https://dodgy-site.com/"
url_parts = list(urllib.parse.urlparse(base_url))
query = dict(urllib.parse.parse_qsl(url_parts[4]))
query.update({"redirectUrl": redirect_url})
query.update({"ticket": "ticket"})
url_parts[4] = urllib.parse.urlencode(query)
cas_ticket_url = urllib.parse.urlunparse(url_parts)
# Get Synapse to call the fake CAS and serve the template.
request, channel = self.make_request("GET", cas_ticket_url)
self.render(request)
# Test that the response is HTML.
self.assertEqual(channel.code, 200)
content_type_header_value = ""
for header in channel.result.get("headers", []):
if header[0] == b"Content-Type":
content_type_header_value = header[1].decode("utf8")
self.assertTrue(content_type_header_value.startswith("text/html"))
# Test that the body isn't empty.
self.assertTrue(len(channel.result["body"]) > 0)
# And that it contains our redirect link
self.assertIn(redirect_url, channel.result["body"].decode("UTF-8"))
@override_config(
{
"sso": {
"client_whitelist": [
"https://legit-site.com/",
"https://other-site.com/",
]
}
}
)
def test_cas_redirect_whitelisted(self):
"""Tests that the SSO login flow serves a redirect to a whitelisted url
"""
redirect_url = "https://legit-site.com/"
cas_ticket_url = (
"/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket"
% (urllib.parse.quote(redirect_url))
)
# Get Synapse to call the fake CAS and serve the template.
request, channel = self.make_request("GET", cas_ticket_url)
self.render(request)
self.assertEqual(channel.code, 302)
location_headers = channel.headers.getRawHeaders("Location")
self.assertEqual(location_headers[0][: len(redirect_url)], redirect_url)

View File

@ -1821,3 +1821,163 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase):
) )
self.render(request) self.render(request)
self.assertEqual(channel.code, expected_code, channel.result) self.assertEqual(channel.code, expected_code, channel.result)
class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
directory.register_servlets,
login.register_servlets,
room.register_servlets,
]
def prepare(self, reactor, clock, homeserver):
self.room_owner = self.register_user("room_owner", "test")
self.room_owner_tok = self.login("room_owner", "test")
self.room_id = self.helper.create_room_as(
self.room_owner, tok=self.room_owner_tok
)
self.alias = "#alias:test"
self._set_alias_via_directory(self.alias)
def _set_alias_via_directory(self, alias: str, expected_code: int = 200):
url = "/_matrix/client/r0/directory/room/" + alias
data = {"room_id": self.room_id}
request_data = json.dumps(data)
request, channel = self.make_request(
"PUT", url, request_data, access_token=self.room_owner_tok
)
self.render(request)
self.assertEqual(channel.code, expected_code, channel.result)
def _get_canonical_alias(self, expected_code: int = 200) -> JsonDict:
"""Calls the endpoint under test. returns the json response object."""
request, channel = self.make_request(
"GET",
"rooms/%s/state/m.room.canonical_alias" % (self.room_id,),
access_token=self.room_owner_tok,
)
self.render(request)
self.assertEqual(channel.code, expected_code, channel.result)
res = channel.json_body
self.assertIsInstance(res, dict)
return res
def _set_canonical_alias(self, content: str, expected_code: int = 200) -> JsonDict:
"""Calls the endpoint under test. returns the json response object."""
request, channel = self.make_request(
"PUT",
"rooms/%s/state/m.room.canonical_alias" % (self.room_id,),
json.dumps(content),
access_token=self.room_owner_tok,
)
self.render(request)
self.assertEqual(channel.code, expected_code, channel.result)
res = channel.json_body
self.assertIsInstance(res, dict)
return res
def test_canonical_alias(self):
"""Test a basic alias message."""
# There is no canonical alias to start with.
self._get_canonical_alias(expected_code=404)
# Create an alias.
self._set_canonical_alias({"alias": self.alias})
# Canonical alias now exists!
res = self._get_canonical_alias()
self.assertEqual(res, {"alias": self.alias})
# Now remove the alias.
self._set_canonical_alias({})
# There is an alias event, but it is empty.
res = self._get_canonical_alias()
self.assertEqual(res, {})
def test_alt_aliases(self):
"""Test a canonical alias message with alt_aliases."""
# Create an alias.
self._set_canonical_alias({"alt_aliases": [self.alias]})
# Canonical alias now exists!
res = self._get_canonical_alias()
self.assertEqual(res, {"alt_aliases": [self.alias]})
# Now remove the alt_aliases.
self._set_canonical_alias({})
# There is an alias event, but it is empty.
res = self._get_canonical_alias()
self.assertEqual(res, {})
def test_alias_alt_aliases(self):
"""Test a canonical alias message with an alias and alt_aliases."""
# Create an alias.
self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]})
# Canonical alias now exists!
res = self._get_canonical_alias()
self.assertEqual(res, {"alias": self.alias, "alt_aliases": [self.alias]})
# Now remove the alias and alt_aliases.
self._set_canonical_alias({})
# There is an alias event, but it is empty.
res = self._get_canonical_alias()
self.assertEqual(res, {})
def test_partial_modify(self):
"""Test removing only the alt_aliases."""
# Create an alias.
self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]})
# Canonical alias now exists!
res = self._get_canonical_alias()
self.assertEqual(res, {"alias": self.alias, "alt_aliases": [self.alias]})
# Now remove the alt_aliases.
self._set_canonical_alias({"alias": self.alias})
# There is an alias event, but it is empty.
res = self._get_canonical_alias()
self.assertEqual(res, {"alias": self.alias})
def test_add_alias(self):
"""Test removing only the alt_aliases."""
# Create an additional alias.
second_alias = "#second:test"
self._set_alias_via_directory(second_alias)
# Add the canonical alias.
self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]})
# Then add the second alias.
self._set_canonical_alias(
{"alias": self.alias, "alt_aliases": [self.alias, second_alias]}
)
# Canonical alias now exists!
res = self._get_canonical_alias()
self.assertEqual(
res, {"alias": self.alias, "alt_aliases": [self.alias, second_alias]}
)
def test_bad_data(self):
"""Invalid data for alt_aliases should cause errors."""
self._set_canonical_alias({"alt_aliases": "@bad:test"}, expected_code=400)
self._set_canonical_alias({"alt_aliases": None}, expected_code=400)
self._set_canonical_alias({"alt_aliases": 0}, expected_code=400)
self._set_canonical_alias({"alt_aliases": 1}, expected_code=400)
self._set_canonical_alias({"alt_aliases": False}, expected_code=400)
self._set_canonical_alias({"alt_aliases": True}, expected_code=400)
self._set_canonical_alias({"alt_aliases": {}}, expected_code=400)
def test_bad_alias(self):
"""An alias which does not point to the room raises a SynapseError."""
self._set_canonical_alias({"alias": "@unknown:test"}, expected_code=400)
self._set_canonical_alias({"alt_aliases": ["@unknown:test"]}, expected_code=400)

View File

@ -303,3 +303,45 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.pump() self.pump()
self.store.upsert_monthly_active_user.assert_not_called() self.store.upsert_monthly_active_user.assert_not_called()
def test_get_monthly_active_count_by_service(self):
appservice1_user1 = "@appservice1_user1:example.com"
appservice1_user2 = "@appservice1_user2:example.com"
appservice2_user1 = "@appservice2_user1:example.com"
native_user1 = "@native_user1:example.com"
service1 = "service1"
service2 = "service2"
native = "native"
self.store.register_user(
user_id=appservice1_user1, password_hash=None, appservice_id=service1
)
self.store.register_user(
user_id=appservice1_user2, password_hash=None, appservice_id=service1
)
self.store.register_user(
user_id=appservice2_user1, password_hash=None, appservice_id=service2
)
self.store.register_user(user_id=native_user1, password_hash=None)
self.pump()
count = self.store.get_monthly_active_count_by_service()
self.assertEqual({}, self.get_success(count))
self.store.upsert_monthly_active_user(native_user1)
self.store.upsert_monthly_active_user(appservice1_user1)
self.store.upsert_monthly_active_user(appservice1_user2)
self.store.upsert_monthly_active_user(appservice2_user1)
self.pump()
count = self.store.get_monthly_active_count()
self.assertEqual(4, self.get_success(count))
count = self.store.get_monthly_active_count_by_service()
result = self.get_success(count)
self.assertEqual(2, result[service1])
self.assertEqual(1, result[service2])
self.assertEqual(1, result[native])

View File

@ -19,6 +19,7 @@ from synapse import event_auth
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.api.room_versions import RoomVersions from synapse.api.room_versions import RoomVersions
from synapse.events import make_event_from_dict from synapse.events import make_event_from_dict
from synapse.types import get_domain_from_id
class EventAuthTestCase(unittest.TestCase): class EventAuthTestCase(unittest.TestCase):
@ -51,7 +52,7 @@ class EventAuthTestCase(unittest.TestCase):
_random_state_event(joiner), _random_state_event(joiner),
auth_events, auth_events,
do_sig_check=False, do_sig_check=False,
), )
def test_state_default_level(self): def test_state_default_level(self):
""" """
@ -87,6 +88,83 @@ class EventAuthTestCase(unittest.TestCase):
RoomVersions.V1, _random_state_event(king), auth_events, do_sig_check=False, RoomVersions.V1, _random_state_event(king), auth_events, do_sig_check=False,
) )
def test_alias_event(self):
"""Alias events have special behavior up through room version 6."""
creator = "@creator:example.com"
other = "@other:example.com"
auth_events = {
("m.room.create", ""): _create_event(creator),
("m.room.member", creator): _join_event(creator),
}
# creator should be able to send aliases
event_auth.check(
RoomVersions.V1, _alias_event(creator), auth_events, do_sig_check=False,
)
# Reject an event with no state key.
with self.assertRaises(AuthError):
event_auth.check(
RoomVersions.V1,
_alias_event(creator, state_key=""),
auth_events,
do_sig_check=False,
)
# If the domain of the sender does not match the state key, reject.
with self.assertRaises(AuthError):
event_auth.check(
RoomVersions.V1,
_alias_event(creator, state_key="test.com"),
auth_events,
do_sig_check=False,
)
# Note that the member does *not* need to be in the room.
event_auth.check(
RoomVersions.V1, _alias_event(other), auth_events, do_sig_check=False,
)
def test_msc2432_alias_event(self):
"""After MSC2432, alias events have no special behavior."""
creator = "@creator:example.com"
other = "@other:example.com"
auth_events = {
("m.room.create", ""): _create_event(creator),
("m.room.member", creator): _join_event(creator),
}
# creator should be able to send aliases
event_auth.check(
RoomVersions.MSC2432_DEV,
_alias_event(creator),
auth_events,
do_sig_check=False,
)
# No particular checks are done on the state key.
event_auth.check(
RoomVersions.MSC2432_DEV,
_alias_event(creator, state_key=""),
auth_events,
do_sig_check=False,
)
event_auth.check(
RoomVersions.MSC2432_DEV,
_alias_event(creator, state_key="test.com"),
auth_events,
do_sig_check=False,
)
# Per standard auth rules, the member must be in the room.
with self.assertRaises(AuthError):
event_auth.check(
RoomVersions.MSC2432_DEV,
_alias_event(other),
auth_events,
do_sig_check=False,
)
# helpers for making events # helpers for making events
@ -131,6 +209,19 @@ def _power_levels_event(sender, content):
) )
def _alias_event(sender, **kwargs):
data = {
"room_id": TEST_ROOM_ID,
"event_id": _get_event_id(),
"type": "m.room.aliases",
"sender": sender,
"state_key": get_domain_from_id(sender),
"content": {"aliases": []},
}
data.update(**kwargs)
return make_event_from_dict(data)
def _random_state_event(sender): def _random_state_event(sender):
return make_event_from_dict( return make_event_from_dict(
{ {

View File

@ -75,7 +75,7 @@ class GroupIDTestCase(unittest.TestCase):
self.fail("Parsing '%s' should raise exception" % id_string) self.fail("Parsing '%s' should raise exception" % id_string)
except SynapseError as exc: except SynapseError as exc:
self.assertEqual(400, exc.code) self.assertEqual(400, exc.code)
self.assertEqual("M_UNKNOWN", exc.errcode) self.assertEqual("M_INVALID_PARAM", exc.errcode)
class MapUsernameTestCase(unittest.TestCase): class MapUsernameTestCase(unittest.TestCase):

View File

@ -168,7 +168,6 @@ commands=
coverage html coverage html
[testenv:mypy] [testenv:mypy]
basepython = python3.7
skip_install = True skip_install = True
deps = deps =
{[base]deps} {[base]deps}
@ -179,10 +178,14 @@ env =
extras = all extras = all
commands = mypy \ commands = mypy \
synapse/api \ synapse/api \
synapse/config/ \ synapse/appservice \
synapse/config \
synapse/events/spamcheck.py \ synapse/events/spamcheck.py \
synapse/federation/federation_base.py \
synapse/federation/federation_client.py \
synapse/federation/sender \ synapse/federation/sender \
synapse/federation/transport \ synapse/federation/transport \
synapse/handlers/directory.py \
synapse/handlers/presence.py \ synapse/handlers/presence.py \
synapse/handlers/sync.py \ synapse/handlers/sync.py \
synapse/handlers/ui_auth \ synapse/handlers/ui_auth \
@ -192,6 +195,7 @@ commands = mypy \
synapse/rest \ synapse/rest \
synapse/spam_checker_api \ synapse/spam_checker_api \
synapse/storage/engines \ synapse/storage/engines \
synapse/storage/database.py \
synapse/streams synapse/streams
# To find all folders that pass mypy you run: # To find all folders that pass mypy you run: