Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes
commit
fbcc8703dc
23
CHANGES.md
23
CHANGES.md
|
@ -1,10 +1,27 @@
|
||||||
Synapse 1.31.0rc1 (2021-03-30)
|
Synapse 1.31.0 (2021-04-06)
|
||||||
==============================
|
===========================
|
||||||
|
|
||||||
**Note:** As announced in v1.25.0, and in line with the deprecation policy for platform dependencies, this is the last release to support Python 3.5 and PostgreSQL 9.5. Future versions of Synapse will require Python 3.6+ and PostgreSQL 9.6+.
|
**Note:** As announced in v1.25.0, and in line with the deprecation policy for platform dependencies, this is the last release to support Python 3.5 and PostgreSQL 9.5. Future versions of Synapse will require Python 3.6+ and PostgreSQL 9.6+, as per our [deprecation policy](docs/deprecation_policy.md).
|
||||||
|
|
||||||
This is also the last release that the Synapse team will be publishing packages for Debian Stretch and Ubuntu Xenial.
|
This is also the last release that the Synapse team will be publishing packages for Debian Stretch and Ubuntu Xenial.
|
||||||
|
|
||||||
|
|
||||||
|
Improved Documentation
|
||||||
|
----------------------
|
||||||
|
|
||||||
|
- Add a document describing the deprecation policy for platform dependencies. ([\#9723](https://github.com/matrix-org/synapse/issues/9723))
|
||||||
|
|
||||||
|
|
||||||
|
Internal Changes
|
||||||
|
----------------
|
||||||
|
|
||||||
|
- Revert using `dmypy run` in lint script. ([\#9720](https://github.com/matrix-org/synapse/issues/9720))
|
||||||
|
- Pin flake8-bugbear's version. ([\#9734](https://github.com/matrix-org/synapse/issues/9734))
|
||||||
|
|
||||||
|
|
||||||
|
Synapse 1.31.0rc1 (2021-03-30)
|
||||||
|
==============================
|
||||||
|
|
||||||
Features
|
Features
|
||||||
--------
|
--------
|
||||||
|
|
||||||
|
|
|
@ -38,6 +38,7 @@ There are 3 steps to follow under **Installation Instructions**.
|
||||||
- [URL previews](#url-previews)
|
- [URL previews](#url-previews)
|
||||||
- [Troubleshooting Installation](#troubleshooting-installation)
|
- [Troubleshooting Installation](#troubleshooting-installation)
|
||||||
|
|
||||||
|
|
||||||
## Choosing your server name
|
## Choosing your server name
|
||||||
|
|
||||||
It is important to choose the name for your server before you install Synapse,
|
It is important to choose the name for your server before you install Synapse,
|
||||||
|
|
16
README.rst
16
README.rst
|
@ -314,6 +314,15 @@ Testing with SyTest is recommended for verifying that changes related to the
|
||||||
Client-Server API are functioning correctly. See the `installation instructions
|
Client-Server API are functioning correctly. See the `installation instructions
|
||||||
<https://github.com/matrix-org/sytest#installing>`_ for details.
|
<https://github.com/matrix-org/sytest#installing>`_ for details.
|
||||||
|
|
||||||
|
|
||||||
|
Platform dependencies
|
||||||
|
=====================
|
||||||
|
|
||||||
|
Synapse uses a number of platform dependencies such as Python and PostgreSQL,
|
||||||
|
and aims to follow supported upstream versions. See the
|
||||||
|
`<docs/deprecation_policy.md>`_ document for more details.
|
||||||
|
|
||||||
|
|
||||||
Troubleshooting
|
Troubleshooting
|
||||||
===============
|
===============
|
||||||
|
|
||||||
|
@ -384,7 +393,12 @@ massive excess of outgoing federation requests (see `discussion
|
||||||
indicate that your server is also issuing far more outgoing federation
|
indicate that your server is also issuing far more outgoing federation
|
||||||
requests than can be accounted for by your users' activity, this is a
|
requests than can be accounted for by your users' activity, this is a
|
||||||
likely cause. The misbehavior can be worked around by setting
|
likely cause. The misbehavior can be worked around by setting
|
||||||
``use_presence: false`` in the Synapse config file.
|
the following in the Synapse config file:
|
||||||
|
|
||||||
|
.. code-block:: yaml
|
||||||
|
|
||||||
|
presence:
|
||||||
|
enabled: false
|
||||||
|
|
||||||
People can't accept room invitations from me
|
People can't accept room invitations from me
|
||||||
--------------------------------------------
|
--------------------------------------------
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
Prevent `synapse_forward_extremities` and `synapse_excess_extremity_events` Prometheus metrics from initially reporting zero-values after startup.
|
|
@ -0,0 +1 @@
|
||||||
|
Add a Synapse module for routing presence updates between users.
|
|
@ -0,0 +1 @@
|
||||||
|
Include request information in structured logging output.
|
|
@ -0,0 +1 @@
|
||||||
|
Update `scripts-dev/complement.sh` to use a local checkout of Complement, allow running a subset of tests and have it use Synapse's Complement test blacklist.
|
|
@ -0,0 +1 @@
|
||||||
|
Improve Jaeger tracing for `to_device` messages.
|
|
@ -0,0 +1 @@
|
||||||
|
Add `order_by` to the admin API `GET /_synapse/admin/v2/users`. Contributed by @dklimpel.
|
|
@ -0,0 +1 @@
|
||||||
|
Replace the `room_invite_state_types` configuration setting with `room_prejoin_state`.
|
|
@ -0,0 +1 @@
|
||||||
|
Experimental Spaces support: include `m.room.create` in the room state sent with room-invites.
|
|
@ -0,0 +1 @@
|
||||||
|
Fix recently added ratelimits to correctly honour the application service `rate_limited` flag.
|
|
@ -0,0 +1 @@
|
||||||
|
Add experimental support for [MSC3083](https://github.com/matrix-org/matrix-doc/pull/3083): restricting room access via group membership.
|
|
@ -0,0 +1 @@
|
||||||
|
Replace deprecated `imp` module with successor `importlib`. Contributed by Cristina Muñoz.
|
|
@ -0,0 +1 @@
|
||||||
|
Make the allowed_local_3pids regex example in the sample config stricter.
|
|
@ -0,0 +1 @@
|
||||||
|
Fix longstanding bug which caused `duplicate key value violates unique constraint "remote_media_cache_thumbnails_media_origin_media_id_thumbna_key"` errors.
|
|
@ -0,0 +1 @@
|
||||||
|
Add type hints to expiring cache.
|
|
@ -0,0 +1 @@
|
||||||
|
Convert various testcases to `HomeserverTestCase`.
|
|
@ -0,0 +1 @@
|
||||||
|
Start linting mypy with `no_implicit_optional`.
|
|
@ -0,0 +1 @@
|
||||||
|
Add missing type hints to federation handler and server.
|
|
@ -0,0 +1 @@
|
||||||
|
Check that a `ConfigError` is raised, rather than simply `Exception`, when appropriate in homeserver config file generation tests.
|
|
@ -1,3 +1,9 @@
|
||||||
|
matrix-synapse-py3 (1.31.0) stable; urgency=medium
|
||||||
|
|
||||||
|
* New synapse release 1.31.0.
|
||||||
|
|
||||||
|
-- Synapse Packaging team <packages@matrix.org> Tue, 06 Apr 2021 13:08:29 +0100
|
||||||
|
|
||||||
matrix-synapse-py3 (1.30.1) stable; urgency=medium
|
matrix-synapse-py3 (1.30.1) stable; urgency=medium
|
||||||
|
|
||||||
* New synapse release 1.30.1.
|
* New synapse release 1.30.1.
|
||||||
|
|
|
@ -173,18 +173,10 @@ report_stats: False
|
||||||
|
|
||||||
## API Configuration ##
|
## API Configuration ##
|
||||||
|
|
||||||
room_invite_state_types:
|
|
||||||
- "m.room.join_rules"
|
|
||||||
- "m.room.canonical_alias"
|
|
||||||
- "m.room.avatar"
|
|
||||||
- "m.room.name"
|
|
||||||
|
|
||||||
{% if SYNAPSE_APPSERVICES %}
|
{% if SYNAPSE_APPSERVICES %}
|
||||||
app_service_config_files:
|
app_service_config_files:
|
||||||
{% for appservice in SYNAPSE_APPSERVICES %} - "{{ appservice }}"
|
{% for appservice in SYNAPSE_APPSERVICES %} - "{{ appservice }}"
|
||||||
{% endfor %}
|
{% endfor %}
|
||||||
{% else %}
|
|
||||||
app_service_config_files: []
|
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
|
||||||
macaroon_secret_key: "{{ SYNAPSE_MACAROON_SECRET_KEY }}"
|
macaroon_secret_key: "{{ SYNAPSE_MACAROON_SECRET_KEY }}"
|
||||||
|
|
|
@ -111,35 +111,16 @@ List Accounts
|
||||||
=============
|
=============
|
||||||
|
|
||||||
This API returns all local user accounts.
|
This API returns all local user accounts.
|
||||||
|
By default, the response is ordered by ascending user ID.
|
||||||
|
|
||||||
The api is::
|
The API is::
|
||||||
|
|
||||||
GET /_synapse/admin/v2/users?from=0&limit=10&guests=false
|
GET /_synapse/admin/v2/users?from=0&limit=10&guests=false
|
||||||
|
|
||||||
To use it, you will need to authenticate by providing an ``access_token`` for a
|
To use it, you will need to authenticate by providing an ``access_token`` for a
|
||||||
server admin: see `README.rst <README.rst>`_.
|
server admin: see `README.rst <README.rst>`_.
|
||||||
|
|
||||||
The parameter ``from`` is optional but used for pagination, denoting the
|
A response body like the following is returned:
|
||||||
offset in the returned results. This should be treated as an opaque value and
|
|
||||||
not explicitly set to anything other than the return value of ``next_token``
|
|
||||||
from a previous call.
|
|
||||||
|
|
||||||
The parameter ``limit`` is optional but is used for pagination, denoting the
|
|
||||||
maximum number of items to return in this call. Defaults to ``100``.
|
|
||||||
|
|
||||||
The parameter ``user_id`` is optional and filters to only return users with user IDs
|
|
||||||
that contain this value. This parameter is ignored when using the ``name`` parameter.
|
|
||||||
|
|
||||||
The parameter ``name`` is optional and filters to only return users with user ID localparts
|
|
||||||
**or** displaynames that contain this value.
|
|
||||||
|
|
||||||
The parameter ``guests`` is optional and if ``false`` will **exclude** guest users.
|
|
||||||
Defaults to ``true`` to include guest users.
|
|
||||||
|
|
||||||
The parameter ``deactivated`` is optional and if ``true`` will **include** deactivated users.
|
|
||||||
Defaults to ``false`` to exclude deactivated users.
|
|
||||||
|
|
||||||
A JSON body is returned with the following shape:
|
|
||||||
|
|
||||||
.. code:: json
|
.. code:: json
|
||||||
|
|
||||||
|
@ -175,6 +156,66 @@ with ``from`` set to the value of ``next_token``. This will return a new page.
|
||||||
If the endpoint does not return a ``next_token`` then there are no more users
|
If the endpoint does not return a ``next_token`` then there are no more users
|
||||||
to paginate through.
|
to paginate through.
|
||||||
|
|
||||||
|
**Parameters**
|
||||||
|
|
||||||
|
The following parameters should be set in the URL:
|
||||||
|
|
||||||
|
- ``user_id`` - Is optional and filters to only return users with user IDs
|
||||||
|
that contain this value. This parameter is ignored when using the ``name`` parameter.
|
||||||
|
- ``name`` - Is optional and filters to only return users with user ID localparts
|
||||||
|
**or** displaynames that contain this value.
|
||||||
|
- ``guests`` - string representing a bool - Is optional and if ``false`` will **exclude** guest users.
|
||||||
|
Defaults to ``true`` to include guest users.
|
||||||
|
- ``deactivated`` - string representing a bool - Is optional and if ``true`` will **include** deactivated users.
|
||||||
|
Defaults to ``false`` to exclude deactivated users.
|
||||||
|
- ``limit`` - string representing a positive integer - Is optional but is used for pagination,
|
||||||
|
denoting the maximum number of items to return in this call. Defaults to ``100``.
|
||||||
|
- ``from`` - string representing a positive integer - Is optional but used for pagination,
|
||||||
|
denoting the offset in the returned results. This should be treated as an opaque value and
|
||||||
|
not explicitly set to anything other than the return value of ``next_token`` from a previous call.
|
||||||
|
Defaults to ``0``.
|
||||||
|
- ``order_by`` - The method by which to sort the returned list of users.
|
||||||
|
If the ordered field has duplicates, the second order is always by ascending ``name``,
|
||||||
|
which guarantees a stable ordering. Valid values are:
|
||||||
|
|
||||||
|
- ``name`` - Users are ordered alphabetically by ``name``. This is the default.
|
||||||
|
- ``is_guest`` - Users are ordered by ``is_guest`` status.
|
||||||
|
- ``admin`` - Users are ordered by ``admin`` status.
|
||||||
|
- ``user_type`` - Users are ordered alphabetically by ``user_type``.
|
||||||
|
- ``deactivated`` - Users are ordered by ``deactivated`` status.
|
||||||
|
- ``shadow_banned`` - Users are ordered by ``shadow_banned`` status.
|
||||||
|
- ``displayname`` - Users are ordered alphabetically by ``displayname``.
|
||||||
|
- ``avatar_url`` - Users are ordered alphabetically by avatar URL.
|
||||||
|
|
||||||
|
- ``dir`` - Direction of media order. Either ``f`` for forwards or ``b`` for backwards.
|
||||||
|
Setting this value to ``b`` will reverse the above sort order. Defaults to ``f``.
|
||||||
|
|
||||||
|
Caution. The database only has indexes on the columns ``name`` and ``created_ts``.
|
||||||
|
This means that if a different sort order is used (``is_guest``, ``admin``,
|
||||||
|
``user_type``, ``deactivated``, ``shadow_banned``, ``avatar_url`` or ``displayname``),
|
||||||
|
this can cause a large load on the database, especially for large environments.
|
||||||
|
|
||||||
|
**Response**
|
||||||
|
|
||||||
|
The following fields are returned in the JSON response body:
|
||||||
|
|
||||||
|
- ``users`` - An array of objects, each containing information about an user.
|
||||||
|
User objects contain the following fields:
|
||||||
|
|
||||||
|
- ``name`` - string - Fully-qualified user ID (ex. `@user:server.com`).
|
||||||
|
- ``is_guest`` - bool - Status if that user is a guest account.
|
||||||
|
- ``admin`` - bool - Status if that user is a server administrator.
|
||||||
|
- ``user_type`` - string - Type of the user. Normal users are type ``None``.
|
||||||
|
This allows user type specific behaviour. There are also types ``support`` and ``bot``.
|
||||||
|
- ``deactivated`` - bool - Status if that user has been marked as deactivated.
|
||||||
|
- ``shadow_banned`` - bool - Status if that user has been marked as shadow banned.
|
||||||
|
- ``displayname`` - string - The user's display name if they have set one.
|
||||||
|
- ``avatar_url`` - string - The user's avatar URL if they have set one.
|
||||||
|
|
||||||
|
- ``next_token``: string representing a positive integer - Indication for pagination. See above.
|
||||||
|
- ``total`` - integer - Total number of media.
|
||||||
|
|
||||||
|
|
||||||
Query current sessions for a user
|
Query current sessions for a user
|
||||||
=================================
|
=================================
|
||||||
|
|
||||||
|
|
|
@ -128,6 +128,9 @@ Some guidelines follow:
|
||||||
will be if no sub-options are enabled).
|
will be if no sub-options are enabled).
|
||||||
- Lines should be wrapped at 80 characters.
|
- Lines should be wrapped at 80 characters.
|
||||||
- Use two-space indents.
|
- Use two-space indents.
|
||||||
|
- `true` and `false` are spelt thus (as opposed to `True`, etc.)
|
||||||
|
- Use single quotes (`'`) rather than double-quotes (`"`) or backticks
|
||||||
|
(`` ` ``) to refer to configuration options.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,33 @@
|
||||||
|
Deprecation Policy for Platform Dependencies
|
||||||
|
============================================
|
||||||
|
|
||||||
|
Synapse has a number of platform dependencies, including Python and PostgreSQL.
|
||||||
|
This document outlines the policy towards which versions we support, and when we
|
||||||
|
drop support for versions in the future.
|
||||||
|
|
||||||
|
|
||||||
|
Policy
|
||||||
|
------
|
||||||
|
|
||||||
|
Synapse follows the upstream support life cycles for Python and PostgreSQL,
|
||||||
|
i.e. when a version reaches End of Life Synapse will withdraw support for that
|
||||||
|
version in future releases.
|
||||||
|
|
||||||
|
Details on the upstream support life cycles for Python and PostgreSQL are
|
||||||
|
documented at https://endoflife.date/python and
|
||||||
|
https://endoflife.date/postgresql.
|
||||||
|
|
||||||
|
|
||||||
|
Context
|
||||||
|
-------
|
||||||
|
|
||||||
|
It is important for system admins to have a clear understanding of the platform
|
||||||
|
requirements of Synapse and its deprecation policies so that they can
|
||||||
|
effectively plan upgrading their infrastructure ahead of time. This is
|
||||||
|
especially important in contexts where upgrading the infrastructure requires
|
||||||
|
auditing and approval from a security team, or where otherwise upgrading is a
|
||||||
|
long process.
|
||||||
|
|
||||||
|
By following the upstream support life cycles Synapse can ensure that its
|
||||||
|
dependencies continue to get security patches, while not requiring system admins
|
||||||
|
to constantly update their platform dependencies to the latest versions.
|
|
@ -0,0 +1,235 @@
|
||||||
|
# Presence Router Module
|
||||||
|
|
||||||
|
Synapse supports configuring a module that can specify additional users
|
||||||
|
(local or remote) to should receive certain presence updates from local
|
||||||
|
users.
|
||||||
|
|
||||||
|
Note that routing presence via Application Service transactions is not
|
||||||
|
currently supported.
|
||||||
|
|
||||||
|
The presence routing module is implemented as a Python class, which will
|
||||||
|
be imported by the running Synapse.
|
||||||
|
|
||||||
|
## Python Presence Router Class
|
||||||
|
|
||||||
|
The Python class is instantiated with two objects:
|
||||||
|
|
||||||
|
* A configuration object of some type (see below).
|
||||||
|
* An instance of `synapse.module_api.ModuleApi`.
|
||||||
|
|
||||||
|
It then implements methods related to presence routing.
|
||||||
|
|
||||||
|
Note that one method of `ModuleApi` that may be useful is:
|
||||||
|
|
||||||
|
```python
|
||||||
|
async def ModuleApi.send_local_online_presence_to(users: Iterable[str]) -> None
|
||||||
|
```
|
||||||
|
|
||||||
|
which can be given a list of local or remote MXIDs to broadcast known, online user
|
||||||
|
presence to (for those users that the receiving user is considered interested in).
|
||||||
|
It does not include state for users who are currently offline, and it can only be
|
||||||
|
called on workers that support sending federation.
|
||||||
|
|
||||||
|
### Module structure
|
||||||
|
|
||||||
|
Below is a list of possible methods that can be implemented, and whether they are
|
||||||
|
required.
|
||||||
|
|
||||||
|
#### `parse_config`
|
||||||
|
|
||||||
|
```python
|
||||||
|
def parse_config(config_dict: dict) -> Any
|
||||||
|
```
|
||||||
|
|
||||||
|
**Required.** A static method that is passed a dictionary of config options, and
|
||||||
|
should return a validated config object. This method is described further in
|
||||||
|
[Configuration](#configuration).
|
||||||
|
|
||||||
|
#### `get_users_for_states`
|
||||||
|
|
||||||
|
```python
|
||||||
|
async def get_users_for_states(
|
||||||
|
self,
|
||||||
|
state_updates: Iterable[UserPresenceState],
|
||||||
|
) -> Dict[str, Set[UserPresenceState]]:
|
||||||
|
```
|
||||||
|
|
||||||
|
**Required.** An asynchronous method that is passed an iterable of user presence
|
||||||
|
state. This method can determine whether a given presence update should be sent to certain
|
||||||
|
users. It does this by returning a dictionary with keys representing local or remote
|
||||||
|
Matrix User IDs, and values being a python set
|
||||||
|
of `synapse.handlers.presence.UserPresenceState` instances.
|
||||||
|
|
||||||
|
Synapse will then attempt to send the specified presence updates to each user when
|
||||||
|
possible.
|
||||||
|
|
||||||
|
#### `get_interested_users`
|
||||||
|
|
||||||
|
```python
|
||||||
|
async def get_interested_users(self, user_id: str) -> Union[Set[str], str]
|
||||||
|
```
|
||||||
|
|
||||||
|
**Required.** An asynchronous method that is passed a single Matrix User ID. This
|
||||||
|
method is expected to return the users that the passed in user may be interested in the
|
||||||
|
presence of. Returned users may be local or remote. The presence routed as a result of
|
||||||
|
what this method returns is sent in addition to the updates already sent between users
|
||||||
|
that share a room together. Presence updates are deduplicated.
|
||||||
|
|
||||||
|
This method should return a python set of Matrix User IDs, or the object
|
||||||
|
`synapse.events.presence_router.PresenceRouter.ALL_USERS` to indicate that the passed
|
||||||
|
user should receive presence information for *all* known users.
|
||||||
|
|
||||||
|
For clarity, if the user `@alice:example.org` is passed to this method, and the Set
|
||||||
|
`{"@bob:example.com", "@charlie:somewhere.org"}` is returned, this signifies that Alice
|
||||||
|
should receive presence updates sent by Bob and Charlie, regardless of whether these
|
||||||
|
users share a room.
|
||||||
|
|
||||||
|
### Example
|
||||||
|
|
||||||
|
Below is an example implementation of a presence router class.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from typing import Dict, Iterable, Set, Union
|
||||||
|
from synapse.events.presence_router import PresenceRouter
|
||||||
|
from synapse.handlers.presence import UserPresenceState
|
||||||
|
from synapse.module_api import ModuleApi
|
||||||
|
|
||||||
|
class PresenceRouterConfig:
|
||||||
|
def __init__(self):
|
||||||
|
# Config options with their defaults
|
||||||
|
# A list of users to always send all user presence updates to
|
||||||
|
self.always_send_to_users = [] # type: List[str]
|
||||||
|
|
||||||
|
# A list of users to ignore presence updates for. Does not affect
|
||||||
|
# shared-room presence relationships
|
||||||
|
self.blacklisted_users = [] # type: List[str]
|
||||||
|
|
||||||
|
class ExamplePresenceRouter:
|
||||||
|
"""An example implementation of synapse.presence_router.PresenceRouter.
|
||||||
|
Supports routing all presence to a configured set of users, or a subset
|
||||||
|
of presence from certain users to members of certain rooms.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: A configuration object.
|
||||||
|
module_api: An instance of Synapse's ModuleApi.
|
||||||
|
"""
|
||||||
|
def __init__(self, config: PresenceRouterConfig, module_api: ModuleApi):
|
||||||
|
self._config = config
|
||||||
|
self._module_api = module_api
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_config(config_dict: dict) -> PresenceRouterConfig:
|
||||||
|
"""Parse a configuration dictionary from the homeserver config, do
|
||||||
|
some validation and return a typed PresenceRouterConfig.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_dict: The configuration dictionary.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A validated config object.
|
||||||
|
"""
|
||||||
|
# Initialise a typed config object
|
||||||
|
config = PresenceRouterConfig()
|
||||||
|
always_send_to_users = config_dict.get("always_send_to_users")
|
||||||
|
blacklisted_users = config_dict.get("blacklisted_users")
|
||||||
|
|
||||||
|
# Do some validation of config options... otherwise raise a
|
||||||
|
# synapse.config.ConfigError.
|
||||||
|
config.always_send_to_users = always_send_to_users
|
||||||
|
config.blacklisted_users = blacklisted_users
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
async def get_users_for_states(
|
||||||
|
self,
|
||||||
|
state_updates: Iterable[UserPresenceState],
|
||||||
|
) -> Dict[str, Set[UserPresenceState]]:
|
||||||
|
"""Given an iterable of user presence updates, determine where each one
|
||||||
|
needs to go. Returned results will not affect presence updates that are
|
||||||
|
sent between users who share a room.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_updates: An iterable of user presence state updates.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary of user_id -> set of UserPresenceState that the user should
|
||||||
|
receive.
|
||||||
|
"""
|
||||||
|
destination_users = {} # type: Dict[str, Set[UserPresenceState]
|
||||||
|
|
||||||
|
# Ignore any updates for blacklisted users
|
||||||
|
desired_updates = set()
|
||||||
|
for update in state_updates:
|
||||||
|
if update.state_key not in self._config.blacklisted_users:
|
||||||
|
desired_updates.add(update)
|
||||||
|
|
||||||
|
# Send all presence updates to specific users
|
||||||
|
for user_id in self._config.always_send_to_users:
|
||||||
|
destination_users[user_id] = desired_updates
|
||||||
|
|
||||||
|
return destination_users
|
||||||
|
|
||||||
|
async def get_interested_users(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
) -> Union[Set[str], PresenceRouter.ALL_USERS]:
|
||||||
|
"""
|
||||||
|
Retrieve a list of users that `user_id` is interested in receiving the
|
||||||
|
presence of. This will be in addition to those they share a room with.
|
||||||
|
Optionally, the object PresenceRouter.ALL_USERS can be returned to indicate
|
||||||
|
that this user should receive all incoming local and remote presence updates.
|
||||||
|
|
||||||
|
Note that this method will only be called for local users.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: A user requesting presence updates.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A set of user IDs to return additional presence updates for, or
|
||||||
|
PresenceRouter.ALL_USERS to return presence updates for all other users.
|
||||||
|
"""
|
||||||
|
if user_id in self._config.always_send_to_users:
|
||||||
|
return PresenceRouter.ALL_USERS
|
||||||
|
|
||||||
|
return set()
|
||||||
|
```
|
||||||
|
|
||||||
|
#### A note on `get_users_for_states` and `get_interested_users`
|
||||||
|
|
||||||
|
Both of these methods are effectively two different sides of the same coin. The logic
|
||||||
|
regarding which users should receive updates for other users should be the same
|
||||||
|
between them.
|
||||||
|
|
||||||
|
`get_users_for_states` is called when presence updates come in from either federation
|
||||||
|
or local users, and is used to either direct local presence to remote users, or to
|
||||||
|
wake up the sync streams of local users to collect remote presence.
|
||||||
|
|
||||||
|
In contrast, `get_interested_users` is used to determine the users that presence should
|
||||||
|
be fetched for when a local user is syncing. This presence is then retrieved, before
|
||||||
|
being fed through `get_users_for_states` once again, with only the syncing user's
|
||||||
|
routing information pulled from the resulting dictionary.
|
||||||
|
|
||||||
|
Their routing logic should thus line up, else you may run into unintended behaviour.
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
Once you've crafted your module and installed it into the same Python environment as
|
||||||
|
Synapse, amend your homeserver config file with the following.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
presence:
|
||||||
|
routing_module:
|
||||||
|
module: my_module.ExamplePresenceRouter
|
||||||
|
config:
|
||||||
|
# Any configuration options for your module. The below is an example.
|
||||||
|
# of setting options for ExamplePresenceRouter.
|
||||||
|
always_send_to_users: ["@presence_gobbler:example.org"]
|
||||||
|
blacklisted_users:
|
||||||
|
- "@alice:example.com"
|
||||||
|
- "@bob:example.com"
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
The contents of `config` will be passed as a Python dictionary to the static
|
||||||
|
`parse_config` method of your class. The object returned by this method will
|
||||||
|
then be passed to the `__init__` method of your module as `config`.
|
|
@ -82,9 +82,28 @@ pid_file: DATADIR/homeserver.pid
|
||||||
#
|
#
|
||||||
#soft_file_limit: 0
|
#soft_file_limit: 0
|
||||||
|
|
||||||
# Set to false to disable presence tracking on this homeserver.
|
# Presence tracking allows users to see the state (e.g online/offline)
|
||||||
|
# of other local and remote users.
|
||||||
#
|
#
|
||||||
#use_presence: false
|
presence:
|
||||||
|
# Uncomment to disable presence tracking on this homeserver. This option
|
||||||
|
# replaces the previous top-level 'use_presence' option.
|
||||||
|
#
|
||||||
|
#enabled: false
|
||||||
|
|
||||||
|
# Presence routers are third-party modules that can specify additional logic
|
||||||
|
# to where presence updates from users are routed.
|
||||||
|
#
|
||||||
|
presence_router:
|
||||||
|
# The custom module's class. Uncomment to use a custom presence router module.
|
||||||
|
#
|
||||||
|
#module: "my_custom_router.PresenceRouter"
|
||||||
|
|
||||||
|
# Configuration options of the custom module. Refer to your module's
|
||||||
|
# documentation for available options.
|
||||||
|
#
|
||||||
|
#config:
|
||||||
|
# example_option: 'something'
|
||||||
|
|
||||||
# Whether to require authentication to retrieve profile data (avatars,
|
# Whether to require authentication to retrieve profile data (avatars,
|
||||||
# display names) of other users through the client API. Defaults to
|
# display names) of other users through the client API. Defaults to
|
||||||
|
@ -1246,9 +1265,9 @@ account_validity:
|
||||||
#
|
#
|
||||||
#allowed_local_3pids:
|
#allowed_local_3pids:
|
||||||
# - medium: email
|
# - medium: email
|
||||||
# pattern: '.*@matrix\.org'
|
# pattern: '^[^@]+@matrix\.org$'
|
||||||
# - medium: email
|
# - medium: email
|
||||||
# pattern: '.*@vector\.im'
|
# pattern: '^[^@]+@vector\.im$'
|
||||||
# - medium: msisdn
|
# - medium: msisdn
|
||||||
# pattern: '\+44'
|
# pattern: '\+44'
|
||||||
|
|
||||||
|
@ -1451,14 +1470,31 @@ metrics_flags:
|
||||||
|
|
||||||
## API Configuration ##
|
## API Configuration ##
|
||||||
|
|
||||||
# A list of event types that will be included in the room_invite_state
|
# Controls for the state that is shared with users who receive an invite
|
||||||
|
# to a room
|
||||||
#
|
#
|
||||||
#room_invite_state_types:
|
room_prejoin_state:
|
||||||
# - "m.room.join_rules"
|
# By default, the following state event types are shared with users who
|
||||||
# - "m.room.canonical_alias"
|
# receive invites to the room:
|
||||||
# - "m.room.avatar"
|
#
|
||||||
# - "m.room.encryption"
|
# - m.room.join_rules
|
||||||
# - "m.room.name"
|
# - m.room.canonical_alias
|
||||||
|
# - m.room.avatar
|
||||||
|
# - m.room.encryption
|
||||||
|
# - m.room.name
|
||||||
|
#
|
||||||
|
# Uncomment the following to disable these defaults (so that only the event
|
||||||
|
# types listed in 'additional_event_types' are shared). Defaults to 'false'.
|
||||||
|
#
|
||||||
|
#disable_default_event_types: true
|
||||||
|
|
||||||
|
# Additional state event types to share with users when they are invited
|
||||||
|
# to a room.
|
||||||
|
#
|
||||||
|
# By default, this list is empty (so only the default event types are shared).
|
||||||
|
#
|
||||||
|
#additional_event_types:
|
||||||
|
# - org.example.custom.event.type
|
||||||
|
|
||||||
|
|
||||||
# A list of application service config files to use
|
# A list of application service config files to use
|
||||||
|
|
1
mypy.ini
1
mypy.ini
|
@ -8,6 +8,7 @@ show_traceback = True
|
||||||
mypy_path = stubs
|
mypy_path = stubs
|
||||||
warn_unreachable = True
|
warn_unreachable = True
|
||||||
local_partial_types = True
|
local_partial_types = True
|
||||||
|
no_implicit_optional = True
|
||||||
|
|
||||||
# To find all folders that pass mypy you run:
|
# To find all folders that pass mypy you run:
|
||||||
#
|
#
|
||||||
|
|
|
@ -1,22 +1,49 @@
|
||||||
#! /bin/bash -eu
|
#!/usr/bin/env bash
|
||||||
# This script is designed for developers who want to test their code
|
# This script is designed for developers who want to test their code
|
||||||
# against Complement.
|
# against Complement.
|
||||||
#
|
#
|
||||||
# It makes a Synapse image which represents the current checkout,
|
# It makes a Synapse image which represents the current checkout,
|
||||||
# then downloads Complement and runs it with that image.
|
# builds a synapse-complement image on top, then runs tests with it.
|
||||||
|
#
|
||||||
|
# By default the script will fetch the latest Complement master branch and
|
||||||
|
# run tests with that. This can be overridden to use a custom Complement
|
||||||
|
# checkout by setting the COMPLEMENT_DIR environment variable to the
|
||||||
|
# filepath of a local Complement checkout.
|
||||||
|
#
|
||||||
|
# A regular expression of test method names can be supplied as the first
|
||||||
|
# argument to the script. Complement will then only run those tests. If
|
||||||
|
# no regex is supplied, all tests are run. For example;
|
||||||
|
#
|
||||||
|
# ./complement.sh "TestOutboundFederation(Profile|Send)"
|
||||||
|
#
|
||||||
|
|
||||||
|
# Exit if a line returns a non-zero exit code
|
||||||
|
set -e
|
||||||
|
|
||||||
|
# Change to the repository root
|
||||||
cd "$(dirname $0)/.."
|
cd "$(dirname $0)/.."
|
||||||
|
|
||||||
# Build the base Synapse image from the local checkout
|
# Check for a user-specified Complement checkout
|
||||||
docker build -t matrixdotorg/synapse:latest -f docker/Dockerfile .
|
if [[ -z "$COMPLEMENT_DIR" ]]; then
|
||||||
|
echo "COMPLEMENT_DIR not set. Fetching the latest Complement checkout..."
|
||||||
# Download Complement
|
wget -Nq https://github.com/matrix-org/complement/archive/master.tar.gz
|
||||||
wget -N https://github.com/matrix-org/complement/archive/master.tar.gz
|
|
||||||
tar -xzf master.tar.gz
|
tar -xzf master.tar.gz
|
||||||
cd complement-master
|
COMPLEMENT_DIR=complement-master
|
||||||
|
echo "Checkout available at 'complement-master'"
|
||||||
|
fi
|
||||||
|
|
||||||
# Build the Synapse image from Complement, based on the above image we just built
|
# Build the base Synapse image from the local checkout
|
||||||
docker build -t complement-synapse -f dockerfiles/Synapse.Dockerfile ./dockerfiles
|
docker build -t matrixdotorg/synapse -f docker/Dockerfile .
|
||||||
|
# Build the Synapse monolith image from Complement, based on the above image we just built
|
||||||
|
docker build -t complement-synapse -f "$COMPLEMENT_DIR/dockerfiles/Synapse.Dockerfile" "$COMPLEMENT_DIR/dockerfiles"
|
||||||
|
|
||||||
# Run the tests on the resulting image!
|
cd "$COMPLEMENT_DIR"
|
||||||
COMPLEMENT_BASE_IMAGE=complement-synapse go test -v -count=1 ./tests
|
|
||||||
|
EXTRA_COMPLEMENT_ARGS=""
|
||||||
|
if [[ -n "$1" ]]; then
|
||||||
|
# A test name regex has been set, supply it to Complement
|
||||||
|
EXTRA_COMPLEMENT_ARGS+="-run $1 "
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Run the tests!
|
||||||
|
COMPLEMENT_BASE_IMAGE=complement-synapse go test -v -tags synapse_blacklist -count=1 $EXTRA_COMPLEMENT_ARGS ./tests
|
||||||
|
|
|
@ -95,4 +95,4 @@ isort "${files[@]}"
|
||||||
python3 -m black "${files[@]}"
|
python3 -m black "${files[@]}"
|
||||||
./scripts-dev/config-lint.sh
|
./scripts-dev/config-lint.sh
|
||||||
flake8 "${files[@]}"
|
flake8 "${files[@]}"
|
||||||
dmypy run
|
mypy
|
||||||
|
|
2
setup.py
2
setup.py
|
@ -99,7 +99,7 @@ CONDITIONAL_REQUIREMENTS["lint"] = [
|
||||||
"isort==5.7.0",
|
"isort==5.7.0",
|
||||||
"black==20.8b1",
|
"black==20.8b1",
|
||||||
"flake8-comprehensions",
|
"flake8-comprehensions",
|
||||||
"flake8-bugbear",
|
"flake8-bugbear==21.3.2",
|
||||||
"flake8",
|
"flake8",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
@ -48,7 +48,7 @@ try:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
__version__ = "1.31.0rc1"
|
__version__ = "1.31.0"
|
||||||
|
|
||||||
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
|
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
|
||||||
# We import here so that we don't have to install a bunch of deps when
|
# We import here so that we don't have to install a bunch of deps when
|
||||||
|
|
|
@ -59,6 +59,8 @@ class JoinRules:
|
||||||
KNOCK = "knock"
|
KNOCK = "knock"
|
||||||
INVITE = "invite"
|
INVITE = "invite"
|
||||||
PRIVATE = "private"
|
PRIVATE = "private"
|
||||||
|
# As defined for MSC3083.
|
||||||
|
MSC3083_RESTRICTED = "restricted"
|
||||||
|
|
||||||
|
|
||||||
class LoginType:
|
class LoginType:
|
||||||
|
|
|
@ -17,6 +17,7 @@ from collections import OrderedDict
|
||||||
from typing import Hashable, Optional, Tuple
|
from typing import Hashable, Optional, Tuple
|
||||||
|
|
||||||
from synapse.api.errors import LimitExceededError
|
from synapse.api.errors import LimitExceededError
|
||||||
|
from synapse.storage.databases.main import DataStore
|
||||||
from synapse.types import Requester
|
from synapse.types import Requester
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
|
||||||
|
@ -31,10 +32,13 @@ class Ratelimiter:
|
||||||
burst_count: How many actions that can be performed before being limited.
|
burst_count: How many actions that can be performed before being limited.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, clock: Clock, rate_hz: float, burst_count: int):
|
def __init__(
|
||||||
|
self, store: DataStore, clock: Clock, rate_hz: float, burst_count: int
|
||||||
|
):
|
||||||
self.clock = clock
|
self.clock = clock
|
||||||
self.rate_hz = rate_hz
|
self.rate_hz = rate_hz
|
||||||
self.burst_count = burst_count
|
self.burst_count = burst_count
|
||||||
|
self.store = store
|
||||||
|
|
||||||
# A ordered dictionary keeping track of actions, when they were last
|
# A ordered dictionary keeping track of actions, when they were last
|
||||||
# performed and how often. Each entry is a mapping from a key of arbitrary type
|
# performed and how often. Each entry is a mapping from a key of arbitrary type
|
||||||
|
@ -46,45 +50,10 @@ class Ratelimiter:
|
||||||
OrderedDict()
|
OrderedDict()
|
||||||
) # type: OrderedDict[Hashable, Tuple[float, int, float]]
|
) # type: OrderedDict[Hashable, Tuple[float, int, float]]
|
||||||
|
|
||||||
def can_requester_do_action(
|
async def can_do_action(
|
||||||
self,
|
self,
|
||||||
requester: Requester,
|
requester: Optional[Requester],
|
||||||
rate_hz: Optional[float] = None,
|
key: Optional[Hashable] = None,
|
||||||
burst_count: Optional[int] = None,
|
|
||||||
update: bool = True,
|
|
||||||
_time_now_s: Optional[int] = None,
|
|
||||||
) -> Tuple[bool, float]:
|
|
||||||
"""Can the requester perform the action?
|
|
||||||
|
|
||||||
Args:
|
|
||||||
requester: The requester to key off when rate limiting. The user property
|
|
||||||
will be used.
|
|
||||||
rate_hz: The long term number of actions that can be performed in a second.
|
|
||||||
Overrides the value set during instantiation if set.
|
|
||||||
burst_count: How many actions that can be performed before being limited.
|
|
||||||
Overrides the value set during instantiation if set.
|
|
||||||
update: Whether to count this check as performing the action
|
|
||||||
_time_now_s: The current time. Optional, defaults to the current time according
|
|
||||||
to self.clock. Only used by tests.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A tuple containing:
|
|
||||||
* A bool indicating if they can perform the action now
|
|
||||||
* The reactor timestamp for when the action can be performed next.
|
|
||||||
-1 if rate_hz is less than or equal to zero
|
|
||||||
"""
|
|
||||||
# Disable rate limiting of users belonging to any AS that is configured
|
|
||||||
# not to be rate limited in its registration file (rate_limited: true|false).
|
|
||||||
if requester.app_service and not requester.app_service.is_rate_limited():
|
|
||||||
return True, -1.0
|
|
||||||
|
|
||||||
return self.can_do_action(
|
|
||||||
requester.user.to_string(), rate_hz, burst_count, update, _time_now_s
|
|
||||||
)
|
|
||||||
|
|
||||||
def can_do_action(
|
|
||||||
self,
|
|
||||||
key: Hashable,
|
|
||||||
rate_hz: Optional[float] = None,
|
rate_hz: Optional[float] = None,
|
||||||
burst_count: Optional[int] = None,
|
burst_count: Optional[int] = None,
|
||||||
update: bool = True,
|
update: bool = True,
|
||||||
|
@ -92,9 +61,16 @@ class Ratelimiter:
|
||||||
) -> Tuple[bool, float]:
|
) -> Tuple[bool, float]:
|
||||||
"""Can the entity (e.g. user or IP address) perform the action?
|
"""Can the entity (e.g. user or IP address) perform the action?
|
||||||
|
|
||||||
|
Checks if the user has ratelimiting disabled in the database by looking
|
||||||
|
for null/zero values in the `ratelimit_override` table. (Non-zero
|
||||||
|
values aren't honoured, as they're specific to the event sending
|
||||||
|
ratelimiter, rather than all ratelimiters)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
key: The key we should use when rate limiting. Can be a user ID
|
requester: The requester that is doing the action, if any. Used to check
|
||||||
(when sending events), an IP address, etc.
|
if the user has ratelimits disabled in the database.
|
||||||
|
key: An arbitrary key used to classify an action. Defaults to the
|
||||||
|
requester's user ID.
|
||||||
rate_hz: The long term number of actions that can be performed in a second.
|
rate_hz: The long term number of actions that can be performed in a second.
|
||||||
Overrides the value set during instantiation if set.
|
Overrides the value set during instantiation if set.
|
||||||
burst_count: How many actions that can be performed before being limited.
|
burst_count: How many actions that can be performed before being limited.
|
||||||
|
@ -109,6 +85,30 @@ class Ratelimiter:
|
||||||
* The reactor timestamp for when the action can be performed next.
|
* The reactor timestamp for when the action can be performed next.
|
||||||
-1 if rate_hz is less than or equal to zero
|
-1 if rate_hz is less than or equal to zero
|
||||||
"""
|
"""
|
||||||
|
if key is None:
|
||||||
|
if not requester:
|
||||||
|
raise ValueError("Must supply at least one of `requester` or `key`")
|
||||||
|
|
||||||
|
key = requester.user.to_string()
|
||||||
|
|
||||||
|
if requester:
|
||||||
|
# Disable rate limiting of users belonging to any AS that is configured
|
||||||
|
# not to be rate limited in its registration file (rate_limited: true|false).
|
||||||
|
if requester.app_service and not requester.app_service.is_rate_limited():
|
||||||
|
return True, -1.0
|
||||||
|
|
||||||
|
# Check if ratelimiting has been disabled for the user.
|
||||||
|
#
|
||||||
|
# Note that we don't use the returned rate/burst count, as the table
|
||||||
|
# is specifically for the event sending ratelimiter. Instead, we
|
||||||
|
# only use it to (somewhat cheekily) infer whether the user should
|
||||||
|
# be subject to any rate limiting or not.
|
||||||
|
override = await self.store.get_ratelimit_for_user(
|
||||||
|
requester.authenticated_entity
|
||||||
|
)
|
||||||
|
if override and not override.messages_per_second:
|
||||||
|
return True, -1.0
|
||||||
|
|
||||||
# Override default values if set
|
# Override default values if set
|
||||||
time_now_s = _time_now_s if _time_now_s is not None else self.clock.time()
|
time_now_s = _time_now_s if _time_now_s is not None else self.clock.time()
|
||||||
rate_hz = rate_hz if rate_hz is not None else self.rate_hz
|
rate_hz = rate_hz if rate_hz is not None else self.rate_hz
|
||||||
|
@ -175,9 +175,10 @@ class Ratelimiter:
|
||||||
else:
|
else:
|
||||||
del self.actions[key]
|
del self.actions[key]
|
||||||
|
|
||||||
def ratelimit(
|
async def ratelimit(
|
||||||
self,
|
self,
|
||||||
key: Hashable,
|
requester: Optional[Requester],
|
||||||
|
key: Optional[Hashable] = None,
|
||||||
rate_hz: Optional[float] = None,
|
rate_hz: Optional[float] = None,
|
||||||
burst_count: Optional[int] = None,
|
burst_count: Optional[int] = None,
|
||||||
update: bool = True,
|
update: bool = True,
|
||||||
|
@ -185,8 +186,16 @@ class Ratelimiter:
|
||||||
):
|
):
|
||||||
"""Checks if an action can be performed. If not, raises a LimitExceededError
|
"""Checks if an action can be performed. If not, raises a LimitExceededError
|
||||||
|
|
||||||
|
Checks if the user has ratelimiting disabled in the database by looking
|
||||||
|
for null/zero values in the `ratelimit_override` table. (Non-zero
|
||||||
|
values aren't honoured, as they're specific to the event sending
|
||||||
|
ratelimiter, rather than all ratelimiters)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
key: An arbitrary key used to classify an action
|
requester: The requester that is doing the action, if any. Used to check for
|
||||||
|
if the user has ratelimits disabled.
|
||||||
|
key: An arbitrary key used to classify an action. Defaults to the
|
||||||
|
requester's user ID.
|
||||||
rate_hz: The long term number of actions that can be performed in a second.
|
rate_hz: The long term number of actions that can be performed in a second.
|
||||||
Overrides the value set during instantiation if set.
|
Overrides the value set during instantiation if set.
|
||||||
burst_count: How many actions that can be performed before being limited.
|
burst_count: How many actions that can be performed before being limited.
|
||||||
|
@ -201,7 +210,8 @@ class Ratelimiter:
|
||||||
"""
|
"""
|
||||||
time_now_s = _time_now_s if _time_now_s is not None else self.clock.time()
|
time_now_s = _time_now_s if _time_now_s is not None else self.clock.time()
|
||||||
|
|
||||||
allowed, time_allowed = self.can_do_action(
|
allowed, time_allowed = await self.can_do_action(
|
||||||
|
requester,
|
||||||
key,
|
key,
|
||||||
rate_hz=rate_hz,
|
rate_hz=rate_hz,
|
||||||
burst_count=burst_count,
|
burst_count=burst_count,
|
||||||
|
|
|
@ -57,7 +57,7 @@ class RoomVersion:
|
||||||
state_res = attr.ib(type=int) # one of the StateResolutionVersions
|
state_res = attr.ib(type=int) # one of the StateResolutionVersions
|
||||||
enforce_key_validity = attr.ib(type=bool)
|
enforce_key_validity = attr.ib(type=bool)
|
||||||
|
|
||||||
# bool: before MSC2261/MSC2432, m.room.aliases had special auth rules and redaction rules
|
# Before MSC2261/MSC2432, m.room.aliases had special auth rules and redaction rules
|
||||||
special_case_aliases_auth = attr.ib(type=bool)
|
special_case_aliases_auth = attr.ib(type=bool)
|
||||||
# Strictly enforce canonicaljson, do not allow:
|
# Strictly enforce canonicaljson, do not allow:
|
||||||
# * Integers outside the range of [-2 ^ 53 + 1, 2 ^ 53 - 1]
|
# * Integers outside the range of [-2 ^ 53 + 1, 2 ^ 53 - 1]
|
||||||
|
@ -69,6 +69,8 @@ class RoomVersion:
|
||||||
limit_notifications_power_levels = attr.ib(type=bool)
|
limit_notifications_power_levels = attr.ib(type=bool)
|
||||||
# MSC2174/MSC2176: Apply updated redaction rules algorithm.
|
# MSC2174/MSC2176: Apply updated redaction rules algorithm.
|
||||||
msc2176_redaction_rules = attr.ib(type=bool)
|
msc2176_redaction_rules = attr.ib(type=bool)
|
||||||
|
# MSC3083: Support the 'restricted' join_rule.
|
||||||
|
msc3083_join_rules = attr.ib(type=bool)
|
||||||
|
|
||||||
|
|
||||||
class RoomVersions:
|
class RoomVersions:
|
||||||
|
@ -82,6 +84,7 @@ class RoomVersions:
|
||||||
strict_canonicaljson=False,
|
strict_canonicaljson=False,
|
||||||
limit_notifications_power_levels=False,
|
limit_notifications_power_levels=False,
|
||||||
msc2176_redaction_rules=False,
|
msc2176_redaction_rules=False,
|
||||||
|
msc3083_join_rules=False,
|
||||||
)
|
)
|
||||||
V2 = RoomVersion(
|
V2 = RoomVersion(
|
||||||
"2",
|
"2",
|
||||||
|
@ -93,6 +96,7 @@ class RoomVersions:
|
||||||
strict_canonicaljson=False,
|
strict_canonicaljson=False,
|
||||||
limit_notifications_power_levels=False,
|
limit_notifications_power_levels=False,
|
||||||
msc2176_redaction_rules=False,
|
msc2176_redaction_rules=False,
|
||||||
|
msc3083_join_rules=False,
|
||||||
)
|
)
|
||||||
V3 = RoomVersion(
|
V3 = RoomVersion(
|
||||||
"3",
|
"3",
|
||||||
|
@ -104,6 +108,7 @@ class RoomVersions:
|
||||||
strict_canonicaljson=False,
|
strict_canonicaljson=False,
|
||||||
limit_notifications_power_levels=False,
|
limit_notifications_power_levels=False,
|
||||||
msc2176_redaction_rules=False,
|
msc2176_redaction_rules=False,
|
||||||
|
msc3083_join_rules=False,
|
||||||
)
|
)
|
||||||
V4 = RoomVersion(
|
V4 = RoomVersion(
|
||||||
"4",
|
"4",
|
||||||
|
@ -115,6 +120,7 @@ class RoomVersions:
|
||||||
strict_canonicaljson=False,
|
strict_canonicaljson=False,
|
||||||
limit_notifications_power_levels=False,
|
limit_notifications_power_levels=False,
|
||||||
msc2176_redaction_rules=False,
|
msc2176_redaction_rules=False,
|
||||||
|
msc3083_join_rules=False,
|
||||||
)
|
)
|
||||||
V5 = RoomVersion(
|
V5 = RoomVersion(
|
||||||
"5",
|
"5",
|
||||||
|
@ -126,6 +132,7 @@ class RoomVersions:
|
||||||
strict_canonicaljson=False,
|
strict_canonicaljson=False,
|
||||||
limit_notifications_power_levels=False,
|
limit_notifications_power_levels=False,
|
||||||
msc2176_redaction_rules=False,
|
msc2176_redaction_rules=False,
|
||||||
|
msc3083_join_rules=False,
|
||||||
)
|
)
|
||||||
V6 = RoomVersion(
|
V6 = RoomVersion(
|
||||||
"6",
|
"6",
|
||||||
|
@ -137,6 +144,7 @@ class RoomVersions:
|
||||||
strict_canonicaljson=True,
|
strict_canonicaljson=True,
|
||||||
limit_notifications_power_levels=True,
|
limit_notifications_power_levels=True,
|
||||||
msc2176_redaction_rules=False,
|
msc2176_redaction_rules=False,
|
||||||
|
msc3083_join_rules=False,
|
||||||
)
|
)
|
||||||
MSC2176 = RoomVersion(
|
MSC2176 = RoomVersion(
|
||||||
"org.matrix.msc2176",
|
"org.matrix.msc2176",
|
||||||
|
@ -148,6 +156,19 @@ class RoomVersions:
|
||||||
strict_canonicaljson=True,
|
strict_canonicaljson=True,
|
||||||
limit_notifications_power_levels=True,
|
limit_notifications_power_levels=True,
|
||||||
msc2176_redaction_rules=True,
|
msc2176_redaction_rules=True,
|
||||||
|
msc3083_join_rules=False,
|
||||||
|
)
|
||||||
|
MSC3083 = RoomVersion(
|
||||||
|
"org.matrix.msc3083",
|
||||||
|
RoomDisposition.UNSTABLE,
|
||||||
|
EventFormatVersions.V3,
|
||||||
|
StateResolutionVersions.V2,
|
||||||
|
enforce_key_validity=True,
|
||||||
|
special_case_aliases_auth=False,
|
||||||
|
strict_canonicaljson=True,
|
||||||
|
limit_notifications_power_levels=True,
|
||||||
|
msc2176_redaction_rules=False,
|
||||||
|
msc3083_join_rules=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -162,4 +183,5 @@ KNOWN_ROOM_VERSIONS = {
|
||||||
RoomVersions.V6,
|
RoomVersions.V6,
|
||||||
RoomVersions.MSC2176,
|
RoomVersions.MSC2176,
|
||||||
)
|
)
|
||||||
|
# Note that we do not include MSC3083 here unless it is enabled in the config.
|
||||||
} # type: Dict[str, RoomVersion]
|
} # type: Dict[str, RoomVersion]
|
||||||
|
|
|
@ -281,6 +281,7 @@ class GenericWorkerPresence(BasePresenceHandler):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.is_mine_id = hs.is_mine_id
|
self.is_mine_id = hs.is_mine_id
|
||||||
|
|
||||||
|
self.presence_router = hs.get_presence_router()
|
||||||
self._presence_enabled = hs.config.use_presence
|
self._presence_enabled = hs.config.use_presence
|
||||||
|
|
||||||
# The number of ongoing syncs on this process, by user id.
|
# The number of ongoing syncs on this process, by user id.
|
||||||
|
@ -395,7 +396,7 @@ class GenericWorkerPresence(BasePresenceHandler):
|
||||||
return _user_syncing()
|
return _user_syncing()
|
||||||
|
|
||||||
async def notify_from_replication(self, states, stream_id):
|
async def notify_from_replication(self, states, stream_id):
|
||||||
parties = await get_interested_parties(self.store, states)
|
parties = await get_interested_parties(self.store, self.presence_router, states)
|
||||||
room_ids_to_states, users_to_states = parties
|
room_ids_to_states, users_to_states = parties
|
||||||
|
|
||||||
self.notifier.on_new_event(
|
self.notifier.on_new_event(
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright 2015, 2016 OpenMarket Ltd
|
# Copyright 2015-2021 The Matrix.org Foundation C.I.C.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -12,38 +12,131 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes
|
import logging
|
||||||
|
from typing import Iterable
|
||||||
|
|
||||||
from ._base import Config
|
from synapse.api.constants import EventTypes
|
||||||
|
from synapse.config._base import Config, ConfigError
|
||||||
|
from synapse.config._util import validate_config
|
||||||
|
from synapse.types import JsonDict
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ApiConfig(Config):
|
class ApiConfig(Config):
|
||||||
section = "api"
|
section = "api"
|
||||||
|
|
||||||
def read_config(self, config, **kwargs):
|
def read_config(self, config: JsonDict, **kwargs):
|
||||||
self.room_invite_state_types = config.get(
|
validate_config(_MAIN_SCHEMA, config, ())
|
||||||
"room_invite_state_types",
|
self.room_prejoin_state = list(self._get_prejoin_state_types(config))
|
||||||
[
|
|
||||||
|
def generate_config_section(cls, **kwargs) -> str:
|
||||||
|
formatted_default_state_types = "\n".join(
|
||||||
|
" # - %s" % (t,) for t in _DEFAULT_PREJOIN_STATE_TYPES
|
||||||
|
)
|
||||||
|
|
||||||
|
return """\
|
||||||
|
## API Configuration ##
|
||||||
|
|
||||||
|
# Controls for the state that is shared with users who receive an invite
|
||||||
|
# to a room
|
||||||
|
#
|
||||||
|
room_prejoin_state:
|
||||||
|
# By default, the following state event types are shared with users who
|
||||||
|
# receive invites to the room:
|
||||||
|
#
|
||||||
|
%(formatted_default_state_types)s
|
||||||
|
#
|
||||||
|
# Uncomment the following to disable these defaults (so that only the event
|
||||||
|
# types listed in 'additional_event_types' are shared). Defaults to 'false'.
|
||||||
|
#
|
||||||
|
#disable_default_event_types: true
|
||||||
|
|
||||||
|
# Additional state event types to share with users when they are invited
|
||||||
|
# to a room.
|
||||||
|
#
|
||||||
|
# By default, this list is empty (so only the default event types are shared).
|
||||||
|
#
|
||||||
|
#additional_event_types:
|
||||||
|
# - org.example.custom.event.type
|
||||||
|
""" % {
|
||||||
|
"formatted_default_state_types": formatted_default_state_types
|
||||||
|
}
|
||||||
|
|
||||||
|
def _get_prejoin_state_types(self, config: JsonDict) -> Iterable[str]:
|
||||||
|
"""Get the event types to include in the prejoin state
|
||||||
|
|
||||||
|
Parses the config and returns an iterable of the event types to be included.
|
||||||
|
"""
|
||||||
|
room_prejoin_state_config = config.get("room_prejoin_state") or {}
|
||||||
|
|
||||||
|
# backwards-compatibility support for room_invite_state_types
|
||||||
|
if "room_invite_state_types" in config:
|
||||||
|
# if both "room_invite_state_types" and "room_prejoin_state" are set, then
|
||||||
|
# we don't really know what to do.
|
||||||
|
if room_prejoin_state_config:
|
||||||
|
raise ConfigError(
|
||||||
|
"Can't specify both 'room_invite_state_types' and 'room_prejoin_state' "
|
||||||
|
"in config"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.warning(_ROOM_INVITE_STATE_TYPES_WARNING)
|
||||||
|
|
||||||
|
yield from config["room_invite_state_types"]
|
||||||
|
return
|
||||||
|
|
||||||
|
if not room_prejoin_state_config.get("disable_default_event_types"):
|
||||||
|
yield from _DEFAULT_PREJOIN_STATE_TYPES
|
||||||
|
|
||||||
|
if self.spaces_enabled:
|
||||||
|
# MSC1772 suggests adding m.room.create to the prejoin state
|
||||||
|
yield EventTypes.Create
|
||||||
|
|
||||||
|
yield from room_prejoin_state_config.get("additional_event_types", [])
|
||||||
|
|
||||||
|
|
||||||
|
_ROOM_INVITE_STATE_TYPES_WARNING = """\
|
||||||
|
WARNING: The 'room_invite_state_types' configuration setting is now deprecated,
|
||||||
|
and replaced with 'room_prejoin_state'. New features may not work correctly
|
||||||
|
unless 'room_invite_state_types' is removed. See the sample configuration file for
|
||||||
|
details of 'room_prejoin_state'.
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
"""
|
||||||
|
|
||||||
|
_DEFAULT_PREJOIN_STATE_TYPES = [
|
||||||
EventTypes.JoinRules,
|
EventTypes.JoinRules,
|
||||||
EventTypes.CanonicalAlias,
|
EventTypes.CanonicalAlias,
|
||||||
EventTypes.RoomAvatar,
|
EventTypes.RoomAvatar,
|
||||||
EventTypes.RoomEncryption,
|
EventTypes.RoomEncryption,
|
||||||
EventTypes.Name,
|
EventTypes.Name,
|
||||||
],
|
]
|
||||||
)
|
|
||||||
|
|
||||||
def generate_config_section(cls, **kwargs):
|
|
||||||
return """\
|
|
||||||
## API Configuration ##
|
|
||||||
|
|
||||||
# A list of event types that will be included in the room_invite_state
|
# room_prejoin_state can either be None (as it is in the default config), or
|
||||||
#
|
# an object containing other config settings
|
||||||
#room_invite_state_types:
|
_ROOM_PREJOIN_STATE_CONFIG_SCHEMA = {
|
||||||
# - "{JoinRules}"
|
"oneOf": [
|
||||||
# - "{CanonicalAlias}"
|
{
|
||||||
# - "{RoomAvatar}"
|
"type": "object",
|
||||||
# - "{RoomEncryption}"
|
"properties": {
|
||||||
# - "{Name}"
|
"disable_default_event_types": {"type": "boolean"},
|
||||||
""".format(
|
"additional_event_types": {
|
||||||
**vars(EventTypes)
|
"type": "array",
|
||||||
)
|
"items": {"type": "string"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{"type": "null"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
# the legacy room_invite_state_types setting
|
||||||
|
_ROOM_INVITE_STATE_TYPES_SCHEMA = {"type": "array", "items": {"type": "string"}}
|
||||||
|
|
||||||
|
_MAIN_SCHEMA = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"room_prejoin_state": _ROOM_PREJOIN_STATE_CONFIG_SCHEMA,
|
||||||
|
"room_invite_state_types": _ROOM_INVITE_STATE_TYPES_SCHEMA,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
|
||||||
from synapse.config._base import Config
|
from synapse.config._base import Config
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
|
|
||||||
|
@ -27,7 +28,11 @@ class ExperimentalConfig(Config):
|
||||||
|
|
||||||
# MSC2858 (multiple SSO identity providers)
|
# MSC2858 (multiple SSO identity providers)
|
||||||
self.msc2858_enabled = experimental.get("msc2858_enabled", False) # type: bool
|
self.msc2858_enabled = experimental.get("msc2858_enabled", False) # type: bool
|
||||||
# Spaces (MSC1772, MSC2946, etc)
|
|
||||||
|
# Spaces (MSC1772, MSC2946, MSC3083, etc)
|
||||||
self.spaces_enabled = experimental.get("spaces_enabled", False) # type: bool
|
self.spaces_enabled = experimental.get("spaces_enabled", False) # type: bool
|
||||||
|
if self.spaces_enabled:
|
||||||
|
KNOWN_ROOM_VERSIONS[RoomVersions.MSC3083.identifier] = RoomVersions.MSC3083
|
||||||
|
|
||||||
# MSC3026 (busy presence state)
|
# MSC3026 (busy presence state)
|
||||||
self.msc3026_enabled = experimental.get("msc3026_enabled", False) # type: bool
|
self.msc3026_enabled = experimental.get("msc3026_enabled", False) # type: bool
|
||||||
|
|
|
@ -298,9 +298,9 @@ class RegistrationConfig(Config):
|
||||||
#
|
#
|
||||||
#allowed_local_3pids:
|
#allowed_local_3pids:
|
||||||
# - medium: email
|
# - medium: email
|
||||||
# pattern: '.*@matrix\\.org'
|
# pattern: '^[^@]+@matrix\\.org$'
|
||||||
# - medium: email
|
# - medium: email
|
||||||
# pattern: '.*@vector\\.im'
|
# pattern: '^[^@]+@vector\\.im$'
|
||||||
# - medium: msisdn
|
# - medium: msisdn
|
||||||
# pattern: '\\+44'
|
# pattern: '\\+44'
|
||||||
|
|
||||||
|
|
|
@ -27,6 +27,7 @@ import yaml
|
||||||
from netaddr import AddrFormatError, IPNetwork, IPSet
|
from netaddr import AddrFormatError, IPNetwork, IPSet
|
||||||
|
|
||||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||||
|
from synapse.util.module_loader import load_module
|
||||||
from synapse.util.stringutils import parse_and_validate_server_name
|
from synapse.util.stringutils import parse_and_validate_server_name
|
||||||
|
|
||||||
from ._base import Config, ConfigError
|
from ._base import Config, ConfigError
|
||||||
|
@ -238,8 +239,21 @@ class ServerConfig(Config):
|
||||||
self.public_baseurl = config.get("public_baseurl")
|
self.public_baseurl = config.get("public_baseurl")
|
||||||
|
|
||||||
# Whether to enable user presence.
|
# Whether to enable user presence.
|
||||||
|
presence_config = config.get("presence") or {}
|
||||||
|
self.use_presence = presence_config.get("enabled")
|
||||||
|
if self.use_presence is None:
|
||||||
self.use_presence = config.get("use_presence", True)
|
self.use_presence = config.get("use_presence", True)
|
||||||
|
|
||||||
|
# Custom presence router module
|
||||||
|
self.presence_router_module_class = None
|
||||||
|
self.presence_router_config = None
|
||||||
|
presence_router_config = presence_config.get("presence_router")
|
||||||
|
if presence_router_config:
|
||||||
|
(
|
||||||
|
self.presence_router_module_class,
|
||||||
|
self.presence_router_config,
|
||||||
|
) = load_module(presence_router_config, ("presence", "presence_router"))
|
||||||
|
|
||||||
# Whether to update the user directory or not. This should be set to
|
# Whether to update the user directory or not. This should be set to
|
||||||
# false only if we are updating the user directory in a worker
|
# false only if we are updating the user directory in a worker
|
||||||
self.update_user_directory = config.get("update_user_directory", True)
|
self.update_user_directory = config.get("update_user_directory", True)
|
||||||
|
@ -834,9 +848,28 @@ class ServerConfig(Config):
|
||||||
#
|
#
|
||||||
#soft_file_limit: 0
|
#soft_file_limit: 0
|
||||||
|
|
||||||
# Set to false to disable presence tracking on this homeserver.
|
# Presence tracking allows users to see the state (e.g online/offline)
|
||||||
|
# of other local and remote users.
|
||||||
#
|
#
|
||||||
#use_presence: false
|
presence:
|
||||||
|
# Uncomment to disable presence tracking on this homeserver. This option
|
||||||
|
# replaces the previous top-level 'use_presence' option.
|
||||||
|
#
|
||||||
|
#enabled: false
|
||||||
|
|
||||||
|
# Presence routers are third-party modules that can specify additional logic
|
||||||
|
# to where presence updates from users are routed.
|
||||||
|
#
|
||||||
|
presence_router:
|
||||||
|
# The custom module's class. Uncomment to use a custom presence router module.
|
||||||
|
#
|
||||||
|
#module: "my_custom_router.PresenceRouter"
|
||||||
|
|
||||||
|
# Configuration options of the custom module. Refer to your module's
|
||||||
|
# documentation for available options.
|
||||||
|
#
|
||||||
|
#config:
|
||||||
|
# example_option: 'something'
|
||||||
|
|
||||||
# Whether to require authentication to retrieve profile data (avatars,
|
# Whether to require authentication to retrieve profile data (avatars,
|
||||||
# display names) of other users through the client API. Defaults to
|
# display names) of other users through the client API. Defaults to
|
||||||
|
|
|
@ -162,7 +162,7 @@ def check(
|
||||||
logger.debug("Auth events: %s", [a.event_id for a in auth_events.values()])
|
logger.debug("Auth events: %s", [a.event_id for a in auth_events.values()])
|
||||||
|
|
||||||
if event.type == EventTypes.Member:
|
if event.type == EventTypes.Member:
|
||||||
_is_membership_change_allowed(event, auth_events)
|
_is_membership_change_allowed(room_version_obj, event, auth_events)
|
||||||
logger.debug("Allowing! %s", event)
|
logger.debug("Allowing! %s", event)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -220,8 +220,19 @@ def _can_federate(event: EventBase, auth_events: StateMap[EventBase]) -> bool:
|
||||||
|
|
||||||
|
|
||||||
def _is_membership_change_allowed(
|
def _is_membership_change_allowed(
|
||||||
event: EventBase, auth_events: StateMap[EventBase]
|
room_version: RoomVersion, event: EventBase, auth_events: StateMap[EventBase]
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""
|
||||||
|
Confirms that the event which changes membership is an allowed change.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_version: The version of the room.
|
||||||
|
event: The event to check.
|
||||||
|
auth_events: The current auth events of the room.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AuthError if the event is not allowed.
|
||||||
|
"""
|
||||||
membership = event.content["membership"]
|
membership = event.content["membership"]
|
||||||
|
|
||||||
# Check if this is the room creator joining:
|
# Check if this is the room creator joining:
|
||||||
|
@ -315,14 +326,19 @@ def _is_membership_change_allowed(
|
||||||
if user_level < invite_level:
|
if user_level < invite_level:
|
||||||
raise AuthError(403, "You don't have permission to invite users")
|
raise AuthError(403, "You don't have permission to invite users")
|
||||||
elif Membership.JOIN == membership:
|
elif Membership.JOIN == membership:
|
||||||
# Joins are valid iff caller == target and they were:
|
# Joins are valid iff caller == target and:
|
||||||
# invited: They are accepting the invitation
|
# * They are not banned.
|
||||||
# joined: It's a NOOP
|
# * They are accepting a previously sent invitation.
|
||||||
|
# * They are already joined (it's a NOOP).
|
||||||
|
# * The room is public or restricted.
|
||||||
if event.user_id != target_user_id:
|
if event.user_id != target_user_id:
|
||||||
raise AuthError(403, "Cannot force another user to join.")
|
raise AuthError(403, "Cannot force another user to join.")
|
||||||
elif target_banned:
|
elif target_banned:
|
||||||
raise AuthError(403, "You are banned from this room")
|
raise AuthError(403, "You are banned from this room")
|
||||||
elif join_rule == JoinRules.PUBLIC:
|
elif join_rule == JoinRules.PUBLIC or (
|
||||||
|
room_version.msc3083_join_rules
|
||||||
|
and join_rule == JoinRules.MSC3083_RESTRICTED
|
||||||
|
):
|
||||||
pass
|
pass
|
||||||
elif join_rule == JoinRules.INVITE:
|
elif join_rule == JoinRules.INVITE:
|
||||||
if not caller_in_room and not caller_invited:
|
if not caller_in_room and not caller_invited:
|
||||||
|
|
|
@ -0,0 +1,104 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Dict, Iterable, Set, Union
|
||||||
|
|
||||||
|
from synapse.api.presence import UserPresenceState
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
|
||||||
|
|
||||||
|
class PresenceRouter:
|
||||||
|
"""
|
||||||
|
A module that the homeserver will call upon to help route user presence updates to
|
||||||
|
additional destinations. If a custom presence router is configured, calls will be
|
||||||
|
passed to that instead.
|
||||||
|
"""
|
||||||
|
|
||||||
|
ALL_USERS = "ALL"
|
||||||
|
|
||||||
|
def __init__(self, hs: "HomeServer"):
|
||||||
|
self.custom_presence_router = None
|
||||||
|
|
||||||
|
# Check whether a custom presence router module has been configured
|
||||||
|
if hs.config.presence_router_module_class:
|
||||||
|
# Initialise the module
|
||||||
|
self.custom_presence_router = hs.config.presence_router_module_class(
|
||||||
|
config=hs.config.presence_router_config, module_api=hs.get_module_api()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure the module has implemented the required methods
|
||||||
|
required_methods = ["get_users_for_states", "get_interested_users"]
|
||||||
|
for method_name in required_methods:
|
||||||
|
if not hasattr(self.custom_presence_router, method_name):
|
||||||
|
raise Exception(
|
||||||
|
"PresenceRouter module '%s' must implement all required methods: %s"
|
||||||
|
% (
|
||||||
|
hs.config.presence_router_module_class.__name__,
|
||||||
|
", ".join(required_methods),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_users_for_states(
|
||||||
|
self,
|
||||||
|
state_updates: Iterable[UserPresenceState],
|
||||||
|
) -> Dict[str, Set[UserPresenceState]]:
|
||||||
|
"""
|
||||||
|
Given an iterable of user presence updates, determine where each one
|
||||||
|
needs to go.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_updates: An iterable of user presence state updates.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary of user_id -> set of UserPresenceState, indicating which
|
||||||
|
presence updates each user should receive.
|
||||||
|
"""
|
||||||
|
if self.custom_presence_router is not None:
|
||||||
|
# Ask the custom module
|
||||||
|
return await self.custom_presence_router.get_users_for_states(
|
||||||
|
state_updates=state_updates
|
||||||
|
)
|
||||||
|
|
||||||
|
# Don't include any extra destinations for presence updates
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def get_interested_users(self, user_id: str) -> Union[Set[str], ALL_USERS]:
|
||||||
|
"""
|
||||||
|
Retrieve a list of users that `user_id` is interested in receiving the
|
||||||
|
presence of. This will be in addition to those they share a room with.
|
||||||
|
Optionally, the object PresenceRouter.ALL_USERS can be returned to indicate
|
||||||
|
that this user should receive all incoming local and remote presence updates.
|
||||||
|
|
||||||
|
Note that this method will only be called for local users, but can return users
|
||||||
|
that are local or remote.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: A user requesting presence updates.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A set of user IDs to return presence updates for, or ALL_USERS to return all
|
||||||
|
known updates.
|
||||||
|
"""
|
||||||
|
if self.custom_presence_router is not None:
|
||||||
|
# Ask the custom module for interested users
|
||||||
|
return await self.custom_presence_router.get_interested_users(
|
||||||
|
user_id=user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# A custom presence router is not defined.
|
||||||
|
# Don't report any additional interested users
|
||||||
|
return set()
|
|
@ -102,7 +102,7 @@ class FederationClient(FederationBase):
|
||||||
max_len=1000,
|
max_len=1000,
|
||||||
expiry_ms=120 * 1000,
|
expiry_ms=120 * 1000,
|
||||||
reset_expiry_on_get=False,
|
reset_expiry_on_get=False,
|
||||||
)
|
) # type: ExpiringCache[str, EventBase]
|
||||||
|
|
||||||
def _clear_tried_cache(self):
|
def _clear_tried_cache(self):
|
||||||
"""Clear pdu_destination_tried cache"""
|
"""Clear pdu_destination_tried cache"""
|
||||||
|
|
|
@ -739,22 +739,20 @@ class FederationServer(FederationBase):
|
||||||
|
|
||||||
await self.handler.on_receive_pdu(origin, pdu, sent_to_us_directly=True)
|
await self.handler.on_receive_pdu(origin, pdu, sent_to_us_directly=True)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self) -> str:
|
||||||
return "<ReplicationLayer(%s)>" % self.server_name
|
return "<ReplicationLayer(%s)>" % self.server_name
|
||||||
|
|
||||||
async def exchange_third_party_invite(
|
async def exchange_third_party_invite(
|
||||||
self, sender_user_id: str, target_user_id: str, room_id: str, signed: Dict
|
self, sender_user_id: str, target_user_id: str, room_id: str, signed: Dict
|
||||||
):
|
) -> None:
|
||||||
ret = await self.handler.exchange_third_party_invite(
|
await self.handler.exchange_third_party_invite(
|
||||||
sender_user_id, target_user_id, room_id, signed
|
sender_user_id, target_user_id, room_id, signed
|
||||||
)
|
)
|
||||||
return ret
|
|
||||||
|
|
||||||
async def on_exchange_third_party_invite_request(self, event_dict: Dict):
|
async def on_exchange_third_party_invite_request(self, event_dict: Dict) -> None:
|
||||||
ret = await self.handler.on_exchange_third_party_invite_request(event_dict)
|
await self.handler.on_exchange_third_party_invite_request(event_dict)
|
||||||
return ret
|
|
||||||
|
|
||||||
async def check_server_matches_acl(self, server_name: str, room_id: str):
|
async def check_server_matches_acl(self, server_name: str, room_id: str) -> None:
|
||||||
"""Check if the given server is allowed by the server ACLs in the room
|
"""Check if the given server is allowed by the server ACLs in the room
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -870,6 +868,7 @@ class FederationHandlerRegistry:
|
||||||
|
|
||||||
# A rate limiter for incoming room key requests per origin.
|
# A rate limiter for incoming room key requests per origin.
|
||||||
self._room_key_request_rate_limiter = Ratelimiter(
|
self._room_key_request_rate_limiter = Ratelimiter(
|
||||||
|
store=hs.get_datastore(),
|
||||||
clock=self.clock,
|
clock=self.clock,
|
||||||
rate_hz=self.config.rc_key_requests.per_second,
|
rate_hz=self.config.rc_key_requests.per_second,
|
||||||
burst_count=self.config.rc_key_requests.burst_count,
|
burst_count=self.config.rc_key_requests.burst_count,
|
||||||
|
@ -877,7 +876,7 @@ class FederationHandlerRegistry:
|
||||||
|
|
||||||
def register_edu_handler(
|
def register_edu_handler(
|
||||||
self, edu_type: str, handler: Callable[[str, JsonDict], Awaitable[None]]
|
self, edu_type: str, handler: Callable[[str, JsonDict], Awaitable[None]]
|
||||||
):
|
) -> None:
|
||||||
"""Sets the handler callable that will be used to handle an incoming
|
"""Sets the handler callable that will be used to handle an incoming
|
||||||
federation EDU of the given type.
|
federation EDU of the given type.
|
||||||
|
|
||||||
|
@ -896,7 +895,7 @@ class FederationHandlerRegistry:
|
||||||
|
|
||||||
def register_query_handler(
|
def register_query_handler(
|
||||||
self, query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
|
self, query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
|
||||||
):
|
) -> None:
|
||||||
"""Sets the handler callable that will be used to handle an incoming
|
"""Sets the handler callable that will be used to handle an incoming
|
||||||
federation query of the given type.
|
federation query of the given type.
|
||||||
|
|
||||||
|
@ -914,15 +913,17 @@ class FederationHandlerRegistry:
|
||||||
|
|
||||||
self.query_handlers[query_type] = handler
|
self.query_handlers[query_type] = handler
|
||||||
|
|
||||||
def register_instance_for_edu(self, edu_type: str, instance_name: str):
|
def register_instance_for_edu(self, edu_type: str, instance_name: str) -> None:
|
||||||
"""Register that the EDU handler is on a different instance than master."""
|
"""Register that the EDU handler is on a different instance than master."""
|
||||||
self._edu_type_to_instance[edu_type] = [instance_name]
|
self._edu_type_to_instance[edu_type] = [instance_name]
|
||||||
|
|
||||||
def register_instances_for_edu(self, edu_type: str, instance_names: List[str]):
|
def register_instances_for_edu(
|
||||||
|
self, edu_type: str, instance_names: List[str]
|
||||||
|
) -> None:
|
||||||
"""Register that the EDU handler is on multiple instances."""
|
"""Register that the EDU handler is on multiple instances."""
|
||||||
self._edu_type_to_instance[edu_type] = instance_names
|
self._edu_type_to_instance[edu_type] = instance_names
|
||||||
|
|
||||||
async def on_edu(self, edu_type: str, origin: str, content: dict):
|
async def on_edu(self, edu_type: str, origin: str, content: dict) -> None:
|
||||||
if not self.config.use_presence and edu_type == EduTypes.Presence:
|
if not self.config.use_presence and edu_type == EduTypes.Presence:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -930,7 +931,9 @@ class FederationHandlerRegistry:
|
||||||
# the limit, drop them.
|
# the limit, drop them.
|
||||||
if (
|
if (
|
||||||
edu_type == EduTypes.RoomKeyRequest
|
edu_type == EduTypes.RoomKeyRequest
|
||||||
and not self._room_key_request_rate_limiter.can_do_action(origin)
|
and not await self._room_key_request_rate_limiter.can_do_action(
|
||||||
|
None, origin
|
||||||
|
)
|
||||||
):
|
):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
@ -44,6 +44,7 @@ from synapse.types import JsonDict, ReadReceipt, RoomStreamToken
|
||||||
from synapse.util.metrics import Measure, measure_func
|
from synapse.util.metrics import Measure, measure_func
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from synapse.events.presence_router import PresenceRouter
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -162,6 +163,7 @@ class FederationSender(AbstractFederationSender):
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.is_mine_id = hs.is_mine_id
|
self.is_mine_id = hs.is_mine_id
|
||||||
|
|
||||||
|
self._presence_router = None # type: Optional[PresenceRouter]
|
||||||
self._transaction_manager = TransactionManager(hs)
|
self._transaction_manager = TransactionManager(hs)
|
||||||
|
|
||||||
self._instance_name = hs.get_instance_name()
|
self._instance_name = hs.get_instance_name()
|
||||||
|
@ -584,7 +586,22 @@ class FederationSender(AbstractFederationSender):
|
||||||
"""Given a list of states populate self.pending_presence_by_dest and
|
"""Given a list of states populate self.pending_presence_by_dest and
|
||||||
poke to send a new transaction to each destination
|
poke to send a new transaction to each destination
|
||||||
"""
|
"""
|
||||||
hosts_and_states = await get_interested_remotes(self.store, states, self.state)
|
# We pull the presence router here instead of __init__
|
||||||
|
# to prevent a dependency cycle:
|
||||||
|
#
|
||||||
|
# AuthHandler -> Notifier -> FederationSender
|
||||||
|
# -> PresenceRouter -> ModuleApi -> AuthHandler
|
||||||
|
if self._presence_router is None:
|
||||||
|
self._presence_router = self.hs.get_presence_router()
|
||||||
|
|
||||||
|
assert self._presence_router is not None
|
||||||
|
|
||||||
|
hosts_and_states = await get_interested_remotes(
|
||||||
|
self.store,
|
||||||
|
self._presence_router,
|
||||||
|
states,
|
||||||
|
self.state,
|
||||||
|
)
|
||||||
|
|
||||||
for destinations, states in hosts_and_states:
|
for destinations, states in hosts_and_states:
|
||||||
for destination in destinations:
|
for destination in destinations:
|
||||||
|
|
|
@ -29,6 +29,7 @@ from synapse.api.presence import UserPresenceState
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.federation.units import Edu
|
from synapse.federation.units import Edu
|
||||||
from synapse.handlers.presence import format_user_presence_state
|
from synapse.handlers.presence import format_user_presence_state
|
||||||
|
from synapse.logging.opentracing import SynapseTags, set_tag
|
||||||
from synapse.metrics import sent_transactions_counter
|
from synapse.metrics import sent_transactions_counter
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.types import ReadReceipt
|
from synapse.types import ReadReceipt
|
||||||
|
@ -557,6 +558,13 @@ class PerDestinationQueue:
|
||||||
contents, stream_id = await self._store.get_new_device_msgs_for_remote(
|
contents, stream_id = await self._store.get_new_device_msgs_for_remote(
|
||||||
self._destination, last_device_stream_id, to_device_stream_id, limit
|
self._destination, last_device_stream_id, to_device_stream_id, limit
|
||||||
)
|
)
|
||||||
|
for content in contents:
|
||||||
|
message_id = content.get("message_id")
|
||||||
|
if not message_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
set_tag(SynapseTags.TO_DEVICE_MESSAGE_ID, message_id)
|
||||||
|
|
||||||
edus = [
|
edus = [
|
||||||
Edu(
|
Edu(
|
||||||
origin=self._server_name,
|
origin=self._server_name,
|
||||||
|
|
|
@ -620,8 +620,8 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet):
|
||||||
PATH = "/exchange_third_party_invite/(?P<room_id>[^/]*)"
|
PATH = "/exchange_third_party_invite/(?P<room_id>[^/]*)"
|
||||||
|
|
||||||
async def on_PUT(self, origin, content, query, room_id):
|
async def on_PUT(self, origin, content, query, room_id):
|
||||||
content = await self.handler.on_exchange_third_party_invite_request(content)
|
await self.handler.on_exchange_third_party_invite_request(content)
|
||||||
return 200, content
|
return 200, {}
|
||||||
|
|
||||||
|
|
||||||
class FederationClientKeysQueryServlet(BaseFederationServlet):
|
class FederationClientKeysQueryServlet(BaseFederationServlet):
|
||||||
|
|
|
@ -49,7 +49,7 @@ class BaseHandler:
|
||||||
|
|
||||||
# The rate_hz and burst_count are overridden on a per-user basis
|
# The rate_hz and burst_count are overridden on a per-user basis
|
||||||
self.request_ratelimiter = Ratelimiter(
|
self.request_ratelimiter = Ratelimiter(
|
||||||
clock=self.clock, rate_hz=0, burst_count=0
|
store=self.store, clock=self.clock, rate_hz=0, burst_count=0
|
||||||
)
|
)
|
||||||
self._rc_message = self.hs.config.rc_message
|
self._rc_message = self.hs.config.rc_message
|
||||||
|
|
||||||
|
@ -57,6 +57,7 @@ class BaseHandler:
|
||||||
# by the presence of rate limits in the config
|
# by the presence of rate limits in the config
|
||||||
if self.hs.config.rc_admin_redaction:
|
if self.hs.config.rc_admin_redaction:
|
||||||
self.admin_redaction_ratelimiter = Ratelimiter(
|
self.admin_redaction_ratelimiter = Ratelimiter(
|
||||||
|
store=self.store,
|
||||||
clock=self.clock,
|
clock=self.clock,
|
||||||
rate_hz=self.hs.config.rc_admin_redaction.per_second,
|
rate_hz=self.hs.config.rc_admin_redaction.per_second,
|
||||||
burst_count=self.hs.config.rc_admin_redaction.burst_count,
|
burst_count=self.hs.config.rc_admin_redaction.burst_count,
|
||||||
|
@ -91,11 +92,6 @@ class BaseHandler:
|
||||||
if app_service is not None:
|
if app_service is not None:
|
||||||
return # do not ratelimit app service senders
|
return # do not ratelimit app service senders
|
||||||
|
|
||||||
# Disable rate limiting of users belonging to any AS that is configured
|
|
||||||
# not to be rate limited in its registration file (rate_limited: true|false).
|
|
||||||
if requester.app_service and not requester.app_service.is_rate_limited():
|
|
||||||
return
|
|
||||||
|
|
||||||
messages_per_second = self._rc_message.per_second
|
messages_per_second = self._rc_message.per_second
|
||||||
burst_count = self._rc_message.burst_count
|
burst_count = self._rc_message.burst_count
|
||||||
|
|
||||||
|
@ -113,11 +109,11 @@ class BaseHandler:
|
||||||
if is_admin_redaction and self.admin_redaction_ratelimiter:
|
if is_admin_redaction and self.admin_redaction_ratelimiter:
|
||||||
# If we have separate config for admin redactions, use a separate
|
# If we have separate config for admin redactions, use a separate
|
||||||
# ratelimiter as to not have user_ids clash
|
# ratelimiter as to not have user_ids clash
|
||||||
self.admin_redaction_ratelimiter.ratelimit(user_id, update=update)
|
await self.admin_redaction_ratelimiter.ratelimit(requester, update=update)
|
||||||
else:
|
else:
|
||||||
# Override rate and burst count per-user
|
# Override rate and burst count per-user
|
||||||
self.request_ratelimiter.ratelimit(
|
await self.request_ratelimiter.ratelimit(
|
||||||
user_id,
|
requester,
|
||||||
rate_hz=messages_per_second,
|
rate_hz=messages_per_second,
|
||||||
burst_count=burst_count,
|
burst_count=burst_count,
|
||||||
update=update,
|
update=update,
|
||||||
|
|
|
@ -18,7 +18,7 @@ import email.utils
|
||||||
import logging
|
import logging
|
||||||
from email.mime.multipart import MIMEMultipart
|
from email.mime.multipart import MIMEMultipart
|
||||||
from email.mime.text import MIMEText
|
from email.mime.text import MIMEText
|
||||||
from typing import TYPE_CHECKING, List
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
from synapse.api.errors import StoreError, SynapseError
|
from synapse.api.errors import StoreError, SynapseError
|
||||||
from synapse.logging.context import make_deferred_yieldable
|
from synapse.logging.context import make_deferred_yieldable
|
||||||
|
@ -241,7 +241,10 @@ class AccountValidityHandler:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def renew_account_for_user(
|
async def renew_account_for_user(
|
||||||
self, user_id: str, expiration_ts: int = None, email_sent: bool = False
|
self,
|
||||||
|
user_id: str,
|
||||||
|
expiration_ts: Optional[int] = None,
|
||||||
|
email_sent: bool = False,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""Renews the account attached to a given user by pushing back the
|
"""Renews the account attached to a given user by pushing back the
|
||||||
expiration date by the current validity period in the server's
|
expiration date by the current validity period in the server's
|
||||||
|
|
|
@ -238,6 +238,7 @@ class AuthHandler(BaseHandler):
|
||||||
# Ratelimiter for failed auth during UIA. Uses same ratelimit config
|
# Ratelimiter for failed auth during UIA. Uses same ratelimit config
|
||||||
# as per `rc_login.failed_attempts`.
|
# as per `rc_login.failed_attempts`.
|
||||||
self._failed_uia_attempts_ratelimiter = Ratelimiter(
|
self._failed_uia_attempts_ratelimiter = Ratelimiter(
|
||||||
|
store=self.store,
|
||||||
clock=self.clock,
|
clock=self.clock,
|
||||||
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
|
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
|
||||||
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
|
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
|
||||||
|
@ -248,6 +249,7 @@ class AuthHandler(BaseHandler):
|
||||||
|
|
||||||
# Ratelimitier for failed /login attempts
|
# Ratelimitier for failed /login attempts
|
||||||
self._failed_login_attempts_ratelimiter = Ratelimiter(
|
self._failed_login_attempts_ratelimiter = Ratelimiter(
|
||||||
|
store=self.store,
|
||||||
clock=hs.get_clock(),
|
clock=hs.get_clock(),
|
||||||
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
|
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
|
||||||
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
|
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
|
||||||
|
@ -352,7 +354,7 @@ class AuthHandler(BaseHandler):
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
# Check if we should be ratelimited due to too many previous failed attempts
|
# Check if we should be ratelimited due to too many previous failed attempts
|
||||||
self._failed_uia_attempts_ratelimiter.ratelimit(requester_user_id, update=False)
|
await self._failed_uia_attempts_ratelimiter.ratelimit(requester, update=False)
|
||||||
|
|
||||||
# build a list of supported flows
|
# build a list of supported flows
|
||||||
supported_ui_auth_types = await self._get_available_ui_auth_types(
|
supported_ui_auth_types = await self._get_available_ui_auth_types(
|
||||||
|
@ -373,7 +375,9 @@ class AuthHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
except LoginError:
|
except LoginError:
|
||||||
# Update the ratelimiter to say we failed (`can_do_action` doesn't raise).
|
# Update the ratelimiter to say we failed (`can_do_action` doesn't raise).
|
||||||
self._failed_uia_attempts_ratelimiter.can_do_action(requester_user_id)
|
await self._failed_uia_attempts_ratelimiter.can_do_action(
|
||||||
|
requester,
|
||||||
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# find the completed login type
|
# find the completed login type
|
||||||
|
@ -982,8 +986,8 @@ class AuthHandler(BaseHandler):
|
||||||
# We also apply account rate limiting using the 3PID as a key, as
|
# We also apply account rate limiting using the 3PID as a key, as
|
||||||
# otherwise using 3PID bypasses the ratelimiting based on user ID.
|
# otherwise using 3PID bypasses the ratelimiting based on user ID.
|
||||||
if ratelimit:
|
if ratelimit:
|
||||||
self._failed_login_attempts_ratelimiter.ratelimit(
|
await self._failed_login_attempts_ratelimiter.ratelimit(
|
||||||
(medium, address), update=False
|
None, (medium, address), update=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check for login providers that support 3pid login types
|
# Check for login providers that support 3pid login types
|
||||||
|
@ -1016,8 +1020,8 @@ class AuthHandler(BaseHandler):
|
||||||
# this code path, which is fine as then the per-user ratelimit
|
# this code path, which is fine as then the per-user ratelimit
|
||||||
# will kick in below.
|
# will kick in below.
|
||||||
if ratelimit:
|
if ratelimit:
|
||||||
self._failed_login_attempts_ratelimiter.can_do_action(
|
await self._failed_login_attempts_ratelimiter.can_do_action(
|
||||||
(medium, address)
|
None, (medium, address)
|
||||||
)
|
)
|
||||||
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
||||||
|
|
||||||
|
@ -1039,8 +1043,8 @@ class AuthHandler(BaseHandler):
|
||||||
|
|
||||||
# Check if we've hit the failed ratelimit (but don't update it)
|
# Check if we've hit the failed ratelimit (but don't update it)
|
||||||
if ratelimit:
|
if ratelimit:
|
||||||
self._failed_login_attempts_ratelimiter.ratelimit(
|
await self._failed_login_attempts_ratelimiter.ratelimit(
|
||||||
qualified_user_id.lower(), update=False
|
None, qualified_user_id.lower(), update=False
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -1051,8 +1055,8 @@ class AuthHandler(BaseHandler):
|
||||||
# exception and masking the LoginError. The actual ratelimiting
|
# exception and masking the LoginError. The actual ratelimiting
|
||||||
# should have happened above.
|
# should have happened above.
|
||||||
if ratelimit:
|
if ratelimit:
|
||||||
self._failed_login_attempts_ratelimiter.can_do_action(
|
await self._failed_login_attempts_ratelimiter.can_do_action(
|
||||||
qualified_user_id.lower()
|
None, qualified_user_id.lower()
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
|
@ -631,7 +631,7 @@ class DeviceListUpdater:
|
||||||
max_len=10000,
|
max_len=10000,
|
||||||
expiry_ms=30 * 60 * 1000,
|
expiry_ms=30 * 60 * 1000,
|
||||||
iterable=True,
|
iterable=True,
|
||||||
)
|
) # type: ExpiringCache[str, Set[str]]
|
||||||
|
|
||||||
# Attempt to resync out of sync device lists every 30s.
|
# Attempt to resync out of sync device lists every 30s.
|
||||||
self._resync_retry_in_progress = False
|
self._resync_retry_in_progress = False
|
||||||
|
@ -760,7 +760,7 @@ class DeviceListUpdater:
|
||||||
"""Given a list of updates for a user figure out if we need to do a full
|
"""Given a list of updates for a user figure out if we need to do a full
|
||||||
resync, or whether we have enough data that we can just apply the delta.
|
resync, or whether we have enough data that we can just apply the delta.
|
||||||
"""
|
"""
|
||||||
seen_updates = self._seen_updates.get(user_id, set())
|
seen_updates = self._seen_updates.get(user_id, set()) # type: Set[str]
|
||||||
|
|
||||||
extremity = await self.store.get_device_list_last_stream_id_for_remote(user_id)
|
extremity = await self.store.get_device_list_last_stream_id_for_remote(user_id)
|
||||||
|
|
||||||
|
|
|
@ -21,10 +21,10 @@ from synapse.api.errors import SynapseError
|
||||||
from synapse.api.ratelimiting import Ratelimiter
|
from synapse.api.ratelimiting import Ratelimiter
|
||||||
from synapse.logging.context import run_in_background
|
from synapse.logging.context import run_in_background
|
||||||
from synapse.logging.opentracing import (
|
from synapse.logging.opentracing import (
|
||||||
|
SynapseTags,
|
||||||
get_active_span_text_map,
|
get_active_span_text_map,
|
||||||
log_kv,
|
log_kv,
|
||||||
set_tag,
|
set_tag,
|
||||||
start_active_span,
|
|
||||||
)
|
)
|
||||||
from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
|
from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
|
||||||
from synapse.types import JsonDict, Requester, UserID, get_domain_from_id
|
from synapse.types import JsonDict, Requester, UserID, get_domain_from_id
|
||||||
|
@ -81,6 +81,7 @@ class DeviceMessageHandler:
|
||||||
)
|
)
|
||||||
|
|
||||||
self._ratelimiter = Ratelimiter(
|
self._ratelimiter = Ratelimiter(
|
||||||
|
store=self.store,
|
||||||
clock=hs.get_clock(),
|
clock=hs.get_clock(),
|
||||||
rate_hz=hs.config.rc_key_requests.per_second,
|
rate_hz=hs.config.rc_key_requests.per_second,
|
||||||
burst_count=hs.config.rc_key_requests.burst_count,
|
burst_count=hs.config.rc_key_requests.burst_count,
|
||||||
|
@ -182,7 +183,10 @@ class DeviceMessageHandler:
|
||||||
) -> None:
|
) -> None:
|
||||||
sender_user_id = requester.user.to_string()
|
sender_user_id = requester.user.to_string()
|
||||||
|
|
||||||
set_tag("number_of_messages", len(messages))
|
message_id = random_string(16)
|
||||||
|
set_tag(SynapseTags.TO_DEVICE_MESSAGE_ID, message_id)
|
||||||
|
|
||||||
|
log_kv({"number_of_to_device_messages": len(messages)})
|
||||||
set_tag("sender", sender_user_id)
|
set_tag("sender", sender_user_id)
|
||||||
local_messages = {}
|
local_messages = {}
|
||||||
remote_messages = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]]
|
remote_messages = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]]
|
||||||
|
@ -191,8 +195,8 @@ class DeviceMessageHandler:
|
||||||
if (
|
if (
|
||||||
message_type == EduTypes.RoomKeyRequest
|
message_type == EduTypes.RoomKeyRequest
|
||||||
and user_id != sender_user_id
|
and user_id != sender_user_id
|
||||||
and self._ratelimiter.can_do_action(
|
and await self._ratelimiter.can_do_action(
|
||||||
(sender_user_id, requester.device_id)
|
requester, (sender_user_id, requester.device_id)
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
|
@ -204,23 +208,27 @@ class DeviceMessageHandler:
|
||||||
"content": message_content,
|
"content": message_content,
|
||||||
"type": message_type,
|
"type": message_type,
|
||||||
"sender": sender_user_id,
|
"sender": sender_user_id,
|
||||||
|
"message_id": message_id,
|
||||||
}
|
}
|
||||||
for device_id, message_content in by_device.items()
|
for device_id, message_content in by_device.items()
|
||||||
}
|
}
|
||||||
if messages_by_device:
|
if messages_by_device:
|
||||||
local_messages[user_id] = messages_by_device
|
local_messages[user_id] = messages_by_device
|
||||||
|
log_kv(
|
||||||
|
{
|
||||||
|
"user_id": user_id,
|
||||||
|
"device_id": list(messages_by_device),
|
||||||
|
}
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
destination = get_domain_from_id(user_id)
|
destination = get_domain_from_id(user_id)
|
||||||
remote_messages.setdefault(destination, {})[user_id] = by_device
|
remote_messages.setdefault(destination, {})[user_id] = by_device
|
||||||
|
|
||||||
message_id = random_string(16)
|
|
||||||
|
|
||||||
context = get_active_span_text_map()
|
context = get_active_span_text_map()
|
||||||
|
|
||||||
remote_edu_contents = {}
|
remote_edu_contents = {}
|
||||||
for destination, messages in remote_messages.items():
|
for destination, messages in remote_messages.items():
|
||||||
with start_active_span("to_device_for_user"):
|
log_kv({"destination": destination})
|
||||||
set_tag("destination", destination)
|
|
||||||
remote_edu_contents[destination] = {
|
remote_edu_contents[destination] = {
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"sender": sender_user_id,
|
"sender": sender_user_id,
|
||||||
|
@ -229,7 +237,6 @@ class DeviceMessageHandler:
|
||||||
"org.matrix.opentracing_context": json_encoder.encode(context),
|
"org.matrix.opentracing_context": json_encoder.encode(context),
|
||||||
}
|
}
|
||||||
|
|
||||||
log_kv({"local_messages": local_messages})
|
|
||||||
stream_id = await self.store.add_messages_to_device_inbox(
|
stream_id = await self.store.add_messages_to_device_inbox(
|
||||||
local_messages, remote_edu_contents
|
local_messages, remote_edu_contents
|
||||||
)
|
)
|
||||||
|
@ -238,7 +245,6 @@ class DeviceMessageHandler:
|
||||||
"to_device_key", stream_id, users=local_messages.keys()
|
"to_device_key", stream_id, users=local_messages.keys()
|
||||||
)
|
)
|
||||||
|
|
||||||
log_kv({"remote_messages": remote_messages})
|
|
||||||
if self.federation_sender:
|
if self.federation_sender:
|
||||||
for destination in remote_messages.keys():
|
for destination in remote_messages.keys():
|
||||||
# Enqueue a new federation transaction to send the new
|
# Enqueue a new federation transaction to send the new
|
||||||
|
|
|
@ -38,7 +38,6 @@ from synapse.types import (
|
||||||
)
|
)
|
||||||
from synapse.util import json_decoder, unwrapFirstError
|
from synapse.util import json_decoder, unwrapFirstError
|
||||||
from synapse.util.async_helpers import Linearizer
|
from synapse.util.async_helpers import Linearizer
|
||||||
from synapse.util.caches.expiringcache import ExpiringCache
|
|
||||||
from synapse.util.retryutils import NotRetryingDestination
|
from synapse.util.retryutils import NotRetryingDestination
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -1008,7 +1007,7 @@ class E2eKeysHandler:
|
||||||
return signature_list, failures
|
return signature_list, failures
|
||||||
|
|
||||||
async def _get_e2e_cross_signing_verify_key(
|
async def _get_e2e_cross_signing_verify_key(
|
||||||
self, user_id: str, key_type: str, from_user_id: str = None
|
self, user_id: str, key_type: str, from_user_id: Optional[str] = None
|
||||||
) -> Tuple[JsonDict, str, VerifyKey]:
|
) -> Tuple[JsonDict, str, VerifyKey]:
|
||||||
"""Fetch locally or remotely query for a cross-signing public key.
|
"""Fetch locally or remotely query for a cross-signing public key.
|
||||||
|
|
||||||
|
@ -1292,17 +1291,6 @@ class SigningKeyEduUpdater:
|
||||||
# user_id -> list of updates waiting to be handled.
|
# user_id -> list of updates waiting to be handled.
|
||||||
self._pending_updates = {} # type: Dict[str, List[Tuple[JsonDict, JsonDict]]]
|
self._pending_updates = {} # type: Dict[str, List[Tuple[JsonDict, JsonDict]]]
|
||||||
|
|
||||||
# Recently seen stream ids. We don't bother keeping these in the DB,
|
|
||||||
# but they're useful to have them about to reduce the number of spurious
|
|
||||||
# resyncs.
|
|
||||||
self._seen_updates = ExpiringCache(
|
|
||||||
cache_name="signing_key_update_edu",
|
|
||||||
clock=self.clock,
|
|
||||||
max_len=10000,
|
|
||||||
expiry_ms=30 * 60 * 1000,
|
|
||||||
iterable=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def incoming_signing_key_update(
|
async def incoming_signing_key_update(
|
||||||
self, origin: str, edu_content: JsonDict
|
self, origin: str, edu_content: JsonDict
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
|
@ -21,7 +21,17 @@ import itertools
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Container
|
from collections.abc import Container
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Dict,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Set,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
from signedjson.key import decode_verify_key_bytes
|
from signedjson.key import decode_verify_key_bytes
|
||||||
|
@ -171,15 +181,17 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
|
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
|
||||||
|
|
||||||
async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None:
|
async def on_receive_pdu(
|
||||||
|
self, origin: str, pdu: EventBase, sent_to_us_directly: bool = False
|
||||||
|
) -> None:
|
||||||
"""Process a PDU received via a federation /send/ transaction, or
|
"""Process a PDU received via a federation /send/ transaction, or
|
||||||
via backfill of missing prev_events
|
via backfill of missing prev_events
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
origin (str): server which initiated the /send/ transaction. Will
|
origin: server which initiated the /send/ transaction. Will
|
||||||
be used to fetch missing events or state.
|
be used to fetch missing events or state.
|
||||||
pdu (FrozenEvent): received PDU
|
pdu: received PDU
|
||||||
sent_to_us_directly (bool): True if this event was pushed to us; False if
|
sent_to_us_directly: True if this event was pushed to us; False if
|
||||||
we pulled it as the result of a missing prev_event.
|
we pulled it as the result of a missing prev_event.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -411,13 +423,15 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
await self._process_received_pdu(origin, pdu, state=state)
|
await self._process_received_pdu(origin, pdu, state=state)
|
||||||
|
|
||||||
async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
|
async def _get_missing_events_for_pdu(
|
||||||
|
self, origin: str, pdu: EventBase, prevs: Set[str], min_depth: int
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
origin (str): Origin of the pdu. Will be called to get the missing events
|
origin: Origin of the pdu. Will be called to get the missing events
|
||||||
pdu: received pdu
|
pdu: received pdu
|
||||||
prevs (set(str)): List of event ids which we are missing
|
prevs: List of event ids which we are missing
|
||||||
min_depth (int): Minimum depth of events to return.
|
min_depth: Minimum depth of events to return.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
room_id = pdu.room_id
|
room_id = pdu.room_id
|
||||||
|
@ -778,7 +792,7 @@ class FederationHandler(BaseHandler):
|
||||||
origin: str,
|
origin: str,
|
||||||
event: EventBase,
|
event: EventBase,
|
||||||
state: Optional[Iterable[EventBase]],
|
state: Optional[Iterable[EventBase]],
|
||||||
):
|
) -> None:
|
||||||
"""Called when we have a new pdu. We need to do auth checks and put it
|
"""Called when we have a new pdu. We need to do auth checks and put it
|
||||||
through the StateHandler.
|
through the StateHandler.
|
||||||
|
|
||||||
|
@ -887,7 +901,9 @@ class FederationHandler(BaseHandler):
|
||||||
logger.exception("Failed to resync device for %s", sender)
|
logger.exception("Failed to resync device for %s", sender)
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
async def backfill(self, dest, room_id, limit, extremities):
|
async def backfill(
|
||||||
|
self, dest: str, room_id: str, limit: int, extremities: List[str]
|
||||||
|
) -> List[EventBase]:
|
||||||
"""Trigger a backfill request to `dest` for the given `room_id`
|
"""Trigger a backfill request to `dest` for the given `room_id`
|
||||||
|
|
||||||
This will attempt to get more events from the remote. If the other side
|
This will attempt to get more events from the remote. If the other side
|
||||||
|
@ -1142,16 +1158,15 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
curr_state = await self.state_handler.get_current_state(room_id)
|
curr_state = await self.state_handler.get_current_state(room_id)
|
||||||
|
|
||||||
def get_domains_from_state(state):
|
def get_domains_from_state(state: StateMap[EventBase]) -> List[Tuple[str, int]]:
|
||||||
"""Get joined domains from state
|
"""Get joined domains from state
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
state (dict[tuple, FrozenEvent]): State map from type/state
|
state: State map from type/state key to event.
|
||||||
key to event.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list[tuple[str, int]]: Returns a list of servers with the
|
Returns a list of servers with the lowest depth of their joins.
|
||||||
lowest depth of their joins. Sorted by lowest depth first.
|
Sorted by lowest depth first.
|
||||||
"""
|
"""
|
||||||
joined_users = [
|
joined_users = [
|
||||||
(state_key, int(event.depth))
|
(state_key, int(event.depth))
|
||||||
|
@ -1179,7 +1194,7 @@ class FederationHandler(BaseHandler):
|
||||||
domain for domain, depth in curr_domains if domain != self.server_name
|
domain for domain, depth in curr_domains if domain != self.server_name
|
||||||
]
|
]
|
||||||
|
|
||||||
async def try_backfill(domains):
|
async def try_backfill(domains: List[str]) -> bool:
|
||||||
# TODO: Should we try multiple of these at a time?
|
# TODO: Should we try multiple of these at a time?
|
||||||
for dom in domains:
|
for dom in domains:
|
||||||
try:
|
try:
|
||||||
|
@ -1258,21 +1273,25 @@ class FederationHandler(BaseHandler):
|
||||||
}
|
}
|
||||||
|
|
||||||
for e_id, _ in sorted_extremeties_tuple:
|
for e_id, _ in sorted_extremeties_tuple:
|
||||||
likely_domains = get_domains_from_state(states[e_id])
|
likely_extremeties_domains = get_domains_from_state(states[e_id])
|
||||||
|
|
||||||
success = await try_backfill(
|
success = await try_backfill(
|
||||||
[dom for dom, _ in likely_domains if dom not in tried_domains]
|
[
|
||||||
|
dom
|
||||||
|
for dom, _ in likely_extremeties_domains
|
||||||
|
if dom not in tried_domains
|
||||||
|
]
|
||||||
)
|
)
|
||||||
if success:
|
if success:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
tried_domains.update(dom for dom, _ in likely_domains)
|
tried_domains.update(dom for dom, _ in likely_extremeties_domains)
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def _get_events_and_persist(
|
async def _get_events_and_persist(
|
||||||
self, destination: str, room_id: str, events: Iterable[str]
|
self, destination: str, room_id: str, events: Iterable[str]
|
||||||
):
|
) -> None:
|
||||||
"""Fetch the given events from a server, and persist them as outliers.
|
"""Fetch the given events from a server, and persist them as outliers.
|
||||||
|
|
||||||
This function *does not* recursively get missing auth events of the
|
This function *does not* recursively get missing auth events of the
|
||||||
|
@ -1348,7 +1367,7 @@ class FederationHandler(BaseHandler):
|
||||||
event_infos,
|
event_infos,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _sanity_check_event(self, ev):
|
def _sanity_check_event(self, ev: EventBase) -> None:
|
||||||
"""
|
"""
|
||||||
Do some early sanity checks of a received event
|
Do some early sanity checks of a received event
|
||||||
|
|
||||||
|
@ -1357,9 +1376,7 @@ class FederationHandler(BaseHandler):
|
||||||
or cascade of event fetches.
|
or cascade of event fetches.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ev (synapse.events.EventBase): event to be checked
|
ev: event to be checked
|
||||||
|
|
||||||
Returns: None
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
SynapseError if the event does not pass muster
|
SynapseError if the event does not pass muster
|
||||||
|
@ -1380,7 +1397,7 @@ class FederationHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events")
|
raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events")
|
||||||
|
|
||||||
async def send_invite(self, target_host, event):
|
async def send_invite(self, target_host: str, event: EventBase) -> EventBase:
|
||||||
"""Sends the invite to the remote server for signing.
|
"""Sends the invite to the remote server for signing.
|
||||||
|
|
||||||
Invites must be signed by the invitee's server before distribution.
|
Invites must be signed by the invitee's server before distribution.
|
||||||
|
@ -1528,12 +1545,13 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
run_in_background(self._handle_queued_pdus, room_queue)
|
run_in_background(self._handle_queued_pdus, room_queue)
|
||||||
|
|
||||||
async def _handle_queued_pdus(self, room_queue):
|
async def _handle_queued_pdus(
|
||||||
|
self, room_queue: List[Tuple[EventBase, str]]
|
||||||
|
) -> None:
|
||||||
"""Process PDUs which got queued up while we were busy send_joining.
|
"""Process PDUs which got queued up while we were busy send_joining.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
room_queue (list[FrozenEvent, str]): list of PDUs to be processed
|
room_queue: list of PDUs to be processed and the servers that sent them
|
||||||
and the servers that sent them
|
|
||||||
"""
|
"""
|
||||||
for p, origin in room_queue:
|
for p, origin in room_queue:
|
||||||
try:
|
try:
|
||||||
|
@ -1612,7 +1630,7 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
return event
|
return event
|
||||||
|
|
||||||
async def on_send_join_request(self, origin, pdu):
|
async def on_send_join_request(self, origin: str, pdu: EventBase) -> JsonDict:
|
||||||
"""We have received a join event for a room. Fully process it and
|
"""We have received a join event for a room. Fully process it and
|
||||||
respond with the current state and auth chains.
|
respond with the current state and auth chains.
|
||||||
"""
|
"""
|
||||||
|
@ -1668,7 +1686,7 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
async def on_invite_request(
|
async def on_invite_request(
|
||||||
self, origin: str, event: EventBase, room_version: RoomVersion
|
self, origin: str, event: EventBase, room_version: RoomVersion
|
||||||
):
|
) -> EventBase:
|
||||||
"""We've got an invite event. Process and persist it. Sign it.
|
"""We've got an invite event. Process and persist it. Sign it.
|
||||||
|
|
||||||
Respond with the now signed event.
|
Respond with the now signed event.
|
||||||
|
@ -1711,7 +1729,7 @@ class FederationHandler(BaseHandler):
|
||||||
member_handler = self.hs.get_room_member_handler()
|
member_handler = self.hs.get_room_member_handler()
|
||||||
# We don't rate limit based on room ID, as that should be done by
|
# We don't rate limit based on room ID, as that should be done by
|
||||||
# sending server.
|
# sending server.
|
||||||
member_handler.ratelimit_invite(None, event.state_key)
|
await member_handler.ratelimit_invite(None, None, event.state_key)
|
||||||
|
|
||||||
# keep a record of the room version, if we don't yet know it.
|
# keep a record of the room version, if we don't yet know it.
|
||||||
# (this may get overwritten if we later get a different room version in a
|
# (this may get overwritten if we later get a different room version in a
|
||||||
|
@ -1841,7 +1859,7 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
return event
|
return event
|
||||||
|
|
||||||
async def on_send_leave_request(self, origin, pdu):
|
async def on_send_leave_request(self, origin: str, pdu: EventBase) -> None:
|
||||||
""" We have received a leave event for a room. Fully process it."""
|
""" We have received a leave event for a room. Fully process it."""
|
||||||
event = pdu
|
event = pdu
|
||||||
|
|
||||||
|
@ -1969,12 +1987,17 @@ class FederationHandler(BaseHandler):
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def get_min_depth_for_context(self, context):
|
async def get_min_depth_for_context(self, context: str) -> int:
|
||||||
return await self.store.get_min_depth(context)
|
return await self.store.get_min_depth(context)
|
||||||
|
|
||||||
async def _handle_new_event(
|
async def _handle_new_event(
|
||||||
self, origin, event, state=None, auth_events=None, backfilled=False
|
self,
|
||||||
):
|
origin: str,
|
||||||
|
event: EventBase,
|
||||||
|
state: Optional[Iterable[EventBase]] = None,
|
||||||
|
auth_events: Optional[MutableStateMap[EventBase]] = None,
|
||||||
|
backfilled: bool = False,
|
||||||
|
) -> EventContext:
|
||||||
context = await self._prep_event(
|
context = await self._prep_event(
|
||||||
origin, event, state=state, auth_events=auth_events, backfilled=backfilled
|
origin, event, state=state, auth_events=auth_events, backfilled=backfilled
|
||||||
)
|
)
|
||||||
|
@ -2280,40 +2303,14 @@ class FederationHandler(BaseHandler):
|
||||||
logger.warning("Soft-failing %r because %s", event, e)
|
logger.warning("Soft-failing %r because %s", event, e)
|
||||||
event.internal_metadata.soft_failed = True
|
event.internal_metadata.soft_failed = True
|
||||||
|
|
||||||
async def on_query_auth(
|
|
||||||
self, origin, event_id, room_id, remote_auth_chain, rejects, missing
|
|
||||||
):
|
|
||||||
in_room = await self.auth.check_host_in_room(room_id, origin)
|
|
||||||
if not in_room:
|
|
||||||
raise AuthError(403, "Host not in room.")
|
|
||||||
|
|
||||||
event = await self.store.get_event(event_id, check_room_id=room_id)
|
|
||||||
|
|
||||||
# Just go through and process each event in `remote_auth_chain`. We
|
|
||||||
# don't want to fall into the trap of `missing` being wrong.
|
|
||||||
for e in remote_auth_chain:
|
|
||||||
try:
|
|
||||||
await self._handle_new_event(origin, e)
|
|
||||||
except AuthError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Now get the current auth_chain for the event.
|
|
||||||
local_auth_chain = await self.store.get_auth_chain(
|
|
||||||
room_id, list(event.auth_event_ids()), include_given=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: Check if we would now reject event_id. If so we need to tell
|
|
||||||
# everyone.
|
|
||||||
|
|
||||||
ret = await self.construct_auth_difference(local_auth_chain, remote_auth_chain)
|
|
||||||
|
|
||||||
logger.debug("on_query_auth returning: %s", ret)
|
|
||||||
|
|
||||||
return ret
|
|
||||||
|
|
||||||
async def on_get_missing_events(
|
async def on_get_missing_events(
|
||||||
self, origin, room_id, earliest_events, latest_events, limit
|
self,
|
||||||
):
|
origin: str,
|
||||||
|
room_id: str,
|
||||||
|
earliest_events: List[str],
|
||||||
|
latest_events: List[str],
|
||||||
|
limit: int,
|
||||||
|
) -> List[EventBase]:
|
||||||
in_room = await self.auth.check_host_in_room(room_id, origin)
|
in_room = await self.auth.check_host_in_room(room_id, origin)
|
||||||
if not in_room:
|
if not in_room:
|
||||||
raise AuthError(403, "Host not in room.")
|
raise AuthError(403, "Host not in room.")
|
||||||
|
@ -2617,8 +2614,8 @@ class FederationHandler(BaseHandler):
|
||||||
assumes that we have already processed all events in remote_auth
|
assumes that we have already processed all events in remote_auth
|
||||||
|
|
||||||
Params:
|
Params:
|
||||||
local_auth (list)
|
local_auth
|
||||||
remote_auth (list)
|
remote_auth
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict
|
dict
|
||||||
|
@ -2742,8 +2739,8 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
async def exchange_third_party_invite(
|
async def exchange_third_party_invite(
|
||||||
self, sender_user_id, target_user_id, room_id, signed
|
self, sender_user_id: str, target_user_id: str, room_id: str, signed: JsonDict
|
||||||
):
|
) -> None:
|
||||||
third_party_invite = {"signed": signed}
|
third_party_invite = {"signed": signed}
|
||||||
|
|
||||||
event_dict = {
|
event_dict = {
|
||||||
|
@ -2835,8 +2832,12 @@ class FederationHandler(BaseHandler):
|
||||||
await member_handler.send_membership_event(None, event, context)
|
await member_handler.send_membership_event(None, event, context)
|
||||||
|
|
||||||
async def add_display_name_to_third_party_invite(
|
async def add_display_name_to_third_party_invite(
|
||||||
self, room_version, event_dict, event, context
|
self,
|
||||||
):
|
room_version: str,
|
||||||
|
event_dict: JsonDict,
|
||||||
|
event: EventBase,
|
||||||
|
context: EventContext,
|
||||||
|
) -> Tuple[EventBase, EventContext]:
|
||||||
key = (
|
key = (
|
||||||
EventTypes.ThirdPartyInvite,
|
EventTypes.ThirdPartyInvite,
|
||||||
event.content["third_party_invite"]["signed"]["token"],
|
event.content["third_party_invite"]["signed"]["token"],
|
||||||
|
@ -2872,13 +2873,13 @@ class FederationHandler(BaseHandler):
|
||||||
EventValidator().validate_new(event, self.config)
|
EventValidator().validate_new(event, self.config)
|
||||||
return (event, context)
|
return (event, context)
|
||||||
|
|
||||||
async def _check_signature(self, event, context):
|
async def _check_signature(self, event: EventBase, context: EventContext) -> None:
|
||||||
"""
|
"""
|
||||||
Checks that the signature in the event is consistent with its invite.
|
Checks that the signature in the event is consistent with its invite.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
event (Event): The m.room.member event to check
|
event: The m.room.member event to check
|
||||||
context (EventContext):
|
context:
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
AuthError: if signature didn't match any keys, or key has been
|
AuthError: if signature didn't match any keys, or key has been
|
||||||
|
@ -2964,13 +2965,13 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
raise last_exception
|
raise last_exception
|
||||||
|
|
||||||
async def _check_key_revocation(self, public_key, url):
|
async def _check_key_revocation(self, public_key: str, url: str) -> None:
|
||||||
"""
|
"""
|
||||||
Checks whether public_key has been revoked.
|
Checks whether public_key has been revoked.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
public_key (str): base-64 encoded public key.
|
public_key: base-64 encoded public key.
|
||||||
url (str): Key revocation URL.
|
url: Key revocation URL.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
AuthError: if they key has been revoked.
|
AuthError: if they key has been revoked.
|
||||||
|
|
|
@ -61,17 +61,19 @@ class IdentityHandler(BaseHandler):
|
||||||
|
|
||||||
# Ratelimiters for `/requestToken` endpoints.
|
# Ratelimiters for `/requestToken` endpoints.
|
||||||
self._3pid_validation_ratelimiter_ip = Ratelimiter(
|
self._3pid_validation_ratelimiter_ip = Ratelimiter(
|
||||||
|
store=self.store,
|
||||||
clock=hs.get_clock(),
|
clock=hs.get_clock(),
|
||||||
rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
|
rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
|
||||||
burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
|
burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
|
||||||
)
|
)
|
||||||
self._3pid_validation_ratelimiter_address = Ratelimiter(
|
self._3pid_validation_ratelimiter_address = Ratelimiter(
|
||||||
|
store=self.store,
|
||||||
clock=hs.get_clock(),
|
clock=hs.get_clock(),
|
||||||
rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
|
rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
|
||||||
burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
|
burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
def ratelimit_request_token_requests(
|
async def ratelimit_request_token_requests(
|
||||||
self,
|
self,
|
||||||
request: SynapseRequest,
|
request: SynapseRequest,
|
||||||
medium: str,
|
medium: str,
|
||||||
|
@ -85,8 +87,12 @@ class IdentityHandler(BaseHandler):
|
||||||
address: The actual threepid ID, e.g. the phone number or email address
|
address: The actual threepid ID, e.g. the phone number or email address
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self._3pid_validation_ratelimiter_ip.ratelimit((medium, request.getClientIP()))
|
await self._3pid_validation_ratelimiter_ip.ratelimit(
|
||||||
self._3pid_validation_ratelimiter_address.ratelimit((medium, address))
|
None, (medium, request.getClientIP())
|
||||||
|
)
|
||||||
|
await self._3pid_validation_ratelimiter_address.ratelimit(
|
||||||
|
None, (medium, address)
|
||||||
|
)
|
||||||
|
|
||||||
async def threepid_from_creds(
|
async def threepid_from_creds(
|
||||||
self, id_server: str, creds: Dict[str, str]
|
self, id_server: str, creds: Dict[str, str]
|
||||||
|
|
|
@ -385,7 +385,7 @@ class EventCreationHandler:
|
||||||
self._events_shard_config = self.config.worker.events_shard_config
|
self._events_shard_config = self.config.worker.events_shard_config
|
||||||
self._instance_name = hs.get_instance_name()
|
self._instance_name = hs.get_instance_name()
|
||||||
|
|
||||||
self.room_invite_state_types = self.hs.config.room_invite_state_types
|
self.room_invite_state_types = self.hs.config.api.room_prejoin_state
|
||||||
|
|
||||||
self.membership_types_to_include_profile_data_in = (
|
self.membership_types_to_include_profile_data_in = (
|
||||||
{Membership.JOIN, Membership.INVITE}
|
{Membership.JOIN, Membership.INVITE}
|
||||||
|
|
|
@ -25,7 +25,17 @@ The methods that define policy are:
|
||||||
import abc
|
import abc
|
||||||
import logging
|
import logging
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Set, Tuple
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Dict,
|
||||||
|
FrozenSet,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Set,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
from prometheus_client import Counter
|
from prometheus_client import Counter
|
||||||
from typing_extensions import ContextManager
|
from typing_extensions import ContextManager
|
||||||
|
@ -34,6 +44,7 @@ import synapse.metrics
|
||||||
from synapse.api.constants import EventTypes, Membership, PresenceState
|
from synapse.api.constants import EventTypes, Membership, PresenceState
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
from synapse.api.presence import UserPresenceState
|
from synapse.api.presence import UserPresenceState
|
||||||
|
from synapse.events.presence_router import PresenceRouter
|
||||||
from synapse.logging.context import run_in_background
|
from synapse.logging.context import run_in_background
|
||||||
from synapse.logging.utils import log_function
|
from synapse.logging.utils import log_function
|
||||||
from synapse.metrics import LaterGauge
|
from synapse.metrics import LaterGauge
|
||||||
|
@ -42,7 +53,7 @@ from synapse.state import StateHandler
|
||||||
from synapse.storage.databases.main import DataStore
|
from synapse.storage.databases.main import DataStore
|
||||||
from synapse.types import Collection, JsonDict, UserID, get_domain_from_id
|
from synapse.types import Collection, JsonDict, UserID, get_domain_from_id
|
||||||
from synapse.util.async_helpers import Linearizer
|
from synapse.util.async_helpers import Linearizer
|
||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import _CacheContext, cached
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
from synapse.util.wheel_timer import WheelTimer
|
from synapse.util.wheel_timer import WheelTimer
|
||||||
|
|
||||||
|
@ -209,6 +220,7 @@ class PresenceHandler(BasePresenceHandler):
|
||||||
self.notifier = hs.get_notifier()
|
self.notifier = hs.get_notifier()
|
||||||
self.federation = hs.get_federation_sender()
|
self.federation = hs.get_federation_sender()
|
||||||
self.state = hs.get_state_handler()
|
self.state = hs.get_state_handler()
|
||||||
|
self.presence_router = hs.get_presence_router()
|
||||||
self._presence_enabled = hs.config.use_presence
|
self._presence_enabled = hs.config.use_presence
|
||||||
|
|
||||||
federation_registry = hs.get_federation_registry()
|
federation_registry = hs.get_federation_registry()
|
||||||
|
@ -653,7 +665,7 @@ class PresenceHandler(BasePresenceHandler):
|
||||||
"""
|
"""
|
||||||
stream_id, max_token = await self.store.update_presence(states)
|
stream_id, max_token = await self.store.update_presence(states)
|
||||||
|
|
||||||
parties = await get_interested_parties(self.store, states)
|
parties = await get_interested_parties(self.store, self.presence_router, states)
|
||||||
room_ids_to_states, users_to_states = parties
|
room_ids_to_states, users_to_states = parties
|
||||||
|
|
||||||
self.notifier.on_new_event(
|
self.notifier.on_new_event(
|
||||||
|
@ -1041,7 +1053,12 @@ class PresenceEventSource:
|
||||||
#
|
#
|
||||||
# Presence -> Notifier -> PresenceEventSource -> Presence
|
# Presence -> Notifier -> PresenceEventSource -> Presence
|
||||||
#
|
#
|
||||||
|
# Same with get_module_api, get_presence_router
|
||||||
|
#
|
||||||
|
# AuthHandler -> Notifier -> PresenceEventSource -> ModuleApi -> AuthHandler
|
||||||
self.get_presence_handler = hs.get_presence_handler
|
self.get_presence_handler = hs.get_presence_handler
|
||||||
|
self.get_module_api = hs.get_module_api
|
||||||
|
self.get_presence_router = hs.get_presence_router
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.state = hs.get_state_handler()
|
self.state = hs.get_state_handler()
|
||||||
|
@ -1055,7 +1072,7 @@ class PresenceEventSource:
|
||||||
include_offline=True,
|
include_offline=True,
|
||||||
explicit_room_id=None,
|
explicit_room_id=None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
) -> Tuple[List[UserPresenceState], int]:
|
||||||
# The process for getting presence events are:
|
# The process for getting presence events are:
|
||||||
# 1. Get the rooms the user is in.
|
# 1. Get the rooms the user is in.
|
||||||
# 2. Get the list of user in the rooms.
|
# 2. Get the list of user in the rooms.
|
||||||
|
@ -1068,7 +1085,17 @@ class PresenceEventSource:
|
||||||
# We don't try and limit the presence updates by the current token, as
|
# We don't try and limit the presence updates by the current token, as
|
||||||
# sending down the rare duplicate is not a concern.
|
# sending down the rare duplicate is not a concern.
|
||||||
|
|
||||||
|
user_id = user.to_string()
|
||||||
|
stream_change_cache = self.store.presence_stream_cache
|
||||||
|
|
||||||
with Measure(self.clock, "presence.get_new_events"):
|
with Measure(self.clock, "presence.get_new_events"):
|
||||||
|
if user_id in self.get_module_api()._send_full_presence_to_local_users:
|
||||||
|
# This user has been specified by a module to receive all current, online
|
||||||
|
# user presence. Removing from_key and setting include_offline to false
|
||||||
|
# will do effectively this.
|
||||||
|
from_key = None
|
||||||
|
include_offline = False
|
||||||
|
|
||||||
if from_key is not None:
|
if from_key is not None:
|
||||||
from_key = int(from_key)
|
from_key = int(from_key)
|
||||||
|
|
||||||
|
@ -1091,59 +1118,209 @@ class PresenceEventSource:
|
||||||
# doesn't return. C.f. #5503.
|
# doesn't return. C.f. #5503.
|
||||||
return [], max_token
|
return [], max_token
|
||||||
|
|
||||||
presence = self.get_presence_handler()
|
# Figure out which other users this user should receive updates for
|
||||||
stream_change_cache = self.store.presence_stream_cache
|
|
||||||
|
|
||||||
users_interested_in = await self._get_interested_in(user, explicit_room_id)
|
users_interested_in = await self._get_interested_in(user, explicit_room_id)
|
||||||
|
|
||||||
user_ids_changed = set() # type: Collection[str]
|
# We have a set of users that we're interested in the presence of. We want to
|
||||||
changed = None
|
# cross-reference that with the users that have actually changed their presence.
|
||||||
|
|
||||||
|
# Check whether this user should see all user updates
|
||||||
|
|
||||||
|
if users_interested_in == PresenceRouter.ALL_USERS:
|
||||||
|
# Provide presence state for all users
|
||||||
|
presence_updates = await self._filter_all_presence_updates_for_user(
|
||||||
|
user_id, include_offline, from_key
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove the user from the list of users to receive all presence
|
||||||
|
if user_id in self.get_module_api()._send_full_presence_to_local_users:
|
||||||
|
self.get_module_api()._send_full_presence_to_local_users.remove(
|
||||||
|
user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
return presence_updates, max_token
|
||||||
|
|
||||||
|
# Make mypy happy. users_interested_in should now be a set
|
||||||
|
assert not isinstance(users_interested_in, str)
|
||||||
|
|
||||||
|
# The set of users that we're interested in and that have had a presence update.
|
||||||
|
# We'll actually pull the presence updates for these users at the end.
|
||||||
|
interested_and_updated_users = (
|
||||||
|
set()
|
||||||
|
) # type: Union[Set[str], FrozenSet[str]]
|
||||||
|
|
||||||
if from_key:
|
if from_key:
|
||||||
changed = stream_change_cache.get_all_entities_changed(from_key)
|
# First get all users that have had a presence update
|
||||||
|
updated_users = stream_change_cache.get_all_entities_changed(from_key)
|
||||||
|
|
||||||
if changed is not None and len(changed) < 500:
|
# Cross-reference users we're interested in with those that have had updates.
|
||||||
assert isinstance(user_ids_changed, set)
|
# Use a slightly-optimised method for processing smaller sets of updates.
|
||||||
|
if updated_users is not None and len(updated_users) < 500:
|
||||||
# For small deltas, its quicker to get all changes and then
|
# For small deltas, it's quicker to get all changes and then
|
||||||
# work out if we share a room or they're in our presence list
|
# cross-reference with the users we're interested in
|
||||||
get_updates_counter.labels("stream").inc()
|
get_updates_counter.labels("stream").inc()
|
||||||
for other_user_id in changed:
|
for other_user_id in updated_users:
|
||||||
if other_user_id in users_interested_in:
|
if other_user_id in users_interested_in:
|
||||||
user_ids_changed.add(other_user_id)
|
# mypy thinks this variable could be a FrozenSet as it's possibly set
|
||||||
|
# to one in the `get_entities_changed` call below, and `add()` is not
|
||||||
|
# method on a FrozenSet. That doesn't affect us here though, as
|
||||||
|
# `interested_and_updated_users` is clearly a set() above.
|
||||||
|
interested_and_updated_users.add(other_user_id) # type: ignore
|
||||||
else:
|
else:
|
||||||
# Too many possible updates. Find all users we can see and check
|
# Too many possible updates. Find all users we can see and check
|
||||||
# if any of them have changed.
|
# if any of them have changed.
|
||||||
get_updates_counter.labels("full").inc()
|
get_updates_counter.labels("full").inc()
|
||||||
|
|
||||||
if from_key:
|
interested_and_updated_users = (
|
||||||
user_ids_changed = stream_change_cache.get_entities_changed(
|
stream_change_cache.get_entities_changed(
|
||||||
users_interested_in, from_key
|
users_interested_in, from_key
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
user_ids_changed = users_interested_in
|
|
||||||
|
|
||||||
updates = await presence.current_state_for_users(user_ids_changed)
|
|
||||||
|
|
||||||
if include_offline:
|
|
||||||
return (list(updates.values()), max_token)
|
|
||||||
else:
|
|
||||||
return (
|
|
||||||
[s for s in updates.values() if s.state != PresenceState.OFFLINE],
|
|
||||||
max_token,
|
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
# No from_key has been specified. Return the presence for all users
|
||||||
|
# this user is interested in
|
||||||
|
interested_and_updated_users = users_interested_in
|
||||||
|
|
||||||
|
# Retrieve the current presence state for each user
|
||||||
|
users_to_state = await self.get_presence_handler().current_state_for_users(
|
||||||
|
interested_and_updated_users
|
||||||
|
)
|
||||||
|
presence_updates = list(users_to_state.values())
|
||||||
|
|
||||||
|
# Remove the user from the list of users to receive all presence
|
||||||
|
if user_id in self.get_module_api()._send_full_presence_to_local_users:
|
||||||
|
self.get_module_api()._send_full_presence_to_local_users.remove(user_id)
|
||||||
|
|
||||||
|
if not include_offline:
|
||||||
|
# Filter out offline presence states
|
||||||
|
presence_updates = self._filter_offline_presence_state(presence_updates)
|
||||||
|
|
||||||
|
return presence_updates, max_token
|
||||||
|
|
||||||
|
async def _filter_all_presence_updates_for_user(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
include_offline: bool,
|
||||||
|
from_key: Optional[int] = None,
|
||||||
|
) -> List[UserPresenceState]:
|
||||||
|
"""
|
||||||
|
Computes the presence updates a user should receive.
|
||||||
|
|
||||||
|
First pulls presence updates from the database. Then consults PresenceRouter
|
||||||
|
for whether any updates should be excluded by user ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The User ID of the user to compute presence updates for.
|
||||||
|
include_offline: Whether to include offline presence states from the results.
|
||||||
|
from_key: The minimum stream ID of updates to pull from the database
|
||||||
|
before filtering.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of presence states for the given user to receive.
|
||||||
|
"""
|
||||||
|
if from_key:
|
||||||
|
# Only return updates since the last sync
|
||||||
|
updated_users = self.store.presence_stream_cache.get_all_entities_changed(
|
||||||
|
from_key
|
||||||
|
)
|
||||||
|
if not updated_users:
|
||||||
|
updated_users = []
|
||||||
|
|
||||||
|
# Get the actual presence update for each change
|
||||||
|
users_to_state = await self.get_presence_handler().current_state_for_users(
|
||||||
|
updated_users
|
||||||
|
)
|
||||||
|
presence_updates = list(users_to_state.values())
|
||||||
|
|
||||||
|
if not include_offline:
|
||||||
|
# Filter out offline states
|
||||||
|
presence_updates = self._filter_offline_presence_state(presence_updates)
|
||||||
|
else:
|
||||||
|
users_to_state = await self.store.get_presence_for_all_users(
|
||||||
|
include_offline=include_offline
|
||||||
|
)
|
||||||
|
|
||||||
|
presence_updates = list(users_to_state.values())
|
||||||
|
|
||||||
|
# TODO: This feels wildly inefficient, and it's unfortunate we need to ask the
|
||||||
|
# module for information on a number of users when we then only take the info
|
||||||
|
# for a single user
|
||||||
|
|
||||||
|
# Filter through the presence router
|
||||||
|
users_to_state_set = await self.get_presence_router().get_users_for_states(
|
||||||
|
presence_updates
|
||||||
|
)
|
||||||
|
|
||||||
|
# We only want the mapping for the syncing user
|
||||||
|
presence_updates = list(users_to_state_set[user_id])
|
||||||
|
|
||||||
|
# Return presence information for all users
|
||||||
|
return presence_updates
|
||||||
|
|
||||||
|
def _filter_offline_presence_state(
|
||||||
|
self, presence_updates: Iterable[UserPresenceState]
|
||||||
|
) -> List[UserPresenceState]:
|
||||||
|
"""Given an iterable containing user presence updates, return a list with any offline
|
||||||
|
presence states removed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
presence_updates: Presence states to filter
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new list with any offline presence states removed.
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
update
|
||||||
|
for update in presence_updates
|
||||||
|
if update.state != PresenceState.OFFLINE
|
||||||
|
]
|
||||||
|
|
||||||
def get_current_key(self):
|
def get_current_key(self):
|
||||||
return self.store.get_current_presence_token()
|
return self.store.get_current_presence_token()
|
||||||
|
|
||||||
@cached(num_args=2, cache_context=True)
|
@cached(num_args=2, cache_context=True)
|
||||||
async def _get_interested_in(self, user, explicit_room_id, cache_context):
|
async def _get_interested_in(
|
||||||
|
self,
|
||||||
|
user: UserID,
|
||||||
|
explicit_room_id: Optional[str] = None,
|
||||||
|
cache_context: Optional[_CacheContext] = None,
|
||||||
|
) -> Union[Set[str], str]:
|
||||||
"""Returns the set of users that the given user should see presence
|
"""Returns the set of users that the given user should see presence
|
||||||
updates for
|
updates for.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user: The user to retrieve presence updates for.
|
||||||
|
explicit_room_id: The users that are in the room will be returned.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A set of user IDs to return presence updates for, or "ALL" to return all
|
||||||
|
known updates.
|
||||||
"""
|
"""
|
||||||
user_id = user.to_string()
|
user_id = user.to_string()
|
||||||
users_interested_in = set()
|
users_interested_in = set()
|
||||||
users_interested_in.add(user_id) # So that we receive our own presence
|
users_interested_in.add(user_id) # So that we receive our own presence
|
||||||
|
|
||||||
|
# cache_context isn't likely to ever be None due to the @cached decorator,
|
||||||
|
# but we can't have a non-optional argument after the optional argument
|
||||||
|
# explicit_room_id either. Assert cache_context is not None so we can use it
|
||||||
|
# without mypy complaining.
|
||||||
|
assert cache_context
|
||||||
|
|
||||||
|
# Check with the presence router whether we should poll additional users for
|
||||||
|
# their presence information
|
||||||
|
additional_users = await self.get_presence_router().get_interested_users(
|
||||||
|
user.to_string()
|
||||||
|
)
|
||||||
|
if additional_users == PresenceRouter.ALL_USERS:
|
||||||
|
# If the module requested that this user see the presence updates of *all*
|
||||||
|
# users, then simply return that instead of calculating what rooms this
|
||||||
|
# user shares
|
||||||
|
return PresenceRouter.ALL_USERS
|
||||||
|
|
||||||
|
# Add the additional users from the router
|
||||||
|
users_interested_in.update(additional_users)
|
||||||
|
|
||||||
|
# Find the users who share a room with this user
|
||||||
users_who_share_room = await self.store.get_users_who_share_room_with_user(
|
users_who_share_room = await self.store.get_users_who_share_room_with_user(
|
||||||
user_id, on_invalidate=cache_context.invalidate
|
user_id, on_invalidate=cache_context.invalidate
|
||||||
)
|
)
|
||||||
|
@ -1314,14 +1491,15 @@ def handle_update(prev_state, new_state, is_mine, wheel_timer, now):
|
||||||
|
|
||||||
|
|
||||||
async def get_interested_parties(
|
async def get_interested_parties(
|
||||||
store: DataStore, states: List[UserPresenceState]
|
store: DataStore, presence_router: PresenceRouter, states: List[UserPresenceState]
|
||||||
) -> Tuple[Dict[str, List[UserPresenceState]], Dict[str, List[UserPresenceState]]]:
|
) -> Tuple[Dict[str, List[UserPresenceState]], Dict[str, List[UserPresenceState]]]:
|
||||||
"""Given a list of states return which entities (rooms, users)
|
"""Given a list of states return which entities (rooms, users)
|
||||||
are interested in the given states.
|
are interested in the given states.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
store
|
store: The homeserver's data store.
|
||||||
states
|
presence_router: A module for augmenting the destinations for presence updates.
|
||||||
|
states: A list of incoming user presence updates.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A 2-tuple of `(room_ids_to_states, users_to_states)`,
|
A 2-tuple of `(room_ids_to_states, users_to_states)`,
|
||||||
|
@ -1337,11 +1515,22 @@ async def get_interested_parties(
|
||||||
# Always notify self
|
# Always notify self
|
||||||
users_to_states.setdefault(state.user_id, []).append(state)
|
users_to_states.setdefault(state.user_id, []).append(state)
|
||||||
|
|
||||||
|
# Ask a presence routing module for any additional parties if one
|
||||||
|
# is loaded.
|
||||||
|
router_users_to_states = await presence_router.get_users_for_states(states)
|
||||||
|
|
||||||
|
# Update the dictionaries with additional destinations and state to send
|
||||||
|
for user_id, user_states in router_users_to_states.items():
|
||||||
|
users_to_states.setdefault(user_id, []).extend(user_states)
|
||||||
|
|
||||||
return room_ids_to_states, users_to_states
|
return room_ids_to_states, users_to_states
|
||||||
|
|
||||||
|
|
||||||
async def get_interested_remotes(
|
async def get_interested_remotes(
|
||||||
store: DataStore, states: List[UserPresenceState], state_handler: StateHandler
|
store: DataStore,
|
||||||
|
presence_router: PresenceRouter,
|
||||||
|
states: List[UserPresenceState],
|
||||||
|
state_handler: StateHandler,
|
||||||
) -> List[Tuple[Collection[str], List[UserPresenceState]]]:
|
) -> List[Tuple[Collection[str], List[UserPresenceState]]]:
|
||||||
"""Given a list of presence states figure out which remote servers
|
"""Given a list of presence states figure out which remote servers
|
||||||
should be sent which.
|
should be sent which.
|
||||||
|
@ -1349,9 +1538,10 @@ async def get_interested_remotes(
|
||||||
All the presence states should be for local users only.
|
All the presence states should be for local users only.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
store
|
store: The homeserver's data store.
|
||||||
states
|
presence_router: A module for augmenting the destinations for presence updates.
|
||||||
state_handler
|
states: A list of incoming user presence updates.
|
||||||
|
state_handler:
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of 2-tuples of destinations and states, where for
|
A list of 2-tuples of destinations and states, where for
|
||||||
|
@ -1363,7 +1553,9 @@ async def get_interested_remotes(
|
||||||
# First we look up the rooms each user is in (as well as any explicit
|
# First we look up the rooms each user is in (as well as any explicit
|
||||||
# subscriptions), then for each distinct room we look up the remote
|
# subscriptions), then for each distinct room we look up the remote
|
||||||
# hosts in those rooms.
|
# hosts in those rooms.
|
||||||
room_ids_to_states, users_to_states = await get_interested_parties(store, states)
|
room_ids_to_states, users_to_states = await get_interested_parties(
|
||||||
|
store, presence_router, states
|
||||||
|
)
|
||||||
|
|
||||||
for room_id, states in room_ids_to_states.items():
|
for room_id, states in room_ids_to_states.items():
|
||||||
hosts = await state_handler.get_current_hosts_in_room(room_id)
|
hosts = await state_handler.get_current_hosts_in_room(room_id)
|
||||||
|
|
|
@ -204,7 +204,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
Raises:
|
Raises:
|
||||||
SynapseError if there was a problem registering.
|
SynapseError if there was a problem registering.
|
||||||
"""
|
"""
|
||||||
self.check_registration_ratelimit(address)
|
await self.check_registration_ratelimit(address)
|
||||||
|
|
||||||
result = await self.spam_checker.check_registration_for_spam(
|
result = await self.spam_checker.check_registration_for_spam(
|
||||||
threepid,
|
threepid,
|
||||||
|
@ -583,7 +583,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
errcode=Codes.EXCLUSIVE,
|
errcode=Codes.EXCLUSIVE,
|
||||||
)
|
)
|
||||||
|
|
||||||
def check_registration_ratelimit(self, address: Optional[str]) -> None:
|
async def check_registration_ratelimit(self, address: Optional[str]) -> None:
|
||||||
"""A simple helper method to check whether the registration rate limit has been hit
|
"""A simple helper method to check whether the registration rate limit has been hit
|
||||||
for a given IP address
|
for a given IP address
|
||||||
|
|
||||||
|
@ -597,7 +597,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
if not address:
|
if not address:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.ratelimiter.ratelimit(address)
|
await self.ratelimiter.ratelimit(None, address)
|
||||||
|
|
||||||
async def register_with_store(
|
async def register_with_store(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -76,22 +76,26 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||||
self.allow_per_room_profiles = self.config.allow_per_room_profiles
|
self.allow_per_room_profiles = self.config.allow_per_room_profiles
|
||||||
|
|
||||||
self._join_rate_limiter_local = Ratelimiter(
|
self._join_rate_limiter_local = Ratelimiter(
|
||||||
|
store=self.store,
|
||||||
clock=self.clock,
|
clock=self.clock,
|
||||||
rate_hz=hs.config.ratelimiting.rc_joins_local.per_second,
|
rate_hz=hs.config.ratelimiting.rc_joins_local.per_second,
|
||||||
burst_count=hs.config.ratelimiting.rc_joins_local.burst_count,
|
burst_count=hs.config.ratelimiting.rc_joins_local.burst_count,
|
||||||
)
|
)
|
||||||
self._join_rate_limiter_remote = Ratelimiter(
|
self._join_rate_limiter_remote = Ratelimiter(
|
||||||
|
store=self.store,
|
||||||
clock=self.clock,
|
clock=self.clock,
|
||||||
rate_hz=hs.config.ratelimiting.rc_joins_remote.per_second,
|
rate_hz=hs.config.ratelimiting.rc_joins_remote.per_second,
|
||||||
burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count,
|
burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._invites_per_room_limiter = Ratelimiter(
|
self._invites_per_room_limiter = Ratelimiter(
|
||||||
|
store=self.store,
|
||||||
clock=self.clock,
|
clock=self.clock,
|
||||||
rate_hz=hs.config.ratelimiting.rc_invites_per_room.per_second,
|
rate_hz=hs.config.ratelimiting.rc_invites_per_room.per_second,
|
||||||
burst_count=hs.config.ratelimiting.rc_invites_per_room.burst_count,
|
burst_count=hs.config.ratelimiting.rc_invites_per_room.burst_count,
|
||||||
)
|
)
|
||||||
self._invites_per_user_limiter = Ratelimiter(
|
self._invites_per_user_limiter = Ratelimiter(
|
||||||
|
store=self.store,
|
||||||
clock=self.clock,
|
clock=self.clock,
|
||||||
rate_hz=hs.config.ratelimiting.rc_invites_per_user.per_second,
|
rate_hz=hs.config.ratelimiting.rc_invites_per_user.per_second,
|
||||||
burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count,
|
burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count,
|
||||||
|
@ -160,15 +164,20 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||||
async def forget(self, user: UserID, room_id: str) -> None:
|
async def forget(self, user: UserID, room_id: str) -> None:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def ratelimit_invite(self, room_id: Optional[str], invitee_user_id: str):
|
async def ratelimit_invite(
|
||||||
|
self,
|
||||||
|
requester: Optional[Requester],
|
||||||
|
room_id: Optional[str],
|
||||||
|
invitee_user_id: str,
|
||||||
|
):
|
||||||
"""Ratelimit invites by room and by target user.
|
"""Ratelimit invites by room and by target user.
|
||||||
|
|
||||||
If room ID is missing then we just rate limit by target user.
|
If room ID is missing then we just rate limit by target user.
|
||||||
"""
|
"""
|
||||||
if room_id:
|
if room_id:
|
||||||
self._invites_per_room_limiter.ratelimit(room_id)
|
await self._invites_per_room_limiter.ratelimit(requester, room_id)
|
||||||
|
|
||||||
self._invites_per_user_limiter.ratelimit(invitee_user_id)
|
await self._invites_per_user_limiter.ratelimit(requester, invitee_user_id)
|
||||||
|
|
||||||
async def _local_membership_update(
|
async def _local_membership_update(
|
||||||
self,
|
self,
|
||||||
|
@ -238,7 +247,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||||
(
|
(
|
||||||
allowed,
|
allowed,
|
||||||
time_allowed,
|
time_allowed,
|
||||||
) = self._join_rate_limiter_local.can_requester_do_action(requester)
|
) = await self._join_rate_limiter_local.can_do_action(requester)
|
||||||
|
|
||||||
if not allowed:
|
if not allowed:
|
||||||
raise LimitExceededError(
|
raise LimitExceededError(
|
||||||
|
@ -441,9 +450,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||||
if effective_membership_state == Membership.INVITE:
|
if effective_membership_state == Membership.INVITE:
|
||||||
target_id = target.to_string()
|
target_id = target.to_string()
|
||||||
if ratelimit:
|
if ratelimit:
|
||||||
# Don't ratelimit application services.
|
await self.ratelimit_invite(requester, room_id, target_id)
|
||||||
if not requester.app_service or requester.app_service.is_rate_limited():
|
|
||||||
self.ratelimit_invite(room_id, target_id)
|
|
||||||
|
|
||||||
# block any attempts to invite the server notices mxid
|
# block any attempts to invite the server notices mxid
|
||||||
if target_id == self._server_notices_mxid:
|
if target_id == self._server_notices_mxid:
|
||||||
|
@ -554,7 +561,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||||
(
|
(
|
||||||
allowed,
|
allowed,
|
||||||
time_allowed,
|
time_allowed,
|
||||||
) = self._join_rate_limiter_remote.can_requester_do_action(
|
) = await self._join_rate_limiter_remote.can_do_action(
|
||||||
requester,
|
requester,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -24,6 +24,7 @@ from synapse.api.constants import AccountDataTypes, EventTypes, Membership
|
||||||
from synapse.api.filtering import FilterCollection
|
from synapse.api.filtering import FilterCollection
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.logging.context import current_context
|
from synapse.logging.context import current_context
|
||||||
|
from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span
|
||||||
from synapse.push.clientformat import format_push_rules_for_user
|
from synapse.push.clientformat import format_push_rules_for_user
|
||||||
from synapse.storage.roommember import MemberSummary
|
from synapse.storage.roommember import MemberSummary
|
||||||
from synapse.storage.state import StateFilter
|
from synapse.storage.state import StateFilter
|
||||||
|
@ -252,13 +253,13 @@ class SyncHandler:
|
||||||
self.storage = hs.get_storage()
|
self.storage = hs.get_storage()
|
||||||
self.state_store = self.storage.state
|
self.state_store = self.storage.state
|
||||||
|
|
||||||
# ExpiringCache((User, Device)) -> LruCache(state_key => event_id)
|
# ExpiringCache((User, Device)) -> LruCache(user_id => event_id)
|
||||||
self.lazy_loaded_members_cache = ExpiringCache(
|
self.lazy_loaded_members_cache = ExpiringCache(
|
||||||
"lazy_loaded_members_cache",
|
"lazy_loaded_members_cache",
|
||||||
self.clock,
|
self.clock,
|
||||||
max_len=0,
|
max_len=0,
|
||||||
expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE,
|
expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE,
|
||||||
)
|
) # type: ExpiringCache[Tuple[str, Optional[str]], LruCache[str, str]]
|
||||||
|
|
||||||
async def wait_for_sync_for_user(
|
async def wait_for_sync_for_user(
|
||||||
self,
|
self,
|
||||||
|
@ -341,7 +342,14 @@ class SyncHandler:
|
||||||
full_state: bool = False,
|
full_state: bool = False,
|
||||||
) -> SyncResult:
|
) -> SyncResult:
|
||||||
"""Get the sync for client needed to match what the server has now."""
|
"""Get the sync for client needed to match what the server has now."""
|
||||||
return await self.generate_sync_result(sync_config, since_token, full_state)
|
with start_active_span("current_sync_for_user"):
|
||||||
|
log_kv({"since_token": since_token})
|
||||||
|
sync_result = await self.generate_sync_result(
|
||||||
|
sync_config, since_token, full_state
|
||||||
|
)
|
||||||
|
|
||||||
|
set_tag(SynapseTags.SYNC_RESULT, bool(sync_result))
|
||||||
|
return sync_result
|
||||||
|
|
||||||
async def push_rules_for_user(self, user: UserID) -> JsonDict:
|
async def push_rules_for_user(self, user: UserID) -> JsonDict:
|
||||||
user_id = user.to_string()
|
user_id = user.to_string()
|
||||||
|
@ -726,8 +734,10 @@ class SyncHandler:
|
||||||
|
|
||||||
def get_lazy_loaded_members_cache(
|
def get_lazy_loaded_members_cache(
|
||||||
self, cache_key: Tuple[str, Optional[str]]
|
self, cache_key: Tuple[str, Optional[str]]
|
||||||
) -> LruCache:
|
) -> LruCache[str, str]:
|
||||||
cache = self.lazy_loaded_members_cache.get(cache_key)
|
cache = self.lazy_loaded_members_cache.get(
|
||||||
|
cache_key
|
||||||
|
) # type: Optional[LruCache[str, str]]
|
||||||
if cache is None:
|
if cache is None:
|
||||||
logger.debug("creating LruCache for %r", cache_key)
|
logger.debug("creating LruCache for %r", cache_key)
|
||||||
cache = LruCache(LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE)
|
cache = LruCache(LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE)
|
||||||
|
@ -965,6 +975,7 @@ class SyncHandler:
|
||||||
# to query up to a given point.
|
# to query up to a given point.
|
||||||
# Always use the `now_token` in `SyncResultBuilder`
|
# Always use the `now_token` in `SyncResultBuilder`
|
||||||
now_token = self.event_sources.get_current_token()
|
now_token = self.event_sources.get_current_token()
|
||||||
|
log_kv({"now_token": now_token})
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Calculating sync response for %r between %s and %s",
|
"Calculating sync response for %r between %s and %s",
|
||||||
|
@ -1226,6 +1237,13 @@ class SyncHandler:
|
||||||
user_id, device_id, since_stream_id, now_token.to_device_key
|
user_id, device_id, since_stream_id, now_token.to_device_key
|
||||||
)
|
)
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
# We pop here as we shouldn't be sending the message ID down
|
||||||
|
# `/sync`
|
||||||
|
message_id = message.pop("message_id", None)
|
||||||
|
if message_id:
|
||||||
|
set_tag(SynapseTags.TO_DEVICE_MESSAGE_ID, message_id)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Returning %d to-device messages between %d and %d (current token: %d)",
|
"Returning %d to-device messages between %d and %d (current token: %d)",
|
||||||
len(messages),
|
len(messages),
|
||||||
|
|
|
@ -19,7 +19,10 @@ from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
|
||||||
|
|
||||||
from synapse.api.errors import AuthError, ShadowBanError, SynapseError
|
from synapse.api.errors import AuthError, ShadowBanError, SynapseError
|
||||||
from synapse.appservice import ApplicationService
|
from synapse.appservice import ApplicationService
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import (
|
||||||
|
run_as_background_process,
|
||||||
|
wrap_as_background_process,
|
||||||
|
)
|
||||||
from synapse.replication.tcp.streams import TypingStream
|
from synapse.replication.tcp.streams import TypingStream
|
||||||
from synapse.types import JsonDict, Requester, UserID, get_domain_from_id
|
from synapse.types import JsonDict, Requester, UserID, get_domain_from_id
|
||||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||||
|
@ -86,6 +89,7 @@ class FollowerTypingHandler:
|
||||||
self._member_last_federation_poke = {}
|
self._member_last_federation_poke = {}
|
||||||
self.wheel_timer = WheelTimer(bucket_size=5000)
|
self.wheel_timer = WheelTimer(bucket_size=5000)
|
||||||
|
|
||||||
|
@wrap_as_background_process("typing._handle_timeouts")
|
||||||
def _handle_timeouts(self) -> None:
|
def _handle_timeouts(self) -> None:
|
||||||
logger.debug("Checking for typing timeouts")
|
logger.debug("Checking for typing timeouts")
|
||||||
|
|
||||||
|
|
|
@ -590,7 +590,7 @@ class SimpleHttpClient:
|
||||||
uri: str,
|
uri: str,
|
||||||
json_body: Any,
|
json_body: Any,
|
||||||
args: Optional[QueryParams] = None,
|
args: Optional[QueryParams] = None,
|
||||||
headers: RawHeaders = None,
|
headers: Optional[RawHeaders] = None,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Puts some json to the given URI.
|
"""Puts some json to the given URI.
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
import contextlib
|
import contextlib
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import Optional, Type, Union
|
from typing import Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
from zope.interface import implementer
|
from zope.interface import implementer
|
||||||
|
@ -26,7 +26,11 @@ from twisted.web.server import Request, Site
|
||||||
from synapse.config.server import ListenerConfig
|
from synapse.config.server import ListenerConfig
|
||||||
from synapse.http import get_request_user_agent, redact_uri
|
from synapse.http import get_request_user_agent, redact_uri
|
||||||
from synapse.http.request_metrics import RequestMetrics, requests_counter
|
from synapse.http.request_metrics import RequestMetrics, requests_counter
|
||||||
from synapse.logging.context import LoggingContext, PreserveLoggingContext
|
from synapse.logging.context import (
|
||||||
|
ContextRequest,
|
||||||
|
LoggingContext,
|
||||||
|
PreserveLoggingContext,
|
||||||
|
)
|
||||||
from synapse.types import Requester
|
from synapse.types import Requester
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -63,7 +67,7 @@ class SynapseRequest(Request):
|
||||||
|
|
||||||
# The requester, if authenticated. For federation requests this is the
|
# The requester, if authenticated. For federation requests this is the
|
||||||
# server name, for client requests this is the Requester object.
|
# server name, for client requests this is the Requester object.
|
||||||
self.requester = None # type: Optional[Union[Requester, str]]
|
self._requester = None # type: Optional[Union[Requester, str]]
|
||||||
|
|
||||||
# we can't yet create the logcontext, as we don't know the method.
|
# we can't yet create the logcontext, as we don't know the method.
|
||||||
self.logcontext = None # type: Optional[LoggingContext]
|
self.logcontext = None # type: Optional[LoggingContext]
|
||||||
|
@ -93,6 +97,31 @@ class SynapseRequest(Request):
|
||||||
self.site.site_tag,
|
self.site.site_tag,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def requester(self) -> Optional[Union[Requester, str]]:
|
||||||
|
return self._requester
|
||||||
|
|
||||||
|
@requester.setter
|
||||||
|
def requester(self, value: Union[Requester, str]) -> None:
|
||||||
|
# Store the requester, and update some properties based on it.
|
||||||
|
|
||||||
|
# This should only be called once.
|
||||||
|
assert self._requester is None
|
||||||
|
|
||||||
|
self._requester = value
|
||||||
|
|
||||||
|
# A logging context should exist by now (and have a ContextRequest).
|
||||||
|
assert self.logcontext is not None
|
||||||
|
assert self.logcontext.request is not None
|
||||||
|
|
||||||
|
(
|
||||||
|
requester,
|
||||||
|
authenticated_entity,
|
||||||
|
) = self.get_authenticated_entity()
|
||||||
|
self.logcontext.request.requester = requester
|
||||||
|
# If there's no authenticated entity, it was the requester.
|
||||||
|
self.logcontext.request.authenticated_entity = authenticated_entity or requester
|
||||||
|
|
||||||
def get_request_id(self):
|
def get_request_id(self):
|
||||||
return "%s-%i" % (self.get_method(), self.request_seq)
|
return "%s-%i" % (self.get_method(), self.request_seq)
|
||||||
|
|
||||||
|
@ -126,13 +155,60 @@ class SynapseRequest(Request):
|
||||||
return self.method.decode("ascii")
|
return self.method.decode("ascii")
|
||||||
return method
|
return method
|
||||||
|
|
||||||
|
def get_authenticated_entity(self) -> Tuple[Optional[str], Optional[str]]:
|
||||||
|
"""
|
||||||
|
Get the "authenticated" entity of the request, which might be the user
|
||||||
|
performing the action, or a user being puppeted by a server admin.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple:
|
||||||
|
The first item is a string representing the user making the request.
|
||||||
|
|
||||||
|
The second item is a string or None representing the user who
|
||||||
|
authenticated when making this request. See
|
||||||
|
Requester.authenticated_entity.
|
||||||
|
"""
|
||||||
|
# Convert the requester into a string that we can log
|
||||||
|
if isinstance(self._requester, str):
|
||||||
|
return self._requester, None
|
||||||
|
elif isinstance(self._requester, Requester):
|
||||||
|
requester = self._requester.user.to_string()
|
||||||
|
authenticated_entity = self._requester.authenticated_entity
|
||||||
|
|
||||||
|
# If this is a request where the target user doesn't match the user who
|
||||||
|
# authenticated (e.g. and admin is puppetting a user) then we return both.
|
||||||
|
if self._requester.user.to_string() != authenticated_entity:
|
||||||
|
return requester, authenticated_entity
|
||||||
|
|
||||||
|
return requester, None
|
||||||
|
elif self._requester is not None:
|
||||||
|
# This shouldn't happen, but we log it so we don't lose information
|
||||||
|
# and can see that we're doing something wrong.
|
||||||
|
return repr(self._requester), None # type: ignore[unreachable]
|
||||||
|
|
||||||
|
return None, None
|
||||||
|
|
||||||
def render(self, resrc):
|
def render(self, resrc):
|
||||||
# this is called once a Resource has been found to serve the request; in our
|
# this is called once a Resource has been found to serve the request; in our
|
||||||
# case the Resource in question will normally be a JsonResource.
|
# case the Resource in question will normally be a JsonResource.
|
||||||
|
|
||||||
# create a LogContext for this request
|
# create a LogContext for this request
|
||||||
request_id = self.get_request_id()
|
request_id = self.get_request_id()
|
||||||
self.logcontext = LoggingContext(request_id, request=request_id)
|
self.logcontext = LoggingContext(
|
||||||
|
request_id,
|
||||||
|
request=ContextRequest(
|
||||||
|
request_id=request_id,
|
||||||
|
ip_address=self.getClientIP(),
|
||||||
|
site_tag=self.site.site_tag,
|
||||||
|
# The requester is going to be unknown at this point.
|
||||||
|
requester=None,
|
||||||
|
authenticated_entity=None,
|
||||||
|
method=self.get_method(),
|
||||||
|
url=self.get_redacted_uri(),
|
||||||
|
protocol=self.clientproto.decode("ascii", errors="replace"),
|
||||||
|
user_agent=get_request_user_agent(self),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
# override the Server header which is set by twisted
|
# override the Server header which is set by twisted
|
||||||
self.setHeader("Server", self.site.server_version_string)
|
self.setHeader("Server", self.site.server_version_string)
|
||||||
|
@ -277,25 +353,6 @@ class SynapseRequest(Request):
|
||||||
# to the client (nb may be negative)
|
# to the client (nb may be negative)
|
||||||
response_send_time = self.finish_time - self._processing_finished_time
|
response_send_time = self.finish_time - self._processing_finished_time
|
||||||
|
|
||||||
# Convert the requester into a string that we can log
|
|
||||||
authenticated_entity = None
|
|
||||||
if isinstance(self.requester, str):
|
|
||||||
authenticated_entity = self.requester
|
|
||||||
elif isinstance(self.requester, Requester):
|
|
||||||
authenticated_entity = self.requester.authenticated_entity
|
|
||||||
|
|
||||||
# If this is a request where the target user doesn't match the user who
|
|
||||||
# authenticated (e.g. and admin is puppetting a user) then we log both.
|
|
||||||
if self.requester.user.to_string() != authenticated_entity:
|
|
||||||
authenticated_entity = "{},{}".format(
|
|
||||||
authenticated_entity,
|
|
||||||
self.requester.user.to_string(),
|
|
||||||
)
|
|
||||||
elif self.requester is not None:
|
|
||||||
# This shouldn't happen, but we log it so we don't lose information
|
|
||||||
# and can see that we're doing something wrong.
|
|
||||||
authenticated_entity = repr(self.requester) # type: ignore[unreachable]
|
|
||||||
|
|
||||||
user_agent = get_request_user_agent(self, "-")
|
user_agent = get_request_user_agent(self, "-")
|
||||||
|
|
||||||
code = str(self.code)
|
code = str(self.code)
|
||||||
|
@ -305,6 +362,13 @@ class SynapseRequest(Request):
|
||||||
code += "!"
|
code += "!"
|
||||||
|
|
||||||
log_level = logging.INFO if self._should_log_request() else logging.DEBUG
|
log_level = logging.INFO if self._should_log_request() else logging.DEBUG
|
||||||
|
|
||||||
|
# If this is a request where the target user doesn't match the user who
|
||||||
|
# authenticated (e.g. and admin is puppetting a user) then we log both.
|
||||||
|
requester, authenticated_entity = self.get_authenticated_entity()
|
||||||
|
if authenticated_entity:
|
||||||
|
requester = "{}.{}".format(authenticated_entity, requester)
|
||||||
|
|
||||||
self.site.access_logger.log(
|
self.site.access_logger.log(
|
||||||
log_level,
|
log_level,
|
||||||
"%s - %s - {%s}"
|
"%s - %s - {%s}"
|
||||||
|
@ -312,7 +376,7 @@ class SynapseRequest(Request):
|
||||||
' %sB %s "%s %s %s" "%s" [%d dbevts]',
|
' %sB %s "%s %s %s" "%s" [%d dbevts]',
|
||||||
self.getClientIP(),
|
self.getClientIP(),
|
||||||
self.site.site_tag,
|
self.site.site_tag,
|
||||||
authenticated_entity,
|
requester,
|
||||||
processing_time,
|
processing_time,
|
||||||
response_send_time,
|
response_send_time,
|
||||||
usage.ru_utime,
|
usage.ru_utime,
|
||||||
|
|
|
@ -22,7 +22,6 @@ them.
|
||||||
|
|
||||||
See doc/log_contexts.rst for details on how this works.
|
See doc/log_contexts.rst for details on how this works.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
|
@ -30,6 +29,7 @@ import types
|
||||||
import warnings
|
import warnings
|
||||||
from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union
|
from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union
|
||||||
|
|
||||||
|
import attr
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
|
|
||||||
from twisted.internet import defer, threads
|
from twisted.internet import defer, threads
|
||||||
|
@ -181,6 +181,29 @@ class ContextResourceUsage:
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(slots=True)
|
||||||
|
class ContextRequest:
|
||||||
|
"""
|
||||||
|
A bundle of attributes from the SynapseRequest object.
|
||||||
|
|
||||||
|
This exists to:
|
||||||
|
|
||||||
|
* Avoid a cycle between LoggingContext and SynapseRequest.
|
||||||
|
* Be a single variable that can be passed from parent LoggingContexts to
|
||||||
|
their children.
|
||||||
|
"""
|
||||||
|
|
||||||
|
request_id = attr.ib(type=str)
|
||||||
|
ip_address = attr.ib(type=str)
|
||||||
|
site_tag = attr.ib(type=str)
|
||||||
|
requester = attr.ib(type=Optional[str])
|
||||||
|
authenticated_entity = attr.ib(type=Optional[str])
|
||||||
|
method = attr.ib(type=str)
|
||||||
|
url = attr.ib(type=str)
|
||||||
|
protocol = attr.ib(type=str)
|
||||||
|
user_agent = attr.ib(type=str)
|
||||||
|
|
||||||
|
|
||||||
LoggingContextOrSentinel = Union["LoggingContext", "_Sentinel"]
|
LoggingContextOrSentinel = Union["LoggingContext", "_Sentinel"]
|
||||||
|
|
||||||
|
|
||||||
|
@ -256,7 +279,7 @@ class LoggingContext:
|
||||||
self,
|
self,
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
parent_context: "Optional[LoggingContext]" = None,
|
parent_context: "Optional[LoggingContext]" = None,
|
||||||
request: Optional[str] = None,
|
request: Optional[ContextRequest] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.previous_context = current_context()
|
self.previous_context = current_context()
|
||||||
self.name = name
|
self.name = name
|
||||||
|
@ -281,7 +304,11 @@ class LoggingContext:
|
||||||
self.parent_context = parent_context
|
self.parent_context = parent_context
|
||||||
|
|
||||||
if self.parent_context is not None:
|
if self.parent_context is not None:
|
||||||
self.parent_context.copy_to(self)
|
# we track the current request_id
|
||||||
|
self.request = self.parent_context.request
|
||||||
|
|
||||||
|
# we also track the current scope:
|
||||||
|
self.scope = self.parent_context.scope
|
||||||
|
|
||||||
if request is not None:
|
if request is not None:
|
||||||
# the request param overrides the request from the parent context
|
# the request param overrides the request from the parent context
|
||||||
|
@ -289,7 +316,7 @@ class LoggingContext:
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
if self.request:
|
if self.request:
|
||||||
return str(self.request)
|
return self.request.request_id
|
||||||
return "%s@%x" % (self.name, id(self))
|
return "%s@%x" % (self.name, id(self))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -556,8 +583,23 @@ class LoggingContextFilter(logging.Filter):
|
||||||
# we end up in a death spiral of infinite loops, so let's check, for
|
# we end up in a death spiral of infinite loops, so let's check, for
|
||||||
# robustness' sake.
|
# robustness' sake.
|
||||||
if context is not None:
|
if context is not None:
|
||||||
# Logging is interested in the request.
|
# Logging is interested in the request ID. Note that for backwards
|
||||||
record.request = context.request # type: ignore
|
# compatibility this is stored as the "request" on the record.
|
||||||
|
record.request = str(context) # type: ignore
|
||||||
|
|
||||||
|
# Add some data from the HTTP request.
|
||||||
|
request = context.request
|
||||||
|
if request is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
record.ip_address = request.ip_address # type: ignore
|
||||||
|
record.site_tag = request.site_tag # type: ignore
|
||||||
|
record.requester = request.requester # type: ignore
|
||||||
|
record.authenticated_entity = request.authenticated_entity # type: ignore
|
||||||
|
record.method = request.method # type: ignore
|
||||||
|
record.url = request.url # type: ignore
|
||||||
|
record.protocol = request.protocol # type: ignore
|
||||||
|
record.user_agent = request.user_agent # type: ignore
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -630,8 +672,8 @@ def set_current_context(context: LoggingContextOrSentinel) -> LoggingContextOrSe
|
||||||
def nested_logging_context(suffix: str) -> LoggingContext:
|
def nested_logging_context(suffix: str) -> LoggingContext:
|
||||||
"""Creates a new logging context as a child of another.
|
"""Creates a new logging context as a child of another.
|
||||||
|
|
||||||
The nested logging context will have a 'request' made up of the parent context's
|
The nested logging context will have a 'name' made up of the parent context's
|
||||||
request, plus the given suffix.
|
name, plus the given suffix.
|
||||||
|
|
||||||
CPU/db usage stats will be added to the parent context's on exit.
|
CPU/db usage stats will be added to the parent context's on exit.
|
||||||
|
|
||||||
|
@ -641,7 +683,7 @@ def nested_logging_context(suffix: str) -> LoggingContext:
|
||||||
# ... do stuff
|
# ... do stuff
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
suffix: suffix to add to the parent context's 'request'.
|
suffix: suffix to add to the parent context's 'name'.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
LoggingContext: new logging context.
|
LoggingContext: new logging context.
|
||||||
|
@ -653,11 +695,17 @@ def nested_logging_context(suffix: str) -> LoggingContext:
|
||||||
)
|
)
|
||||||
parent_context = None
|
parent_context = None
|
||||||
prefix = ""
|
prefix = ""
|
||||||
|
request = None
|
||||||
else:
|
else:
|
||||||
assert isinstance(curr_context, LoggingContext)
|
assert isinstance(curr_context, LoggingContext)
|
||||||
parent_context = curr_context
|
parent_context = curr_context
|
||||||
prefix = str(parent_context.request)
|
prefix = str(parent_context.name)
|
||||||
return LoggingContext(parent_context=parent_context, request=prefix + "-" + suffix)
|
request = parent_context.request
|
||||||
|
return LoggingContext(
|
||||||
|
prefix + "-" + suffix,
|
||||||
|
parent_context=parent_context,
|
||||||
|
request=request,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def preserve_fn(f):
|
def preserve_fn(f):
|
||||||
|
|
|
@ -259,6 +259,14 @@ except ImportError:
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SynapseTags:
|
||||||
|
# The message ID of any to_device message processed
|
||||||
|
TO_DEVICE_MESSAGE_ID = "to_device.message_id"
|
||||||
|
|
||||||
|
# Whether the sync response has new data to be returned to the client.
|
||||||
|
SYNC_RESULT = "sync.new_data"
|
||||||
|
|
||||||
|
|
||||||
# Block everything by default
|
# Block everything by default
|
||||||
# A regex which matches the server_names to expose traces for.
|
# A regex which matches the server_names to expose traces for.
|
||||||
# None means 'block everything'.
|
# None means 'block everything'.
|
||||||
|
|
|
@ -214,7 +214,12 @@ class GaugeBucketCollector:
|
||||||
Prometheus, and optimise for that case.
|
Prometheus, and optimise for that case.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__slots__ = ("_name", "_documentation", "_bucket_bounds", "_metric")
|
__slots__ = (
|
||||||
|
"_name",
|
||||||
|
"_documentation",
|
||||||
|
"_bucket_bounds",
|
||||||
|
"_metric",
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -242,10 +247,15 @@ class GaugeBucketCollector:
|
||||||
if self._bucket_bounds[-1] != float("inf"):
|
if self._bucket_bounds[-1] != float("inf"):
|
||||||
self._bucket_bounds.append(float("inf"))
|
self._bucket_bounds.append(float("inf"))
|
||||||
|
|
||||||
self._metric = self._values_to_metric([])
|
# We initially set this to None. We won't report metrics until
|
||||||
|
# this has been initialised after a successful data update
|
||||||
|
self._metric = None # type: Optional[GaugeHistogramMetricFamily]
|
||||||
|
|
||||||
registry.register(self)
|
registry.register(self)
|
||||||
|
|
||||||
def collect(self):
|
def collect(self):
|
||||||
|
# Don't report metrics unless we've already collected some data
|
||||||
|
if self._metric is not None:
|
||||||
yield self._metric
|
yield self._metric
|
||||||
|
|
||||||
def update_data(self, values: Iterable[float]):
|
def update_data(self, values: Iterable[float]):
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import TYPE_CHECKING, Dict, Optional, Set
|
from typing import TYPE_CHECKING, Dict, Optional, Set, Union
|
||||||
|
|
||||||
from prometheus_client.core import REGISTRY, Counter, Gauge
|
from prometheus_client.core import REGISTRY, Counter, Gauge
|
||||||
|
|
||||||
|
@ -199,11 +199,11 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar
|
||||||
_background_process_start_count.labels(desc).inc()
|
_background_process_start_count.labels(desc).inc()
|
||||||
_background_process_in_flight_count.labels(desc).inc()
|
_background_process_in_flight_count.labels(desc).inc()
|
||||||
|
|
||||||
with BackgroundProcessLoggingContext(desc, "%s-%i" % (desc, count)) as context:
|
with BackgroundProcessLoggingContext(desc, count) as context:
|
||||||
try:
|
try:
|
||||||
ctx = noop_context_manager()
|
ctx = noop_context_manager()
|
||||||
if bg_start_span:
|
if bg_start_span:
|
||||||
ctx = start_active_span(desc, tags={"request_id": context.request})
|
ctx = start_active_span(desc, tags={"request_id": str(context)})
|
||||||
with ctx:
|
with ctx:
|
||||||
return await maybe_awaitable(func(*args, **kwargs))
|
return await maybe_awaitable(func(*args, **kwargs))
|
||||||
except Exception:
|
except Exception:
|
||||||
|
@ -242,13 +242,19 @@ class BackgroundProcessLoggingContext(LoggingContext):
|
||||||
processes.
|
processes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__slots__ = ["_proc"]
|
__slots__ = ["_id", "_proc"]
|
||||||
|
|
||||||
def __init__(self, name: str, request: Optional[str] = None):
|
def __init__(self, name: str, id: Optional[Union[int, str]] = None):
|
||||||
super().__init__(name, request=request)
|
super().__init__(name)
|
||||||
|
self._id = id
|
||||||
|
|
||||||
self._proc = _BackgroundProcess(name, self)
|
self._proc = _BackgroundProcess(name, self)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
if self._id is not None:
|
||||||
|
return "%s-%s" % (self.name, self._id)
|
||||||
|
return "%s@%x" % (self.name, id(self))
|
||||||
|
|
||||||
def start(self, rusage: "Optional[resource._RUsage]"):
|
def start(self, rusage: "Optional[resource._RUsage]"):
|
||||||
"""Log context has started running (again)."""
|
"""Log context has started running (again)."""
|
||||||
|
|
||||||
|
|
|
@ -50,11 +50,20 @@ class ModuleApi:
|
||||||
self._auth = hs.get_auth()
|
self._auth = hs.get_auth()
|
||||||
self._auth_handler = auth_handler
|
self._auth_handler = auth_handler
|
||||||
self._server_name = hs.hostname
|
self._server_name = hs.hostname
|
||||||
|
self._presence_stream = hs.get_event_sources().sources["presence"]
|
||||||
|
|
||||||
# We expose these as properties below in order to attach a helpful docstring.
|
# We expose these as properties below in order to attach a helpful docstring.
|
||||||
self._http_client = hs.get_simple_http_client() # type: SimpleHttpClient
|
self._http_client = hs.get_simple_http_client() # type: SimpleHttpClient
|
||||||
self._public_room_list_manager = PublicRoomListManager(hs)
|
self._public_room_list_manager = PublicRoomListManager(hs)
|
||||||
|
|
||||||
|
# The next time these users sync, they will receive the current presence
|
||||||
|
# state of all local users. Users are added by send_local_online_presence_to,
|
||||||
|
# and removed after a successful sync.
|
||||||
|
#
|
||||||
|
# We make this a private variable to deter modules from accessing it directly,
|
||||||
|
# though other classes in Synapse will still do so.
|
||||||
|
self._send_full_presence_to_local_users = set()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def http_client(self):
|
def http_client(self):
|
||||||
"""Allows making outbound HTTP requests to remote resources.
|
"""Allows making outbound HTTP requests to remote resources.
|
||||||
|
@ -385,6 +394,47 @@ class ModuleApi:
|
||||||
|
|
||||||
return event
|
return event
|
||||||
|
|
||||||
|
async def send_local_online_presence_to(self, users: Iterable[str]) -> None:
|
||||||
|
"""
|
||||||
|
Forces the equivalent of a presence initial_sync for a set of local or remote
|
||||||
|
users. The users will receive presence for all currently online users that they
|
||||||
|
are considered interested in.
|
||||||
|
|
||||||
|
Updates to remote users will be sent immediately, whereas local users will receive
|
||||||
|
them on their next sync attempt.
|
||||||
|
|
||||||
|
Note that this method can only be run on the main or federation_sender worker
|
||||||
|
processes.
|
||||||
|
"""
|
||||||
|
if not self._hs.should_send_federation():
|
||||||
|
raise Exception(
|
||||||
|
"send_local_online_presence_to can only be run "
|
||||||
|
"on processes that send federation",
|
||||||
|
)
|
||||||
|
|
||||||
|
for user in users:
|
||||||
|
if self._hs.is_mine_id(user):
|
||||||
|
# Modify SyncHandler._generate_sync_entry_for_presence to call
|
||||||
|
# presence_source.get_new_events with an empty `from_key` if
|
||||||
|
# that user's ID were in a list modified by ModuleApi somewhere.
|
||||||
|
# That user would then get all presence state on next incremental sync.
|
||||||
|
|
||||||
|
# Force a presence initial_sync for this user next time
|
||||||
|
self._send_full_presence_to_local_users.add(user)
|
||||||
|
else:
|
||||||
|
# Retrieve presence state for currently online users that this user
|
||||||
|
# is considered interested in
|
||||||
|
presence_events, _ = await self._presence_stream.get_new_events(
|
||||||
|
UserID.from_string(user), from_key=None, include_offline=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send to remote destinations
|
||||||
|
await make_deferred_yieldable(
|
||||||
|
# We pull the federation sender here as we can only do so on workers
|
||||||
|
# that support sending presence
|
||||||
|
self._hs.get_federation_sender().send_presence(presence_events)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class PublicRoomListManager:
|
class PublicRoomListManager:
|
||||||
"""Contains methods for adding to, removing from and querying whether a room
|
"""Contains methods for adding to, removing from and querying whether a room
|
||||||
|
|
|
@ -39,6 +39,7 @@ from synapse.api.errors import AuthError
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.handlers.presence import format_user_presence_state
|
from synapse.handlers.presence import format_user_presence_state
|
||||||
from synapse.logging.context import PreserveLoggingContext
|
from synapse.logging.context import PreserveLoggingContext
|
||||||
|
from synapse.logging.opentracing import log_kv, start_active_span
|
||||||
from synapse.logging.utils import log_function
|
from synapse.logging.utils import log_function
|
||||||
from synapse.metrics import LaterGauge
|
from synapse.metrics import LaterGauge
|
||||||
from synapse.streams.config import PaginationConfig
|
from synapse.streams.config import PaginationConfig
|
||||||
|
@ -136,6 +137,15 @@ class _NotifierUserStream:
|
||||||
self.last_notified_ms = time_now_ms
|
self.last_notified_ms = time_now_ms
|
||||||
noify_deferred = self.notify_deferred
|
noify_deferred = self.notify_deferred
|
||||||
|
|
||||||
|
log_kv(
|
||||||
|
{
|
||||||
|
"notify": self.user_id,
|
||||||
|
"stream": stream_key,
|
||||||
|
"stream_id": stream_id,
|
||||||
|
"listeners": self.count_listeners(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
users_woken_by_stream_counter.labels(stream_key).inc()
|
users_woken_by_stream_counter.labels(stream_key).inc()
|
||||||
|
|
||||||
with PreserveLoggingContext():
|
with PreserveLoggingContext():
|
||||||
|
@ -404,6 +414,13 @@ class Notifier:
|
||||||
with Measure(self.clock, "on_new_event"):
|
with Measure(self.clock, "on_new_event"):
|
||||||
user_streams = set()
|
user_streams = set()
|
||||||
|
|
||||||
|
log_kv(
|
||||||
|
{
|
||||||
|
"waking_up_explicit_users": len(users),
|
||||||
|
"waking_up_explicit_rooms": len(rooms),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
for user in users:
|
for user in users:
|
||||||
user_stream = self.user_to_user_stream.get(str(user))
|
user_stream = self.user_to_user_stream.get(str(user))
|
||||||
if user_stream is not None:
|
if user_stream is not None:
|
||||||
|
@ -476,12 +493,34 @@ class Notifier:
|
||||||
(end_time - now) / 1000.0,
|
(end_time - now) / 1000.0,
|
||||||
self.hs.get_reactor(),
|
self.hs.get_reactor(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
with start_active_span("wait_for_events.deferred"):
|
||||||
|
log_kv(
|
||||||
|
{
|
||||||
|
"wait_for_events": "sleep",
|
||||||
|
"token": prev_token,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
with PreserveLoggingContext():
|
with PreserveLoggingContext():
|
||||||
await listener.deferred
|
await listener.deferred
|
||||||
|
|
||||||
|
log_kv(
|
||||||
|
{
|
||||||
|
"wait_for_events": "woken",
|
||||||
|
"token": user_stream.current_token,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
current_token = user_stream.current_token
|
current_token = user_stream.current_token
|
||||||
|
|
||||||
result = await callback(prev_token, current_token)
|
result = await callback(prev_token, current_token)
|
||||||
|
log_kv(
|
||||||
|
{
|
||||||
|
"wait_for_events": "result",
|
||||||
|
"result": bool(result),
|
||||||
|
}
|
||||||
|
)
|
||||||
if result:
|
if result:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -489,8 +528,10 @@ class Notifier:
|
||||||
# has happened between the old prev_token and the current_token
|
# has happened between the old prev_token and the current_token
|
||||||
prev_token = current_token
|
prev_token = current_token
|
||||||
except defer.TimeoutError:
|
except defer.TimeoutError:
|
||||||
|
log_kv({"wait_for_events": "timeout"})
|
||||||
break
|
break
|
||||||
except defer.CancelledError:
|
except defer.CancelledError:
|
||||||
|
log_kv({"wait_for_events": "cancelled"})
|
||||||
break
|
break
|
||||||
|
|
||||||
if result is None:
|
if result is None:
|
||||||
|
@ -507,7 +548,7 @@ class Notifier:
|
||||||
pagination_config: PaginationConfig,
|
pagination_config: PaginationConfig,
|
||||||
timeout: int,
|
timeout: int,
|
||||||
is_guest: bool = False,
|
is_guest: bool = False,
|
||||||
explicit_room_id: str = None,
|
explicit_room_id: Optional[str] = None,
|
||||||
) -> EventStreamResult:
|
) -> EventStreamResult:
|
||||||
"""For the given user and rooms, return any new events for them. If
|
"""For the given user and rooms, return any new events for them. If
|
||||||
there are no new events wait for up to `timeout` milliseconds for any
|
there are no new events wait for up to `timeout` milliseconds for any
|
||||||
|
|
|
@ -77,7 +77,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
|
||||||
async def _handle_request(self, request, user_id):
|
async def _handle_request(self, request, user_id):
|
||||||
content = parse_json_object_from_request(request)
|
content = parse_json_object_from_request(request)
|
||||||
|
|
||||||
self.registration_handler.check_registration_ratelimit(content["address"])
|
await self.registration_handler.check_registration_ratelimit(content["address"])
|
||||||
|
|
||||||
await self.registration_handler.register_with_store(
|
await self.registration_handler.register_with_store(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
|
|
@ -184,8 +184,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
||||||
|
|
||||||
# a logcontext which we use for processing incoming commands. We declare it as a
|
# a logcontext which we use for processing incoming commands. We declare it as a
|
||||||
# background process so that the CPU stats get reported to prometheus.
|
# background process so that the CPU stats get reported to prometheus.
|
||||||
ctx_name = "replication-conn-%s" % self.conn_id
|
self._logging_context = BackgroundProcessLoggingContext(
|
||||||
self._logging_context = BackgroundProcessLoggingContext(ctx_name, ctx_name)
|
"replication-conn", self.conn_id
|
||||||
|
)
|
||||||
|
|
||||||
def connectionMade(self):
|
def connectionMade(self):
|
||||||
logger.info("[%s] Connection established", self.id())
|
logger.info("[%s] Connection established", self.id())
|
||||||
|
|
|
@ -60,7 +60,7 @@ class ConstantProperty(Generic[T, V]):
|
||||||
|
|
||||||
constant = attr.ib() # type: V
|
constant = attr.ib() # type: V
|
||||||
|
|
||||||
def __get__(self, obj: Optional[T], objtype: Type[T] = None) -> V:
|
def __get__(self, obj: Optional[T], objtype: Optional[Type[T]] = None) -> V:
|
||||||
return self.constant
|
return self.constant
|
||||||
|
|
||||||
def __set__(self, obj: Optional[T], value: V):
|
def __set__(self, obj: Optional[T], value: V):
|
||||||
|
|
|
@ -36,6 +36,7 @@ from synapse.rest.admin._base import (
|
||||||
)
|
)
|
||||||
from synapse.rest.client.v2_alpha._base import client_patterns
|
from synapse.rest.client.v2_alpha._base import client_patterns
|
||||||
from synapse.storage.databases.main.media_repository import MediaSortOrder
|
from synapse.storage.databases.main.media_repository import MediaSortOrder
|
||||||
|
from synapse.storage.databases.main.stats import UserSortOrder
|
||||||
from synapse.types import JsonDict, UserID
|
from synapse.types import JsonDict, UserID
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -117,8 +118,26 @@ class UsersRestServletV2(RestServlet):
|
||||||
guests = parse_boolean(request, "guests", default=True)
|
guests = parse_boolean(request, "guests", default=True)
|
||||||
deactivated = parse_boolean(request, "deactivated", default=False)
|
deactivated = parse_boolean(request, "deactivated", default=False)
|
||||||
|
|
||||||
|
order_by = parse_string(
|
||||||
|
request,
|
||||||
|
"order_by",
|
||||||
|
default=UserSortOrder.NAME.value,
|
||||||
|
allowed_values=(
|
||||||
|
UserSortOrder.NAME.value,
|
||||||
|
UserSortOrder.DISPLAYNAME.value,
|
||||||
|
UserSortOrder.GUEST.value,
|
||||||
|
UserSortOrder.ADMIN.value,
|
||||||
|
UserSortOrder.DEACTIVATED.value,
|
||||||
|
UserSortOrder.USER_TYPE.value,
|
||||||
|
UserSortOrder.AVATAR_URL.value,
|
||||||
|
UserSortOrder.SHADOW_BANNED.value,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
direction = parse_string(request, "dir", default="f", allowed_values=("f", "b"))
|
||||||
|
|
||||||
users, total = await self.store.get_users_paginate(
|
users, total = await self.store.get_users_paginate(
|
||||||
start, limit, user_id, name, guests, deactivated
|
start, limit, user_id, name, guests, deactivated, order_by, direction
|
||||||
)
|
)
|
||||||
ret = {"users": users, "total": total}
|
ret = {"users": users, "total": total}
|
||||||
if (start + limit) < total:
|
if (start + limit) < total:
|
||||||
|
|
|
@ -74,11 +74,13 @@ class LoginRestServlet(RestServlet):
|
||||||
|
|
||||||
self._well_known_builder = WellKnownBuilder(hs)
|
self._well_known_builder = WellKnownBuilder(hs)
|
||||||
self._address_ratelimiter = Ratelimiter(
|
self._address_ratelimiter = Ratelimiter(
|
||||||
|
store=hs.get_datastore(),
|
||||||
clock=hs.get_clock(),
|
clock=hs.get_clock(),
|
||||||
rate_hz=self.hs.config.rc_login_address.per_second,
|
rate_hz=self.hs.config.rc_login_address.per_second,
|
||||||
burst_count=self.hs.config.rc_login_address.burst_count,
|
burst_count=self.hs.config.rc_login_address.burst_count,
|
||||||
)
|
)
|
||||||
self._account_ratelimiter = Ratelimiter(
|
self._account_ratelimiter = Ratelimiter(
|
||||||
|
store=hs.get_datastore(),
|
||||||
clock=hs.get_clock(),
|
clock=hs.get_clock(),
|
||||||
rate_hz=self.hs.config.rc_login_account.per_second,
|
rate_hz=self.hs.config.rc_login_account.per_second,
|
||||||
burst_count=self.hs.config.rc_login_account.burst_count,
|
burst_count=self.hs.config.rc_login_account.burst_count,
|
||||||
|
@ -141,20 +143,22 @@ class LoginRestServlet(RestServlet):
|
||||||
appservice = self.auth.get_appservice_by_req(request)
|
appservice = self.auth.get_appservice_by_req(request)
|
||||||
|
|
||||||
if appservice.is_rate_limited():
|
if appservice.is_rate_limited():
|
||||||
self._address_ratelimiter.ratelimit(request.getClientIP())
|
await self._address_ratelimiter.ratelimit(
|
||||||
|
None, request.getClientIP()
|
||||||
|
)
|
||||||
|
|
||||||
result = await self._do_appservice_login(login_submission, appservice)
|
result = await self._do_appservice_login(login_submission, appservice)
|
||||||
elif self.jwt_enabled and (
|
elif self.jwt_enabled and (
|
||||||
login_submission["type"] == LoginRestServlet.JWT_TYPE
|
login_submission["type"] == LoginRestServlet.JWT_TYPE
|
||||||
or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED
|
or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED
|
||||||
):
|
):
|
||||||
self._address_ratelimiter.ratelimit(request.getClientIP())
|
await self._address_ratelimiter.ratelimit(None, request.getClientIP())
|
||||||
result = await self._do_jwt_login(login_submission)
|
result = await self._do_jwt_login(login_submission)
|
||||||
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
|
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
|
||||||
self._address_ratelimiter.ratelimit(request.getClientIP())
|
await self._address_ratelimiter.ratelimit(None, request.getClientIP())
|
||||||
result = await self._do_token_login(login_submission)
|
result = await self._do_token_login(login_submission)
|
||||||
else:
|
else:
|
||||||
self._address_ratelimiter.ratelimit(request.getClientIP())
|
await self._address_ratelimiter.ratelimit(None, request.getClientIP())
|
||||||
result = await self._do_other_login(login_submission)
|
result = await self._do_other_login(login_submission)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise SynapseError(400, "Missing JSON keys.")
|
raise SynapseError(400, "Missing JSON keys.")
|
||||||
|
@ -258,7 +262,7 @@ class LoginRestServlet(RestServlet):
|
||||||
# too often. This happens here rather than before as we don't
|
# too often. This happens here rather than before as we don't
|
||||||
# necessarily know the user before now.
|
# necessarily know the user before now.
|
||||||
if ratelimit:
|
if ratelimit:
|
||||||
self._account_ratelimiter.ratelimit(user_id.lower())
|
await self._account_ratelimiter.ratelimit(None, user_id.lower())
|
||||||
|
|
||||||
if create_non_existent_users:
|
if create_non_existent_users:
|
||||||
canonical_uid = await self.auth_handler.check_user_exists(user_id)
|
canonical_uid = await self.auth_handler.check_user_exists(user_id)
|
||||||
|
|
|
@ -103,7 +103,9 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
|
||||||
# Raise if the provided next_link value isn't valid
|
# Raise if the provided next_link value isn't valid
|
||||||
assert_valid_next_link(self.hs, next_link)
|
assert_valid_next_link(self.hs, next_link)
|
||||||
|
|
||||||
self.identity_handler.ratelimit_request_token_requests(request, "email", email)
|
await self.identity_handler.ratelimit_request_token_requests(
|
||||||
|
request, "email", email
|
||||||
|
)
|
||||||
|
|
||||||
# The email will be sent to the stored address.
|
# The email will be sent to the stored address.
|
||||||
# This avoids a potential account hijack by requesting a password reset to
|
# This avoids a potential account hijack by requesting a password reset to
|
||||||
|
@ -387,7 +389,9 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
|
||||||
Codes.THREEPID_DENIED,
|
Codes.THREEPID_DENIED,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.identity_handler.ratelimit_request_token_requests(request, "email", email)
|
await self.identity_handler.ratelimit_request_token_requests(
|
||||||
|
request, "email", email
|
||||||
|
)
|
||||||
|
|
||||||
if next_link:
|
if next_link:
|
||||||
# Raise if the provided next_link value isn't valid
|
# Raise if the provided next_link value isn't valid
|
||||||
|
@ -468,7 +472,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
|
||||||
Codes.THREEPID_DENIED,
|
Codes.THREEPID_DENIED,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.identity_handler.ratelimit_request_token_requests(
|
await self.identity_handler.ratelimit_request_token_requests(
|
||||||
request, "msisdn", msisdn
|
request, "msisdn", msisdn
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -126,7 +126,9 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
|
||||||
Codes.THREEPID_DENIED,
|
Codes.THREEPID_DENIED,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.identity_handler.ratelimit_request_token_requests(request, "email", email)
|
await self.identity_handler.ratelimit_request_token_requests(
|
||||||
|
request, "email", email
|
||||||
|
)
|
||||||
|
|
||||||
existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
|
existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
|
||||||
"email", email
|
"email", email
|
||||||
|
@ -208,7 +210,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
|
||||||
Codes.THREEPID_DENIED,
|
Codes.THREEPID_DENIED,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.identity_handler.ratelimit_request_token_requests(
|
await self.identity_handler.ratelimit_request_token_requests(
|
||||||
request, "msisdn", msisdn
|
request, "msisdn", msisdn
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -406,7 +408,7 @@ class RegisterRestServlet(RestServlet):
|
||||||
|
|
||||||
client_addr = request.getClientIP()
|
client_addr = request.getClientIP()
|
||||||
|
|
||||||
self.ratelimiter.ratelimit(client_addr, update=False)
|
await self.ratelimiter.ratelimit(None, client_addr, update=False)
|
||||||
|
|
||||||
kind = b"user"
|
kind = b"user"
|
||||||
if b"kind" in request.args:
|
if b"kind" in request.args:
|
||||||
|
|
|
@ -175,7 +175,7 @@ class PreviewUrlResource(DirectServeJsonResource):
|
||||||
clock=self.clock,
|
clock=self.clock,
|
||||||
# don't spider URLs more often than once an hour
|
# don't spider URLs more often than once an hour
|
||||||
expiry_ms=ONE_HOUR,
|
expiry_ms=ONE_HOUR,
|
||||||
)
|
) # type: ExpiringCache[str, ObservableDeferred]
|
||||||
|
|
||||||
if self._worker_run_media_background_jobs:
|
if self._worker_run_media_background_jobs:
|
||||||
self._cleaner_loop = self.clock.looping_call(
|
self._cleaner_loop = self.clock.looping_call(
|
||||||
|
|
|
@ -51,6 +51,7 @@ from synapse.crypto import context_factory
|
||||||
from synapse.crypto.context_factory import RegularPolicyForHTTPS
|
from synapse.crypto.context_factory import RegularPolicyForHTTPS
|
||||||
from synapse.crypto.keyring import Keyring
|
from synapse.crypto.keyring import Keyring
|
||||||
from synapse.events.builder import EventBuilderFactory
|
from synapse.events.builder import EventBuilderFactory
|
||||||
|
from synapse.events.presence_router import PresenceRouter
|
||||||
from synapse.events.spamcheck import SpamChecker
|
from synapse.events.spamcheck import SpamChecker
|
||||||
from synapse.events.third_party_rules import ThirdPartyEventRules
|
from synapse.events.third_party_rules import ThirdPartyEventRules
|
||||||
from synapse.events.utils import EventClientSerializer
|
from synapse.events.utils import EventClientSerializer
|
||||||
|
@ -329,6 +330,7 @@ class HomeServer(metaclass=abc.ABCMeta):
|
||||||
@cache_in_self
|
@cache_in_self
|
||||||
def get_registration_ratelimiter(self) -> Ratelimiter:
|
def get_registration_ratelimiter(self) -> Ratelimiter:
|
||||||
return Ratelimiter(
|
return Ratelimiter(
|
||||||
|
store=self.get_datastore(),
|
||||||
clock=self.get_clock(),
|
clock=self.get_clock(),
|
||||||
rate_hz=self.config.rc_registration.per_second,
|
rate_hz=self.config.rc_registration.per_second,
|
||||||
burst_count=self.config.rc_registration.burst_count,
|
burst_count=self.config.rc_registration.burst_count,
|
||||||
|
@ -424,6 +426,10 @@ class HomeServer(metaclass=abc.ABCMeta):
|
||||||
else:
|
else:
|
||||||
raise Exception("Workers cannot write typing")
|
raise Exception("Workers cannot write typing")
|
||||||
|
|
||||||
|
@cache_in_self
|
||||||
|
def get_presence_router(self) -> PresenceRouter:
|
||||||
|
return PresenceRouter(self)
|
||||||
|
|
||||||
@cache_in_self
|
@cache_in_self
|
||||||
def get_typing_handler(self) -> FollowerTypingHandler:
|
def get_typing_handler(self) -> FollowerTypingHandler:
|
||||||
if self.config.worker.writers.typing == self.get_instance_name():
|
if self.config.worker.writers.typing == self.get_instance_name():
|
||||||
|
|
|
@ -22,6 +22,7 @@ from typing import (
|
||||||
Callable,
|
Callable,
|
||||||
DefaultDict,
|
DefaultDict,
|
||||||
Dict,
|
Dict,
|
||||||
|
FrozenSet,
|
||||||
Iterable,
|
Iterable,
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
|
@ -515,7 +516,7 @@ class StateResolutionHandler:
|
||||||
expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
|
expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
|
||||||
iterable=True,
|
iterable=True,
|
||||||
reset_expiry_on_get=True,
|
reset_expiry_on_get=True,
|
||||||
)
|
) # type: ExpiringCache[FrozenSet[int], _StateCacheEntry]
|
||||||
|
|
||||||
#
|
#
|
||||||
# stuff for tracking time spent on state-res by room
|
# stuff for tracking time spent on state-res by room
|
||||||
|
@ -536,7 +537,7 @@ class StateResolutionHandler:
|
||||||
state_groups_ids: Dict[int, StateMap[str]],
|
state_groups_ids: Dict[int, StateMap[str]],
|
||||||
event_map: Optional[Dict[str, EventBase]],
|
event_map: Optional[Dict[str, EventBase]],
|
||||||
state_res_store: "StateResolutionStore",
|
state_res_store: "StateResolutionStore",
|
||||||
):
|
) -> _StateCacheEntry:
|
||||||
"""Resolves conflicts between a set of state groups
|
"""Resolves conflicts between a set of state groups
|
||||||
|
|
||||||
Always generates a new state group (unless we hit the cache), so should
|
Always generates a new state group (unless we hit the cache), so should
|
||||||
|
|
|
@ -21,6 +21,7 @@ from typing import List, Optional, Tuple
|
||||||
from synapse.api.constants import PresenceState
|
from synapse.api.constants import PresenceState
|
||||||
from synapse.config.homeserver import HomeServerConfig
|
from synapse.config.homeserver import HomeServerConfig
|
||||||
from synapse.storage.database import DatabasePool
|
from synapse.storage.database import DatabasePool
|
||||||
|
from synapse.storage.databases.main.stats import UserSortOrder
|
||||||
from synapse.storage.engines import PostgresEngine
|
from synapse.storage.engines import PostgresEngine
|
||||||
from synapse.storage.util.id_generators import (
|
from synapse.storage.util.id_generators import (
|
||||||
IdGenerator,
|
IdGenerator,
|
||||||
|
@ -292,6 +293,8 @@ class DataStore(
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
guests: bool = True,
|
guests: bool = True,
|
||||||
deactivated: bool = False,
|
deactivated: bool = False,
|
||||||
|
order_by: UserSortOrder = UserSortOrder.USER_ID.value,
|
||||||
|
direction: str = "f",
|
||||||
) -> Tuple[List[JsonDict], int]:
|
) -> Tuple[List[JsonDict], int]:
|
||||||
"""Function to retrieve a paginated list of users from
|
"""Function to retrieve a paginated list of users from
|
||||||
users list. This will return a json list of users and the
|
users list. This will return a json list of users and the
|
||||||
|
@ -304,6 +307,8 @@ class DataStore(
|
||||||
name: search for local part of user_id or display name
|
name: search for local part of user_id or display name
|
||||||
guests: whether to in include guest users
|
guests: whether to in include guest users
|
||||||
deactivated: whether to include deactivated users
|
deactivated: whether to include deactivated users
|
||||||
|
order_by: the sort order of the returned list
|
||||||
|
direction: sort ascending or descending
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of a list of mappings from user to information and a count of total users.
|
A tuple of a list of mappings from user to information and a count of total users.
|
||||||
"""
|
"""
|
||||||
|
@ -312,6 +317,14 @@ class DataStore(
|
||||||
filters = []
|
filters = []
|
||||||
args = [self.hs.config.server_name]
|
args = [self.hs.config.server_name]
|
||||||
|
|
||||||
|
# Set ordering
|
||||||
|
order_by_column = UserSortOrder(order_by).value
|
||||||
|
|
||||||
|
if direction == "b":
|
||||||
|
order = "DESC"
|
||||||
|
else:
|
||||||
|
order = "ASC"
|
||||||
|
|
||||||
# `name` is in database already in lower case
|
# `name` is in database already in lower case
|
||||||
if name:
|
if name:
|
||||||
filters.append("(name LIKE ? OR LOWER(displayname) LIKE ?)")
|
filters.append("(name LIKE ? OR LOWER(displayname) LIKE ?)")
|
||||||
|
@ -339,10 +352,15 @@ class DataStore(
|
||||||
txn.execute(sql, args)
|
txn.execute(sql, args)
|
||||||
count = txn.fetchone()[0]
|
count = txn.fetchone()[0]
|
||||||
|
|
||||||
sql = (
|
sql = """
|
||||||
"SELECT name, user_type, is_guest, admin, deactivated, shadow_banned, displayname, avatar_url "
|
SELECT name, user_type, is_guest, admin, deactivated, shadow_banned, displayname, avatar_url
|
||||||
+ sql_base
|
{sql_base}
|
||||||
+ " ORDER BY u.name LIMIT ? OFFSET ?"
|
ORDER BY {order_by_column} {order}, u.name ASC
|
||||||
|
LIMIT ? OFFSET ?
|
||||||
|
""".format(
|
||||||
|
sql_base=sql_base,
|
||||||
|
order_by_column=order_by_column,
|
||||||
|
order=order,
|
||||||
)
|
)
|
||||||
args += [limit, start]
|
args += [limit, start]
|
||||||
txn.execute(sql, args)
|
txn.execute(sql, args)
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from typing import Dict, Iterable, List, Optional, Tuple, overload
|
from typing import Container, Dict, Iterable, List, Optional, Tuple, overload
|
||||||
|
|
||||||
from constantly import NamedConstant, Names
|
from constantly import NamedConstant, Names
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
|
@ -544,7 +544,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||||
async def get_stripped_room_state_from_event_context(
|
async def get_stripped_room_state_from_event_context(
|
||||||
self,
|
self,
|
||||||
context: EventContext,
|
context: EventContext,
|
||||||
state_types_to_include: List[EventTypes],
|
state_types_to_include: Container[str],
|
||||||
membership_user_id: Optional[str] = None,
|
membership_user_id: Optional[str] = None,
|
||||||
) -> List[JsonDict]:
|
) -> List[JsonDict]:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1027,8 +1027,8 @@ class GroupServerStore(GroupServerWorkerStore):
|
||||||
user_id: str,
|
user_id: str,
|
||||||
is_admin: bool = False,
|
is_admin: bool = False,
|
||||||
is_public: bool = True,
|
is_public: bool = True,
|
||||||
local_attestation: dict = None,
|
local_attestation: Optional[dict] = None,
|
||||||
remote_attestation: dict = None,
|
remote_attestation: Optional[dict] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Add a user to the group server.
|
"""Add a user to the group server.
|
||||||
|
|
||||||
|
|
|
@ -22,6 +22,9 @@ from synapse.storage.database import DatabasePool
|
||||||
BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD = (
|
BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD = (
|
||||||
"media_repository_drop_index_wo_method"
|
"media_repository_drop_index_wo_method"
|
||||||
)
|
)
|
||||||
|
BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2 = (
|
||||||
|
"media_repository_drop_index_wo_method_2"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MediaSortOrder(Enum):
|
class MediaSortOrder(Enum):
|
||||||
|
@ -85,23 +88,35 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
|
||||||
unique=True,
|
unique=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# the original impl of _drop_media_index_without_method was broken (see
|
||||||
|
# https://github.com/matrix-org/synapse/issues/8649), so we replace the original
|
||||||
|
# impl with a no-op and run the fixed migration as
|
||||||
|
# media_repository_drop_index_wo_method_2.
|
||||||
|
self.db_pool.updates.register_noop_background_update(
|
||||||
|
BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD
|
||||||
|
)
|
||||||
self.db_pool.updates.register_background_update_handler(
|
self.db_pool.updates.register_background_update_handler(
|
||||||
BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD,
|
BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2,
|
||||||
self._drop_media_index_without_method,
|
self._drop_media_index_without_method,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _drop_media_index_without_method(self, progress, batch_size):
|
async def _drop_media_index_without_method(self, progress, batch_size):
|
||||||
|
"""background update handler which removes the old constraints.
|
||||||
|
|
||||||
|
Note that this is only run on postgres.
|
||||||
|
"""
|
||||||
|
|
||||||
def f(txn):
|
def f(txn):
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"ALTER TABLE local_media_repository_thumbnails DROP CONSTRAINT IF EXISTS local_media_repository_thumbn_media_id_thumbnail_width_thum_key"
|
"ALTER TABLE local_media_repository_thumbnails DROP CONSTRAINT IF EXISTS local_media_repository_thumbn_media_id_thumbnail_width_thum_key"
|
||||||
)
|
)
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"ALTER TABLE remote_media_cache_thumbnails DROP CONSTRAINT IF EXISTS remote_media_repository_thumbn_media_id_thumbnail_width_thum_key"
|
"ALTER TABLE remote_media_cache_thumbnails DROP CONSTRAINT IF EXISTS remote_media_cache_thumbnails_media_origin_media_id_thumbna_key"
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.db_pool.runInteraction("drop_media_indices_without_method", f)
|
await self.db_pool.runInteraction("drop_media_indices_without_method", f)
|
||||||
await self.db_pool.updates._end_background_update(
|
await self.db_pool.updates._end_background_update(
|
||||||
BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD
|
BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2
|
||||||
)
|
)
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,22 @@
|
||||||
|
/* Copyright 2021 The Matrix.org Foundation C.I.C
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
-- drop old constraints on remote_media_cache_thumbnails
|
||||||
|
--
|
||||||
|
-- This was originally part of 57.07, but it was done wrong, per
|
||||||
|
-- https://github.com/matrix-org/synapse/issues/8649, so we do it again.
|
||||||
|
INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES
|
||||||
|
(5911, 'media_repository_drop_index_wo_method_2', '{}', 'remote_media_repository_thumbnails_method_idx');
|
||||||
|
|
|
@ -66,18 +66,37 @@ TYPE_TO_ORIGIN_TABLE = {"room": ("rooms", "room_id"), "user": ("users", "name")}
|
||||||
class UserSortOrder(Enum):
|
class UserSortOrder(Enum):
|
||||||
"""
|
"""
|
||||||
Enum to define the sorting method used when returning users
|
Enum to define the sorting method used when returning users
|
||||||
with get_users_media_usage_paginate
|
with get_users_paginate in __init__.py
|
||||||
|
and get_users_media_usage_paginate in stats.py
|
||||||
|
|
||||||
MEDIA_LENGTH = ordered by size of uploaded media. Smallest to largest.
|
When moves this to __init__.py gets `builtins.ImportError` with
|
||||||
MEDIA_COUNT = ordered by number of uploaded media. Smallest to largest.
|
`most likely due to a circular import`
|
||||||
|
|
||||||
|
MEDIA_LENGTH = ordered by size of uploaded media.
|
||||||
|
MEDIA_COUNT = ordered by number of uploaded media.
|
||||||
USER_ID = ordered alphabetically by `user_id`.
|
USER_ID = ordered alphabetically by `user_id`.
|
||||||
|
NAME = ordered alphabetically by `user_id`. This is for compatibility reasons,
|
||||||
|
as the user_id is returned in the name field in the response in list users admin API.
|
||||||
DISPLAYNAME = ordered alphabetically by `displayname`
|
DISPLAYNAME = ordered alphabetically by `displayname`
|
||||||
|
GUEST = ordered by `is_guest`
|
||||||
|
ADMIN = ordered by `admin`
|
||||||
|
DEACTIVATED = ordered by `deactivated`
|
||||||
|
USER_TYPE = ordered alphabetically by `user_type`
|
||||||
|
AVATAR_URL = ordered alphabetically by `avatar_url`
|
||||||
|
SHADOW_BANNED = ordered by `shadow_banned`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
MEDIA_LENGTH = "media_length"
|
MEDIA_LENGTH = "media_length"
|
||||||
MEDIA_COUNT = "media_count"
|
MEDIA_COUNT = "media_count"
|
||||||
USER_ID = "user_id"
|
USER_ID = "user_id"
|
||||||
|
NAME = "name"
|
||||||
DISPLAYNAME = "displayname"
|
DISPLAYNAME = "displayname"
|
||||||
|
GUEST = "is_guest"
|
||||||
|
ADMIN = "admin"
|
||||||
|
DEACTIVATED = "deactivated"
|
||||||
|
USER_TYPE = "user_type"
|
||||||
|
AVATAR_URL = "avatar_url"
|
||||||
|
SHADOW_BANNED = "shadow_banned"
|
||||||
|
|
||||||
|
|
||||||
class StatsStore(StateDeltasStore):
|
class StatsStore(StateDeltasStore):
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import imp
|
import importlib.util
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
@ -454,8 +454,13 @@ def _upgrade_existing_database(
|
||||||
)
|
)
|
||||||
|
|
||||||
module_name = "synapse.storage.v%d_%s" % (v, root_name)
|
module_name = "synapse.storage.v%d_%s" % (v, root_name)
|
||||||
with open(absolute_path) as python_file:
|
|
||||||
module = imp.load_source(module_name, absolute_path, python_file) # type: ignore
|
spec = importlib.util.spec_from_file_location(
|
||||||
|
module_name, absolute_path
|
||||||
|
)
|
||||||
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
spec.loader.exec_module(module) # type: ignore
|
||||||
|
|
||||||
logger.info("Running script %s", relative_path)
|
logger.info("Running script %s", relative_path)
|
||||||
module.run_create(cur, database_engine) # type: ignore
|
module.run_create(cur, database_engine) # type: ignore
|
||||||
if not is_empty:
|
if not is_empty:
|
||||||
|
|
|
@ -283,7 +283,9 @@ class DeferredCache(Generic[KT, VT]):
|
||||||
# we return a new Deferred which will be called before any subsequent observers.
|
# we return a new Deferred which will be called before any subsequent observers.
|
||||||
return observable.observe()
|
return observable.observe()
|
||||||
|
|
||||||
def prefill(self, key: KT, value: VT, callback: Callable[[], None] = None):
|
def prefill(
|
||||||
|
self, key: KT, value: VT, callback: Optional[Callable[[], None]] = None
|
||||||
|
):
|
||||||
callbacks = [callback] if callback else []
|
callbacks = [callback] if callback else []
|
||||||
self.cache.set(key, value, callbacks=callbacks)
|
self.cache.set(key, value, callbacks=callbacks)
|
||||||
|
|
||||||
|
|
|
@ -15,40 +15,50 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
from typing import Any, Generic, Optional, TypeVar, Union, overload
|
||||||
|
|
||||||
|
import attr
|
||||||
|
from typing_extensions import Literal
|
||||||
|
|
||||||
from synapse.config import cache as cache_config
|
from synapse.config import cache as cache_config
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
|
from synapse.util import Clock
|
||||||
from synapse.util.caches import register_cache
|
from synapse.util.caches import register_cache
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
SENTINEL = object()
|
SENTINEL = object() # type: Any
|
||||||
|
|
||||||
|
|
||||||
class ExpiringCache:
|
T = TypeVar("T")
|
||||||
|
KT = TypeVar("KT")
|
||||||
|
VT = TypeVar("VT")
|
||||||
|
|
||||||
|
|
||||||
|
class ExpiringCache(Generic[KT, VT]):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
cache_name,
|
cache_name: str,
|
||||||
clock,
|
clock: Clock,
|
||||||
max_len=0,
|
max_len: int = 0,
|
||||||
expiry_ms=0,
|
expiry_ms: int = 0,
|
||||||
reset_expiry_on_get=False,
|
reset_expiry_on_get: bool = False,
|
||||||
iterable=False,
|
iterable: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
cache_name (str): Name of this cache, used for logging.
|
cache_name: Name of this cache, used for logging.
|
||||||
clock (Clock)
|
clock
|
||||||
max_len (int): Max size of dict. If the dict grows larger than this
|
max_len: Max size of dict. If the dict grows larger than this
|
||||||
then the oldest items get automatically evicted. Default is 0,
|
then the oldest items get automatically evicted. Default is 0,
|
||||||
which indicates there is no max limit.
|
which indicates there is no max limit.
|
||||||
expiry_ms (int): How long before an item is evicted from the cache
|
expiry_ms: How long before an item is evicted from the cache
|
||||||
in milliseconds. Default is 0, indicating items never get
|
in milliseconds. Default is 0, indicating items never get
|
||||||
evicted based on time.
|
evicted based on time.
|
||||||
reset_expiry_on_get (bool): If true, will reset the expiry time for
|
reset_expiry_on_get: If true, will reset the expiry time for
|
||||||
an item on access. Defaults to False.
|
an item on access. Defaults to False.
|
||||||
iterable (bool): If true, the size is calculated by summing the
|
iterable: If true, the size is calculated by summing the
|
||||||
sizes of all entries, rather than the number of entries.
|
sizes of all entries, rather than the number of entries.
|
||||||
"""
|
"""
|
||||||
self._cache_name = cache_name
|
self._cache_name = cache_name
|
||||||
|
@ -62,7 +72,7 @@ class ExpiringCache:
|
||||||
self._expiry_ms = expiry_ms
|
self._expiry_ms = expiry_ms
|
||||||
self._reset_expiry_on_get = reset_expiry_on_get
|
self._reset_expiry_on_get = reset_expiry_on_get
|
||||||
|
|
||||||
self._cache = OrderedDict()
|
self._cache = OrderedDict() # type: OrderedDict[KT, _CacheEntry]
|
||||||
|
|
||||||
self.iterable = iterable
|
self.iterable = iterable
|
||||||
|
|
||||||
|
@ -79,12 +89,12 @@ class ExpiringCache:
|
||||||
|
|
||||||
self._clock.looping_call(f, self._expiry_ms / 2)
|
self._clock.looping_call(f, self._expiry_ms / 2)
|
||||||
|
|
||||||
def __setitem__(self, key, value):
|
def __setitem__(self, key: KT, value: VT) -> None:
|
||||||
now = self._clock.time_msec()
|
now = self._clock.time_msec()
|
||||||
self._cache[key] = _CacheEntry(now, value)
|
self._cache[key] = _CacheEntry(now, value)
|
||||||
self.evict()
|
self.evict()
|
||||||
|
|
||||||
def evict(self):
|
def evict(self) -> None:
|
||||||
# Evict if there are now too many items
|
# Evict if there are now too many items
|
||||||
while self._max_size and len(self) > self._max_size:
|
while self._max_size and len(self) > self._max_size:
|
||||||
_key, value = self._cache.popitem(last=False)
|
_key, value = self._cache.popitem(last=False)
|
||||||
|
@ -93,7 +103,7 @@ class ExpiringCache:
|
||||||
else:
|
else:
|
||||||
self.metrics.inc_evictions()
|
self.metrics.inc_evictions()
|
||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key: KT) -> VT:
|
||||||
try:
|
try:
|
||||||
entry = self._cache[key]
|
entry = self._cache[key]
|
||||||
self.metrics.inc_hits()
|
self.metrics.inc_hits()
|
||||||
|
@ -106,7 +116,7 @@ class ExpiringCache:
|
||||||
|
|
||||||
return entry.value
|
return entry.value
|
||||||
|
|
||||||
def pop(self, key, default=SENTINEL):
|
def pop(self, key: KT, default: T = SENTINEL) -> Union[VT, T]:
|
||||||
"""Removes and returns the value with the given key from the cache.
|
"""Removes and returns the value with the given key from the cache.
|
||||||
|
|
||||||
If the key isn't in the cache then `default` will be returned if
|
If the key isn't in the cache then `default` will be returned if
|
||||||
|
@ -115,29 +125,40 @@ class ExpiringCache:
|
||||||
Identical functionality to `dict.pop(..)`.
|
Identical functionality to `dict.pop(..)`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
value = self._cache.pop(key, default)
|
value = self._cache.pop(key, SENTINEL)
|
||||||
|
# The key was not found.
|
||||||
if value is SENTINEL:
|
if value is SENTINEL:
|
||||||
|
if default is SENTINEL:
|
||||||
raise KeyError(key)
|
raise KeyError(key)
|
||||||
|
return default
|
||||||
|
|
||||||
return value
|
return value.value
|
||||||
|
|
||||||
def __contains__(self, key):
|
def __contains__(self, key: KT) -> bool:
|
||||||
return key in self._cache
|
return key in self._cache
|
||||||
|
|
||||||
def get(self, key, default=None):
|
@overload
|
||||||
|
def get(self, key: KT, default: Literal[None] = None) -> Optional[VT]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get(self, key: KT, default: T) -> Union[VT, T]:
|
||||||
|
...
|
||||||
|
|
||||||
|
def get(self, key: KT, default: Optional[T] = None) -> Union[VT, Optional[T]]:
|
||||||
try:
|
try:
|
||||||
return self[key]
|
return self[key]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
return default
|
return default
|
||||||
|
|
||||||
def setdefault(self, key, value):
|
def setdefault(self, key: KT, value: VT) -> VT:
|
||||||
try:
|
try:
|
||||||
return self[key]
|
return self[key]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
self[key] = value
|
self[key] = value
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def _prune_cache(self):
|
def _prune_cache(self) -> None:
|
||||||
if not self._expiry_ms:
|
if not self._expiry_ms:
|
||||||
# zero expiry time means don't expire. This should never get called
|
# zero expiry time means don't expire. This should never get called
|
||||||
# since we have this check in start too.
|
# since we have this check in start too.
|
||||||
|
@ -166,7 +187,7 @@ class ExpiringCache:
|
||||||
len(self),
|
len(self),
|
||||||
)
|
)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self) -> int:
|
||||||
if self.iterable:
|
if self.iterable:
|
||||||
return sum(len(entry.value) for entry in self._cache.values())
|
return sum(len(entry.value) for entry in self._cache.values())
|
||||||
else:
|
else:
|
||||||
|
@ -190,9 +211,7 @@ class ExpiringCache:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(slots=True)
|
||||||
class _CacheEntry:
|
class _CacheEntry:
|
||||||
__slots__ = ["time", "value"]
|
time = attr.ib(type=int)
|
||||||
|
value = attr.ib()
|
||||||
def __init__(self, time, value):
|
|
||||||
self.time = time
|
|
||||||
self.value = value
|
|
||||||
|
|
|
@ -5,38 +5,25 @@ from synapse.types import create_requester
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
|
||||||
|
|
||||||
class TestRatelimiter(unittest.TestCase):
|
class TestRatelimiter(unittest.HomeserverTestCase):
|
||||||
def test_allowed_via_can_do_action(self):
|
def test_allowed_via_can_do_action(self):
|
||||||
limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
|
limiter = Ratelimiter(
|
||||||
allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=0)
|
store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
|
||||||
self.assertTrue(allowed)
|
)
|
||||||
self.assertEquals(10.0, time_allowed)
|
allowed, time_allowed = self.get_success_or_raise(
|
||||||
|
limiter.can_do_action(None, key="test_id", _time_now_s=0)
|
||||||
allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=5)
|
|
||||||
self.assertFalse(allowed)
|
|
||||||
self.assertEquals(10.0, time_allowed)
|
|
||||||
|
|
||||||
allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=10)
|
|
||||||
self.assertTrue(allowed)
|
|
||||||
self.assertEquals(20.0, time_allowed)
|
|
||||||
|
|
||||||
def test_allowed_user_via_can_requester_do_action(self):
|
|
||||||
user_requester = create_requester("@user:example.com")
|
|
||||||
limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
|
|
||||||
allowed, time_allowed = limiter.can_requester_do_action(
|
|
||||||
user_requester, _time_now_s=0
|
|
||||||
)
|
)
|
||||||
self.assertTrue(allowed)
|
self.assertTrue(allowed)
|
||||||
self.assertEquals(10.0, time_allowed)
|
self.assertEquals(10.0, time_allowed)
|
||||||
|
|
||||||
allowed, time_allowed = limiter.can_requester_do_action(
|
allowed, time_allowed = self.get_success_or_raise(
|
||||||
user_requester, _time_now_s=5
|
limiter.can_do_action(None, key="test_id", _time_now_s=5)
|
||||||
)
|
)
|
||||||
self.assertFalse(allowed)
|
self.assertFalse(allowed)
|
||||||
self.assertEquals(10.0, time_allowed)
|
self.assertEquals(10.0, time_allowed)
|
||||||
|
|
||||||
allowed, time_allowed = limiter.can_requester_do_action(
|
allowed, time_allowed = self.get_success_or_raise(
|
||||||
user_requester, _time_now_s=10
|
limiter.can_do_action(None, key="test_id", _time_now_s=10)
|
||||||
)
|
)
|
||||||
self.assertTrue(allowed)
|
self.assertTrue(allowed)
|
||||||
self.assertEquals(20.0, time_allowed)
|
self.assertEquals(20.0, time_allowed)
|
||||||
|
@ -51,21 +38,23 @@ class TestRatelimiter(unittest.TestCase):
|
||||||
)
|
)
|
||||||
as_requester = create_requester("@user:example.com", app_service=appservice)
|
as_requester = create_requester("@user:example.com", app_service=appservice)
|
||||||
|
|
||||||
limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
|
limiter = Ratelimiter(
|
||||||
allowed, time_allowed = limiter.can_requester_do_action(
|
store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
|
||||||
as_requester, _time_now_s=0
|
)
|
||||||
|
allowed, time_allowed = self.get_success_or_raise(
|
||||||
|
limiter.can_do_action(as_requester, _time_now_s=0)
|
||||||
)
|
)
|
||||||
self.assertTrue(allowed)
|
self.assertTrue(allowed)
|
||||||
self.assertEquals(10.0, time_allowed)
|
self.assertEquals(10.0, time_allowed)
|
||||||
|
|
||||||
allowed, time_allowed = limiter.can_requester_do_action(
|
allowed, time_allowed = self.get_success_or_raise(
|
||||||
as_requester, _time_now_s=5
|
limiter.can_do_action(as_requester, _time_now_s=5)
|
||||||
)
|
)
|
||||||
self.assertFalse(allowed)
|
self.assertFalse(allowed)
|
||||||
self.assertEquals(10.0, time_allowed)
|
self.assertEquals(10.0, time_allowed)
|
||||||
|
|
||||||
allowed, time_allowed = limiter.can_requester_do_action(
|
allowed, time_allowed = self.get_success_or_raise(
|
||||||
as_requester, _time_now_s=10
|
limiter.can_do_action(as_requester, _time_now_s=10)
|
||||||
)
|
)
|
||||||
self.assertTrue(allowed)
|
self.assertTrue(allowed)
|
||||||
self.assertEquals(20.0, time_allowed)
|
self.assertEquals(20.0, time_allowed)
|
||||||
|
@ -80,73 +69,89 @@ class TestRatelimiter(unittest.TestCase):
|
||||||
)
|
)
|
||||||
as_requester = create_requester("@user:example.com", app_service=appservice)
|
as_requester = create_requester("@user:example.com", app_service=appservice)
|
||||||
|
|
||||||
limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
|
limiter = Ratelimiter(
|
||||||
allowed, time_allowed = limiter.can_requester_do_action(
|
store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
|
||||||
as_requester, _time_now_s=0
|
)
|
||||||
|
allowed, time_allowed = self.get_success_or_raise(
|
||||||
|
limiter.can_do_action(as_requester, _time_now_s=0)
|
||||||
)
|
)
|
||||||
self.assertTrue(allowed)
|
self.assertTrue(allowed)
|
||||||
self.assertEquals(-1, time_allowed)
|
self.assertEquals(-1, time_allowed)
|
||||||
|
|
||||||
allowed, time_allowed = limiter.can_requester_do_action(
|
allowed, time_allowed = self.get_success_or_raise(
|
||||||
as_requester, _time_now_s=5
|
limiter.can_do_action(as_requester, _time_now_s=5)
|
||||||
)
|
)
|
||||||
self.assertTrue(allowed)
|
self.assertTrue(allowed)
|
||||||
self.assertEquals(-1, time_allowed)
|
self.assertEquals(-1, time_allowed)
|
||||||
|
|
||||||
allowed, time_allowed = limiter.can_requester_do_action(
|
allowed, time_allowed = self.get_success_or_raise(
|
||||||
as_requester, _time_now_s=10
|
limiter.can_do_action(as_requester, _time_now_s=10)
|
||||||
)
|
)
|
||||||
self.assertTrue(allowed)
|
self.assertTrue(allowed)
|
||||||
self.assertEquals(-1, time_allowed)
|
self.assertEquals(-1, time_allowed)
|
||||||
|
|
||||||
def test_allowed_via_ratelimit(self):
|
def test_allowed_via_ratelimit(self):
|
||||||
limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
|
limiter = Ratelimiter(
|
||||||
|
store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
|
||||||
|
)
|
||||||
|
|
||||||
# Shouldn't raise
|
# Shouldn't raise
|
||||||
limiter.ratelimit(key="test_id", _time_now_s=0)
|
self.get_success_or_raise(limiter.ratelimit(None, key="test_id", _time_now_s=0))
|
||||||
|
|
||||||
# Should raise
|
# Should raise
|
||||||
with self.assertRaises(LimitExceededError) as context:
|
with self.assertRaises(LimitExceededError) as context:
|
||||||
limiter.ratelimit(key="test_id", _time_now_s=5)
|
self.get_success_or_raise(
|
||||||
|
limiter.ratelimit(None, key="test_id", _time_now_s=5)
|
||||||
|
)
|
||||||
self.assertEqual(context.exception.retry_after_ms, 5000)
|
self.assertEqual(context.exception.retry_after_ms, 5000)
|
||||||
|
|
||||||
# Shouldn't raise
|
# Shouldn't raise
|
||||||
limiter.ratelimit(key="test_id", _time_now_s=10)
|
self.get_success_or_raise(
|
||||||
|
limiter.ratelimit(None, key="test_id", _time_now_s=10)
|
||||||
|
)
|
||||||
|
|
||||||
def test_allowed_via_can_do_action_and_overriding_parameters(self):
|
def test_allowed_via_can_do_action_and_overriding_parameters(self):
|
||||||
"""Test that we can override options of can_do_action that would otherwise fail
|
"""Test that we can override options of can_do_action that would otherwise fail
|
||||||
an action
|
an action
|
||||||
"""
|
"""
|
||||||
# Create a Ratelimiter with a very low allowed rate_hz and burst_count
|
# Create a Ratelimiter with a very low allowed rate_hz and burst_count
|
||||||
limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
|
limiter = Ratelimiter(
|
||||||
|
store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
|
||||||
|
)
|
||||||
|
|
||||||
# First attempt should be allowed
|
# First attempt should be allowed
|
||||||
allowed, time_allowed = limiter.can_do_action(
|
allowed, time_allowed = self.get_success_or_raise(
|
||||||
|
limiter.can_do_action(
|
||||||
|
None,
|
||||||
("test_id",),
|
("test_id",),
|
||||||
_time_now_s=0,
|
_time_now_s=0,
|
||||||
)
|
)
|
||||||
|
)
|
||||||
self.assertTrue(allowed)
|
self.assertTrue(allowed)
|
||||||
self.assertEqual(10.0, time_allowed)
|
self.assertEqual(10.0, time_allowed)
|
||||||
|
|
||||||
# Second attempt, 1s later, will fail
|
# Second attempt, 1s later, will fail
|
||||||
allowed, time_allowed = limiter.can_do_action(
|
allowed, time_allowed = self.get_success_or_raise(
|
||||||
|
limiter.can_do_action(
|
||||||
|
None,
|
||||||
("test_id",),
|
("test_id",),
|
||||||
_time_now_s=1,
|
_time_now_s=1,
|
||||||
)
|
)
|
||||||
|
)
|
||||||
self.assertFalse(allowed)
|
self.assertFalse(allowed)
|
||||||
self.assertEqual(10.0, time_allowed)
|
self.assertEqual(10.0, time_allowed)
|
||||||
|
|
||||||
# But, if we allow 10 actions/sec for this request, we should be allowed
|
# But, if we allow 10 actions/sec for this request, we should be allowed
|
||||||
# to continue.
|
# to continue.
|
||||||
allowed, time_allowed = limiter.can_do_action(
|
allowed, time_allowed = self.get_success_or_raise(
|
||||||
("test_id",), _time_now_s=1, rate_hz=10.0
|
limiter.can_do_action(None, ("test_id",), _time_now_s=1, rate_hz=10.0)
|
||||||
)
|
)
|
||||||
self.assertTrue(allowed)
|
self.assertTrue(allowed)
|
||||||
self.assertEqual(1.1, time_allowed)
|
self.assertEqual(1.1, time_allowed)
|
||||||
|
|
||||||
# Similarly if we allow a burst of 10 actions
|
# Similarly if we allow a burst of 10 actions
|
||||||
allowed, time_allowed = limiter.can_do_action(
|
allowed, time_allowed = self.get_success_or_raise(
|
||||||
("test_id",), _time_now_s=1, burst_count=10
|
limiter.can_do_action(None, ("test_id",), _time_now_s=1, burst_count=10)
|
||||||
)
|
)
|
||||||
self.assertTrue(allowed)
|
self.assertTrue(allowed)
|
||||||
self.assertEqual(1.0, time_allowed)
|
self.assertEqual(1.0, time_allowed)
|
||||||
|
@ -156,29 +161,72 @@ class TestRatelimiter(unittest.TestCase):
|
||||||
fail an action
|
fail an action
|
||||||
"""
|
"""
|
||||||
# Create a Ratelimiter with a very low allowed rate_hz and burst_count
|
# Create a Ratelimiter with a very low allowed rate_hz and burst_count
|
||||||
limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
|
limiter = Ratelimiter(
|
||||||
|
store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
|
||||||
|
)
|
||||||
|
|
||||||
# First attempt should be allowed
|
# First attempt should be allowed
|
||||||
limiter.ratelimit(key=("test_id",), _time_now_s=0)
|
self.get_success_or_raise(
|
||||||
|
limiter.ratelimit(None, key=("test_id",), _time_now_s=0)
|
||||||
|
)
|
||||||
|
|
||||||
# Second attempt, 1s later, will fail
|
# Second attempt, 1s later, will fail
|
||||||
with self.assertRaises(LimitExceededError) as context:
|
with self.assertRaises(LimitExceededError) as context:
|
||||||
limiter.ratelimit(key=("test_id",), _time_now_s=1)
|
self.get_success_or_raise(
|
||||||
|
limiter.ratelimit(None, key=("test_id",), _time_now_s=1)
|
||||||
|
)
|
||||||
self.assertEqual(context.exception.retry_after_ms, 9000)
|
self.assertEqual(context.exception.retry_after_ms, 9000)
|
||||||
|
|
||||||
# But, if we allow 10 actions/sec for this request, we should be allowed
|
# But, if we allow 10 actions/sec for this request, we should be allowed
|
||||||
# to continue.
|
# to continue.
|
||||||
limiter.ratelimit(key=("test_id",), _time_now_s=1, rate_hz=10.0)
|
self.get_success_or_raise(
|
||||||
|
limiter.ratelimit(None, key=("test_id",), _time_now_s=1, rate_hz=10.0)
|
||||||
|
)
|
||||||
|
|
||||||
# Similarly if we allow a burst of 10 actions
|
# Similarly if we allow a burst of 10 actions
|
||||||
limiter.ratelimit(key=("test_id",), _time_now_s=1, burst_count=10)
|
self.get_success_or_raise(
|
||||||
|
limiter.ratelimit(None, key=("test_id",), _time_now_s=1, burst_count=10)
|
||||||
|
)
|
||||||
|
|
||||||
def test_pruning(self):
|
def test_pruning(self):
|
||||||
limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
|
limiter = Ratelimiter(
|
||||||
limiter.can_do_action(key="test_id_1", _time_now_s=0)
|
store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
|
||||||
|
)
|
||||||
|
self.get_success_or_raise(
|
||||||
|
limiter.can_do_action(None, key="test_id_1", _time_now_s=0)
|
||||||
|
)
|
||||||
|
|
||||||
self.assertIn("test_id_1", limiter.actions)
|
self.assertIn("test_id_1", limiter.actions)
|
||||||
|
|
||||||
limiter.can_do_action(key="test_id_2", _time_now_s=10)
|
self.get_success_or_raise(
|
||||||
|
limiter.can_do_action(None, key="test_id_2", _time_now_s=10)
|
||||||
|
)
|
||||||
|
|
||||||
self.assertNotIn("test_id_1", limiter.actions)
|
self.assertNotIn("test_id_1", limiter.actions)
|
||||||
|
|
||||||
|
def test_db_user_override(self):
|
||||||
|
"""Test that users that have ratelimiting disabled in the DB aren't
|
||||||
|
ratelimited.
|
||||||
|
"""
|
||||||
|
store = self.hs.get_datastore()
|
||||||
|
|
||||||
|
user_id = "@user:test"
|
||||||
|
requester = create_requester(user_id)
|
||||||
|
|
||||||
|
self.get_success(
|
||||||
|
store.db_pool.simple_insert(
|
||||||
|
table="ratelimit_override",
|
||||||
|
values={
|
||||||
|
"user_id": user_id,
|
||||||
|
"messages_per_second": None,
|
||||||
|
"burst_count": None,
|
||||||
|
},
|
||||||
|
desc="test_db_user_override",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
limiter = Ratelimiter(store=store, clock=None, rate_hz=0.1, burst_count=1)
|
||||||
|
|
||||||
|
# Shouldn't raise
|
||||||
|
for _ in range(20):
|
||||||
|
self.get_success_or_raise(limiter.ratelimit(requester, _time_now_s=0))
|
||||||
|
|
|
@ -20,6 +20,7 @@ from io import StringIO
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
from synapse.config import ConfigError
|
||||||
from synapse.config.homeserver import HomeServerConfig
|
from synapse.config.homeserver import HomeServerConfig
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
@ -35,9 +36,9 @@ class ConfigLoadingTestCase(unittest.TestCase):
|
||||||
|
|
||||||
def test_load_fails_if_server_name_missing(self):
|
def test_load_fails_if_server_name_missing(self):
|
||||||
self.generate_config_and_remove_lines_containing("server_name")
|
self.generate_config_and_remove_lines_containing("server_name")
|
||||||
with self.assertRaises(Exception):
|
with self.assertRaises(ConfigError):
|
||||||
HomeServerConfig.load_config("", ["-c", self.file])
|
HomeServerConfig.load_config("", ["-c", self.file])
|
||||||
with self.assertRaises(Exception):
|
with self.assertRaises(ConfigError):
|
||||||
HomeServerConfig.load_or_generate_config("", ["-c", self.file])
|
HomeServerConfig.load_or_generate_config("", ["-c", self.file])
|
||||||
|
|
||||||
def test_generates_and_loads_macaroon_secret_key(self):
|
def test_generates_and_loads_macaroon_secret_key(self):
|
||||||
|
|
|
@ -16,6 +16,7 @@ import time
|
||||||
|
|
||||||
from mock import Mock
|
from mock import Mock
|
||||||
|
|
||||||
|
import attr
|
||||||
import canonicaljson
|
import canonicaljson
|
||||||
import signedjson.key
|
import signedjson.key
|
||||||
import signedjson.sign
|
import signedjson.sign
|
||||||
|
@ -68,6 +69,11 @@ class MockPerspectiveServer:
|
||||||
signedjson.sign.sign_json(res, self.server_name, self.key)
|
signedjson.sign.sign_json(res, self.server_name, self.key)
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(slots=True)
|
||||||
|
class FakeRequest:
|
||||||
|
id = attr.ib()
|
||||||
|
|
||||||
|
|
||||||
@logcontext_clean
|
@logcontext_clean
|
||||||
class KeyringTestCase(unittest.HomeserverTestCase):
|
class KeyringTestCase(unittest.HomeserverTestCase):
|
||||||
def check_context(self, val, expected):
|
def check_context(self, val, expected):
|
||||||
|
@ -89,7 +95,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
|
||||||
first_lookup_deferred = Deferred()
|
first_lookup_deferred = Deferred()
|
||||||
|
|
||||||
async def first_lookup_fetch(keys_to_fetch):
|
async def first_lookup_fetch(keys_to_fetch):
|
||||||
self.assertEquals(current_context().request, "context_11")
|
self.assertEquals(current_context().request.id, "context_11")
|
||||||
self.assertEqual(keys_to_fetch, {"server10": {get_key_id(key1): 0}})
|
self.assertEqual(keys_to_fetch, {"server10": {get_key_id(key1): 0}})
|
||||||
|
|
||||||
await make_deferred_yieldable(first_lookup_deferred)
|
await make_deferred_yieldable(first_lookup_deferred)
|
||||||
|
@ -102,9 +108,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
|
||||||
mock_fetcher.get_keys.side_effect = first_lookup_fetch
|
mock_fetcher.get_keys.side_effect = first_lookup_fetch
|
||||||
|
|
||||||
async def first_lookup():
|
async def first_lookup():
|
||||||
with LoggingContext("context_11") as context_11:
|
with LoggingContext("context_11", request=FakeRequest("context_11")):
|
||||||
context_11.request = "context_11"
|
|
||||||
|
|
||||||
res_deferreds = kr.verify_json_objects_for_server(
|
res_deferreds = kr.verify_json_objects_for_server(
|
||||||
[("server10", json1, 0, "test10"), ("server11", {}, 0, "test11")]
|
[("server10", json1, 0, "test10"), ("server11", {}, 0, "test11")]
|
||||||
)
|
)
|
||||||
|
@ -130,7 +134,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
|
||||||
# should block rather than start a second call
|
# should block rather than start a second call
|
||||||
|
|
||||||
async def second_lookup_fetch(keys_to_fetch):
|
async def second_lookup_fetch(keys_to_fetch):
|
||||||
self.assertEquals(current_context().request, "context_12")
|
self.assertEquals(current_context().request.id, "context_12")
|
||||||
return {
|
return {
|
||||||
"server10": {
|
"server10": {
|
||||||
get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)
|
get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)
|
||||||
|
@ -142,9 +146,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
|
||||||
second_lookup_state = [0]
|
second_lookup_state = [0]
|
||||||
|
|
||||||
async def second_lookup():
|
async def second_lookup():
|
||||||
with LoggingContext("context_12") as context_12:
|
with LoggingContext("context_12", request=FakeRequest("context_12")):
|
||||||
context_12.request = "context_12"
|
|
||||||
|
|
||||||
res_deferreds_2 = kr.verify_json_objects_for_server(
|
res_deferreds_2 = kr.verify_json_objects_for_server(
|
||||||
[("server10", json1, 0, "test")]
|
[("server10", json1, 0, "test")]
|
||||||
)
|
)
|
||||||
|
@ -589,10 +591,7 @@ def get_key_id(key):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def run_in_context(f, *args, **kwargs):
|
def run_in_context(f, *args, **kwargs):
|
||||||
with LoggingContext("testctx") as ctx:
|
with LoggingContext("testctx"):
|
||||||
# we set the "request" prop to make it easier to follow what's going on in the
|
|
||||||
# logs.
|
|
||||||
ctx.request = "testctx"
|
|
||||||
rv = yield f(*args, **kwargs)
|
rv = yield f(*args, **kwargs)
|
||||||
return rv
|
return rv
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,386 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the 'License');
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an 'AS IS' BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
|
from mock import Mock
|
||||||
|
|
||||||
|
import attr
|
||||||
|
|
||||||
|
from synapse.api.constants import EduTypes
|
||||||
|
from synapse.events.presence_router import PresenceRouter
|
||||||
|
from synapse.federation.units import Transaction
|
||||||
|
from synapse.handlers.presence import UserPresenceState
|
||||||
|
from synapse.module_api import ModuleApi
|
||||||
|
from synapse.rest import admin
|
||||||
|
from synapse.rest.client.v1 import login, presence, room
|
||||||
|
from synapse.types import JsonDict, StreamToken, create_requester
|
||||||
|
|
||||||
|
from tests.handlers.test_sync import generate_sync_config
|
||||||
|
from tests.unittest import FederatingHomeserverTestCase, TestCase, override_config
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s
|
||||||
|
class PresenceRouterTestConfig:
|
||||||
|
users_who_should_receive_all_presence = attr.ib(type=List[str], default=[])
|
||||||
|
|
||||||
|
|
||||||
|
class PresenceRouterTestModule:
|
||||||
|
def __init__(self, config: PresenceRouterTestConfig, module_api: ModuleApi):
|
||||||
|
self._config = config
|
||||||
|
self._module_api = module_api
|
||||||
|
|
||||||
|
async def get_users_for_states(
|
||||||
|
self, state_updates: Iterable[UserPresenceState]
|
||||||
|
) -> Dict[str, Set[UserPresenceState]]:
|
||||||
|
users_to_state = {
|
||||||
|
user_id: set(state_updates)
|
||||||
|
for user_id in self._config.users_who_should_receive_all_presence
|
||||||
|
}
|
||||||
|
return users_to_state
|
||||||
|
|
||||||
|
async def get_interested_users(
|
||||||
|
self, user_id: str
|
||||||
|
) -> Union[Set[str], PresenceRouter.ALL_USERS]:
|
||||||
|
if user_id in self._config.users_who_should_receive_all_presence:
|
||||||
|
return PresenceRouter.ALL_USERS
|
||||||
|
|
||||||
|
return set()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_config(config_dict: dict) -> PresenceRouterTestConfig:
|
||||||
|
"""Parse a configuration dictionary from the homeserver config, do
|
||||||
|
some validation and return a typed PresenceRouterConfig.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_dict: The configuration dictionary.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A validated config object.
|
||||||
|
"""
|
||||||
|
# Initialise a typed config object
|
||||||
|
config = PresenceRouterTestConfig()
|
||||||
|
|
||||||
|
config.users_who_should_receive_all_presence = config_dict.get(
|
||||||
|
"users_who_should_receive_all_presence"
|
||||||
|
)
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
class PresenceRouterTestCase(FederatingHomeserverTestCase):
|
||||||
|
servlets = [
|
||||||
|
admin.register_servlets,
|
||||||
|
login.register_servlets,
|
||||||
|
room.register_servlets,
|
||||||
|
presence.register_servlets,
|
||||||
|
]
|
||||||
|
|
||||||
|
def make_homeserver(self, reactor, clock):
|
||||||
|
return self.setup_test_homeserver(
|
||||||
|
federation_transport_client=Mock(spec=["send_transaction"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare(self, reactor, clock, homeserver):
|
||||||
|
self.sync_handler = self.hs.get_sync_handler()
|
||||||
|
self.module_api = homeserver.get_module_api()
|
||||||
|
|
||||||
|
@override_config(
|
||||||
|
{
|
||||||
|
"presence": {
|
||||||
|
"presence_router": {
|
||||||
|
"module": __name__ + ".PresenceRouterTestModule",
|
||||||
|
"config": {
|
||||||
|
"users_who_should_receive_all_presence": [
|
||||||
|
"@presence_gobbler:test",
|
||||||
|
]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"send_federation": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def test_receiving_all_presence(self):
|
||||||
|
"""Test that a user that does not share a room with another other can receive
|
||||||
|
presence for them, due to presence routing.
|
||||||
|
"""
|
||||||
|
# Create a user who should receive all presence of others
|
||||||
|
self.presence_receiving_user_id = self.register_user(
|
||||||
|
"presence_gobbler", "monkey"
|
||||||
|
)
|
||||||
|
self.presence_receiving_user_tok = self.login("presence_gobbler", "monkey")
|
||||||
|
|
||||||
|
# And two users who should not have any special routing
|
||||||
|
self.other_user_one_id = self.register_user("other_user_one", "monkey")
|
||||||
|
self.other_user_one_tok = self.login("other_user_one", "monkey")
|
||||||
|
self.other_user_two_id = self.register_user("other_user_two", "monkey")
|
||||||
|
self.other_user_two_tok = self.login("other_user_two", "monkey")
|
||||||
|
|
||||||
|
# Put the other two users in a room with each other
|
||||||
|
room_id = self.helper.create_room_as(
|
||||||
|
self.other_user_one_id, tok=self.other_user_one_tok
|
||||||
|
)
|
||||||
|
|
||||||
|
self.helper.invite(
|
||||||
|
room_id,
|
||||||
|
self.other_user_one_id,
|
||||||
|
self.other_user_two_id,
|
||||||
|
tok=self.other_user_one_tok,
|
||||||
|
)
|
||||||
|
self.helper.join(room_id, self.other_user_two_id, tok=self.other_user_two_tok)
|
||||||
|
# User one sends some presence
|
||||||
|
send_presence_update(
|
||||||
|
self,
|
||||||
|
self.other_user_one_id,
|
||||||
|
self.other_user_one_tok,
|
||||||
|
"online",
|
||||||
|
"boop",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that the presence receiving user gets user one's presence when syncing
|
||||||
|
presence_updates, sync_token = sync_presence(
|
||||||
|
self, self.presence_receiving_user_id
|
||||||
|
)
|
||||||
|
self.assertEqual(len(presence_updates), 1)
|
||||||
|
|
||||||
|
presence_update = presence_updates[0] # type: UserPresenceState
|
||||||
|
self.assertEqual(presence_update.user_id, self.other_user_one_id)
|
||||||
|
self.assertEqual(presence_update.state, "online")
|
||||||
|
self.assertEqual(presence_update.status_msg, "boop")
|
||||||
|
|
||||||
|
# Have all three users send presence
|
||||||
|
send_presence_update(
|
||||||
|
self,
|
||||||
|
self.other_user_one_id,
|
||||||
|
self.other_user_one_tok,
|
||||||
|
"online",
|
||||||
|
"user_one",
|
||||||
|
)
|
||||||
|
send_presence_update(
|
||||||
|
self,
|
||||||
|
self.other_user_two_id,
|
||||||
|
self.other_user_two_tok,
|
||||||
|
"online",
|
||||||
|
"user_two",
|
||||||
|
)
|
||||||
|
send_presence_update(
|
||||||
|
self,
|
||||||
|
self.presence_receiving_user_id,
|
||||||
|
self.presence_receiving_user_tok,
|
||||||
|
"online",
|
||||||
|
"presence_gobbler",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that the presence receiving user gets everyone's presence
|
||||||
|
presence_updates, _ = sync_presence(
|
||||||
|
self, self.presence_receiving_user_id, sync_token
|
||||||
|
)
|
||||||
|
self.assertEqual(len(presence_updates), 3)
|
||||||
|
|
||||||
|
# But that User One only get itself and User Two's presence
|
||||||
|
presence_updates, _ = sync_presence(self, self.other_user_one_id)
|
||||||
|
self.assertEqual(len(presence_updates), 2)
|
||||||
|
|
||||||
|
found = False
|
||||||
|
for update in presence_updates:
|
||||||
|
if update.user_id == self.other_user_two_id:
|
||||||
|
self.assertEqual(update.state, "online")
|
||||||
|
self.assertEqual(update.status_msg, "user_two")
|
||||||
|
found = True
|
||||||
|
|
||||||
|
self.assertTrue(found)
|
||||||
|
|
||||||
|
@override_config(
|
||||||
|
{
|
||||||
|
"presence": {
|
||||||
|
"presence_router": {
|
||||||
|
"module": __name__ + ".PresenceRouterTestModule",
|
||||||
|
"config": {
|
||||||
|
"users_who_should_receive_all_presence": [
|
||||||
|
"@presence_gobbler1:test",
|
||||||
|
"@presence_gobbler2:test",
|
||||||
|
"@far_away_person:island",
|
||||||
|
]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"send_federation": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def test_send_local_online_presence_to_with_module(self):
|
||||||
|
"""Tests that send_local_presence_to_users sends local online presence to a set
|
||||||
|
of specified local and remote users, with a custom PresenceRouter module enabled.
|
||||||
|
"""
|
||||||
|
# Create a user who will send presence updates
|
||||||
|
self.other_user_id = self.register_user("other_user", "monkey")
|
||||||
|
self.other_user_tok = self.login("other_user", "monkey")
|
||||||
|
|
||||||
|
# And another two users that will also send out presence updates, as well as receive
|
||||||
|
# theirs and everyone else's
|
||||||
|
self.presence_receiving_user_one_id = self.register_user(
|
||||||
|
"presence_gobbler1", "monkey"
|
||||||
|
)
|
||||||
|
self.presence_receiving_user_one_tok = self.login("presence_gobbler1", "monkey")
|
||||||
|
self.presence_receiving_user_two_id = self.register_user(
|
||||||
|
"presence_gobbler2", "monkey"
|
||||||
|
)
|
||||||
|
self.presence_receiving_user_two_tok = self.login("presence_gobbler2", "monkey")
|
||||||
|
|
||||||
|
# Have all three users send some presence updates
|
||||||
|
send_presence_update(
|
||||||
|
self,
|
||||||
|
self.other_user_id,
|
||||||
|
self.other_user_tok,
|
||||||
|
"online",
|
||||||
|
"I'm online!",
|
||||||
|
)
|
||||||
|
send_presence_update(
|
||||||
|
self,
|
||||||
|
self.presence_receiving_user_one_id,
|
||||||
|
self.presence_receiving_user_one_tok,
|
||||||
|
"online",
|
||||||
|
"I'm also online!",
|
||||||
|
)
|
||||||
|
send_presence_update(
|
||||||
|
self,
|
||||||
|
self.presence_receiving_user_two_id,
|
||||||
|
self.presence_receiving_user_two_tok,
|
||||||
|
"unavailable",
|
||||||
|
"I'm in a meeting!",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mark each presence-receiving user for receiving all user presence
|
||||||
|
self.get_success(
|
||||||
|
self.module_api.send_local_online_presence_to(
|
||||||
|
[
|
||||||
|
self.presence_receiving_user_one_id,
|
||||||
|
self.presence_receiving_user_two_id,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Perform a sync for each user
|
||||||
|
|
||||||
|
# The other user should only receive their own presence
|
||||||
|
presence_updates, _ = sync_presence(self, self.other_user_id)
|
||||||
|
self.assertEqual(len(presence_updates), 1)
|
||||||
|
|
||||||
|
presence_update = presence_updates[0] # type: UserPresenceState
|
||||||
|
self.assertEqual(presence_update.user_id, self.other_user_id)
|
||||||
|
self.assertEqual(presence_update.state, "online")
|
||||||
|
self.assertEqual(presence_update.status_msg, "I'm online!")
|
||||||
|
|
||||||
|
# Whereas both presence receiving users should receive everyone's presence updates
|
||||||
|
presence_updates, _ = sync_presence(self, self.presence_receiving_user_one_id)
|
||||||
|
self.assertEqual(len(presence_updates), 3)
|
||||||
|
presence_updates, _ = sync_presence(self, self.presence_receiving_user_two_id)
|
||||||
|
self.assertEqual(len(presence_updates), 3)
|
||||||
|
|
||||||
|
# Test that sending to a remote user works
|
||||||
|
remote_user_id = "@far_away_person:island"
|
||||||
|
|
||||||
|
# Note that due to the remote user being in our module's
|
||||||
|
# users_who_should_receive_all_presence config, they would have
|
||||||
|
# received user presence updates already.
|
||||||
|
#
|
||||||
|
# Thus we reset the mock, and try sending all online local user
|
||||||
|
# presence again
|
||||||
|
self.hs.get_federation_transport_client().send_transaction.reset_mock()
|
||||||
|
|
||||||
|
# Broadcast local user online presence
|
||||||
|
self.get_success(
|
||||||
|
self.module_api.send_local_online_presence_to([remote_user_id])
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that the expected presence updates were sent
|
||||||
|
expected_users = [
|
||||||
|
self.other_user_id,
|
||||||
|
self.presence_receiving_user_one_id,
|
||||||
|
self.presence_receiving_user_two_id,
|
||||||
|
]
|
||||||
|
|
||||||
|
calls = (
|
||||||
|
self.hs.get_federation_transport_client().send_transaction.call_args_list
|
||||||
|
)
|
||||||
|
for call in calls:
|
||||||
|
federation_transaction = call.args[0] # type: Transaction
|
||||||
|
|
||||||
|
# Get the sent EDUs in this transaction
|
||||||
|
edus = federation_transaction.get_dict()["edus"]
|
||||||
|
|
||||||
|
for edu in edus:
|
||||||
|
# Make sure we're only checking presence-type EDUs
|
||||||
|
if edu["edu_type"] != EduTypes.Presence:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# EDUs can contain multiple presence updates
|
||||||
|
for presence_update in edu["content"]["push"]:
|
||||||
|
# Check for presence updates that contain the user IDs we're after
|
||||||
|
expected_users.remove(presence_update["user_id"])
|
||||||
|
|
||||||
|
# Ensure that no offline states are being sent out
|
||||||
|
self.assertNotEqual(presence_update["presence"], "offline")
|
||||||
|
|
||||||
|
self.assertEqual(len(expected_users), 0)
|
||||||
|
|
||||||
|
|
||||||
|
def send_presence_update(
|
||||||
|
testcase: TestCase,
|
||||||
|
user_id: str,
|
||||||
|
access_token: str,
|
||||||
|
presence_state: str,
|
||||||
|
status_message: Optional[str] = None,
|
||||||
|
) -> JsonDict:
|
||||||
|
# Build the presence body
|
||||||
|
body = {"presence": presence_state}
|
||||||
|
if status_message:
|
||||||
|
body["status_msg"] = status_message
|
||||||
|
|
||||||
|
# Update the user's presence state
|
||||||
|
channel = testcase.make_request(
|
||||||
|
"PUT", "/presence/%s/status" % (user_id,), body, access_token=access_token
|
||||||
|
)
|
||||||
|
testcase.assertEqual(channel.code, 200)
|
||||||
|
|
||||||
|
return channel.json_body
|
||||||
|
|
||||||
|
|
||||||
|
def sync_presence(
|
||||||
|
testcase: TestCase,
|
||||||
|
user_id: str,
|
||||||
|
since_token: Optional[StreamToken] = None,
|
||||||
|
) -> Tuple[List[UserPresenceState], StreamToken]:
|
||||||
|
"""Perform a sync request for the given user and return the user presence updates
|
||||||
|
they've received, as well as the next_batch token.
|
||||||
|
|
||||||
|
This method assumes testcase.sync_handler points to the homeserver's sync handler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
testcase: The testcase that is currently being run.
|
||||||
|
user_id: The ID of the user to generate a sync response for.
|
||||||
|
since_token: An optional token to indicate from at what point to sync from.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple containing a list of presence updates, and the sync response's
|
||||||
|
next_batch token.
|
||||||
|
"""
|
||||||
|
requester = create_requester(user_id)
|
||||||
|
sync_config = generate_sync_config(requester.user.to_string())
|
||||||
|
sync_result = testcase.get_success(
|
||||||
|
testcase.sync_handler.wait_for_sync_for_user(
|
||||||
|
requester, sync_config, since_token
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return sync_result.presence, sync_result.next_batch
|
|
@ -37,7 +37,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
|
||||||
def test_wait_for_sync_for_user_auth_blocking(self):
|
def test_wait_for_sync_for_user_auth_blocking(self):
|
||||||
user_id1 = "@user1:test"
|
user_id1 = "@user1:test"
|
||||||
user_id2 = "@user2:test"
|
user_id2 = "@user2:test"
|
||||||
sync_config = self._generate_sync_config(user_id1)
|
sync_config = generate_sync_config(user_id1)
|
||||||
requester = create_requester(user_id1)
|
requester = create_requester(user_id1)
|
||||||
|
|
||||||
self.reactor.advance(100) # So we get not 0 time
|
self.reactor.advance(100) # So we get not 0 time
|
||||||
|
@ -60,7 +60,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.auth_blocking._hs_disabled = False
|
self.auth_blocking._hs_disabled = False
|
||||||
|
|
||||||
sync_config = self._generate_sync_config(user_id2)
|
sync_config = generate_sync_config(user_id2)
|
||||||
requester = create_requester(user_id2)
|
requester = create_requester(user_id2)
|
||||||
|
|
||||||
e = self.get_failure(
|
e = self.get_failure(
|
||||||
|
@ -69,7 +69,8 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
|
self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
|
||||||
|
|
||||||
def _generate_sync_config(self, user_id):
|
|
||||||
|
def generate_sync_config(user_id: str) -> SyncConfig:
|
||||||
return SyncConfig(
|
return SyncConfig(
|
||||||
user=UserID(user_id.split(":")[0][1:], user_id.split(":")[1]),
|
user=UserID(user_id.split(":")[0][1:], user_id.split(":")[1]),
|
||||||
filter_collection=DEFAULT_FILTER_COLLECTION,
|
filter_collection=DEFAULT_FILTER_COLLECTION,
|
||||||
|
|
|
@ -12,15 +12,20 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from io import StringIO
|
from io import BytesIO, StringIO
|
||||||
|
|
||||||
|
from mock import Mock, patch
|
||||||
|
|
||||||
|
from twisted.web.server import Request
|
||||||
|
|
||||||
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.logging._terse_json import JsonFormatter, TerseJsonFormatter
|
from synapse.logging._terse_json import JsonFormatter, TerseJsonFormatter
|
||||||
from synapse.logging.context import LoggingContext, LoggingContextFilter
|
from synapse.logging.context import LoggingContext, LoggingContextFilter
|
||||||
|
|
||||||
from tests.logging import LoggerCleanupMixin
|
from tests.logging import LoggerCleanupMixin
|
||||||
|
from tests.server import FakeChannel
|
||||||
from tests.unittest import TestCase
|
from tests.unittest import TestCase
|
||||||
|
|
||||||
|
|
||||||
|
@ -120,7 +125,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
|
||||||
handler.addFilter(LoggingContextFilter())
|
handler.addFilter(LoggingContextFilter())
|
||||||
logger = self.get_logger(handler)
|
logger = self.get_logger(handler)
|
||||||
|
|
||||||
with LoggingContext(request="test"):
|
with LoggingContext("name"):
|
||||||
logger.info("Hello there, %s!", "wally")
|
logger.info("Hello there, %s!", "wally")
|
||||||
|
|
||||||
log = self.get_log_line()
|
log = self.get_log_line()
|
||||||
|
@ -134,4 +139,61 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
|
||||||
]
|
]
|
||||||
self.assertCountEqual(log.keys(), expected_log_keys)
|
self.assertCountEqual(log.keys(), expected_log_keys)
|
||||||
self.assertEqual(log["log"], "Hello there, wally!")
|
self.assertEqual(log["log"], "Hello there, wally!")
|
||||||
self.assertEqual(log["request"], "test")
|
self.assertTrue(log["request"].startswith("name@"))
|
||||||
|
|
||||||
|
def test_with_request_context(self):
|
||||||
|
"""
|
||||||
|
Information from the logging context request should be added to the JSON response.
|
||||||
|
"""
|
||||||
|
handler = logging.StreamHandler(self.output)
|
||||||
|
handler.setFormatter(JsonFormatter())
|
||||||
|
handler.addFilter(LoggingContextFilter())
|
||||||
|
logger = self.get_logger(handler)
|
||||||
|
|
||||||
|
# A full request isn't needed here.
|
||||||
|
site = Mock(spec=["site_tag", "server_version_string", "getResourceFor"])
|
||||||
|
site.site_tag = "test-site"
|
||||||
|
site.server_version_string = "Server v1"
|
||||||
|
request = SynapseRequest(FakeChannel(site, None))
|
||||||
|
# Call requestReceived to finish instantiating the object.
|
||||||
|
request.content = BytesIO()
|
||||||
|
# Partially skip some of the internal processing of SynapseRequest.
|
||||||
|
request._started_processing = Mock()
|
||||||
|
request.request_metrics = Mock(spec=["name"])
|
||||||
|
with patch.object(Request, "render"):
|
||||||
|
request.requestReceived(b"POST", b"/_matrix/client/versions", b"1.1")
|
||||||
|
|
||||||
|
# Also set the requester to ensure the processing works.
|
||||||
|
request.requester = "@foo:test"
|
||||||
|
|
||||||
|
with LoggingContext(parent_context=request.logcontext):
|
||||||
|
logger.info("Hello there, %s!", "wally")
|
||||||
|
|
||||||
|
log = self.get_log_line()
|
||||||
|
|
||||||
|
# The terse logger includes additional request information, if possible.
|
||||||
|
expected_log_keys = [
|
||||||
|
"log",
|
||||||
|
"level",
|
||||||
|
"namespace",
|
||||||
|
"request",
|
||||||
|
"ip_address",
|
||||||
|
"site_tag",
|
||||||
|
"requester",
|
||||||
|
"authenticated_entity",
|
||||||
|
"method",
|
||||||
|
"url",
|
||||||
|
"protocol",
|
||||||
|
"user_agent",
|
||||||
|
]
|
||||||
|
self.assertCountEqual(log.keys(), expected_log_keys)
|
||||||
|
self.assertEqual(log["log"], "Hello there, wally!")
|
||||||
|
self.assertTrue(log["request"].startswith("POST-"))
|
||||||
|
self.assertEqual(log["ip_address"], "127.0.0.1")
|
||||||
|
self.assertEqual(log["site_tag"], "test-site")
|
||||||
|
self.assertEqual(log["requester"], "@foo:test")
|
||||||
|
self.assertEqual(log["authenticated_entity"], "@foo:test")
|
||||||
|
self.assertEqual(log["method"], "POST")
|
||||||
|
self.assertEqual(log["url"], "/_matrix/client/versions")
|
||||||
|
self.assertEqual(log["protocol"], "1.1")
|
||||||
|
self.assertEqual(log["user_agent"], "")
|
||||||
|
|
|
@ -14,25 +14,37 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from mock import Mock
|
from mock import Mock
|
||||||
|
|
||||||
|
from synapse.api.constants import EduTypes
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
|
from synapse.federation.units import Transaction
|
||||||
|
from synapse.handlers.presence import UserPresenceState
|
||||||
from synapse.rest import admin
|
from synapse.rest import admin
|
||||||
from synapse.rest.client.v1 import login, room
|
from synapse.rest.client.v1 import login, presence, room
|
||||||
from synapse.types import create_requester
|
from synapse.types import create_requester
|
||||||
|
|
||||||
from tests.unittest import HomeserverTestCase
|
from tests.events.test_presence_router import send_presence_update, sync_presence
|
||||||
|
from tests.test_utils.event_injection import inject_member_event
|
||||||
|
from tests.unittest import FederatingHomeserverTestCase, override_config
|
||||||
|
|
||||||
|
|
||||||
class ModuleApiTestCase(HomeserverTestCase):
|
class ModuleApiTestCase(FederatingHomeserverTestCase):
|
||||||
servlets = [
|
servlets = [
|
||||||
admin.register_servlets,
|
admin.register_servlets,
|
||||||
login.register_servlets,
|
login.register_servlets,
|
||||||
room.register_servlets,
|
room.register_servlets,
|
||||||
|
presence.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def prepare(self, reactor, clock, homeserver):
|
def prepare(self, reactor, clock, homeserver):
|
||||||
self.store = homeserver.get_datastore()
|
self.store = homeserver.get_datastore()
|
||||||
self.module_api = homeserver.get_module_api()
|
self.module_api = homeserver.get_module_api()
|
||||||
self.event_creation_handler = homeserver.get_event_creation_handler()
|
self.event_creation_handler = homeserver.get_event_creation_handler()
|
||||||
|
self.sync_handler = homeserver.get_sync_handler()
|
||||||
|
|
||||||
|
def make_homeserver(self, reactor, clock):
|
||||||
|
return self.setup_test_homeserver(
|
||||||
|
federation_transport_client=Mock(spec=["send_transaction"]),
|
||||||
|
)
|
||||||
|
|
||||||
def test_can_register_user(self):
|
def test_can_register_user(self):
|
||||||
"""Tests that an external module can register a user"""
|
"""Tests that an external module can register a user"""
|
||||||
|
@ -205,3 +217,160 @@ class ModuleApiTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.assertFalse(is_in_public_rooms)
|
self.assertFalse(is_in_public_rooms)
|
||||||
|
|
||||||
|
# The ability to send federation is required by send_local_online_presence_to.
|
||||||
|
@override_config({"send_federation": True})
|
||||||
|
def test_send_local_online_presence_to(self):
|
||||||
|
"""Tests that send_local_presence_to_users sends local online presence to local users."""
|
||||||
|
# Create a user who will send presence updates
|
||||||
|
self.presence_receiver_id = self.register_user("presence_receiver", "monkey")
|
||||||
|
self.presence_receiver_tok = self.login("presence_receiver", "monkey")
|
||||||
|
|
||||||
|
# And another user that will send presence updates out
|
||||||
|
self.presence_sender_id = self.register_user("presence_sender", "monkey")
|
||||||
|
self.presence_sender_tok = self.login("presence_sender", "monkey")
|
||||||
|
|
||||||
|
# Put them in a room together so they will receive each other's presence updates
|
||||||
|
room_id = self.helper.create_room_as(
|
||||||
|
self.presence_receiver_id,
|
||||||
|
tok=self.presence_receiver_tok,
|
||||||
|
)
|
||||||
|
self.helper.join(room_id, self.presence_sender_id, tok=self.presence_sender_tok)
|
||||||
|
|
||||||
|
# Presence sender comes online
|
||||||
|
send_presence_update(
|
||||||
|
self,
|
||||||
|
self.presence_sender_id,
|
||||||
|
self.presence_sender_tok,
|
||||||
|
"online",
|
||||||
|
"I'm online!",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Presence receiver should have received it
|
||||||
|
presence_updates, sync_token = sync_presence(self, self.presence_receiver_id)
|
||||||
|
self.assertEqual(len(presence_updates), 1)
|
||||||
|
|
||||||
|
presence_update = presence_updates[0] # type: UserPresenceState
|
||||||
|
self.assertEqual(presence_update.user_id, self.presence_sender_id)
|
||||||
|
self.assertEqual(presence_update.state, "online")
|
||||||
|
|
||||||
|
# Syncing again should result in no presence updates
|
||||||
|
presence_updates, sync_token = sync_presence(
|
||||||
|
self, self.presence_receiver_id, sync_token
|
||||||
|
)
|
||||||
|
self.assertEqual(len(presence_updates), 0)
|
||||||
|
|
||||||
|
# Trigger sending local online presence
|
||||||
|
self.get_success(
|
||||||
|
self.module_api.send_local_online_presence_to(
|
||||||
|
[
|
||||||
|
self.presence_receiver_id,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Presence receiver should have received online presence again
|
||||||
|
presence_updates, sync_token = sync_presence(
|
||||||
|
self, self.presence_receiver_id, sync_token
|
||||||
|
)
|
||||||
|
self.assertEqual(len(presence_updates), 1)
|
||||||
|
|
||||||
|
presence_update = presence_updates[0] # type: UserPresenceState
|
||||||
|
self.assertEqual(presence_update.user_id, self.presence_sender_id)
|
||||||
|
self.assertEqual(presence_update.state, "online")
|
||||||
|
|
||||||
|
# Presence sender goes offline
|
||||||
|
send_presence_update(
|
||||||
|
self,
|
||||||
|
self.presence_sender_id,
|
||||||
|
self.presence_sender_tok,
|
||||||
|
"offline",
|
||||||
|
"I slink back into the darkness.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Trigger sending local online presence
|
||||||
|
self.get_success(
|
||||||
|
self.module_api.send_local_online_presence_to(
|
||||||
|
[
|
||||||
|
self.presence_receiver_id,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Presence receiver should *not* have received offline state
|
||||||
|
presence_updates, sync_token = sync_presence(
|
||||||
|
self, self.presence_receiver_id, sync_token
|
||||||
|
)
|
||||||
|
self.assertEqual(len(presence_updates), 0)
|
||||||
|
|
||||||
|
@override_config({"send_federation": True})
|
||||||
|
def test_send_local_online_presence_to_federation(self):
|
||||||
|
"""Tests that send_local_presence_to_users sends local online presence to remote users."""
|
||||||
|
# Create a user who will send presence updates
|
||||||
|
self.presence_sender_id = self.register_user("presence_sender", "monkey")
|
||||||
|
self.presence_sender_tok = self.login("presence_sender", "monkey")
|
||||||
|
|
||||||
|
# And a room they're a part of
|
||||||
|
room_id = self.helper.create_room_as(
|
||||||
|
self.presence_sender_id,
|
||||||
|
tok=self.presence_sender_tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mark them as online
|
||||||
|
send_presence_update(
|
||||||
|
self,
|
||||||
|
self.presence_sender_id,
|
||||||
|
self.presence_sender_tok,
|
||||||
|
"online",
|
||||||
|
"I'm online!",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make up a remote user to send presence to
|
||||||
|
remote_user_id = "@far_away_person:island"
|
||||||
|
|
||||||
|
# Create a join membership event for the remote user into the room.
|
||||||
|
# This allows presence information to flow from one user to the other.
|
||||||
|
self.get_success(
|
||||||
|
inject_member_event(
|
||||||
|
self.hs,
|
||||||
|
room_id,
|
||||||
|
sender=remote_user_id,
|
||||||
|
target=remote_user_id,
|
||||||
|
membership="join",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# The remote user would have received the existing room members' presence
|
||||||
|
# when they joined the room.
|
||||||
|
#
|
||||||
|
# Thus we reset the mock, and try sending online local user
|
||||||
|
# presence again
|
||||||
|
self.hs.get_federation_transport_client().send_transaction.reset_mock()
|
||||||
|
|
||||||
|
# Broadcast local user online presence
|
||||||
|
self.get_success(
|
||||||
|
self.module_api.send_local_online_presence_to([remote_user_id])
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that a presence update was sent as part of a federation transaction
|
||||||
|
found_update = False
|
||||||
|
calls = (
|
||||||
|
self.hs.get_federation_transport_client().send_transaction.call_args_list
|
||||||
|
)
|
||||||
|
for call in calls:
|
||||||
|
federation_transaction = call.args[0] # type: Transaction
|
||||||
|
|
||||||
|
# Get the sent EDUs in this transaction
|
||||||
|
edus = federation_transaction.get_dict()["edus"]
|
||||||
|
|
||||||
|
for edu in edus:
|
||||||
|
# Make sure we're only checking presence-type EDUs
|
||||||
|
if edu["edu_type"] != EduTypes.Presence:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# EDUs can contain multiple presence updates
|
||||||
|
for presence_update in edu["content"]["push"]:
|
||||||
|
if presence_update["user_id"] == self.presence_sender_id:
|
||||||
|
found_update = True
|
||||||
|
|
||||||
|
self.assertTrue(found_update)
|
||||||
|
|
|
@ -28,7 +28,7 @@ from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
|
||||||
from synapse.api.room_versions import RoomVersions
|
from synapse.api.room_versions import RoomVersions
|
||||||
from synapse.rest.client.v1 import login, logout, profile, room
|
from synapse.rest.client.v1 import login, logout, profile, room
|
||||||
from synapse.rest.client.v2_alpha import devices, sync
|
from synapse.rest.client.v2_alpha import devices, sync
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict, UserID
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.server import FakeSite, make_request
|
from tests.server import FakeSite, make_request
|
||||||
|
@ -467,6 +467,8 @@ class UsersListTestCase(unittest.HomeserverTestCase):
|
||||||
url = "/_synapse/admin/v2/users"
|
url = "/_synapse/admin/v2/users"
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor, clock, hs):
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
self.admin_user = self.register_user("admin", "pass", admin=True)
|
self.admin_user = self.register_user("admin", "pass", admin=True)
|
||||||
self.admin_user_tok = self.login("admin", "pass")
|
self.admin_user_tok = self.login("admin", "pass")
|
||||||
|
|
||||||
|
@ -634,6 +636,26 @@ class UsersListTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
|
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
|
||||||
|
|
||||||
|
# unkown order_by
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
self.url + "?order_by=bar",
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
|
||||||
|
|
||||||
|
# invalid search order
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
self.url + "?dir=bar",
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
|
||||||
|
|
||||||
def test_limit(self):
|
def test_limit(self):
|
||||||
"""
|
"""
|
||||||
Testing list of users with limit
|
Testing list of users with limit
|
||||||
|
@ -759,6 +781,103 @@ class UsersListTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(len(channel.json_body["users"]), 1)
|
self.assertEqual(len(channel.json_body["users"]), 1)
|
||||||
self.assertNotIn("next_token", channel.json_body)
|
self.assertNotIn("next_token", channel.json_body)
|
||||||
|
|
||||||
|
def test_order_by(self):
|
||||||
|
"""
|
||||||
|
Testing order list with parameter `order_by`
|
||||||
|
"""
|
||||||
|
|
||||||
|
user1 = self.register_user("user1", "pass1", admin=False, displayname="Name Z")
|
||||||
|
user2 = self.register_user("user2", "pass2", admin=False, displayname="Name Y")
|
||||||
|
|
||||||
|
# Modify user
|
||||||
|
self.get_success(self.store.set_user_deactivated_status(user1, True))
|
||||||
|
self.get_success(self.store.set_shadow_banned(UserID.from_string(user1), True))
|
||||||
|
|
||||||
|
# Set avatar URL to all users, that no user has a NULL value to avoid
|
||||||
|
# different sort order between SQlite and PostreSQL
|
||||||
|
self.get_success(self.store.set_profile_avatar_url("user1", "mxc://url3"))
|
||||||
|
self.get_success(self.store.set_profile_avatar_url("user2", "mxc://url2"))
|
||||||
|
self.get_success(self.store.set_profile_avatar_url("admin", "mxc://url1"))
|
||||||
|
|
||||||
|
# order by default (name)
|
||||||
|
self._order_test([self.admin_user, user1, user2], None)
|
||||||
|
self._order_test([self.admin_user, user1, user2], None, "f")
|
||||||
|
self._order_test([user2, user1, self.admin_user], None, "b")
|
||||||
|
|
||||||
|
# order by name
|
||||||
|
self._order_test([self.admin_user, user1, user2], "name")
|
||||||
|
self._order_test([self.admin_user, user1, user2], "name", "f")
|
||||||
|
self._order_test([user2, user1, self.admin_user], "name", "b")
|
||||||
|
|
||||||
|
# order by displayname
|
||||||
|
self._order_test([user2, user1, self.admin_user], "displayname")
|
||||||
|
self._order_test([user2, user1, self.admin_user], "displayname", "f")
|
||||||
|
self._order_test([self.admin_user, user1, user2], "displayname", "b")
|
||||||
|
|
||||||
|
# order by is_guest
|
||||||
|
# like sort by ascending name, as no guest user here
|
||||||
|
self._order_test([self.admin_user, user1, user2], "is_guest")
|
||||||
|
self._order_test([self.admin_user, user1, user2], "is_guest", "f")
|
||||||
|
self._order_test([self.admin_user, user1, user2], "is_guest", "b")
|
||||||
|
|
||||||
|
# order by admin
|
||||||
|
self._order_test([user1, user2, self.admin_user], "admin")
|
||||||
|
self._order_test([user1, user2, self.admin_user], "admin", "f")
|
||||||
|
self._order_test([self.admin_user, user1, user2], "admin", "b")
|
||||||
|
|
||||||
|
# order by deactivated
|
||||||
|
self._order_test([self.admin_user, user2, user1], "deactivated")
|
||||||
|
self._order_test([self.admin_user, user2, user1], "deactivated", "f")
|
||||||
|
self._order_test([user1, self.admin_user, user2], "deactivated", "b")
|
||||||
|
|
||||||
|
# order by user_type
|
||||||
|
# like sort by ascending name, as no special user type here
|
||||||
|
self._order_test([self.admin_user, user1, user2], "user_type")
|
||||||
|
self._order_test([self.admin_user, user1, user2], "user_type", "f")
|
||||||
|
self._order_test([self.admin_user, user1, user2], "is_guest", "b")
|
||||||
|
|
||||||
|
# order by shadow_banned
|
||||||
|
self._order_test([self.admin_user, user2, user1], "shadow_banned")
|
||||||
|
self._order_test([self.admin_user, user2, user1], "shadow_banned", "f")
|
||||||
|
self._order_test([user1, self.admin_user, user2], "shadow_banned", "b")
|
||||||
|
|
||||||
|
# order by avatar_url
|
||||||
|
self._order_test([self.admin_user, user2, user1], "avatar_url")
|
||||||
|
self._order_test([self.admin_user, user2, user1], "avatar_url", "f")
|
||||||
|
self._order_test([user1, user2, self.admin_user], "avatar_url", "b")
|
||||||
|
|
||||||
|
def _order_test(
|
||||||
|
self,
|
||||||
|
expected_user_list: List[str],
|
||||||
|
order_by: Optional[str],
|
||||||
|
dir: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""Request the list of users in a certain order. Assert that order is what
|
||||||
|
we expect
|
||||||
|
Args:
|
||||||
|
expected_user_list: The list of user_id in the order we expect to get
|
||||||
|
back from the server
|
||||||
|
order_by: The type of ordering to give the server
|
||||||
|
dir: The direction of ordering to give the server
|
||||||
|
"""
|
||||||
|
|
||||||
|
url = self.url + "?deactivated=true&"
|
||||||
|
if order_by is not None:
|
||||||
|
url += "order_by=%s&" % (order_by,)
|
||||||
|
if dir is not None and dir in ("b", "f"):
|
||||||
|
url += "dir=%s" % (dir,)
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
url.encode("ascii"),
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
self.assertEqual(200, channel.code, msg=channel.json_body)
|
||||||
|
self.assertEqual(channel.json_body["total"], len(expected_user_list))
|
||||||
|
|
||||||
|
returned_order = [row["name"] for row in channel.json_body["users"]]
|
||||||
|
self.assertEqual(expected_user_list, returned_order)
|
||||||
|
self._check_fields(channel.json_body["users"])
|
||||||
|
|
||||||
def _check_fields(self, content: JsonDict):
|
def _check_fields(self, content: JsonDict):
|
||||||
"""Checks that the expected user attributes are present in content
|
"""Checks that the expected user attributes are present in content
|
||||||
Args:
|
Args:
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from twisted.internet.defer import succeed
|
from twisted.internet.defer import succeed
|
||||||
|
|
||||||
|
@ -74,7 +74,10 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
|
||||||
return channel
|
return channel
|
||||||
|
|
||||||
def recaptcha(
|
def recaptcha(
|
||||||
self, session: str, expected_post_response: int, post_session: str = None
|
self,
|
||||||
|
session: str,
|
||||||
|
expected_post_response: int,
|
||||||
|
post_session: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Get and respond to a fallback recaptcha. Returns the second request."""
|
"""Get and respond to a fallback recaptcha. Returns the second request."""
|
||||||
if post_session is None:
|
if post_session is None:
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright 2016 OpenMarket Ltd
|
# Copyright 2016-2021 The Matrix.org Foundation C.I.C.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -13,32 +13,21 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
import synapse.api.errors
|
import synapse.api.errors
|
||||||
|
|
||||||
import tests.unittest
|
from tests.unittest import HomeserverTestCase
|
||||||
import tests.utils
|
|
||||||
|
|
||||||
|
|
||||||
class DeviceStoreTestCase(tests.unittest.TestCase):
|
class DeviceStoreTestCase(HomeserverTestCase):
|
||||||
def __init__(self, *args, **kwargs):
|
def prepare(self, reactor, clock, hs):
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.store = None # type: synapse.storage.DataStore
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def setUp(self):
|
|
||||||
hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
|
|
||||||
|
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_store_new_device(self):
|
def test_store_new_device(self):
|
||||||
yield defer.ensureDeferred(
|
self.get_success(
|
||||||
self.store.store_device("user_id", "device_id", "display_name")
|
self.store.store_device("user_id", "device_id", "display_name")
|
||||||
)
|
)
|
||||||
|
|
||||||
res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
|
res = self.get_success(self.store.get_device("user_id", "device_id"))
|
||||||
self.assertDictContainsSubset(
|
self.assertDictContainsSubset(
|
||||||
{
|
{
|
||||||
"user_id": "user_id",
|
"user_id": "user_id",
|
||||||
|
@ -48,19 +37,18 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
|
||||||
res,
|
res,
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_get_devices_by_user(self):
|
def test_get_devices_by_user(self):
|
||||||
yield defer.ensureDeferred(
|
self.get_success(
|
||||||
self.store.store_device("user_id", "device1", "display_name 1")
|
self.store.store_device("user_id", "device1", "display_name 1")
|
||||||
)
|
)
|
||||||
yield defer.ensureDeferred(
|
self.get_success(
|
||||||
self.store.store_device("user_id", "device2", "display_name 2")
|
self.store.store_device("user_id", "device2", "display_name 2")
|
||||||
)
|
)
|
||||||
yield defer.ensureDeferred(
|
self.get_success(
|
||||||
self.store.store_device("user_id2", "device3", "display_name 3")
|
self.store.store_device("user_id2", "device3", "display_name 3")
|
||||||
)
|
)
|
||||||
|
|
||||||
res = yield defer.ensureDeferred(self.store.get_devices_by_user("user_id"))
|
res = self.get_success(self.store.get_devices_by_user("user_id"))
|
||||||
self.assertEqual(2, len(res.keys()))
|
self.assertEqual(2, len(res.keys()))
|
||||||
self.assertDictContainsSubset(
|
self.assertDictContainsSubset(
|
||||||
{
|
{
|
||||||
|
@ -79,43 +67,41 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
|
||||||
res["device2"],
|
res["device2"],
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_count_devices_by_users(self):
|
def test_count_devices_by_users(self):
|
||||||
yield defer.ensureDeferred(
|
self.get_success(
|
||||||
self.store.store_device("user_id", "device1", "display_name 1")
|
self.store.store_device("user_id", "device1", "display_name 1")
|
||||||
)
|
)
|
||||||
yield defer.ensureDeferred(
|
self.get_success(
|
||||||
self.store.store_device("user_id", "device2", "display_name 2")
|
self.store.store_device("user_id", "device2", "display_name 2")
|
||||||
)
|
)
|
||||||
yield defer.ensureDeferred(
|
self.get_success(
|
||||||
self.store.store_device("user_id2", "device3", "display_name 3")
|
self.store.store_device("user_id2", "device3", "display_name 3")
|
||||||
)
|
)
|
||||||
|
|
||||||
res = yield defer.ensureDeferred(self.store.count_devices_by_users())
|
res = self.get_success(self.store.count_devices_by_users())
|
||||||
self.assertEqual(0, res)
|
self.assertEqual(0, res)
|
||||||
|
|
||||||
res = yield defer.ensureDeferred(self.store.count_devices_by_users(["unknown"]))
|
res = self.get_success(self.store.count_devices_by_users(["unknown"]))
|
||||||
self.assertEqual(0, res)
|
self.assertEqual(0, res)
|
||||||
|
|
||||||
res = yield defer.ensureDeferred(self.store.count_devices_by_users(["user_id"]))
|
res = self.get_success(self.store.count_devices_by_users(["user_id"]))
|
||||||
self.assertEqual(2, res)
|
self.assertEqual(2, res)
|
||||||
|
|
||||||
res = yield defer.ensureDeferred(
|
res = self.get_success(
|
||||||
self.store.count_devices_by_users(["user_id", "user_id2"])
|
self.store.count_devices_by_users(["user_id", "user_id2"])
|
||||||
)
|
)
|
||||||
self.assertEqual(3, res)
|
self.assertEqual(3, res)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_get_device_updates_by_remote(self):
|
def test_get_device_updates_by_remote(self):
|
||||||
device_ids = ["device_id1", "device_id2"]
|
device_ids = ["device_id1", "device_id2"]
|
||||||
|
|
||||||
# Add two device updates with a single stream_id
|
# Add two device updates with a single stream_id
|
||||||
yield defer.ensureDeferred(
|
self.get_success(
|
||||||
self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
|
self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get all device updates ever meant for this remote
|
# Get all device updates ever meant for this remote
|
||||||
now_stream_id, device_updates = yield defer.ensureDeferred(
|
now_stream_id, device_updates = self.get_success(
|
||||||
self.store.get_device_updates_by_remote("somehost", -1, limit=100)
|
self.store.get_device_updates_by_remote("somehost", -1, limit=100)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -131,37 +117,35 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
|
||||||
}
|
}
|
||||||
self.assertEqual(received_device_ids, set(expected_device_ids))
|
self.assertEqual(received_device_ids, set(expected_device_ids))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_update_device(self):
|
def test_update_device(self):
|
||||||
yield defer.ensureDeferred(
|
self.get_success(
|
||||||
self.store.store_device("user_id", "device_id", "display_name 1")
|
self.store.store_device("user_id", "device_id", "display_name 1")
|
||||||
)
|
)
|
||||||
|
|
||||||
res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
|
res = self.get_success(self.store.get_device("user_id", "device_id"))
|
||||||
self.assertEqual("display_name 1", res["display_name"])
|
self.assertEqual("display_name 1", res["display_name"])
|
||||||
|
|
||||||
# do a no-op first
|
# do a no-op first
|
||||||
yield defer.ensureDeferred(self.store.update_device("user_id", "device_id"))
|
self.get_success(self.store.update_device("user_id", "device_id"))
|
||||||
res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
|
res = self.get_success(self.store.get_device("user_id", "device_id"))
|
||||||
self.assertEqual("display_name 1", res["display_name"])
|
self.assertEqual("display_name 1", res["display_name"])
|
||||||
|
|
||||||
# do the update
|
# do the update
|
||||||
yield defer.ensureDeferred(
|
self.get_success(
|
||||||
self.store.update_device(
|
self.store.update_device(
|
||||||
"user_id", "device_id", new_display_name="display_name 2"
|
"user_id", "device_id", new_display_name="display_name 2"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# check it worked
|
# check it worked
|
||||||
res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
|
res = self.get_success(self.store.get_device("user_id", "device_id"))
|
||||||
self.assertEqual("display_name 2", res["display_name"])
|
self.assertEqual("display_name 2", res["display_name"])
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_update_unknown_device(self):
|
def test_update_unknown_device(self):
|
||||||
with self.assertRaises(synapse.api.errors.StoreError) as cm:
|
exc = self.get_failure(
|
||||||
yield defer.ensureDeferred(
|
|
||||||
self.store.update_device(
|
self.store.update_device(
|
||||||
"user_id", "unknown_device_id", new_display_name="display_name 2"
|
"user_id", "unknown_device_id", new_display_name="display_name 2"
|
||||||
|
),
|
||||||
|
synapse.api.errors.StoreError,
|
||||||
)
|
)
|
||||||
)
|
self.assertEqual(404, exc.value.code)
|
||||||
self.assertEqual(404, cm.exception.code)
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright 2014-2016 OpenMarket Ltd
|
# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -13,28 +13,20 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.types import RoomAlias, RoomID
|
from synapse.types import RoomAlias, RoomID
|
||||||
|
|
||||||
from tests import unittest
|
from tests.unittest import HomeserverTestCase
|
||||||
from tests.utils import setup_test_homeserver
|
|
||||||
|
|
||||||
|
|
||||||
class DirectoryStoreTestCase(unittest.TestCase):
|
class DirectoryStoreTestCase(HomeserverTestCase):
|
||||||
@defer.inlineCallbacks
|
def prepare(self, reactor, clock, hs):
|
||||||
def setUp(self):
|
|
||||||
hs = yield setup_test_homeserver(self.addCleanup)
|
|
||||||
|
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
self.room = RoomID.from_string("!abcde:test")
|
self.room = RoomID.from_string("!abcde:test")
|
||||||
self.alias = RoomAlias.from_string("#my-room:test")
|
self.alias = RoomAlias.from_string("#my-room:test")
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_room_to_alias(self):
|
def test_room_to_alias(self):
|
||||||
yield defer.ensureDeferred(
|
self.get_success(
|
||||||
self.store.create_room_alias_association(
|
self.store.create_room_alias_association(
|
||||||
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
|
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
|
||||||
)
|
)
|
||||||
|
@ -42,16 +34,11 @@ class DirectoryStoreTestCase(unittest.TestCase):
|
||||||
|
|
||||||
self.assertEquals(
|
self.assertEquals(
|
||||||
["#my-room:test"],
|
["#my-room:test"],
|
||||||
(
|
(self.get_success(self.store.get_aliases_for_room(self.room.to_string()))),
|
||||||
yield defer.ensureDeferred(
|
|
||||||
self.store.get_aliases_for_room(self.room.to_string())
|
|
||||||
)
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_alias_to_room(self):
|
def test_alias_to_room(self):
|
||||||
yield defer.ensureDeferred(
|
self.get_success(
|
||||||
self.store.create_room_alias_association(
|
self.store.create_room_alias_association(
|
||||||
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
|
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
|
||||||
)
|
)
|
||||||
|
@ -59,28 +46,19 @@ class DirectoryStoreTestCase(unittest.TestCase):
|
||||||
|
|
||||||
self.assertObjectHasAttributes(
|
self.assertObjectHasAttributes(
|
||||||
{"room_id": self.room.to_string(), "servers": ["test"]},
|
{"room_id": self.room.to_string(), "servers": ["test"]},
|
||||||
(
|
(self.get_success(self.store.get_association_from_room_alias(self.alias))),
|
||||||
yield defer.ensureDeferred(
|
|
||||||
self.store.get_association_from_room_alias(self.alias)
|
|
||||||
)
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_delete_alias(self):
|
def test_delete_alias(self):
|
||||||
yield defer.ensureDeferred(
|
self.get_success(
|
||||||
self.store.create_room_alias_association(
|
self.store.create_room_alias_association(
|
||||||
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
|
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
room_id = yield defer.ensureDeferred(self.store.delete_room_alias(self.alias))
|
room_id = self.get_success(self.store.delete_room_alias(self.alias))
|
||||||
self.assertEqual(self.room.to_string(), room_id)
|
self.assertEqual(self.room.to_string(), room_id)
|
||||||
|
|
||||||
self.assertIsNone(
|
self.assertIsNone(
|
||||||
(
|
(self.get_success(self.store.get_association_from_room_alias(self.alias)))
|
||||||
yield defer.ensureDeferred(
|
|
||||||
self.store.get_association_from_room_alias(self.alias)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue