Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes
commit
e9bd4bb388
|
@ -0,0 +1 @@
|
|||
Return total number of users and profile attributes in admin users endpoint. Contributed by Awesome Technologies Innovationslabor GmbH.
|
|
@ -0,0 +1 @@
|
|||
Improve the documentation of application service configuration files.
|
|
@ -0,0 +1 @@
|
|||
Run replication streamers on workers.
|
|
@ -0,0 +1 @@
|
|||
Add some unit tests for replication.
|
|
@ -0,0 +1 @@
|
|||
Persist user interactive authentication sessions across workers and Synapse restarts.
|
|
@ -0,0 +1 @@
|
|||
Fixed backwards compatibility logic of the first value of `trusted_third_party_id_servers` being used for `account_threepid_delegates.email`, which occurs when the former, deprecated option is set and the latter is not.
|
|
@ -0,0 +1 @@
|
|||
Convert some federation handler code to async/await.
|
|
@ -0,0 +1 @@
|
|||
Fix bad error handling that would cause Synapse to crash if it's provided with a YAML configuration file that's either empty or doesn't parse into a key-value map.
|
|
@ -0,0 +1 @@
|
|||
Support SSO in the user interactive authentication workflow.
|
|
@ -0,0 +1 @@
|
|||
Fix incorrect metrics reporting for `renew_attestations` background task.
|
|
@ -0,0 +1 @@
|
|||
Add support for running replication over Redis when using workers.
|
|
@ -0,0 +1 @@
|
|||
Add documentation on monitoring workers with Prometheus.
|
|
@ -0,0 +1 @@
|
|||
Fix a bug where event updates might not be sent over replication to worker processes after the stream falls behind.
|
|
@ -0,0 +1 @@
|
|||
Fix collation for postgres for unit tests.
|
|
@ -0,0 +1 @@
|
|||
Clarify endpoint usage in the users admin api documentation.
|
|
@ -0,0 +1 @@
|
|||
Add an `instance_name` to `RDATA` and `POSITION` replication commands.
|
|
@ -0,0 +1 @@
|
|||
Prevent non-federating rooms from appearing in responses to federated `POST /publicRoom` requests when a filter was included.
|
|
@ -0,0 +1 @@
|
|||
Move catchup of replication streams logic to worker.
|
|
@ -33,12 +33,22 @@ with a body of:
|
|||
|
||||
including an ``access_token`` of a server admin.
|
||||
|
||||
The parameter ``displayname`` is optional and defaults to ``user_id``.
|
||||
The parameter ``threepids`` is optional.
|
||||
The parameter ``avatar_url`` is optional.
|
||||
The parameter ``admin`` is optional and defaults to 'false'.
|
||||
The parameter ``deactivated`` is optional and defaults to 'false'.
|
||||
The parameter ``password`` is optional. If provided the user's password is updated and all devices are logged out.
|
||||
The parameter ``displayname`` is optional and defaults to the value of
|
||||
``user_id``.
|
||||
|
||||
The parameter ``threepids`` is optional and allows setting the third-party IDs
|
||||
(email, msisdn) belonging to a user.
|
||||
|
||||
The parameter ``avatar_url`` is optional. Must be a [MXC
|
||||
URI](https://matrix.org/docs/spec/client_server/r0.6.0#matrix-content-mxc-uris).
|
||||
|
||||
The parameter ``admin`` is optional and defaults to ``false``.
|
||||
|
||||
The parameter ``deactivated`` is optional and defaults to ``false``.
|
||||
|
||||
The parameter ``password`` is optional. If provided, the user's password is
|
||||
updated and all devices are logged out.
|
||||
|
||||
If the user already exists then optional parameters default to the current value.
|
||||
|
||||
List Accounts
|
||||
|
@ -51,16 +61,25 @@ The api is::
|
|||
GET /_synapse/admin/v2/users?from=0&limit=10&guests=false
|
||||
|
||||
including an ``access_token`` of a server admin.
|
||||
The parameters ``from`` and ``limit`` are required only for pagination.
|
||||
By default, a ``limit`` of 100 is used.
|
||||
The parameter ``user_id`` can be used to select only users with user ids that
|
||||
contain this value.
|
||||
The parameter ``guests=false`` can be used to exclude guest users,
|
||||
default is to include guest users.
|
||||
The parameter ``deactivated=true`` can be used to include deactivated users,
|
||||
default is to exclude deactivated users.
|
||||
If the endpoint does not return a ``next_token`` then there are no more users left.
|
||||
It returns a JSON body like the following:
|
||||
|
||||
The parameter ``from`` is optional but used for pagination, denoting the
|
||||
offset in the returned results. This should be treated as an opaque value and
|
||||
not explicitly set to anything other than the return value of ``next_token``
|
||||
from a previous call.
|
||||
|
||||
The parameter ``limit`` is optional but is used for pagination, denoting the
|
||||
maximum number of items to return in this call. Defaults to ``100``.
|
||||
|
||||
The parameter ``user_id`` is optional and filters to only users with user IDs
|
||||
that contain this value.
|
||||
|
||||
The parameter ``guests`` is optional and if ``false`` will **exclude** guest users.
|
||||
Defaults to ``true`` to include guest users.
|
||||
|
||||
The parameter ``deactivated`` is optional and if ``true`` will **include** deactivated users.
|
||||
Defaults to ``false`` to exclude deactivated users.
|
||||
|
||||
A JSON body is returned with the following shape:
|
||||
|
||||
.. code:: json
|
||||
|
||||
|
@ -72,19 +91,29 @@ It returns a JSON body like the following:
|
|||
"is_guest": 0,
|
||||
"admin": 0,
|
||||
"user_type": null,
|
||||
"deactivated": 0
|
||||
"deactivated": 0,
|
||||
"displayname": "<User One>",
|
||||
"avatar_url": null
|
||||
}, {
|
||||
"name": "<user_id2>",
|
||||
"password_hash": "<password_hash2>",
|
||||
"is_guest": 0,
|
||||
"admin": 1,
|
||||
"user_type": null,
|
||||
"deactivated": 0
|
||||
"deactivated": 0,
|
||||
"displayname": "<User Two>",
|
||||
"avatar_url": "<avatar_url>"
|
||||
}
|
||||
],
|
||||
"next_token": "100"
|
||||
"next_token": "100",
|
||||
"total": 200
|
||||
}
|
||||
|
||||
To paginate, check for ``next_token`` and if present, call the endpoint again
|
||||
with ``from`` set to the value of ``next_token``. This will return a new page.
|
||||
|
||||
If the endpoint does not return a ``next_token`` then there are no more users
|
||||
to paginate through.
|
||||
|
||||
Query Account
|
||||
=============
|
||||
|
|
|
@ -23,9 +23,13 @@ namespaces:
|
|||
users: # List of users we're interested in
|
||||
- exclusive: <bool>
|
||||
regex: <regex>
|
||||
group_id: <group>
|
||||
- ...
|
||||
aliases: [] # List of aliases we're interested in
|
||||
rooms: [] # List of room ids we're interested in
|
||||
```
|
||||
|
||||
`exclusive`: If enabled, only this application service is allowed to register users in its namespace(s).
|
||||
`group_id`: All users of this application service are dynamically joined to this group. This is useful for e.g user organisation or flairs.
|
||||
|
||||
See the [spec](https://matrix.org/docs/spec/application_service/unstable.html) for further details on how application services work.
|
||||
|
|
|
@ -60,6 +60,31 @@
|
|||
|
||||
1. Restart Prometheus.
|
||||
|
||||
## Monitoring workers
|
||||
|
||||
To monitor a Synapse installation using
|
||||
[workers](https://github.com/matrix-org/synapse/blob/master/docs/workers.md),
|
||||
every worker needs to be monitored independently, in addition to
|
||||
the main homeserver process. This is because workers don't send
|
||||
their metrics to the main homeserver process, but expose them
|
||||
directly (if they are configured to do so).
|
||||
|
||||
To allow collecting metrics from a worker, you need to add a
|
||||
`metrics` listener to its configuration, by adding the following
|
||||
under `worker_listeners`:
|
||||
|
||||
```yaml
|
||||
- type: metrics
|
||||
bind_address: ''
|
||||
port: 9101
|
||||
```
|
||||
|
||||
The `bind_address` and `port` parameters should be set so that
|
||||
the resulting listener can be reached by prometheus, and they
|
||||
don't clash with an existing worker.
|
||||
With this example, the worker's metrics would then be available
|
||||
on `http://127.0.0.1:9101`.
|
||||
|
||||
## Renaming of metrics & deprecation of old names in 1.2
|
||||
|
||||
Synapse 1.2 updates the Prometheus metrics to match the naming
|
||||
|
|
|
@ -1518,6 +1518,30 @@ sso:
|
|||
#
|
||||
# * server_name: the homeserver's name.
|
||||
#
|
||||
# * HTML page which notifies the user that they are authenticating to confirm
|
||||
# an operation on their account during the user interactive authentication
|
||||
# process: 'sso_auth_confirm.html'.
|
||||
#
|
||||
# When rendering, this template is given the following variables:
|
||||
# * redirect_url: the URL the user is about to be redirected to. Needs
|
||||
# manual escaping (see
|
||||
# https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
|
||||
#
|
||||
# * description: the operation which the user is being asked to confirm
|
||||
#
|
||||
# * HTML page shown after a successful user interactive authentication session:
|
||||
# 'sso_auth_success.html'.
|
||||
#
|
||||
# Note that this page must include the JavaScript which notifies of a successful authentication
|
||||
# (see https://matrix.org/docs/spec/client_server/r0.6.0#fallback).
|
||||
#
|
||||
# This template has no additional variables.
|
||||
#
|
||||
# * HTML page shown during single sign-on if a deactivated user (according to Synapse's database)
|
||||
# attempts to login: 'sso_account_deactivated.html'.
|
||||
#
|
||||
# This template has no additional variables.
|
||||
#
|
||||
# You can see the default templates at:
|
||||
# https://github.com/matrix-org/synapse/tree/master/synapse/res/templates
|
||||
#
|
||||
|
|
|
@ -15,15 +15,17 @@ example flow would be (where '>' indicates master to worker and
|
|||
|
||||
> SERVER example.com
|
||||
< REPLICATE
|
||||
> POSITION events 53
|
||||
> RDATA events 54 ["$foo1:bar.com", ...]
|
||||
> RDATA events 55 ["$foo4:bar.com", ...]
|
||||
> POSITION events master 53
|
||||
> RDATA events master 54 ["$foo1:bar.com", ...]
|
||||
> RDATA events master 55 ["$foo4:bar.com", ...]
|
||||
|
||||
The example shows the server accepting a new connection and sending its identity
|
||||
with the `SERVER` command, followed by the client server to respond with the
|
||||
position of all streams. The server then periodically sends `RDATA` commands
|
||||
which have the format `RDATA <stream_name> <token> <row>`, where the format of
|
||||
`<row>` is defined by the individual streams.
|
||||
which have the format `RDATA <stream_name> <instance_name> <token> <row>`, where
|
||||
the format of `<row>` is defined by the individual streams. The
|
||||
`<instance_name>` is the name of the Synapse process that generated the data
|
||||
(usually "master").
|
||||
|
||||
Error reporting happens by either the client or server sending an ERROR
|
||||
command, and usually the connection will be closed.
|
||||
|
@ -52,7 +54,7 @@ The basic structure of the protocol is line based, where the initial
|
|||
word of each line specifies the command. The rest of the line is parsed
|
||||
based on the command. For example, the RDATA command is defined as:
|
||||
|
||||
RDATA <stream_name> <token> <row_json>
|
||||
RDATA <stream_name> <instance_name> <token> <row_json>
|
||||
|
||||
(Note that <row_json> may contains spaces, but cannot contain
|
||||
newlines.)
|
||||
|
@ -136,11 +138,11 @@ the wire:
|
|||
< NAME synapse.app.appservice
|
||||
< PING 1490197665618
|
||||
< REPLICATE
|
||||
> POSITION events 1
|
||||
> POSITION backfill 1
|
||||
> POSITION caches 1
|
||||
> RDATA caches 2 ["get_user_by_id",["@01register-user:localhost:8823"],1490197670513]
|
||||
> RDATA events 14 ["$149019767112vOHxz:localhost:8823",
|
||||
> POSITION events master 1
|
||||
> POSITION backfill master 1
|
||||
> POSITION caches master 1
|
||||
> RDATA caches master 2 ["get_user_by_id",["@01register-user:localhost:8823"],1490197670513]
|
||||
> RDATA events master 14 ["$149019767112vOHxz:localhost:8823",
|
||||
"!AFDCvgApUmpdfVjIXm:localhost:8823","m.room.guest_access","",null]
|
||||
< PING 1490197675618
|
||||
> ERROR server stopping
|
||||
|
@ -151,10 +153,10 @@ position without needing to send data with the `RDATA` command.
|
|||
|
||||
An example of a batched set of `RDATA` is:
|
||||
|
||||
> RDATA caches batch ["get_user_by_id",["@test:localhost:8823"],1490197670513]
|
||||
> RDATA caches batch ["get_user_by_id",["@test2:localhost:8823"],1490197670513]
|
||||
> RDATA caches batch ["get_user_by_id",["@test3:localhost:8823"],1490197670513]
|
||||
> RDATA caches 54 ["get_user_by_id",["@test4:localhost:8823"],1490197670513]
|
||||
> RDATA caches master batch ["get_user_by_id",["@test:localhost:8823"],1490197670513]
|
||||
> RDATA caches master batch ["get_user_by_id",["@test2:localhost:8823"],1490197670513]
|
||||
> RDATA caches master batch ["get_user_by_id",["@test3:localhost:8823"],1490197670513]
|
||||
> RDATA caches master 54 ["get_user_by_id",["@test4:localhost:8823"],1490197670513]
|
||||
|
||||
In this case the client shouldn't advance their caches token until it
|
||||
sees the the last `RDATA`.
|
||||
|
@ -178,6 +180,11 @@ client (C):
|
|||
updates, and if so then fetch them out of band. Sent in response to a
|
||||
REPLICATE command (but can happen at any time).
|
||||
|
||||
The POSITION command includes the source of the stream. Currently all streams
|
||||
are written by a single process (usually "master"). If fetching missing
|
||||
updates via HTTP API, rather than via the DB, then processes should make the
|
||||
request to the appropriate process.
|
||||
|
||||
#### ERROR (S, C)
|
||||
|
||||
There was an error
|
||||
|
@ -234,12 +241,12 @@ Each individual cache invalidation results in a row being sent down
|
|||
replication, which includes the cache name (the name of the function)
|
||||
and they key to invalidate. For example:
|
||||
|
||||
> RDATA caches 550953771 ["get_user_by_id", ["@bob:example.com"], 1550574873251]
|
||||
> RDATA caches master 550953771 ["get_user_by_id", ["@bob:example.com"], 1550574873251]
|
||||
|
||||
Alternatively, an entire cache can be invalidated by sending down a `null`
|
||||
instead of the key. For example:
|
||||
|
||||
> RDATA caches 550953772 ["get_user_by_id", null, 1550574873252]
|
||||
> RDATA caches master 550953772 ["get_user_by_id", null, 1550574873252]
|
||||
|
||||
However, there are times when a number of caches need to be invalidated
|
||||
at the same time with the same key. To reduce traffic we batch those
|
||||
|
|
|
@ -270,7 +270,7 @@ def start(hs, listeners=None):
|
|||
|
||||
# Start the tracer
|
||||
synapse.logging.opentracing.init_tracer( # type: ignore[attr-defined] # noqa
|
||||
hs.config
|
||||
hs
|
||||
)
|
||||
|
||||
# It is now safe to start your Synapse.
|
||||
|
@ -316,7 +316,7 @@ def setup_sentry(hs):
|
|||
scope.set_tag("matrix_server_name", hs.config.server_name)
|
||||
|
||||
app = hs.config.worker_app if hs.config.worker_app else "synapse.app.homeserver"
|
||||
name = hs.config.worker_name if hs.config.worker_name else "master"
|
||||
name = hs.get_instance_name()
|
||||
scope.set_tag("worker_app", app)
|
||||
scope.set_tag("worker_name", name)
|
||||
|
||||
|
|
|
@ -127,6 +127,7 @@ from synapse.storage.data_stores.main.monthly_active_users import (
|
|||
MonthlyActiveUsersWorkerStore,
|
||||
)
|
||||
from synapse.storage.data_stores.main.presence import UserPresenceState
|
||||
from synapse.storage.data_stores.main.ui_auth import UIAuthWorkerStore
|
||||
from synapse.storage.data_stores.main.user_directory import UserDirectoryStore
|
||||
from synapse.types import ReadReceipt
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
|
@ -439,6 +440,7 @@ class GenericWorkerSlavedStore(
|
|||
# FIXME(#3714): We need to add UserDirectoryStore as we write directly
|
||||
# rather than going via the correct worker.
|
||||
UserDirectoryStore,
|
||||
UIAuthWorkerStore,
|
||||
SlavedDeviceInboxStore,
|
||||
SlavedDeviceStore,
|
||||
SlavedReceiptsStore,
|
||||
|
@ -960,17 +962,22 @@ def start(config_options):
|
|||
|
||||
synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
||||
|
||||
ss = GenericWorkerServer(
|
||||
hs = GenericWorkerServer(
|
||||
config.server_name,
|
||||
config=config,
|
||||
version_string="Synapse/" + get_version_string(synapse),
|
||||
)
|
||||
|
||||
setup_logging(ss, config, use_worker_options=True)
|
||||
setup_logging(hs, config, use_worker_options=True)
|
||||
|
||||
hs.setup()
|
||||
|
||||
# Ensure the replication streamer is always started in case we write to any
|
||||
# streams. Will no-op if no streams can be written to by this worker.
|
||||
hs.get_replication_streamer()
|
||||
|
||||
ss.setup()
|
||||
reactor.addSystemEventTrigger(
|
||||
"before", "startup", _base.start, ss, config.worker_listeners
|
||||
"before", "startup", _base.start, hs, config.worker_listeners
|
||||
)
|
||||
|
||||
_base.start_worker_reactor("synapse-generic-worker", config)
|
||||
|
|
|
@ -657,6 +657,12 @@ def read_config_files(config_files):
|
|||
for config_file in config_files:
|
||||
with open(config_file) as file_stream:
|
||||
yaml_config = yaml.safe_load(file_stream)
|
||||
|
||||
if not isinstance(yaml_config, dict):
|
||||
err = "File %r is empty or doesn't parse into a key-value map. IGNORING."
|
||||
print(err % (config_file,))
|
||||
continue
|
||||
|
||||
specified_config.update(yaml_config)
|
||||
|
||||
if "server_name" not in specified_config:
|
||||
|
|
|
@ -138,7 +138,7 @@ class DatabaseConfig(Config):
|
|||
database_path = config.get("database_path")
|
||||
|
||||
if multi_database_config and database_config:
|
||||
raise ConfigError("Can't specify both 'database' and 'datbases' in config")
|
||||
raise ConfigError("Can't specify both 'database' and 'databases' in config")
|
||||
|
||||
if multi_database_config:
|
||||
if database_path:
|
||||
|
|
|
@ -108,9 +108,14 @@ class EmailConfig(Config):
|
|||
if self.trusted_third_party_id_servers:
|
||||
# XXX: It's a little confusing that account_threepid_delegate_email is modified
|
||||
# both in RegistrationConfig and here. We should factor this bit out
|
||||
self.account_threepid_delegate_email = self.trusted_third_party_id_servers[
|
||||
0
|
||||
] # type: Optional[str]
|
||||
|
||||
first_trusted_identity_server = self.trusted_third_party_id_servers[0]
|
||||
|
||||
# trusted_third_party_id_servers does not contain a scheme whereas
|
||||
# account_threepid_delegate_email is expected to. Presume https
|
||||
self.account_threepid_delegate_email = (
|
||||
"https://" + first_trusted_identity_server
|
||||
) # type: Optional[str]
|
||||
self.using_identity_server_from_trusted_list = True
|
||||
else:
|
||||
raise ConfigError(
|
||||
|
|
|
@ -113,6 +113,30 @@ class SSOConfig(Config):
|
|||
#
|
||||
# * server_name: the homeserver's name.
|
||||
#
|
||||
# * HTML page which notifies the user that they are authenticating to confirm
|
||||
# an operation on their account during the user interactive authentication
|
||||
# process: 'sso_auth_confirm.html'.
|
||||
#
|
||||
# When rendering, this template is given the following variables:
|
||||
# * redirect_url: the URL the user is about to be redirected to. Needs
|
||||
# manual escaping (see
|
||||
# https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
|
||||
#
|
||||
# * description: the operation which the user is being asked to confirm
|
||||
#
|
||||
# * HTML page shown after a successful user interactive authentication session:
|
||||
# 'sso_auth_success.html'.
|
||||
#
|
||||
# Note that this page must include the JavaScript which notifies of a successful authentication
|
||||
# (see https://matrix.org/docs/spec/client_server/r0.6.0#fallback).
|
||||
#
|
||||
# This template has no additional variables.
|
||||
#
|
||||
# * HTML page shown during single sign-on if a deactivated user (according to Synapse's database)
|
||||
# attempts to login: 'sso_account_deactivated.html'.
|
||||
#
|
||||
# This template has no additional variables.
|
||||
#
|
||||
# You can see the default templates at:
|
||||
# https://github.com/matrix-org/synapse/tree/master/synapse/res/templates
|
||||
#
|
||||
|
|
|
@ -37,15 +37,16 @@ An attestation is a signed blob of json that looks like:
|
|||
|
||||
import logging
|
||||
import random
|
||||
from typing import Tuple
|
||||
|
||||
from signedjson.sign import sign_json
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
|
||||
from synapse.logging.context import run_in_background
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.types import get_domain_from_id
|
||||
from synapse.util.async_helpers import yieldable_gather_results
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -162,19 +163,19 @@ class GroupAttestionRenewer(object):
|
|||
def _start_renew_attestations(self):
|
||||
return run_as_background_process("renew_attestations", self._renew_attestations)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _renew_attestations(self):
|
||||
async def _renew_attestations(self):
|
||||
"""Called periodically to check if we need to update any of our attestations
|
||||
"""
|
||||
|
||||
now = self.clock.time_msec()
|
||||
|
||||
rows = yield self.store.get_attestations_need_renewals(
|
||||
rows = await self.store.get_attestations_need_renewals(
|
||||
now + UPDATE_ATTESTATION_TIME_MS
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _renew_attestation(group_id, user_id):
|
||||
def _renew_attestation(group_user: Tuple[str, str]):
|
||||
group_id, user_id = group_user
|
||||
try:
|
||||
if not self.is_mine_id(group_id):
|
||||
destination = get_domain_from_id(group_id)
|
||||
|
@ -207,8 +208,6 @@ class GroupAttestionRenewer(object):
|
|||
"Error renewing attestation of %r in %r", user_id, group_id
|
||||
)
|
||||
|
||||
for row in rows:
|
||||
group_id = row["group_id"]
|
||||
user_id = row["user_id"]
|
||||
|
||||
run_in_background(_renew_attestation, group_id, user_id)
|
||||
await yieldable_gather_results(
|
||||
_renew_attestation, ((row["group_id"], row["user_id"]) for row in rows)
|
||||
)
|
||||
|
|
|
@ -41,10 +41,10 @@ from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
|
|||
from synapse.http.server import finish_request
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.logging.context import defer_to_thread
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.module_api import ModuleApi
|
||||
from synapse.push.mailer import load_jinja2_templates
|
||||
from synapse.types import Requester, UserID
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
|
||||
from ._base import BaseHandler
|
||||
|
||||
|
@ -69,15 +69,6 @@ class AuthHandler(BaseHandler):
|
|||
|
||||
self.bcrypt_rounds = hs.config.bcrypt_rounds
|
||||
|
||||
# This is not a cache per se, but a store of all current sessions that
|
||||
# expire after N hours
|
||||
self.sessions = ExpiringCache(
|
||||
cache_name="register_sessions",
|
||||
clock=hs.get_clock(),
|
||||
expiry_ms=self.SESSION_EXPIRE_MS,
|
||||
reset_expiry_on_get=True,
|
||||
)
|
||||
|
||||
account_handler = ModuleApi(hs, self)
|
||||
self.password_providers = [
|
||||
module(config=config, account_handler=account_handler)
|
||||
|
@ -119,6 +110,15 @@ class AuthHandler(BaseHandler):
|
|||
|
||||
self._clock = self.hs.get_clock()
|
||||
|
||||
# Expire old UI auth sessions after a period of time.
|
||||
if hs.config.worker_app is None:
|
||||
self._clock.looping_call(
|
||||
run_as_background_process,
|
||||
5 * 60 * 1000,
|
||||
"expire_old_sessions",
|
||||
self._expire_old_sessions,
|
||||
)
|
||||
|
||||
# Load the SSO HTML templates.
|
||||
|
||||
# The following template is shown to the user during a client login via SSO,
|
||||
|
@ -301,16 +301,21 @@ class AuthHandler(BaseHandler):
|
|||
if "session" in authdict:
|
||||
sid = authdict["session"]
|
||||
|
||||
# Convert the URI and method to strings.
|
||||
uri = request.uri.decode("utf-8")
|
||||
method = request.uri.decode("utf-8")
|
||||
|
||||
# If there's no session ID, create a new session.
|
||||
if not sid:
|
||||
session = self._create_session(
|
||||
clientdict, (request.uri, request.method, clientdict), description
|
||||
session = await self.store.create_ui_auth_session(
|
||||
clientdict, uri, method, description
|
||||
)
|
||||
session_id = session["id"]
|
||||
|
||||
else:
|
||||
session = self._get_session_info(sid)
|
||||
session_id = sid
|
||||
try:
|
||||
session = await self.store.get_ui_auth_session(sid)
|
||||
except StoreError:
|
||||
raise SynapseError(400, "Unknown session ID: %s" % (sid,))
|
||||
|
||||
if not clientdict:
|
||||
# This was designed to allow the client to omit the parameters
|
||||
|
@ -322,15 +327,15 @@ class AuthHandler(BaseHandler):
|
|||
# on a homeserver.
|
||||
# Revisit: Assuming the REST APIs do sensible validation, the data
|
||||
# isn't arbitrary.
|
||||
clientdict = session["clientdict"]
|
||||
clientdict = session.clientdict
|
||||
|
||||
# Ensure that the queried operation does not vary between stages of
|
||||
# the UI authentication session. This is done by generating a stable
|
||||
# comparator based on the URI, method, and body (minus the auth dict)
|
||||
# and storing it during the initial query. Subsequent queries ensure
|
||||
# that this comparator has not changed.
|
||||
comparator = (request.uri, request.method, clientdict)
|
||||
if session["ui_auth"] != comparator:
|
||||
comparator = (uri, method, clientdict)
|
||||
if (session.uri, session.method, session.clientdict) != comparator:
|
||||
raise SynapseError(
|
||||
403,
|
||||
"Requested operation has changed during the UI authentication session.",
|
||||
|
@ -338,11 +343,9 @@ class AuthHandler(BaseHandler):
|
|||
|
||||
if not authdict:
|
||||
raise InteractiveAuthIncompleteError(
|
||||
self._auth_dict_for_flows(flows, session_id)
|
||||
self._auth_dict_for_flows(flows, session.session_id)
|
||||
)
|
||||
|
||||
creds = session["creds"]
|
||||
|
||||
# check auth type currently being presented
|
||||
errordict = {} # type: Dict[str, Any]
|
||||
if "type" in authdict:
|
||||
|
@ -350,8 +353,9 @@ class AuthHandler(BaseHandler):
|
|||
try:
|
||||
result = await self._check_auth_dict(authdict, clientip)
|
||||
if result:
|
||||
creds[login_type] = result
|
||||
self._save_session(session)
|
||||
await self.store.mark_ui_auth_stage_complete(
|
||||
session.session_id, login_type, result
|
||||
)
|
||||
except LoginError as e:
|
||||
if login_type == LoginType.EMAIL_IDENTITY:
|
||||
# riot used to have a bug where it would request a new
|
||||
|
@ -367,6 +371,7 @@ class AuthHandler(BaseHandler):
|
|||
# so that the client can have another go.
|
||||
errordict = e.error_dict()
|
||||
|
||||
creds = await self.store.get_completed_ui_auth_stages(session.session_id)
|
||||
for f in flows:
|
||||
if len(set(f) - set(creds)) == 0:
|
||||
# it's very useful to know what args are stored, but this can
|
||||
|
@ -380,9 +385,9 @@ class AuthHandler(BaseHandler):
|
|||
list(clientdict),
|
||||
)
|
||||
|
||||
return creds, clientdict, session_id
|
||||
return creds, clientdict, session.session_id
|
||||
|
||||
ret = self._auth_dict_for_flows(flows, session_id)
|
||||
ret = self._auth_dict_for_flows(flows, session.session_id)
|
||||
ret["completed"] = list(creds)
|
||||
ret.update(errordict)
|
||||
raise InteractiveAuthIncompleteError(ret)
|
||||
|
@ -399,13 +404,11 @@ class AuthHandler(BaseHandler):
|
|||
if "session" not in authdict:
|
||||
raise LoginError(400, "", Codes.MISSING_PARAM)
|
||||
|
||||
sess = self._get_session_info(authdict["session"])
|
||||
creds = sess["creds"]
|
||||
|
||||
result = await self.checkers[stagetype].check_auth(authdict, clientip)
|
||||
if result:
|
||||
creds[stagetype] = result
|
||||
self._save_session(sess)
|
||||
await self.store.mark_ui_auth_stage_complete(
|
||||
authdict["session"], stagetype, result
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
@ -427,7 +430,7 @@ class AuthHandler(BaseHandler):
|
|||
sid = authdict["session"]
|
||||
return sid
|
||||
|
||||
def set_session_data(self, session_id: str, key: str, value: Any) -> None:
|
||||
async def set_session_data(self, session_id: str, key: str, value: Any) -> None:
|
||||
"""
|
||||
Store a key-value pair into the sessions data associated with this
|
||||
request. This data is stored server-side and cannot be modified by
|
||||
|
@ -438,11 +441,12 @@ class AuthHandler(BaseHandler):
|
|||
key: The key to store the data under
|
||||
value: The data to store
|
||||
"""
|
||||
sess = self._get_session_info(session_id)
|
||||
sess["serverdict"][key] = value
|
||||
self._save_session(sess)
|
||||
try:
|
||||
await self.store.set_ui_auth_session_data(session_id, key, value)
|
||||
except StoreError:
|
||||
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
|
||||
|
||||
def get_session_data(
|
||||
async def get_session_data(
|
||||
self, session_id: str, key: str, default: Optional[Any] = None
|
||||
) -> Any:
|
||||
"""
|
||||
|
@ -453,8 +457,18 @@ class AuthHandler(BaseHandler):
|
|||
key: The key to store the data under
|
||||
default: Value to return if the key has not been set
|
||||
"""
|
||||
sess = self._get_session_info(session_id)
|
||||
return sess["serverdict"].get(key, default)
|
||||
try:
|
||||
return await self.store.get_ui_auth_session_data(session_id, key, default)
|
||||
except StoreError:
|
||||
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
|
||||
|
||||
async def _expire_old_sessions(self):
|
||||
"""
|
||||
Invalidate any user interactive authentication sessions that have expired.
|
||||
"""
|
||||
now = self._clock.time_msec()
|
||||
expiration_time = now - self.SESSION_EXPIRE_MS
|
||||
await self.store.delete_old_ui_auth_sessions(expiration_time)
|
||||
|
||||
async def _check_auth_dict(
|
||||
self, authdict: Dict[str, Any], clientip: str
|
||||
|
@ -534,67 +548,6 @@ class AuthHandler(BaseHandler):
|
|||
"params": params,
|
||||
}
|
||||
|
||||
def _create_session(
|
||||
self,
|
||||
clientdict: Dict[str, Any],
|
||||
ui_auth: Tuple[bytes, bytes, Dict[str, Any]],
|
||||
description: str,
|
||||
) -> dict:
|
||||
"""
|
||||
Creates a new user interactive authentication session.
|
||||
|
||||
The session can be used to track data across multiple requests, e.g. for
|
||||
interactive authentication.
|
||||
|
||||
Each session has the following keys:
|
||||
|
||||
id:
|
||||
A unique identifier for this session. Passed back to the client
|
||||
and returned for each stage.
|
||||
clientdict:
|
||||
The dictionary from the client root level, not the 'auth' key.
|
||||
ui_auth:
|
||||
A tuple which is checked at each stage of the authentication to
|
||||
ensure that the asked for operation has not changed.
|
||||
creds:
|
||||
A map, which maps each auth-type (str) to the relevant identity
|
||||
authenticated by that auth-type (mostly str, but for captcha, bool).
|
||||
serverdict:
|
||||
A map of data that is stored server-side and cannot be modified
|
||||
by the client.
|
||||
description:
|
||||
A string description of the operation that the current
|
||||
authentication is authorising.
|
||||
Returns:
|
||||
The newly created session.
|
||||
"""
|
||||
session_id = None
|
||||
while session_id is None or session_id in self.sessions:
|
||||
session_id = stringutils.random_string(24)
|
||||
|
||||
self.sessions[session_id] = {
|
||||
"id": session_id,
|
||||
"clientdict": clientdict,
|
||||
"ui_auth": ui_auth,
|
||||
"creds": {},
|
||||
"serverdict": {},
|
||||
"description": description,
|
||||
}
|
||||
|
||||
return self.sessions[session_id]
|
||||
|
||||
def _get_session_info(self, session_id: str) -> dict:
|
||||
"""
|
||||
Gets a session given a session ID.
|
||||
|
||||
The session can be used to track data across multiple requests, e.g. for
|
||||
interactive authentication.
|
||||
"""
|
||||
try:
|
||||
return self.sessions[session_id]
|
||||
except KeyError:
|
||||
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
|
||||
|
||||
async def get_access_token_for_user_id(
|
||||
self, user_id: str, device_id: Optional[str], valid_until_ms: Optional[int]
|
||||
):
|
||||
|
@ -994,13 +947,6 @@ class AuthHandler(BaseHandler):
|
|||
await self.store.user_delete_threepid(user_id, medium, address)
|
||||
return result
|
||||
|
||||
def _save_session(self, session: Dict[str, Any]) -> None:
|
||||
"""Update the last used time on the session to now and add it back to the session store."""
|
||||
# TODO: Persistent storage
|
||||
logger.debug("Saving session %s", session)
|
||||
session["last_used"] = self.hs.get_clock().time_msec()
|
||||
self.sessions[session["id"]] = session
|
||||
|
||||
async def hash(self, password: str) -> str:
|
||||
"""Computes a secure hash of password.
|
||||
|
||||
|
@ -1052,7 +998,7 @@ class AuthHandler(BaseHandler):
|
|||
else:
|
||||
return False
|
||||
|
||||
def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str:
|
||||
async def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str:
|
||||
"""
|
||||
Get the HTML for the SSO redirect confirmation page.
|
||||
|
||||
|
@ -1063,12 +1009,15 @@ class AuthHandler(BaseHandler):
|
|||
Returns:
|
||||
The HTML to render.
|
||||
"""
|
||||
session = self._get_session_info(session_id)
|
||||
try:
|
||||
session = await self.store.get_ui_auth_session(session_id)
|
||||
except StoreError:
|
||||
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
|
||||
return self._sso_auth_confirm_template.render(
|
||||
description=session["description"], redirect_url=redirect_url,
|
||||
description=session.description, redirect_url=redirect_url,
|
||||
)
|
||||
|
||||
def complete_sso_ui_auth(
|
||||
async def complete_sso_ui_auth(
|
||||
self, registered_user_id: str, session_id: str, request: SynapseRequest,
|
||||
):
|
||||
"""Having figured out a mxid for this user, complete the HTTP request
|
||||
|
@ -1080,13 +1029,11 @@ class AuthHandler(BaseHandler):
|
|||
process.
|
||||
"""
|
||||
# Mark the stage of the authentication as successful.
|
||||
sess = self._get_session_info(session_id)
|
||||
creds = sess["creds"]
|
||||
|
||||
# Save the user who authenticated with SSO, this will be used to ensure
|
||||
# that the account be modified is also the person who logged in.
|
||||
creds[LoginType.SSO] = registered_user_id
|
||||
self._save_session(sess)
|
||||
await self.store.mark_ui_auth_stage_complete(
|
||||
session_id, LoginType.SSO, registered_user_id
|
||||
)
|
||||
|
||||
# Render the HTML and return.
|
||||
html_bytes = self._sso_auth_success_template.encode("utf-8")
|
||||
|
|
|
@ -206,7 +206,7 @@ class CasHandler:
|
|||
registered_user_id = await self._auth_handler.check_user_exists(user_id)
|
||||
|
||||
if session:
|
||||
self._auth_handler.complete_sso_ui_auth(
|
||||
await self._auth_handler.complete_sso_ui_auth(
|
||||
registered_user_id, session, request,
|
||||
)
|
||||
|
||||
|
|
|
@ -343,7 +343,7 @@ class FederationHandler(BaseHandler):
|
|||
ours = await self.state_store.get_state_groups_ids(room_id, seen)
|
||||
|
||||
# state_maps is a list of mappings from (type, state_key) to event_id
|
||||
state_maps = list(ours.values()) # type: list[StateMap[str]]
|
||||
state_maps = list(ours.values()) # type: List[StateMap[str]]
|
||||
|
||||
# we don't need this any more, let's delete it.
|
||||
del ours
|
||||
|
@ -1694,16 +1694,15 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
return None
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_for_pdu(self, room_id, event_id):
|
||||
async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]:
|
||||
"""Returns the state at the event. i.e. not including said event.
|
||||
"""
|
||||
|
||||
event = yield self.store.get_event(
|
||||
event = await self.store.get_event(
|
||||
event_id, allow_none=False, check_room_id=room_id
|
||||
)
|
||||
|
||||
state_groups = yield self.state_store.get_state_groups(room_id, [event_id])
|
||||
state_groups = await self.state_store.get_state_groups(room_id, [event_id])
|
||||
|
||||
if state_groups:
|
||||
_, state = list(iteritems(state_groups)).pop()
|
||||
|
@ -1714,7 +1713,7 @@ class FederationHandler(BaseHandler):
|
|||
if "replaces_state" in event.unsigned:
|
||||
prev_id = event.unsigned["replaces_state"]
|
||||
if prev_id != event.event_id:
|
||||
prev_event = yield self.store.get_event(prev_id)
|
||||
prev_event = await self.store.get_event(prev_id)
|
||||
results[(event.type, event.state_key)] = prev_event
|
||||
else:
|
||||
del results[(event.type, event.state_key)]
|
||||
|
@ -1724,15 +1723,14 @@ class FederationHandler(BaseHandler):
|
|||
else:
|
||||
return []
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_ids_for_pdu(self, room_id, event_id):
|
||||
async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]:
|
||||
"""Returns the state at the event. i.e. not including said event.
|
||||
"""
|
||||
event = yield self.store.get_event(
|
||||
event = await self.store.get_event(
|
||||
event_id, allow_none=False, check_room_id=room_id
|
||||
)
|
||||
|
||||
state_groups = yield self.state_store.get_state_groups_ids(room_id, [event_id])
|
||||
state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id])
|
||||
|
||||
if state_groups:
|
||||
_, state = list(state_groups.items()).pop()
|
||||
|
@ -1751,49 +1749,50 @@ class FederationHandler(BaseHandler):
|
|||
else:
|
||||
return []
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def on_backfill_request(self, origin, room_id, pdu_list, limit):
|
||||
in_room = yield self.auth.check_host_in_room(room_id, origin)
|
||||
async def on_backfill_request(
|
||||
self, origin: str, room_id: str, pdu_list: List[str], limit: int
|
||||
) -> List[EventBase]:
|
||||
in_room = await self.auth.check_host_in_room(room_id, origin)
|
||||
if not in_room:
|
||||
raise AuthError(403, "Host not in room.")
|
||||
|
||||
# Synapse asks for 100 events per backfill request. Do not allow more.
|
||||
limit = min(limit, 100)
|
||||
|
||||
events = yield self.store.get_backfill_events(room_id, pdu_list, limit)
|
||||
events = await self.store.get_backfill_events(room_id, pdu_list, limit)
|
||||
|
||||
events = yield filter_events_for_server(self.storage, origin, events)
|
||||
events = await filter_events_for_server(self.storage, origin, events)
|
||||
|
||||
return events
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def get_persisted_pdu(self, origin, event_id):
|
||||
async def get_persisted_pdu(
|
||||
self, origin: str, event_id: str
|
||||
) -> Optional[EventBase]:
|
||||
"""Get an event from the database for the given server.
|
||||
|
||||
Args:
|
||||
origin [str]: hostname of server which is requesting the event; we
|
||||
origin: hostname of server which is requesting the event; we
|
||||
will check that the server is allowed to see it.
|
||||
event_id [str]: id of the event being requested
|
||||
event_id: id of the event being requested
|
||||
|
||||
Returns:
|
||||
Deferred[EventBase|None]: None if we know nothing about the event;
|
||||
otherwise the (possibly-redacted) event.
|
||||
None if we know nothing about the event; otherwise the (possibly-redacted) event.
|
||||
|
||||
Raises:
|
||||
AuthError if the server is not currently in the room
|
||||
"""
|
||||
event = yield self.store.get_event(
|
||||
event = await self.store.get_event(
|
||||
event_id, allow_none=True, allow_rejected=True
|
||||
)
|
||||
|
||||
if event:
|
||||
in_room = yield self.auth.check_host_in_room(event.room_id, origin)
|
||||
in_room = await self.auth.check_host_in_room(event.room_id, origin)
|
||||
if not in_room:
|
||||
raise AuthError(403, "Host not in room.")
|
||||
|
||||
events = yield filter_events_for_server(self.storage, origin, [event])
|
||||
events = await filter_events_for_server(self.storage, origin, [event])
|
||||
event = events[0]
|
||||
return event
|
||||
else:
|
||||
|
@ -2397,7 +2396,7 @@ class FederationHandler(BaseHandler):
|
|||
"""
|
||||
# exclude the state key of the new event from the current_state in the context.
|
||||
if event.is_state():
|
||||
event_key = (event.type, event.state_key)
|
||||
event_key = (event.type, event.state_key) # type: Optional[Tuple[str, str]]
|
||||
else:
|
||||
event_key = None
|
||||
state_updates = {
|
||||
|
|
|
@ -91,7 +91,11 @@ class RoomListHandler(BaseHandler):
|
|||
logger.info("Bypassing cache as search request.")
|
||||
|
||||
return self._get_public_room_list(
|
||||
limit, since_token, search_filter, network_tuple=network_tuple
|
||||
limit,
|
||||
since_token,
|
||||
search_filter,
|
||||
network_tuple=network_tuple,
|
||||
from_federation=from_federation,
|
||||
)
|
||||
|
||||
key = (limit, since_token, network_tuple)
|
||||
|
|
|
@ -149,7 +149,7 @@ class SamlHandler:
|
|||
|
||||
# Complete the interactive auth session or the login.
|
||||
if current_session and current_session.ui_auth_session_id:
|
||||
self._auth_handler.complete_sso_ui_auth(
|
||||
await self._auth_handler.complete_sso_ui_auth(
|
||||
user_id, current_session.ui_auth_session_id, request
|
||||
)
|
||||
|
||||
|
|
|
@ -171,7 +171,7 @@ import logging
|
|||
import re
|
||||
import types
|
||||
from functools import wraps
|
||||
from typing import Dict
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
from canonicaljson import json
|
||||
|
||||
|
@ -179,6 +179,9 @@ from twisted.internet import defer
|
|||
|
||||
from synapse.config import ConfigError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
# Helper class
|
||||
|
||||
|
||||
|
@ -297,14 +300,11 @@ def _noop_context_manager(*args, **kwargs):
|
|||
# Setup
|
||||
|
||||
|
||||
def init_tracer(config):
|
||||
def init_tracer(hs: "HomeServer"):
|
||||
"""Set the whitelists and initialise the JaegerClient tracer
|
||||
|
||||
Args:
|
||||
config (HomeserverConfig): The config used by the homeserver
|
||||
"""
|
||||
global opentracing
|
||||
if not config.opentracer_enabled:
|
||||
if not hs.config.opentracer_enabled:
|
||||
# We don't have a tracer
|
||||
opentracing = None
|
||||
return
|
||||
|
@ -315,18 +315,15 @@ def init_tracer(config):
|
|||
"installed."
|
||||
)
|
||||
|
||||
# Include the worker name
|
||||
name = config.worker_name if config.worker_name else "master"
|
||||
|
||||
# Pull out the jaeger config if it was given. Otherwise set it to something sensible.
|
||||
# See https://github.com/jaegertracing/jaeger-client-python/blob/master/jaeger_client/config.py
|
||||
|
||||
set_homeserver_whitelist(config.opentracer_whitelist)
|
||||
set_homeserver_whitelist(hs.config.opentracer_whitelist)
|
||||
|
||||
JaegerConfig(
|
||||
config=config.jaeger_config,
|
||||
service_name="{} {}".format(config.server_name, name),
|
||||
scope_manager=LogContextScopeManager(config),
|
||||
config=hs.config.jaeger_config,
|
||||
service_name="{} {}".format(hs.config.server_name, hs.get_instance_name()),
|
||||
scope_manager=LogContextScopeManager(hs.config),
|
||||
).initialize_tracer()
|
||||
|
||||
|
||||
|
|
|
@ -220,12 +220,6 @@ class Notifier(object):
|
|||
"""
|
||||
self.replication_callbacks.append(cb)
|
||||
|
||||
def add_remote_server_up_callback(self, cb: Callable[[str], None]):
|
||||
"""Add a callback that will be called when synapse detects a server
|
||||
has been
|
||||
"""
|
||||
self.remote_server_up_callbacks.append(cb)
|
||||
|
||||
def on_new_room_event(
|
||||
self, event, room_stream_id, max_room_stream_id, extra_users=[]
|
||||
):
|
||||
|
@ -544,6 +538,3 @@ class Notifier(object):
|
|||
# circular dependencies.
|
||||
if self.federation_sender:
|
||||
self.federation_sender.wake_destination(server)
|
||||
|
||||
for cb in self.remote_server_up_callbacks:
|
||||
cb(server)
|
||||
|
|
|
@ -95,7 +95,7 @@ class RdataCommand(Command):
|
|||
|
||||
Format::
|
||||
|
||||
RDATA <stream_name> <token> <row_json>
|
||||
RDATA <stream_name> <instance_name> <token> <row_json>
|
||||
|
||||
The `<token>` may either be a numeric stream id OR "batch". The latter case
|
||||
is used to support sending multiple updates with the same stream ID. This
|
||||
|
@ -105,33 +105,40 @@ class RdataCommand(Command):
|
|||
The client should batch all incoming RDATA with a token of "batch" (per
|
||||
stream_name) until it sees an RDATA with a numeric stream ID.
|
||||
|
||||
The `<instance_name>` is the source of the new data (usually "master").
|
||||
|
||||
`<token>` of "batch" maps to the instance variable `token` being None.
|
||||
|
||||
An example of a batched series of RDATA::
|
||||
|
||||
RDATA presence batch ["@foo:example.com", "online", ...]
|
||||
RDATA presence batch ["@bar:example.com", "online", ...]
|
||||
RDATA presence 59 ["@baz:example.com", "online", ...]
|
||||
RDATA presence master batch ["@foo:example.com", "online", ...]
|
||||
RDATA presence master batch ["@bar:example.com", "online", ...]
|
||||
RDATA presence master 59 ["@baz:example.com", "online", ...]
|
||||
"""
|
||||
|
||||
NAME = "RDATA"
|
||||
|
||||
def __init__(self, stream_name, token, row):
|
||||
def __init__(self, stream_name, instance_name, token, row):
|
||||
self.stream_name = stream_name
|
||||
self.instance_name = instance_name
|
||||
self.token = token
|
||||
self.row = row
|
||||
|
||||
@classmethod
|
||||
def from_line(cls, line):
|
||||
stream_name, token, row_json = line.split(" ", 2)
|
||||
stream_name, instance_name, token, row_json = line.split(" ", 3)
|
||||
return cls(
|
||||
stream_name, None if token == "batch" else int(token), json.loads(row_json)
|
||||
stream_name,
|
||||
instance_name,
|
||||
None if token == "batch" else int(token),
|
||||
json.loads(row_json),
|
||||
)
|
||||
|
||||
def to_line(self):
|
||||
return " ".join(
|
||||
(
|
||||
self.stream_name,
|
||||
self.instance_name,
|
||||
str(self.token) if self.token is not None else "batch",
|
||||
_json_encoder.encode(self.row),
|
||||
)
|
||||
|
@ -145,23 +152,31 @@ class PositionCommand(Command):
|
|||
"""Sent by the server to tell the client the stream postition without
|
||||
needing to send an RDATA.
|
||||
|
||||
Format::
|
||||
|
||||
POSITION <stream_name> <instance_name> <token>
|
||||
|
||||
On receipt of a POSITION command clients should check if they have missed
|
||||
any updates, and if so then fetch them out of band.
|
||||
|
||||
The `<instance_name>` is the process that sent the command and is the source
|
||||
of the stream.
|
||||
"""
|
||||
|
||||
NAME = "POSITION"
|
||||
|
||||
def __init__(self, stream_name, token):
|
||||
def __init__(self, stream_name, instance_name, token):
|
||||
self.stream_name = stream_name
|
||||
self.instance_name = instance_name
|
||||
self.token = token
|
||||
|
||||
@classmethod
|
||||
def from_line(cls, line):
|
||||
stream_name, token = line.split(" ", 1)
|
||||
return cls(stream_name, int(token))
|
||||
stream_name, instance_name, token = line.split(" ", 2)
|
||||
return cls(stream_name, instance_name, int(token))
|
||||
|
||||
def to_line(self):
|
||||
return " ".join((self.stream_name, str(self.token)))
|
||||
return " ".join((self.stream_name, self.instance_name, str(self.token)))
|
||||
|
||||
|
||||
class ErrorCommand(_SimpleCommand):
|
||||
|
|
|
@ -79,6 +79,7 @@ class ReplicationCommandHandler:
|
|||
self._notifier = hs.get_notifier()
|
||||
self._clock = hs.get_clock()
|
||||
self._instance_id = hs.get_instance_id()
|
||||
self._instance_name = hs.get_instance_name()
|
||||
|
||||
# Set of streams that we've caught up with.
|
||||
self._streams_connected = set() # type: Set[str]
|
||||
|
@ -87,7 +88,9 @@ class ReplicationCommandHandler:
|
|||
stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
|
||||
} # type: Dict[str, Stream]
|
||||
|
||||
self._position_linearizer = Linearizer("replication_position")
|
||||
self._position_linearizer = Linearizer(
|
||||
"replication_position", clock=self._clock
|
||||
)
|
||||
|
||||
# Map of stream to batched updates. See RdataCommand for info on how
|
||||
# batching works.
|
||||
|
@ -115,7 +118,6 @@ class ReplicationCommandHandler:
|
|||
self._server_notices_sender = None
|
||||
if self._is_master:
|
||||
self._server_notices_sender = hs.get_server_notices_sender()
|
||||
self._notifier.add_remote_server_up_callback(self.send_remote_server_up)
|
||||
|
||||
def start_replication(self, hs):
|
||||
"""Helper method to start a replication connection to the remote server
|
||||
|
@ -155,13 +157,13 @@ class ReplicationCommandHandler:
|
|||
hs.config.redis.redis_host, hs.config.redis.redis_port, self._factory,
|
||||
)
|
||||
else:
|
||||
client_name = hs.config.worker_name
|
||||
client_name = hs.get_instance_name()
|
||||
self._factory = DirectTcpReplicationClientFactory(hs, client_name, self)
|
||||
host = hs.config.worker_replication_host
|
||||
port = hs.config.worker_replication_port
|
||||
hs.get_reactor().connectTCP(host, port, self._factory)
|
||||
|
||||
async def on_REPLICATE(self, cmd: ReplicateCommand):
|
||||
async def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
|
||||
# We only want to announce positions by the writer of the streams.
|
||||
# Currently this is just the master process.
|
||||
if not self._is_master:
|
||||
|
@ -169,9 +171,11 @@ class ReplicationCommandHandler:
|
|||
|
||||
for stream_name, stream in self._streams.items():
|
||||
current_token = stream.current_token()
|
||||
self.send_command(PositionCommand(stream_name, current_token))
|
||||
self.send_command(
|
||||
PositionCommand(stream_name, self._instance_name, current_token)
|
||||
)
|
||||
|
||||
async def on_USER_SYNC(self, cmd: UserSyncCommand):
|
||||
async def on_USER_SYNC(self, conn: AbstractConnection, cmd: UserSyncCommand):
|
||||
user_sync_counter.inc()
|
||||
|
||||
if self._is_master:
|
||||
|
@ -179,17 +183,23 @@ class ReplicationCommandHandler:
|
|||
cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
|
||||
)
|
||||
|
||||
async def on_CLEAR_USER_SYNC(self, cmd: ClearUserSyncsCommand):
|
||||
async def on_CLEAR_USER_SYNC(
|
||||
self, conn: AbstractConnection, cmd: ClearUserSyncsCommand
|
||||
):
|
||||
if self._is_master:
|
||||
await self._presence_handler.update_external_syncs_clear(cmd.instance_id)
|
||||
|
||||
async def on_FEDERATION_ACK(self, cmd: FederationAckCommand):
|
||||
async def on_FEDERATION_ACK(
|
||||
self, conn: AbstractConnection, cmd: FederationAckCommand
|
||||
):
|
||||
federation_ack_counter.inc()
|
||||
|
||||
if self._federation_sender:
|
||||
self._federation_sender.federation_ack(cmd.token)
|
||||
|
||||
async def on_REMOVE_PUSHER(self, cmd: RemovePusherCommand):
|
||||
async def on_REMOVE_PUSHER(
|
||||
self, conn: AbstractConnection, cmd: RemovePusherCommand
|
||||
):
|
||||
remove_pusher_counter.inc()
|
||||
|
||||
if self._is_master:
|
||||
|
@ -199,7 +209,9 @@ class ReplicationCommandHandler:
|
|||
|
||||
self._notifier.on_new_replication_data()
|
||||
|
||||
async def on_INVALIDATE_CACHE(self, cmd: InvalidateCacheCommand):
|
||||
async def on_INVALIDATE_CACHE(
|
||||
self, conn: AbstractConnection, cmd: InvalidateCacheCommand
|
||||
):
|
||||
invalidate_cache_counter.inc()
|
||||
|
||||
if self._is_master:
|
||||
|
@ -209,7 +221,7 @@ class ReplicationCommandHandler:
|
|||
cmd.cache_func, tuple(cmd.keys)
|
||||
)
|
||||
|
||||
async def on_USER_IP(self, cmd: UserIpCommand):
|
||||
async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand):
|
||||
user_ip_cache_counter.inc()
|
||||
|
||||
if self._is_master:
|
||||
|
@ -225,7 +237,11 @@ class ReplicationCommandHandler:
|
|||
if self._server_notices_sender:
|
||||
await self._server_notices_sender.on_user_ip(cmd.user_id)
|
||||
|
||||
async def on_RDATA(self, cmd: RdataCommand):
|
||||
async def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
|
||||
if cmd.instance_name == self._instance_name:
|
||||
# Ignore RDATA that are just our own echoes
|
||||
return
|
||||
|
||||
stream_name = cmd.stream_name
|
||||
inbound_rdata_count.labels(stream_name).inc()
|
||||
|
||||
|
@ -276,7 +292,11 @@ class ReplicationCommandHandler:
|
|||
logger.debug("Received rdata %s -> %s", stream_name, token)
|
||||
await self._replication_data_handler.on_rdata(stream_name, token, rows)
|
||||
|
||||
async def on_POSITION(self, cmd: PositionCommand):
|
||||
async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
|
||||
if cmd.instance_name == self._instance_name:
|
||||
# Ignore POSITION that are just our own echoes
|
||||
return
|
||||
|
||||
stream = self._streams.get(cmd.stream_name)
|
||||
if not stream:
|
||||
logger.error("Got POSITION for unknown stream: %s", cmd.stream_name)
|
||||
|
@ -330,12 +350,30 @@ class ReplicationCommandHandler:
|
|||
|
||||
self._streams_connected.add(cmd.stream_name)
|
||||
|
||||
async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
|
||||
async def on_REMOTE_SERVER_UP(
|
||||
self, conn: AbstractConnection, cmd: RemoteServerUpCommand
|
||||
):
|
||||
""""Called when get a new REMOTE_SERVER_UP command."""
|
||||
self._replication_data_handler.on_remote_server_up(cmd.data)
|
||||
|
||||
if self._is_master:
|
||||
self._notifier.notify_remote_server_up(cmd.data)
|
||||
self._notifier.notify_remote_server_up(cmd.data)
|
||||
|
||||
# We relay to all other connections to ensure every instance gets the
|
||||
# notification.
|
||||
#
|
||||
# When configured to use redis we'll always only have one connection and
|
||||
# so this is a no-op (all instances will have already received the same
|
||||
# REMOTE_SERVER_UP command).
|
||||
#
|
||||
# For direct TCP connections this will relay to all other connections
|
||||
# connected to us. When on master this will correctly fan out to all
|
||||
# other direct TCP clients and on workers there'll only be the one
|
||||
# connection to master.
|
||||
#
|
||||
# (The logic here should also be sound if we have a mix of Redis and
|
||||
# direct TCP connections so long as there is only one traffic route
|
||||
# between two instances, but that is not currently supported).
|
||||
self.send_command(cmd, ignore_conn=conn)
|
||||
|
||||
def new_connection(self, connection: AbstractConnection):
|
||||
"""Called when we have a new connection.
|
||||
|
@ -380,11 +418,21 @@ class ReplicationCommandHandler:
|
|||
"""
|
||||
return bool(self._connections)
|
||||
|
||||
def send_command(self, cmd: Command):
|
||||
def send_command(
|
||||
self, cmd: Command, ignore_conn: Optional[AbstractConnection] = None
|
||||
):
|
||||
"""Send a command to all connected connections.
|
||||
|
||||
Args:
|
||||
cmd
|
||||
ignore_conn: If set don't send command to the given connection.
|
||||
Used when relaying commands from one connection to all others.
|
||||
"""
|
||||
if self._connections:
|
||||
for connection in self._connections:
|
||||
if connection == ignore_conn:
|
||||
continue
|
||||
|
||||
try:
|
||||
connection.send_command(cmd)
|
||||
except Exception:
|
||||
|
@ -448,7 +496,7 @@ class ReplicationCommandHandler:
|
|||
|
||||
We need to check if the client is interested in the stream or not
|
||||
"""
|
||||
self.send_command(RdataCommand(stream_name, token, data))
|
||||
self.send_command(RdataCommand(stream_name, self._instance_name, token, data))
|
||||
|
||||
|
||||
UpdateToken = TypeVar("UpdateToken")
|
||||
|
|
|
@ -260,7 +260,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
|||
# Then call out to the handler.
|
||||
cmd_func = getattr(self.command_handler, "on_%s" % (cmd.NAME,), None)
|
||||
if cmd_func:
|
||||
await cmd_func(cmd)
|
||||
await cmd_func(self, cmd)
|
||||
handled = True
|
||||
|
||||
if not handled:
|
||||
|
|
|
@ -112,7 +112,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
|
|||
# Then call out to the handler.
|
||||
cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
|
||||
if cmd_func:
|
||||
await cmd_func(cmd)
|
||||
await cmd_func(self, cmd)
|
||||
handled = True
|
||||
|
||||
if not handled:
|
||||
|
|
|
@ -17,9 +17,7 @@
|
|||
|
||||
import logging
|
||||
import random
|
||||
from typing import Dict
|
||||
|
||||
from six import itervalues
|
||||
from typing import Dict, List
|
||||
|
||||
from prometheus_client import Counter
|
||||
|
||||
|
@ -71,29 +69,28 @@ class ReplicationStreamer(object):
|
|||
|
||||
def __init__(self, hs):
|
||||
self.store = hs.get_datastore()
|
||||
self.presence_handler = hs.get_presence_handler()
|
||||
self.clock = hs.get_clock()
|
||||
self.notifier = hs.get_notifier()
|
||||
self._server_notices_sender = hs.get_server_notices_sender()
|
||||
|
||||
self._replication_torture_level = hs.config.replication_torture_level
|
||||
|
||||
# List of streams that clients can subscribe to.
|
||||
# We only support federation stream if federation sending hase been
|
||||
# disabled on the master.
|
||||
self.streams = [
|
||||
stream(hs)
|
||||
for stream in itervalues(STREAMS_MAP)
|
||||
if stream != FederationStream or not hs.config.send_federation
|
||||
]
|
||||
# Work out list of streams that this instance is the source of.
|
||||
self.streams = [] # type: List[Stream]
|
||||
if hs.config.worker_app is None:
|
||||
for stream in STREAMS_MAP.values():
|
||||
if stream == FederationStream and hs.config.send_federation:
|
||||
# We only support federation stream if federation sending
|
||||
# hase been disabled on the master.
|
||||
continue
|
||||
|
||||
self.streams.append(stream(hs))
|
||||
|
||||
self.streams_by_name = {stream.NAME: stream for stream in self.streams}
|
||||
|
||||
self.federation_sender = None
|
||||
if not hs.config.send_federation:
|
||||
self.federation_sender = hs.get_federation_sender()
|
||||
|
||||
self.notifier.add_replication_callback(self.on_notifier_poke)
|
||||
# Only bother registering the notifier callback if we have streams to
|
||||
# publish.
|
||||
if self.streams:
|
||||
self.notifier.add_replication_callback(self.on_notifier_poke)
|
||||
|
||||
# Keeps track of whether we are currently checking for updates
|
||||
self.is_looping = False
|
||||
|
|
|
@ -176,10 +176,9 @@ def db_query_to_update_function(
|
|||
rows = await query_function(from_token, upto_token, limit)
|
||||
updates = [(row[0], row[1:]) for row in rows]
|
||||
limited = False
|
||||
if len(updates) == limit:
|
||||
if len(updates) >= limit:
|
||||
upto_token = updates[-1][0]
|
||||
limited = True
|
||||
assert len(updates) <= limit
|
||||
|
||||
return updates, upto_token, limited
|
||||
|
||||
|
|
|
@ -170,22 +170,16 @@ class EventsStream(Stream):
|
|||
limited = False
|
||||
upper_limit = current_token
|
||||
|
||||
# next up is the state delta table
|
||||
|
||||
state_rows = await self._store.get_all_updated_current_state_deltas(
|
||||
# next up is the state delta table.
|
||||
(
|
||||
state_rows,
|
||||
upper_limit,
|
||||
state_rows_limited,
|
||||
) = await self._store.get_all_updated_current_state_deltas(
|
||||
from_token, upper_limit, target_row_count
|
||||
) # type: List[Tuple]
|
||||
)
|
||||
|
||||
# again, if we've hit the limit there, we'll need to limit the other sources
|
||||
assert len(state_rows) < target_row_count
|
||||
if len(state_rows) == target_row_count:
|
||||
assert state_rows[-1][0] <= upper_limit
|
||||
upper_limit = state_rows[-1][0]
|
||||
limited = True
|
||||
|
||||
# FIXME: is it a given that there is only one row per stream_id in the
|
||||
# state_deltas table (so that we can be sure that we have got all of the
|
||||
# rows for upper_limit)?
|
||||
limited = limited or state_rows_limited
|
||||
|
||||
# finally, fetch the ex-outliers rows. We assume there are few enough of these
|
||||
# not to bother with the limit.
|
||||
|
|
|
@ -94,10 +94,10 @@ class UsersRestServletV2(RestServlet):
|
|||
guests = parse_boolean(request, "guests", default=True)
|
||||
deactivated = parse_boolean(request, "deactivated", default=False)
|
||||
|
||||
users = await self.store.get_users_paginate(
|
||||
users, total = await self.store.get_users_paginate(
|
||||
start, limit, user_id, guests, deactivated
|
||||
)
|
||||
ret = {"users": users}
|
||||
ret = {"users": users, "total": total}
|
||||
if len(users) >= limit:
|
||||
ret["next_token"] = str(start + len(users))
|
||||
|
||||
|
@ -199,7 +199,7 @@ class UserRestServletV2(RestServlet):
|
|||
user_id, threepid["medium"], threepid["address"], current_time
|
||||
)
|
||||
|
||||
if "avatar_url" in body:
|
||||
if "avatar_url" in body and type(body["avatar_url"]) == str:
|
||||
await self.profile_handler.set_avatar_url(
|
||||
target_user, requester, body["avatar_url"], True
|
||||
)
|
||||
|
@ -276,7 +276,7 @@ class UserRestServletV2(RestServlet):
|
|||
user_id, threepid["medium"], threepid["address"], current_time
|
||||
)
|
||||
|
||||
if "avatar_url" in body:
|
||||
if "avatar_url" in body and type(body["avatar_url"]) == str:
|
||||
await self.profile_handler.set_avatar_url(
|
||||
user_id, requester, body["avatar_url"], True
|
||||
)
|
||||
|
|
|
@ -140,7 +140,7 @@ class AuthRestServlet(RestServlet):
|
|||
self._cas_server_url = hs.config.cas_server_url
|
||||
self._cas_service_url = hs.config.cas_service_url
|
||||
|
||||
def on_GET(self, request, stagetype):
|
||||
async def on_GET(self, request, stagetype):
|
||||
session = parse_string(request, "session")
|
||||
if not session:
|
||||
raise SynapseError(400, "No session supplied")
|
||||
|
@ -180,7 +180,7 @@ class AuthRestServlet(RestServlet):
|
|||
else:
|
||||
raise SynapseError(400, "Homeserver not configured for SSO.")
|
||||
|
||||
html = self.auth_handler.start_sso_ui_auth(sso_redirect_url, session)
|
||||
html = await self.auth_handler.start_sso_ui_auth(sso_redirect_url, session)
|
||||
|
||||
else:
|
||||
raise SynapseError(404, "Unknown auth stage type")
|
||||
|
|
|
@ -499,7 +499,7 @@ class RegisterRestServlet(RestServlet):
|
|||
# registered a user for this session, so we could just return the
|
||||
# user here. We carry on and go through the auth checks though,
|
||||
# for paranoia.
|
||||
registered_user_id = self.auth_handler.get_session_data(
|
||||
registered_user_id = await self.auth_handler.get_session_data(
|
||||
session_id, "registered_user_id", None
|
||||
)
|
||||
|
||||
|
@ -598,7 +598,7 @@ class RegisterRestServlet(RestServlet):
|
|||
|
||||
# remember that we've now registered that user account, and with
|
||||
# what user ID (since the user may not have specified)
|
||||
self.auth_handler.set_session_data(
|
||||
await self.auth_handler.set_session_data(
|
||||
session_id, "registered_user_id", registered_user_id
|
||||
)
|
||||
|
||||
|
|
|
@ -234,7 +234,8 @@ class HomeServer(object):
|
|||
self._listening_services = []
|
||||
self.start_time = None
|
||||
|
||||
self.instance_id = random_string(5)
|
||||
self._instance_id = random_string(5)
|
||||
self._instance_name = config.worker_name or "master"
|
||||
|
||||
self.clock = Clock(reactor)
|
||||
self.distributor = Distributor()
|
||||
|
@ -254,7 +255,15 @@ class HomeServer(object):
|
|||
This is used to distinguish running instances in worker-based
|
||||
deployments.
|
||||
"""
|
||||
return self.instance_id
|
||||
return self._instance_id
|
||||
|
||||
def get_instance_name(self) -> str:
|
||||
"""A unique name for this synapse process.
|
||||
|
||||
Used to identify the process over replication and in config. Does not
|
||||
change over restarts.
|
||||
"""
|
||||
return self._instance_name
|
||||
|
||||
def setup(self):
|
||||
logger.info("Setting up.")
|
||||
|
|
|
@ -25,6 +25,7 @@ import synapse.server_notices.server_notices_manager
|
|||
import synapse.server_notices.server_notices_sender
|
||||
import synapse.state
|
||||
import synapse.storage
|
||||
from synapse.events.builder import EventBuilderFactory
|
||||
|
||||
class HomeServer(object):
|
||||
@property
|
||||
|
@ -121,3 +122,9 @@ class HomeServer(object):
|
|||
pass
|
||||
def get_instance_id(self) -> str:
|
||||
pass
|
||||
def get_instance_name(self) -> str:
|
||||
pass
|
||||
def get_event_builder_factory(self) -> EventBuilderFactory:
|
||||
pass
|
||||
def get_storage(self) -> synapse.storage.Storage:
|
||||
pass
|
||||
|
|
|
@ -66,6 +66,7 @@ from .stats import StatsStore
|
|||
from .stream import StreamStore
|
||||
from .tags import TagsStore
|
||||
from .transactions import TransactionStore
|
||||
from .ui_auth import UIAuthStore
|
||||
from .user_directory import UserDirectoryStore
|
||||
from .user_erasure_store import UserErasureStore
|
||||
|
||||
|
@ -112,6 +113,7 @@ class DataStore(
|
|||
StatsStore,
|
||||
RelationsStore,
|
||||
CacheInvalidationStore,
|
||||
UIAuthStore,
|
||||
):
|
||||
def __init__(self, database: Database, db_conn, hs):
|
||||
self.hs = hs
|
||||
|
@ -503,7 +505,8 @@ class DataStore(
|
|||
self, start, limit, name=None, guests=True, deactivated=False
|
||||
):
|
||||
"""Function to retrieve a paginated list of users from
|
||||
users list. This will return a json list of users.
|
||||
users list. This will return a json list of users and the
|
||||
total number of users matching the filter criteria.
|
||||
|
||||
Args:
|
||||
start (int): start number to begin the query from
|
||||
|
@ -512,35 +515,44 @@ class DataStore(
|
|||
guests (bool): whether to in include guest users
|
||||
deactivated (bool): whether to include deactivated users
|
||||
Returns:
|
||||
defer.Deferred: resolves to list[dict[str, Any]]
|
||||
defer.Deferred: resolves to list[dict[str, Any]], int
|
||||
"""
|
||||
name_filter = {}
|
||||
if name:
|
||||
name_filter["name"] = "%" + name + "%"
|
||||
|
||||
attr_filter = {}
|
||||
if not guests:
|
||||
attr_filter["is_guest"] = 0
|
||||
if not deactivated:
|
||||
attr_filter["deactivated"] = 0
|
||||
def get_users_paginate_txn(txn):
|
||||
filters = []
|
||||
args = []
|
||||
|
||||
return self.db.simple_select_list_paginate(
|
||||
desc="get_users_paginate",
|
||||
table="users",
|
||||
orderby="name",
|
||||
start=start,
|
||||
limit=limit,
|
||||
filters=name_filter,
|
||||
keyvalues=attr_filter,
|
||||
retcols=[
|
||||
"name",
|
||||
"password_hash",
|
||||
"is_guest",
|
||||
"admin",
|
||||
"user_type",
|
||||
"deactivated",
|
||||
],
|
||||
)
|
||||
if name:
|
||||
filters.append("name LIKE ?")
|
||||
args.append("%" + name + "%")
|
||||
|
||||
if not guests:
|
||||
filters.append("is_guest = 0")
|
||||
|
||||
if not deactivated:
|
||||
filters.append("deactivated = 0")
|
||||
|
||||
where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
|
||||
|
||||
sql = "SELECT COUNT(*) as total_users FROM users %s" % (where_clause)
|
||||
txn.execute(sql, args)
|
||||
count = txn.fetchone()[0]
|
||||
|
||||
args = [self.hs.config.server_name] + args + [limit, start]
|
||||
sql = """
|
||||
SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url
|
||||
FROM users as u
|
||||
LEFT JOIN profiles AS p ON u.name = '@' || p.user_id || ':' || ?
|
||||
{}
|
||||
ORDER BY u.name LIMIT ? OFFSET ?
|
||||
""".format(
|
||||
where_clause
|
||||
)
|
||||
txn.execute(sql, args)
|
||||
users = self.db.cursor_to_dict(txn)
|
||||
return users, count
|
||||
|
||||
return self.db.runInteraction("get_users_paginate_txn", get_users_paginate_txn)
|
||||
|
||||
def search_users(self, term):
|
||||
"""Function to search users list for one or more users with
|
||||
|
|
|
@ -19,7 +19,7 @@ import itertools
|
|||
import logging
|
||||
import threading
|
||||
from collections import namedtuple
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from canonicaljson import json
|
||||
from constantly import NamedConstant, Names
|
||||
|
@ -1084,7 +1084,28 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
"get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
|
||||
)
|
||||
|
||||
def get_all_updated_current_state_deltas(self, from_token, to_token, limit):
|
||||
async def get_all_updated_current_state_deltas(
|
||||
self, from_token: int, to_token: int, target_row_count: int
|
||||
) -> Tuple[List[Tuple], int, bool]:
|
||||
"""Fetch updates from current_state_delta_stream
|
||||
|
||||
Args:
|
||||
from_token: The previous stream token. Updates from this stream id will
|
||||
be excluded.
|
||||
|
||||
to_token: The current stream token (ie the upper limit). Updates up to this
|
||||
stream id will be included (modulo the 'limit' param)
|
||||
|
||||
target_row_count: The number of rows to try to return. If more rows are
|
||||
available, we will set 'limited' in the result. In the event of a large
|
||||
batch, we may return more rows than this.
|
||||
Returns:
|
||||
A triplet `(updates, new_last_token, limited)`, where:
|
||||
* `updates` is a list of database tuples.
|
||||
* `new_last_token` is the new position in stream.
|
||||
* `limited` is whether there are more updates to fetch.
|
||||
"""
|
||||
|
||||
def get_all_updated_current_state_deltas_txn(txn):
|
||||
sql = """
|
||||
SELECT stream_id, room_id, type, state_key, event_id
|
||||
|
@ -1092,10 +1113,45 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
WHERE ? < stream_id AND stream_id <= ?
|
||||
ORDER BY stream_id ASC LIMIT ?
|
||||
"""
|
||||
txn.execute(sql, (from_token, to_token, limit))
|
||||
txn.execute(sql, (from_token, to_token, target_row_count))
|
||||
return txn.fetchall()
|
||||
|
||||
return self.db.runInteraction(
|
||||
def get_deltas_for_stream_id_txn(txn, stream_id):
|
||||
sql = """
|
||||
SELECT stream_id, room_id, type, state_key, event_id
|
||||
FROM current_state_delta_stream
|
||||
WHERE stream_id = ?
|
||||
"""
|
||||
txn.execute(sql, [stream_id])
|
||||
return txn.fetchall()
|
||||
|
||||
# we need to make sure that, for every stream id in the results, we get *all*
|
||||
# the rows with that stream id.
|
||||
|
||||
rows = await self.db.runInteraction(
|
||||
"get_all_updated_current_state_deltas",
|
||||
get_all_updated_current_state_deltas_txn,
|
||||
) # type: List[Tuple]
|
||||
|
||||
# if we've got fewer rows than the limit, we're good
|
||||
if len(rows) < target_row_count:
|
||||
return rows, to_token, False
|
||||
|
||||
# we hit the limit, so reduce the upper limit so that we exclude the stream id
|
||||
# of the last row in the result.
|
||||
assert rows[-1][0] <= to_token
|
||||
to_token = rows[-1][0] - 1
|
||||
|
||||
# search backwards through the list for the point to truncate
|
||||
for idx in range(len(rows) - 1, 0, -1):
|
||||
if rows[idx - 1][0] <= to_token:
|
||||
return rows[:idx], to_token, True
|
||||
|
||||
# bother. We didn't get a full set of changes for even a single
|
||||
# stream id. let's run the query again, without a row limit, but for
|
||||
# just one stream id.
|
||||
to_token += 1
|
||||
rows = await self.db.runInteraction(
|
||||
"get_deltas_for_stream_id", get_deltas_for_stream_id_txn, to_token
|
||||
)
|
||||
return rows, to_token, True
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
/* Copyright 2020 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
CREATE TABLE IF NOT EXISTS ui_auth_sessions(
|
||||
session_id TEXT NOT NULL, -- The session ID passed to the client.
|
||||
creation_time BIGINT NOT NULL, -- The time this session was created (epoch time in milliseconds).
|
||||
serverdict TEXT NOT NULL, -- A JSON dictionary of arbitrary data added by Synapse.
|
||||
clientdict TEXT NOT NULL, -- A JSON dictionary of arbitrary data from the client.
|
||||
uri TEXT NOT NULL, -- The URI the UI authentication session is using.
|
||||
method TEXT NOT NULL, -- The HTTP method the UI authentication session is using.
|
||||
-- The clientdict, uri, and method make up an tuple that must be immutable
|
||||
-- throughout the lifetime of the UI Auth session.
|
||||
description TEXT NOT NULL, -- A human readable description of the operation which caused the UI Auth flow to occur.
|
||||
UNIQUE (session_id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS ui_auth_sessions_credentials(
|
||||
session_id TEXT NOT NULL, -- The corresponding UI Auth session.
|
||||
stage_type TEXT NOT NULL, -- The stage type.
|
||||
result TEXT NOT NULL, -- The result of the stage verification, stored as JSON.
|
||||
UNIQUE (session_id, stage_type),
|
||||
FOREIGN KEY (session_id)
|
||||
REFERENCES ui_auth_sessions (session_id)
|
||||
);
|
|
@ -0,0 +1,279 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2020 Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import attr
|
||||
|
||||
import synapse.util.stringutils as stringutils
|
||||
from synapse.api.errors import StoreError
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.types import JsonDict
|
||||
|
||||
|
||||
@attr.s
|
||||
class UIAuthSessionData:
|
||||
session_id = attr.ib(type=str)
|
||||
# The dictionary from the client root level, not the 'auth' key.
|
||||
clientdict = attr.ib(type=JsonDict)
|
||||
# The URI and method the session was intiatied with. These are checked at
|
||||
# each stage of the authentication to ensure that the asked for operation
|
||||
# has not changed.
|
||||
uri = attr.ib(type=str)
|
||||
method = attr.ib(type=str)
|
||||
# A string description of the operation that the current authentication is
|
||||
# authorising.
|
||||
description = attr.ib(type=str)
|
||||
|
||||
|
||||
class UIAuthWorkerStore(SQLBaseStore):
|
||||
"""
|
||||
Manage user interactive authentication sessions.
|
||||
"""
|
||||
|
||||
async def create_ui_auth_session(
|
||||
self, clientdict: JsonDict, uri: str, method: str, description: str,
|
||||
) -> UIAuthSessionData:
|
||||
"""
|
||||
Creates a new user interactive authentication session.
|
||||
|
||||
The session can be used to track the stages necessary to authenticate a
|
||||
user across multiple HTTP requests.
|
||||
|
||||
Args:
|
||||
clientdict:
|
||||
The dictionary from the client root level, not the 'auth' key.
|
||||
uri:
|
||||
The URI this session was initiated with, this is checked at each
|
||||
stage of the authentication to ensure that the asked for
|
||||
operation has not changed.
|
||||
method:
|
||||
The method this session was initiated with, this is checked at each
|
||||
stage of the authentication to ensure that the asked for
|
||||
operation has not changed.
|
||||
description:
|
||||
A string description of the operation that the current
|
||||
authentication is authorising.
|
||||
Returns:
|
||||
The newly created session.
|
||||
Raises:
|
||||
StoreError if a unique session ID cannot be generated.
|
||||
"""
|
||||
# The clientdict gets stored as JSON.
|
||||
clientdict_json = json.dumps(clientdict)
|
||||
|
||||
# autogen a session ID and try to create it. We may clash, so just
|
||||
# try a few times till one goes through, giving up eventually.
|
||||
attempts = 0
|
||||
while attempts < 5:
|
||||
session_id = stringutils.random_string(24)
|
||||
|
||||
try:
|
||||
await self.db.simple_insert(
|
||||
table="ui_auth_sessions",
|
||||
values={
|
||||
"session_id": session_id,
|
||||
"clientdict": clientdict_json,
|
||||
"uri": uri,
|
||||
"method": method,
|
||||
"description": description,
|
||||
"serverdict": "{}",
|
||||
"creation_time": self.hs.get_clock().time_msec(),
|
||||
},
|
||||
desc="create_ui_auth_session",
|
||||
)
|
||||
return UIAuthSessionData(
|
||||
session_id, clientdict, uri, method, description
|
||||
)
|
||||
except self.db.engine.module.IntegrityError:
|
||||
attempts += 1
|
||||
raise StoreError(500, "Couldn't generate a session ID.")
|
||||
|
||||
async def get_ui_auth_session(self, session_id: str) -> UIAuthSessionData:
|
||||
"""Retrieve a UI auth session.
|
||||
|
||||
Args:
|
||||
session_id: The ID of the session.
|
||||
Returns:
|
||||
A dict containing the device information.
|
||||
Raises:
|
||||
StoreError if the session is not found.
|
||||
"""
|
||||
result = await self.db.simple_select_one(
|
||||
table="ui_auth_sessions",
|
||||
keyvalues={"session_id": session_id},
|
||||
retcols=("clientdict", "uri", "method", "description"),
|
||||
desc="get_ui_auth_session",
|
||||
)
|
||||
|
||||
result["clientdict"] = json.loads(result["clientdict"])
|
||||
|
||||
return UIAuthSessionData(session_id, **result)
|
||||
|
||||
async def mark_ui_auth_stage_complete(
|
||||
self, session_id: str, stage_type: str, result: Union[str, bool, JsonDict],
|
||||
):
|
||||
"""
|
||||
Mark a session stage as completed.
|
||||
|
||||
Args:
|
||||
session_id: The ID of the corresponding session.
|
||||
stage_type: The completed stage type.
|
||||
result: The result of the stage verification.
|
||||
Raises:
|
||||
StoreError if the session cannot be found.
|
||||
"""
|
||||
# Add (or update) the results of the current stage to the database.
|
||||
#
|
||||
# Note that we need to allow for the same stage to complete multiple
|
||||
# times here so that registration is idempotent.
|
||||
try:
|
||||
await self.db.simple_upsert(
|
||||
table="ui_auth_sessions_credentials",
|
||||
keyvalues={"session_id": session_id, "stage_type": stage_type},
|
||||
values={"result": json.dumps(result)},
|
||||
desc="mark_ui_auth_stage_complete",
|
||||
)
|
||||
except self.db.engine.module.IntegrityError:
|
||||
raise StoreError(400, "Unknown session ID: %s" % (session_id,))
|
||||
|
||||
async def get_completed_ui_auth_stages(
|
||||
self, session_id: str
|
||||
) -> Dict[str, Union[str, bool, JsonDict]]:
|
||||
"""
|
||||
Retrieve the completed stages of a UI authentication session.
|
||||
|
||||
Args:
|
||||
session_id: The ID of the session.
|
||||
Returns:
|
||||
The completed stages mapped to the result of the verification of
|
||||
that auth-type.
|
||||
"""
|
||||
results = {}
|
||||
for row in await self.db.simple_select_list(
|
||||
table="ui_auth_sessions_credentials",
|
||||
keyvalues={"session_id": session_id},
|
||||
retcols=("stage_type", "result"),
|
||||
desc="get_completed_ui_auth_stages",
|
||||
):
|
||||
results[row["stage_type"]] = json.loads(row["result"])
|
||||
|
||||
return results
|
||||
|
||||
async def set_ui_auth_session_data(self, session_id: str, key: str, value: Any):
|
||||
"""
|
||||
Store a key-value pair into the sessions data associated with this
|
||||
request. This data is stored server-side and cannot be modified by
|
||||
the client.
|
||||
|
||||
Args:
|
||||
session_id: The ID of this session as returned from check_auth
|
||||
key: The key to store the data under
|
||||
value: The data to store
|
||||
Raises:
|
||||
StoreError if the session cannot be found.
|
||||
"""
|
||||
await self.db.runInteraction(
|
||||
"set_ui_auth_session_data",
|
||||
self._set_ui_auth_session_data_txn,
|
||||
session_id,
|
||||
key,
|
||||
value,
|
||||
)
|
||||
|
||||
def _set_ui_auth_session_data_txn(self, txn, session_id: str, key: str, value: Any):
|
||||
# Get the current value.
|
||||
result = self.db.simple_select_one_txn(
|
||||
txn,
|
||||
table="ui_auth_sessions",
|
||||
keyvalues={"session_id": session_id},
|
||||
retcols=("serverdict",),
|
||||
)
|
||||
|
||||
# Update it and add it back to the database.
|
||||
serverdict = json.loads(result["serverdict"])
|
||||
serverdict[key] = value
|
||||
|
||||
self.db.simple_update_one_txn(
|
||||
txn,
|
||||
table="ui_auth_sessions",
|
||||
keyvalues={"session_id": session_id},
|
||||
updatevalues={"serverdict": json.dumps(serverdict)},
|
||||
)
|
||||
|
||||
async def get_ui_auth_session_data(
|
||||
self, session_id: str, key: str, default: Optional[Any] = None
|
||||
) -> Any:
|
||||
"""
|
||||
Retrieve data stored with set_session_data
|
||||
|
||||
Args:
|
||||
session_id: The ID of this session as returned from check_auth
|
||||
key: The key to store the data under
|
||||
default: Value to return if the key has not been set
|
||||
Raises:
|
||||
StoreError if the session cannot be found.
|
||||
"""
|
||||
result = await self.db.simple_select_one(
|
||||
table="ui_auth_sessions",
|
||||
keyvalues={"session_id": session_id},
|
||||
retcols=("serverdict",),
|
||||
desc="get_ui_auth_session_data",
|
||||
)
|
||||
|
||||
serverdict = json.loads(result["serverdict"])
|
||||
|
||||
return serverdict.get(key, default)
|
||||
|
||||
|
||||
class UIAuthStore(UIAuthWorkerStore):
|
||||
def delete_old_ui_auth_sessions(self, expiration_time: int):
|
||||
"""
|
||||
Remove sessions which were last used earlier than the expiration time.
|
||||
|
||||
Args:
|
||||
expiration_time: The latest time that is still considered valid.
|
||||
This is an epoch time in milliseconds.
|
||||
|
||||
"""
|
||||
return self.db.runInteraction(
|
||||
"delete_old_ui_auth_sessions",
|
||||
self._delete_old_ui_auth_sessions_txn,
|
||||
expiration_time,
|
||||
)
|
||||
|
||||
def _delete_old_ui_auth_sessions_txn(self, txn, expiration_time: int):
|
||||
# Get the expired sessions.
|
||||
sql = "SELECT session_id FROM ui_auth_sessions WHERE creation_time <= ?"
|
||||
txn.execute(sql, [expiration_time])
|
||||
session_ids = [r[0] for r in txn.fetchall()]
|
||||
|
||||
# Delete the corresponding completed credentials.
|
||||
self.db.simple_delete_many_txn(
|
||||
txn,
|
||||
table="ui_auth_sessions_credentials",
|
||||
column="session_id",
|
||||
iterable=session_ids,
|
||||
keyvalues={},
|
||||
)
|
||||
|
||||
# Finally, delete the sessions.
|
||||
self.db.simple_delete_many_txn(
|
||||
txn,
|
||||
table="ui_auth_sessions",
|
||||
column="session_id",
|
||||
iterable=session_ids,
|
||||
keyvalues={},
|
||||
)
|
|
@ -85,6 +85,7 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
|
|||
prepare_database(db_conn, self, config=None)
|
||||
|
||||
db_conn.create_function("rank", 1, _rank)
|
||||
db_conn.execute("PRAGMA foreign_keys = ON;")
|
||||
|
||||
def is_deadlock(self, error):
|
||||
return False
|
||||
|
|
|
@ -57,6 +57,7 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
|
|||
# We now do some gut wrenching so that we have a client that is based
|
||||
# off of the slave store rather than the main store.
|
||||
self.replication_handler = ReplicationCommandHandler(self.hs)
|
||||
self.replication_handler._instance_name = "worker"
|
||||
self.replication_handler._replication_data_handler = ReplicationDataHandler(
|
||||
self.slaved_store
|
||||
)
|
||||
|
|
|
@ -13,38 +13,72 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from mock import Mock
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import attr
|
||||
|
||||
from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
|
||||
from twisted.internet.task import LoopingCall
|
||||
from twisted.web.http import HTTPChannel
|
||||
|
||||
from synapse.app.generic_worker import GenericWorkerServer
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||
from synapse.replication.tcp.client import ReplicationDataHandler
|
||||
from synapse.replication.tcp.handler import ReplicationCommandHandler
|
||||
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
|
||||
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
from tests.server import FakeTransport
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseStreamTestCase(unittest.HomeserverTestCase):
|
||||
"""Base class for tests of the replication streams"""
|
||||
|
||||
def make_homeserver(self, reactor, clock):
|
||||
self.test_handler = Mock(wraps=TestReplicationDataHandler())
|
||||
return self.setup_test_homeserver(replication_data_handler=self.test_handler)
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
# build a replication server
|
||||
server_factory = ReplicationStreamProtocolFactory(hs)
|
||||
self.streamer = hs.get_replication_streamer()
|
||||
self.server = server_factory.buildProtocol(None)
|
||||
|
||||
repl_handler = ReplicationCommandHandler(hs)
|
||||
repl_handler.handler = self.test_handler
|
||||
# Make a new HomeServer object for the worker
|
||||
config = self.default_config()
|
||||
config["worker_app"] = "synapse.app.generic_worker"
|
||||
config["worker_replication_host"] = "testserv"
|
||||
config["worker_replication_http_port"] = "8765"
|
||||
|
||||
self.reactor.lookups["testserv"] = "1.2.3.4"
|
||||
|
||||
self.worker_hs = self.setup_test_homeserver(
|
||||
http_client=None,
|
||||
homeserverToUse=GenericWorkerServer,
|
||||
config=config,
|
||||
reactor=self.reactor,
|
||||
)
|
||||
|
||||
# Since we use sqlite in memory databases we need to make sure the
|
||||
# databases objects are the same.
|
||||
self.worker_hs.get_datastore().db = hs.get_datastore().db
|
||||
|
||||
self.test_handler = self._build_replication_data_handler()
|
||||
self.worker_hs.replication_data_handler = self.test_handler
|
||||
|
||||
repl_handler = ReplicationCommandHandler(self.worker_hs)
|
||||
self.client = ClientReplicationStreamProtocol(
|
||||
hs, "client", "test", clock, repl_handler,
|
||||
self.worker_hs, "client", "test", clock, repl_handler,
|
||||
)
|
||||
|
||||
self._client_transport = None
|
||||
self._server_transport = None
|
||||
|
||||
def _build_replication_data_handler(self):
|
||||
return TestReplicationDataHandler(self.worker_hs.get_datastore())
|
||||
|
||||
def reconnect(self):
|
||||
if self._client_transport:
|
||||
self.client.close()
|
||||
|
@ -74,24 +108,204 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
|||
self.streamer.on_notifier_poke()
|
||||
self.pump(0.1)
|
||||
|
||||
def handle_http_replication_attempt(self) -> SynapseRequest:
|
||||
"""Asserts that a connection attempt was made to the master HS on the
|
||||
HTTP replication port, then proxies it to the master HS object to be
|
||||
handled.
|
||||
|
||||
class TestReplicationDataHandler:
|
||||
Returns:
|
||||
The request object received by master HS.
|
||||
"""
|
||||
|
||||
# We should have an outbound connection attempt.
|
||||
clients = self.reactor.tcpClients
|
||||
self.assertEqual(len(clients), 1)
|
||||
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
|
||||
self.assertEqual(host, "1.2.3.4")
|
||||
self.assertEqual(port, 8765)
|
||||
|
||||
# Set up client side protocol
|
||||
client_protocol = client_factory.buildProtocol(None)
|
||||
|
||||
request_factory = OneShotRequestFactory()
|
||||
|
||||
# Set up the server side protocol
|
||||
channel = _PushHTTPChannel(self.reactor)
|
||||
channel.requestFactory = request_factory
|
||||
channel.site = self.site
|
||||
|
||||
# Connect client to server and vice versa.
|
||||
client_to_server_transport = FakeTransport(
|
||||
channel, self.reactor, client_protocol
|
||||
)
|
||||
client_protocol.makeConnection(client_to_server_transport)
|
||||
|
||||
server_to_client_transport = FakeTransport(
|
||||
client_protocol, self.reactor, channel
|
||||
)
|
||||
channel.makeConnection(server_to_client_transport)
|
||||
|
||||
# The request will now be processed by `self.site` and the response
|
||||
# streamed back.
|
||||
self.reactor.advance(0)
|
||||
|
||||
# We tear down the connection so it doesn't get reused without our
|
||||
# knowledge.
|
||||
server_to_client_transport.loseConnection()
|
||||
client_to_server_transport.loseConnection()
|
||||
|
||||
return request_factory.request
|
||||
|
||||
def assert_request_is_get_repl_stream_updates(
|
||||
self, request: SynapseRequest, stream_name: str
|
||||
):
|
||||
"""Asserts that the given request is a HTTP replication request for
|
||||
fetching updates for given stream.
|
||||
"""
|
||||
|
||||
self.assertRegex(
|
||||
request.path,
|
||||
br"^/_synapse/replication/get_repl_stream_updates/%s/[^/]+$"
|
||||
% (stream_name.encode("ascii"),),
|
||||
)
|
||||
|
||||
self.assertEqual(request.method, b"GET")
|
||||
|
||||
|
||||
class TestReplicationDataHandler(ReplicationDataHandler):
|
||||
"""Drop-in for ReplicationDataHandler which just collects RDATA rows"""
|
||||
|
||||
def __init__(self):
|
||||
self.streams = set()
|
||||
self._received_rdata_rows = []
|
||||
def __init__(self, store: BaseSlavedStore):
|
||||
super().__init__(store)
|
||||
|
||||
# streams to subscribe to: map from stream id to position
|
||||
self.stream_positions = {} # type: Dict[str, int]
|
||||
|
||||
# list of received (stream_name, token, row) tuples
|
||||
self.received_rdata_rows = [] # type: List[Tuple[str, int, Any]]
|
||||
|
||||
def get_streams_to_replicate(self):
|
||||
positions = {s: 0 for s in self.streams}
|
||||
for stream, token, _ in self._received_rdata_rows:
|
||||
if stream in self.streams:
|
||||
positions[stream] = max(token, positions.get(stream, 0))
|
||||
return positions
|
||||
return self.stream_positions
|
||||
|
||||
async def on_rdata(self, stream_name, token, rows):
|
||||
await super().on_rdata(stream_name, token, rows)
|
||||
for r in rows:
|
||||
self._received_rdata_rows.append((stream_name, token, r))
|
||||
self.received_rdata_rows.append((stream_name, token, r))
|
||||
|
||||
async def on_position(self, stream_name, token):
|
||||
pass
|
||||
if (
|
||||
stream_name in self.stream_positions
|
||||
and token > self.stream_positions[stream_name]
|
||||
):
|
||||
self.stream_positions[stream_name] = token
|
||||
|
||||
|
||||
@attr.s()
|
||||
class OneShotRequestFactory:
|
||||
"""A simple request factory that generates a single `SynapseRequest` and
|
||||
stores it for future use. Can only be used once.
|
||||
"""
|
||||
|
||||
request = attr.ib(default=None)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
assert self.request is None
|
||||
|
||||
self.request = SynapseRequest(*args, **kwargs)
|
||||
return self.request
|
||||
|
||||
|
||||
class _PushHTTPChannel(HTTPChannel):
|
||||
"""A HTTPChannel that wraps pull producers to push producers.
|
||||
|
||||
This is a hack to get around the fact that HTTPChannel transparently wraps a
|
||||
pull producer (which is what Synapse uses to reply to requests) with
|
||||
`_PullToPush` to convert it to a push producer. Unfortunately `_PullToPush`
|
||||
uses the standard reactor rather than letting us use our test reactor, which
|
||||
makes it very hard to test.
|
||||
"""
|
||||
|
||||
def __init__(self, reactor: IReactorTime):
|
||||
super().__init__()
|
||||
self.reactor = reactor
|
||||
|
||||
self._pull_to_push_producer = None # type: Optional[_PullToPushProducer]
|
||||
|
||||
def registerProducer(self, producer, streaming):
|
||||
# Convert pull producers to push producer.
|
||||
if not streaming:
|
||||
self._pull_to_push_producer = _PullToPushProducer(
|
||||
self.reactor, producer, self
|
||||
)
|
||||
producer = self._pull_to_push_producer
|
||||
|
||||
super().registerProducer(producer, True)
|
||||
|
||||
def unregisterProducer(self):
|
||||
if self._pull_to_push_producer:
|
||||
# We need to manually stop the _PullToPushProducer.
|
||||
self._pull_to_push_producer.stop()
|
||||
|
||||
|
||||
class _PullToPushProducer:
|
||||
"""A push producer that wraps a pull producer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, reactor: IReactorTime, producer: IPullProducer, consumer: IConsumer
|
||||
):
|
||||
self._clock = Clock(reactor)
|
||||
self._producer = producer
|
||||
self._consumer = consumer
|
||||
|
||||
# While running we use a looping call with a zero delay to call
|
||||
# resumeProducing on given producer.
|
||||
self._looping_call = None # type: Optional[LoopingCall]
|
||||
|
||||
# We start writing next reactor tick.
|
||||
self._start_loop()
|
||||
|
||||
def _start_loop(self):
|
||||
"""Start the looping call to
|
||||
"""
|
||||
|
||||
if not self._looping_call:
|
||||
# Start a looping call which runs every tick.
|
||||
self._looping_call = self._clock.looping_call(self._run_once, 0)
|
||||
|
||||
def stop(self):
|
||||
"""Stops calling resumeProducing.
|
||||
"""
|
||||
if self._looping_call:
|
||||
self._looping_call.stop()
|
||||
self._looping_call = None
|
||||
|
||||
def pauseProducing(self):
|
||||
"""Implements IPushProducer
|
||||
"""
|
||||
self.stop()
|
||||
|
||||
def resumeProducing(self):
|
||||
"""Implements IPushProducer
|
||||
"""
|
||||
self._start_loop()
|
||||
|
||||
def stopProducing(self):
|
||||
"""Implements IPushProducer
|
||||
"""
|
||||
self.stop()
|
||||
self._producer.stopProducing()
|
||||
|
||||
def _run_once(self):
|
||||
"""Calls resumeProducing on producer once.
|
||||
"""
|
||||
|
||||
try:
|
||||
self._producer.resumeProducing()
|
||||
except Exception:
|
||||
logger.exception("Failed to call resumeProducing")
|
||||
try:
|
||||
self._consumer.unregisterProducer()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self.stopProducing()
|
||||
|
|
|
@ -0,0 +1,417 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2019 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.events import EventBase
|
||||
from synapse.replication.tcp.streams._base import _STREAM_UPDATE_TARGET_ROW_COUNT
|
||||
from synapse.replication.tcp.streams.events import (
|
||||
EventsStreamCurrentStateRow,
|
||||
EventsStreamEventRow,
|
||||
EventsStreamRow,
|
||||
)
|
||||
from synapse.rest import admin
|
||||
from synapse.rest.client.v1 import login, room
|
||||
|
||||
from tests.replication.tcp.streams._base import BaseStreamTestCase
|
||||
from tests.test_utils.event_injection import inject_event, inject_member_event
|
||||
|
||||
|
||||
class EventsStreamTestCase(BaseStreamTestCase):
|
||||
servlets = [
|
||||
admin.register_servlets,
|
||||
login.register_servlets,
|
||||
room.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
super().prepare(reactor, clock, hs)
|
||||
self.user_id = self.register_user("u1", "pass")
|
||||
self.user_tok = self.login("u1", "pass")
|
||||
|
||||
self.reconnect()
|
||||
self.test_handler.stream_positions["events"] = 0
|
||||
|
||||
self.room_id = self.helper.create_room_as(tok=self.user_tok)
|
||||
self.test_handler.received_rdata_rows.clear()
|
||||
|
||||
def test_update_function_event_row_limit(self):
|
||||
"""Test replication with many non-state events
|
||||
|
||||
Checks that all events are correctly replicated when there are lots of
|
||||
event rows to be replicated.
|
||||
"""
|
||||
# disconnect, so that we can stack up some changes
|
||||
self.disconnect()
|
||||
|
||||
# generate lots of non-state events. We inject them using inject_event
|
||||
# so that they are not send out over replication until we call self.replicate().
|
||||
events = [
|
||||
self._inject_test_event()
|
||||
for _ in range(_STREAM_UPDATE_TARGET_ROW_COUNT + 1)
|
||||
]
|
||||
|
||||
# also one state event
|
||||
state_event = self._inject_state_event()
|
||||
|
||||
# tell the notifier to catch up to avoid duplicate rows.
|
||||
# workaround for https://github.com/matrix-org/synapse/issues/7360
|
||||
# FIXME remove this when the above is fixed
|
||||
self.replicate()
|
||||
|
||||
# check we're testing what we think we are: no rows should yet have been
|
||||
# received
|
||||
self.assertEqual([], self.test_handler.received_rdata_rows)
|
||||
|
||||
# now reconnect to pull the updates
|
||||
self.reconnect()
|
||||
self.replicate()
|
||||
|
||||
# we should have received all the expected rows in the right order
|
||||
received_rows = self.test_handler.received_rdata_rows
|
||||
for event in events:
|
||||
stream_name, token, row = received_rows.pop(0)
|
||||
self.assertEqual("events", stream_name)
|
||||
self.assertIsInstance(row, EventsStreamRow)
|
||||
self.assertEqual(row.type, "ev")
|
||||
self.assertIsInstance(row.data, EventsStreamEventRow)
|
||||
self.assertEqual(row.data.event_id, event.event_id)
|
||||
|
||||
stream_name, token, row = received_rows.pop(0)
|
||||
self.assertIsInstance(row, EventsStreamRow)
|
||||
self.assertIsInstance(row.data, EventsStreamEventRow)
|
||||
self.assertEqual(row.data.event_id, state_event.event_id)
|
||||
|
||||
stream_name, token, row = received_rows.pop(0)
|
||||
self.assertEqual("events", stream_name)
|
||||
self.assertIsInstance(row, EventsStreamRow)
|
||||
self.assertEqual(row.type, "state")
|
||||
self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
|
||||
self.assertEqual(row.data.event_id, state_event.event_id)
|
||||
|
||||
self.assertEqual([], received_rows)
|
||||
|
||||
def test_update_function_huge_state_change(self):
|
||||
"""Test replication with many state events
|
||||
|
||||
Ensures that all events are correctly replicated when there are lots of
|
||||
state change rows to be replicated.
|
||||
"""
|
||||
|
||||
# we want to generate lots of state changes at a single stream ID.
|
||||
#
|
||||
# We do this by having two branches in the DAG. On one, we have a moderator
|
||||
# which that generates lots of state; on the other, we de-op the moderator,
|
||||
# thus invalidating all the state.
|
||||
|
||||
OTHER_USER = "@other_user:localhost"
|
||||
|
||||
# have the user join
|
||||
inject_member_event(self.hs, self.room_id, OTHER_USER, Membership.JOIN)
|
||||
|
||||
# Update existing power levels with mod at PL50
|
||||
pls = self.helper.get_state(
|
||||
self.room_id, EventTypes.PowerLevels, tok=self.user_tok
|
||||
)
|
||||
pls["users"][OTHER_USER] = 50
|
||||
self.helper.send_state(
|
||||
self.room_id, EventTypes.PowerLevels, pls, tok=self.user_tok,
|
||||
)
|
||||
|
||||
# this is the point in the DAG where we make a fork
|
||||
fork_point = self.get_success(
|
||||
self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id)
|
||||
) # type: List[str]
|
||||
|
||||
events = [
|
||||
self._inject_state_event(sender=OTHER_USER)
|
||||
for _ in range(_STREAM_UPDATE_TARGET_ROW_COUNT)
|
||||
]
|
||||
|
||||
self.replicate()
|
||||
# all those events and state changes should have landed
|
||||
self.assertGreaterEqual(
|
||||
len(self.test_handler.received_rdata_rows), 2 * len(events)
|
||||
)
|
||||
|
||||
# disconnect, so that we can stack up the changes
|
||||
self.disconnect()
|
||||
self.test_handler.received_rdata_rows.clear()
|
||||
|
||||
# a state event which doesn't get rolled back, to check that the state
|
||||
# before the huge update comes through ok
|
||||
state1 = self._inject_state_event()
|
||||
|
||||
# roll back all the state by de-modding the user
|
||||
prev_events = fork_point
|
||||
pls["users"][OTHER_USER] = 0
|
||||
pl_event = inject_event(
|
||||
self.hs,
|
||||
prev_event_ids=prev_events,
|
||||
type=EventTypes.PowerLevels,
|
||||
state_key="",
|
||||
sender=self.user_id,
|
||||
room_id=self.room_id,
|
||||
content=pls,
|
||||
)
|
||||
|
||||
# one more bit of state that doesn't get rolled back
|
||||
state2 = self._inject_state_event()
|
||||
|
||||
# tell the notifier to catch up to avoid duplicate rows.
|
||||
# workaround for https://github.com/matrix-org/synapse/issues/7360
|
||||
# FIXME remove this when the above is fixed
|
||||
self.replicate()
|
||||
|
||||
# check we're testing what we think we are: no rows should yet have been
|
||||
# received
|
||||
self.assertEqual([], self.test_handler.received_rdata_rows)
|
||||
|
||||
# now reconnect to pull the updates
|
||||
self.reconnect()
|
||||
self.replicate()
|
||||
|
||||
# now we should have received all the expected rows in the right order.
|
||||
#
|
||||
# we expect:
|
||||
#
|
||||
# - two rows for state1
|
||||
# - the PL event row, plus state rows for the PL event and each
|
||||
# of the states that got reverted.
|
||||
# - two rows for state2
|
||||
|
||||
received_rows = self.test_handler.received_rdata_rows
|
||||
|
||||
# first check the first two rows, which should be state1
|
||||
|
||||
stream_name, token, row = received_rows.pop(0)
|
||||
self.assertEqual("events", stream_name)
|
||||
self.assertIsInstance(row, EventsStreamRow)
|
||||
self.assertEqual(row.type, "ev")
|
||||
self.assertIsInstance(row.data, EventsStreamEventRow)
|
||||
self.assertEqual(row.data.event_id, state1.event_id)
|
||||
|
||||
stream_name, token, row = received_rows.pop(0)
|
||||
self.assertIsInstance(row, EventsStreamRow)
|
||||
self.assertEqual(row.type, "state")
|
||||
self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
|
||||
self.assertEqual(row.data.event_id, state1.event_id)
|
||||
|
||||
# now the last two rows, which should be state2
|
||||
stream_name, token, row = received_rows.pop(-2)
|
||||
self.assertEqual("events", stream_name)
|
||||
self.assertIsInstance(row, EventsStreamRow)
|
||||
self.assertEqual(row.type, "ev")
|
||||
self.assertIsInstance(row.data, EventsStreamEventRow)
|
||||
self.assertEqual(row.data.event_id, state2.event_id)
|
||||
|
||||
stream_name, token, row = received_rows.pop(-1)
|
||||
self.assertIsInstance(row, EventsStreamRow)
|
||||
self.assertEqual(row.type, "state")
|
||||
self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
|
||||
self.assertEqual(row.data.event_id, state2.event_id)
|
||||
|
||||
# that should leave us with the rows for the PL event
|
||||
self.assertEqual(len(received_rows), len(events) + 2)
|
||||
|
||||
stream_name, token, row = received_rows.pop(0)
|
||||
self.assertEqual("events", stream_name)
|
||||
self.assertIsInstance(row, EventsStreamRow)
|
||||
self.assertEqual(row.type, "ev")
|
||||
self.assertIsInstance(row.data, EventsStreamEventRow)
|
||||
self.assertEqual(row.data.event_id, pl_event.event_id)
|
||||
|
||||
# the state rows are unsorted
|
||||
state_rows = [] # type: List[EventsStreamCurrentStateRow]
|
||||
for stream_name, token, row in received_rows:
|
||||
self.assertEqual("events", stream_name)
|
||||
self.assertIsInstance(row, EventsStreamRow)
|
||||
self.assertEqual(row.type, "state")
|
||||
self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
|
||||
state_rows.append(row.data)
|
||||
|
||||
state_rows.sort(key=lambda r: r.state_key)
|
||||
|
||||
sr = state_rows.pop(0)
|
||||
self.assertEqual(sr.type, EventTypes.PowerLevels)
|
||||
self.assertEqual(sr.event_id, pl_event.event_id)
|
||||
for sr in state_rows:
|
||||
self.assertEqual(sr.type, "test_state_event")
|
||||
# "None" indicates the state has been deleted
|
||||
self.assertIsNone(sr.event_id)
|
||||
|
||||
def test_update_function_state_row_limit(self):
|
||||
"""Test replication with many state events over several stream ids.
|
||||
"""
|
||||
|
||||
# we want to generate lots of state changes, but for this test, we want to
|
||||
# spread out the state changes over a few stream IDs.
|
||||
#
|
||||
# We do this by having two branches in the DAG. On one, we have four moderators,
|
||||
# each of which that generates lots of state; on the other, we de-op the users,
|
||||
# thus invalidating all the state.
|
||||
|
||||
NUM_USERS = 4
|
||||
STATES_PER_USER = _STREAM_UPDATE_TARGET_ROW_COUNT // 4 + 1
|
||||
|
||||
user_ids = ["@user%i:localhost" % (i,) for i in range(NUM_USERS)]
|
||||
|
||||
# have the users join
|
||||
for u in user_ids:
|
||||
inject_member_event(self.hs, self.room_id, u, Membership.JOIN)
|
||||
|
||||
# Update existing power levels with mod at PL50
|
||||
pls = self.helper.get_state(
|
||||
self.room_id, EventTypes.PowerLevels, tok=self.user_tok
|
||||
)
|
||||
pls["users"].update({u: 50 for u in user_ids})
|
||||
self.helper.send_state(
|
||||
self.room_id, EventTypes.PowerLevels, pls, tok=self.user_tok,
|
||||
)
|
||||
|
||||
# this is the point in the DAG where we make a fork
|
||||
fork_point = self.get_success(
|
||||
self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id)
|
||||
) # type: List[str]
|
||||
|
||||
events = [] # type: List[EventBase]
|
||||
for user in user_ids:
|
||||
events.extend(
|
||||
self._inject_state_event(sender=user) for _ in range(STATES_PER_USER)
|
||||
)
|
||||
|
||||
self.replicate()
|
||||
|
||||
# all those events and state changes should have landed
|
||||
self.assertGreaterEqual(
|
||||
len(self.test_handler.received_rdata_rows), 2 * len(events)
|
||||
)
|
||||
|
||||
# disconnect, so that we can stack up the changes
|
||||
self.disconnect()
|
||||
self.test_handler.received_rdata_rows.clear()
|
||||
|
||||
# now roll back all that state by de-modding the users
|
||||
prev_events = fork_point
|
||||
pl_events = []
|
||||
for u in user_ids:
|
||||
pls["users"][u] = 0
|
||||
e = inject_event(
|
||||
self.hs,
|
||||
prev_event_ids=prev_events,
|
||||
type=EventTypes.PowerLevels,
|
||||
state_key="",
|
||||
sender=self.user_id,
|
||||
room_id=self.room_id,
|
||||
content=pls,
|
||||
)
|
||||
prev_events = [e.event_id]
|
||||
pl_events.append(e)
|
||||
|
||||
# tell the notifier to catch up to avoid duplicate rows.
|
||||
# workaround for https://github.com/matrix-org/synapse/issues/7360
|
||||
# FIXME remove this when the above is fixed
|
||||
self.replicate()
|
||||
|
||||
# check we're testing what we think we are: no rows should yet have been
|
||||
# received
|
||||
self.assertEqual([], self.test_handler.received_rdata_rows)
|
||||
|
||||
# now reconnect to pull the updates
|
||||
self.reconnect()
|
||||
self.replicate()
|
||||
|
||||
# we should have received all the expected rows in the right order
|
||||
|
||||
received_rows = self.test_handler.received_rdata_rows
|
||||
self.assertGreaterEqual(len(received_rows), len(events))
|
||||
for i in range(NUM_USERS):
|
||||
# for each user, we expect the PL event row, followed by state rows for
|
||||
# the PL event and each of the states that got reverted.
|
||||
stream_name, token, row = received_rows.pop(0)
|
||||
self.assertEqual("events", stream_name)
|
||||
self.assertIsInstance(row, EventsStreamRow)
|
||||
self.assertEqual(row.type, "ev")
|
||||
self.assertIsInstance(row.data, EventsStreamEventRow)
|
||||
self.assertEqual(row.data.event_id, pl_events[i].event_id)
|
||||
|
||||
# the state rows are unsorted
|
||||
state_rows = [] # type: List[EventsStreamCurrentStateRow]
|
||||
for j in range(STATES_PER_USER + 1):
|
||||
stream_name, token, row = received_rows.pop(0)
|
||||
self.assertEqual("events", stream_name)
|
||||
self.assertIsInstance(row, EventsStreamRow)
|
||||
self.assertEqual(row.type, "state")
|
||||
self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
|
||||
state_rows.append(row.data)
|
||||
|
||||
state_rows.sort(key=lambda r: r.state_key)
|
||||
|
||||
sr = state_rows.pop(0)
|
||||
self.assertEqual(sr.type, EventTypes.PowerLevels)
|
||||
self.assertEqual(sr.event_id, pl_events[i].event_id)
|
||||
for sr in state_rows:
|
||||
self.assertEqual(sr.type, "test_state_event")
|
||||
# "None" indicates the state has been deleted
|
||||
self.assertIsNone(sr.event_id)
|
||||
|
||||
self.assertEqual([], received_rows)
|
||||
|
||||
event_count = 0
|
||||
|
||||
def _inject_test_event(
|
||||
self, body: Optional[str] = None, sender: Optional[str] = None, **kwargs
|
||||
) -> EventBase:
|
||||
if sender is None:
|
||||
sender = self.user_id
|
||||
|
||||
if body is None:
|
||||
body = "event %i" % (self.event_count,)
|
||||
self.event_count += 1
|
||||
|
||||
return inject_event(
|
||||
self.hs,
|
||||
room_id=self.room_id,
|
||||
sender=sender,
|
||||
type="test_event",
|
||||
content={"body": body},
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def _inject_state_event(
|
||||
self,
|
||||
body: Optional[str] = None,
|
||||
state_key: Optional[str] = None,
|
||||
sender: Optional[str] = None,
|
||||
) -> EventBase:
|
||||
if sender is None:
|
||||
sender = self.user_id
|
||||
|
||||
if state_key is None:
|
||||
state_key = "state_%i" % (self.event_count,)
|
||||
self.event_count += 1
|
||||
|
||||
if body is None:
|
||||
body = "state event %s" % (state_key,)
|
||||
|
||||
return inject_event(
|
||||
self.hs,
|
||||
room_id=self.room_id,
|
||||
sender=sender,
|
||||
type="test_state_event",
|
||||
state_key=state_key,
|
||||
content={"body": body},
|
||||
)
|
|
@ -12,6 +12,11 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# type: ignore
|
||||
|
||||
from mock import Mock
|
||||
|
||||
from synapse.replication.tcp.streams._base import ReceiptsStream
|
||||
|
||||
from tests.replication.tcp.streams._base import BaseStreamTestCase
|
||||
|
@ -20,11 +25,14 @@ USER_ID = "@feeling:blue"
|
|||
|
||||
|
||||
class ReceiptsStreamTestCase(BaseStreamTestCase):
|
||||
def _build_replication_data_handler(self):
|
||||
return Mock(wraps=super()._build_replication_data_handler())
|
||||
|
||||
def test_receipt(self):
|
||||
self.reconnect()
|
||||
|
||||
# make the client subscribe to the receipts stream
|
||||
self.test_handler.streams.add("receipts")
|
||||
self.test_handler.stream_positions.update({"receipts": 0})
|
||||
|
||||
# tell the master to send a new receipt
|
||||
self.get_success(
|
||||
|
|
|
@ -0,0 +1,85 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from mock import Mock
|
||||
|
||||
from synapse.handlers.typing import RoomMember
|
||||
from synapse.replication.http import streams
|
||||
from synapse.replication.tcp.streams import TypingStream
|
||||
|
||||
from tests.replication.tcp.streams._base import BaseStreamTestCase
|
||||
|
||||
USER_ID = "@feeling:blue"
|
||||
|
||||
|
||||
class TypingStreamTestCase(BaseStreamTestCase):
|
||||
servlets = [
|
||||
streams.register_servlets,
|
||||
]
|
||||
|
||||
def _build_replication_data_handler(self):
|
||||
return Mock(wraps=super()._build_replication_data_handler())
|
||||
|
||||
def test_typing(self):
|
||||
typing = self.hs.get_typing_handler()
|
||||
|
||||
room_id = "!bar:blue"
|
||||
|
||||
self.reconnect()
|
||||
|
||||
# make the client subscribe to the typing stream
|
||||
self.test_handler.stream_positions.update({"typing": 0})
|
||||
|
||||
typing._push_update(member=RoomMember(room_id, USER_ID), typing=True)
|
||||
|
||||
self.reactor.advance(0)
|
||||
|
||||
# We should now see an attempt to connect to the master
|
||||
request = self.handle_http_replication_attempt()
|
||||
self.assert_request_is_get_repl_stream_updates(request, "typing")
|
||||
|
||||
self.test_handler.on_rdata.assert_called_once()
|
||||
stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
|
||||
self.assertEqual(stream_name, "typing")
|
||||
self.assertEqual(1, len(rdata_rows))
|
||||
row = rdata_rows[0] # type: TypingStream.TypingStreamRow
|
||||
self.assertEqual(room_id, row.room_id)
|
||||
self.assertEqual([USER_ID], row.user_ids)
|
||||
|
||||
# Now let's disconnect and insert some data.
|
||||
self.disconnect()
|
||||
|
||||
self.test_handler.on_rdata.reset_mock()
|
||||
|
||||
typing._push_update(member=RoomMember(room_id, USER_ID), typing=False)
|
||||
|
||||
self.test_handler.on_rdata.assert_not_called()
|
||||
|
||||
self.reconnect()
|
||||
self.pump(0.1)
|
||||
|
||||
# We should now see an attempt to connect to the master
|
||||
request = self.handle_http_replication_attempt()
|
||||
self.assert_request_is_get_repl_stream_updates(request, "typing")
|
||||
|
||||
# The from token should be the token from the last RDATA we got.
|
||||
self.assertEqual(int(request.args[b"from_token"][0]), token)
|
||||
|
||||
self.test_handler.on_rdata.assert_called_once()
|
||||
stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
|
||||
self.assertEqual(stream_name, "typing")
|
||||
self.assertEqual(1, len(rdata_rows))
|
||||
row = rdata_rows[0]
|
||||
self.assertEqual(room_id, row.room_id)
|
||||
self.assertEqual([], row.user_ids)
|
|
@ -28,15 +28,17 @@ class ParseCommandTestCase(TestCase):
|
|||
self.assertIsInstance(cmd, ReplicateCommand)
|
||||
|
||||
def test_parse_rdata(self):
|
||||
line = 'RDATA events 6287863 ["ev", ["$eventid", "!roomid", "type", null, null, null]]'
|
||||
line = 'RDATA events master 6287863 ["ev", ["$eventid", "!roomid", "type", null, null, null]]'
|
||||
cmd = parse_command_from_line(line)
|
||||
self.assertIsInstance(cmd, RdataCommand)
|
||||
self.assertEqual(cmd.stream_name, "events")
|
||||
self.assertEqual(cmd.instance_name, "master")
|
||||
self.assertEqual(cmd.token, 6287863)
|
||||
|
||||
def test_parse_rdata_batch(self):
|
||||
line = 'RDATA presence batch ["@foo:example.com", "online"]'
|
||||
line = 'RDATA presence master batch ["@foo:example.com", "online"]'
|
||||
cmd = parse_command_from_line(line)
|
||||
self.assertIsInstance(cmd, RdataCommand)
|
||||
self.assertEqual(cmd.stream_name, "presence")
|
||||
self.assertEqual(cmd.instance_name, "master")
|
||||
self.assertIsNone(cmd.token)
|
||||
|
|
|
@ -0,0 +1,62 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
from twisted.internet.interfaces import IProtocol
|
||||
from twisted.test.proto_helpers import StringTransport
|
||||
|
||||
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
|
||||
|
||||
from tests.unittest import HomeserverTestCase
|
||||
|
||||
|
||||
class RemoteServerUpTestCase(HomeserverTestCase):
|
||||
def prepare(self, reactor, clock, hs):
|
||||
self.factory = ReplicationStreamProtocolFactory(hs)
|
||||
|
||||
def _make_client(self) -> Tuple[IProtocol, StringTransport]:
|
||||
"""Create a new direct TCP replication connection
|
||||
"""
|
||||
|
||||
proto = self.factory.buildProtocol(("127.0.0.1", 0))
|
||||
transport = StringTransport()
|
||||
proto.makeConnection(transport)
|
||||
|
||||
# We can safely ignore the commands received during connection.
|
||||
self.pump()
|
||||
transport.clear()
|
||||
|
||||
return proto, transport
|
||||
|
||||
def test_relay(self):
|
||||
"""Test that Synapse will relay REMOTE_SERVER_UP commands to all
|
||||
other connections, but not the one that sent it.
|
||||
"""
|
||||
|
||||
proto1, transport1 = self._make_client()
|
||||
|
||||
# We shouldn't receive an echo.
|
||||
proto1.dataReceived(b"REMOTE_SERVER_UP example.com\n")
|
||||
self.pump()
|
||||
self.assertEqual(transport1.value(), b"")
|
||||
|
||||
# But we should see an echo if we connect another client
|
||||
proto2, transport2 = self._make_client()
|
||||
proto1.dataReceived(b"REMOTE_SERVER_UP example.com\n")
|
||||
|
||||
self.pump()
|
||||
self.assertEqual(transport1.value(), b"")
|
||||
self.assertEqual(transport2.value(), b"REMOTE_SERVER_UP example.com\n")
|
|
@ -360,6 +360,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(3, len(channel.json_body["users"]))
|
||||
self.assertEqual(3, channel.json_body["total"])
|
||||
|
||||
|
||||
class UserRestTestCase(unittest.HomeserverTestCase):
|
||||
|
@ -434,6 +435,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
|||
"admin": True,
|
||||
"displayname": "Bob's name",
|
||||
"threepids": [{"medium": "email", "address": "bob@bob.bob"}],
|
||||
"avatar_url": None,
|
||||
}
|
||||
)
|
||||
|
||||
|
|
|
@ -39,7 +39,7 @@ class RestHelper(object):
|
|||
resource = attr.ib()
|
||||
auth_user_id = attr.ib()
|
||||
|
||||
def create_room_as(self, room_creator, is_public=True, tok=None):
|
||||
def create_room_as(self, room_creator=None, is_public=True, tok=None):
|
||||
temp_id = self.auth_user_id
|
||||
self.auth_user_id = room_creator
|
||||
path = "/_matrix/client/r0/createRoom"
|
||||
|
|
|
@ -181,3 +181,43 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
|
|||
)
|
||||
self.render(request)
|
||||
self.assertEqual(channel.code, 403)
|
||||
|
||||
def test_complete_operation_unknown_session(self):
|
||||
"""
|
||||
Attempting to mark an invalid session as complete should error.
|
||||
"""
|
||||
|
||||
# Make the initial request to register. (Later on a different password
|
||||
# will be used.)
|
||||
request, channel = self.make_request(
|
||||
"POST",
|
||||
"register",
|
||||
{"username": "user", "type": "m.login.password", "password": "bar"},
|
||||
)
|
||||
self.render(request)
|
||||
|
||||
# Returns a 401 as per the spec
|
||||
self.assertEqual(request.code, 401)
|
||||
# Grab the session
|
||||
session = channel.json_body["session"]
|
||||
# Assert our configured public key is being given
|
||||
self.assertEqual(
|
||||
channel.json_body["params"]["m.login.recaptcha"]["public_key"], "brokencake"
|
||||
)
|
||||
|
||||
request, channel = self.make_request(
|
||||
"GET", "auth/m.login.recaptcha/fallback/web?session=" + session
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEqual(request.code, 200)
|
||||
|
||||
# Attempt to complete an unknown session, which should return an error.
|
||||
unknown_session = session + "unknown"
|
||||
request, channel = self.make_request(
|
||||
"POST",
|
||||
"auth/m.login.recaptcha/fallback/web?session="
|
||||
+ unknown_session
|
||||
+ "&g-recaptcha-response=a",
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEqual(request.code, 400)
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2020 Awesome Technologies Innovationslabor GmbH
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.types import UserID
|
||||
|
||||
from tests import unittest
|
||||
from tests.utils import setup_test_homeserver
|
||||
|
||||
|
||||
class DataStoreTestCase(unittest.TestCase):
|
||||
@defer.inlineCallbacks
|
||||
def setUp(self):
|
||||
hs = yield setup_test_homeserver(self.addCleanup)
|
||||
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
self.user = UserID.from_string("@abcde:test")
|
||||
self.displayname = "Frank"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_users_paginate(self):
|
||||
yield self.store.register_user(self.user.to_string(), "pass")
|
||||
yield self.store.create_profile(self.user.localpart)
|
||||
yield self.store.set_profile_displayname(self.user.localpart, self.displayname)
|
||||
|
||||
users, total = yield self.store.get_users_paginate(
|
||||
0, 10, name="bc", guests=False
|
||||
)
|
||||
|
||||
self.assertEquals(1, total)
|
||||
self.assertEquals(self.displayname, users.pop()["displayname"])
|
|
@ -1,5 +1,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2019 New Vector Ltd
|
||||
# Copyright 2020 The Matrix.org Foundation C.I.C
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -16,3 +17,22 @@
|
|||
"""
|
||||
Utilities for running the unit tests
|
||||
"""
|
||||
from typing import Awaitable, TypeVar
|
||||
|
||||
TV = TypeVar("TV")
|
||||
|
||||
|
||||
def get_awaitable_result(awaitable: Awaitable[TV]) -> TV:
|
||||
"""Get the result from an Awaitable which should have completed
|
||||
|
||||
Asserts that the given awaitable has a result ready, and returns its value
|
||||
"""
|
||||
i = awaitable.__await__()
|
||||
try:
|
||||
next(i)
|
||||
except StopIteration as e:
|
||||
# awaitable returned a result
|
||||
return e.value
|
||||
|
||||
# if next didn't raise, the awaitable hasn't completed.
|
||||
raise Exception("awaitable has not yet completed")
|
||||
|
|
|
@ -0,0 +1,96 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2018 New Vector Ltd
|
||||
# Copyright 2020 The Matrix.org Foundation C.I.C
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import synapse.server
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||
from synapse.events import EventBase
|
||||
from synapse.types import Collection
|
||||
|
||||
from tests.test_utils import get_awaitable_result
|
||||
|
||||
|
||||
"""
|
||||
Utility functions for poking events into the storage of the server under test.
|
||||
"""
|
||||
|
||||
|
||||
def inject_member_event(
|
||||
hs: synapse.server.HomeServer,
|
||||
room_id: str,
|
||||
sender: str,
|
||||
membership: str,
|
||||
target: Optional[str] = None,
|
||||
extra_content: Optional[dict] = None,
|
||||
**kwargs
|
||||
) -> EventBase:
|
||||
"""Inject a membership event into a room."""
|
||||
if target is None:
|
||||
target = sender
|
||||
|
||||
content = {"membership": membership}
|
||||
if extra_content:
|
||||
content.update(extra_content)
|
||||
|
||||
return inject_event(
|
||||
hs,
|
||||
room_id=room_id,
|
||||
type=EventTypes.Member,
|
||||
sender=sender,
|
||||
state_key=target,
|
||||
content=content,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
||||
def inject_event(
|
||||
hs: synapse.server.HomeServer,
|
||||
room_version: Optional[str] = None,
|
||||
prev_event_ids: Optional[Collection[str]] = None,
|
||||
**kwargs
|
||||
) -> EventBase:
|
||||
"""Inject a generic event into a room
|
||||
|
||||
Args:
|
||||
hs: the homeserver under test
|
||||
room_version: the version of the room we're inserting into.
|
||||
if not specified, will be looked up
|
||||
prev_event_ids: prev_events for the event. If not specified, will be looked up
|
||||
kwargs: fields for the event to be created
|
||||
"""
|
||||
test_reactor = hs.get_reactor()
|
||||
|
||||
if room_version is None:
|
||||
d = hs.get_datastore().get_room_version_id(kwargs["room_id"])
|
||||
test_reactor.advance(0)
|
||||
room_version = get_awaitable_result(d)
|
||||
|
||||
builder = hs.get_event_builder_factory().for_room_version(
|
||||
KNOWN_ROOM_VERSIONS[room_version], kwargs
|
||||
)
|
||||
d = hs.get_event_creation_handler().create_new_client_event(
|
||||
builder, prev_event_ids=prev_event_ids
|
||||
)
|
||||
test_reactor.advance(0)
|
||||
event, context = get_awaitable_result(d)
|
||||
|
||||
d = hs.get_storage().persistence.persist_event(event, context)
|
||||
test_reactor.advance(0)
|
||||
get_awaitable_result(d)
|
||||
|
||||
return event
|
|
@ -32,7 +32,6 @@ from twisted.python.threadpool import ThreadPool
|
|||
from twisted.trial import unittest
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.config.ratelimiting import FederationRateLimitConfig
|
||||
from synapse.federation.transport import server as federation_server
|
||||
|
@ -55,6 +54,7 @@ from tests.server import (
|
|||
render,
|
||||
setup_test_homeserver,
|
||||
)
|
||||
from tests.test_utils import event_injection
|
||||
from tests.test_utils.logging_setup import setup_logging
|
||||
from tests.utils import default_config, setupdb
|
||||
|
||||
|
@ -596,36 +596,14 @@ class HomeserverTestCase(TestCase):
|
|||
"""
|
||||
Inject a membership event into a room.
|
||||
|
||||
Deprecated: use event_injection.inject_room_member directly
|
||||
|
||||
Args:
|
||||
room: Room ID to inject the event into.
|
||||
user: MXID of the user to inject the membership for.
|
||||
membership: The membership type.
|
||||
"""
|
||||
event_builder_factory = self.hs.get_event_builder_factory()
|
||||
event_creation_handler = self.hs.get_event_creation_handler()
|
||||
|
||||
room_version = self.get_success(
|
||||
self.hs.get_datastore().get_room_version_id(room)
|
||||
)
|
||||
|
||||
builder = event_builder_factory.for_room_version(
|
||||
KNOWN_ROOM_VERSIONS[room_version],
|
||||
{
|
||||
"type": EventTypes.Member,
|
||||
"sender": user,
|
||||
"state_key": user,
|
||||
"room_id": room,
|
||||
"content": {"membership": membership},
|
||||
},
|
||||
)
|
||||
|
||||
event, context = self.get_success(
|
||||
event_creation_handler.create_new_client_event(builder)
|
||||
)
|
||||
|
||||
self.get_success(
|
||||
self.hs.get_storage().persistence.persist_event(event, context)
|
||||
)
|
||||
event_injection.inject_member_event(self.hs, room, user, membership)
|
||||
|
||||
|
||||
class FederatingHomeserverTestCase(HomeserverTestCase):
|
||||
|
|
|
@ -74,7 +74,10 @@ def setupdb():
|
|||
db_conn.autocommit = True
|
||||
cur = db_conn.cursor()
|
||||
cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,))
|
||||
cur.execute("CREATE DATABASE %s;" % (POSTGRES_BASE_DB,))
|
||||
cur.execute(
|
||||
"CREATE DATABASE %s ENCODING 'UTF8' LC_COLLATE='C' LC_CTYPE='C' "
|
||||
"template=template0;" % (POSTGRES_BASE_DB,)
|
||||
)
|
||||
cur.close()
|
||||
db_conn.close()
|
||||
|
||||
|
@ -509,8 +512,8 @@ class MockClock(object):
|
|||
|
||||
return t
|
||||
|
||||
def looping_call(self, function, interval):
|
||||
self.loopers.append([function, interval / 1000.0, self.now])
|
||||
def looping_call(self, function, interval, *args, **kwargs):
|
||||
self.loopers.append([function, interval / 1000.0, self.now, args, kwargs])
|
||||
|
||||
def cancel_call_later(self, timer, ignore_errs=False):
|
||||
if timer[2]:
|
||||
|
@ -540,9 +543,9 @@ class MockClock(object):
|
|||
self.timers.append(t)
|
||||
|
||||
for looped in self.loopers:
|
||||
func, interval, last = looped
|
||||
func, interval, last, args, kwargs = looped
|
||||
if last + interval < self.now:
|
||||
func()
|
||||
func(*args, **kwargs)
|
||||
looped[2] = self.now
|
||||
|
||||
def advance_time_msec(self, ms):
|
||||
|
|
5
tox.ini
5
tox.ini
|
@ -200,10 +200,13 @@ commands = mypy \
|
|||
synapse/replication \
|
||||
synapse/rest \
|
||||
synapse/spam_checker_api \
|
||||
synapse/storage/engines \
|
||||
synapse/storage/data_stores/main/ui_auth.py \
|
||||
synapse/storage/database.py \
|
||||
synapse/storage/engines \
|
||||
synapse/streams \
|
||||
synapse/util/caches/stream_change_cache.py \
|
||||
tests/replication/tcp/streams \
|
||||
tests/test_utils \
|
||||
tests/util/test_stream_change_cache.py
|
||||
|
||||
# To find all folders that pass mypy you run:
|
||||
|
|
Loading…
Reference in New Issue