Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes

pull/8675/head
Richard van der Hoff 2020-05-01 09:26:57 +01:00
commit e9bd4bb388
70 changed files with 1898 additions and 382 deletions

1
changelog.d/6881.misc Normal file
View File

@ -0,0 +1 @@
Return total number of users and profile attributes in admin users endpoint. Contributed by Awesome Technologies Innovationslabor GmbH.

1
changelog.d/7091.doc Normal file
View File

@ -0,0 +1 @@
Improve the documentation of application service configuration files.

1
changelog.d/7146.misc Normal file
View File

@ -0,0 +1 @@
Run replication streamers on workers.

1
changelog.d/7278.misc Normal file
View File

@ -0,0 +1 @@
Add some unit tests for replication.

1
changelog.d/7302.bugfix Normal file
View File

@ -0,0 +1 @@
Persist user interactive authentication sessions across workers and Synapse restarts.

1
changelog.d/7316.bugfix Normal file
View File

@ -0,0 +1 @@
Fixed backwards compatibility logic of the first value of `trusted_third_party_id_servers` being used for `account_threepid_delegates.email`, which occurs when the former, deprecated option is set and the latter is not.

1
changelog.d/7338.misc Normal file
View File

@ -0,0 +1 @@
Convert some federation handler code to async/await.

1
changelog.d/7341.bugfix Normal file
View File

@ -0,0 +1 @@
Fix bad error handling that would cause Synapse to crash if it's provided with a YAML configuration file that's either empty or doesn't parse into a key-value map.

1
changelog.d/7343.feature Normal file
View File

@ -0,0 +1 @@
Support SSO in the user interactive authentication workflow.

1
changelog.d/7344.bugfix Normal file
View File

@ -0,0 +1 @@
Fix incorrect metrics reporting for `renew_attestations` background task.

1
changelog.d/7352.feature Normal file
View File

@ -0,0 +1 @@
Add support for running replication over Redis when using workers.

1
changelog.d/7357.doc Normal file
View File

@ -0,0 +1 @@
Add documentation on monitoring workers with Prometheus.

1
changelog.d/7358.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a bug where event updates might not be sent over replication to worker processes after the stream falls behind.

1
changelog.d/7359.misc Normal file
View File

@ -0,0 +1 @@
Fix collation for postgres for unit tests.

1
changelog.d/7361.doc Normal file
View File

@ -0,0 +1 @@
Clarify endpoint usage in the users admin api documentation.

1
changelog.d/7364.misc Normal file
View File

@ -0,0 +1 @@
Add an `instance_name` to `RDATA` and `POSITION` replication commands.

1
changelog.d/7367.bugfix Normal file
View File

@ -0,0 +1 @@
Prevent non-federating rooms from appearing in responses to federated `POST /publicRoom` requests when a filter was included.

1
changelog.d/7378.misc Normal file
View File

@ -0,0 +1 @@
Move catchup of replication streams logic to worker.

View File

@ -33,12 +33,22 @@ with a body of:
including an ``access_token`` of a server admin.
The parameter ``displayname`` is optional and defaults to ``user_id``.
The parameter ``threepids`` is optional.
The parameter ``avatar_url`` is optional.
The parameter ``admin`` is optional and defaults to 'false'.
The parameter ``deactivated`` is optional and defaults to 'false'.
The parameter ``password`` is optional. If provided the user's password is updated and all devices are logged out.
The parameter ``displayname`` is optional and defaults to the value of
``user_id``.
The parameter ``threepids`` is optional and allows setting the third-party IDs
(email, msisdn) belonging to a user.
The parameter ``avatar_url`` is optional. Must be a [MXC
URI](https://matrix.org/docs/spec/client_server/r0.6.0#matrix-content-mxc-uris).
The parameter ``admin`` is optional and defaults to ``false``.
The parameter ``deactivated`` is optional and defaults to ``false``.
The parameter ``password`` is optional. If provided, the user's password is
updated and all devices are logged out.
If the user already exists then optional parameters default to the current value.
List Accounts
@ -51,16 +61,25 @@ The api is::
GET /_synapse/admin/v2/users?from=0&limit=10&guests=false
including an ``access_token`` of a server admin.
The parameters ``from`` and ``limit`` are required only for pagination.
By default, a ``limit`` of 100 is used.
The parameter ``user_id`` can be used to select only users with user ids that
contain this value.
The parameter ``guests=false`` can be used to exclude guest users,
default is to include guest users.
The parameter ``deactivated=true`` can be used to include deactivated users,
default is to exclude deactivated users.
If the endpoint does not return a ``next_token`` then there are no more users left.
It returns a JSON body like the following:
The parameter ``from`` is optional but used for pagination, denoting the
offset in the returned results. This should be treated as an opaque value and
not explicitly set to anything other than the return value of ``next_token``
from a previous call.
The parameter ``limit`` is optional but is used for pagination, denoting the
maximum number of items to return in this call. Defaults to ``100``.
The parameter ``user_id`` is optional and filters to only users with user IDs
that contain this value.
The parameter ``guests`` is optional and if ``false`` will **exclude** guest users.
Defaults to ``true`` to include guest users.
The parameter ``deactivated`` is optional and if ``true`` will **include** deactivated users.
Defaults to ``false`` to exclude deactivated users.
A JSON body is returned with the following shape:
.. code:: json
@ -72,19 +91,29 @@ It returns a JSON body like the following:
"is_guest": 0,
"admin": 0,
"user_type": null,
"deactivated": 0
"deactivated": 0,
"displayname": "<User One>",
"avatar_url": null
}, {
"name": "<user_id2>",
"password_hash": "<password_hash2>",
"is_guest": 0,
"admin": 1,
"user_type": null,
"deactivated": 0
"deactivated": 0,
"displayname": "<User Two>",
"avatar_url": "<avatar_url>"
}
],
"next_token": "100"
"next_token": "100",
"total": 200
}
To paginate, check for ``next_token`` and if present, call the endpoint again
with ``from`` set to the value of ``next_token``. This will return a new page.
If the endpoint does not return a ``next_token`` then there are no more users
to paginate through.
Query Account
=============

View File

@ -23,9 +23,13 @@ namespaces:
users: # List of users we're interested in
- exclusive: <bool>
regex: <regex>
group_id: <group>
- ...
aliases: [] # List of aliases we're interested in
rooms: [] # List of room ids we're interested in
```
`exclusive`: If enabled, only this application service is allowed to register users in its namespace(s).
`group_id`: All users of this application service are dynamically joined to this group. This is useful for e.g user organisation or flairs.
See the [spec](https://matrix.org/docs/spec/application_service/unstable.html) for further details on how application services work.

View File

@ -60,6 +60,31 @@
1. Restart Prometheus.
## Monitoring workers
To monitor a Synapse installation using
[workers](https://github.com/matrix-org/synapse/blob/master/docs/workers.md),
every worker needs to be monitored independently, in addition to
the main homeserver process. This is because workers don't send
their metrics to the main homeserver process, but expose them
directly (if they are configured to do so).
To allow collecting metrics from a worker, you need to add a
`metrics` listener to its configuration, by adding the following
under `worker_listeners`:
```yaml
- type: metrics
bind_address: ''
port: 9101
```
The `bind_address` and `port` parameters should be set so that
the resulting listener can be reached by prometheus, and they
don't clash with an existing worker.
With this example, the worker's metrics would then be available
on `http://127.0.0.1:9101`.
## Renaming of metrics & deprecation of old names in 1.2
Synapse 1.2 updates the Prometheus metrics to match the naming

View File

@ -1518,6 +1518,30 @@ sso:
#
# * server_name: the homeserver's name.
#
# * HTML page which notifies the user that they are authenticating to confirm
# an operation on their account during the user interactive authentication
# process: 'sso_auth_confirm.html'.
#
# When rendering, this template is given the following variables:
# * redirect_url: the URL the user is about to be redirected to. Needs
# manual escaping (see
# https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
#
# * description: the operation which the user is being asked to confirm
#
# * HTML page shown after a successful user interactive authentication session:
# 'sso_auth_success.html'.
#
# Note that this page must include the JavaScript which notifies of a successful authentication
# (see https://matrix.org/docs/spec/client_server/r0.6.0#fallback).
#
# This template has no additional variables.
#
# * HTML page shown during single sign-on if a deactivated user (according to Synapse's database)
# attempts to login: 'sso_account_deactivated.html'.
#
# This template has no additional variables.
#
# You can see the default templates at:
# https://github.com/matrix-org/synapse/tree/master/synapse/res/templates
#

View File

@ -15,15 +15,17 @@ example flow would be (where '>' indicates master to worker and
> SERVER example.com
< REPLICATE
> POSITION events 53
> RDATA events 54 ["$foo1:bar.com", ...]
> RDATA events 55 ["$foo4:bar.com", ...]
> POSITION events master 53
> RDATA events master 54 ["$foo1:bar.com", ...]
> RDATA events master 55 ["$foo4:bar.com", ...]
The example shows the server accepting a new connection and sending its identity
with the `SERVER` command, followed by the client server to respond with the
position of all streams. The server then periodically sends `RDATA` commands
which have the format `RDATA <stream_name> <token> <row>`, where the format of
`<row>` is defined by the individual streams.
which have the format `RDATA <stream_name> <instance_name> <token> <row>`, where
the format of `<row>` is defined by the individual streams. The
`<instance_name>` is the name of the Synapse process that generated the data
(usually "master").
Error reporting happens by either the client or server sending an ERROR
command, and usually the connection will be closed.
@ -52,7 +54,7 @@ The basic structure of the protocol is line based, where the initial
word of each line specifies the command. The rest of the line is parsed
based on the command. For example, the RDATA command is defined as:
RDATA <stream_name> <token> <row_json>
RDATA <stream_name> <instance_name> <token> <row_json>
(Note that <row_json> may contains spaces, but cannot contain
newlines.)
@ -136,11 +138,11 @@ the wire:
< NAME synapse.app.appservice
< PING 1490197665618
< REPLICATE
> POSITION events 1
> POSITION backfill 1
> POSITION caches 1
> RDATA caches 2 ["get_user_by_id",["@01register-user:localhost:8823"],1490197670513]
> RDATA events 14 ["$149019767112vOHxz:localhost:8823",
> POSITION events master 1
> POSITION backfill master 1
> POSITION caches master 1
> RDATA caches master 2 ["get_user_by_id",["@01register-user:localhost:8823"],1490197670513]
> RDATA events master 14 ["$149019767112vOHxz:localhost:8823",
"!AFDCvgApUmpdfVjIXm:localhost:8823","m.room.guest_access","",null]
< PING 1490197675618
> ERROR server stopping
@ -151,10 +153,10 @@ position without needing to send data with the `RDATA` command.
An example of a batched set of `RDATA` is:
> RDATA caches batch ["get_user_by_id",["@test:localhost:8823"],1490197670513]
> RDATA caches batch ["get_user_by_id",["@test2:localhost:8823"],1490197670513]
> RDATA caches batch ["get_user_by_id",["@test3:localhost:8823"],1490197670513]
> RDATA caches 54 ["get_user_by_id",["@test4:localhost:8823"],1490197670513]
> RDATA caches master batch ["get_user_by_id",["@test:localhost:8823"],1490197670513]
> RDATA caches master batch ["get_user_by_id",["@test2:localhost:8823"],1490197670513]
> RDATA caches master batch ["get_user_by_id",["@test3:localhost:8823"],1490197670513]
> RDATA caches master 54 ["get_user_by_id",["@test4:localhost:8823"],1490197670513]
In this case the client shouldn't advance their caches token until it
sees the the last `RDATA`.
@ -178,6 +180,11 @@ client (C):
updates, and if so then fetch them out of band. Sent in response to a
REPLICATE command (but can happen at any time).
The POSITION command includes the source of the stream. Currently all streams
are written by a single process (usually "master"). If fetching missing
updates via HTTP API, rather than via the DB, then processes should make the
request to the appropriate process.
#### ERROR (S, C)
There was an error
@ -234,12 +241,12 @@ Each individual cache invalidation results in a row being sent down
replication, which includes the cache name (the name of the function)
and they key to invalidate. For example:
> RDATA caches 550953771 ["get_user_by_id", ["@bob:example.com"], 1550574873251]
> RDATA caches master 550953771 ["get_user_by_id", ["@bob:example.com"], 1550574873251]
Alternatively, an entire cache can be invalidated by sending down a `null`
instead of the key. For example:
> RDATA caches 550953772 ["get_user_by_id", null, 1550574873252]
> RDATA caches master 550953772 ["get_user_by_id", null, 1550574873252]
However, there are times when a number of caches need to be invalidated
at the same time with the same key. To reduce traffic we batch those

View File

@ -270,7 +270,7 @@ def start(hs, listeners=None):
# Start the tracer
synapse.logging.opentracing.init_tracer( # type: ignore[attr-defined] # noqa
hs.config
hs
)
# It is now safe to start your Synapse.
@ -316,7 +316,7 @@ def setup_sentry(hs):
scope.set_tag("matrix_server_name", hs.config.server_name)
app = hs.config.worker_app if hs.config.worker_app else "synapse.app.homeserver"
name = hs.config.worker_name if hs.config.worker_name else "master"
name = hs.get_instance_name()
scope.set_tag("worker_app", app)
scope.set_tag("worker_name", name)

View File

@ -127,6 +127,7 @@ from synapse.storage.data_stores.main.monthly_active_users import (
MonthlyActiveUsersWorkerStore,
)
from synapse.storage.data_stores.main.presence import UserPresenceState
from synapse.storage.data_stores.main.ui_auth import UIAuthWorkerStore
from synapse.storage.data_stores.main.user_directory import UserDirectoryStore
from synapse.types import ReadReceipt
from synapse.util.async_helpers import Linearizer
@ -439,6 +440,7 @@ class GenericWorkerSlavedStore(
# FIXME(#3714): We need to add UserDirectoryStore as we write directly
# rather than going via the correct worker.
UserDirectoryStore,
UIAuthWorkerStore,
SlavedDeviceInboxStore,
SlavedDeviceStore,
SlavedReceiptsStore,
@ -960,17 +962,22 @@ def start(config_options):
synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
ss = GenericWorkerServer(
hs = GenericWorkerServer(
config.server_name,
config=config,
version_string="Synapse/" + get_version_string(synapse),
)
setup_logging(ss, config, use_worker_options=True)
setup_logging(hs, config, use_worker_options=True)
hs.setup()
# Ensure the replication streamer is always started in case we write to any
# streams. Will no-op if no streams can be written to by this worker.
hs.get_replication_streamer()
ss.setup()
reactor.addSystemEventTrigger(
"before", "startup", _base.start, ss, config.worker_listeners
"before", "startup", _base.start, hs, config.worker_listeners
)
_base.start_worker_reactor("synapse-generic-worker", config)

View File

@ -657,6 +657,12 @@ def read_config_files(config_files):
for config_file in config_files:
with open(config_file) as file_stream:
yaml_config = yaml.safe_load(file_stream)
if not isinstance(yaml_config, dict):
err = "File %r is empty or doesn't parse into a key-value map. IGNORING."
print(err % (config_file,))
continue
specified_config.update(yaml_config)
if "server_name" not in specified_config:

View File

@ -138,7 +138,7 @@ class DatabaseConfig(Config):
database_path = config.get("database_path")
if multi_database_config and database_config:
raise ConfigError("Can't specify both 'database' and 'datbases' in config")
raise ConfigError("Can't specify both 'database' and 'databases' in config")
if multi_database_config:
if database_path:

View File

@ -108,9 +108,14 @@ class EmailConfig(Config):
if self.trusted_third_party_id_servers:
# XXX: It's a little confusing that account_threepid_delegate_email is modified
# both in RegistrationConfig and here. We should factor this bit out
self.account_threepid_delegate_email = self.trusted_third_party_id_servers[
0
] # type: Optional[str]
first_trusted_identity_server = self.trusted_third_party_id_servers[0]
# trusted_third_party_id_servers does not contain a scheme whereas
# account_threepid_delegate_email is expected to. Presume https
self.account_threepid_delegate_email = (
"https://" + first_trusted_identity_server
) # type: Optional[str]
self.using_identity_server_from_trusted_list = True
else:
raise ConfigError(

View File

@ -113,6 +113,30 @@ class SSOConfig(Config):
#
# * server_name: the homeserver's name.
#
# * HTML page which notifies the user that they are authenticating to confirm
# an operation on their account during the user interactive authentication
# process: 'sso_auth_confirm.html'.
#
# When rendering, this template is given the following variables:
# * redirect_url: the URL the user is about to be redirected to. Needs
# manual escaping (see
# https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
#
# * description: the operation which the user is being asked to confirm
#
# * HTML page shown after a successful user interactive authentication session:
# 'sso_auth_success.html'.
#
# Note that this page must include the JavaScript which notifies of a successful authentication
# (see https://matrix.org/docs/spec/client_server/r0.6.0#fallback).
#
# This template has no additional variables.
#
# * HTML page shown during single sign-on if a deactivated user (according to Synapse's database)
# attempts to login: 'sso_account_deactivated.html'.
#
# This template has no additional variables.
#
# You can see the default templates at:
# https://github.com/matrix-org/synapse/tree/master/synapse/res/templates
#

View File

@ -37,15 +37,16 @@ An attestation is a signed blob of json that looks like:
import logging
import random
from typing import Tuple
from signedjson.sign import sign_json
from twisted.internet import defer
from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import get_domain_from_id
from synapse.util.async_helpers import yieldable_gather_results
logger = logging.getLogger(__name__)
@ -162,19 +163,19 @@ class GroupAttestionRenewer(object):
def _start_renew_attestations(self):
return run_as_background_process("renew_attestations", self._renew_attestations)
@defer.inlineCallbacks
def _renew_attestations(self):
async def _renew_attestations(self):
"""Called periodically to check if we need to update any of our attestations
"""
now = self.clock.time_msec()
rows = yield self.store.get_attestations_need_renewals(
rows = await self.store.get_attestations_need_renewals(
now + UPDATE_ATTESTATION_TIME_MS
)
@defer.inlineCallbacks
def _renew_attestation(group_id, user_id):
def _renew_attestation(group_user: Tuple[str, str]):
group_id, user_id = group_user
try:
if not self.is_mine_id(group_id):
destination = get_domain_from_id(group_id)
@ -207,8 +208,6 @@ class GroupAttestionRenewer(object):
"Error renewing attestation of %r in %r", user_id, group_id
)
for row in rows:
group_id = row["group_id"]
user_id = row["user_id"]
run_in_background(_renew_attestation, group_id, user_id)
await yieldable_gather_results(
_renew_attestation, ((row["group_id"], row["user_id"]) for row in rows)
)

View File

@ -41,10 +41,10 @@ from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
from synapse.http.server import finish_request
from synapse.http.site import SynapseRequest
from synapse.logging.context import defer_to_thread
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.module_api import ModuleApi
from synapse.push.mailer import load_jinja2_templates
from synapse.types import Requester, UserID
from synapse.util.caches.expiringcache import ExpiringCache
from ._base import BaseHandler
@ -69,15 +69,6 @@ class AuthHandler(BaseHandler):
self.bcrypt_rounds = hs.config.bcrypt_rounds
# This is not a cache per se, but a store of all current sessions that
# expire after N hours
self.sessions = ExpiringCache(
cache_name="register_sessions",
clock=hs.get_clock(),
expiry_ms=self.SESSION_EXPIRE_MS,
reset_expiry_on_get=True,
)
account_handler = ModuleApi(hs, self)
self.password_providers = [
module(config=config, account_handler=account_handler)
@ -119,6 +110,15 @@ class AuthHandler(BaseHandler):
self._clock = self.hs.get_clock()
# Expire old UI auth sessions after a period of time.
if hs.config.worker_app is None:
self._clock.looping_call(
run_as_background_process,
5 * 60 * 1000,
"expire_old_sessions",
self._expire_old_sessions,
)
# Load the SSO HTML templates.
# The following template is shown to the user during a client login via SSO,
@ -301,16 +301,21 @@ class AuthHandler(BaseHandler):
if "session" in authdict:
sid = authdict["session"]
# Convert the URI and method to strings.
uri = request.uri.decode("utf-8")
method = request.uri.decode("utf-8")
# If there's no session ID, create a new session.
if not sid:
session = self._create_session(
clientdict, (request.uri, request.method, clientdict), description
session = await self.store.create_ui_auth_session(
clientdict, uri, method, description
)
session_id = session["id"]
else:
session = self._get_session_info(sid)
session_id = sid
try:
session = await self.store.get_ui_auth_session(sid)
except StoreError:
raise SynapseError(400, "Unknown session ID: %s" % (sid,))
if not clientdict:
# This was designed to allow the client to omit the parameters
@ -322,15 +327,15 @@ class AuthHandler(BaseHandler):
# on a homeserver.
# Revisit: Assuming the REST APIs do sensible validation, the data
# isn't arbitrary.
clientdict = session["clientdict"]
clientdict = session.clientdict
# Ensure that the queried operation does not vary between stages of
# the UI authentication session. This is done by generating a stable
# comparator based on the URI, method, and body (minus the auth dict)
# and storing it during the initial query. Subsequent queries ensure
# that this comparator has not changed.
comparator = (request.uri, request.method, clientdict)
if session["ui_auth"] != comparator:
comparator = (uri, method, clientdict)
if (session.uri, session.method, session.clientdict) != comparator:
raise SynapseError(
403,
"Requested operation has changed during the UI authentication session.",
@ -338,11 +343,9 @@ class AuthHandler(BaseHandler):
if not authdict:
raise InteractiveAuthIncompleteError(
self._auth_dict_for_flows(flows, session_id)
self._auth_dict_for_flows(flows, session.session_id)
)
creds = session["creds"]
# check auth type currently being presented
errordict = {} # type: Dict[str, Any]
if "type" in authdict:
@ -350,8 +353,9 @@ class AuthHandler(BaseHandler):
try:
result = await self._check_auth_dict(authdict, clientip)
if result:
creds[login_type] = result
self._save_session(session)
await self.store.mark_ui_auth_stage_complete(
session.session_id, login_type, result
)
except LoginError as e:
if login_type == LoginType.EMAIL_IDENTITY:
# riot used to have a bug where it would request a new
@ -367,6 +371,7 @@ class AuthHandler(BaseHandler):
# so that the client can have another go.
errordict = e.error_dict()
creds = await self.store.get_completed_ui_auth_stages(session.session_id)
for f in flows:
if len(set(f) - set(creds)) == 0:
# it's very useful to know what args are stored, but this can
@ -380,9 +385,9 @@ class AuthHandler(BaseHandler):
list(clientdict),
)
return creds, clientdict, session_id
return creds, clientdict, session.session_id
ret = self._auth_dict_for_flows(flows, session_id)
ret = self._auth_dict_for_flows(flows, session.session_id)
ret["completed"] = list(creds)
ret.update(errordict)
raise InteractiveAuthIncompleteError(ret)
@ -399,13 +404,11 @@ class AuthHandler(BaseHandler):
if "session" not in authdict:
raise LoginError(400, "", Codes.MISSING_PARAM)
sess = self._get_session_info(authdict["session"])
creds = sess["creds"]
result = await self.checkers[stagetype].check_auth(authdict, clientip)
if result:
creds[stagetype] = result
self._save_session(sess)
await self.store.mark_ui_auth_stage_complete(
authdict["session"], stagetype, result
)
return True
return False
@ -427,7 +430,7 @@ class AuthHandler(BaseHandler):
sid = authdict["session"]
return sid
def set_session_data(self, session_id: str, key: str, value: Any) -> None:
async def set_session_data(self, session_id: str, key: str, value: Any) -> None:
"""
Store a key-value pair into the sessions data associated with this
request. This data is stored server-side and cannot be modified by
@ -438,11 +441,12 @@ class AuthHandler(BaseHandler):
key: The key to store the data under
value: The data to store
"""
sess = self._get_session_info(session_id)
sess["serverdict"][key] = value
self._save_session(sess)
try:
await self.store.set_ui_auth_session_data(session_id, key, value)
except StoreError:
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
def get_session_data(
async def get_session_data(
self, session_id: str, key: str, default: Optional[Any] = None
) -> Any:
"""
@ -453,8 +457,18 @@ class AuthHandler(BaseHandler):
key: The key to store the data under
default: Value to return if the key has not been set
"""
sess = self._get_session_info(session_id)
return sess["serverdict"].get(key, default)
try:
return await self.store.get_ui_auth_session_data(session_id, key, default)
except StoreError:
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
async def _expire_old_sessions(self):
"""
Invalidate any user interactive authentication sessions that have expired.
"""
now = self._clock.time_msec()
expiration_time = now - self.SESSION_EXPIRE_MS
await self.store.delete_old_ui_auth_sessions(expiration_time)
async def _check_auth_dict(
self, authdict: Dict[str, Any], clientip: str
@ -534,67 +548,6 @@ class AuthHandler(BaseHandler):
"params": params,
}
def _create_session(
self,
clientdict: Dict[str, Any],
ui_auth: Tuple[bytes, bytes, Dict[str, Any]],
description: str,
) -> dict:
"""
Creates a new user interactive authentication session.
The session can be used to track data across multiple requests, e.g. for
interactive authentication.
Each session has the following keys:
id:
A unique identifier for this session. Passed back to the client
and returned for each stage.
clientdict:
The dictionary from the client root level, not the 'auth' key.
ui_auth:
A tuple which is checked at each stage of the authentication to
ensure that the asked for operation has not changed.
creds:
A map, which maps each auth-type (str) to the relevant identity
authenticated by that auth-type (mostly str, but for captcha, bool).
serverdict:
A map of data that is stored server-side and cannot be modified
by the client.
description:
A string description of the operation that the current
authentication is authorising.
Returns:
The newly created session.
"""
session_id = None
while session_id is None or session_id in self.sessions:
session_id = stringutils.random_string(24)
self.sessions[session_id] = {
"id": session_id,
"clientdict": clientdict,
"ui_auth": ui_auth,
"creds": {},
"serverdict": {},
"description": description,
}
return self.sessions[session_id]
def _get_session_info(self, session_id: str) -> dict:
"""
Gets a session given a session ID.
The session can be used to track data across multiple requests, e.g. for
interactive authentication.
"""
try:
return self.sessions[session_id]
except KeyError:
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
async def get_access_token_for_user_id(
self, user_id: str, device_id: Optional[str], valid_until_ms: Optional[int]
):
@ -994,13 +947,6 @@ class AuthHandler(BaseHandler):
await self.store.user_delete_threepid(user_id, medium, address)
return result
def _save_session(self, session: Dict[str, Any]) -> None:
"""Update the last used time on the session to now and add it back to the session store."""
# TODO: Persistent storage
logger.debug("Saving session %s", session)
session["last_used"] = self.hs.get_clock().time_msec()
self.sessions[session["id"]] = session
async def hash(self, password: str) -> str:
"""Computes a secure hash of password.
@ -1052,7 +998,7 @@ class AuthHandler(BaseHandler):
else:
return False
def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str:
async def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str:
"""
Get the HTML for the SSO redirect confirmation page.
@ -1063,12 +1009,15 @@ class AuthHandler(BaseHandler):
Returns:
The HTML to render.
"""
session = self._get_session_info(session_id)
try:
session = await self.store.get_ui_auth_session(session_id)
except StoreError:
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
return self._sso_auth_confirm_template.render(
description=session["description"], redirect_url=redirect_url,
description=session.description, redirect_url=redirect_url,
)
def complete_sso_ui_auth(
async def complete_sso_ui_auth(
self, registered_user_id: str, session_id: str, request: SynapseRequest,
):
"""Having figured out a mxid for this user, complete the HTTP request
@ -1080,13 +1029,11 @@ class AuthHandler(BaseHandler):
process.
"""
# Mark the stage of the authentication as successful.
sess = self._get_session_info(session_id)
creds = sess["creds"]
# Save the user who authenticated with SSO, this will be used to ensure
# that the account be modified is also the person who logged in.
creds[LoginType.SSO] = registered_user_id
self._save_session(sess)
await self.store.mark_ui_auth_stage_complete(
session_id, LoginType.SSO, registered_user_id
)
# Render the HTML and return.
html_bytes = self._sso_auth_success_template.encode("utf-8")

View File

@ -206,7 +206,7 @@ class CasHandler:
registered_user_id = await self._auth_handler.check_user_exists(user_id)
if session:
self._auth_handler.complete_sso_ui_auth(
await self._auth_handler.complete_sso_ui_auth(
registered_user_id, session, request,
)

View File

@ -343,7 +343,7 @@ class FederationHandler(BaseHandler):
ours = await self.state_store.get_state_groups_ids(room_id, seen)
# state_maps is a list of mappings from (type, state_key) to event_id
state_maps = list(ours.values()) # type: list[StateMap[str]]
state_maps = list(ours.values()) # type: List[StateMap[str]]
# we don't need this any more, let's delete it.
del ours
@ -1694,16 +1694,15 @@ class FederationHandler(BaseHandler):
return None
@defer.inlineCallbacks
def get_state_for_pdu(self, room_id, event_id):
async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]:
"""Returns the state at the event. i.e. not including said event.
"""
event = yield self.store.get_event(
event = await self.store.get_event(
event_id, allow_none=False, check_room_id=room_id
)
state_groups = yield self.state_store.get_state_groups(room_id, [event_id])
state_groups = await self.state_store.get_state_groups(room_id, [event_id])
if state_groups:
_, state = list(iteritems(state_groups)).pop()
@ -1714,7 +1713,7 @@ class FederationHandler(BaseHandler):
if "replaces_state" in event.unsigned:
prev_id = event.unsigned["replaces_state"]
if prev_id != event.event_id:
prev_event = yield self.store.get_event(prev_id)
prev_event = await self.store.get_event(prev_id)
results[(event.type, event.state_key)] = prev_event
else:
del results[(event.type, event.state_key)]
@ -1724,15 +1723,14 @@ class FederationHandler(BaseHandler):
else:
return []
@defer.inlineCallbacks
def get_state_ids_for_pdu(self, room_id, event_id):
async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]:
"""Returns the state at the event. i.e. not including said event.
"""
event = yield self.store.get_event(
event = await self.store.get_event(
event_id, allow_none=False, check_room_id=room_id
)
state_groups = yield self.state_store.get_state_groups_ids(room_id, [event_id])
state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id])
if state_groups:
_, state = list(state_groups.items()).pop()
@ -1751,49 +1749,50 @@ class FederationHandler(BaseHandler):
else:
return []
@defer.inlineCallbacks
@log_function
def on_backfill_request(self, origin, room_id, pdu_list, limit):
in_room = yield self.auth.check_host_in_room(room_id, origin)
async def on_backfill_request(
self, origin: str, room_id: str, pdu_list: List[str], limit: int
) -> List[EventBase]:
in_room = await self.auth.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
# Synapse asks for 100 events per backfill request. Do not allow more.
limit = min(limit, 100)
events = yield self.store.get_backfill_events(room_id, pdu_list, limit)
events = await self.store.get_backfill_events(room_id, pdu_list, limit)
events = yield filter_events_for_server(self.storage, origin, events)
events = await filter_events_for_server(self.storage, origin, events)
return events
@defer.inlineCallbacks
@log_function
def get_persisted_pdu(self, origin, event_id):
async def get_persisted_pdu(
self, origin: str, event_id: str
) -> Optional[EventBase]:
"""Get an event from the database for the given server.
Args:
origin [str]: hostname of server which is requesting the event; we
origin: hostname of server which is requesting the event; we
will check that the server is allowed to see it.
event_id [str]: id of the event being requested
event_id: id of the event being requested
Returns:
Deferred[EventBase|None]: None if we know nothing about the event;
otherwise the (possibly-redacted) event.
None if we know nothing about the event; otherwise the (possibly-redacted) event.
Raises:
AuthError if the server is not currently in the room
"""
event = yield self.store.get_event(
event = await self.store.get_event(
event_id, allow_none=True, allow_rejected=True
)
if event:
in_room = yield self.auth.check_host_in_room(event.room_id, origin)
in_room = await self.auth.check_host_in_room(event.room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
events = yield filter_events_for_server(self.storage, origin, [event])
events = await filter_events_for_server(self.storage, origin, [event])
event = events[0]
return event
else:
@ -2397,7 +2396,7 @@ class FederationHandler(BaseHandler):
"""
# exclude the state key of the new event from the current_state in the context.
if event.is_state():
event_key = (event.type, event.state_key)
event_key = (event.type, event.state_key) # type: Optional[Tuple[str, str]]
else:
event_key = None
state_updates = {

View File

@ -91,7 +91,11 @@ class RoomListHandler(BaseHandler):
logger.info("Bypassing cache as search request.")
return self._get_public_room_list(
limit, since_token, search_filter, network_tuple=network_tuple
limit,
since_token,
search_filter,
network_tuple=network_tuple,
from_federation=from_federation,
)
key = (limit, since_token, network_tuple)

View File

@ -149,7 +149,7 @@ class SamlHandler:
# Complete the interactive auth session or the login.
if current_session and current_session.ui_auth_session_id:
self._auth_handler.complete_sso_ui_auth(
await self._auth_handler.complete_sso_ui_auth(
user_id, current_session.ui_auth_session_id, request
)

View File

@ -171,7 +171,7 @@ import logging
import re
import types
from functools import wraps
from typing import Dict
from typing import TYPE_CHECKING, Dict
from canonicaljson import json
@ -179,6 +179,9 @@ from twisted.internet import defer
from synapse.config import ConfigError
if TYPE_CHECKING:
from synapse.server import HomeServer
# Helper class
@ -297,14 +300,11 @@ def _noop_context_manager(*args, **kwargs):
# Setup
def init_tracer(config):
def init_tracer(hs: "HomeServer"):
"""Set the whitelists and initialise the JaegerClient tracer
Args:
config (HomeserverConfig): The config used by the homeserver
"""
global opentracing
if not config.opentracer_enabled:
if not hs.config.opentracer_enabled:
# We don't have a tracer
opentracing = None
return
@ -315,18 +315,15 @@ def init_tracer(config):
"installed."
)
# Include the worker name
name = config.worker_name if config.worker_name else "master"
# Pull out the jaeger config if it was given. Otherwise set it to something sensible.
# See https://github.com/jaegertracing/jaeger-client-python/blob/master/jaeger_client/config.py
set_homeserver_whitelist(config.opentracer_whitelist)
set_homeserver_whitelist(hs.config.opentracer_whitelist)
JaegerConfig(
config=config.jaeger_config,
service_name="{} {}".format(config.server_name, name),
scope_manager=LogContextScopeManager(config),
config=hs.config.jaeger_config,
service_name="{} {}".format(hs.config.server_name, hs.get_instance_name()),
scope_manager=LogContextScopeManager(hs.config),
).initialize_tracer()

View File

@ -220,12 +220,6 @@ class Notifier(object):
"""
self.replication_callbacks.append(cb)
def add_remote_server_up_callback(self, cb: Callable[[str], None]):
"""Add a callback that will be called when synapse detects a server
has been
"""
self.remote_server_up_callbacks.append(cb)
def on_new_room_event(
self, event, room_stream_id, max_room_stream_id, extra_users=[]
):
@ -544,6 +538,3 @@ class Notifier(object):
# circular dependencies.
if self.federation_sender:
self.federation_sender.wake_destination(server)
for cb in self.remote_server_up_callbacks:
cb(server)

View File

@ -95,7 +95,7 @@ class RdataCommand(Command):
Format::
RDATA <stream_name> <token> <row_json>
RDATA <stream_name> <instance_name> <token> <row_json>
The `<token>` may either be a numeric stream id OR "batch". The latter case
is used to support sending multiple updates with the same stream ID. This
@ -105,33 +105,40 @@ class RdataCommand(Command):
The client should batch all incoming RDATA with a token of "batch" (per
stream_name) until it sees an RDATA with a numeric stream ID.
The `<instance_name>` is the source of the new data (usually "master").
`<token>` of "batch" maps to the instance variable `token` being None.
An example of a batched series of RDATA::
RDATA presence batch ["@foo:example.com", "online", ...]
RDATA presence batch ["@bar:example.com", "online", ...]
RDATA presence 59 ["@baz:example.com", "online", ...]
RDATA presence master batch ["@foo:example.com", "online", ...]
RDATA presence master batch ["@bar:example.com", "online", ...]
RDATA presence master 59 ["@baz:example.com", "online", ...]
"""
NAME = "RDATA"
def __init__(self, stream_name, token, row):
def __init__(self, stream_name, instance_name, token, row):
self.stream_name = stream_name
self.instance_name = instance_name
self.token = token
self.row = row
@classmethod
def from_line(cls, line):
stream_name, token, row_json = line.split(" ", 2)
stream_name, instance_name, token, row_json = line.split(" ", 3)
return cls(
stream_name, None if token == "batch" else int(token), json.loads(row_json)
stream_name,
instance_name,
None if token == "batch" else int(token),
json.loads(row_json),
)
def to_line(self):
return " ".join(
(
self.stream_name,
self.instance_name,
str(self.token) if self.token is not None else "batch",
_json_encoder.encode(self.row),
)
@ -145,23 +152,31 @@ class PositionCommand(Command):
"""Sent by the server to tell the client the stream postition without
needing to send an RDATA.
Format::
POSITION <stream_name> <instance_name> <token>
On receipt of a POSITION command clients should check if they have missed
any updates, and if so then fetch them out of band.
The `<instance_name>` is the process that sent the command and is the source
of the stream.
"""
NAME = "POSITION"
def __init__(self, stream_name, token):
def __init__(self, stream_name, instance_name, token):
self.stream_name = stream_name
self.instance_name = instance_name
self.token = token
@classmethod
def from_line(cls, line):
stream_name, token = line.split(" ", 1)
return cls(stream_name, int(token))
stream_name, instance_name, token = line.split(" ", 2)
return cls(stream_name, instance_name, int(token))
def to_line(self):
return " ".join((self.stream_name, str(self.token)))
return " ".join((self.stream_name, self.instance_name, str(self.token)))
class ErrorCommand(_SimpleCommand):

View File

@ -79,6 +79,7 @@ class ReplicationCommandHandler:
self._notifier = hs.get_notifier()
self._clock = hs.get_clock()
self._instance_id = hs.get_instance_id()
self._instance_name = hs.get_instance_name()
# Set of streams that we've caught up with.
self._streams_connected = set() # type: Set[str]
@ -87,7 +88,9 @@ class ReplicationCommandHandler:
stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
} # type: Dict[str, Stream]
self._position_linearizer = Linearizer("replication_position")
self._position_linearizer = Linearizer(
"replication_position", clock=self._clock
)
# Map of stream to batched updates. See RdataCommand for info on how
# batching works.
@ -115,7 +118,6 @@ class ReplicationCommandHandler:
self._server_notices_sender = None
if self._is_master:
self._server_notices_sender = hs.get_server_notices_sender()
self._notifier.add_remote_server_up_callback(self.send_remote_server_up)
def start_replication(self, hs):
"""Helper method to start a replication connection to the remote server
@ -155,13 +157,13 @@ class ReplicationCommandHandler:
hs.config.redis.redis_host, hs.config.redis.redis_port, self._factory,
)
else:
client_name = hs.config.worker_name
client_name = hs.get_instance_name()
self._factory = DirectTcpReplicationClientFactory(hs, client_name, self)
host = hs.config.worker_replication_host
port = hs.config.worker_replication_port
hs.get_reactor().connectTCP(host, port, self._factory)
async def on_REPLICATE(self, cmd: ReplicateCommand):
async def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
# We only want to announce positions by the writer of the streams.
# Currently this is just the master process.
if not self._is_master:
@ -169,9 +171,11 @@ class ReplicationCommandHandler:
for stream_name, stream in self._streams.items():
current_token = stream.current_token()
self.send_command(PositionCommand(stream_name, current_token))
self.send_command(
PositionCommand(stream_name, self._instance_name, current_token)
)
async def on_USER_SYNC(self, cmd: UserSyncCommand):
async def on_USER_SYNC(self, conn: AbstractConnection, cmd: UserSyncCommand):
user_sync_counter.inc()
if self._is_master:
@ -179,17 +183,23 @@ class ReplicationCommandHandler:
cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
)
async def on_CLEAR_USER_SYNC(self, cmd: ClearUserSyncsCommand):
async def on_CLEAR_USER_SYNC(
self, conn: AbstractConnection, cmd: ClearUserSyncsCommand
):
if self._is_master:
await self._presence_handler.update_external_syncs_clear(cmd.instance_id)
async def on_FEDERATION_ACK(self, cmd: FederationAckCommand):
async def on_FEDERATION_ACK(
self, conn: AbstractConnection, cmd: FederationAckCommand
):
federation_ack_counter.inc()
if self._federation_sender:
self._federation_sender.federation_ack(cmd.token)
async def on_REMOVE_PUSHER(self, cmd: RemovePusherCommand):
async def on_REMOVE_PUSHER(
self, conn: AbstractConnection, cmd: RemovePusherCommand
):
remove_pusher_counter.inc()
if self._is_master:
@ -199,7 +209,9 @@ class ReplicationCommandHandler:
self._notifier.on_new_replication_data()
async def on_INVALIDATE_CACHE(self, cmd: InvalidateCacheCommand):
async def on_INVALIDATE_CACHE(
self, conn: AbstractConnection, cmd: InvalidateCacheCommand
):
invalidate_cache_counter.inc()
if self._is_master:
@ -209,7 +221,7 @@ class ReplicationCommandHandler:
cmd.cache_func, tuple(cmd.keys)
)
async def on_USER_IP(self, cmd: UserIpCommand):
async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand):
user_ip_cache_counter.inc()
if self._is_master:
@ -225,7 +237,11 @@ class ReplicationCommandHandler:
if self._server_notices_sender:
await self._server_notices_sender.on_user_ip(cmd.user_id)
async def on_RDATA(self, cmd: RdataCommand):
async def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
if cmd.instance_name == self._instance_name:
# Ignore RDATA that are just our own echoes
return
stream_name = cmd.stream_name
inbound_rdata_count.labels(stream_name).inc()
@ -276,7 +292,11 @@ class ReplicationCommandHandler:
logger.debug("Received rdata %s -> %s", stream_name, token)
await self._replication_data_handler.on_rdata(stream_name, token, rows)
async def on_POSITION(self, cmd: PositionCommand):
async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
if cmd.instance_name == self._instance_name:
# Ignore POSITION that are just our own echoes
return
stream = self._streams.get(cmd.stream_name)
if not stream:
logger.error("Got POSITION for unknown stream: %s", cmd.stream_name)
@ -330,12 +350,30 @@ class ReplicationCommandHandler:
self._streams_connected.add(cmd.stream_name)
async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
async def on_REMOTE_SERVER_UP(
self, conn: AbstractConnection, cmd: RemoteServerUpCommand
):
""""Called when get a new REMOTE_SERVER_UP command."""
self._replication_data_handler.on_remote_server_up(cmd.data)
if self._is_master:
self._notifier.notify_remote_server_up(cmd.data)
self._notifier.notify_remote_server_up(cmd.data)
# We relay to all other connections to ensure every instance gets the
# notification.
#
# When configured to use redis we'll always only have one connection and
# so this is a no-op (all instances will have already received the same
# REMOTE_SERVER_UP command).
#
# For direct TCP connections this will relay to all other connections
# connected to us. When on master this will correctly fan out to all
# other direct TCP clients and on workers there'll only be the one
# connection to master.
#
# (The logic here should also be sound if we have a mix of Redis and
# direct TCP connections so long as there is only one traffic route
# between two instances, but that is not currently supported).
self.send_command(cmd, ignore_conn=conn)
def new_connection(self, connection: AbstractConnection):
"""Called when we have a new connection.
@ -380,11 +418,21 @@ class ReplicationCommandHandler:
"""
return bool(self._connections)
def send_command(self, cmd: Command):
def send_command(
self, cmd: Command, ignore_conn: Optional[AbstractConnection] = None
):
"""Send a command to all connected connections.
Args:
cmd
ignore_conn: If set don't send command to the given connection.
Used when relaying commands from one connection to all others.
"""
if self._connections:
for connection in self._connections:
if connection == ignore_conn:
continue
try:
connection.send_command(cmd)
except Exception:
@ -448,7 +496,7 @@ class ReplicationCommandHandler:
We need to check if the client is interested in the stream or not
"""
self.send_command(RdataCommand(stream_name, token, data))
self.send_command(RdataCommand(stream_name, self._instance_name, token, data))
UpdateToken = TypeVar("UpdateToken")

View File

@ -260,7 +260,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# Then call out to the handler.
cmd_func = getattr(self.command_handler, "on_%s" % (cmd.NAME,), None)
if cmd_func:
await cmd_func(cmd)
await cmd_func(self, cmd)
handled = True
if not handled:

View File

@ -112,7 +112,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
# Then call out to the handler.
cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
if cmd_func:
await cmd_func(cmd)
await cmd_func(self, cmd)
handled = True
if not handled:

View File

@ -17,9 +17,7 @@
import logging
import random
from typing import Dict
from six import itervalues
from typing import Dict, List
from prometheus_client import Counter
@ -71,29 +69,28 @@ class ReplicationStreamer(object):
def __init__(self, hs):
self.store = hs.get_datastore()
self.presence_handler = hs.get_presence_handler()
self.clock = hs.get_clock()
self.notifier = hs.get_notifier()
self._server_notices_sender = hs.get_server_notices_sender()
self._replication_torture_level = hs.config.replication_torture_level
# List of streams that clients can subscribe to.
# We only support federation stream if federation sending hase been
# disabled on the master.
self.streams = [
stream(hs)
for stream in itervalues(STREAMS_MAP)
if stream != FederationStream or not hs.config.send_federation
]
# Work out list of streams that this instance is the source of.
self.streams = [] # type: List[Stream]
if hs.config.worker_app is None:
for stream in STREAMS_MAP.values():
if stream == FederationStream and hs.config.send_federation:
# We only support federation stream if federation sending
# hase been disabled on the master.
continue
self.streams.append(stream(hs))
self.streams_by_name = {stream.NAME: stream for stream in self.streams}
self.federation_sender = None
if not hs.config.send_federation:
self.federation_sender = hs.get_federation_sender()
self.notifier.add_replication_callback(self.on_notifier_poke)
# Only bother registering the notifier callback if we have streams to
# publish.
if self.streams:
self.notifier.add_replication_callback(self.on_notifier_poke)
# Keeps track of whether we are currently checking for updates
self.is_looping = False

View File

@ -176,10 +176,9 @@ def db_query_to_update_function(
rows = await query_function(from_token, upto_token, limit)
updates = [(row[0], row[1:]) for row in rows]
limited = False
if len(updates) == limit:
if len(updates) >= limit:
upto_token = updates[-1][0]
limited = True
assert len(updates) <= limit
return updates, upto_token, limited

View File

@ -170,22 +170,16 @@ class EventsStream(Stream):
limited = False
upper_limit = current_token
# next up is the state delta table
state_rows = await self._store.get_all_updated_current_state_deltas(
# next up is the state delta table.
(
state_rows,
upper_limit,
state_rows_limited,
) = await self._store.get_all_updated_current_state_deltas(
from_token, upper_limit, target_row_count
) # type: List[Tuple]
)
# again, if we've hit the limit there, we'll need to limit the other sources
assert len(state_rows) < target_row_count
if len(state_rows) == target_row_count:
assert state_rows[-1][0] <= upper_limit
upper_limit = state_rows[-1][0]
limited = True
# FIXME: is it a given that there is only one row per stream_id in the
# state_deltas table (so that we can be sure that we have got all of the
# rows for upper_limit)?
limited = limited or state_rows_limited
# finally, fetch the ex-outliers rows. We assume there are few enough of these
# not to bother with the limit.

View File

@ -94,10 +94,10 @@ class UsersRestServletV2(RestServlet):
guests = parse_boolean(request, "guests", default=True)
deactivated = parse_boolean(request, "deactivated", default=False)
users = await self.store.get_users_paginate(
users, total = await self.store.get_users_paginate(
start, limit, user_id, guests, deactivated
)
ret = {"users": users}
ret = {"users": users, "total": total}
if len(users) >= limit:
ret["next_token"] = str(start + len(users))
@ -199,7 +199,7 @@ class UserRestServletV2(RestServlet):
user_id, threepid["medium"], threepid["address"], current_time
)
if "avatar_url" in body:
if "avatar_url" in body and type(body["avatar_url"]) == str:
await self.profile_handler.set_avatar_url(
target_user, requester, body["avatar_url"], True
)
@ -276,7 +276,7 @@ class UserRestServletV2(RestServlet):
user_id, threepid["medium"], threepid["address"], current_time
)
if "avatar_url" in body:
if "avatar_url" in body and type(body["avatar_url"]) == str:
await self.profile_handler.set_avatar_url(
user_id, requester, body["avatar_url"], True
)

View File

@ -140,7 +140,7 @@ class AuthRestServlet(RestServlet):
self._cas_server_url = hs.config.cas_server_url
self._cas_service_url = hs.config.cas_service_url
def on_GET(self, request, stagetype):
async def on_GET(self, request, stagetype):
session = parse_string(request, "session")
if not session:
raise SynapseError(400, "No session supplied")
@ -180,7 +180,7 @@ class AuthRestServlet(RestServlet):
else:
raise SynapseError(400, "Homeserver not configured for SSO.")
html = self.auth_handler.start_sso_ui_auth(sso_redirect_url, session)
html = await self.auth_handler.start_sso_ui_auth(sso_redirect_url, session)
else:
raise SynapseError(404, "Unknown auth stage type")

View File

@ -499,7 +499,7 @@ class RegisterRestServlet(RestServlet):
# registered a user for this session, so we could just return the
# user here. We carry on and go through the auth checks though,
# for paranoia.
registered_user_id = self.auth_handler.get_session_data(
registered_user_id = await self.auth_handler.get_session_data(
session_id, "registered_user_id", None
)
@ -598,7 +598,7 @@ class RegisterRestServlet(RestServlet):
# remember that we've now registered that user account, and with
# what user ID (since the user may not have specified)
self.auth_handler.set_session_data(
await self.auth_handler.set_session_data(
session_id, "registered_user_id", registered_user_id
)

View File

@ -234,7 +234,8 @@ class HomeServer(object):
self._listening_services = []
self.start_time = None
self.instance_id = random_string(5)
self._instance_id = random_string(5)
self._instance_name = config.worker_name or "master"
self.clock = Clock(reactor)
self.distributor = Distributor()
@ -254,7 +255,15 @@ class HomeServer(object):
This is used to distinguish running instances in worker-based
deployments.
"""
return self.instance_id
return self._instance_id
def get_instance_name(self) -> str:
"""A unique name for this synapse process.
Used to identify the process over replication and in config. Does not
change over restarts.
"""
return self._instance_name
def setup(self):
logger.info("Setting up.")

View File

@ -25,6 +25,7 @@ import synapse.server_notices.server_notices_manager
import synapse.server_notices.server_notices_sender
import synapse.state
import synapse.storage
from synapse.events.builder import EventBuilderFactory
class HomeServer(object):
@property
@ -121,3 +122,9 @@ class HomeServer(object):
pass
def get_instance_id(self) -> str:
pass
def get_instance_name(self) -> str:
pass
def get_event_builder_factory(self) -> EventBuilderFactory:
pass
def get_storage(self) -> synapse.storage.Storage:
pass

View File

@ -66,6 +66,7 @@ from .stats import StatsStore
from .stream import StreamStore
from .tags import TagsStore
from .transactions import TransactionStore
from .ui_auth import UIAuthStore
from .user_directory import UserDirectoryStore
from .user_erasure_store import UserErasureStore
@ -112,6 +113,7 @@ class DataStore(
StatsStore,
RelationsStore,
CacheInvalidationStore,
UIAuthStore,
):
def __init__(self, database: Database, db_conn, hs):
self.hs = hs
@ -503,7 +505,8 @@ class DataStore(
self, start, limit, name=None, guests=True, deactivated=False
):
"""Function to retrieve a paginated list of users from
users list. This will return a json list of users.
users list. This will return a json list of users and the
total number of users matching the filter criteria.
Args:
start (int): start number to begin the query from
@ -512,35 +515,44 @@ class DataStore(
guests (bool): whether to in include guest users
deactivated (bool): whether to include deactivated users
Returns:
defer.Deferred: resolves to list[dict[str, Any]]
defer.Deferred: resolves to list[dict[str, Any]], int
"""
name_filter = {}
if name:
name_filter["name"] = "%" + name + "%"
attr_filter = {}
if not guests:
attr_filter["is_guest"] = 0
if not deactivated:
attr_filter["deactivated"] = 0
def get_users_paginate_txn(txn):
filters = []
args = []
return self.db.simple_select_list_paginate(
desc="get_users_paginate",
table="users",
orderby="name",
start=start,
limit=limit,
filters=name_filter,
keyvalues=attr_filter,
retcols=[
"name",
"password_hash",
"is_guest",
"admin",
"user_type",
"deactivated",
],
)
if name:
filters.append("name LIKE ?")
args.append("%" + name + "%")
if not guests:
filters.append("is_guest = 0")
if not deactivated:
filters.append("deactivated = 0")
where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
sql = "SELECT COUNT(*) as total_users FROM users %s" % (where_clause)
txn.execute(sql, args)
count = txn.fetchone()[0]
args = [self.hs.config.server_name] + args + [limit, start]
sql = """
SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url
FROM users as u
LEFT JOIN profiles AS p ON u.name = '@' || p.user_id || ':' || ?
{}
ORDER BY u.name LIMIT ? OFFSET ?
""".format(
where_clause
)
txn.execute(sql, args)
users = self.db.cursor_to_dict(txn)
return users, count
return self.db.runInteraction("get_users_paginate_txn", get_users_paginate_txn)
def search_users(self, term):
"""Function to search users list for one or more users with

View File

@ -19,7 +19,7 @@ import itertools
import logging
import threading
from collections import namedtuple
from typing import List, Optional
from typing import List, Optional, Tuple
from canonicaljson import json
from constantly import NamedConstant, Names
@ -1084,7 +1084,28 @@ class EventsWorkerStore(SQLBaseStore):
"get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
)
def get_all_updated_current_state_deltas(self, from_token, to_token, limit):
async def get_all_updated_current_state_deltas(
self, from_token: int, to_token: int, target_row_count: int
) -> Tuple[List[Tuple], int, bool]:
"""Fetch updates from current_state_delta_stream
Args:
from_token: The previous stream token. Updates from this stream id will
be excluded.
to_token: The current stream token (ie the upper limit). Updates up to this
stream id will be included (modulo the 'limit' param)
target_row_count: The number of rows to try to return. If more rows are
available, we will set 'limited' in the result. In the event of a large
batch, we may return more rows than this.
Returns:
A triplet `(updates, new_last_token, limited)`, where:
* `updates` is a list of database tuples.
* `new_last_token` is the new position in stream.
* `limited` is whether there are more updates to fetch.
"""
def get_all_updated_current_state_deltas_txn(txn):
sql = """
SELECT stream_id, room_id, type, state_key, event_id
@ -1092,10 +1113,45 @@ class EventsWorkerStore(SQLBaseStore):
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC LIMIT ?
"""
txn.execute(sql, (from_token, to_token, limit))
txn.execute(sql, (from_token, to_token, target_row_count))
return txn.fetchall()
return self.db.runInteraction(
def get_deltas_for_stream_id_txn(txn, stream_id):
sql = """
SELECT stream_id, room_id, type, state_key, event_id
FROM current_state_delta_stream
WHERE stream_id = ?
"""
txn.execute(sql, [stream_id])
return txn.fetchall()
# we need to make sure that, for every stream id in the results, we get *all*
# the rows with that stream id.
rows = await self.db.runInteraction(
"get_all_updated_current_state_deltas",
get_all_updated_current_state_deltas_txn,
) # type: List[Tuple]
# if we've got fewer rows than the limit, we're good
if len(rows) < target_row_count:
return rows, to_token, False
# we hit the limit, so reduce the upper limit so that we exclude the stream id
# of the last row in the result.
assert rows[-1][0] <= to_token
to_token = rows[-1][0] - 1
# search backwards through the list for the point to truncate
for idx in range(len(rows) - 1, 0, -1):
if rows[idx - 1][0] <= to_token:
return rows[:idx], to_token, True
# bother. We didn't get a full set of changes for even a single
# stream id. let's run the query again, without a row limit, but for
# just one stream id.
to_token += 1
rows = await self.db.runInteraction(
"get_deltas_for_stream_id", get_deltas_for_stream_id_txn, to_token
)
return rows, to_token, True

View File

@ -0,0 +1,36 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS ui_auth_sessions(
session_id TEXT NOT NULL, -- The session ID passed to the client.
creation_time BIGINT NOT NULL, -- The time this session was created (epoch time in milliseconds).
serverdict TEXT NOT NULL, -- A JSON dictionary of arbitrary data added by Synapse.
clientdict TEXT NOT NULL, -- A JSON dictionary of arbitrary data from the client.
uri TEXT NOT NULL, -- The URI the UI authentication session is using.
method TEXT NOT NULL, -- The HTTP method the UI authentication session is using.
-- The clientdict, uri, and method make up an tuple that must be immutable
-- throughout the lifetime of the UI Auth session.
description TEXT NOT NULL, -- A human readable description of the operation which caused the UI Auth flow to occur.
UNIQUE (session_id)
);
CREATE TABLE IF NOT EXISTS ui_auth_sessions_credentials(
session_id TEXT NOT NULL, -- The corresponding UI Auth session.
stage_type TEXT NOT NULL, -- The stage type.
result TEXT NOT NULL, -- The result of the stage verification, stored as JSON.
UNIQUE (session_id, stage_type),
FOREIGN KEY (session_id)
REFERENCES ui_auth_sessions (session_id)
);

View File

@ -0,0 +1,279 @@
# -*- coding: utf-8 -*-
# Copyright 2020 Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import Any, Dict, Optional, Union
import attr
import synapse.util.stringutils as stringutils
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore
from synapse.types import JsonDict
@attr.s
class UIAuthSessionData:
session_id = attr.ib(type=str)
# The dictionary from the client root level, not the 'auth' key.
clientdict = attr.ib(type=JsonDict)
# The URI and method the session was intiatied with. These are checked at
# each stage of the authentication to ensure that the asked for operation
# has not changed.
uri = attr.ib(type=str)
method = attr.ib(type=str)
# A string description of the operation that the current authentication is
# authorising.
description = attr.ib(type=str)
class UIAuthWorkerStore(SQLBaseStore):
"""
Manage user interactive authentication sessions.
"""
async def create_ui_auth_session(
self, clientdict: JsonDict, uri: str, method: str, description: str,
) -> UIAuthSessionData:
"""
Creates a new user interactive authentication session.
The session can be used to track the stages necessary to authenticate a
user across multiple HTTP requests.
Args:
clientdict:
The dictionary from the client root level, not the 'auth' key.
uri:
The URI this session was initiated with, this is checked at each
stage of the authentication to ensure that the asked for
operation has not changed.
method:
The method this session was initiated with, this is checked at each
stage of the authentication to ensure that the asked for
operation has not changed.
description:
A string description of the operation that the current
authentication is authorising.
Returns:
The newly created session.
Raises:
StoreError if a unique session ID cannot be generated.
"""
# The clientdict gets stored as JSON.
clientdict_json = json.dumps(clientdict)
# autogen a session ID and try to create it. We may clash, so just
# try a few times till one goes through, giving up eventually.
attempts = 0
while attempts < 5:
session_id = stringutils.random_string(24)
try:
await self.db.simple_insert(
table="ui_auth_sessions",
values={
"session_id": session_id,
"clientdict": clientdict_json,
"uri": uri,
"method": method,
"description": description,
"serverdict": "{}",
"creation_time": self.hs.get_clock().time_msec(),
},
desc="create_ui_auth_session",
)
return UIAuthSessionData(
session_id, clientdict, uri, method, description
)
except self.db.engine.module.IntegrityError:
attempts += 1
raise StoreError(500, "Couldn't generate a session ID.")
async def get_ui_auth_session(self, session_id: str) -> UIAuthSessionData:
"""Retrieve a UI auth session.
Args:
session_id: The ID of the session.
Returns:
A dict containing the device information.
Raises:
StoreError if the session is not found.
"""
result = await self.db.simple_select_one(
table="ui_auth_sessions",
keyvalues={"session_id": session_id},
retcols=("clientdict", "uri", "method", "description"),
desc="get_ui_auth_session",
)
result["clientdict"] = json.loads(result["clientdict"])
return UIAuthSessionData(session_id, **result)
async def mark_ui_auth_stage_complete(
self, session_id: str, stage_type: str, result: Union[str, bool, JsonDict],
):
"""
Mark a session stage as completed.
Args:
session_id: The ID of the corresponding session.
stage_type: The completed stage type.
result: The result of the stage verification.
Raises:
StoreError if the session cannot be found.
"""
# Add (or update) the results of the current stage to the database.
#
# Note that we need to allow for the same stage to complete multiple
# times here so that registration is idempotent.
try:
await self.db.simple_upsert(
table="ui_auth_sessions_credentials",
keyvalues={"session_id": session_id, "stage_type": stage_type},
values={"result": json.dumps(result)},
desc="mark_ui_auth_stage_complete",
)
except self.db.engine.module.IntegrityError:
raise StoreError(400, "Unknown session ID: %s" % (session_id,))
async def get_completed_ui_auth_stages(
self, session_id: str
) -> Dict[str, Union[str, bool, JsonDict]]:
"""
Retrieve the completed stages of a UI authentication session.
Args:
session_id: The ID of the session.
Returns:
The completed stages mapped to the result of the verification of
that auth-type.
"""
results = {}
for row in await self.db.simple_select_list(
table="ui_auth_sessions_credentials",
keyvalues={"session_id": session_id},
retcols=("stage_type", "result"),
desc="get_completed_ui_auth_stages",
):
results[row["stage_type"]] = json.loads(row["result"])
return results
async def set_ui_auth_session_data(self, session_id: str, key: str, value: Any):
"""
Store a key-value pair into the sessions data associated with this
request. This data is stored server-side and cannot be modified by
the client.
Args:
session_id: The ID of this session as returned from check_auth
key: The key to store the data under
value: The data to store
Raises:
StoreError if the session cannot be found.
"""
await self.db.runInteraction(
"set_ui_auth_session_data",
self._set_ui_auth_session_data_txn,
session_id,
key,
value,
)
def _set_ui_auth_session_data_txn(self, txn, session_id: str, key: str, value: Any):
# Get the current value.
result = self.db.simple_select_one_txn(
txn,
table="ui_auth_sessions",
keyvalues={"session_id": session_id},
retcols=("serverdict",),
)
# Update it and add it back to the database.
serverdict = json.loads(result["serverdict"])
serverdict[key] = value
self.db.simple_update_one_txn(
txn,
table="ui_auth_sessions",
keyvalues={"session_id": session_id},
updatevalues={"serverdict": json.dumps(serverdict)},
)
async def get_ui_auth_session_data(
self, session_id: str, key: str, default: Optional[Any] = None
) -> Any:
"""
Retrieve data stored with set_session_data
Args:
session_id: The ID of this session as returned from check_auth
key: The key to store the data under
default: Value to return if the key has not been set
Raises:
StoreError if the session cannot be found.
"""
result = await self.db.simple_select_one(
table="ui_auth_sessions",
keyvalues={"session_id": session_id},
retcols=("serverdict",),
desc="get_ui_auth_session_data",
)
serverdict = json.loads(result["serverdict"])
return serverdict.get(key, default)
class UIAuthStore(UIAuthWorkerStore):
def delete_old_ui_auth_sessions(self, expiration_time: int):
"""
Remove sessions which were last used earlier than the expiration time.
Args:
expiration_time: The latest time that is still considered valid.
This is an epoch time in milliseconds.
"""
return self.db.runInteraction(
"delete_old_ui_auth_sessions",
self._delete_old_ui_auth_sessions_txn,
expiration_time,
)
def _delete_old_ui_auth_sessions_txn(self, txn, expiration_time: int):
# Get the expired sessions.
sql = "SELECT session_id FROM ui_auth_sessions WHERE creation_time <= ?"
txn.execute(sql, [expiration_time])
session_ids = [r[0] for r in txn.fetchall()]
# Delete the corresponding completed credentials.
self.db.simple_delete_many_txn(
txn,
table="ui_auth_sessions_credentials",
column="session_id",
iterable=session_ids,
keyvalues={},
)
# Finally, delete the sessions.
self.db.simple_delete_many_txn(
txn,
table="ui_auth_sessions",
column="session_id",
iterable=session_ids,
keyvalues={},
)

View File

@ -85,6 +85,7 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
prepare_database(db_conn, self, config=None)
db_conn.create_function("rank", 1, _rank)
db_conn.execute("PRAGMA foreign_keys = ON;")
def is_deadlock(self, error):
return False

View File

@ -57,6 +57,7 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
# We now do some gut wrenching so that we have a client that is based
# off of the slave store rather than the main store.
self.replication_handler = ReplicationCommandHandler(self.hs)
self.replication_handler._instance_name = "worker"
self.replication_handler._replication_data_handler = ReplicationDataHandler(
self.slaved_store
)

View File

@ -13,38 +13,72 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from mock import Mock
import logging
from typing import Any, Dict, List, Optional, Tuple
import attr
from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
from twisted.internet.task import LoopingCall
from twisted.web.http import HTTPChannel
from synapse.app.generic_worker import GenericWorkerServer
from synapse.http.site import SynapseRequest
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.tcp.client import ReplicationDataHandler
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.util import Clock
from tests import unittest
from tests.server import FakeTransport
logger = logging.getLogger(__name__)
class BaseStreamTestCase(unittest.HomeserverTestCase):
"""Base class for tests of the replication streams"""
def make_homeserver(self, reactor, clock):
self.test_handler = Mock(wraps=TestReplicationDataHandler())
return self.setup_test_homeserver(replication_data_handler=self.test_handler)
def prepare(self, reactor, clock, hs):
# build a replication server
server_factory = ReplicationStreamProtocolFactory(hs)
self.streamer = hs.get_replication_streamer()
self.server = server_factory.buildProtocol(None)
repl_handler = ReplicationCommandHandler(hs)
repl_handler.handler = self.test_handler
# Make a new HomeServer object for the worker
config = self.default_config()
config["worker_app"] = "synapse.app.generic_worker"
config["worker_replication_host"] = "testserv"
config["worker_replication_http_port"] = "8765"
self.reactor.lookups["testserv"] = "1.2.3.4"
self.worker_hs = self.setup_test_homeserver(
http_client=None,
homeserverToUse=GenericWorkerServer,
config=config,
reactor=self.reactor,
)
# Since we use sqlite in memory databases we need to make sure the
# databases objects are the same.
self.worker_hs.get_datastore().db = hs.get_datastore().db
self.test_handler = self._build_replication_data_handler()
self.worker_hs.replication_data_handler = self.test_handler
repl_handler = ReplicationCommandHandler(self.worker_hs)
self.client = ClientReplicationStreamProtocol(
hs, "client", "test", clock, repl_handler,
self.worker_hs, "client", "test", clock, repl_handler,
)
self._client_transport = None
self._server_transport = None
def _build_replication_data_handler(self):
return TestReplicationDataHandler(self.worker_hs.get_datastore())
def reconnect(self):
if self._client_transport:
self.client.close()
@ -74,24 +108,204 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self.streamer.on_notifier_poke()
self.pump(0.1)
def handle_http_replication_attempt(self) -> SynapseRequest:
"""Asserts that a connection attempt was made to the master HS on the
HTTP replication port, then proxies it to the master HS object to be
handled.
class TestReplicationDataHandler:
Returns:
The request object received by master HS.
"""
# We should have an outbound connection attempt.
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
self.assertEqual(host, "1.2.3.4")
self.assertEqual(port, 8765)
# Set up client side protocol
client_protocol = client_factory.buildProtocol(None)
request_factory = OneShotRequestFactory()
# Set up the server side protocol
channel = _PushHTTPChannel(self.reactor)
channel.requestFactory = request_factory
channel.site = self.site
# Connect client to server and vice versa.
client_to_server_transport = FakeTransport(
channel, self.reactor, client_protocol
)
client_protocol.makeConnection(client_to_server_transport)
server_to_client_transport = FakeTransport(
client_protocol, self.reactor, channel
)
channel.makeConnection(server_to_client_transport)
# The request will now be processed by `self.site` and the response
# streamed back.
self.reactor.advance(0)
# We tear down the connection so it doesn't get reused without our
# knowledge.
server_to_client_transport.loseConnection()
client_to_server_transport.loseConnection()
return request_factory.request
def assert_request_is_get_repl_stream_updates(
self, request: SynapseRequest, stream_name: str
):
"""Asserts that the given request is a HTTP replication request for
fetching updates for given stream.
"""
self.assertRegex(
request.path,
br"^/_synapse/replication/get_repl_stream_updates/%s/[^/]+$"
% (stream_name.encode("ascii"),),
)
self.assertEqual(request.method, b"GET")
class TestReplicationDataHandler(ReplicationDataHandler):
"""Drop-in for ReplicationDataHandler which just collects RDATA rows"""
def __init__(self):
self.streams = set()
self._received_rdata_rows = []
def __init__(self, store: BaseSlavedStore):
super().__init__(store)
# streams to subscribe to: map from stream id to position
self.stream_positions = {} # type: Dict[str, int]
# list of received (stream_name, token, row) tuples
self.received_rdata_rows = [] # type: List[Tuple[str, int, Any]]
def get_streams_to_replicate(self):
positions = {s: 0 for s in self.streams}
for stream, token, _ in self._received_rdata_rows:
if stream in self.streams:
positions[stream] = max(token, positions.get(stream, 0))
return positions
return self.stream_positions
async def on_rdata(self, stream_name, token, rows):
await super().on_rdata(stream_name, token, rows)
for r in rows:
self._received_rdata_rows.append((stream_name, token, r))
self.received_rdata_rows.append((stream_name, token, r))
async def on_position(self, stream_name, token):
pass
if (
stream_name in self.stream_positions
and token > self.stream_positions[stream_name]
):
self.stream_positions[stream_name] = token
@attr.s()
class OneShotRequestFactory:
"""A simple request factory that generates a single `SynapseRequest` and
stores it for future use. Can only be used once.
"""
request = attr.ib(default=None)
def __call__(self, *args, **kwargs):
assert self.request is None
self.request = SynapseRequest(*args, **kwargs)
return self.request
class _PushHTTPChannel(HTTPChannel):
"""A HTTPChannel that wraps pull producers to push producers.
This is a hack to get around the fact that HTTPChannel transparently wraps a
pull producer (which is what Synapse uses to reply to requests) with
`_PullToPush` to convert it to a push producer. Unfortunately `_PullToPush`
uses the standard reactor rather than letting us use our test reactor, which
makes it very hard to test.
"""
def __init__(self, reactor: IReactorTime):
super().__init__()
self.reactor = reactor
self._pull_to_push_producer = None # type: Optional[_PullToPushProducer]
def registerProducer(self, producer, streaming):
# Convert pull producers to push producer.
if not streaming:
self._pull_to_push_producer = _PullToPushProducer(
self.reactor, producer, self
)
producer = self._pull_to_push_producer
super().registerProducer(producer, True)
def unregisterProducer(self):
if self._pull_to_push_producer:
# We need to manually stop the _PullToPushProducer.
self._pull_to_push_producer.stop()
class _PullToPushProducer:
"""A push producer that wraps a pull producer.
"""
def __init__(
self, reactor: IReactorTime, producer: IPullProducer, consumer: IConsumer
):
self._clock = Clock(reactor)
self._producer = producer
self._consumer = consumer
# While running we use a looping call with a zero delay to call
# resumeProducing on given producer.
self._looping_call = None # type: Optional[LoopingCall]
# We start writing next reactor tick.
self._start_loop()
def _start_loop(self):
"""Start the looping call to
"""
if not self._looping_call:
# Start a looping call which runs every tick.
self._looping_call = self._clock.looping_call(self._run_once, 0)
def stop(self):
"""Stops calling resumeProducing.
"""
if self._looping_call:
self._looping_call.stop()
self._looping_call = None
def pauseProducing(self):
"""Implements IPushProducer
"""
self.stop()
def resumeProducing(self):
"""Implements IPushProducer
"""
self._start_loop()
def stopProducing(self):
"""Implements IPushProducer
"""
self.stop()
self._producer.stopProducing()
def _run_once(self):
"""Calls resumeProducing on producer once.
"""
try:
self._producer.resumeProducing()
except Exception:
logger.exception("Failed to call resumeProducing")
try:
self._consumer.unregisterProducer()
except Exception:
pass
self.stopProducing()

View File

@ -0,0 +1,417 @@
# -*- coding: utf-8 -*-
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional
from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
from synapse.replication.tcp.streams._base import _STREAM_UPDATE_TARGET_ROW_COUNT
from synapse.replication.tcp.streams.events import (
EventsStreamCurrentStateRow,
EventsStreamEventRow,
EventsStreamRow,
)
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
from tests.replication.tcp.streams._base import BaseStreamTestCase
from tests.test_utils.event_injection import inject_event, inject_member_event
class EventsStreamTestCase(BaseStreamTestCase):
servlets = [
admin.register_servlets,
login.register_servlets,
room.register_servlets,
]
def prepare(self, reactor, clock, hs):
super().prepare(reactor, clock, hs)
self.user_id = self.register_user("u1", "pass")
self.user_tok = self.login("u1", "pass")
self.reconnect()
self.test_handler.stream_positions["events"] = 0
self.room_id = self.helper.create_room_as(tok=self.user_tok)
self.test_handler.received_rdata_rows.clear()
def test_update_function_event_row_limit(self):
"""Test replication with many non-state events
Checks that all events are correctly replicated when there are lots of
event rows to be replicated.
"""
# disconnect, so that we can stack up some changes
self.disconnect()
# generate lots of non-state events. We inject them using inject_event
# so that they are not send out over replication until we call self.replicate().
events = [
self._inject_test_event()
for _ in range(_STREAM_UPDATE_TARGET_ROW_COUNT + 1)
]
# also one state event
state_event = self._inject_state_event()
# tell the notifier to catch up to avoid duplicate rows.
# workaround for https://github.com/matrix-org/synapse/issues/7360
# FIXME remove this when the above is fixed
self.replicate()
# check we're testing what we think we are: no rows should yet have been
# received
self.assertEqual([], self.test_handler.received_rdata_rows)
# now reconnect to pull the updates
self.reconnect()
self.replicate()
# we should have received all the expected rows in the right order
received_rows = self.test_handler.received_rdata_rows
for event in events:
stream_name, token, row = received_rows.pop(0)
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "ev")
self.assertIsInstance(row.data, EventsStreamEventRow)
self.assertEqual(row.data.event_id, event.event_id)
stream_name, token, row = received_rows.pop(0)
self.assertIsInstance(row, EventsStreamRow)
self.assertIsInstance(row.data, EventsStreamEventRow)
self.assertEqual(row.data.event_id, state_event.event_id)
stream_name, token, row = received_rows.pop(0)
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "state")
self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
self.assertEqual(row.data.event_id, state_event.event_id)
self.assertEqual([], received_rows)
def test_update_function_huge_state_change(self):
"""Test replication with many state events
Ensures that all events are correctly replicated when there are lots of
state change rows to be replicated.
"""
# we want to generate lots of state changes at a single stream ID.
#
# We do this by having two branches in the DAG. On one, we have a moderator
# which that generates lots of state; on the other, we de-op the moderator,
# thus invalidating all the state.
OTHER_USER = "@other_user:localhost"
# have the user join
inject_member_event(self.hs, self.room_id, OTHER_USER, Membership.JOIN)
# Update existing power levels with mod at PL50
pls = self.helper.get_state(
self.room_id, EventTypes.PowerLevels, tok=self.user_tok
)
pls["users"][OTHER_USER] = 50
self.helper.send_state(
self.room_id, EventTypes.PowerLevels, pls, tok=self.user_tok,
)
# this is the point in the DAG where we make a fork
fork_point = self.get_success(
self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id)
) # type: List[str]
events = [
self._inject_state_event(sender=OTHER_USER)
for _ in range(_STREAM_UPDATE_TARGET_ROW_COUNT)
]
self.replicate()
# all those events and state changes should have landed
self.assertGreaterEqual(
len(self.test_handler.received_rdata_rows), 2 * len(events)
)
# disconnect, so that we can stack up the changes
self.disconnect()
self.test_handler.received_rdata_rows.clear()
# a state event which doesn't get rolled back, to check that the state
# before the huge update comes through ok
state1 = self._inject_state_event()
# roll back all the state by de-modding the user
prev_events = fork_point
pls["users"][OTHER_USER] = 0
pl_event = inject_event(
self.hs,
prev_event_ids=prev_events,
type=EventTypes.PowerLevels,
state_key="",
sender=self.user_id,
room_id=self.room_id,
content=pls,
)
# one more bit of state that doesn't get rolled back
state2 = self._inject_state_event()
# tell the notifier to catch up to avoid duplicate rows.
# workaround for https://github.com/matrix-org/synapse/issues/7360
# FIXME remove this when the above is fixed
self.replicate()
# check we're testing what we think we are: no rows should yet have been
# received
self.assertEqual([], self.test_handler.received_rdata_rows)
# now reconnect to pull the updates
self.reconnect()
self.replicate()
# now we should have received all the expected rows in the right order.
#
# we expect:
#
# - two rows for state1
# - the PL event row, plus state rows for the PL event and each
# of the states that got reverted.
# - two rows for state2
received_rows = self.test_handler.received_rdata_rows
# first check the first two rows, which should be state1
stream_name, token, row = received_rows.pop(0)
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "ev")
self.assertIsInstance(row.data, EventsStreamEventRow)
self.assertEqual(row.data.event_id, state1.event_id)
stream_name, token, row = received_rows.pop(0)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "state")
self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
self.assertEqual(row.data.event_id, state1.event_id)
# now the last two rows, which should be state2
stream_name, token, row = received_rows.pop(-2)
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "ev")
self.assertIsInstance(row.data, EventsStreamEventRow)
self.assertEqual(row.data.event_id, state2.event_id)
stream_name, token, row = received_rows.pop(-1)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "state")
self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
self.assertEqual(row.data.event_id, state2.event_id)
# that should leave us with the rows for the PL event
self.assertEqual(len(received_rows), len(events) + 2)
stream_name, token, row = received_rows.pop(0)
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "ev")
self.assertIsInstance(row.data, EventsStreamEventRow)
self.assertEqual(row.data.event_id, pl_event.event_id)
# the state rows are unsorted
state_rows = [] # type: List[EventsStreamCurrentStateRow]
for stream_name, token, row in received_rows:
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "state")
self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
state_rows.append(row.data)
state_rows.sort(key=lambda r: r.state_key)
sr = state_rows.pop(0)
self.assertEqual(sr.type, EventTypes.PowerLevels)
self.assertEqual(sr.event_id, pl_event.event_id)
for sr in state_rows:
self.assertEqual(sr.type, "test_state_event")
# "None" indicates the state has been deleted
self.assertIsNone(sr.event_id)
def test_update_function_state_row_limit(self):
"""Test replication with many state events over several stream ids.
"""
# we want to generate lots of state changes, but for this test, we want to
# spread out the state changes over a few stream IDs.
#
# We do this by having two branches in the DAG. On one, we have four moderators,
# each of which that generates lots of state; on the other, we de-op the users,
# thus invalidating all the state.
NUM_USERS = 4
STATES_PER_USER = _STREAM_UPDATE_TARGET_ROW_COUNT // 4 + 1
user_ids = ["@user%i:localhost" % (i,) for i in range(NUM_USERS)]
# have the users join
for u in user_ids:
inject_member_event(self.hs, self.room_id, u, Membership.JOIN)
# Update existing power levels with mod at PL50
pls = self.helper.get_state(
self.room_id, EventTypes.PowerLevels, tok=self.user_tok
)
pls["users"].update({u: 50 for u in user_ids})
self.helper.send_state(
self.room_id, EventTypes.PowerLevels, pls, tok=self.user_tok,
)
# this is the point in the DAG where we make a fork
fork_point = self.get_success(
self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id)
) # type: List[str]
events = [] # type: List[EventBase]
for user in user_ids:
events.extend(
self._inject_state_event(sender=user) for _ in range(STATES_PER_USER)
)
self.replicate()
# all those events and state changes should have landed
self.assertGreaterEqual(
len(self.test_handler.received_rdata_rows), 2 * len(events)
)
# disconnect, so that we can stack up the changes
self.disconnect()
self.test_handler.received_rdata_rows.clear()
# now roll back all that state by de-modding the users
prev_events = fork_point
pl_events = []
for u in user_ids:
pls["users"][u] = 0
e = inject_event(
self.hs,
prev_event_ids=prev_events,
type=EventTypes.PowerLevels,
state_key="",
sender=self.user_id,
room_id=self.room_id,
content=pls,
)
prev_events = [e.event_id]
pl_events.append(e)
# tell the notifier to catch up to avoid duplicate rows.
# workaround for https://github.com/matrix-org/synapse/issues/7360
# FIXME remove this when the above is fixed
self.replicate()
# check we're testing what we think we are: no rows should yet have been
# received
self.assertEqual([], self.test_handler.received_rdata_rows)
# now reconnect to pull the updates
self.reconnect()
self.replicate()
# we should have received all the expected rows in the right order
received_rows = self.test_handler.received_rdata_rows
self.assertGreaterEqual(len(received_rows), len(events))
for i in range(NUM_USERS):
# for each user, we expect the PL event row, followed by state rows for
# the PL event and each of the states that got reverted.
stream_name, token, row = received_rows.pop(0)
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "ev")
self.assertIsInstance(row.data, EventsStreamEventRow)
self.assertEqual(row.data.event_id, pl_events[i].event_id)
# the state rows are unsorted
state_rows = [] # type: List[EventsStreamCurrentStateRow]
for j in range(STATES_PER_USER + 1):
stream_name, token, row = received_rows.pop(0)
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "state")
self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
state_rows.append(row.data)
state_rows.sort(key=lambda r: r.state_key)
sr = state_rows.pop(0)
self.assertEqual(sr.type, EventTypes.PowerLevels)
self.assertEqual(sr.event_id, pl_events[i].event_id)
for sr in state_rows:
self.assertEqual(sr.type, "test_state_event")
# "None" indicates the state has been deleted
self.assertIsNone(sr.event_id)
self.assertEqual([], received_rows)
event_count = 0
def _inject_test_event(
self, body: Optional[str] = None, sender: Optional[str] = None, **kwargs
) -> EventBase:
if sender is None:
sender = self.user_id
if body is None:
body = "event %i" % (self.event_count,)
self.event_count += 1
return inject_event(
self.hs,
room_id=self.room_id,
sender=sender,
type="test_event",
content={"body": body},
**kwargs
)
def _inject_state_event(
self,
body: Optional[str] = None,
state_key: Optional[str] = None,
sender: Optional[str] = None,
) -> EventBase:
if sender is None:
sender = self.user_id
if state_key is None:
state_key = "state_%i" % (self.event_count,)
self.event_count += 1
if body is None:
body = "state event %s" % (state_key,)
return inject_event(
self.hs,
room_id=self.room_id,
sender=sender,
type="test_state_event",
state_key=state_key,
content={"body": body},
)

View File

@ -12,6 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# type: ignore
from mock import Mock
from synapse.replication.tcp.streams._base import ReceiptsStream
from tests.replication.tcp.streams._base import BaseStreamTestCase
@ -20,11 +25,14 @@ USER_ID = "@feeling:blue"
class ReceiptsStreamTestCase(BaseStreamTestCase):
def _build_replication_data_handler(self):
return Mock(wraps=super()._build_replication_data_handler())
def test_receipt(self):
self.reconnect()
# make the client subscribe to the receipts stream
self.test_handler.streams.add("receipts")
self.test_handler.stream_positions.update({"receipts": 0})
# tell the master to send a new receipt
self.get_success(

View File

@ -0,0 +1,85 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from mock import Mock
from synapse.handlers.typing import RoomMember
from synapse.replication.http import streams
from synapse.replication.tcp.streams import TypingStream
from tests.replication.tcp.streams._base import BaseStreamTestCase
USER_ID = "@feeling:blue"
class TypingStreamTestCase(BaseStreamTestCase):
servlets = [
streams.register_servlets,
]
def _build_replication_data_handler(self):
return Mock(wraps=super()._build_replication_data_handler())
def test_typing(self):
typing = self.hs.get_typing_handler()
room_id = "!bar:blue"
self.reconnect()
# make the client subscribe to the typing stream
self.test_handler.stream_positions.update({"typing": 0})
typing._push_update(member=RoomMember(room_id, USER_ID), typing=True)
self.reactor.advance(0)
# We should now see an attempt to connect to the master
request = self.handle_http_replication_attempt()
self.assert_request_is_get_repl_stream_updates(request, "typing")
self.test_handler.on_rdata.assert_called_once()
stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows))
row = rdata_rows[0] # type: TypingStream.TypingStreamRow
self.assertEqual(room_id, row.room_id)
self.assertEqual([USER_ID], row.user_ids)
# Now let's disconnect and insert some data.
self.disconnect()
self.test_handler.on_rdata.reset_mock()
typing._push_update(member=RoomMember(room_id, USER_ID), typing=False)
self.test_handler.on_rdata.assert_not_called()
self.reconnect()
self.pump(0.1)
# We should now see an attempt to connect to the master
request = self.handle_http_replication_attempt()
self.assert_request_is_get_repl_stream_updates(request, "typing")
# The from token should be the token from the last RDATA we got.
self.assertEqual(int(request.args[b"from_token"][0]), token)
self.test_handler.on_rdata.assert_called_once()
stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows))
row = rdata_rows[0]
self.assertEqual(room_id, row.room_id)
self.assertEqual([], row.user_ids)

View File

@ -28,15 +28,17 @@ class ParseCommandTestCase(TestCase):
self.assertIsInstance(cmd, ReplicateCommand)
def test_parse_rdata(self):
line = 'RDATA events 6287863 ["ev", ["$eventid", "!roomid", "type", null, null, null]]'
line = 'RDATA events master 6287863 ["ev", ["$eventid", "!roomid", "type", null, null, null]]'
cmd = parse_command_from_line(line)
self.assertIsInstance(cmd, RdataCommand)
self.assertEqual(cmd.stream_name, "events")
self.assertEqual(cmd.instance_name, "master")
self.assertEqual(cmd.token, 6287863)
def test_parse_rdata_batch(self):
line = 'RDATA presence batch ["@foo:example.com", "online"]'
line = 'RDATA presence master batch ["@foo:example.com", "online"]'
cmd = parse_command_from_line(line)
self.assertIsInstance(cmd, RdataCommand)
self.assertEqual(cmd.stream_name, "presence")
self.assertEqual(cmd.instance_name, "master")
self.assertIsNone(cmd.token)

View File

@ -0,0 +1,62 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
from twisted.internet.interfaces import IProtocol
from twisted.test.proto_helpers import StringTransport
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from tests.unittest import HomeserverTestCase
class RemoteServerUpTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.factory = ReplicationStreamProtocolFactory(hs)
def _make_client(self) -> Tuple[IProtocol, StringTransport]:
"""Create a new direct TCP replication connection
"""
proto = self.factory.buildProtocol(("127.0.0.1", 0))
transport = StringTransport()
proto.makeConnection(transport)
# We can safely ignore the commands received during connection.
self.pump()
transport.clear()
return proto, transport
def test_relay(self):
"""Test that Synapse will relay REMOTE_SERVER_UP commands to all
other connections, but not the one that sent it.
"""
proto1, transport1 = self._make_client()
# We shouldn't receive an echo.
proto1.dataReceived(b"REMOTE_SERVER_UP example.com\n")
self.pump()
self.assertEqual(transport1.value(), b"")
# But we should see an echo if we connect another client
proto2, transport2 = self._make_client()
proto1.dataReceived(b"REMOTE_SERVER_UP example.com\n")
self.pump()
self.assertEqual(transport1.value(), b"")
self.assertEqual(transport2.value(), b"REMOTE_SERVER_UP example.com\n")

View File

@ -360,6 +360,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(3, len(channel.json_body["users"]))
self.assertEqual(3, channel.json_body["total"])
class UserRestTestCase(unittest.HomeserverTestCase):
@ -434,6 +435,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"admin": True,
"displayname": "Bob's name",
"threepids": [{"medium": "email", "address": "bob@bob.bob"}],
"avatar_url": None,
}
)

View File

@ -39,7 +39,7 @@ class RestHelper(object):
resource = attr.ib()
auth_user_id = attr.ib()
def create_room_as(self, room_creator, is_public=True, tok=None):
def create_room_as(self, room_creator=None, is_public=True, tok=None):
temp_id = self.auth_user_id
self.auth_user_id = room_creator
path = "/_matrix/client/r0/createRoom"

View File

@ -181,3 +181,43 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
)
self.render(request)
self.assertEqual(channel.code, 403)
def test_complete_operation_unknown_session(self):
"""
Attempting to mark an invalid session as complete should error.
"""
# Make the initial request to register. (Later on a different password
# will be used.)
request, channel = self.make_request(
"POST",
"register",
{"username": "user", "type": "m.login.password", "password": "bar"},
)
self.render(request)
# Returns a 401 as per the spec
self.assertEqual(request.code, 401)
# Grab the session
session = channel.json_body["session"]
# Assert our configured public key is being given
self.assertEqual(
channel.json_body["params"]["m.login.recaptcha"]["public_key"], "brokencake"
)
request, channel = self.make_request(
"GET", "auth/m.login.recaptcha/fallback/web?session=" + session
)
self.render(request)
self.assertEqual(request.code, 200)
# Attempt to complete an unknown session, which should return an error.
unknown_session = session + "unknown"
request, channel = self.make_request(
"POST",
"auth/m.login.recaptcha/fallback/web?session="
+ unknown_session
+ "&g-recaptcha-response=a",
)
self.render(request)
self.assertEqual(request.code, 400)

View File

@ -0,0 +1,46 @@
# -*- coding: utf-8 -*-
# Copyright 2020 Awesome Technologies Innovationslabor GmbH
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.types import UserID
from tests import unittest
from tests.utils import setup_test_homeserver
class DataStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
hs = yield setup_test_homeserver(self.addCleanup)
self.store = hs.get_datastore()
self.user = UserID.from_string("@abcde:test")
self.displayname = "Frank"
@defer.inlineCallbacks
def test_get_users_paginate(self):
yield self.store.register_user(self.user.to_string(), "pass")
yield self.store.create_profile(self.user.localpart)
yield self.store.set_profile_displayname(self.user.localpart, self.displayname)
users, total = yield self.store.get_users_paginate(
0, 10, name="bc", guests=False
)
self.assertEquals(1, total)
self.assertEquals(self.displayname, users.pop()["displayname"])

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2019 New Vector Ltd
# Copyright 2020 The Matrix.org Foundation C.I.C
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -16,3 +17,22 @@
"""
Utilities for running the unit tests
"""
from typing import Awaitable, TypeVar
TV = TypeVar("TV")
def get_awaitable_result(awaitable: Awaitable[TV]) -> TV:
"""Get the result from an Awaitable which should have completed
Asserts that the given awaitable has a result ready, and returns its value
"""
i = awaitable.__await__()
try:
next(i)
except StopIteration as e:
# awaitable returned a result
return e.value
# if next didn't raise, the awaitable hasn't completed.
raise Exception("awaitable has not yet completed")

View File

@ -0,0 +1,96 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
# Copyright 2020 The Matrix.org Foundation C.I.C
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import synapse.server
from synapse.api.constants import EventTypes
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
from synapse.types import Collection
from tests.test_utils import get_awaitable_result
"""
Utility functions for poking events into the storage of the server under test.
"""
def inject_member_event(
hs: synapse.server.HomeServer,
room_id: str,
sender: str,
membership: str,
target: Optional[str] = None,
extra_content: Optional[dict] = None,
**kwargs
) -> EventBase:
"""Inject a membership event into a room."""
if target is None:
target = sender
content = {"membership": membership}
if extra_content:
content.update(extra_content)
return inject_event(
hs,
room_id=room_id,
type=EventTypes.Member,
sender=sender,
state_key=target,
content=content,
**kwargs
)
def inject_event(
hs: synapse.server.HomeServer,
room_version: Optional[str] = None,
prev_event_ids: Optional[Collection[str]] = None,
**kwargs
) -> EventBase:
"""Inject a generic event into a room
Args:
hs: the homeserver under test
room_version: the version of the room we're inserting into.
if not specified, will be looked up
prev_event_ids: prev_events for the event. If not specified, will be looked up
kwargs: fields for the event to be created
"""
test_reactor = hs.get_reactor()
if room_version is None:
d = hs.get_datastore().get_room_version_id(kwargs["room_id"])
test_reactor.advance(0)
room_version = get_awaitable_result(d)
builder = hs.get_event_builder_factory().for_room_version(
KNOWN_ROOM_VERSIONS[room_version], kwargs
)
d = hs.get_event_creation_handler().create_new_client_event(
builder, prev_event_ids=prev_event_ids
)
test_reactor.advance(0)
event, context = get_awaitable_result(d)
d = hs.get_storage().persistence.persist_event(event, context)
test_reactor.advance(0)
get_awaitable_result(d)
return event

View File

@ -32,7 +32,6 @@ from twisted.python.threadpool import ThreadPool
from twisted.trial import unittest
from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.config.homeserver import HomeServerConfig
from synapse.config.ratelimiting import FederationRateLimitConfig
from synapse.federation.transport import server as federation_server
@ -55,6 +54,7 @@ from tests.server import (
render,
setup_test_homeserver,
)
from tests.test_utils import event_injection
from tests.test_utils.logging_setup import setup_logging
from tests.utils import default_config, setupdb
@ -596,36 +596,14 @@ class HomeserverTestCase(TestCase):
"""
Inject a membership event into a room.
Deprecated: use event_injection.inject_room_member directly
Args:
room: Room ID to inject the event into.
user: MXID of the user to inject the membership for.
membership: The membership type.
"""
event_builder_factory = self.hs.get_event_builder_factory()
event_creation_handler = self.hs.get_event_creation_handler()
room_version = self.get_success(
self.hs.get_datastore().get_room_version_id(room)
)
builder = event_builder_factory.for_room_version(
KNOWN_ROOM_VERSIONS[room_version],
{
"type": EventTypes.Member,
"sender": user,
"state_key": user,
"room_id": room,
"content": {"membership": membership},
},
)
event, context = self.get_success(
event_creation_handler.create_new_client_event(builder)
)
self.get_success(
self.hs.get_storage().persistence.persist_event(event, context)
)
event_injection.inject_member_event(self.hs, room, user, membership)
class FederatingHomeserverTestCase(HomeserverTestCase):

View File

@ -74,7 +74,10 @@ def setupdb():
db_conn.autocommit = True
cur = db_conn.cursor()
cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,))
cur.execute("CREATE DATABASE %s;" % (POSTGRES_BASE_DB,))
cur.execute(
"CREATE DATABASE %s ENCODING 'UTF8' LC_COLLATE='C' LC_CTYPE='C' "
"template=template0;" % (POSTGRES_BASE_DB,)
)
cur.close()
db_conn.close()
@ -509,8 +512,8 @@ class MockClock(object):
return t
def looping_call(self, function, interval):
self.loopers.append([function, interval / 1000.0, self.now])
def looping_call(self, function, interval, *args, **kwargs):
self.loopers.append([function, interval / 1000.0, self.now, args, kwargs])
def cancel_call_later(self, timer, ignore_errs=False):
if timer[2]:
@ -540,9 +543,9 @@ class MockClock(object):
self.timers.append(t)
for looped in self.loopers:
func, interval, last = looped
func, interval, last, args, kwargs = looped
if last + interval < self.now:
func()
func(*args, **kwargs)
looped[2] = self.now
def advance_time_msec(self, ms):

View File

@ -200,10 +200,13 @@ commands = mypy \
synapse/replication \
synapse/rest \
synapse/spam_checker_api \
synapse/storage/engines \
synapse/storage/data_stores/main/ui_auth.py \
synapse/storage/database.py \
synapse/storage/engines \
synapse/streams \
synapse/util/caches/stream_change_cache.py \
tests/replication/tcp/streams \
tests/test_utils \
tests/util/test_stream_change_cache.py
# To find all folders that pass mypy you run: