Merge branch 'develop' into hs/super-wip-edus-down-sync
commit
6201ea56ee
Binary file not shown.
27
CHANGES.md
27
CHANGES.md
|
@ -1,3 +1,30 @@
|
|||
Synapse 1.20.1 (2020-09-24)
|
||||
===========================
|
||||
|
||||
Bugfixes
|
||||
--------
|
||||
|
||||
- Fix a bug introduced in v1.20.0 which caused the `synapse_port_db` script to fail. ([\#8386](https://github.com/matrix-org/synapse/issues/8386))
|
||||
- Fix a bug introduced in v1.20.0 which caused variables to be incorrectly escaped in Jinja2 templates. ([\#8394](https://github.com/matrix-org/synapse/issues/8394))
|
||||
|
||||
|
||||
Synapse 1.20.0 (2020-09-22)
|
||||
===========================
|
||||
|
||||
No significant changes since v1.20.0rc5.
|
||||
|
||||
Removal warning
|
||||
---------------
|
||||
|
||||
Historically, the [Synapse Admin
|
||||
API](https://github.com/matrix-org/synapse/tree/master/docs) has been
|
||||
accessible under the `/_matrix/client/api/v1/admin`,
|
||||
`/_matrix/client/unstable/admin`, `/_matrix/client/r0/admin` and
|
||||
`/_synapse/admin` prefixes. In a future release, we will be dropping support
|
||||
for accessing Synapse's Admin API using the `/_matrix/client/*` prefixes. This
|
||||
makes it easier for homeserver admins to lock down external access to the Admin
|
||||
API endpoints.
|
||||
|
||||
Synapse 1.20.0rc5 (2020-09-18)
|
||||
==============================
|
||||
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
Add an admin API `GET /_synapse/admin/v1/event_reports` to read entries of table `event_reports`. Contributed by @dklimpel.
|
|
@ -1 +1 @@
|
|||
Move lint-related dependencies to package-extra field, update CONTRIBUTING.md to utilise this.
|
||||
Move lint-related dependencies to package-extra field, update CONTRIBUTING.md to utilise this.
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
Add a configuration option that allows existing users to log in with OpenID Connect. Contributed by @BBBSnowball and @OmmyZhang.
|
|
@ -0,0 +1 @@
|
|||
Don't send push notifications to expired user accounts.
|
|
@ -0,0 +1 @@
|
|||
Fixed a regression in v1.19.0 with reactivating users through the admin API.
|
|
@ -0,0 +1,2 @@
|
|||
Fix a bug where during device registration the length of the device name wasn't
|
||||
limited.
|
|
@ -0,0 +1 @@
|
|||
Factor out a `_send_dummy_event_for_room` method.
|
|
@ -0,0 +1 @@
|
|||
Improve logging of state resolution.
|
|
@ -0,0 +1 @@
|
|||
Add type annotations to `SimpleHttpClient`.
|
|
@ -0,0 +1 @@
|
|||
Include `guest_access` in the fields that are checked for null bytes when updating `room_stats_state`. Broke in v1.7.2.
|
|
@ -0,0 +1 @@
|
|||
Fix theoretical race condition where events are not sent down `/sync` if the synchrotron worker is restarted without restarting other workers.
|
|
@ -0,0 +1 @@
|
|||
Add note to the reverse proxy settings documentation about disabling Apache's mod_security2. Contributed by Julian Fietkau (@jfietkau).
|
|
@ -0,0 +1 @@
|
|||
Move lint-related dependencies to package-extra field, update CONTRIBUTING.md to utilise this.
|
|
@ -0,0 +1 @@
|
|||
Refactor ID generators to use `async with` syntax.
|
|
@ -0,0 +1 @@
|
|||
Fix a bug which could cause errors in rooms with malformed membership events, on servers using sqlite.
|
|
@ -0,0 +1 @@
|
|||
Fix a bug introduced in v1.20.0 which caused the `synapse_port_db` script to fail.
|
|
@ -0,0 +1 @@
|
|||
Add experimental support for sharding event persister.
|
|
@ -0,0 +1 @@
|
|||
Add `EventStreamPosition` type.
|
|
@ -0,0 +1 @@
|
|||
Add experimental support for sharding event persister.
|
|
@ -0,0 +1 @@
|
|||
Fix "Re-starting finished log context" warning when receiving an event we already had over federation.
|
|
@ -0,0 +1 @@
|
|||
Consolidate the SSO error template across all configuration.
|
|
@ -1,8 +1,18 @@
|
|||
matrix-synapse-py3 (1.20.0ubuntu1) UNRELEASED; urgency=medium
|
||||
matrix-synapse-py3 (1.20.1) stable; urgency=medium
|
||||
|
||||
* New synapse release 1.20.1.
|
||||
|
||||
-- Synapse Packaging team <packages@matrix.org> Thu, 24 Sep 2020 16:25:22 +0100
|
||||
|
||||
matrix-synapse-py3 (1.20.0) stable; urgency=medium
|
||||
|
||||
[ Synapse Packaging team ]
|
||||
* New synapse release 1.20.0.
|
||||
|
||||
[ Dexter Chua ]
|
||||
* Use Type=notify in systemd service
|
||||
|
||||
-- Dexter Chua <dec41@srcf.net> Wed, 26 Aug 2020 12:41:36 +0000
|
||||
-- Synapse Packaging team <packages@matrix.org> Tue, 22 Sep 2020 15:19:32 +0100
|
||||
|
||||
matrix-synapse-py3 (1.19.3) stable; urgency=medium
|
||||
|
||||
|
|
|
@ -0,0 +1,129 @@
|
|||
Show reported events
|
||||
====================
|
||||
|
||||
This API returns information about reported events.
|
||||
|
||||
The api is::
|
||||
|
||||
GET /_synapse/admin/v1/event_reports?from=0&limit=10
|
||||
|
||||
To use it, you will need to authenticate by providing an ``access_token`` for a
|
||||
server admin: see `README.rst <README.rst>`_.
|
||||
|
||||
It returns a JSON body like the following:
|
||||
|
||||
.. code:: jsonc
|
||||
|
||||
{
|
||||
"event_reports": [
|
||||
{
|
||||
"content": {
|
||||
"reason": "foo",
|
||||
"score": -100
|
||||
},
|
||||
"event_id": "$bNUFCwGzWca1meCGkjp-zwslF-GfVcXukvRLI1_FaVY",
|
||||
"event_json": {
|
||||
"auth_events": [
|
||||
"$YK4arsKKcc0LRoe700pS8DSjOvUT4NDv0HfInlMFw2M",
|
||||
"$oggsNXxzPFRE3y53SUNd7nsj69-QzKv03a1RucHu-ws"
|
||||
],
|
||||
"content": {
|
||||
"body": "matrix.org: This Week in Matrix",
|
||||
"format": "org.matrix.custom.html",
|
||||
"formatted_body": "<strong>matrix.org</strong>:<br><a href=\"https://matrix.org/blog/\"><strong>This Week in Matrix</strong></a>",
|
||||
"msgtype": "m.notice"
|
||||
},
|
||||
"depth": 546,
|
||||
"hashes": {
|
||||
"sha256": "xK1//xnmvHJIOvbgXlkI8eEqdvoMmihVDJ9J4SNlsAw"
|
||||
},
|
||||
"origin": "matrix.org",
|
||||
"origin_server_ts": 1592291711430,
|
||||
"prev_events": [
|
||||
"$YK4arsKKcc0LRoe700pS8DSjOvUT4NDv0HfInlMFw2M"
|
||||
],
|
||||
"prev_state": [],
|
||||
"room_id": "!ERAgBpSOcCCuTJqQPk:matrix.org",
|
||||
"sender": "@foobar:matrix.org",
|
||||
"signatures": {
|
||||
"matrix.org": {
|
||||
"ed25519:a_JaEG": "cs+OUKW/iHx5pEidbWxh0UiNNHwe46Ai9LwNz+Ah16aWDNszVIe2gaAcVZfvNsBhakQTew51tlKmL2kspXk/Dg"
|
||||
}
|
||||
},
|
||||
"type": "m.room.message",
|
||||
"unsigned": {
|
||||
"age_ts": 1592291711430,
|
||||
}
|
||||
},
|
||||
"id": 2,
|
||||
"reason": "foo",
|
||||
"received_ts": 1570897107409,
|
||||
"room_alias": "#alias1:matrix.org",
|
||||
"room_id": "!ERAgBpSOcCCuTJqQPk:matrix.org",
|
||||
"sender": "@foobar:matrix.org",
|
||||
"user_id": "@foo:matrix.org"
|
||||
},
|
||||
{
|
||||
"content": {
|
||||
"reason": "bar",
|
||||
"score": -100
|
||||
},
|
||||
"event_id": "$3IcdZsDaN_En-S1DF4EMCy3v4gNRKeOJs8W5qTOKj4I",
|
||||
"event_json": {
|
||||
// hidden items
|
||||
// see above
|
||||
},
|
||||
"id": 3,
|
||||
"reason": "bar",
|
||||
"received_ts": 1598889612059,
|
||||
"room_alias": "#alias2:matrix.org",
|
||||
"room_id": "!eGvUQuTCkHGVwNMOjv:matrix.org",
|
||||
"sender": "@foobar:matrix.org",
|
||||
"user_id": "@bar:matrix.org"
|
||||
}
|
||||
],
|
||||
"next_token": 2,
|
||||
"total": 4
|
||||
}
|
||||
|
||||
To paginate, check for ``next_token`` and if present, call the endpoint again
|
||||
with ``from`` set to the value of ``next_token``. This will return a new page.
|
||||
|
||||
If the endpoint does not return a ``next_token`` then there are no more
|
||||
reports to paginate through.
|
||||
|
||||
**URL parameters:**
|
||||
|
||||
- ``limit``: integer - Is optional but is used for pagination,
|
||||
denoting the maximum number of items to return in this call. Defaults to ``100``.
|
||||
- ``from``: integer - Is optional but used for pagination,
|
||||
denoting the offset in the returned results. This should be treated as an opaque value and
|
||||
not explicitly set to anything other than the return value of ``next_token`` from a previous call.
|
||||
Defaults to ``0``.
|
||||
- ``dir``: string - Direction of event report order. Whether to fetch the most recent first (``b``) or the
|
||||
oldest first (``f``). Defaults to ``b``.
|
||||
- ``user_id``: string - Is optional and filters to only return users with user IDs that contain this value.
|
||||
This is the user who reported the event and wrote the reason.
|
||||
- ``room_id``: string - Is optional and filters to only return rooms with room IDs that contain this value.
|
||||
|
||||
**Response**
|
||||
|
||||
The following fields are returned in the JSON response body:
|
||||
|
||||
- ``id``: integer - ID of event report.
|
||||
- ``received_ts``: integer - The timestamp (in milliseconds since the unix epoch) when this report was sent.
|
||||
- ``room_id``: string - The ID of the room in which the event being reported is located.
|
||||
- ``event_id``: string - The ID of the reported event.
|
||||
- ``user_id``: string - This is the user who reported the event and wrote the reason.
|
||||
- ``reason``: string - Comment made by the ``user_id`` in this report. May be blank.
|
||||
- ``content``: object - Content of reported event.
|
||||
|
||||
- ``reason``: string - Comment made by the ``user_id`` in this report. May be blank.
|
||||
- ``score``: integer - Content is reported based upon a negative score, where -100 is "most offensive" and 0 is "inoffensive".
|
||||
|
||||
- ``sender``: string - This is the ID of the user who sent the original message/event that was reported.
|
||||
- ``room_alias``: string - The alias of the room. ``null`` if the room does not have a canonical alias set.
|
||||
- ``event_json``: object - Details of the original event that was reported.
|
||||
- ``next_token``: integer - Indication for pagination. See above.
|
||||
- ``total``: integer - Total number of event reports related to the query (``user_id`` and ``room_id``).
|
||||
|
|
@ -121,6 +121,14 @@ example.com:8448 {
|
|||
|
||||
**NOTE**: ensure the `nocanon` options are included.
|
||||
|
||||
**NOTE 2**: It appears that Synapse is currently incompatible with the ModSecurity module for Apache (`mod_security2`). If you need it enabled for other services on your web server, you can disable it for Synapse's two VirtualHosts by including the following lines before each of the two `</VirtualHost>` above:
|
||||
|
||||
```
|
||||
<IfModule security2_module>
|
||||
SecRuleEngine off
|
||||
</IfModule>
|
||||
```
|
||||
|
||||
### HAProxy
|
||||
|
||||
```
|
||||
|
|
|
@ -1689,6 +1689,11 @@ oidc_config:
|
|||
#
|
||||
#skip_verification: true
|
||||
|
||||
# Uncomment to allow a user logging in via OIDC to match a pre-existing account instead
|
||||
# of failing. This could be used if switching from password logins to OIDC. Defaults to false.
|
||||
#
|
||||
#allow_existing_users: true
|
||||
|
||||
# An external module can be provided here as a custom solution to mapping
|
||||
# attributes returned from a OIDC provider onto a matrix user.
|
||||
#
|
||||
|
|
|
@ -89,6 +89,7 @@ BOOLEAN_COLUMNS = {
|
|||
"redactions": ["have_censored"],
|
||||
"room_stats_state": ["is_federatable"],
|
||||
"local_media_repository": ["safe_from_quarantine"],
|
||||
"users": ["shadow_banned"],
|
||||
}
|
||||
|
||||
|
||||
|
@ -627,6 +628,7 @@ class Porter(object):
|
|||
self.progress.set_state("Setting up sequence generators")
|
||||
await self._setup_state_group_id_seq()
|
||||
await self._setup_user_id_seq()
|
||||
await self._setup_events_stream_seqs()
|
||||
|
||||
self.progress.done()
|
||||
except Exception as e:
|
||||
|
@ -803,6 +805,29 @@ class Porter(object):
|
|||
|
||||
return self.postgres_store.db_pool.runInteraction("setup_user_id_seq", r)
|
||||
|
||||
def _setup_events_stream_seqs(self):
|
||||
def r(txn):
|
||||
txn.execute("SELECT MAX(stream_ordering) FROM events")
|
||||
curr_id = txn.fetchone()[0]
|
||||
if curr_id:
|
||||
next_id = curr_id + 1
|
||||
txn.execute(
|
||||
"ALTER SEQUENCE events_stream_seq RESTART WITH %s", (next_id,)
|
||||
)
|
||||
|
||||
txn.execute("SELECT -MIN(stream_ordering) FROM events")
|
||||
curr_id = txn.fetchone()[0]
|
||||
if curr_id:
|
||||
next_id = curr_id + 1
|
||||
txn.execute(
|
||||
"ALTER SEQUENCE events_backfill_stream_seq RESTART WITH %s",
|
||||
(next_id,),
|
||||
)
|
||||
|
||||
return self.postgres_store.db_pool.runInteraction(
|
||||
"_setup_events_stream_seqs", r
|
||||
)
|
||||
|
||||
|
||||
##############################################
|
||||
# The following is simply UI stuff
|
||||
|
|
16
setup.py
16
setup.py
|
@ -94,6 +94,22 @@ ALL_OPTIONAL_REQUIREMENTS = dependencies["ALL_OPTIONAL_REQUIREMENTS"]
|
|||
# Make `pip install matrix-synapse[all]` install all the optional dependencies.
|
||||
CONDITIONAL_REQUIREMENTS["all"] = list(ALL_OPTIONAL_REQUIREMENTS)
|
||||
|
||||
# Developer dependencies should not get included in "all".
|
||||
#
|
||||
# We pin black so that our tests don't start failing on new releases.
|
||||
CONDITIONAL_REQUIREMENTS["lint"] = [
|
||||
"isort==5.0.3",
|
||||
"black==19.10b0",
|
||||
"flake8-comprehensions",
|
||||
"flake8",
|
||||
]
|
||||
|
||||
# Dependencies which are exclusively required by unit test code. This is
|
||||
# NOT a list of all modules that are necessary to run the unit tests.
|
||||
# Tests assume that all optional dependencies are installed.
|
||||
#
|
||||
# parameterized_class decorator was introduced in parameterized 0.7.0
|
||||
CONDITIONAL_REQUIREMENTS["test"] = ["mock>=2.0", "parameterized>=0.7.0"]
|
||||
|
||||
setup(
|
||||
name="matrix-synapse",
|
||||
|
|
|
@ -48,7 +48,7 @@ try:
|
|||
except ImportError:
|
||||
pass
|
||||
|
||||
__version__ = "1.20.0rc5"
|
||||
__version__ = "1.20.1"
|
||||
|
||||
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
|
||||
# We import here so that we don't have to install a bunch of deps when
|
||||
|
|
|
@ -218,11 +218,7 @@ class Auth:
|
|||
# Deny the request if the user account has expired.
|
||||
if self._account_validity.enabled and not allow_expired:
|
||||
user_id = user.to_string()
|
||||
expiration_ts = await self.store.get_expiration_ts_for_user(user_id)
|
||||
if (
|
||||
expiration_ts is not None
|
||||
and self.clock.time_msec() >= expiration_ts
|
||||
):
|
||||
if await self.store.is_account_expired(user_id, self.clock.time_msec()):
|
||||
raise AuthError(
|
||||
403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT
|
||||
)
|
||||
|
|
|
@ -178,7 +178,7 @@ class ApplicationServiceApi(SimpleHttpClient):
|
|||
urllib.parse.quote(protocol),
|
||||
)
|
||||
try:
|
||||
info = await self.get_json(uri, {})
|
||||
info = await self.get_json(uri)
|
||||
|
||||
if not _is_valid_3pe_metadata(info):
|
||||
logger.warning(
|
||||
|
|
|
@ -194,7 +194,10 @@ class Config:
|
|||
return file_stream.read()
|
||||
|
||||
def read_templates(
|
||||
self, filenames: List[str], custom_template_directory: Optional[str] = None,
|
||||
self,
|
||||
filenames: List[str],
|
||||
custom_template_directory: Optional[str] = None,
|
||||
autoescape: bool = False,
|
||||
) -> List[jinja2.Template]:
|
||||
"""Load a list of template files from disk using the given variables.
|
||||
|
||||
|
@ -210,6 +213,9 @@ class Config:
|
|||
custom_template_directory: A directory to try to look for the templates
|
||||
before using the default Synapse template directory instead.
|
||||
|
||||
autoescape: Whether to autoescape variables before inserting them into the
|
||||
template.
|
||||
|
||||
Raises:
|
||||
ConfigError: if the file's path is incorrect or otherwise cannot be read.
|
||||
|
||||
|
@ -233,7 +239,7 @@ class Config:
|
|||
search_directories.insert(0, custom_template_directory)
|
||||
|
||||
loader = jinja2.FileSystemLoader(search_directories)
|
||||
env = jinja2.Environment(loader=loader, autoescape=True)
|
||||
env = jinja2.Environment(loader=loader, autoescape=autoescape)
|
||||
|
||||
# Update the environment with our custom filters
|
||||
env.filters.update(
|
||||
|
|
|
@ -56,6 +56,7 @@ class OIDCConfig(Config):
|
|||
self.oidc_userinfo_endpoint = oidc_config.get("userinfo_endpoint")
|
||||
self.oidc_jwks_uri = oidc_config.get("jwks_uri")
|
||||
self.oidc_skip_verification = oidc_config.get("skip_verification", False)
|
||||
self.oidc_allow_existing_users = oidc_config.get("allow_existing_users", False)
|
||||
|
||||
ump_config = oidc_config.get("user_mapping_provider", {})
|
||||
ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER)
|
||||
|
@ -158,6 +159,11 @@ class OIDCConfig(Config):
|
|||
#
|
||||
#skip_verification: true
|
||||
|
||||
# Uncomment to allow a user logging in via OIDC to match a pre-existing account instead
|
||||
# of failing. This could be used if switching from password logins to OIDC. Defaults to false.
|
||||
#
|
||||
#allow_existing_users: true
|
||||
|
||||
# An external module can be provided here as a custom solution to mapping
|
||||
# attributes returned from a OIDC provider onto a matrix user.
|
||||
#
|
||||
|
|
|
@ -45,7 +45,11 @@ _TLS_VERSION_MAP = {
|
|||
|
||||
class ServerContextFactory(ContextFactory):
|
||||
"""Factory for PyOpenSSL SSL contexts that are used to handle incoming
|
||||
connections."""
|
||||
connections.
|
||||
|
||||
TODO: replace this with an implementation of IOpenSSLServerConnectionCreator,
|
||||
per https://github.com/matrix-org/synapse/issues/1691
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
# TODO: once pyOpenSSL exposes TLS_METHOD and SSL_CTX_set_min_proto_version,
|
||||
|
|
|
@ -42,7 +42,6 @@ from synapse.api.errors import (
|
|||
)
|
||||
from synapse.logging.context import (
|
||||
PreserveLoggingContext,
|
||||
current_context,
|
||||
make_deferred_yieldable,
|
||||
preserve_fn,
|
||||
run_in_background,
|
||||
|
@ -233,8 +232,6 @@ class Keyring:
|
|||
"""
|
||||
|
||||
try:
|
||||
ctx = current_context()
|
||||
|
||||
# map from server name to a set of outstanding request ids
|
||||
server_to_request_ids = {}
|
||||
|
||||
|
@ -265,12 +262,8 @@ class Keyring:
|
|||
|
||||
# if there are no more requests for this server, we can drop the lock.
|
||||
if not server_requests:
|
||||
with PreserveLoggingContext(ctx):
|
||||
logger.debug("Releasing key lookup lock on %s", server_name)
|
||||
|
||||
# ... but not immediately, as that can cause stack explosions if
|
||||
# we get a long queue of lookups.
|
||||
self.clock.call_later(0, drop_server_lock, server_name)
|
||||
logger.debug("Releasing key lookup lock on %s", server_name)
|
||||
drop_server_lock(server_name)
|
||||
|
||||
return res
|
||||
|
||||
|
@ -335,20 +328,32 @@ class Keyring:
|
|||
)
|
||||
|
||||
# look for any requests which weren't satisfied
|
||||
with PreserveLoggingContext():
|
||||
for verify_request in remaining_requests:
|
||||
verify_request.key_ready.errback(
|
||||
SynapseError(
|
||||
401,
|
||||
"No key for %s with ids in %s (min_validity %i)"
|
||||
% (
|
||||
verify_request.server_name,
|
||||
verify_request.key_ids,
|
||||
verify_request.minimum_valid_until_ts,
|
||||
),
|
||||
Codes.UNAUTHORIZED,
|
||||
)
|
||||
while remaining_requests:
|
||||
verify_request = remaining_requests.pop()
|
||||
rq_str = (
|
||||
"VerifyJsonRequest(server=%s, key_ids=%s, min_valid=%i)"
|
||||
% (
|
||||
verify_request.server_name,
|
||||
verify_request.key_ids,
|
||||
verify_request.minimum_valid_until_ts,
|
||||
)
|
||||
)
|
||||
|
||||
# If we run the errback immediately, it may cancel our
|
||||
# loggingcontext while we are still in it, so instead we
|
||||
# schedule it for the next time round the reactor.
|
||||
#
|
||||
# (this also ensures that we don't get a stack overflow if we
|
||||
# has a massive queue of lookups waiting for this server).
|
||||
self.clock.call_later(
|
||||
0,
|
||||
verify_request.key_ready.errback,
|
||||
SynapseError(
|
||||
401,
|
||||
"Failed to find any key to satisfy %s" % (rq_str,),
|
||||
Codes.UNAUTHORIZED,
|
||||
),
|
||||
)
|
||||
except Exception as err:
|
||||
# we don't really expect to get here, because any errors should already
|
||||
# have been caught and logged. But if we do, let's log the error and make
|
||||
|
@ -410,10 +415,23 @@ class Keyring:
|
|||
# key was not valid at this point
|
||||
continue
|
||||
|
||||
with PreserveLoggingContext():
|
||||
verify_request.key_ready.callback(
|
||||
(server_name, key_id, fetch_key_result.verify_key)
|
||||
)
|
||||
# we have a valid key for this request. If we run the callback
|
||||
# immediately, it may cancel our loggingcontext while we are still in
|
||||
# it, so instead we schedule it for the next time round the reactor.
|
||||
#
|
||||
# (this also ensures that we don't get a stack overflow if we had
|
||||
# a massive queue of lookups waiting for this server).
|
||||
logger.debug(
|
||||
"Found key %s:%s for %s",
|
||||
server_name,
|
||||
key_id,
|
||||
verify_request.request_name,
|
||||
)
|
||||
self.clock.call_later(
|
||||
0,
|
||||
verify_request.key_ready.callback,
|
||||
(server_name, key_id, fetch_key_result.verify_key),
|
||||
)
|
||||
completed.append(verify_request)
|
||||
break
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ from typing import Any, Dict, List, Optional
|
|||
from synapse.api import errors
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.errors import (
|
||||
Codes,
|
||||
FederationDeniedError,
|
||||
HttpResponseException,
|
||||
RequestSendFailed,
|
||||
|
@ -265,6 +266,24 @@ class DeviceHandler(DeviceWorkerHandler):
|
|||
|
||||
hs.get_distributor().observe("user_left_room", self.user_left_room)
|
||||
|
||||
def _check_device_name_length(self, name: str):
|
||||
"""
|
||||
Checks whether a device name is longer than the maximum allowed length.
|
||||
|
||||
Args:
|
||||
name: The name of the device.
|
||||
|
||||
Raises:
|
||||
SynapseError: if the device name is too long.
|
||||
"""
|
||||
if name and len(name) > MAX_DEVICE_DISPLAY_NAME_LEN:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"Device display name is too long (max %i)"
|
||||
% (MAX_DEVICE_DISPLAY_NAME_LEN,),
|
||||
errcode=Codes.TOO_LARGE,
|
||||
)
|
||||
|
||||
async def check_device_registered(
|
||||
self, user_id, device_id, initial_device_display_name=None
|
||||
):
|
||||
|
@ -282,6 +301,9 @@ class DeviceHandler(DeviceWorkerHandler):
|
|||
Returns:
|
||||
str: device id (generated if none was supplied)
|
||||
"""
|
||||
|
||||
self._check_device_name_length(initial_device_display_name)
|
||||
|
||||
if device_id is not None:
|
||||
new_device = await self.store.store_device(
|
||||
user_id=user_id,
|
||||
|
@ -397,12 +419,8 @@ class DeviceHandler(DeviceWorkerHandler):
|
|||
|
||||
# Reject a new displayname which is too long.
|
||||
new_display_name = content.get("display_name")
|
||||
if new_display_name and len(new_display_name) > MAX_DEVICE_DISPLAY_NAME_LEN:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"Device display name is too long (max %i)"
|
||||
% (MAX_DEVICE_DISPLAY_NAME_LEN,),
|
||||
)
|
||||
|
||||
self._check_device_name_length(new_display_name)
|
||||
|
||||
try:
|
||||
await self.store.update_device(
|
||||
|
|
|
@ -74,6 +74,8 @@ from synapse.storage.databases.main.events_worker import EventRedactBehaviour
|
|||
from synapse.types import (
|
||||
JsonDict,
|
||||
MutableStateMap,
|
||||
PersistedEventPosition,
|
||||
RoomStreamToken,
|
||||
StateMap,
|
||||
UserID,
|
||||
get_domain_from_id,
|
||||
|
@ -2956,7 +2958,7 @@ class FederationHandler(BaseHandler):
|
|||
)
|
||||
return result["max_stream_id"]
|
||||
else:
|
||||
max_stream_id = await self.storage.persistence.persist_events(
|
||||
max_stream_token = await self.storage.persistence.persist_events(
|
||||
event_and_contexts, backfilled=backfilled
|
||||
)
|
||||
|
||||
|
@ -2967,12 +2969,12 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
if not backfilled: # Never notify for backfilled events
|
||||
for event, _ in event_and_contexts:
|
||||
await self._notify_persisted_event(event, max_stream_id)
|
||||
await self._notify_persisted_event(event, max_stream_token)
|
||||
|
||||
return max_stream_id
|
||||
return max_stream_token.stream
|
||||
|
||||
async def _notify_persisted_event(
|
||||
self, event: EventBase, max_stream_id: int
|
||||
self, event: EventBase, max_stream_token: RoomStreamToken
|
||||
) -> None:
|
||||
"""Checks to see if notifier/pushers should be notified about the
|
||||
event or not.
|
||||
|
@ -2998,9 +3000,11 @@ class FederationHandler(BaseHandler):
|
|||
elif event.internal_metadata.is_outlier():
|
||||
return
|
||||
|
||||
event_stream_id = event.internal_metadata.stream_ordering
|
||||
event_pos = PersistedEventPosition(
|
||||
self._instance_name, event.internal_metadata.stream_ordering
|
||||
)
|
||||
self.notifier.on_new_room_event(
|
||||
event, event_stream_id, max_stream_id, extra_users=extra_users
|
||||
event, event_pos, max_stream_token, extra_users=extra_users
|
||||
)
|
||||
|
||||
async def _clean_room_for_join(self, room_id: str) -> None:
|
||||
|
|
|
@ -1138,7 +1138,7 @@ class EventCreationHandler:
|
|||
if prev_state_ids:
|
||||
raise AuthError(403, "Changing the room create event is forbidden")
|
||||
|
||||
event_stream_id, max_stream_id = await self.storage.persistence.persist_event(
|
||||
event_pos, max_stream_token = await self.storage.persistence.persist_event(
|
||||
event, context=context
|
||||
)
|
||||
|
||||
|
@ -1149,7 +1149,7 @@ class EventCreationHandler:
|
|||
def _notify():
|
||||
try:
|
||||
self.notifier.on_new_room_event(
|
||||
event, event_stream_id, max_stream_id, extra_users=extra_users
|
||||
event, event_pos, max_stream_token, extra_users=extra_users
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error notifying about new room event")
|
||||
|
@ -1161,7 +1161,7 @@ class EventCreationHandler:
|
|||
# matters as sometimes presence code can take a while.
|
||||
run_in_background(self._bump_active_time, requester.user)
|
||||
|
||||
return event_stream_id
|
||||
return event_pos.stream
|
||||
|
||||
async def _bump_active_time(self, user: UserID) -> None:
|
||||
try:
|
||||
|
@ -1182,54 +1182,7 @@ class EventCreationHandler:
|
|||
)
|
||||
|
||||
for room_id in room_ids:
|
||||
# For each room we need to find a joined member we can use to send
|
||||
# the dummy event with.
|
||||
|
||||
latest_event_ids = await self.store.get_prev_events_for_room(room_id)
|
||||
|
||||
members = await self.state.get_current_users_in_room(
|
||||
room_id, latest_event_ids=latest_event_ids
|
||||
)
|
||||
dummy_event_sent = False
|
||||
for user_id in members:
|
||||
if not self.hs.is_mine_id(user_id):
|
||||
continue
|
||||
requester = create_requester(user_id)
|
||||
try:
|
||||
event, context = await self.create_event(
|
||||
requester,
|
||||
{
|
||||
"type": "org.matrix.dummy_event",
|
||||
"content": {},
|
||||
"room_id": room_id,
|
||||
"sender": user_id,
|
||||
},
|
||||
prev_event_ids=latest_event_ids,
|
||||
)
|
||||
|
||||
event.internal_metadata.proactively_send = False
|
||||
|
||||
# Since this is a dummy-event it is OK if it is sent by a
|
||||
# shadow-banned user.
|
||||
await self.send_nonmember_event(
|
||||
requester,
|
||||
event,
|
||||
context,
|
||||
ratelimit=False,
|
||||
ignore_shadow_ban=True,
|
||||
)
|
||||
dummy_event_sent = True
|
||||
break
|
||||
except ConsentNotGivenError:
|
||||
logger.info(
|
||||
"Failed to send dummy event into room %s for user %s due to "
|
||||
"lack of consent. Will try another user" % (room_id, user_id)
|
||||
)
|
||||
except AuthError:
|
||||
logger.info(
|
||||
"Failed to send dummy event into room %s for user %s due to "
|
||||
"lack of power. Will try another user" % (room_id, user_id)
|
||||
)
|
||||
dummy_event_sent = await self._send_dummy_event_for_room(room_id)
|
||||
|
||||
if not dummy_event_sent:
|
||||
# Did not find a valid user in the room, so remove from future attempts
|
||||
|
@ -1242,6 +1195,59 @@ class EventCreationHandler:
|
|||
now = self.clock.time_msec()
|
||||
self._rooms_to_exclude_from_dummy_event_insertion[room_id] = now
|
||||
|
||||
async def _send_dummy_event_for_room(self, room_id: str) -> bool:
|
||||
"""Attempt to send a dummy event for the given room.
|
||||
|
||||
Args:
|
||||
room_id: room to try to send an event from
|
||||
|
||||
Returns:
|
||||
True if a dummy event was successfully sent. False if no user was able
|
||||
to send an event.
|
||||
"""
|
||||
|
||||
# For each room we need to find a joined member we can use to send
|
||||
# the dummy event with.
|
||||
latest_event_ids = await self.store.get_prev_events_for_room(room_id)
|
||||
members = await self.state.get_current_users_in_room(
|
||||
room_id, latest_event_ids=latest_event_ids
|
||||
)
|
||||
for user_id in members:
|
||||
if not self.hs.is_mine_id(user_id):
|
||||
continue
|
||||
requester = create_requester(user_id)
|
||||
try:
|
||||
event, context = await self.create_event(
|
||||
requester,
|
||||
{
|
||||
"type": "org.matrix.dummy_event",
|
||||
"content": {},
|
||||
"room_id": room_id,
|
||||
"sender": user_id,
|
||||
},
|
||||
prev_event_ids=latest_event_ids,
|
||||
)
|
||||
|
||||
event.internal_metadata.proactively_send = False
|
||||
|
||||
# Since this is a dummy-event it is OK if it is sent by a
|
||||
# shadow-banned user.
|
||||
await self.send_nonmember_event(
|
||||
requester, event, context, ratelimit=False, ignore_shadow_ban=True,
|
||||
)
|
||||
return True
|
||||
except ConsentNotGivenError:
|
||||
logger.info(
|
||||
"Failed to send dummy event into room %s for user %s due to "
|
||||
"lack of consent. Will try another user" % (room_id, user_id)
|
||||
)
|
||||
except AuthError:
|
||||
logger.info(
|
||||
"Failed to send dummy event into room %s for user %s due to "
|
||||
"lack of power. Will try another user" % (room_id, user_id)
|
||||
)
|
||||
return False
|
||||
|
||||
def _expire_rooms_to_exclude_from_dummy_event_insertion(self):
|
||||
expire_before = self.clock.time_msec() - _DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY
|
||||
to_expire = set()
|
||||
|
|
|
@ -114,6 +114,7 @@ class OidcHandler:
|
|||
hs.config.oidc_user_mapping_provider_config
|
||||
) # type: OidcMappingProvider
|
||||
self._skip_verification = hs.config.oidc_skip_verification # type: bool
|
||||
self._allow_existing_users = hs.config.oidc_allow_existing_users # type: bool
|
||||
|
||||
self._http_client = hs.get_proxied_http_client()
|
||||
self._auth_handler = hs.get_auth_handler()
|
||||
|
@ -849,7 +850,8 @@ class OidcHandler:
|
|||
If we don't find the user that way, we should register the user,
|
||||
mapping the localpart and the display name from the UserInfo.
|
||||
|
||||
If a user already exists with the mxid we've mapped, raise an exception.
|
||||
If a user already exists with the mxid we've mapped and allow_existing_users
|
||||
is disabled, raise an exception.
|
||||
|
||||
Args:
|
||||
userinfo: an object representing the user
|
||||
|
@ -905,21 +907,31 @@ class OidcHandler:
|
|||
|
||||
localpart = map_username_to_mxid_localpart(attributes["localpart"])
|
||||
|
||||
user_id = UserID(localpart, self._hostname)
|
||||
if await self._datastore.get_users_by_id_case_insensitive(user_id.to_string()):
|
||||
# This mxid is taken
|
||||
raise MappingException(
|
||||
"mxid '{}' is already taken".format(user_id.to_string())
|
||||
user_id = UserID(localpart, self._hostname).to_string()
|
||||
users = await self._datastore.get_users_by_id_case_insensitive(user_id)
|
||||
if users:
|
||||
if self._allow_existing_users:
|
||||
if len(users) == 1:
|
||||
registered_user_id = next(iter(users))
|
||||
elif user_id in users:
|
||||
registered_user_id = user_id
|
||||
else:
|
||||
raise MappingException(
|
||||
"Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
|
||||
user_id, list(users.keys())
|
||||
)
|
||||
)
|
||||
else:
|
||||
# This mxid is taken
|
||||
raise MappingException("mxid '{}' is already taken".format(user_id))
|
||||
else:
|
||||
# It's the first time this user is logging in and the mapped mxid was
|
||||
# not taken, register the user
|
||||
registered_user_id = await self._registration_handler.register_user(
|
||||
localpart=localpart,
|
||||
default_display_name=attributes["display_name"],
|
||||
user_agent_ips=(user_agent, ip_address),
|
||||
)
|
||||
|
||||
# It's the first time this user is logging in and the mapped mxid was
|
||||
# not taken, register the user
|
||||
registered_user_id = await self._registration_handler.register_user(
|
||||
localpart=localpart,
|
||||
default_display_name=attributes["display_name"],
|
||||
user_agent_ips=(user_agent, ip_address),
|
||||
)
|
||||
|
||||
await self._datastore.record_user_external_id(
|
||||
self._auth_provider_id, remote_user_id, registered_user_id,
|
||||
)
|
||||
|
|
|
@ -967,7 +967,7 @@ class SyncHandler:
|
|||
raise NotImplementedError()
|
||||
else:
|
||||
joined_room_ids = await self.get_rooms_for_user_at(
|
||||
user_id, now_token.room_stream_id
|
||||
user_id, now_token.room_key
|
||||
)
|
||||
sync_result_builder = SyncResultBuilder(
|
||||
sync_config,
|
||||
|
@ -1916,7 +1916,7 @@ class SyncHandler:
|
|||
raise Exception("Unrecognized rtype: %r", room_builder.rtype)
|
||||
|
||||
async def get_rooms_for_user_at(
|
||||
self, user_id: str, stream_ordering: int
|
||||
self, user_id: str, room_key: RoomStreamToken
|
||||
) -> FrozenSet[str]:
|
||||
"""Get set of joined rooms for a user at the given stream ordering.
|
||||
|
||||
|
@ -1942,15 +1942,15 @@ class SyncHandler:
|
|||
# If the membership's stream ordering is after the given stream
|
||||
# ordering, we need to go and work out if the user was in the room
|
||||
# before.
|
||||
for room_id, membership_stream_ordering in joined_rooms:
|
||||
if membership_stream_ordering <= stream_ordering:
|
||||
for room_id, event_pos in joined_rooms:
|
||||
if not event_pos.persisted_after(room_key):
|
||||
joined_room_ids.add(room_id)
|
||||
continue
|
||||
|
||||
logger.info("User joined room after current token: %s", room_id)
|
||||
|
||||
extrems = await self.store.get_forward_extremeties_for_room(
|
||||
room_id, stream_ordering
|
||||
room_id, event_pos.stream
|
||||
)
|
||||
users_in_room = await self.state.get_current_users_in_room(room_id, extrems)
|
||||
if user_id in users_in_room:
|
||||
|
|
|
@ -17,6 +17,18 @@
|
|||
import logging
|
||||
import urllib
|
||||
from io import BytesIO
|
||||
from typing import (
|
||||
Any,
|
||||
BinaryIO,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import treq
|
||||
from canonicaljson import encode_canonical_json
|
||||
|
@ -37,6 +49,7 @@ from twisted.web._newclient import ResponseDone
|
|||
from twisted.web.client import Agent, HTTPConnectionPool, readBody
|
||||
from twisted.web.http import PotentialDataLoss
|
||||
from twisted.web.http_headers import Headers
|
||||
from twisted.web.iweb import IResponse
|
||||
|
||||
from synapse.api.errors import Codes, HttpResponseException, SynapseError
|
||||
from synapse.http import (
|
||||
|
@ -57,6 +70,19 @@ incoming_responses_counter = Counter(
|
|||
"synapse_http_client_responses", "", ["method", "code"]
|
||||
)
|
||||
|
||||
# the type of the headers list, to be passed to the t.w.h.Headers.
|
||||
# Actually we can mix str and bytes keys, but Mapping treats 'key' as invariant so
|
||||
# we simplify.
|
||||
RawHeaders = Union[Mapping[str, "RawHeaderValue"], Mapping[bytes, "RawHeaderValue"]]
|
||||
|
||||
# the value actually has to be a List, but List is invariant so we can't specify that
|
||||
# the entries can either be Lists or bytes.
|
||||
RawHeaderValue = Sequence[Union[str, bytes]]
|
||||
|
||||
# the type of the query params, to be passed into `urlencode`
|
||||
QueryParamValue = Union[str, bytes, Iterable[Union[str, bytes]]]
|
||||
QueryParams = Union[Mapping[str, QueryParamValue], Mapping[bytes, QueryParamValue]]
|
||||
|
||||
|
||||
def check_against_blacklist(ip_address, ip_whitelist, ip_blacklist):
|
||||
"""
|
||||
|
@ -285,13 +311,26 @@ class SimpleHttpClient:
|
|||
ip_blacklist=self._ip_blacklist,
|
||||
)
|
||||
|
||||
async def request(self, method, uri, data=None, headers=None):
|
||||
async def request(
|
||||
self,
|
||||
method: str,
|
||||
uri: str,
|
||||
data: Optional[bytes] = None,
|
||||
headers: Optional[Headers] = None,
|
||||
) -> IResponse:
|
||||
"""
|
||||
Args:
|
||||
method (str): HTTP method to use.
|
||||
uri (str): URI to query.
|
||||
data (bytes): Data to send in the request body, if applicable.
|
||||
headers (t.w.http_headers.Headers): Request headers.
|
||||
method: HTTP method to use.
|
||||
uri: URI to query.
|
||||
data: Data to send in the request body, if applicable.
|
||||
headers: Request headers.
|
||||
|
||||
Returns:
|
||||
Response object, once the headers have been read.
|
||||
|
||||
Raises:
|
||||
RequestTimedOutError if the request times out before the headers are read
|
||||
|
||||
"""
|
||||
# A small wrapper around self.agent.request() so we can easily attach
|
||||
# counters to it
|
||||
|
@ -324,6 +363,8 @@ class SimpleHttpClient:
|
|||
headers=headers,
|
||||
**self._extra_treq_args
|
||||
)
|
||||
# we use our own timeout mechanism rather than treq's as a workaround
|
||||
# for https://twistedmatrix.com/trac/ticket/9534.
|
||||
request_deferred = timeout_deferred(
|
||||
request_deferred,
|
||||
60,
|
||||
|
@ -353,18 +394,26 @@ class SimpleHttpClient:
|
|||
set_tag("error_reason", e.args[0])
|
||||
raise
|
||||
|
||||
async def post_urlencoded_get_json(self, uri, args={}, headers=None):
|
||||
async def post_urlencoded_get_json(
|
||||
self,
|
||||
uri: str,
|
||||
args: Mapping[str, Union[str, List[str]]] = {},
|
||||
headers: Optional[RawHeaders] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Args:
|
||||
uri (str):
|
||||
args (dict[str, str|List[str]]): query params
|
||||
headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
|
||||
header name to a list of values for that header
|
||||
uri: uri to query
|
||||
args: parameters to be url-encoded in the body
|
||||
headers: a map from header name to a list of values for that header
|
||||
|
||||
Returns:
|
||||
object: parsed json
|
||||
parsed json
|
||||
|
||||
Raises:
|
||||
RequestTimedOutException: if there is a timeout before the response headers
|
||||
are received. Note there is currently no timeout on reading the response
|
||||
body.
|
||||
|
||||
HttpResponseException: On a non-2xx HTTP response.
|
||||
|
||||
ValueError: if the response was not JSON
|
||||
|
@ -398,19 +447,24 @@ class SimpleHttpClient:
|
|||
response.code, response.phrase.decode("ascii", errors="replace"), body
|
||||
)
|
||||
|
||||
async def post_json_get_json(self, uri, post_json, headers=None):
|
||||
async def post_json_get_json(
|
||||
self, uri: str, post_json: Any, headers: Optional[RawHeaders] = None
|
||||
) -> Any:
|
||||
"""
|
||||
|
||||
Args:
|
||||
uri (str):
|
||||
post_json (object):
|
||||
headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
|
||||
header name to a list of values for that header
|
||||
uri: URI to query.
|
||||
post_json: request body, to be encoded as json
|
||||
headers: a map from header name to a list of values for that header
|
||||
|
||||
Returns:
|
||||
object: parsed json
|
||||
parsed json
|
||||
|
||||
Raises:
|
||||
RequestTimedOutException: if there is a timeout before the response headers
|
||||
are received. Note there is currently no timeout on reading the response
|
||||
body.
|
||||
|
||||
HttpResponseException: On a non-2xx HTTP response.
|
||||
|
||||
ValueError: if the response was not JSON
|
||||
|
@ -440,21 +494,22 @@ class SimpleHttpClient:
|
|||
response.code, response.phrase.decode("ascii", errors="replace"), body
|
||||
)
|
||||
|
||||
async def get_json(self, uri, args={}, headers=None):
|
||||
""" Gets some json from the given URI.
|
||||
async def get_json(
|
||||
self, uri: str, args: QueryParams = {}, headers: Optional[RawHeaders] = None,
|
||||
) -> Any:
|
||||
"""Gets some json from the given URI.
|
||||
|
||||
Args:
|
||||
uri (str): The URI to request, not including query parameters
|
||||
args (dict): A dictionary used to create query strings, defaults to
|
||||
None.
|
||||
**Note**: The value of each key is assumed to be an iterable
|
||||
and *not* a string.
|
||||
headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
|
||||
header name to a list of values for that header
|
||||
uri: The URI to request, not including query parameters
|
||||
args: A dictionary used to create query string
|
||||
headers: a map from header name to a list of values for that header
|
||||
Returns:
|
||||
Succeeds when we get *any* 2xx HTTP response, with the
|
||||
HTTP body as JSON.
|
||||
Succeeds when we get a 2xx HTTP response, with the HTTP body as JSON.
|
||||
Raises:
|
||||
RequestTimedOutException: if there is a timeout before the response headers
|
||||
are received. Note there is currently no timeout on reading the response
|
||||
body.
|
||||
|
||||
HttpResponseException On a non-2xx HTTP response.
|
||||
|
||||
ValueError: if the response was not JSON
|
||||
|
@ -466,22 +521,27 @@ class SimpleHttpClient:
|
|||
body = await self.get_raw(uri, args, headers=headers)
|
||||
return json_decoder.decode(body.decode("utf-8"))
|
||||
|
||||
async def put_json(self, uri, json_body, args={}, headers=None):
|
||||
""" Puts some json to the given URI.
|
||||
async def put_json(
|
||||
self,
|
||||
uri: str,
|
||||
json_body: Any,
|
||||
args: QueryParams = {},
|
||||
headers: RawHeaders = None,
|
||||
) -> Any:
|
||||
"""Puts some json to the given URI.
|
||||
|
||||
Args:
|
||||
uri (str): The URI to request, not including query parameters
|
||||
json_body (dict): The JSON to put in the HTTP body,
|
||||
args (dict): A dictionary used to create query strings, defaults to
|
||||
None.
|
||||
**Note**: The value of each key is assumed to be an iterable
|
||||
and *not* a string.
|
||||
headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
|
||||
header name to a list of values for that header
|
||||
uri: The URI to request, not including query parameters
|
||||
json_body: The JSON to put in the HTTP body,
|
||||
args: A dictionary used to create query strings
|
||||
headers: a map from header name to a list of values for that header
|
||||
Returns:
|
||||
Succeeds when we get *any* 2xx HTTP response, with the
|
||||
HTTP body as JSON.
|
||||
Succeeds when we get a 2xx HTTP response, with the HTTP body as JSON.
|
||||
Raises:
|
||||
RequestTimedOutException: if there is a timeout before the response headers
|
||||
are received. Note there is currently no timeout on reading the response
|
||||
body.
|
||||
|
||||
HttpResponseException On a non-2xx HTTP response.
|
||||
|
||||
ValueError: if the response was not JSON
|
||||
|
@ -513,21 +573,23 @@ class SimpleHttpClient:
|
|||
response.code, response.phrase.decode("ascii", errors="replace"), body
|
||||
)
|
||||
|
||||
async def get_raw(self, uri, args={}, headers=None):
|
||||
""" Gets raw text from the given URI.
|
||||
async def get_raw(
|
||||
self, uri: str, args: QueryParams = {}, headers: Optional[RawHeaders] = None
|
||||
) -> bytes:
|
||||
"""Gets raw text from the given URI.
|
||||
|
||||
Args:
|
||||
uri (str): The URI to request, not including query parameters
|
||||
args (dict): A dictionary used to create query strings, defaults to
|
||||
None.
|
||||
**Note**: The value of each key is assumed to be an iterable
|
||||
and *not* a string.
|
||||
headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
|
||||
header name to a list of values for that header
|
||||
uri: The URI to request, not including query parameters
|
||||
args: A dictionary used to create query strings
|
||||
headers: a map from header name to a list of values for that header
|
||||
Returns:
|
||||
Succeeds when we get *any* 2xx HTTP response, with the
|
||||
Succeeds when we get a 2xx HTTP response, with the
|
||||
HTTP body as bytes.
|
||||
Raises:
|
||||
RequestTimedOutException: if there is a timeout before the response headers
|
||||
are received. Note there is currently no timeout on reading the response
|
||||
body.
|
||||
|
||||
HttpResponseException on a non-2xx HTTP response.
|
||||
"""
|
||||
if len(args):
|
||||
|
@ -552,16 +614,29 @@ class SimpleHttpClient:
|
|||
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
|
||||
# The two should be factored out.
|
||||
|
||||
async def get_file(self, url, output_stream, max_size=None, headers=None):
|
||||
async def get_file(
|
||||
self,
|
||||
url: str,
|
||||
output_stream: BinaryIO,
|
||||
max_size: Optional[int] = None,
|
||||
headers: Optional[RawHeaders] = None,
|
||||
) -> Tuple[int, Dict[bytes, List[bytes]], str, int]:
|
||||
"""GETs a file from a given URL
|
||||
Args:
|
||||
url (str): The URL to GET
|
||||
output_stream (file): File to write the response body to.
|
||||
headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
|
||||
header name to a list of values for that header
|
||||
url: The URL to GET
|
||||
output_stream: File to write the response body to.
|
||||
headers: A map from header name to a list of values for that header
|
||||
Returns:
|
||||
A (int,dict,string,int) tuple of the file length, dict of the response
|
||||
A tuple of the file length, dict of the response
|
||||
headers, absolute URI of the response and HTTP response code.
|
||||
|
||||
Raises:
|
||||
RequestTimedOutException: if there is a timeout before the response headers
|
||||
are received. Note there is currently no timeout on reading the response
|
||||
body.
|
||||
|
||||
SynapseError: if the response is not a 2xx, the remote file is too large, or
|
||||
another exception happens during the download.
|
||||
"""
|
||||
|
||||
actual_headers = {b"User-Agent": [self.user_agent]}
|
||||
|
|
|
@ -42,7 +42,13 @@ from synapse.logging.utils import log_function
|
|||
from synapse.metrics import LaterGauge
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.streams.config import PaginationConfig
|
||||
from synapse.types import Collection, RoomStreamToken, StreamToken, UserID
|
||||
from synapse.types import (
|
||||
Collection,
|
||||
PersistedEventPosition,
|
||||
RoomStreamToken,
|
||||
StreamToken,
|
||||
UserID,
|
||||
)
|
||||
from synapse.util.async_helpers import ObservableDeferred, timeout_deferred
|
||||
from synapse.util.metrics import Measure
|
||||
from synapse.visibility import filter_events_for_client
|
||||
|
@ -187,7 +193,7 @@ class Notifier:
|
|||
self.store = hs.get_datastore()
|
||||
self.pending_new_room_events = (
|
||||
[]
|
||||
) # type: List[Tuple[int, EventBase, Collection[UserID]]]
|
||||
) # type: List[Tuple[PersistedEventPosition, EventBase, Collection[UserID]]]
|
||||
|
||||
# Called when there are new things to stream over replication
|
||||
self.replication_callbacks = [] # type: List[Callable[[], None]]
|
||||
|
@ -246,8 +252,8 @@ class Notifier:
|
|||
def on_new_room_event(
|
||||
self,
|
||||
event: EventBase,
|
||||
room_stream_id: int,
|
||||
max_room_stream_id: int,
|
||||
event_pos: PersistedEventPosition,
|
||||
max_room_stream_token: RoomStreamToken,
|
||||
extra_users: Collection[UserID] = [],
|
||||
):
|
||||
""" Used by handlers to inform the notifier something has happened
|
||||
|
@ -261,16 +267,16 @@ class Notifier:
|
|||
until all previous events have been persisted before notifying
|
||||
the client streams.
|
||||
"""
|
||||
self.pending_new_room_events.append((room_stream_id, event, extra_users))
|
||||
self._notify_pending_new_room_events(max_room_stream_id)
|
||||
self.pending_new_room_events.append((event_pos, event, extra_users))
|
||||
self._notify_pending_new_room_events(max_room_stream_token)
|
||||
|
||||
self.notify_replication()
|
||||
|
||||
def _notify_pending_new_room_events(self, max_room_stream_id: int):
|
||||
def _notify_pending_new_room_events(self, max_room_stream_token: RoomStreamToken):
|
||||
"""Notify for the room events that were queued waiting for a previous
|
||||
event to be persisted.
|
||||
Args:
|
||||
max_room_stream_id: The highest stream_id below which all
|
||||
max_room_stream_token: The highest stream_id below which all
|
||||
events have been persisted.
|
||||
"""
|
||||
pending = self.pending_new_room_events
|
||||
|
@ -279,11 +285,9 @@ class Notifier:
|
|||
users = set() # type: Set[UserID]
|
||||
rooms = set() # type: Set[str]
|
||||
|
||||
for room_stream_id, event, extra_users in pending:
|
||||
if room_stream_id > max_room_stream_id:
|
||||
self.pending_new_room_events.append(
|
||||
(room_stream_id, event, extra_users)
|
||||
)
|
||||
for event_pos, event, extra_users in pending:
|
||||
if event_pos.persisted_after(max_room_stream_token):
|
||||
self.pending_new_room_events.append((event_pos, event, extra_users))
|
||||
else:
|
||||
if (
|
||||
event.type == EventTypes.Member
|
||||
|
@ -296,33 +300,32 @@ class Notifier:
|
|||
|
||||
if users or rooms:
|
||||
self.on_new_event(
|
||||
"room_key",
|
||||
RoomStreamToken(None, max_room_stream_id),
|
||||
users=users,
|
||||
rooms=rooms,
|
||||
"room_key", max_room_stream_token, users=users, rooms=rooms,
|
||||
)
|
||||
self._on_updated_room_token(max_room_stream_id)
|
||||
self._on_updated_room_token(max_room_stream_token)
|
||||
|
||||
def _on_updated_room_token(self, max_room_stream_id: int):
|
||||
def _on_updated_room_token(self, max_room_stream_token: RoomStreamToken):
|
||||
"""Poke services that might care that the room position has been
|
||||
updated.
|
||||
"""
|
||||
|
||||
# poke any interested application service.
|
||||
run_as_background_process(
|
||||
"_notify_app_services", self._notify_app_services, max_room_stream_id
|
||||
"_notify_app_services", self._notify_app_services, max_room_stream_token
|
||||
)
|
||||
|
||||
run_as_background_process(
|
||||
"_notify_pusher_pool", self._notify_pusher_pool, max_room_stream_id
|
||||
"_notify_pusher_pool", self._notify_pusher_pool, max_room_stream_token
|
||||
)
|
||||
|
||||
if self.federation_sender:
|
||||
self.federation_sender.notify_new_events(max_room_stream_id)
|
||||
self.federation_sender.notify_new_events(max_room_stream_token.stream)
|
||||
|
||||
async def _notify_app_services(self, max_room_stream_id: int):
|
||||
async def _notify_app_services(self, max_room_stream_token: RoomStreamToken):
|
||||
try:
|
||||
await self.appservice_handler.notify_interested_services(max_room_stream_id)
|
||||
await self.appservice_handler.notify_interested_services(
|
||||
max_room_stream_token.stream
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error notifying application services of event")
|
||||
|
||||
|
@ -332,9 +335,9 @@ class Notifier:
|
|||
except Exception:
|
||||
logger.exception("Error notifying application services of event")
|
||||
|
||||
async def _notify_pusher_pool(self, max_room_stream_id: int):
|
||||
async def _notify_pusher_pool(self, max_room_stream_token: RoomStreamToken):
|
||||
try:
|
||||
await self._pusher_pool.on_new_notifications(max_room_stream_id)
|
||||
await self._pusher_pool.on_new_notifications(max_room_stream_token.stream)
|
||||
except Exception:
|
||||
logger.exception("Error pusher pool of event")
|
||||
|
||||
|
|
|
@ -60,6 +60,8 @@ class PusherPool:
|
|||
self.store = self.hs.get_datastore()
|
||||
self.clock = self.hs.get_clock()
|
||||
|
||||
self._account_validity = hs.config.account_validity
|
||||
|
||||
# We shard the handling of push notifications by user ID.
|
||||
self._pusher_shard_config = hs.config.push.pusher_shard_config
|
||||
self._instance_name = hs.get_instance_name()
|
||||
|
@ -202,6 +204,14 @@ class PusherPool:
|
|||
)
|
||||
|
||||
for u in users_affected:
|
||||
# Don't push if the user account has expired
|
||||
if self._account_validity.enabled:
|
||||
expired = await self.store.is_account_expired(
|
||||
u, self.clock.time_msec()
|
||||
)
|
||||
if expired:
|
||||
continue
|
||||
|
||||
if u in self.pushers:
|
||||
for p in self.pushers[u].values():
|
||||
p.on_new_notifications(max_stream_id)
|
||||
|
@ -222,6 +232,14 @@ class PusherPool:
|
|||
)
|
||||
|
||||
for u in users_affected:
|
||||
# Don't push if the user account has expired
|
||||
if self._account_validity.enabled:
|
||||
expired = await self.store.is_account_expired(
|
||||
u, self.clock.time_msec()
|
||||
)
|
||||
if expired:
|
||||
continue
|
||||
|
||||
if u in self.pushers:
|
||||
for p in self.pushers[u].values():
|
||||
p.on_new_receipts(min_stream_id, max_stream_id)
|
||||
|
|
|
@ -37,6 +37,9 @@ logger = logging.getLogger(__name__)
|
|||
# installed when that optional dependency requirement is specified. It is passed
|
||||
# to setup() as extras_require in setup.py
|
||||
#
|
||||
# Note that these both represent runtime dependencies (and the versions
|
||||
# installed are checked at runtime).
|
||||
#
|
||||
# [1] https://pip.pypa.io/en/stable/reference/pip_install/#requirement-specifiers.
|
||||
|
||||
REQUIREMENTS = [
|
||||
|
@ -92,20 +95,12 @@ CONDITIONAL_REQUIREMENTS = {
|
|||
"oidc": ["authlib>=0.14.0"],
|
||||
"systemd": ["systemd-python>=231"],
|
||||
"url_preview": ["lxml>=3.5.0"],
|
||||
# Dependencies which are exclusively required by unit test code. This is
|
||||
# NOT a list of all modules that are necessary to run the unit tests.
|
||||
# Tests assume that all optional dependencies are installed.
|
||||
#
|
||||
# parameterized_class decorator was introduced in parameterized 0.7.0
|
||||
"test": ["mock>=2.0", "parameterized>=0.7.0"],
|
||||
"sentry": ["sentry-sdk>=0.7.2"],
|
||||
"opentracing": ["jaeger-client>=4.0.0", "opentracing>=2.2.0"],
|
||||
"jwt": ["pyjwt>=1.6.4"],
|
||||
# hiredis is not a *strict* dependency, but it makes things much faster.
|
||||
# (if it is not installed, we fall back to slow code.)
|
||||
"redis": ["txredisapi>=1.4.7", "hiredis"],
|
||||
# We pin black so that our tests don't start failing on new releases.
|
||||
"lint": ["isort==5.0.3", "black==19.10b0", "flake8-comprehensions", "flake8"],
|
||||
}
|
||||
|
||||
ALL_OPTIONAL_REQUIREMENTS = set() # type: Set[str]
|
||||
|
@ -113,7 +108,7 @@ ALL_OPTIONAL_REQUIREMENTS = set() # type: Set[str]
|
|||
for name, optional_deps in CONDITIONAL_REQUIREMENTS.items():
|
||||
# Exclude systemd as it's a system-based requirement.
|
||||
# Exclude lint as it's a dev-based requirement.
|
||||
if name not in ["systemd", "lint"]:
|
||||
if name not in ["systemd"]:
|
||||
ALL_OPTIONAL_REQUIREMENTS = set(optional_deps) | ALL_OPTIONAL_REQUIREMENTS
|
||||
|
||||
|
||||
|
|
|
@ -31,11 +31,13 @@ class BaseSlavedStore(CacheInvalidationWorkerStore):
|
|||
self._cache_id_gen = MultiWriterIdGenerator(
|
||||
db_conn,
|
||||
database,
|
||||
stream_name="caches",
|
||||
instance_name=hs.get_instance_name(),
|
||||
table="cache_invalidation_stream_by_instance",
|
||||
instance_column="instance_name",
|
||||
id_column="stream_id",
|
||||
sequence_name="cache_invalidation_stream_seq",
|
||||
writers=[],
|
||||
) # type: Optional[MultiWriterIdGenerator]
|
||||
else:
|
||||
self._cache_id_gen = None
|
||||
|
|
|
@ -29,7 +29,7 @@ from synapse.replication.tcp.streams.events import (
|
|||
EventsStreamEventRow,
|
||||
EventsStreamRow,
|
||||
)
|
||||
from synapse.types import UserID
|
||||
from synapse.types import PersistedEventPosition, RoomStreamToken, UserID
|
||||
from synapse.util.async_helpers import timeout_deferred
|
||||
from synapse.util.metrics import Measure
|
||||
|
||||
|
@ -151,8 +151,14 @@ class ReplicationDataHandler:
|
|||
extra_users = () # type: Tuple[UserID, ...]
|
||||
if event.type == EventTypes.Member:
|
||||
extra_users = (UserID.from_string(event.state_key),)
|
||||
max_token = self.store.get_room_max_stream_ordering()
|
||||
self.notifier.on_new_room_event(event, token, max_token, extra_users)
|
||||
|
||||
max_token = RoomStreamToken(
|
||||
None, self.store.get_room_max_stream_ordering()
|
||||
)
|
||||
event_pos = PersistedEventPosition(instance_name, token)
|
||||
self.notifier.on_new_room_event(
|
||||
event, event_pos, max_token, extra_users
|
||||
)
|
||||
|
||||
# Notify any waiting deferreds. The list is ordered by position so we
|
||||
# just iterate through the list until we reach a position that is
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
<p>
|
||||
There was an error during authentication:
|
||||
</p>
|
||||
<div id="errormsg" style="margin:20px 80px">{{ error_description }}</div>
|
||||
<div id="errormsg" style="margin:20px 80px">{{ error_description | e }}</div>
|
||||
<p>
|
||||
If you are seeing this page after clicking a link sent to you via email, make
|
||||
sure you only click the confirmation link once, and that you open the
|
||||
|
|
|
@ -31,6 +31,7 @@ from synapse.rest.admin.devices import (
|
|||
DeviceRestServlet,
|
||||
DevicesRestServlet,
|
||||
)
|
||||
from synapse.rest.admin.event_reports import EventReportsRestServlet
|
||||
from synapse.rest.admin.groups import DeleteGroupAdminRestServlet
|
||||
from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo
|
||||
from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet
|
||||
|
@ -216,6 +217,7 @@ def register_servlets(hs, http_server):
|
|||
DeviceRestServlet(hs).register(http_server)
|
||||
DevicesRestServlet(hs).register(http_server)
|
||||
DeleteDevicesRestServlet(hs).register(http_server)
|
||||
EventReportsRestServlet(hs).register(http_server)
|
||||
|
||||
|
||||
def register_servlets_for_client_rest_resource(hs, http_server):
|
||||
|
|
|
@ -0,0 +1,88 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2020 Dirk Klimpel
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.http.servlet import RestServlet, parse_integer, parse_string
|
||||
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EventReportsRestServlet(RestServlet):
|
||||
"""
|
||||
List all reported events that are known to the homeserver. Results are returned
|
||||
in a dictionary containing report information. Supports pagination.
|
||||
The requester must have administrator access in Synapse.
|
||||
|
||||
GET /_synapse/admin/v1/event_reports
|
||||
returns:
|
||||
200 OK with list of reports if success otherwise an error.
|
||||
|
||||
Args:
|
||||
The parameters `from` and `limit` are required only for pagination.
|
||||
By default, a `limit` of 100 is used.
|
||||
The parameter `dir` can be used to define the order of results.
|
||||
The parameter `user_id` can be used to filter by user id.
|
||||
The parameter `room_id` can be used to filter by room id.
|
||||
Returns:
|
||||
A list of reported events and an integer representing the total number of
|
||||
reported events that exist given this query
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/event_reports$")
|
||||
|
||||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
async def on_GET(self, request):
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
start = parse_integer(request, "from", default=0)
|
||||
limit = parse_integer(request, "limit", default=100)
|
||||
direction = parse_string(request, "dir", default="b")
|
||||
user_id = parse_string(request, "user_id")
|
||||
room_id = parse_string(request, "room_id")
|
||||
|
||||
if start < 0:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"The start parameter must be a positive integer.",
|
||||
errcode=Codes.INVALID_PARAM,
|
||||
)
|
||||
|
||||
if limit < 0:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"The limit parameter must be a positive integer.",
|
||||
errcode=Codes.INVALID_PARAM,
|
||||
)
|
||||
|
||||
if direction not in ("f", "b"):
|
||||
raise SynapseError(
|
||||
400, "Unknown direction: %s" % (direction,), errcode=Codes.INVALID_PARAM
|
||||
)
|
||||
|
||||
event_reports, total = await self.store.get_event_reports_paginate(
|
||||
start, limit, direction, user_id, room_id
|
||||
)
|
||||
ret = {"event_reports": event_reports, "total": total}
|
||||
if (start + limit) < total:
|
||||
ret["next_token"] = start + len(event_reports)
|
||||
|
||||
return 200, ret
|
|
@ -450,7 +450,7 @@ class PreviewUrlResource(DirectServeJsonResource):
|
|||
logger.warning("Error downloading oEmbed metadata from %s: %r", url, e)
|
||||
raise OEmbedError() from e
|
||||
|
||||
async def _download_url(self, url, user):
|
||||
async def _download_url(self, url: str, user):
|
||||
# TODO: we should probably honour robots.txt... except in practice
|
||||
# we're most likely being explicitly triggered by a human rather than a
|
||||
# bot, so are we really a robot?
|
||||
|
@ -460,7 +460,7 @@ class PreviewUrlResource(DirectServeJsonResource):
|
|||
file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True)
|
||||
|
||||
# If this URL can be accessed via oEmbed, use that instead.
|
||||
url_to_download = url
|
||||
url_to_download = url # type: Optional[str]
|
||||
oembed_url = self._get_oembed_url(url)
|
||||
if oembed_url:
|
||||
# The result might be a new URL to download, or it might be HTML content.
|
||||
|
@ -520,9 +520,15 @@ class PreviewUrlResource(DirectServeJsonResource):
|
|||
# FIXME: we should calculate a proper expiration based on the
|
||||
# Cache-Control and Expire headers. But for now, assume 1 hour.
|
||||
expires = ONE_HOUR
|
||||
etag = headers["ETag"][0] if "ETag" in headers else None
|
||||
etag = (
|
||||
headers[b"ETag"][0].decode("ascii") if b"ETag" in headers else None
|
||||
)
|
||||
else:
|
||||
html_bytes = oembed_result.html.encode("utf-8") # type: ignore
|
||||
# we can only get here if we did an oembed request and have an oembed_result.html
|
||||
assert oembed_result.html is not None
|
||||
assert oembed_url is not None
|
||||
|
||||
html_bytes = oembed_result.html.encode("utf-8")
|
||||
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
|
||||
f.write(html_bytes)
|
||||
await finish()
|
||||
|
|
|
@ -25,7 +25,6 @@ from typing import (
|
|||
Sequence,
|
||||
Set,
|
||||
Union,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
|
||||
|
@ -42,7 +41,7 @@ from synapse.logging.utils import log_function
|
|||
from synapse.state import v1, v2
|
||||
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
|
||||
from synapse.storage.roommember import ProfileInfo
|
||||
from synapse.types import Collection, MutableStateMap, StateMap
|
||||
from synapse.types import Collection, StateMap
|
||||
from synapse.util import Clock
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
|
@ -472,10 +471,9 @@ class StateResolutionHandler:
|
|||
def __init__(self, hs):
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
# dict of set of event_ids -> _StateCacheEntry.
|
||||
self._state_cache = None
|
||||
self.resolve_linearizer = Linearizer(name="state_resolve_lock")
|
||||
|
||||
# dict of set of event_ids -> _StateCacheEntry.
|
||||
self._state_cache = ExpiringCache(
|
||||
cache_name="state_cache",
|
||||
clock=self.clock,
|
||||
|
@ -519,57 +517,28 @@ class StateResolutionHandler:
|
|||
Returns:
|
||||
The resolved state
|
||||
"""
|
||||
logger.debug("resolve_state_groups state_groups %s", state_groups_ids.keys())
|
||||
|
||||
group_names = frozenset(state_groups_ids.keys())
|
||||
|
||||
with (await self.resolve_linearizer.queue(group_names)):
|
||||
if self._state_cache is not None:
|
||||
cache = self._state_cache.get(group_names, None)
|
||||
if cache:
|
||||
return cache
|
||||
cache = self._state_cache.get(group_names, None)
|
||||
if cache:
|
||||
return cache
|
||||
|
||||
logger.info(
|
||||
"Resolving state for %s with %d groups", room_id, len(state_groups_ids)
|
||||
"Resolving state for %s with groups %s", room_id, list(group_names),
|
||||
)
|
||||
|
||||
state_groups_histogram.observe(len(state_groups_ids))
|
||||
|
||||
# start by assuming we won't have any conflicted state, and build up the new
|
||||
# state map by iterating through the state groups. If we discover a conflict,
|
||||
# we give up and instead use `resolve_events_with_store`.
|
||||
#
|
||||
# XXX: is this actually worthwhile, or should we just let
|
||||
# resolve_events_with_store do it?
|
||||
new_state = {} # type: MutableStateMap[str]
|
||||
conflicted_state = False
|
||||
for st in state_groups_ids.values():
|
||||
for key, e_id in st.items():
|
||||
if key in new_state:
|
||||
conflicted_state = True
|
||||
break
|
||||
new_state[key] = e_id
|
||||
if conflicted_state:
|
||||
break
|
||||
|
||||
if conflicted_state:
|
||||
logger.info("Resolving conflicted state for %r", room_id)
|
||||
with Measure(self.clock, "state._resolve_events"):
|
||||
# resolve_events_with_store returns a StateMap, but we can
|
||||
# treat it as a MutableStateMap as it is above. It isn't
|
||||
# actually mutated anymore (and is frozen in
|
||||
# _make_state_cache_entry below).
|
||||
new_state = cast(
|
||||
MutableStateMap,
|
||||
await resolve_events_with_store(
|
||||
self.clock,
|
||||
room_id,
|
||||
room_version,
|
||||
list(state_groups_ids.values()),
|
||||
event_map=event_map,
|
||||
state_res_store=state_res_store,
|
||||
),
|
||||
)
|
||||
with Measure(self.clock, "state._resolve_events"):
|
||||
new_state = await resolve_events_with_store(
|
||||
self.clock,
|
||||
room_id,
|
||||
room_version,
|
||||
list(state_groups_ids.values()),
|
||||
event_map=event_map,
|
||||
state_res_store=state_res_store,
|
||||
)
|
||||
|
||||
# if the new state matches any of the input state groups, we can
|
||||
# use that state group again. Otherwise we will generate a state_id
|
||||
|
@ -579,8 +548,7 @@ class StateResolutionHandler:
|
|||
with Measure(self.clock, "state.create_group_ids"):
|
||||
cache = _make_state_cache_entry(new_state, state_groups_ids)
|
||||
|
||||
if self._state_cache is not None:
|
||||
self._state_cache[group_names] = cache
|
||||
self._state_cache[group_names] = cache
|
||||
|
||||
return cache
|
||||
|
||||
|
|
|
@ -160,14 +160,20 @@ class DataStore(
|
|||
)
|
||||
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
# We set the `writers` to an empty list here as we don't care about
|
||||
# missing updates over restarts, as we'll not have anything in our
|
||||
# caches to invalidate. (This reduces the amount of writes to the DB
|
||||
# that happen).
|
||||
self._cache_id_gen = MultiWriterIdGenerator(
|
||||
db_conn,
|
||||
database,
|
||||
instance_name="master",
|
||||
stream_name="caches",
|
||||
instance_name=hs.get_instance_name(),
|
||||
table="cache_invalidation_stream_by_instance",
|
||||
instance_column="instance_name",
|
||||
id_column="stream_id",
|
||||
sequence_name="cache_invalidation_stream_seq",
|
||||
writers=[],
|
||||
)
|
||||
else:
|
||||
self._cache_id_gen = None
|
||||
|
|
|
@ -339,7 +339,7 @@ class AccountDataStore(AccountDataWorkerStore):
|
|||
"""
|
||||
content_json = json_encoder.encode(content)
|
||||
|
||||
with await self._account_data_id_gen.get_next() as next_id:
|
||||
async with self._account_data_id_gen.get_next() as next_id:
|
||||
# no need to lock here as room_account_data has a unique constraint
|
||||
# on (user_id, room_id, account_data_type) so simple_upsert will
|
||||
# retry if there is a conflict.
|
||||
|
@ -387,7 +387,7 @@ class AccountDataStore(AccountDataWorkerStore):
|
|||
"""
|
||||
content_json = json_encoder.encode(content)
|
||||
|
||||
with await self._account_data_id_gen.get_next() as next_id:
|
||||
async with self._account_data_id_gen.get_next() as next_id:
|
||||
# no need to lock here as account_data has a unique constraint on
|
||||
# (user_id, account_data_type) so simple_upsert will retry if
|
||||
# there is a conflict.
|
||||
|
|
|
@ -394,7 +394,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
|
|||
rows.append((destination, stream_id, now_ms, edu_json))
|
||||
txn.executemany(sql, rows)
|
||||
|
||||
with await self._device_inbox_id_gen.get_next() as stream_id:
|
||||
async with self._device_inbox_id_gen.get_next() as stream_id:
|
||||
now_ms = self.clock.time_msec()
|
||||
await self.db_pool.runInteraction(
|
||||
"add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
|
||||
|
@ -443,7 +443,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
|
|||
txn, stream_id, local_messages_by_user_then_device
|
||||
)
|
||||
|
||||
with await self._device_inbox_id_gen.get_next() as stream_id:
|
||||
async with self._device_inbox_id_gen.get_next() as stream_id:
|
||||
now_ms = self.clock.time_msec()
|
||||
await self.db_pool.runInteraction(
|
||||
"add_messages_from_remote_to_device_inbox",
|
||||
|
|
|
@ -377,7 +377,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
THe new stream ID.
|
||||
"""
|
||||
|
||||
with await self._device_list_id_gen.get_next() as stream_id:
|
||||
async with self._device_list_id_gen.get_next() as stream_id:
|
||||
await self.db_pool.runInteraction(
|
||||
"add_user_sig_change_to_streams",
|
||||
self._add_user_signature_change_txn,
|
||||
|
@ -1093,7 +1093,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
if not device_ids:
|
||||
return
|
||||
|
||||
with await self._device_list_id_gen.get_next_mult(
|
||||
async with self._device_list_id_gen.get_next_mult(
|
||||
len(device_ids)
|
||||
) as stream_ids:
|
||||
await self.db_pool.runInteraction(
|
||||
|
@ -1108,7 +1108,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
return stream_ids[-1]
|
||||
|
||||
context = get_active_span_text_map()
|
||||
with await self._device_list_id_gen.get_next_mult(
|
||||
async with self._device_list_id_gen.get_next_mult(
|
||||
len(hosts) * len(device_ids)
|
||||
) as stream_ids:
|
||||
await self.db_pool.runInteraction(
|
||||
|
|
|
@ -831,7 +831,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
|||
key (dict): the key data
|
||||
"""
|
||||
|
||||
with await self._cross_signing_id_gen.get_next() as stream_id:
|
||||
async with self._cross_signing_id_gen.get_next() as stream_id:
|
||||
return await self.db_pool.runInteraction(
|
||||
"add_e2e_cross_signing_key",
|
||||
self._set_e2e_cross_signing_key_txn,
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
import itertools
|
||||
import logging
|
||||
from collections import OrderedDict, namedtuple
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Set, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple
|
||||
|
||||
import attr
|
||||
from prometheus_client import Counter
|
||||
|
@ -156,15 +156,15 @@ class PersistEventsStore:
|
|||
# Note: Multiple instances of this function cannot be in flight at
|
||||
# the same time for the same room.
|
||||
if backfilled:
|
||||
stream_ordering_manager = await self._backfill_id_gen.get_next_mult(
|
||||
stream_ordering_manager = self._backfill_id_gen.get_next_mult(
|
||||
len(events_and_contexts)
|
||||
)
|
||||
else:
|
||||
stream_ordering_manager = await self._stream_id_gen.get_next_mult(
|
||||
stream_ordering_manager = self._stream_id_gen.get_next_mult(
|
||||
len(events_and_contexts)
|
||||
)
|
||||
|
||||
with stream_ordering_manager as stream_orderings:
|
||||
async with stream_ordering_manager as stream_orderings:
|
||||
for (event, context), stream in zip(events_and_contexts, stream_orderings):
|
||||
event.internal_metadata.stream_ordering = stream
|
||||
|
||||
|
@ -1108,6 +1108,10 @@ class PersistEventsStore:
|
|||
def _store_room_members_txn(self, txn, events, backfilled):
|
||||
"""Store a room member in the database.
|
||||
"""
|
||||
|
||||
def str_or_none(val: Any) -> Optional[str]:
|
||||
return val if isinstance(val, str) else None
|
||||
|
||||
self.db_pool.simple_insert_many_txn(
|
||||
txn,
|
||||
table="room_memberships",
|
||||
|
@ -1118,8 +1122,8 @@ class PersistEventsStore:
|
|||
"sender": event.user_id,
|
||||
"room_id": event.room_id,
|
||||
"membership": event.membership,
|
||||
"display_name": event.content.get("displayname", None),
|
||||
"avatar_url": event.content.get("avatar_url", None),
|
||||
"display_name": str_or_none(event.content.get("displayname")),
|
||||
"avatar_url": str_or_none(event.content.get("avatar_url")),
|
||||
}
|
||||
for event in events
|
||||
],
|
||||
|
|
|
@ -83,21 +83,25 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
self._stream_id_gen = MultiWriterIdGenerator(
|
||||
db_conn=db_conn,
|
||||
db=database,
|
||||
stream_name="events",
|
||||
instance_name=hs.get_instance_name(),
|
||||
table="events",
|
||||
instance_column="instance_name",
|
||||
id_column="stream_ordering",
|
||||
sequence_name="events_stream_seq",
|
||||
writers=hs.config.worker.writers.events,
|
||||
)
|
||||
self._backfill_id_gen = MultiWriterIdGenerator(
|
||||
db_conn=db_conn,
|
||||
db=database,
|
||||
stream_name="backfill",
|
||||
instance_name=hs.get_instance_name(),
|
||||
table="events",
|
||||
instance_column="instance_name",
|
||||
id_column="stream_ordering",
|
||||
sequence_name="events_backfill_stream_seq",
|
||||
positive=False,
|
||||
writers=hs.config.worker.writers.events,
|
||||
)
|
||||
else:
|
||||
# We shouldn't be running in worker mode with SQLite, but its useful
|
||||
|
|
|
@ -1265,7 +1265,7 @@ class GroupServerStore(GroupServerWorkerStore):
|
|||
|
||||
return next_id
|
||||
|
||||
with await self._group_updates_id_gen.get_next() as next_id:
|
||||
async with self._group_updates_id_gen.get_next() as next_id:
|
||||
res = await self.db_pool.runInteraction(
|
||||
"register_user_group_membership",
|
||||
_register_user_group_membership_txn,
|
||||
|
|
|
@ -23,11 +23,11 @@ from synapse.util.iterutils import batch_iter
|
|||
|
||||
class PresenceStore(SQLBaseStore):
|
||||
async def update_presence(self, presence_states):
|
||||
stream_ordering_manager = await self._presence_id_gen.get_next_mult(
|
||||
stream_ordering_manager = self._presence_id_gen.get_next_mult(
|
||||
len(presence_states)
|
||||
)
|
||||
|
||||
with stream_ordering_manager as stream_orderings:
|
||||
async with stream_ordering_manager as stream_orderings:
|
||||
await self.db_pool.runInteraction(
|
||||
"update_presence",
|
||||
self._update_presence_txn,
|
||||
|
|
|
@ -338,7 +338,7 @@ class PushRuleStore(PushRulesWorkerStore):
|
|||
) -> None:
|
||||
conditions_json = json_encoder.encode(conditions)
|
||||
actions_json = json_encoder.encode(actions)
|
||||
with await self._push_rules_stream_id_gen.get_next() as stream_id:
|
||||
async with self._push_rules_stream_id_gen.get_next() as stream_id:
|
||||
event_stream_ordering = self._stream_id_gen.get_current_token()
|
||||
|
||||
if before or after:
|
||||
|
@ -585,7 +585,7 @@ class PushRuleStore(PushRulesWorkerStore):
|
|||
txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE"
|
||||
)
|
||||
|
||||
with await self._push_rules_stream_id_gen.get_next() as stream_id:
|
||||
async with self._push_rules_stream_id_gen.get_next() as stream_id:
|
||||
event_stream_ordering = self._stream_id_gen.get_current_token()
|
||||
|
||||
await self.db_pool.runInteraction(
|
||||
|
@ -616,7 +616,7 @@ class PushRuleStore(PushRulesWorkerStore):
|
|||
Raises:
|
||||
NotFoundError if the rule does not exist.
|
||||
"""
|
||||
with await self._push_rules_stream_id_gen.get_next() as stream_id:
|
||||
async with self._push_rules_stream_id_gen.get_next() as stream_id:
|
||||
event_stream_ordering = self._stream_id_gen.get_current_token()
|
||||
await self.db_pool.runInteraction(
|
||||
"_set_push_rule_enabled_txn",
|
||||
|
@ -754,7 +754,7 @@ class PushRuleStore(PushRulesWorkerStore):
|
|||
data={"actions": actions_json},
|
||||
)
|
||||
|
||||
with await self._push_rules_stream_id_gen.get_next() as stream_id:
|
||||
async with self._push_rules_stream_id_gen.get_next() as stream_id:
|
||||
event_stream_ordering = self._stream_id_gen.get_current_token()
|
||||
|
||||
await self.db_pool.runInteraction(
|
||||
|
|
|
@ -281,7 +281,7 @@ class PusherStore(PusherWorkerStore):
|
|||
last_stream_ordering,
|
||||
profile_tag="",
|
||||
) -> None:
|
||||
with await self._pushers_id_gen.get_next() as stream_id:
|
||||
async with self._pushers_id_gen.get_next() as stream_id:
|
||||
# no need to lock because `pushers` has a unique key on
|
||||
# (app_id, pushkey, user_name) so simple_upsert will retry
|
||||
await self.db_pool.simple_upsert(
|
||||
|
@ -344,7 +344,7 @@ class PusherStore(PusherWorkerStore):
|
|||
},
|
||||
)
|
||||
|
||||
with await self._pushers_id_gen.get_next() as stream_id:
|
||||
async with self._pushers_id_gen.get_next() as stream_id:
|
||||
await self.db_pool.runInteraction(
|
||||
"delete_pusher", delete_pusher_txn, stream_id
|
||||
)
|
||||
|
|
|
@ -577,7 +577,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
|
|||
"insert_receipt_conv", graph_to_linear
|
||||
)
|
||||
|
||||
with await self._receipts_id_gen.get_next() as stream_id:
|
||||
async with self._receipts_id_gen.get_next() as stream_id:
|
||||
event_ts = await self.db_pool.runInteraction(
|
||||
"insert_linearized_receipt",
|
||||
self.insert_linearized_receipt_txn,
|
||||
|
|
|
@ -116,6 +116,20 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||
desc="get_expiration_ts_for_user",
|
||||
)
|
||||
|
||||
async def is_account_expired(self, user_id: str, current_ts: int) -> bool:
|
||||
"""
|
||||
Returns whether an user account is expired.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID
|
||||
current_ts: The current timestamp
|
||||
|
||||
Returns:
|
||||
Whether the user account has expired
|
||||
"""
|
||||
expiration_ts = await self.get_expiration_ts_for_user(user_id)
|
||||
return expiration_ts is not None and current_ts >= expiration_ts
|
||||
|
||||
async def set_account_validity_for_user(
|
||||
self,
|
||||
user_id: str,
|
||||
|
@ -379,7 +393,7 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||
|
||||
async def get_user_by_external_id(
|
||||
self, auth_provider: str, external_id: str
|
||||
) -> str:
|
||||
) -> Optional[str]:
|
||||
"""Look up a user by their external auth id
|
||||
|
||||
Args:
|
||||
|
@ -387,7 +401,7 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||
external_id: id on that system
|
||||
|
||||
Returns:
|
||||
str|None: the mxid of the user, or None if they are not known
|
||||
the mxid of the user, or None if they are not known
|
||||
"""
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="user_external_ids",
|
||||
|
|
|
@ -1137,7 +1137,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
|||
},
|
||||
)
|
||||
|
||||
with await self._public_room_id_gen.get_next() as next_id:
|
||||
async with self._public_room_id_gen.get_next() as next_id:
|
||||
await self.db_pool.runInteraction(
|
||||
"store_room_txn", store_room_txn, next_id
|
||||
)
|
||||
|
@ -1204,7 +1204,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
|||
},
|
||||
)
|
||||
|
||||
with await self._public_room_id_gen.get_next() as next_id:
|
||||
async with self._public_room_id_gen.get_next() as next_id:
|
||||
await self.db_pool.runInteraction(
|
||||
"set_room_is_public", set_room_is_public_txn, next_id
|
||||
)
|
||||
|
@ -1284,7 +1284,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
|||
},
|
||||
)
|
||||
|
||||
with await self._public_room_id_gen.get_next() as next_id:
|
||||
async with self._public_room_id_gen.get_next() as next_id:
|
||||
await self.db_pool.runInteraction(
|
||||
"set_room_is_public_appservice",
|
||||
set_room_is_public_appservice_txn,
|
||||
|
@ -1328,6 +1328,101 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
|||
desc="add_event_report",
|
||||
)
|
||||
|
||||
async def get_event_reports_paginate(
|
||||
self,
|
||||
start: int,
|
||||
limit: int,
|
||||
direction: str = "b",
|
||||
user_id: Optional[str] = None,
|
||||
room_id: Optional[str] = None,
|
||||
) -> Tuple[List[Dict[str, Any]], int]:
|
||||
"""Retrieve a paginated list of event reports
|
||||
|
||||
Args:
|
||||
start: event offset to begin the query from
|
||||
limit: number of rows to retrieve
|
||||
direction: Whether to fetch the most recent first (`"b"`) or the
|
||||
oldest first (`"f"`)
|
||||
user_id: search for user_id. Ignored if user_id is None
|
||||
room_id: search for room_id. Ignored if room_id is None
|
||||
Returns:
|
||||
event_reports: json list of event reports
|
||||
count: total number of event reports matching the filter criteria
|
||||
"""
|
||||
|
||||
def _get_event_reports_paginate_txn(txn):
|
||||
filters = []
|
||||
args = []
|
||||
|
||||
if user_id:
|
||||
filters.append("er.user_id LIKE ?")
|
||||
args.extend(["%" + user_id + "%"])
|
||||
if room_id:
|
||||
filters.append("er.room_id LIKE ?")
|
||||
args.extend(["%" + room_id + "%"])
|
||||
|
||||
if direction == "b":
|
||||
order = "DESC"
|
||||
else:
|
||||
order = "ASC"
|
||||
|
||||
where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
|
||||
|
||||
sql = """
|
||||
SELECT COUNT(*) as total_event_reports
|
||||
FROM event_reports AS er
|
||||
{}
|
||||
""".format(
|
||||
where_clause
|
||||
)
|
||||
txn.execute(sql, args)
|
||||
count = txn.fetchone()[0]
|
||||
|
||||
sql = """
|
||||
SELECT
|
||||
er.id,
|
||||
er.received_ts,
|
||||
er.room_id,
|
||||
er.event_id,
|
||||
er.user_id,
|
||||
er.reason,
|
||||
er.content,
|
||||
events.sender,
|
||||
room_aliases.room_alias,
|
||||
event_json.json AS event_json
|
||||
FROM event_reports AS er
|
||||
LEFT JOIN room_aliases
|
||||
ON room_aliases.room_id = er.room_id
|
||||
JOIN events
|
||||
ON events.event_id = er.event_id
|
||||
JOIN event_json
|
||||
ON event_json.event_id = er.event_id
|
||||
{where_clause}
|
||||
ORDER BY er.received_ts {order}
|
||||
LIMIT ?
|
||||
OFFSET ?
|
||||
""".format(
|
||||
where_clause=where_clause, order=order,
|
||||
)
|
||||
|
||||
args += [limit, start]
|
||||
txn.execute(sql, args)
|
||||
event_reports = self.db_pool.cursor_to_dict(txn)
|
||||
|
||||
if count > 0:
|
||||
for row in event_reports:
|
||||
try:
|
||||
row["content"] = db_to_json(row["content"])
|
||||
row["event_json"] = db_to_json(row["event_json"])
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return event_reports, count
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_event_reports_paginate", _get_event_reports_paginate_txn
|
||||
)
|
||||
|
||||
def get_current_public_room_stream_id(self):
|
||||
return self._public_room_id_gen.get_current_token()
|
||||
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set
|
||||
|
||||
|
@ -37,7 +36,7 @@ from synapse.storage.roommember import (
|
|||
ProfileInfo,
|
||||
RoomsForUser,
|
||||
)
|
||||
from synapse.types import Collection, get_domain_from_id
|
||||
from synapse.types import Collection, PersistedEventPosition, get_domain_from_id
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.caches import intern_string
|
||||
from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
|
||||
|
@ -387,7 +386,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
# for rooms the server is participating in.
|
||||
if self._current_state_events_membership_up_to_date:
|
||||
sql = """
|
||||
SELECT room_id, e.stream_ordering
|
||||
SELECT room_id, e.instance_name, e.stream_ordering
|
||||
FROM current_state_events AS c
|
||||
INNER JOIN events AS e USING (room_id, event_id)
|
||||
WHERE
|
||||
|
@ -397,7 +396,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
"""
|
||||
else:
|
||||
sql = """
|
||||
SELECT room_id, e.stream_ordering
|
||||
SELECT room_id, e.instance_name, e.stream_ordering
|
||||
FROM current_state_events AS c
|
||||
INNER JOIN room_memberships AS m USING (room_id, event_id)
|
||||
INNER JOIN events AS e USING (room_id, event_id)
|
||||
|
@ -408,7 +407,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
"""
|
||||
|
||||
txn.execute(sql, (user_id, Membership.JOIN))
|
||||
return frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn)
|
||||
return frozenset(
|
||||
GetRoomsForUserWithStreamOrdering(
|
||||
room_id, PersistedEventPosition(instance, stream_id)
|
||||
)
|
||||
for room_id, instance, stream_id in txn
|
||||
)
|
||||
|
||||
async def get_users_server_still_shares_room_with(
|
||||
self, user_ids: Collection[str]
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
-- room_id and topoligical_ordering are denormalised from the events table in order to
|
||||
-- room_id and topological_ordering are denormalised from the events table in order to
|
||||
-- make the index work.
|
||||
CREATE TABLE IF NOT EXISTS event_labels (
|
||||
event_id TEXT,
|
||||
|
|
|
@ -21,6 +21,8 @@ SELECT setval('events_stream_seq', (
|
|||
|
||||
CREATE SEQUENCE IF NOT EXISTS events_backfill_stream_seq;
|
||||
|
||||
-- If the server has never backfilled a room then doing `-MIN(...)` will give
|
||||
-- a negative result, hence why we do `GREATEST(...)`
|
||||
SELECT setval('events_backfill_stream_seq', (
|
||||
SELECT COALESCE(-MIN(stream_ordering), 1) FROM events
|
||||
SELECT GREATEST(COALESCE(-MIN(stream_ordering), 1), 1) FROM events
|
||||
));
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
/* Copyright 2020 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
CREATE TABLE stream_positions (
|
||||
stream_name TEXT NOT NULL,
|
||||
instance_name TEXT NOT NULL,
|
||||
stream_id BIGINT NOT NULL
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX stream_positions_idx ON stream_positions(stream_name, instance_name);
|
|
@ -210,6 +210,7 @@ class StatsStore(StateDeltasStore):
|
|||
* topic
|
||||
* avatar
|
||||
* canonical_alias
|
||||
* guest_access
|
||||
|
||||
A is_federatable key can also be included with a boolean value.
|
||||
|
||||
|
@ -234,6 +235,7 @@ class StatsStore(StateDeltasStore):
|
|||
"topic",
|
||||
"avatar",
|
||||
"canonical_alias",
|
||||
"guest_access",
|
||||
):
|
||||
field = fields.get(col, sentinel)
|
||||
if field is not sentinel and (not isinstance(field, str) or "\0" in field):
|
||||
|
|
|
@ -210,7 +210,7 @@ class TagsStore(TagsWorkerStore):
|
|||
)
|
||||
self._update_revision_txn(txn, user_id, room_id, next_id)
|
||||
|
||||
with await self._account_data_id_gen.get_next() as next_id:
|
||||
async with self._account_data_id_gen.get_next() as next_id:
|
||||
await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id)
|
||||
|
||||
self.get_tags_for_user.invalidate((user_id,))
|
||||
|
@ -232,7 +232,7 @@ class TagsStore(TagsWorkerStore):
|
|||
txn.execute(sql, (user_id, room_id, tag))
|
||||
self._update_revision_txn(txn, user_id, room_id, next_id)
|
||||
|
||||
with await self._account_data_id_gen.get_next() as next_id:
|
||||
async with self._account_data_id_gen.get_next() as next_id:
|
||||
await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id)
|
||||
|
||||
self.get_tags_for_user.invalidate((user_id,))
|
||||
|
|
|
@ -100,7 +100,7 @@ class UserErasureStore(UserErasureWorkerStore):
|
|||
return
|
||||
|
||||
# They are there, delete them.
|
||||
self.simple_delete_one_txn(
|
||||
self.db_pool.simple_delete_one_txn(
|
||||
txn, "erased_users", keyvalues={"user_id": user_id}
|
||||
)
|
||||
|
||||
|
|
|
@ -31,7 +31,7 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda
|
|||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.storage.databases import Databases
|
||||
from synapse.storage.databases.main.events import DeltaState
|
||||
from synapse.types import Collection, StateMap
|
||||
from synapse.types import Collection, PersistedEventPosition, RoomStreamToken, StateMap
|
||||
from synapse.util.async_helpers import ObservableDeferred
|
||||
from synapse.util.metrics import Measure
|
||||
|
||||
|
@ -190,6 +190,7 @@ class EventsPersistenceStorage:
|
|||
self.persist_events_store = stores.persist_events
|
||||
|
||||
self._clock = hs.get_clock()
|
||||
self._instance_name = hs.get_instance_name()
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
self._event_persist_queue = _EventPeristenceQueue()
|
||||
self._state_resolution_handler = hs.get_state_resolution_handler()
|
||||
|
@ -198,7 +199,7 @@ class EventsPersistenceStorage:
|
|||
self,
|
||||
events_and_contexts: List[Tuple[EventBase, EventContext]],
|
||||
backfilled: bool = False,
|
||||
) -> int:
|
||||
) -> RoomStreamToken:
|
||||
"""
|
||||
Write events to the database
|
||||
Args:
|
||||
|
@ -228,11 +229,11 @@ class EventsPersistenceStorage:
|
|||
defer.gatherResults(deferreds, consumeErrors=True)
|
||||
)
|
||||
|
||||
return self.main_store.get_current_events_token()
|
||||
return RoomStreamToken(None, self.main_store.get_current_events_token())
|
||||
|
||||
async def persist_event(
|
||||
self, event: EventBase, context: EventContext, backfilled: bool = False
|
||||
) -> Tuple[int, int]:
|
||||
) -> Tuple[PersistedEventPosition, RoomStreamToken]:
|
||||
"""
|
||||
Returns:
|
||||
The stream ordering of `event`, and the stream ordering of the
|
||||
|
@ -247,7 +248,10 @@ class EventsPersistenceStorage:
|
|||
await make_deferred_yieldable(deferred)
|
||||
|
||||
max_persisted_id = self.main_store.get_current_events_token()
|
||||
return (event.internal_metadata.stream_ordering, max_persisted_id)
|
||||
event_stream_id = event.internal_metadata.stream_ordering
|
||||
|
||||
pos = PersistedEventPosition(self._instance_name, event_stream_id)
|
||||
return pos, RoomStreamToken(None, max_persisted_id)
|
||||
|
||||
def _maybe_start_persisting(self, room_id: str):
|
||||
async def persisting_queue(item):
|
||||
|
|
|
@ -25,7 +25,7 @@ RoomsForUser = namedtuple(
|
|||
)
|
||||
|
||||
GetRoomsForUserWithStreamOrdering = namedtuple(
|
||||
"_GetRoomsForUserWithStreamOrdering", ("room_id", "stream_ordering")
|
||||
"_GetRoomsForUserWithStreamOrdering", ("room_id", "event_pos")
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -12,16 +12,17 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
import heapq
|
||||
import logging
|
||||
import threading
|
||||
from collections import deque
|
||||
from typing import Dict, List, Set
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, List, Optional, Set, Union
|
||||
|
||||
import attr
|
||||
from typing_extensions import Deque
|
||||
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.storage.database import DatabasePool, LoggingTransaction
|
||||
from synapse.storage.util.sequence import PostgresSequenceGenerator
|
||||
|
||||
|
@ -86,7 +87,7 @@ class StreamIdGenerator:
|
|||
upwards, -1 to grow downwards.
|
||||
|
||||
Usage:
|
||||
with await stream_id_gen.get_next() as stream_id:
|
||||
async with stream_id_gen.get_next() as stream_id:
|
||||
# ... persist event ...
|
||||
"""
|
||||
|
||||
|
@ -101,10 +102,10 @@ class StreamIdGenerator:
|
|||
)
|
||||
self._unfinished_ids = deque() # type: Deque[int]
|
||||
|
||||
async def get_next(self):
|
||||
def get_next(self):
|
||||
"""
|
||||
Usage:
|
||||
with await stream_id_gen.get_next() as stream_id:
|
||||
async with stream_id_gen.get_next() as stream_id:
|
||||
# ... persist event ...
|
||||
"""
|
||||
with self._lock:
|
||||
|
@ -113,7 +114,7 @@ class StreamIdGenerator:
|
|||
|
||||
self._unfinished_ids.append(next_id)
|
||||
|
||||
@contextlib.contextmanager
|
||||
@contextmanager
|
||||
def manager():
|
||||
try:
|
||||
yield next_id
|
||||
|
@ -121,12 +122,12 @@ class StreamIdGenerator:
|
|||
with self._lock:
|
||||
self._unfinished_ids.remove(next_id)
|
||||
|
||||
return manager()
|
||||
return _AsyncCtxManagerWrapper(manager())
|
||||
|
||||
async def get_next_mult(self, n):
|
||||
def get_next_mult(self, n):
|
||||
"""
|
||||
Usage:
|
||||
with await stream_id_gen.get_next(n) as stream_ids:
|
||||
async with stream_id_gen.get_next(n) as stream_ids:
|
||||
# ... persist events ...
|
||||
"""
|
||||
with self._lock:
|
||||
|
@ -140,7 +141,7 @@ class StreamIdGenerator:
|
|||
for next_id in next_ids:
|
||||
self._unfinished_ids.append(next_id)
|
||||
|
||||
@contextlib.contextmanager
|
||||
@contextmanager
|
||||
def manager():
|
||||
try:
|
||||
yield next_ids
|
||||
|
@ -149,7 +150,7 @@ class StreamIdGenerator:
|
|||
for next_id in next_ids:
|
||||
self._unfinished_ids.remove(next_id)
|
||||
|
||||
return manager()
|
||||
return _AsyncCtxManagerWrapper(manager())
|
||||
|
||||
def get_current_token(self):
|
||||
"""Returns the maximum stream id such that all stream ids less than or
|
||||
|
@ -184,12 +185,16 @@ class MultiWriterIdGenerator:
|
|||
Args:
|
||||
db_conn
|
||||
db
|
||||
stream_name: A name for the stream.
|
||||
instance_name: The name of this instance.
|
||||
table: Database table associated with stream.
|
||||
instance_column: Column that stores the row's writer's instance name
|
||||
id_column: Column that stores the stream ID.
|
||||
sequence_name: The name of the postgres sequence used to generate new
|
||||
IDs.
|
||||
writers: A list of known writers to use to populate current positions
|
||||
on startup. Can be empty if nothing uses `get_current_token` or
|
||||
`get_positions` (e.g. caches stream).
|
||||
positive: Whether the IDs are positive (true) or negative (false).
|
||||
When using negative IDs we go backwards from -1 to -2, -3, etc.
|
||||
"""
|
||||
|
@ -198,16 +203,20 @@ class MultiWriterIdGenerator:
|
|||
self,
|
||||
db_conn,
|
||||
db: DatabasePool,
|
||||
stream_name: str,
|
||||
instance_name: str,
|
||||
table: str,
|
||||
instance_column: str,
|
||||
id_column: str,
|
||||
sequence_name: str,
|
||||
writers: List[str],
|
||||
positive: bool = True,
|
||||
):
|
||||
self._db = db
|
||||
self._stream_name = stream_name
|
||||
self._instance_name = instance_name
|
||||
self._positive = positive
|
||||
self._writers = writers
|
||||
self._return_factor = 1 if positive else -1
|
||||
|
||||
# We lock as some functions may be called from DB threads.
|
||||
|
@ -216,9 +225,7 @@ class MultiWriterIdGenerator:
|
|||
# Note: If we are a negative stream then we still store all the IDs as
|
||||
# positive to make life easier for us, and simply negate the IDs when we
|
||||
# return them.
|
||||
self._current_positions = self._load_current_ids(
|
||||
db_conn, table, instance_column, id_column
|
||||
)
|
||||
self._current_positions = {} # type: Dict[str, int]
|
||||
|
||||
# Set of local IDs that we're still processing. The current position
|
||||
# should be less than the minimum of this set (if not empty).
|
||||
|
@ -251,90 +258,108 @@ class MultiWriterIdGenerator:
|
|||
|
||||
self._sequence_gen = PostgresSequenceGenerator(sequence_name)
|
||||
|
||||
# This goes and fills out the above state from the database.
|
||||
self._load_current_ids(db_conn, table, instance_column, id_column)
|
||||
|
||||
def _load_current_ids(
|
||||
self, db_conn, table: str, instance_column: str, id_column: str
|
||||
) -> Dict[str, int]:
|
||||
# If positive stream aggregate via MAX. For negative stream use MIN
|
||||
# *and* negate the result to get a positive number.
|
||||
sql = """
|
||||
SELECT %(instance)s, %(agg)s(%(id)s) FROM %(table)s
|
||||
GROUP BY %(instance)s
|
||||
""" % {
|
||||
"instance": instance_column,
|
||||
"id": id_column,
|
||||
"table": table,
|
||||
"agg": "MAX" if self._positive else "-MIN",
|
||||
}
|
||||
|
||||
):
|
||||
cur = db_conn.cursor()
|
||||
cur.execute(sql)
|
||||
|
||||
# `cur` is an iterable over returned rows, which are 2-tuples.
|
||||
current_positions = dict(cur)
|
||||
# Load the current positions of all writers for the stream.
|
||||
if self._writers:
|
||||
sql = """
|
||||
SELECT instance_name, stream_id FROM stream_positions
|
||||
WHERE stream_name = ?
|
||||
"""
|
||||
sql = self._db.engine.convert_param_style(sql)
|
||||
|
||||
cur.execute(sql, (self._stream_name,))
|
||||
|
||||
self._current_positions = {
|
||||
instance: stream_id * self._return_factor
|
||||
for instance, stream_id in cur
|
||||
if instance in self._writers
|
||||
}
|
||||
|
||||
# We set the `_persisted_upto_position` to be the minimum of all current
|
||||
# positions. If empty we use the max stream ID from the DB table.
|
||||
min_stream_id = min(self._current_positions.values(), default=None)
|
||||
|
||||
if min_stream_id is None:
|
||||
# We add a GREATEST here to ensure that the result is always
|
||||
# positive. (This can be a problem for e.g. backfill streams where
|
||||
# the server has never backfilled).
|
||||
sql = """
|
||||
SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1)
|
||||
FROM %(table)s
|
||||
""" % {
|
||||
"id": id_column,
|
||||
"table": table,
|
||||
"agg": "MAX" if self._positive else "-MIN",
|
||||
}
|
||||
cur.execute(sql)
|
||||
(stream_id,) = cur.fetchone()
|
||||
self._persisted_upto_position = stream_id
|
||||
else:
|
||||
# If we have a min_stream_id then we pull out everything greater
|
||||
# than it from the DB so that we can prefill
|
||||
# `_known_persisted_positions` and get a more accurate
|
||||
# `_persisted_upto_position`.
|
||||
#
|
||||
# We also check if any of the later rows are from this instance, in
|
||||
# which case we use that for this instance's current position. This
|
||||
# is to handle the case where we didn't finish persisting to the
|
||||
# stream positions table before restart (or the stream position
|
||||
# table otherwise got out of date).
|
||||
|
||||
sql = """
|
||||
SELECT %(instance)s, %(id)s FROM %(table)s
|
||||
WHERE ? %(cmp)s %(id)s
|
||||
""" % {
|
||||
"id": id_column,
|
||||
"table": table,
|
||||
"instance": instance_column,
|
||||
"cmp": "<=" if self._positive else ">=",
|
||||
}
|
||||
sql = self._db.engine.convert_param_style(sql)
|
||||
cur.execute(sql, (min_stream_id,))
|
||||
|
||||
self._persisted_upto_position = min_stream_id
|
||||
|
||||
with self._lock:
|
||||
for (instance, stream_id,) in cur:
|
||||
stream_id = self._return_factor * stream_id
|
||||
self._add_persisted_position(stream_id)
|
||||
|
||||
if instance == self._instance_name:
|
||||
self._current_positions[instance] = stream_id
|
||||
|
||||
cur.close()
|
||||
|
||||
return current_positions
|
||||
|
||||
def _load_next_id_txn(self, txn) -> int:
|
||||
return self._sequence_gen.get_next_id_txn(txn)
|
||||
|
||||
def _load_next_mult_id_txn(self, txn, n: int) -> List[int]:
|
||||
return self._sequence_gen.get_next_mult_txn(txn, n)
|
||||
|
||||
async def get_next(self):
|
||||
def get_next(self):
|
||||
"""
|
||||
Usage:
|
||||
with await stream_id_gen.get_next() as stream_id:
|
||||
async with stream_id_gen.get_next() as stream_id:
|
||||
# ... persist event ...
|
||||
"""
|
||||
next_id = await self._db.runInteraction("_load_next_id", self._load_next_id_txn)
|
||||
|
||||
# Assert the fetched ID is actually greater than what we currently
|
||||
# believe the ID to be. If not, then the sequence and table have got
|
||||
# out of sync somehow.
|
||||
with self._lock:
|
||||
assert self._current_positions.get(self._instance_name, 0) < next_id
|
||||
return _MultiWriterCtxManager(self)
|
||||
|
||||
self._unfinished_ids.add(next_id)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def manager():
|
||||
try:
|
||||
# Multiply by the return factor so that the ID has correct sign.
|
||||
yield self._return_factor * next_id
|
||||
finally:
|
||||
self._mark_id_as_finished(next_id)
|
||||
|
||||
return manager()
|
||||
|
||||
async def get_next_mult(self, n: int):
|
||||
def get_next_mult(self, n: int):
|
||||
"""
|
||||
Usage:
|
||||
with await stream_id_gen.get_next_mult(5) as stream_ids:
|
||||
async with stream_id_gen.get_next_mult(5) as stream_ids:
|
||||
# ... persist events ...
|
||||
"""
|
||||
next_ids = await self._db.runInteraction(
|
||||
"_load_next_mult_id", self._load_next_mult_id_txn, n
|
||||
)
|
||||
|
||||
# Assert the fetched ID is actually greater than any ID we've already
|
||||
# seen. If not, then the sequence and table have got out of sync
|
||||
# somehow.
|
||||
with self._lock:
|
||||
assert max(self._current_positions.values(), default=0) < min(next_ids)
|
||||
|
||||
self._unfinished_ids.update(next_ids)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def manager():
|
||||
try:
|
||||
yield [self._return_factor * i for i in next_ids]
|
||||
finally:
|
||||
for i in next_ids:
|
||||
self._mark_id_as_finished(i)
|
||||
|
||||
return manager()
|
||||
return _MultiWriterCtxManager(self, n)
|
||||
|
||||
def get_next_txn(self, txn: LoggingTransaction):
|
||||
"""
|
||||
|
@ -352,6 +377,21 @@ class MultiWriterIdGenerator:
|
|||
txn.call_after(self._mark_id_as_finished, next_id)
|
||||
txn.call_on_exception(self._mark_id_as_finished, next_id)
|
||||
|
||||
# Update the `stream_positions` table with newly updated stream
|
||||
# ID (unless self._writers is not set in which case we don't
|
||||
# bother, as nothing will read it).
|
||||
#
|
||||
# We only do this on the success path so that the persisted current
|
||||
# position points to a persited row with the correct instance name.
|
||||
if self._writers:
|
||||
txn.call_after(
|
||||
run_as_background_process,
|
||||
"MultiWriterIdGenerator._update_table",
|
||||
self._db.runInteraction,
|
||||
"MultiWriterIdGenerator._update_table",
|
||||
self._update_stream_positions_table_txn,
|
||||
)
|
||||
|
||||
return self._return_factor * next_id
|
||||
|
||||
def _mark_id_as_finished(self, next_id: int):
|
||||
|
@ -482,3 +522,95 @@ class MultiWriterIdGenerator:
|
|||
# There was a gap in seen positions, so there is nothing more to
|
||||
# do.
|
||||
break
|
||||
|
||||
def _update_stream_positions_table_txn(self, txn):
|
||||
"""Update the `stream_positions` table with newly persisted position.
|
||||
"""
|
||||
|
||||
if not self._writers:
|
||||
return
|
||||
|
||||
# We upsert the value, ensuring on conflict that we always increase the
|
||||
# value (or decrease if stream goes backwards).
|
||||
sql = """
|
||||
INSERT INTO stream_positions (stream_name, instance_name, stream_id)
|
||||
VALUES (?, ?, ?)
|
||||
ON CONFLICT (stream_name, instance_name)
|
||||
DO UPDATE SET
|
||||
stream_id = %(agg)s(stream_positions.stream_id, EXCLUDED.stream_id)
|
||||
""" % {
|
||||
"agg": "GREATEST" if self._positive else "LEAST",
|
||||
}
|
||||
|
||||
pos = (self.get_current_token_for_writer(self._instance_name),)
|
||||
txn.execute(sql, (self._stream_name, self._instance_name, pos))
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
class _AsyncCtxManagerWrapper:
|
||||
"""Helper class to convert a plain context manager to an async one.
|
||||
|
||||
This is mainly useful if you have a plain context manager but the interface
|
||||
requires an async one.
|
||||
"""
|
||||
|
||||
inner = attr.ib()
|
||||
|
||||
async def __aenter__(self):
|
||||
return self.inner.__enter__()
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return self.inner.__exit__(exc_type, exc, tb)
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
class _MultiWriterCtxManager:
|
||||
"""Async context manager returned by MultiWriterIdGenerator
|
||||
"""
|
||||
|
||||
id_gen = attr.ib(type=MultiWriterIdGenerator)
|
||||
multiple_ids = attr.ib(type=Optional[int], default=None)
|
||||
stream_ids = attr.ib(type=List[int], factory=list)
|
||||
|
||||
async def __aenter__(self) -> Union[int, List[int]]:
|
||||
self.stream_ids = await self.id_gen._db.runInteraction(
|
||||
"_load_next_mult_id",
|
||||
self.id_gen._load_next_mult_id_txn,
|
||||
self.multiple_ids or 1,
|
||||
)
|
||||
|
||||
# Assert the fetched ID is actually greater than any ID we've already
|
||||
# seen. If not, then the sequence and table have got out of sync
|
||||
# somehow.
|
||||
with self.id_gen._lock:
|
||||
assert max(self.id_gen._current_positions.values(), default=0) < min(
|
||||
self.stream_ids
|
||||
)
|
||||
|
||||
self.id_gen._unfinished_ids.update(self.stream_ids)
|
||||
|
||||
if self.multiple_ids is None:
|
||||
return self.stream_ids[0] * self.id_gen._return_factor
|
||||
else:
|
||||
return [i * self.id_gen._return_factor for i in self.stream_ids]
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
for i in self.stream_ids:
|
||||
self.id_gen._mark_id_as_finished(i)
|
||||
|
||||
if exc_type is not None:
|
||||
return False
|
||||
|
||||
# Update the `stream_positions` table with newly updated stream
|
||||
# ID (unless self._writers is not set in which case we don't
|
||||
# bother, as nothing will read it).
|
||||
#
|
||||
# We only do this on the success path so that the persisted current
|
||||
# position points to a persisted row with the correct instance name.
|
||||
if self.id_gen._writers:
|
||||
await self.id_gen._db.runInteraction(
|
||||
"MultiWriterIdGenerator._update_table",
|
||||
self.id_gen._update_stream_positions_table_txn,
|
||||
)
|
||||
|
||||
return False
|
||||
|
|
|
@ -495,6 +495,21 @@ class StreamToken:
|
|||
StreamToken.START = StreamToken.from_string("s0_0")
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True)
|
||||
class PersistedEventPosition:
|
||||
"""Position of a newly persisted event with instance that persisted it.
|
||||
|
||||
This can be used to test whether the event is persisted before or after a
|
||||
RoomStreamToken.
|
||||
"""
|
||||
|
||||
instance_name = attr.ib(type=str)
|
||||
stream = attr.ib(type=int)
|
||||
|
||||
def persisted_after(self, token: RoomStreamToken) -> bool:
|
||||
return token.stream < self.stream
|
||||
|
||||
|
||||
class ThirdPartyInstanceID(
|
||||
namedtuple("ThirdPartyInstanceID", ("appservice_id", "network_id"))
|
||||
):
|
||||
|
|
|
@ -23,6 +23,7 @@ from nacl.signing import SigningKey
|
|||
from signedjson.key import encode_verify_key_base64, get_verify_key
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.defer import Deferred, ensureDeferred
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.crypto import keyring
|
||||
|
@ -33,7 +34,6 @@ from synapse.crypto.keyring import (
|
|||
)
|
||||
from synapse.logging.context import (
|
||||
LoggingContext,
|
||||
PreserveLoggingContext,
|
||||
current_context,
|
||||
make_deferred_yieldable,
|
||||
)
|
||||
|
@ -68,54 +68,40 @@ class MockPerspectiveServer:
|
|||
|
||||
|
||||
class KeyringTestCase(unittest.HomeserverTestCase):
|
||||
def make_homeserver(self, reactor, clock):
|
||||
self.mock_perspective_server = MockPerspectiveServer()
|
||||
self.http_client = Mock()
|
||||
|
||||
config = self.default_config()
|
||||
config["trusted_key_servers"] = [
|
||||
{
|
||||
"server_name": self.mock_perspective_server.server_name,
|
||||
"verify_keys": self.mock_perspective_server.get_verify_keys(),
|
||||
}
|
||||
]
|
||||
|
||||
return self.setup_test_homeserver(
|
||||
handlers=None, http_client=self.http_client, config=config
|
||||
)
|
||||
|
||||
def check_context(self, _, expected):
|
||||
def check_context(self, val, expected):
|
||||
self.assertEquals(getattr(current_context(), "request", None), expected)
|
||||
return val
|
||||
|
||||
def test_verify_json_objects_for_server_awaits_previous_requests(self):
|
||||
key1 = signedjson.key.generate_signing_key(1)
|
||||
mock_fetcher = keyring.KeyFetcher()
|
||||
mock_fetcher.get_keys = Mock()
|
||||
kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,))
|
||||
|
||||
kr = keyring.Keyring(self.hs)
|
||||
# a signed object that we are going to try to validate
|
||||
key1 = signedjson.key.generate_signing_key(1)
|
||||
json1 = {}
|
||||
signedjson.sign.sign_json(json1, "server10", key1)
|
||||
|
||||
persp_resp = {
|
||||
"server_keys": [
|
||||
self.mock_perspective_server.get_signed_key(
|
||||
"server10", signedjson.key.get_verify_key(key1)
|
||||
)
|
||||
]
|
||||
}
|
||||
persp_deferred = defer.Deferred()
|
||||
# start off a first set of lookups. We make the mock fetcher block until this
|
||||
# deferred completes.
|
||||
first_lookup_deferred = Deferred()
|
||||
|
||||
async def get_perspectives(**kwargs):
|
||||
self.assertEquals(current_context().request, "11")
|
||||
with PreserveLoggingContext():
|
||||
await persp_deferred
|
||||
return persp_resp
|
||||
async def first_lookup_fetch(keys_to_fetch):
|
||||
self.assertEquals(current_context().request, "context_11")
|
||||
self.assertEqual(keys_to_fetch, {"server10": {get_key_id(key1): 0}})
|
||||
|
||||
self.http_client.post_json.side_effect = get_perspectives
|
||||
await make_deferred_yieldable(first_lookup_deferred)
|
||||
return {
|
||||
"server10": {
|
||||
get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)
|
||||
}
|
||||
}
|
||||
|
||||
# start off a first set of lookups
|
||||
@defer.inlineCallbacks
|
||||
def first_lookup():
|
||||
with LoggingContext("11") as context_11:
|
||||
context_11.request = "11"
|
||||
mock_fetcher.get_keys.side_effect = first_lookup_fetch
|
||||
|
||||
async def first_lookup():
|
||||
with LoggingContext("context_11") as context_11:
|
||||
context_11.request = "context_11"
|
||||
|
||||
res_deferreds = kr.verify_json_objects_for_server(
|
||||
[("server10", json1, 0, "test10"), ("server11", {}, 0, "test11")]
|
||||
|
@ -124,7 +110,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
|
|||
# the unsigned json should be rejected pretty quickly
|
||||
self.assertTrue(res_deferreds[1].called)
|
||||
try:
|
||||
yield res_deferreds[1]
|
||||
await res_deferreds[1]
|
||||
self.assertFalse("unsigned json didn't cause a failure")
|
||||
except SynapseError:
|
||||
pass
|
||||
|
@ -132,45 +118,51 @@ class KeyringTestCase(unittest.HomeserverTestCase):
|
|||
self.assertFalse(res_deferreds[0].called)
|
||||
res_deferreds[0].addBoth(self.check_context, None)
|
||||
|
||||
yield make_deferred_yieldable(res_deferreds[0])
|
||||
await make_deferred_yieldable(res_deferreds[0])
|
||||
|
||||
# let verify_json_objects_for_server finish its work before we kill the
|
||||
# logcontext
|
||||
yield self.clock.sleep(0)
|
||||
d0 = ensureDeferred(first_lookup())
|
||||
|
||||
d0 = first_lookup()
|
||||
|
||||
# wait a tick for it to send the request to the perspectives server
|
||||
# (it first tries the datastore)
|
||||
self.pump()
|
||||
self.http_client.post_json.assert_called_once()
|
||||
mock_fetcher.get_keys.assert_called_once()
|
||||
|
||||
# a second request for a server with outstanding requests
|
||||
# should block rather than start a second call
|
||||
@defer.inlineCallbacks
|
||||
def second_lookup():
|
||||
with LoggingContext("12") as context_12:
|
||||
context_12.request = "12"
|
||||
self.http_client.post_json.reset_mock()
|
||||
self.http_client.post_json.return_value = defer.Deferred()
|
||||
|
||||
async def second_lookup_fetch(keys_to_fetch):
|
||||
self.assertEquals(current_context().request, "context_12")
|
||||
return {
|
||||
"server10": {
|
||||
get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)
|
||||
}
|
||||
}
|
||||
|
||||
mock_fetcher.get_keys.reset_mock()
|
||||
mock_fetcher.get_keys.side_effect = second_lookup_fetch
|
||||
second_lookup_state = [0]
|
||||
|
||||
async def second_lookup():
|
||||
with LoggingContext("context_12") as context_12:
|
||||
context_12.request = "context_12"
|
||||
|
||||
res_deferreds_2 = kr.verify_json_objects_for_server(
|
||||
[("server10", json1, 0, "test")]
|
||||
)
|
||||
res_deferreds_2[0].addBoth(self.check_context, None)
|
||||
yield make_deferred_yieldable(res_deferreds_2[0])
|
||||
second_lookup_state[0] = 1
|
||||
await make_deferred_yieldable(res_deferreds_2[0])
|
||||
second_lookup_state[0] = 2
|
||||
|
||||
# let verify_json_objects_for_server finish its work before we kill the
|
||||
# logcontext
|
||||
yield self.clock.sleep(0)
|
||||
|
||||
d2 = second_lookup()
|
||||
d2 = ensureDeferred(second_lookup())
|
||||
|
||||
self.pump()
|
||||
self.http_client.post_json.assert_not_called()
|
||||
# the second request should be pending, but the fetcher should not yet have been
|
||||
# called
|
||||
self.assertEqual(second_lookup_state[0], 1)
|
||||
mock_fetcher.get_keys.assert_not_called()
|
||||
|
||||
# complete the first request
|
||||
persp_deferred.callback(persp_resp)
|
||||
first_lookup_deferred.callback(None)
|
||||
|
||||
# and now both verifications should succeed.
|
||||
self.get_success(d0)
|
||||
self.get_success(d2)
|
||||
|
||||
|
|
|
@ -35,6 +35,17 @@ class DeviceTestCase(unittest.HomeserverTestCase):
|
|||
# These tests assume that it starts 1000 seconds in.
|
||||
self.reactor.advance(1000)
|
||||
|
||||
def test_device_is_created_with_invalid_name(self):
|
||||
self.get_failure(
|
||||
self.handler.check_device_registered(
|
||||
user_id="@boris:foo",
|
||||
device_id="foo",
|
||||
initial_device_display_name="a"
|
||||
* (synapse.handlers.device.MAX_DEVICE_DISPLAY_NAME_LEN + 1),
|
||||
),
|
||||
synapse.api.errors.SynapseError,
|
||||
)
|
||||
|
||||
def test_device_is_created_if_doesnt_exist(self):
|
||||
res = self.get_success(
|
||||
self.handler.check_device_registered(
|
||||
|
|
|
@ -617,3 +617,38 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||
)
|
||||
)
|
||||
self.assertEqual(mxid, "@test_user_2:test")
|
||||
|
||||
# Test if the mxid is already taken
|
||||
store = self.hs.get_datastore()
|
||||
user3 = UserID.from_string("@test_user_3:test")
|
||||
self.get_success(
|
||||
store.register_user(user_id=user3.to_string(), password_hash=None)
|
||||
)
|
||||
userinfo = {"sub": "test3", "username": "test_user_3"}
|
||||
e = self.get_failure(
|
||||
self.handler._map_userinfo_to_user(
|
||||
userinfo, token, "user-agent", "10.10.10.10"
|
||||
),
|
||||
MappingException,
|
||||
)
|
||||
self.assertEqual(str(e.value), "mxid '@test_user_3:test' is already taken")
|
||||
|
||||
@override_config({"oidc_config": {"allow_existing_users": True}})
|
||||
def test_map_userinfo_to_existing_user(self):
|
||||
"""Existing users can log in with OpenID Connect when allow_existing_users is True."""
|
||||
store = self.hs.get_datastore()
|
||||
user4 = UserID.from_string("@test_user_4:test")
|
||||
self.get_success(
|
||||
store.register_user(user_id=user4.to_string(), password_hash=None)
|
||||
)
|
||||
userinfo = {
|
||||
"sub": "test4",
|
||||
"username": "test_user_4",
|
||||
}
|
||||
token = {}
|
||||
mxid = self.get_success(
|
||||
self.handler._map_userinfo_to_user(
|
||||
userinfo, token, "user-agent", "10.10.10.10"
|
||||
)
|
||||
)
|
||||
self.assertEqual(mxid, "@test_user_4:test")
|
||||
|
|
|
@ -20,6 +20,7 @@ from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_
|
|||
from synapse.handlers.room import RoomEventSource
|
||||
from synapse.replication.slave.storage.events import SlavedEventStore
|
||||
from synapse.storage.roommember import RoomsForUser
|
||||
from synapse.types import PersistedEventPosition
|
||||
|
||||
from tests.server import FakeTransport
|
||||
|
||||
|
@ -204,10 +205,14 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
|
|||
type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join"
|
||||
)
|
||||
self.replicate()
|
||||
|
||||
expected_pos = PersistedEventPosition(
|
||||
"master", j2.internal_metadata.stream_ordering
|
||||
)
|
||||
self.check(
|
||||
"get_rooms_for_user_with_stream_ordering",
|
||||
(USER_ID_2,),
|
||||
{(ROOM_ID, j2.internal_metadata.stream_ordering)},
|
||||
{(ROOM_ID, expected_pos)},
|
||||
)
|
||||
|
||||
def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(self):
|
||||
|
@ -293,9 +298,10 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
|
|||
# the membership change is only any use to us if the room is in the
|
||||
# joined_rooms list.
|
||||
if membership_changes:
|
||||
self.assertEqual(
|
||||
joined_rooms, {(ROOM_ID, j2.internal_metadata.stream_ordering)}
|
||||
expected_pos = PersistedEventPosition(
|
||||
"master", j2.internal_metadata.stream_ordering
|
||||
)
|
||||
self.assertEqual(joined_rooms, {(ROOM_ID, expected_pos)})
|
||||
|
||||
event_id = 0
|
||||
|
||||
|
|
|
@ -221,7 +221,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
|
|||
self.render(request)
|
||||
|
||||
self.assertEqual(400, channel.code, msg=channel.json_body)
|
||||
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
|
||||
self.assertEqual(Codes.TOO_LARGE, channel.json_body["errcode"])
|
||||
|
||||
# Ensure the display name was not updated.
|
||||
request, channel = self.make_request(
|
||||
|
|
|
@ -0,0 +1,382 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2020 Dirk Klimpel
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
|
||||
import synapse.rest.admin
|
||||
from synapse.api.errors import Codes
|
||||
from synapse.rest.client.v1 import login, room
|
||||
from synapse.rest.client.v2_alpha import report_event
|
||||
|
||||
from tests import unittest
|
||||
|
||||
|
||||
class EventReportsTestCase(unittest.HomeserverTestCase):
|
||||
servlets = [
|
||||
synapse.rest.admin.register_servlets,
|
||||
login.register_servlets,
|
||||
room.register_servlets,
|
||||
report_event.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
self.admin_user = self.register_user("admin", "pass", admin=True)
|
||||
self.admin_user_tok = self.login("admin", "pass")
|
||||
|
||||
self.other_user = self.register_user("user", "pass")
|
||||
self.other_user_tok = self.login("user", "pass")
|
||||
|
||||
self.room_id1 = self.helper.create_room_as(
|
||||
self.other_user, tok=self.other_user_tok, is_public=True
|
||||
)
|
||||
self.helper.join(self.room_id1, user=self.admin_user, tok=self.admin_user_tok)
|
||||
|
||||
self.room_id2 = self.helper.create_room_as(
|
||||
self.other_user, tok=self.other_user_tok, is_public=True
|
||||
)
|
||||
self.helper.join(self.room_id2, user=self.admin_user, tok=self.admin_user_tok)
|
||||
|
||||
# Two rooms and two users. Every user sends and reports every room event
|
||||
for i in range(5):
|
||||
self._create_event_and_report(
|
||||
room_id=self.room_id1, user_tok=self.other_user_tok,
|
||||
)
|
||||
for i in range(5):
|
||||
self._create_event_and_report(
|
||||
room_id=self.room_id2, user_tok=self.other_user_tok,
|
||||
)
|
||||
for i in range(5):
|
||||
self._create_event_and_report(
|
||||
room_id=self.room_id1, user_tok=self.admin_user_tok,
|
||||
)
|
||||
for i in range(5):
|
||||
self._create_event_and_report(
|
||||
room_id=self.room_id2, user_tok=self.admin_user_tok,
|
||||
)
|
||||
|
||||
self.url = "/_synapse/admin/v1/event_reports"
|
||||
|
||||
def test_requester_is_no_admin(self):
|
||||
"""
|
||||
If the user is not a server admin, an error 403 is returned.
|
||||
"""
|
||||
|
||||
request, channel = self.make_request(
|
||||
"GET", self.url, access_token=self.other_user_tok,
|
||||
)
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
|
||||
|
||||
def test_default_success(self):
|
||||
"""
|
||||
Testing list of reported events
|
||||
"""
|
||||
|
||||
request, channel = self.make_request(
|
||||
"GET", self.url, access_token=self.admin_user_tok,
|
||||
)
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(channel.json_body["total"], 20)
|
||||
self.assertEqual(len(channel.json_body["event_reports"]), 20)
|
||||
self.assertNotIn("next_token", channel.json_body)
|
||||
self._check_fields(channel.json_body["event_reports"])
|
||||
|
||||
def test_limit(self):
|
||||
"""
|
||||
Testing list of reported events with limit
|
||||
"""
|
||||
|
||||
request, channel = self.make_request(
|
||||
"GET", self.url + "?limit=5", access_token=self.admin_user_tok,
|
||||
)
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(channel.json_body["total"], 20)
|
||||
self.assertEqual(len(channel.json_body["event_reports"]), 5)
|
||||
self.assertEqual(channel.json_body["next_token"], 5)
|
||||
self._check_fields(channel.json_body["event_reports"])
|
||||
|
||||
def test_from(self):
|
||||
"""
|
||||
Testing list of reported events with a defined starting point (from)
|
||||
"""
|
||||
|
||||
request, channel = self.make_request(
|
||||
"GET", self.url + "?from=5", access_token=self.admin_user_tok,
|
||||
)
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(channel.json_body["total"], 20)
|
||||
self.assertEqual(len(channel.json_body["event_reports"]), 15)
|
||||
self.assertNotIn("next_token", channel.json_body)
|
||||
self._check_fields(channel.json_body["event_reports"])
|
||||
|
||||
def test_limit_and_from(self):
|
||||
"""
|
||||
Testing list of reported events with a defined starting point and limit
|
||||
"""
|
||||
|
||||
request, channel = self.make_request(
|
||||
"GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok,
|
||||
)
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(channel.json_body["total"], 20)
|
||||
self.assertEqual(channel.json_body["next_token"], 15)
|
||||
self.assertEqual(len(channel.json_body["event_reports"]), 10)
|
||||
self._check_fields(channel.json_body["event_reports"])
|
||||
|
||||
def test_filter_room(self):
|
||||
"""
|
||||
Testing list of reported events with a filter of room
|
||||
"""
|
||||
|
||||
request, channel = self.make_request(
|
||||
"GET",
|
||||
self.url + "?room_id=%s" % self.room_id1,
|
||||
access_token=self.admin_user_tok,
|
||||
)
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(channel.json_body["total"], 10)
|
||||
self.assertEqual(len(channel.json_body["event_reports"]), 10)
|
||||
self.assertNotIn("next_token", channel.json_body)
|
||||
self._check_fields(channel.json_body["event_reports"])
|
||||
|
||||
for report in channel.json_body["event_reports"]:
|
||||
self.assertEqual(report["room_id"], self.room_id1)
|
||||
|
||||
def test_filter_user(self):
|
||||
"""
|
||||
Testing list of reported events with a filter of user
|
||||
"""
|
||||
|
||||
request, channel = self.make_request(
|
||||
"GET",
|
||||
self.url + "?user_id=%s" % self.other_user,
|
||||
access_token=self.admin_user_tok,
|
||||
)
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(channel.json_body["total"], 10)
|
||||
self.assertEqual(len(channel.json_body["event_reports"]), 10)
|
||||
self.assertNotIn("next_token", channel.json_body)
|
||||
self._check_fields(channel.json_body["event_reports"])
|
||||
|
||||
for report in channel.json_body["event_reports"]:
|
||||
self.assertEqual(report["user_id"], self.other_user)
|
||||
|
||||
def test_filter_user_and_room(self):
|
||||
"""
|
||||
Testing list of reported events with a filter of user and room
|
||||
"""
|
||||
|
||||
request, channel = self.make_request(
|
||||
"GET",
|
||||
self.url + "?user_id=%s&room_id=%s" % (self.other_user, self.room_id1),
|
||||
access_token=self.admin_user_tok,
|
||||
)
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(channel.json_body["total"], 5)
|
||||
self.assertEqual(len(channel.json_body["event_reports"]), 5)
|
||||
self.assertNotIn("next_token", channel.json_body)
|
||||
self._check_fields(channel.json_body["event_reports"])
|
||||
|
||||
for report in channel.json_body["event_reports"]:
|
||||
self.assertEqual(report["user_id"], self.other_user)
|
||||
self.assertEqual(report["room_id"], self.room_id1)
|
||||
|
||||
def test_valid_search_order(self):
|
||||
"""
|
||||
Testing search order. Order by timestamps.
|
||||
"""
|
||||
|
||||
# fetch the most recent first, largest timestamp
|
||||
request, channel = self.make_request(
|
||||
"GET", self.url + "?dir=b", access_token=self.admin_user_tok,
|
||||
)
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(channel.json_body["total"], 20)
|
||||
self.assertEqual(len(channel.json_body["event_reports"]), 20)
|
||||
report = 1
|
||||
while report < len(channel.json_body["event_reports"]):
|
||||
self.assertGreaterEqual(
|
||||
channel.json_body["event_reports"][report - 1]["received_ts"],
|
||||
channel.json_body["event_reports"][report]["received_ts"],
|
||||
)
|
||||
report += 1
|
||||
|
||||
# fetch the oldest first, smallest timestamp
|
||||
request, channel = self.make_request(
|
||||
"GET", self.url + "?dir=f", access_token=self.admin_user_tok,
|
||||
)
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(channel.json_body["total"], 20)
|
||||
self.assertEqual(len(channel.json_body["event_reports"]), 20)
|
||||
report = 1
|
||||
while report < len(channel.json_body["event_reports"]):
|
||||
self.assertLessEqual(
|
||||
channel.json_body["event_reports"][report - 1]["received_ts"],
|
||||
channel.json_body["event_reports"][report]["received_ts"],
|
||||
)
|
||||
report += 1
|
||||
|
||||
def test_invalid_search_order(self):
|
||||
"""
|
||||
Testing that a invalid search order returns a 400
|
||||
"""
|
||||
|
||||
request, channel = self.make_request(
|
||||
"GET", self.url + "?dir=bar", access_token=self.admin_user_tok,
|
||||
)
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
|
||||
self.assertEqual("Unknown direction: bar", channel.json_body["error"])
|
||||
|
||||
def test_limit_is_negative(self):
|
||||
"""
|
||||
Testing that a negative list parameter returns a 400
|
||||
"""
|
||||
|
||||
request, channel = self.make_request(
|
||||
"GET", self.url + "?limit=-5", access_token=self.admin_user_tok,
|
||||
)
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
|
||||
|
||||
def test_from_is_negative(self):
|
||||
"""
|
||||
Testing that a negative from parameter returns a 400
|
||||
"""
|
||||
|
||||
request, channel = self.make_request(
|
||||
"GET", self.url + "?from=-5", access_token=self.admin_user_tok,
|
||||
)
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
|
||||
|
||||
def test_next_token(self):
|
||||
"""
|
||||
Testing that `next_token` appears at the right place
|
||||
"""
|
||||
|
||||
# `next_token` does not appear
|
||||
# Number of results is the number of entries
|
||||
request, channel = self.make_request(
|
||||
"GET", self.url + "?limit=20", access_token=self.admin_user_tok,
|
||||
)
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(channel.json_body["total"], 20)
|
||||
self.assertEqual(len(channel.json_body["event_reports"]), 20)
|
||||
self.assertNotIn("next_token", channel.json_body)
|
||||
|
||||
# `next_token` does not appear
|
||||
# Number of max results is larger than the number of entries
|
||||
request, channel = self.make_request(
|
||||
"GET", self.url + "?limit=21", access_token=self.admin_user_tok,
|
||||
)
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(channel.json_body["total"], 20)
|
||||
self.assertEqual(len(channel.json_body["event_reports"]), 20)
|
||||
self.assertNotIn("next_token", channel.json_body)
|
||||
|
||||
# `next_token` does appear
|
||||
# Number of max results is smaller than the number of entries
|
||||
request, channel = self.make_request(
|
||||
"GET", self.url + "?limit=19", access_token=self.admin_user_tok,
|
||||
)
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(channel.json_body["total"], 20)
|
||||
self.assertEqual(len(channel.json_body["event_reports"]), 19)
|
||||
self.assertEqual(channel.json_body["next_token"], 19)
|
||||
|
||||
# Check
|
||||
# Set `from` to value of `next_token` for request remaining entries
|
||||
# `next_token` does not appear
|
||||
request, channel = self.make_request(
|
||||
"GET", self.url + "?from=19", access_token=self.admin_user_tok,
|
||||
)
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(channel.json_body["total"], 20)
|
||||
self.assertEqual(len(channel.json_body["event_reports"]), 1)
|
||||
self.assertNotIn("next_token", channel.json_body)
|
||||
|
||||
def _create_event_and_report(self, room_id, user_tok):
|
||||
"""Create and report events
|
||||
"""
|
||||
resp = self.helper.send(room_id, tok=user_tok)
|
||||
event_id = resp["event_id"]
|
||||
|
||||
request, channel = self.make_request(
|
||||
"POST",
|
||||
"rooms/%s/report/%s" % (room_id, event_id),
|
||||
json.dumps({"score": -100, "reason": "this makes me sad"}),
|
||||
access_token=user_tok,
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
|
||||
def _check_fields(self, content):
|
||||
"""Checks that all attributes are present in a event report
|
||||
"""
|
||||
for c in content:
|
||||
self.assertIn("id", c)
|
||||
self.assertIn("received_ts", c)
|
||||
self.assertIn("room_id", c)
|
||||
self.assertIn("event_id", c)
|
||||
self.assertIn("user_id", c)
|
||||
self.assertIn("reason", c)
|
||||
self.assertIn("content", c)
|
||||
self.assertIn("sender", c)
|
||||
self.assertIn("room_alias", c)
|
||||
self.assertIn("event_json", c)
|
||||
self.assertIn("score", c["content"])
|
||||
self.assertIn("reason", c["content"])
|
||||
self.assertIn("auth_events", c["event_json"])
|
||||
self.assertIn("type", c["event_json"])
|
||||
self.assertIn("room_id", c["event_json"])
|
||||
self.assertIn("sender", c["event_json"])
|
||||
self.assertIn("content", c["event_json"])
|
|
@ -874,6 +874,10 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
|||
)
|
||||
self.render(request)
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self._is_erased("@user:test", False)
|
||||
d = self.store.mark_user_erased("@user:test")
|
||||
self.assertIsNone(self.get_success(d))
|
||||
self._is_erased("@user:test", True)
|
||||
|
||||
# Attempt to reactivate the user (without a password).
|
||||
request, channel = self.make_request(
|
||||
|
@ -906,6 +910,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
|||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual("@user:test", channel.json_body["name"])
|
||||
self.assertEqual(False, channel.json_body["deactivated"])
|
||||
self._is_erased("@user:test", False)
|
||||
|
||||
def test_set_user_as_admin(self):
|
||||
"""
|
||||
|
@ -996,6 +1001,15 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
|||
# Ensure they're still alive
|
||||
self.assertEqual(0, channel.json_body["deactivated"])
|
||||
|
||||
def _is_erased(self, user_id, expect):
|
||||
"""Assert that the user is erased or not
|
||||
"""
|
||||
d = self.store.is_user_erased(user_id)
|
||||
if expect:
|
||||
self.assertTrue(self.get_success(d))
|
||||
else:
|
||||
self.assertFalse(self.get_success(d))
|
||||
|
||||
|
||||
class UserMembershipRestTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
|
|
|
@ -43,16 +43,20 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
|||
"""
|
||||
)
|
||||
|
||||
def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator:
|
||||
def _create_id_generator(
|
||||
self, instance_name="master", writers=["master"]
|
||||
) -> MultiWriterIdGenerator:
|
||||
def _create(conn):
|
||||
return MultiWriterIdGenerator(
|
||||
conn,
|
||||
self.db_pool,
|
||||
stream_name="test_stream",
|
||||
instance_name=instance_name,
|
||||
table="foobar",
|
||||
instance_column="instance_name",
|
||||
id_column="stream_id",
|
||||
sequence_name="foobar_seq",
|
||||
writers=writers,
|
||||
)
|
||||
|
||||
return self.get_success(self.db_pool.runWithConnection(_create))
|
||||
|
@ -68,6 +72,13 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
|||
"INSERT INTO foobar VALUES (nextval('foobar_seq'), ?)",
|
||||
(instance_name,),
|
||||
)
|
||||
txn.execute(
|
||||
"""
|
||||
INSERT INTO stream_positions VALUES ('test_stream', ?, lastval())
|
||||
ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = lastval()
|
||||
""",
|
||||
(instance_name,),
|
||||
)
|
||||
|
||||
self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
|
||||
|
||||
|
@ -81,6 +92,13 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
|||
"INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,),
|
||||
)
|
||||
txn.execute("SELECT setval('foobar_seq', ?)", (stream_id,))
|
||||
txn.execute(
|
||||
"""
|
||||
INSERT INTO stream_positions VALUES ('test_stream', ?, ?)
|
||||
ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
|
||||
""",
|
||||
(instance_name, stream_id, stream_id),
|
||||
)
|
||||
|
||||
self.get_success(self.db_pool.runInteraction("_insert_row_with_id", _insert))
|
||||
|
||||
|
@ -111,7 +129,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
|||
# advanced after we leave the context manager.
|
||||
|
||||
async def _get_next_async():
|
||||
with await id_gen.get_next() as stream_id:
|
||||
async with id_gen.get_next() as stream_id:
|
||||
self.assertEqual(stream_id, 8)
|
||||
|
||||
self.assertEqual(id_gen.get_positions(), {"master": 7})
|
||||
|
@ -139,10 +157,10 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
|||
ctx3 = self.get_success(id_gen.get_next())
|
||||
ctx4 = self.get_success(id_gen.get_next())
|
||||
|
||||
s1 = ctx1.__enter__()
|
||||
s2 = ctx2.__enter__()
|
||||
s3 = ctx3.__enter__()
|
||||
s4 = ctx4.__enter__()
|
||||
s1 = self.get_success(ctx1.__aenter__())
|
||||
s2 = self.get_success(ctx2.__aenter__())
|
||||
s3 = self.get_success(ctx3.__aenter__())
|
||||
s4 = self.get_success(ctx4.__aenter__())
|
||||
|
||||
self.assertEqual(s1, 8)
|
||||
self.assertEqual(s2, 9)
|
||||
|
@ -152,22 +170,22 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
|||
self.assertEqual(id_gen.get_positions(), {"master": 7})
|
||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
|
||||
|
||||
ctx2.__exit__(None, None, None)
|
||||
self.get_success(ctx2.__aexit__(None, None, None))
|
||||
|
||||
self.assertEqual(id_gen.get_positions(), {"master": 7})
|
||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
|
||||
|
||||
ctx1.__exit__(None, None, None)
|
||||
self.get_success(ctx1.__aexit__(None, None, None))
|
||||
|
||||
self.assertEqual(id_gen.get_positions(), {"master": 9})
|
||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), 9)
|
||||
|
||||
ctx4.__exit__(None, None, None)
|
||||
self.get_success(ctx4.__aexit__(None, None, None))
|
||||
|
||||
self.assertEqual(id_gen.get_positions(), {"master": 9})
|
||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), 9)
|
||||
|
||||
ctx3.__exit__(None, None, None)
|
||||
self.get_success(ctx3.__aexit__(None, None, None))
|
||||
|
||||
self.assertEqual(id_gen.get_positions(), {"master": 11})
|
||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), 11)
|
||||
|
@ -179,8 +197,8 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
|||
self._insert_rows("first", 3)
|
||||
self._insert_rows("second", 4)
|
||||
|
||||
first_id_gen = self._create_id_generator("first")
|
||||
second_id_gen = self._create_id_generator("second")
|
||||
first_id_gen = self._create_id_generator("first", writers=["first", "second"])
|
||||
second_id_gen = self._create_id_generator("second", writers=["first", "second"])
|
||||
|
||||
self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
|
||||
self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3)
|
||||
|
@ -190,7 +208,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
|||
# advanced after we leave the context manager.
|
||||
|
||||
async def _get_next_async():
|
||||
with await first_id_gen.get_next() as stream_id:
|
||||
async with first_id_gen.get_next() as stream_id:
|
||||
self.assertEqual(stream_id, 8)
|
||||
|
||||
self.assertEqual(
|
||||
|
@ -208,7 +226,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
|||
# stream ID
|
||||
|
||||
async def _get_next_async():
|
||||
with await second_id_gen.get_next() as stream_id:
|
||||
async with second_id_gen.get_next() as stream_id:
|
||||
self.assertEqual(stream_id, 9)
|
||||
|
||||
self.assertEqual(
|
||||
|
@ -262,7 +280,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
|||
self._insert_row_with_id("first", 3)
|
||||
self._insert_row_with_id("second", 5)
|
||||
|
||||
id_gen = self._create_id_generator("first")
|
||||
id_gen = self._create_id_generator("first", writers=["first", "second"])
|
||||
|
||||
self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
|
||||
|
||||
|
@ -300,14 +318,18 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
|||
self._insert_row_with_id("first", 3)
|
||||
self._insert_row_with_id("second", 5)
|
||||
|
||||
id_gen = self._create_id_generator("first")
|
||||
id_gen = self._create_id_generator("first", writers=["first", "second"])
|
||||
|
||||
self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
|
||||
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), 3)
|
||||
with self.get_success(id_gen.get_next()) as stream_id:
|
||||
self.assertEqual(stream_id, 6)
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), 3)
|
||||
|
||||
async def _get_next_async():
|
||||
async with id_gen.get_next() as stream_id:
|
||||
self.assertEqual(stream_id, 6)
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), 3)
|
||||
|
||||
self.get_success(_get_next_async())
|
||||
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), 6)
|
||||
|
||||
|
@ -315,6 +337,80 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
|||
# `persisted_upto_position` in this case, then it will be correct in the
|
||||
# other cases that are tested above (since they'll hit the same code).
|
||||
|
||||
def test_restart_during_out_of_order_persistence(self):
|
||||
"""Test that restarting a process while another process is writing out
|
||||
of order updates are handled correctly.
|
||||
"""
|
||||
|
||||
# Prefill table with 7 rows written by 'master'
|
||||
self._insert_rows("master", 7)
|
||||
|
||||
id_gen = self._create_id_generator()
|
||||
|
||||
self.assertEqual(id_gen.get_positions(), {"master": 7})
|
||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
|
||||
|
||||
# Persist two rows at once
|
||||
ctx1 = self.get_success(id_gen.get_next())
|
||||
ctx2 = self.get_success(id_gen.get_next())
|
||||
|
||||
s1 = self.get_success(ctx1.__aenter__())
|
||||
s2 = self.get_success(ctx2.__aenter__())
|
||||
|
||||
self.assertEqual(s1, 8)
|
||||
self.assertEqual(s2, 9)
|
||||
|
||||
self.assertEqual(id_gen.get_positions(), {"master": 7})
|
||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
|
||||
|
||||
# We finish persisting the second row before restart
|
||||
self.get_success(ctx2.__aexit__(None, None, None))
|
||||
|
||||
# We simulate a restart of another worker by just creating a new ID gen.
|
||||
id_gen_worker = self._create_id_generator("worker")
|
||||
|
||||
# Restarted worker should not see the second persisted row
|
||||
self.assertEqual(id_gen_worker.get_positions(), {"master": 7})
|
||||
self.assertEqual(id_gen_worker.get_current_token_for_writer("master"), 7)
|
||||
|
||||
# Now if we persist the first row then both instances should jump ahead
|
||||
# correctly.
|
||||
self.get_success(ctx1.__aexit__(None, None, None))
|
||||
|
||||
self.assertEqual(id_gen.get_positions(), {"master": 9})
|
||||
id_gen_worker.advance("master", 9)
|
||||
self.assertEqual(id_gen_worker.get_positions(), {"master": 9})
|
||||
|
||||
def test_writer_config_change(self):
|
||||
"""Test that changing the writer config correctly works.
|
||||
"""
|
||||
|
||||
self._insert_row_with_id("first", 3)
|
||||
self._insert_row_with_id("second", 5)
|
||||
|
||||
# Initial config has two writers
|
||||
id_gen = self._create_id_generator("first", writers=["first", "second"])
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), 3)
|
||||
|
||||
# New config removes one of the configs. Note that if the writer is
|
||||
# removed from config we assume that it has been shut down and has
|
||||
# finished persisting, hence why the persisted upto position is 5.
|
||||
id_gen_2 = self._create_id_generator("second", writers=["second"])
|
||||
self.assertEqual(id_gen_2.get_persisted_upto_position(), 5)
|
||||
|
||||
# This config points to a single, previously unused writer.
|
||||
id_gen_3 = self._create_id_generator("third", writers=["third"])
|
||||
self.assertEqual(id_gen_3.get_persisted_upto_position(), 5)
|
||||
|
||||
# Check that we get a sane next stream ID with this new config.
|
||||
|
||||
async def _get_next_async():
|
||||
async with id_gen_3.get_next() as stream_id:
|
||||
self.assertEqual(stream_id, 6)
|
||||
|
||||
self.get_success(_get_next_async())
|
||||
self.assertEqual(id_gen_3.get_persisted_upto_position(), 6)
|
||||
|
||||
|
||||
class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||
"""Tests MultiWriterIdGenerator that produce *negative* stream IDs.
|
||||
|
@ -341,16 +437,20 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
|||
"""
|
||||
)
|
||||
|
||||
def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator:
|
||||
def _create_id_generator(
|
||||
self, instance_name="master", writers=["master"]
|
||||
) -> MultiWriterIdGenerator:
|
||||
def _create(conn):
|
||||
return MultiWriterIdGenerator(
|
||||
conn,
|
||||
self.db_pool,
|
||||
stream_name="test_stream",
|
||||
instance_name=instance_name,
|
||||
table="foobar",
|
||||
instance_column="instance_name",
|
||||
id_column="stream_id",
|
||||
sequence_name="foobar_seq",
|
||||
writers=writers,
|
||||
positive=False,
|
||||
)
|
||||
|
||||
|
@ -364,6 +464,13 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
|||
txn.execute(
|
||||
"INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,),
|
||||
)
|
||||
txn.execute(
|
||||
"""
|
||||
INSERT INTO stream_positions VALUES ('test_stream', ?, ?)
|
||||
ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
|
||||
""",
|
||||
(instance_name, -stream_id, -stream_id),
|
||||
)
|
||||
|
||||
self.get_success(self.db_pool.runInteraction("_insert_row", _insert))
|
||||
|
||||
|
@ -373,16 +480,22 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
|||
"""
|
||||
id_gen = self._create_id_generator()
|
||||
|
||||
with self.get_success(id_gen.get_next()) as stream_id:
|
||||
self._insert_row("master", stream_id)
|
||||
async def _get_next_async():
|
||||
async with id_gen.get_next() as stream_id:
|
||||
self._insert_row("master", stream_id)
|
||||
|
||||
self.get_success(_get_next_async())
|
||||
|
||||
self.assertEqual(id_gen.get_positions(), {"master": -1})
|
||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), -1)
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), -1)
|
||||
|
||||
with self.get_success(id_gen.get_next_mult(3)) as stream_ids:
|
||||
for stream_id in stream_ids:
|
||||
self._insert_row("master", stream_id)
|
||||
async def _get_next_async2():
|
||||
async with id_gen.get_next_mult(3) as stream_ids:
|
||||
for stream_id in stream_ids:
|
||||
self._insert_row("master", stream_id)
|
||||
|
||||
self.get_success(_get_next_async2())
|
||||
|
||||
self.assertEqual(id_gen.get_positions(), {"master": -4})
|
||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), -4)
|
||||
|
@ -399,21 +512,27 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
|||
"""Tests that having multiple instances that get advanced over
|
||||
federation works corretly.
|
||||
"""
|
||||
id_gen_1 = self._create_id_generator("first")
|
||||
id_gen_2 = self._create_id_generator("second")
|
||||
id_gen_1 = self._create_id_generator("first", writers=["first", "second"])
|
||||
id_gen_2 = self._create_id_generator("second", writers=["first", "second"])
|
||||
|
||||
with self.get_success(id_gen_1.get_next()) as stream_id:
|
||||
self._insert_row("first", stream_id)
|
||||
id_gen_2.advance("first", stream_id)
|
||||
async def _get_next_async():
|
||||
async with id_gen_1.get_next() as stream_id:
|
||||
self._insert_row("first", stream_id)
|
||||
id_gen_2.advance("first", stream_id)
|
||||
|
||||
self.get_success(_get_next_async())
|
||||
|
||||
self.assertEqual(id_gen_1.get_positions(), {"first": -1})
|
||||
self.assertEqual(id_gen_2.get_positions(), {"first": -1})
|
||||
self.assertEqual(id_gen_1.get_persisted_upto_position(), -1)
|
||||
self.assertEqual(id_gen_2.get_persisted_upto_position(), -1)
|
||||
|
||||
with self.get_success(id_gen_2.get_next()) as stream_id:
|
||||
self._insert_row("second", stream_id)
|
||||
id_gen_1.advance("second", stream_id)
|
||||
async def _get_next_async2():
|
||||
async with id_gen_2.get_next() as stream_id:
|
||||
self._insert_row("second", stream_id)
|
||||
id_gen_1.advance("second", stream_id)
|
||||
|
||||
self.get_success(_get_next_async2())
|
||||
|
||||
self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -2})
|
||||
self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2})
|
||||
|
|
8
tox.ini
8
tox.ini
|
@ -2,13 +2,12 @@
|
|||
envlist = packaging, py35, py36, py37, py38, check_codestyle, check_isort
|
||||
|
||||
[base]
|
||||
extras = test
|
||||
deps =
|
||||
mock
|
||||
python-subunit
|
||||
junitxml
|
||||
coverage
|
||||
coverage-enable-subprocess
|
||||
parameterized
|
||||
|
||||
# cyptography 2.2 requires setuptools >= 18.5
|
||||
#
|
||||
|
@ -36,7 +35,7 @@ setenv =
|
|||
[testenv]
|
||||
deps =
|
||||
{[base]deps}
|
||||
extras = all
|
||||
extras = all, test
|
||||
|
||||
whitelist_externals =
|
||||
sh
|
||||
|
@ -84,7 +83,6 @@ deps =
|
|||
# Old automat version for Twisted
|
||||
Automat == 0.3.0
|
||||
|
||||
mock
|
||||
lxml
|
||||
coverage
|
||||
coverage-enable-subprocess
|
||||
|
@ -97,7 +95,7 @@ commands =
|
|||
/bin/sh -c 'python -m synapse.python_dependencies | sed -e "s/>=/==/g" -e "s/psycopg2==2.6//" -e "s/pyopenssl==16.0.0/pyopenssl==17.0.0/" | xargs -d"\n" pip install'
|
||||
|
||||
# Install Synapse itself. This won't update any libraries.
|
||||
pip install -e .
|
||||
pip install -e ".[test]"
|
||||
|
||||
{envbindir}/coverage run "{envbindir}/trial" {env:TRIAL_FLAGS:} {posargs:tests} {env:TOXSUFFIX:}
|
||||
|
||||
|
|
Loading…
Reference in New Issue