Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes

michaelkaye/remove_warning
Erik Johnston 2021-01-26 14:15:26 +00:00
commit 512e313f18
77 changed files with 1809 additions and 442 deletions

View File

@ -10,4 +10,7 @@ apt-get install -y python3.5 python3.5-dev python3-pip libxml2-dev libxslt-dev x
export LANG="C.UTF-8" export LANG="C.UTF-8"
# Prevent virtualenv from auto-updating pip to an incompatible version
export VIRTUALENV_NO_DOWNLOAD=1
exec tox -e py35-old,combine exec tox -e py35-old,combine

View File

@ -1,3 +1,20 @@
Synapse 1.26.0rc2 (2021-01-25)
==============================
Bugfixes
--------
- Fix receipts and account data not being sent down sync. Introduced in v1.26.0rc1. ([\#9193](https://github.com/matrix-org/synapse/issues/9193), [\#9195](https://github.com/matrix-org/synapse/issues/9195))
- Fix chain cover update to handle events with duplicate auth events. Introduced in v1.26.0rc1. ([\#9210](https://github.com/matrix-org/synapse/issues/9210))
Internal Changes
----------------
- Add an `oidc-` prefix to any `idp_id`s which are given in the `oidc_providers` configuration. ([\#9189](https://github.com/matrix-org/synapse/issues/9189))
- Bump minimum `psycopg2` version to v2.8. ([\#9204](https://github.com/matrix-org/synapse/issues/9204))
Synapse 1.26.0rc1 (2021-01-20) Synapse 1.26.0rc1 (2021-01-20)
============================== ==============================

1
changelog.d/9045.misc Normal file
View File

@ -0,0 +1 @@
Add tests to `test_user.UsersListTestCase` for List Users Admin API.

1
changelog.d/9062.feature Normal file
View File

@ -0,0 +1 @@
Add admin API for getting and deleting forward extremities for a room.

1
changelog.d/9121.bugfix Normal file
View File

@ -0,0 +1 @@
Fix spurious errors in logs when deleting a non-existant pusher.

1
changelog.d/9129.misc Normal file
View File

@ -0,0 +1 @@
Various improvements to the federation client.

1
changelog.d/9135.doc Normal file
View File

@ -0,0 +1 @@
Add link to Matrix VoIP tester for turn-howto.

1
changelog.d/9163.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a long-standing bug where Synapse would return a 500 error when a thumbnail did not exist (and auto-generation of thumbnails was not enabled).

1
changelog.d/9164.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a long-standing bug where an internal server error was raised when attempting to preview an HTML document in an unknown character encoding.

1
changelog.d/9165.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a long-standing bug where invalid data could cause errors when calculating the presentable room name for push.

1
changelog.d/9176.misc Normal file
View File

@ -0,0 +1 @@
Speed up chain cover calculation when persisting a batch of state events at once.

1
changelog.d/9180.misc Normal file
View File

@ -0,0 +1 @@
Add a `long_description_type` to the package metadata.

1
changelog.d/9181.misc Normal file
View File

@ -0,0 +1 @@
Speed up batch insertion when using PostgreSQL.

1
changelog.d/9184.misc Normal file
View File

@ -0,0 +1 @@
Emit an error at startup if different Identity Providers are configured with the same `idp_id`.

1
changelog.d/9188.misc Normal file
View File

@ -0,0 +1 @@
Speed up batch insertion when using PostgreSQL.

View File

@ -1 +0,0 @@
Add an `oidc-` prefix to any `idp_id`s which are given in the `oidc_providers` configuration.

1
changelog.d/9190.misc Normal file
View File

@ -0,0 +1 @@
Improve performance of concurrent use of `StreamIDGenerators`.

1
changelog.d/9191.misc Normal file
View File

@ -0,0 +1 @@
Add some missing source directories to the automatic linting script.

View File

@ -1 +0,0 @@
Fix receipts or account data not being sent down sync. Introduced in v1.26.0rc1.

View File

@ -1 +0,0 @@
Fix receipts or account data not being sent down sync. Introduced in v1.26.0rc1.

1
changelog.d/9198.misc Normal file
View File

@ -0,0 +1 @@
Precompute joined hosts and store in Redis.

View File

@ -1 +0,0 @@
Bump minimum `psycopg2` version.

1
changelog.d/9209.feature Normal file
View File

@ -0,0 +1 @@
Add an admin API endpoint for shadow-banning users.

View File

@ -1 +0,0 @@
Fix chain cover update to handle events with duplicate auth events. Introduced in v1.26.0rc1.

1
changelog.d/9217.misc Normal file
View File

@ -0,0 +1 @@
Fix the Python 3.5 old dependencies build.

1
changelog.d/9218.bugfix Normal file
View File

@ -0,0 +1 @@
Fix bug where we sometimes didn't detect that Redis connections had died, causing workers to not see new data.

1
changelog.d/9222.misc Normal file
View File

@ -0,0 +1 @@
Update `isort` to v5.7.0 to bypass a bug where it would disagree with `black` about formatting.

View File

@ -9,6 +9,7 @@
* [Response](#response) * [Response](#response)
* [Undoing room shutdowns](#undoing-room-shutdowns) * [Undoing room shutdowns](#undoing-room-shutdowns)
- [Make Room Admin API](#make-room-admin-api) - [Make Room Admin API](#make-room-admin-api)
- [Forward Extremities Admin API](#forward-extremities-admin-api)
# List Room API # List Room API
@ -511,3 +512,55 @@ optionally be specified, e.g.:
"user_id": "@foo:example.com" "user_id": "@foo:example.com"
} }
``` ```
# Forward Extremities Admin API
Enables querying and deleting forward extremities from rooms. When a lot of forward
extremities accumulate in a room, performance can become degraded. For details, see
[#1760](https://github.com/matrix-org/synapse/issues/1760).
## Check for forward extremities
To check the status of forward extremities for a room:
```
GET /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities
```
A response as follows will be returned:
```json
{
"count": 1,
"results": [
{
"event_id": "$M5SP266vsnxctfwFgFLNceaCo3ujhRtg_NiiHabcdefgh",
"state_group": 439,
"depth": 123,
"received_ts": 1611263016761
}
]
}
```
## Deleting forward extremities
**WARNING**: Please ensure you know what you're doing and have read
the related issue [#1760](https://github.com/matrix-org/synapse/issues/1760).
Under no situations should this API be executed as an automated maintenance task!
If a room has lots of forward extremities, the extra can be
deleted as follows:
```
DELETE /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities
```
A response as follows will be returned, indicating the amount of forward extremities
that were deleted.
```json
{
"deleted": 1
}
```

View File

@ -760,3 +760,33 @@ The following fields are returned in the JSON response body:
- ``total`` - integer - Number of pushers. - ``total`` - integer - Number of pushers.
See also `Client-Server API Spec <https://matrix.org/docs/spec/client_server/latest#get-matrix-client-r0-pushers>`_ See also `Client-Server API Spec <https://matrix.org/docs/spec/client_server/latest#get-matrix-client-r0-pushers>`_
Shadow-banning users
====================
Shadow-banning is a useful tool for moderating malicious or egregiously abusive users.
A shadow-banned users receives successful responses to their client-server API requests,
but the events are not propagated into rooms. This can be an effective tool as it
(hopefully) takes longer for the user to realise they are being moderated before
pivoting to another account.
Shadow-banning a user should be used as a tool of last resort and may lead to confusing
or broken behaviour for the client. A shadow-banned user will not receive any
notification and it is generally more appropriate to ban or kick abusive users.
A shadow-banned user will be unable to contact anyone on the server.
The API is::
POST /_synapse/admin/v1/users/<user_id>/shadow_ban
To use it, you will need to authenticate by providing an ``access_token`` for a
server admin: see `README.rst <README.rst>`_.
An empty JSON dict is returned.
**Parameters**
The following parameters should be set in the URL:
- ``user_id`` - The fully qualified MXID: for example, ``@user:server.com``. The user must
be local.

View File

@ -232,6 +232,12 @@ Here are a few things to try:
(Understanding the output is beyond the scope of this document!) (Understanding the output is beyond the scope of this document!)
* You can test your Matrix homeserver TURN setup with https://test.voip.librepush.net/.
Note that this test is not fully reliable yet, so don't be discouraged if
the test fails.
[Here](https://github.com/matrix-org/voip-tester) is the github repo of the
source of the tester, where you can file bug reports.
* There is a WebRTC test tool at * There is a WebRTC test tool at
https://webrtc.github.io/samples/src/content/peerconnection/trickle-ice/. To https://webrtc.github.io/samples/src/content/peerconnection/trickle-ice/. To
use it, you will need a username/password for your TURN server. You can use it, you will need a username/password for your TURN server. You can

View File

@ -80,7 +80,8 @@ else
# then lint everything! # then lint everything!
if [[ -z ${files+x} ]]; then if [[ -z ${files+x} ]]; then
# Lint all source code files and directories # Lint all source code files and directories
files=("synapse" "tests" "scripts-dev" "scripts" "contrib" "synctl" "setup.py" "synmark") # Note: this list aims the mirror the one in tox.ini
files=("synapse" "docker" "tests" "scripts-dev" "scripts" "contrib" "synctl" "setup.py" "synmark" "stubs" ".buildkite")
fi fi
fi fi

View File

@ -96,7 +96,7 @@ CONDITIONAL_REQUIREMENTS["all"] = list(ALL_OPTIONAL_REQUIREMENTS)
# #
# We pin black so that our tests don't start failing on new releases. # We pin black so that our tests don't start failing on new releases.
CONDITIONAL_REQUIREMENTS["lint"] = [ CONDITIONAL_REQUIREMENTS["lint"] = [
"isort==5.0.3", "isort==5.7.0",
"black==19.10b0", "black==19.10b0",
"flake8-comprehensions", "flake8-comprehensions",
"flake8", "flake8",
@ -121,6 +121,7 @@ setup(
include_package_data=True, include_package_data=True,
zip_safe=False, zip_safe=False,
long_description=long_description, long_description=long_description,
long_description_content_type="text/x-rst",
python_requires="~=3.5", python_requires="~=3.5",
classifiers=[ classifiers=[
"Development Status :: 5 - Production/Stable", "Development Status :: 5 - Production/Stable",

View File

@ -15,13 +15,23 @@
"""Contains *incomplete* type hints for txredisapi. """Contains *incomplete* type hints for txredisapi.
""" """
from typing import Any, List, Optional, Type, Union
from typing import List, Optional, Type, Union
class RedisProtocol: class RedisProtocol:
def publish(self, channel: str, message: bytes): ... def publish(self, channel: str, message: bytes): ...
async def ping(self) -> None: ...
async def set(
self,
key: str,
value: Any,
expire: Optional[int] = None,
pexpire: Optional[int] = None,
only_if_not_exists: bool = False,
only_if_exists: bool = False,
) -> None: ...
async def get(self, key: str) -> Any: ...
class SubscriberProtocol: class SubscriberProtocol(RedisProtocol):
def __init__(self, *args, **kwargs): ... def __init__(self, *args, **kwargs): ...
password: Optional[str] password: Optional[str]
def subscribe(self, channels: Union[str, List[str]]): ... def subscribe(self, channels: Union[str, List[str]]): ...
@ -40,14 +50,13 @@ def lazyConnection(
convertNumbers: bool = ..., convertNumbers: bool = ...,
) -> RedisProtocol: ... ) -> RedisProtocol: ...
class SubscriberFactory:
def buildProtocol(self, addr): ...
class ConnectionHandler: ... class ConnectionHandler: ...
class RedisFactory: class RedisFactory:
continueTrying: bool continueTrying: bool
handler: RedisProtocol handler: RedisProtocol
pool: List[RedisProtocol]
replyTimeout: Optional[int]
def __init__( def __init__(
self, self,
uuid: str, uuid: str,
@ -60,3 +69,7 @@ class RedisFactory:
replyTimeout: Optional[int] = None, replyTimeout: Optional[int] = None,
convertNumbers: Optional[int] = True, convertNumbers: Optional[int] = True,
): ... ): ...
def buildProtocol(self, addr) -> RedisProtocol: ...
class SubscriberFactory(RedisFactory):
def __init__(self): ...

View File

@ -48,7 +48,7 @@ try:
except ImportError: except ImportError:
pass pass
__version__ = "1.26.0rc1" __version__ = "1.26.0rc2"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when # We import here so that we don't have to install a bunch of deps when

View File

@ -18,6 +18,7 @@ from synapse.config import (
password_auth_providers, password_auth_providers,
push, push,
ratelimiting, ratelimiting,
redis,
registration, registration,
repository, repository,
room_directory, room_directory,
@ -79,6 +80,7 @@ class RootConfig:
roomdirectory: room_directory.RoomDirectoryConfig roomdirectory: room_directory.RoomDirectoryConfig
thirdpartyrules: third_party_event_rules.ThirdPartyRulesConfig thirdpartyrules: third_party_event_rules.ThirdPartyRulesConfig
tracer: tracer.TracerConfig tracer: tracer.TracerConfig
redis: redis.RedisConfig
config_classes: List = ... config_classes: List = ...
def __init__(self) -> None: ... def __init__(self) -> None: ...

View File

@ -15,6 +15,7 @@
# limitations under the License. # limitations under the License.
import string import string
from collections import Counter
from typing import Iterable, Optional, Tuple, Type from typing import Iterable, Optional, Tuple, Type
import attr import attr
@ -43,6 +44,16 @@ class OIDCConfig(Config):
except DependencyException as e: except DependencyException as e:
raise ConfigError(e.message) from e raise ConfigError(e.message) from e
# check we don't have any duplicate idp_ids now. (The SSO handler will also
# check for duplicates when the REST listeners get registered, but that happens
# after synapse has forked so doesn't give nice errors.)
c = Counter([i.idp_id for i in self.oidc_providers])
for idp_id, count in c.items():
if count > 1:
raise ConfigError(
"Multiple OIDC providers have the idp_id %r." % idp_id
)
public_baseurl = self.public_baseurl public_baseurl = self.public_baseurl
self.oidc_callback_url = public_baseurl + "_synapse/oidc/callback" self.oidc_callback_url = public_baseurl + "_synapse/oidc/callback"

View File

@ -18,6 +18,7 @@ import copy
import itertools import itertools
import logging import logging
from typing import ( from typing import (
TYPE_CHECKING,
Any, Any,
Awaitable, Awaitable,
Callable, Callable,
@ -26,7 +27,6 @@ from typing import (
List, List,
Mapping, Mapping,
Optional, Optional,
Sequence,
Tuple, Tuple,
TypeVar, TypeVar,
Union, Union,
@ -61,6 +61,9 @@ from synapse.util import unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
sent_queries_counter = Counter("synapse_federation_client_sent_queries", "", ["type"]) sent_queries_counter = Counter("synapse_federation_client_sent_queries", "", ["type"])
@ -80,10 +83,10 @@ class InvalidResponseError(RuntimeError):
class FederationClient(FederationBase): class FederationClient(FederationBase):
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.pdu_destination_tried = {} self.pdu_destination_tried = {} # type: Dict[str, Dict[str, int]]
self._clock.looping_call(self._clear_tried_cache, 60 * 1000) self._clock.looping_call(self._clear_tried_cache, 60 * 1000)
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.transport_layer = hs.get_federation_transport_client() self.transport_layer = hs.get_federation_transport_client()
@ -116,33 +119,32 @@ class FederationClient(FederationBase):
self.pdu_destination_tried[event_id] = destination_dict self.pdu_destination_tried[event_id] = destination_dict
@log_function @log_function
def make_query( async def make_query(
self, self,
destination, destination: str,
query_type, query_type: str,
args, args: dict,
retry_on_dns_fail=False, retry_on_dns_fail: bool = False,
ignore_backoff=False, ignore_backoff: bool = False,
): ) -> JsonDict:
"""Sends a federation Query to a remote homeserver of the given type """Sends a federation Query to a remote homeserver of the given type
and arguments. and arguments.
Args: Args:
destination (str): Domain name of the remote homeserver destination: Domain name of the remote homeserver
query_type (str): Category of the query type; should match the query_type: Category of the query type; should match the
handler name used in register_query_handler(). handler name used in register_query_handler().
args (dict): Mapping of strings to strings containing the details args: Mapping of strings to strings containing the details
of the query request. of the query request.
ignore_backoff (bool): true to ignore the historical backoff data ignore_backoff: true to ignore the historical backoff data
and try the request anyway. and try the request anyway.
Returns: Returns:
a Awaitable which will eventually yield a JSON object from the The JSON object from the response
response
""" """
sent_queries_counter.labels(query_type).inc() sent_queries_counter.labels(query_type).inc()
return self.transport_layer.make_query( return await self.transport_layer.make_query(
destination, destination,
query_type, query_type,
args, args,
@ -151,42 +153,52 @@ class FederationClient(FederationBase):
) )
@log_function @log_function
def query_client_keys(self, destination, content, timeout): async def query_client_keys(
self, destination: str, content: JsonDict, timeout: int
) -> JsonDict:
"""Query device keys for a device hosted on a remote server. """Query device keys for a device hosted on a remote server.
Args: Args:
destination (str): Domain name of the remote homeserver destination: Domain name of the remote homeserver
content (dict): The query content. content: The query content.
Returns: Returns:
an Awaitable which will eventually yield a JSON object from the The JSON object from the response
response
""" """
sent_queries_counter.labels("client_device_keys").inc() sent_queries_counter.labels("client_device_keys").inc()
return self.transport_layer.query_client_keys(destination, content, timeout) return await self.transport_layer.query_client_keys(
destination, content, timeout
)
@log_function @log_function
def query_user_devices(self, destination, user_id, timeout=30000): async def query_user_devices(
self, destination: str, user_id: str, timeout: int = 30000
) -> JsonDict:
"""Query the device keys for a list of user ids hosted on a remote """Query the device keys for a list of user ids hosted on a remote
server. server.
""" """
sent_queries_counter.labels("user_devices").inc() sent_queries_counter.labels("user_devices").inc()
return self.transport_layer.query_user_devices(destination, user_id, timeout) return await self.transport_layer.query_user_devices(
destination, user_id, timeout
)
@log_function @log_function
def claim_client_keys(self, destination, content, timeout): async def claim_client_keys(
self, destination: str, content: JsonDict, timeout: int
) -> JsonDict:
"""Claims one-time keys for a device hosted on a remote server. """Claims one-time keys for a device hosted on a remote server.
Args: Args:
destination (str): Domain name of the remote homeserver destination: Domain name of the remote homeserver
content (dict): The query content. content: The query content.
Returns: Returns:
an Awaitable which will eventually yield a JSON object from the The JSON object from the response
response
""" """
sent_queries_counter.labels("client_one_time_keys").inc() sent_queries_counter.labels("client_one_time_keys").inc()
return self.transport_layer.claim_client_keys(destination, content, timeout) return await self.transport_layer.claim_client_keys(
destination, content, timeout
)
async def backfill( async def backfill(
self, dest: str, room_id: str, limit: int, extremities: Iterable[str] self, dest: str, room_id: str, limit: int, extremities: Iterable[str]
@ -195,10 +207,10 @@ class FederationClient(FederationBase):
given destination server. given destination server.
Args: Args:
dest (str): The remote homeserver to ask. dest: The remote homeserver to ask.
room_id (str): The room_id to backfill. room_id: The room_id to backfill.
limit (int): The maximum number of events to return. limit: The maximum number of events to return.
extremities (list): our current backwards extremities, to backfill from extremities: our current backwards extremities, to backfill from
""" """
logger.debug("backfill extrem=%s", extremities) logger.debug("backfill extrem=%s", extremities)
@ -370,7 +382,7 @@ class FederationClient(FederationBase):
for events that have failed their checks for events that have failed their checks
Returns: Returns:
Deferred : A list of PDUs that have valid signatures and hashes. A list of PDUs that have valid signatures and hashes.
""" """
deferreds = self._check_sigs_and_hashes(room_version, pdus) deferreds = self._check_sigs_and_hashes(room_version, pdus)
@ -418,7 +430,9 @@ class FederationClient(FederationBase):
else: else:
return [p for p in valid_pdus if p] return [p for p in valid_pdus if p]
async def get_event_auth(self, destination, room_id, event_id): async def get_event_auth(
self, destination: str, room_id: str, event_id: str
) -> List[EventBase]:
res = await self.transport_layer.get_event_auth(destination, room_id, event_id) res = await self.transport_layer.get_event_auth(destination, room_id, event_id)
room_version = await self.store.get_room_version(room_id) room_version = await self.store.get_room_version(room_id)
@ -700,18 +714,16 @@ class FederationClient(FederationBase):
return await self._try_destination_list("send_join", destinations, send_request) return await self._try_destination_list("send_join", destinations, send_request)
async def _do_send_join(self, destination: str, pdu: EventBase): async def _do_send_join(self, destination: str, pdu: EventBase) -> JsonDict:
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
try: try:
content = await self.transport_layer.send_join_v2( return await self.transport_layer.send_join_v2(
destination=destination, destination=destination,
room_id=pdu.room_id, room_id=pdu.room_id,
event_id=pdu.event_id, event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now), content=pdu.get_pdu_json(time_now),
) )
return content
except HttpResponseException as e: except HttpResponseException as e:
if e.code in [400, 404]: if e.code in [400, 404]:
err = e.to_synapse_error() err = e.to_synapse_error()
@ -769,7 +781,7 @@ class FederationClient(FederationBase):
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
try: try:
content = await self.transport_layer.send_invite_v2( return await self.transport_layer.send_invite_v2(
destination=destination, destination=destination,
room_id=pdu.room_id, room_id=pdu.room_id,
event_id=pdu.event_id, event_id=pdu.event_id,
@ -779,7 +791,6 @@ class FederationClient(FederationBase):
"invite_room_state": pdu.unsigned.get("invite_room_state", []), "invite_room_state": pdu.unsigned.get("invite_room_state", []),
}, },
) )
return content
except HttpResponseException as e: except HttpResponseException as e:
if e.code in [400, 404]: if e.code in [400, 404]:
err = e.to_synapse_error() err = e.to_synapse_error()
@ -842,18 +853,16 @@ class FederationClient(FederationBase):
"send_leave", destinations, send_request "send_leave", destinations, send_request
) )
async def _do_send_leave(self, destination, pdu): async def _do_send_leave(self, destination: str, pdu: EventBase) -> JsonDict:
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
try: try:
content = await self.transport_layer.send_leave_v2( return await self.transport_layer.send_leave_v2(
destination=destination, destination=destination,
room_id=pdu.room_id, room_id=pdu.room_id,
event_id=pdu.event_id, event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now), content=pdu.get_pdu_json(time_now),
) )
return content
except HttpResponseException as e: except HttpResponseException as e:
if e.code in [400, 404]: if e.code in [400, 404]:
err = e.to_synapse_error() err = e.to_synapse_error()
@ -879,7 +888,7 @@ class FederationClient(FederationBase):
# content. # content.
return resp[1] return resp[1]
def get_public_rooms( async def get_public_rooms(
self, self,
remote_server: str, remote_server: str,
limit: Optional[int] = None, limit: Optional[int] = None,
@ -887,7 +896,7 @@ class FederationClient(FederationBase):
search_filter: Optional[Dict] = None, search_filter: Optional[Dict] = None,
include_all_networks: bool = False, include_all_networks: bool = False,
third_party_instance_id: Optional[str] = None, third_party_instance_id: Optional[str] = None,
): ) -> JsonDict:
"""Get the list of public rooms from a remote homeserver """Get the list of public rooms from a remote homeserver
Args: Args:
@ -901,8 +910,7 @@ class FederationClient(FederationBase):
party instance party instance
Returns: Returns:
Awaitable[Dict[str, Any]]: The response from the remote server, or None if The response from the remote server.
`remote_server` is the same as the local server_name
Raises: Raises:
HttpResponseException: There was an exception returned from the remote server HttpResponseException: There was an exception returned from the remote server
@ -910,7 +918,7 @@ class FederationClient(FederationBase):
requests over federation requests over federation
""" """
return self.transport_layer.get_public_rooms( return await self.transport_layer.get_public_rooms(
remote_server, remote_server,
limit, limit,
since_token, since_token,
@ -923,7 +931,7 @@ class FederationClient(FederationBase):
self, self,
destination: str, destination: str,
room_id: str, room_id: str,
earliest_events_ids: Sequence[str], earliest_events_ids: Iterable[str],
latest_events: Iterable[EventBase], latest_events: Iterable[EventBase],
limit: int, limit: int,
min_depth: int, min_depth: int,
@ -974,7 +982,9 @@ class FederationClient(FederationBase):
return signed_events return signed_events
async def forward_third_party_invite(self, destinations, room_id, event_dict): async def forward_third_party_invite(
self, destinations: Iterable[str], room_id: str, event_dict: JsonDict
) -> None:
for destination in destinations: for destination in destinations:
if destination == self.server_name: if destination == self.server_name:
continue continue
@ -983,7 +993,7 @@ class FederationClient(FederationBase):
await self.transport_layer.exchange_third_party_invite( await self.transport_layer.exchange_third_party_invite(
destination=destination, room_id=room_id, event_dict=event_dict destination=destination, room_id=room_id, event_dict=event_dict
) )
return None return
except CodeMessageException: except CodeMessageException:
raise raise
except Exception as e: except Exception as e:
@ -995,7 +1005,7 @@ class FederationClient(FederationBase):
async def get_room_complexity( async def get_room_complexity(
self, destination: str, room_id: str self, destination: str, room_id: str
) -> Optional[dict]: ) -> Optional[JsonDict]:
""" """
Fetch the complexity of a remote room from another server. Fetch the complexity of a remote room from another server.
@ -1008,10 +1018,9 @@ class FederationClient(FederationBase):
could not fetch the complexity. could not fetch the complexity.
""" """
try: try:
complexity = await self.transport_layer.get_room_complexity( return await self.transport_layer.get_room_complexity(
destination=destination, room_id=room_id destination=destination, room_id=room_id
) )
return complexity
except CodeMessageException as e: except CodeMessageException as e:
# We didn't manage to get it -- probably a 404. We are okay if other # We didn't manage to get it -- probably a 404. We are okay if other
# servers don't give it to us. # servers don't give it to us.

View File

@ -142,6 +142,8 @@ class FederationSender:
self._wake_destinations_needing_catchup, self._wake_destinations_needing_catchup,
) )
self._external_cache = hs.get_external_cache()
def _get_per_destination_queue(self, destination: str) -> PerDestinationQueue: def _get_per_destination_queue(self, destination: str) -> PerDestinationQueue:
"""Get or create a PerDestinationQueue for the given destination """Get or create a PerDestinationQueue for the given destination
@ -197,22 +199,40 @@ class FederationSender:
if not event.internal_metadata.should_proactively_send(): if not event.internal_metadata.should_proactively_send():
return return
try: destinations = None # type: Optional[Set[str]]
# Get the state from before the event. if not event.prev_event_ids():
# We need to make sure that this is the state from before # If there are no prev event IDs then the state is empty
# the event and not from after it. # and so no remote servers in the room
# Otherwise if the last member on a server in a room is destinations = set()
# banned then it won't receive the event because it won't else:
# be in the room after the ban. # We check the external cache for the destinations, which is
destinations = await self.state.get_hosts_in_room_at_events( # stored per state group.
event.room_id, event_ids=event.prev_event_ids()
sg = await self._external_cache.get(
"event_to_prev_state_group", event.event_id
) )
except Exception: if sg:
logger.exception( destinations = await self._external_cache.get(
"Failed to calculate hosts in room for event: %s", "get_joined_hosts", str(sg)
event.event_id, )
)
return if destinations is None:
try:
# Get the state from before the event.
# We need to make sure that this is the state from before
# the event and not from after it.
# Otherwise if the last member on a server in a room is
# banned then it won't receive the event because it won't
# be in the room after the ban.
destinations = await self.state.get_hosts_in_room_at_events(
event.room_id, event_ids=event.prev_event_ids()
)
except Exception:
logger.exception(
"Failed to calculate hosts in room for event: %s",
event.event_id,
)
return
destinations = { destinations = {
d d

View File

@ -2093,6 +2093,11 @@ class FederationHandler(BaseHandler):
if event.type == EventTypes.GuestAccess and not context.rejected: if event.type == EventTypes.GuestAccess and not context.rejected:
await self.maybe_kick_guest_users(event) await self.maybe_kick_guest_users(event)
# If we are going to send this event over federation we precaclculate
# the joined hosts.
if event.internal_metadata.get_send_on_behalf_of():
await self.event_creation_handler.cache_joined_hosts_for_event(event)
return context return context
async def _check_for_soft_fail( async def _check_for_soft_fail(

View File

@ -432,6 +432,8 @@ class EventCreationHandler:
self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages
self._external_cache = hs.get_external_cache()
async def create_event( async def create_event(
self, self,
requester: Requester, requester: Requester,
@ -939,6 +941,8 @@ class EventCreationHandler:
await self.action_generator.handle_push_actions_for_event(event, context) await self.action_generator.handle_push_actions_for_event(event, context)
await self.cache_joined_hosts_for_event(event)
try: try:
# If we're a worker we need to hit out to the master. # If we're a worker we need to hit out to the master.
writer_instance = self._events_shard_config.get_instance(event.room_id) writer_instance = self._events_shard_config.get_instance(event.room_id)
@ -978,6 +982,44 @@ class EventCreationHandler:
await self.store.remove_push_actions_from_staging(event.event_id) await self.store.remove_push_actions_from_staging(event.event_id)
raise raise
async def cache_joined_hosts_for_event(self, event: EventBase) -> None:
"""Precalculate the joined hosts at the event, when using Redis, so that
external federation senders don't have to recalculate it themselves.
"""
if not self._external_cache.is_enabled():
return
# We actually store two mappings, event ID -> prev state group,
# state group -> joined hosts, which is much more space efficient
# than event ID -> joined hosts.
#
# Note: We have to cache event ID -> prev state group, as we don't
# store that in the DB.
#
# Note: We always set the state group -> joined hosts cache, even if
# we already set it, so that the expiry time is reset.
state_entry = await self.state.resolve_state_groups_for_events(
event.room_id, event_ids=event.prev_event_ids()
)
if state_entry.state_group:
joined_hosts = await self.store.get_joined_hosts(event.room_id, state_entry)
await self._external_cache.set(
"event_to_prev_state_group",
event.event_id,
state_entry.state_group,
expiry_ms=60 * 60 * 1000,
)
await self._external_cache.set(
"get_joined_hosts",
str(state_entry.state_group),
list(joined_hosts),
expiry_ms=60 * 60 * 1000,
)
async def _validate_canonical_alias( async def _validate_canonical_alias(
self, directory_handler, room_alias_str: str, expected_room_id: str self, directory_handler, room_alias_str: str, expected_room_id: str
) -> None: ) -> None:

View File

@ -17,7 +17,7 @@ import logging
import re import re
from typing import TYPE_CHECKING, Dict, Iterable, Optional from typing import TYPE_CHECKING, Dict, Iterable, Optional
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase from synapse.events import EventBase
from synapse.types import StateMap from synapse.types import StateMap
@ -63,7 +63,7 @@ async def calculate_room_name(
m_room_name = await store.get_event( m_room_name = await store.get_event(
room_state_ids[(EventTypes.Name, "")], allow_none=True room_state_ids[(EventTypes.Name, "")], allow_none=True
) )
if m_room_name and m_room_name.content and m_room_name.content["name"]: if m_room_name and m_room_name.content and m_room_name.content.get("name"):
return m_room_name.content["name"] return m_room_name.content["name"]
# does it have a canonical alias? # does it have a canonical alias?
@ -74,15 +74,11 @@ async def calculate_room_name(
if ( if (
canon_alias canon_alias
and canon_alias.content and canon_alias.content
and canon_alias.content["alias"] and canon_alias.content.get("alias")
and _looks_like_an_alias(canon_alias.content["alias"]) and _looks_like_an_alias(canon_alias.content["alias"])
): ):
return canon_alias.content["alias"] return canon_alias.content["alias"]
# at this point we're going to need to search the state by all state keys
# for an event type, so rearrange the data structure
room_state_bytype_ids = _state_as_two_level_dict(room_state_ids)
if not fallback_to_members: if not fallback_to_members:
return None return None
@ -94,7 +90,7 @@ async def calculate_room_name(
if ( if (
my_member_event is not None my_member_event is not None
and my_member_event.content["membership"] == "invite" and my_member_event.content.get("membership") == Membership.INVITE
): ):
if (EventTypes.Member, my_member_event.sender) in room_state_ids: if (EventTypes.Member, my_member_event.sender) in room_state_ids:
inviter_member_event = await store.get_event( inviter_member_event = await store.get_event(
@ -111,6 +107,10 @@ async def calculate_room_name(
else: else:
return "Room Invite" return "Room Invite"
# at this point we're going to need to search the state by all state keys
# for an event type, so rearrange the data structure
room_state_bytype_ids = _state_as_two_level_dict(room_state_ids)
# we're going to have to generate a name based on who's in the room, # we're going to have to generate a name based on who's in the room,
# so find out who is in the room that isn't the user. # so find out who is in the room that isn't the user.
if EventTypes.Member in room_state_bytype_ids: if EventTypes.Member in room_state_bytype_ids:
@ -120,8 +120,8 @@ async def calculate_room_name(
all_members = [ all_members = [
ev ev
for ev in member_events.values() for ev in member_events.values()
if ev.content["membership"] == "join" if ev.content.get("membership") == Membership.JOIN
or ev.content["membership"] == "invite" or ev.content.get("membership") == Membership.INVITE
] ]
# Sort the member events oldest-first so the we name people in the # Sort the member events oldest-first so the we name people in the
# order the joined (it should at least be deterministic rather than # order the joined (it should at least be deterministic rather than
@ -194,11 +194,7 @@ def descriptor_from_member_events(member_events: Iterable[EventBase]) -> str:
def name_from_member_event(member_event: EventBase) -> str: def name_from_member_event(member_event: EventBase) -> str:
if ( if member_event.content and member_event.content.get("displayname"):
member_event.content
and "displayname" in member_event.content
and member_event.content["displayname"]
):
return member_event.content["displayname"] return member_event.content["displayname"]
return member_event.state_key return member_event.state_key

View File

@ -0,0 +1,105 @@
# -*- coding: utf-8 -*-
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Any, Optional
from prometheus_client import Counter
from synapse.logging.context import make_deferred_yieldable
from synapse.util import json_decoder, json_encoder
if TYPE_CHECKING:
from synapse.server import HomeServer
set_counter = Counter(
"synapse_external_cache_set",
"Number of times we set a cache",
labelnames=["cache_name"],
)
get_counter = Counter(
"synapse_external_cache_get",
"Number of times we get a cache",
labelnames=["cache_name", "hit"],
)
logger = logging.getLogger(__name__)
class ExternalCache:
"""A cache backed by an external Redis. Does nothing if no Redis is
configured.
"""
def __init__(self, hs: "HomeServer"):
self._redis_connection = hs.get_outbound_redis_connection()
def _get_redis_key(self, cache_name: str, key: str) -> str:
return "cache_v1:%s:%s" % (cache_name, key)
def is_enabled(self) -> bool:
"""Whether the external cache is used or not.
It's safe to use the cache when this returns false, the methods will
just no-op, but the function is useful to avoid doing unnecessary work.
"""
return self._redis_connection is not None
async def set(self, cache_name: str, key: str, value: Any, expiry_ms: int) -> None:
"""Add the key/value to the named cache, with the expiry time given.
"""
if self._redis_connection is None:
return
set_counter.labels(cache_name).inc()
# txredisapi requires the value to be string, bytes or numbers, so we
# encode stuff in JSON.
encoded_value = json_encoder.encode(value)
logger.debug("Caching %s %s: %r", cache_name, key, encoded_value)
return await make_deferred_yieldable(
self._redis_connection.set(
self._get_redis_key(cache_name, key), encoded_value, pexpire=expiry_ms,
)
)
async def get(self, cache_name: str, key: str) -> Optional[Any]:
"""Look up a key/value in the named cache.
"""
if self._redis_connection is None:
return None
result = await make_deferred_yieldable(
self._redis_connection.get(self._get_redis_key(cache_name, key))
)
logger.debug("Got cache result %s %s: %r", cache_name, key, result)
get_counter.labels(cache_name, result is not None).inc()
if not result:
return None
# For some reason the integers get magically converted back to integers
if isinstance(result, int):
return result
return json_decoder.decode(result)

View File

@ -15,6 +15,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import ( from typing import (
TYPE_CHECKING,
Any, Any,
Awaitable, Awaitable,
Dict, Dict,
@ -63,6 +64,9 @@ from synapse.replication.tcp.streams import (
TypingStream, TypingStream,
) )
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -88,7 +92,7 @@ class ReplicationCommandHandler:
back out to connections. back out to connections.
""" """
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self._replication_data_handler = hs.get_replication_data_handler() self._replication_data_handler = hs.get_replication_data_handler()
self._presence_handler = hs.get_presence_handler() self._presence_handler = hs.get_presence_handler()
self._store = hs.get_datastore() self._store = hs.get_datastore()
@ -282,13 +286,6 @@ class ReplicationCommandHandler:
if hs.config.redis.redis_enabled: if hs.config.redis.redis_enabled:
from synapse.replication.tcp.redis import ( from synapse.replication.tcp.redis import (
RedisDirectTcpReplicationClientFactory, RedisDirectTcpReplicationClientFactory,
lazyConnection,
)
logger.info(
"Connecting to redis (host=%r port=%r)",
hs.config.redis_host,
hs.config.redis_port,
) )
# First let's ensure that we have a ReplicationStreamer started. # First let's ensure that we have a ReplicationStreamer started.
@ -299,13 +296,7 @@ class ReplicationCommandHandler:
# connection after SUBSCRIBE is called). # connection after SUBSCRIBE is called).
# First create the connection for sending commands. # First create the connection for sending commands.
outbound_redis_connection = lazyConnection( outbound_redis_connection = hs.get_outbound_redis_connection()
reactor=hs.get_reactor(),
host=hs.config.redis_host,
port=hs.config.redis_port,
password=hs.config.redis.redis_password,
reconnect=True,
)
# Now create the factory/connection for the subscription stream. # Now create the factory/connection for the subscription stream.
self._factory = RedisDirectTcpReplicationClientFactory( self._factory = RedisDirectTcpReplicationClientFactory(

View File

@ -15,7 +15,7 @@
import logging import logging
from inspect import isawaitable from inspect import isawaitable
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional, Type, cast
import txredisapi import txredisapi
@ -23,6 +23,7 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda
from synapse.metrics.background_process_metrics import ( from synapse.metrics.background_process_metrics import (
BackgroundProcessLoggingContext, BackgroundProcessLoggingContext,
run_as_background_process, run_as_background_process,
wrap_as_background_process,
) )
from synapse.replication.tcp.commands import ( from synapse.replication.tcp.commands import (
Command, Command,
@ -59,16 +60,16 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
immediately after initialisation. immediately after initialisation.
Attributes: Attributes:
handler: The command handler to handle incoming commands. synapse_handler: The command handler to handle incoming commands.
stream_name: The *redis* stream name to subscribe to and publish from synapse_stream_name: The *redis* stream name to subscribe to and publish
(not anything to do with Synapse replication streams). from (not anything to do with Synapse replication streams).
outbound_redis_connection: The connection to redis to use to send synapse_outbound_redis_connection: The connection to redis to use to send
commands. commands.
""" """
handler = None # type: ReplicationCommandHandler synapse_handler = None # type: ReplicationCommandHandler
stream_name = None # type: str synapse_stream_name = None # type: str
outbound_redis_connection = None # type: txredisapi.RedisProtocol synapse_outbound_redis_connection = None # type: txredisapi.RedisProtocol
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -88,19 +89,19 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
# it's important to make sure that we only send the REPLICATE command once we # it's important to make sure that we only send the REPLICATE command once we
# have successfully subscribed to the stream - otherwise we might miss the # have successfully subscribed to the stream - otherwise we might miss the
# POSITION response sent back by the other end. # POSITION response sent back by the other end.
logger.info("Sending redis SUBSCRIBE for %s", self.stream_name) logger.info("Sending redis SUBSCRIBE for %s", self.synapse_stream_name)
await make_deferred_yieldable(self.subscribe(self.stream_name)) await make_deferred_yieldable(self.subscribe(self.synapse_stream_name))
logger.info( logger.info(
"Successfully subscribed to redis stream, sending REPLICATE command" "Successfully subscribed to redis stream, sending REPLICATE command"
) )
self.handler.new_connection(self) self.synapse_handler.new_connection(self)
await self._async_send_command(ReplicateCommand()) await self._async_send_command(ReplicateCommand())
logger.info("REPLICATE successfully sent") logger.info("REPLICATE successfully sent")
# We send out our positions when there is a new connection in case the # We send out our positions when there is a new connection in case the
# other side missed updates. We do this for Redis connections as the # other side missed updates. We do this for Redis connections as the
# otherside won't know we've connected and so won't issue a REPLICATE. # otherside won't know we've connected and so won't issue a REPLICATE.
self.handler.send_positions_to_connection(self) self.synapse_handler.send_positions_to_connection(self)
def messageReceived(self, pattern: str, channel: str, message: str): def messageReceived(self, pattern: str, channel: str, message: str):
"""Received a message from redis. """Received a message from redis.
@ -137,7 +138,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
cmd: received command cmd: received command
""" """
cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None) cmd_func = getattr(self.synapse_handler, "on_%s" % (cmd.NAME,), None)
if not cmd_func: if not cmd_func:
logger.warning("Unhandled command: %r", cmd) logger.warning("Unhandled command: %r", cmd)
return return
@ -155,7 +156,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
def connectionLost(self, reason): def connectionLost(self, reason):
logger.info("Lost connection to redis") logger.info("Lost connection to redis")
super().connectionLost(reason) super().connectionLost(reason)
self.handler.lost_connection(self) self.synapse_handler.lost_connection(self)
# mark the logging context as finished # mark the logging context as finished
self._logging_context.__exit__(None, None, None) self._logging_context.__exit__(None, None, None)
@ -183,11 +184,54 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
tcp_outbound_commands_counter.labels(cmd.NAME, "redis").inc() tcp_outbound_commands_counter.labels(cmd.NAME, "redis").inc()
await make_deferred_yieldable( await make_deferred_yieldable(
self.outbound_redis_connection.publish(self.stream_name, encoded_string) self.synapse_outbound_redis_connection.publish(
self.synapse_stream_name, encoded_string
)
) )
class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory): class SynapseRedisFactory(txredisapi.RedisFactory):
"""A subclass of RedisFactory that periodically sends pings to ensure that
we detect dead connections.
"""
def __init__(
self,
hs: "HomeServer",
uuid: str,
dbid: Optional[int],
poolsize: int,
isLazy: bool = False,
handler: Type = txredisapi.ConnectionHandler,
charset: str = "utf-8",
password: Optional[str] = None,
replyTimeout: int = 30,
convertNumbers: Optional[int] = True,
):
super().__init__(
uuid=uuid,
dbid=dbid,
poolsize=poolsize,
isLazy=isLazy,
handler=handler,
charset=charset,
password=password,
replyTimeout=replyTimeout,
convertNumbers=convertNumbers,
)
hs.get_clock().looping_call(self._send_ping, 30 * 1000)
@wrap_as_background_process("redis_ping")
async def _send_ping(self):
for connection in self.pool:
try:
await make_deferred_yieldable(connection.ping())
except Exception:
logger.warning("Failed to send ping to a redis connection")
class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
"""This is a reconnecting factory that connects to redis and immediately """This is a reconnecting factory that connects to redis and immediately
subscribes to a stream. subscribes to a stream.
@ -206,65 +250,62 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
self, hs: "HomeServer", outbound_redis_connection: txredisapi.RedisProtocol self, hs: "HomeServer", outbound_redis_connection: txredisapi.RedisProtocol
): ):
super().__init__() super().__init__(
hs,
uuid="subscriber",
dbid=None,
poolsize=1,
replyTimeout=30,
password=hs.config.redis.redis_password,
)
# This sets the password on the RedisFactory base class (as self.synapse_handler = hs.get_tcp_replication()
# SubscriberFactory constructor doesn't pass it through). self.synapse_stream_name = hs.hostname
self.password = hs.config.redis.redis_password
self.handler = hs.get_tcp_replication() self.synapse_outbound_redis_connection = outbound_redis_connection
self.stream_name = hs.hostname
self.outbound_redis_connection = outbound_redis_connection
def buildProtocol(self, addr): def buildProtocol(self, addr):
p = super().buildProtocol(addr) # type: RedisSubscriber p = super().buildProtocol(addr)
p = cast(RedisSubscriber, p)
# We do this here rather than add to the constructor of `RedisSubcriber` # We do this here rather than add to the constructor of `RedisSubcriber`
# as to do so would involve overriding `buildProtocol` entirely, however # as to do so would involve overriding `buildProtocol` entirely, however
# the base method does some other things than just instantiating the # the base method does some other things than just instantiating the
# protocol. # protocol.
p.handler = self.handler p.synapse_handler = self.synapse_handler
p.outbound_redis_connection = self.outbound_redis_connection p.synapse_outbound_redis_connection = self.synapse_outbound_redis_connection
p.stream_name = self.stream_name p.synapse_stream_name = self.synapse_stream_name
p.password = self.password
return p return p
def lazyConnection( def lazyConnection(
reactor, hs: "HomeServer",
host: str = "localhost", host: str = "localhost",
port: int = 6379, port: int = 6379,
dbid: Optional[int] = None, dbid: Optional[int] = None,
reconnect: bool = True, reconnect: bool = True,
charset: str = "utf-8",
password: Optional[str] = None, password: Optional[str] = None,
connectTimeout: Optional[int] = None, replyTimeout: int = 30,
replyTimeout: Optional[int] = None,
convertNumbers: bool = True,
) -> txredisapi.RedisProtocol: ) -> txredisapi.RedisProtocol:
"""Equivalent to `txredisapi.lazyConnection`, except allows specifying a """Creates a connection to Redis that is lazily set up and reconnects if the
reactor. connections is lost.
""" """
isLazy = True
poolsize = 1
uuid = "%s:%d" % (host, port) uuid = "%s:%d" % (host, port)
factory = txredisapi.RedisFactory( factory = SynapseRedisFactory(
uuid, hs,
dbid, uuid=uuid,
poolsize, dbid=dbid,
isLazy, poolsize=1,
txredisapi.ConnectionHandler, isLazy=True,
charset, handler=txredisapi.ConnectionHandler,
password, password=password,
replyTimeout, replyTimeout=replyTimeout,
convertNumbers,
) )
factory.continueTrying = reconnect factory.continueTrying = reconnect
for x in range(poolsize):
reactor.connectTCP(host, port, factory, connectTimeout) reactor = hs.get_reactor()
reactor.connectTCP(host, port, factory, 30)
return factory.handler return factory.handler

View File

@ -1,6 +1,8 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018-2019 New Vector Ltd # Copyright 2018-2019 New Vector Ltd
# Copyright 2020, 2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -36,6 +38,7 @@ from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_medi
from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet
from synapse.rest.admin.rooms import ( from synapse.rest.admin.rooms import (
DeleteRoomRestServlet, DeleteRoomRestServlet,
ForwardExtremitiesRestServlet,
JoinRoomAliasServlet, JoinRoomAliasServlet,
ListRoomRestServlet, ListRoomRestServlet,
MakeRoomAdminRestServlet, MakeRoomAdminRestServlet,
@ -51,6 +54,7 @@ from synapse.rest.admin.users import (
PushersRestServlet, PushersRestServlet,
ResetPasswordRestServlet, ResetPasswordRestServlet,
SearchUsersRestServlet, SearchUsersRestServlet,
ShadowBanRestServlet,
UserAdminServlet, UserAdminServlet,
UserMediaRestServlet, UserMediaRestServlet,
UserMembershipRestServlet, UserMembershipRestServlet,
@ -230,6 +234,8 @@ def register_servlets(hs, http_server):
EventReportsRestServlet(hs).register(http_server) EventReportsRestServlet(hs).register(http_server)
PushersRestServlet(hs).register(http_server) PushersRestServlet(hs).register(http_server)
MakeRoomAdminRestServlet(hs).register(http_server) MakeRoomAdminRestServlet(hs).register(http_server)
ShadowBanRestServlet(hs).register(http_server)
ForwardExtremitiesRestServlet(hs).register(http_server)
def register_servlets_for_client_rest_resource(hs, http_server): def register_servlets_for_client_rest_resource(hs, http_server):

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2019 The Matrix.org Foundation C.I.C. # Copyright 2019-2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -499,3 +499,60 @@ class MakeRoomAdminRestServlet(RestServlet):
) )
return 200, {} return 200, {}
class ForwardExtremitiesRestServlet(RestServlet):
"""Allows a server admin to get or clear forward extremities.
Clearing does not require restarting the server.
Clear forward extremities:
DELETE /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities
Get forward_extremities:
GET /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities
"""
PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/forward_extremities")
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.room_member_handler = hs.get_room_member_handler()
self.store = hs.get_datastore()
async def resolve_room_id(self, room_identifier: str) -> str:
"""Resolve to a room ID, if necessary."""
if RoomID.is_valid(room_identifier):
resolved_room_id = room_identifier
elif RoomAlias.is_valid(room_identifier):
room_alias = RoomAlias.from_string(room_identifier)
room_id, _ = await self.room_member_handler.lookup_room_alias(room_alias)
resolved_room_id = room_id.to_string()
else:
raise SynapseError(
400, "%s was not legal room ID or room alias" % (room_identifier,)
)
if not resolved_room_id:
raise SynapseError(
400, "Unknown room ID or room alias %s" % room_identifier
)
return resolved_room_id
async def on_DELETE(self, request, room_identifier):
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
room_id = await self.resolve_room_id(room_identifier)
deleted_count = await self.store.delete_forward_extremities_for_room(room_id)
return 200, {"deleted": deleted_count}
async def on_GET(self, request, room_identifier):
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
room_id = await self.resolve_room_id(room_identifier)
extremities = await self.store.get_forward_extremities_for_room(room_id)
return 200, {"count": len(extremities), "results": extremities}

View File

@ -83,17 +83,32 @@ class UsersRestServletV2(RestServlet):
The parameter `deactivated` can be used to include deactivated users. The parameter `deactivated` can be used to include deactivated users.
""" """
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.hs = hs self.hs = hs
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler() self.admin_handler = hs.get_admin_handler()
async def on_GET(self, request): async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
start = parse_integer(request, "from", default=0) start = parse_integer(request, "from", default=0)
limit = parse_integer(request, "limit", default=100) limit = parse_integer(request, "limit", default=100)
if start < 0:
raise SynapseError(
400,
"Query parameter from must be a string representing a positive integer.",
errcode=Codes.INVALID_PARAM,
)
if limit < 0:
raise SynapseError(
400,
"Query parameter limit must be a string representing a positive integer.",
errcode=Codes.INVALID_PARAM,
)
user_id = parse_string(request, "user_id", default=None) user_id = parse_string(request, "user_id", default=None)
name = parse_string(request, "name", default=None) name = parse_string(request, "name", default=None)
guests = parse_boolean(request, "guests", default=True) guests = parse_boolean(request, "guests", default=True)
@ -103,7 +118,7 @@ class UsersRestServletV2(RestServlet):
start, limit, user_id, name, guests, deactivated start, limit, user_id, name, guests, deactivated
) )
ret = {"users": users, "total": total} ret = {"users": users, "total": total}
if len(users) >= limit: if (start + limit) < total:
ret["next_token"] = str(start + len(users)) ret["next_token"] = str(start + len(users))
return 200, ret return 200, ret
@ -875,3 +890,39 @@ class UserTokenRestServlet(RestServlet):
) )
return 200, {"access_token": token} return 200, {"access_token": token}
class ShadowBanRestServlet(RestServlet):
"""An admin API for shadow-banning a user.
A shadow-banned users receives successful responses to their client-server
API requests, but the events are not propagated into rooms.
Shadow-banning a user should be used as a tool of last resort and may lead
to confusing or broken behaviour for the client.
Example:
POST /_synapse/admin/v1/users/@test:example.com/shadow_ban
{}
200 OK
{}
"""
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/shadow_ban")
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
async def on_POST(self, request, user_id):
await assert_requester_is_admin(self.auth, request)
if not self.hs.is_mine_id(user_id):
raise SynapseError(400, "Only local users can be shadow-banned")
await self.store.set_shadow_banned(UserID.from_string(user_id), True)
return 200, {}

View File

@ -300,6 +300,7 @@ class FileInfo:
thumbnail_height (int) thumbnail_height (int)
thumbnail_method (str) thumbnail_method (str)
thumbnail_type (str): Content type of thumbnail, e.g. image/png thumbnail_type (str): Content type of thumbnail, e.g. image/png
thumbnail_length (int): The size of the media file, in bytes.
""" """
def __init__( def __init__(
@ -312,6 +313,7 @@ class FileInfo:
thumbnail_height=None, thumbnail_height=None,
thumbnail_method=None, thumbnail_method=None,
thumbnail_type=None, thumbnail_type=None,
thumbnail_length=None,
): ):
self.server_name = server_name self.server_name = server_name
self.file_id = file_id self.file_id = file_id
@ -321,6 +323,7 @@ class FileInfo:
self.thumbnail_height = thumbnail_height self.thumbnail_height = thumbnail_height
self.thumbnail_method = thumbnail_method self.thumbnail_method = thumbnail_method
self.thumbnail_type = thumbnail_type self.thumbnail_type = thumbnail_type
self.thumbnail_length = thumbnail_length
def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]: def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]:

View File

@ -386,7 +386,7 @@ class PreviewUrlResource(DirectServeJsonResource):
""" """
Check whether the URL should be downloaded as oEmbed content instead. Check whether the URL should be downloaded as oEmbed content instead.
Params: Args:
url: The URL to check. url: The URL to check.
Returns: Returns:
@ -403,7 +403,7 @@ class PreviewUrlResource(DirectServeJsonResource):
""" """
Request content from an oEmbed endpoint. Request content from an oEmbed endpoint.
Params: Args:
endpoint: The oEmbed API endpoint. endpoint: The oEmbed API endpoint.
url: The URL to pass to the API. url: The URL to pass to the API.
@ -692,27 +692,51 @@ class PreviewUrlResource(DirectServeJsonResource):
def decode_and_calc_og( def decode_and_calc_og(
body: bytes, media_uri: str, request_encoding: Optional[str] = None body: bytes, media_uri: str, request_encoding: Optional[str] = None
) -> Dict[str, Optional[str]]: ) -> Dict[str, Optional[str]]:
"""
Calculate metadata for an HTML document.
This uses lxml to parse the HTML document into the OG response. If errors
occur during processing of the document, an empty response is returned.
Args:
body: The HTML document, as bytes.
media_url: The URI used to download the body.
request_encoding: The character encoding of the body, as a string.
Returns:
The OG response as a dictionary.
"""
# If there's no body, nothing useful is going to be found. # If there's no body, nothing useful is going to be found.
if not body: if not body:
return {} return {}
from lxml import etree from lxml import etree
# Create an HTML parser. If this fails, log and return no metadata.
try: try:
parser = etree.HTMLParser(recover=True, encoding=request_encoding) parser = etree.HTMLParser(recover=True, encoding=request_encoding)
tree = etree.fromstring(body, parser) except LookupError:
og = _calc_og(tree, media_uri) # blindly consider the encoding as utf-8.
parser = etree.HTMLParser(recover=True, encoding="utf-8")
except Exception as e:
logger.warning("Unable to create HTML parser: %s" % (e,))
return {}
def _attempt_calc_og(body_attempt: Union[bytes, str]) -> Dict[str, Optional[str]]:
# Attempt to parse the body. If this fails, log and return no metadata.
tree = etree.fromstring(body_attempt, parser)
return _calc_og(tree, media_uri)
# Attempt to parse the body. If this fails, log and return no metadata.
try:
return _attempt_calc_og(body)
except UnicodeDecodeError: except UnicodeDecodeError:
# blindly try decoding the body as utf-8, which seems to fix # blindly try decoding the body as utf-8, which seems to fix
# the charset mismatches on https://google.com # the charset mismatches on https://google.com
parser = etree.HTMLParser(recover=True, encoding=request_encoding) return _attempt_calc_og(body.decode("utf-8", "ignore"))
tree = etree.fromstring(body.decode("utf-8", "ignore"), parser)
og = _calc_og(tree, media_uri)
return og
def _calc_og(tree, media_uri: str) -> Dict[str, Optional[str]]: def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]:
# suck our tree into lxml and define our OG response. # suck our tree into lxml and define our OG response.
# if we see any image URLs in the OG response, then spider them # if we see any image URLs in the OG response, then spider them

View File

@ -16,7 +16,7 @@
import logging import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Any, Dict, List, Optional
from twisted.web.http import Request from twisted.web.http import Request
@ -106,31 +106,17 @@ class ThumbnailResource(DirectServeJsonResource):
return return
thumbnail_infos = await self.store.get_local_media_thumbnails(media_id) thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
await self._select_and_respond_with_thumbnail(
if thumbnail_infos: request,
thumbnail_info = self._select_thumbnail( width,
width, height, method, m_type, thumbnail_infos height,
) method,
m_type,
file_info = FileInfo( thumbnail_infos,
server_name=None, media_id,
file_id=media_id, url_cache=media_info["url_cache"],
url_cache=media_info["url_cache"], server_name=None,
thumbnail=True, )
thumbnail_width=thumbnail_info["thumbnail_width"],
thumbnail_height=thumbnail_info["thumbnail_height"],
thumbnail_type=thumbnail_info["thumbnail_type"],
thumbnail_method=thumbnail_info["thumbnail_method"],
)
t_type = file_info.thumbnail_type
t_length = thumbnail_info["thumbnail_length"]
responder = await self.media_storage.fetch_media(file_info)
await respond_with_responder(request, responder, t_type, t_length)
else:
logger.info("Couldn't find any generated thumbnails")
respond_404(request)
async def _select_or_generate_local_thumbnail( async def _select_or_generate_local_thumbnail(
self, self,
@ -276,26 +262,64 @@ class ThumbnailResource(DirectServeJsonResource):
thumbnail_infos = await self.store.get_remote_media_thumbnails( thumbnail_infos = await self.store.get_remote_media_thumbnails(
server_name, media_id server_name, media_id
) )
await self._select_and_respond_with_thumbnail(
request,
width,
height,
method,
m_type,
thumbnail_infos,
media_info["filesystem_id"],
url_cache=None,
server_name=server_name,
)
async def _select_and_respond_with_thumbnail(
self,
request: Request,
desired_width: int,
desired_height: int,
desired_method: str,
desired_type: str,
thumbnail_infos: List[Dict[str, Any]],
file_id: str,
url_cache: Optional[str] = None,
server_name: Optional[str] = None,
) -> None:
"""
Respond to a request with an appropriate thumbnail from the previously generated thumbnails.
Args:
request: The incoming request.
desired_width: The desired width, the returned thumbnail may be larger than this.
desired_height: The desired height, the returned thumbnail may be larger than this.
desired_method: The desired method used to generate the thumbnail.
desired_type: The desired content-type of the thumbnail.
thumbnail_infos: A list of dictionaries of candidate thumbnails.
file_id: The ID of the media that a thumbnail is being requested for.
url_cache: The URL cache value.
server_name: The server name, if this is a remote thumbnail.
"""
if thumbnail_infos: if thumbnail_infos:
thumbnail_info = self._select_thumbnail( file_info = self._select_thumbnail(
width, height, method, m_type, thumbnail_infos desired_width,
desired_height,
desired_method,
desired_type,
thumbnail_infos,
file_id,
url_cache,
server_name,
) )
file_info = FileInfo( if not file_info:
server_name=server_name, logger.info("Couldn't find a thumbnail matching the desired inputs")
file_id=media_info["filesystem_id"], respond_404(request)
thumbnail=True, return
thumbnail_width=thumbnail_info["thumbnail_width"],
thumbnail_height=thumbnail_info["thumbnail_height"],
thumbnail_type=thumbnail_info["thumbnail_type"],
thumbnail_method=thumbnail_info["thumbnail_method"],
)
t_type = file_info.thumbnail_type
t_length = thumbnail_info["thumbnail_length"]
responder = await self.media_storage.fetch_media(file_info) responder = await self.media_storage.fetch_media(file_info)
await respond_with_responder(request, responder, t_type, t_length) await respond_with_responder(
request, responder, file_info.thumbnail_type, file_info.thumbnail_length
)
else: else:
logger.info("Failed to find any generated thumbnails") logger.info("Failed to find any generated thumbnails")
respond_404(request) respond_404(request)
@ -306,67 +330,117 @@ class ThumbnailResource(DirectServeJsonResource):
desired_height: int, desired_height: int,
desired_method: str, desired_method: str,
desired_type: str, desired_type: str,
thumbnail_infos, thumbnail_infos: List[Dict[str, Any]],
) -> dict: file_id: str,
url_cache: Optional[str],
server_name: Optional[str],
) -> Optional[FileInfo]:
"""
Choose an appropriate thumbnail from the previously generated thumbnails.
Args:
desired_width: The desired width, the returned thumbnail may be larger than this.
desired_height: The desired height, the returned thumbnail may be larger than this.
desired_method: The desired method used to generate the thumbnail.
desired_type: The desired content-type of the thumbnail.
thumbnail_infos: A list of dictionaries of candidate thumbnails.
file_id: The ID of the media that a thumbnail is being requested for.
url_cache: The URL cache value.
server_name: The server name, if this is a remote thumbnail.
Returns:
The thumbnail which best matches the desired parameters.
"""
desired_method = desired_method.lower()
# The chosen thumbnail.
thumbnail_info = None
d_w = desired_width d_w = desired_width
d_h = desired_height d_h = desired_height
if desired_method.lower() == "crop": if desired_method == "crop":
# Thumbnails that match equal or larger sizes of desired width/height.
crop_info_list = [] crop_info_list = []
# Other thumbnails.
crop_info_list2 = [] crop_info_list2 = []
for info in thumbnail_infos: for info in thumbnail_infos:
# Skip thumbnails generated with different methods.
if info["thumbnail_method"] != "crop":
continue
t_w = info["thumbnail_width"] t_w = info["thumbnail_width"]
t_h = info["thumbnail_height"] t_h = info["thumbnail_height"]
t_method = info["thumbnail_method"] aspect_quality = abs(d_w * t_h - d_h * t_w)
if t_method == "crop": min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
aspect_quality = abs(d_w * t_h - d_h * t_w)
min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
size_quality = abs((d_w - t_w) * (d_h - t_h))
type_quality = desired_type != info["thumbnail_type"]
length_quality = info["thumbnail_length"]
if t_w >= d_w or t_h >= d_h:
crop_info_list.append(
(
aspect_quality,
min_quality,
size_quality,
type_quality,
length_quality,
info,
)
)
else:
crop_info_list2.append(
(
aspect_quality,
min_quality,
size_quality,
type_quality,
length_quality,
info,
)
)
if crop_info_list:
return min(crop_info_list)[-1]
else:
return min(crop_info_list2)[-1]
else:
info_list = []
info_list2 = []
for info in thumbnail_infos:
t_w = info["thumbnail_width"]
t_h = info["thumbnail_height"]
t_method = info["thumbnail_method"]
size_quality = abs((d_w - t_w) * (d_h - t_h)) size_quality = abs((d_w - t_w) * (d_h - t_h))
type_quality = desired_type != info["thumbnail_type"] type_quality = desired_type != info["thumbnail_type"]
length_quality = info["thumbnail_length"] length_quality = info["thumbnail_length"]
if t_method == "scale" and (t_w >= d_w or t_h >= d_h): if t_w >= d_w or t_h >= d_h:
crop_info_list.append(
(
aspect_quality,
min_quality,
size_quality,
type_quality,
length_quality,
info,
)
)
else:
crop_info_list2.append(
(
aspect_quality,
min_quality,
size_quality,
type_quality,
length_quality,
info,
)
)
if crop_info_list:
thumbnail_info = min(crop_info_list)[-1]
elif crop_info_list2:
thumbnail_info = min(crop_info_list2)[-1]
elif desired_method == "scale":
# Thumbnails that match equal or larger sizes of desired width/height.
info_list = []
# Other thumbnails.
info_list2 = []
for info in thumbnail_infos:
# Skip thumbnails generated with different methods.
if info["thumbnail_method"] != "scale":
continue
t_w = info["thumbnail_width"]
t_h = info["thumbnail_height"]
size_quality = abs((d_w - t_w) * (d_h - t_h))
type_quality = desired_type != info["thumbnail_type"]
length_quality = info["thumbnail_length"]
if t_w >= d_w or t_h >= d_h:
info_list.append((size_quality, type_quality, length_quality, info)) info_list.append((size_quality, type_quality, length_quality, info))
elif t_method == "scale": else:
info_list2.append( info_list2.append(
(size_quality, type_quality, length_quality, info) (size_quality, type_quality, length_quality, info)
) )
if info_list: if info_list:
return min(info_list)[-1] thumbnail_info = min(info_list)[-1]
else: elif info_list2:
return min(info_list2)[-1] thumbnail_info = min(info_list2)[-1]
if thumbnail_info:
return FileInfo(
file_id=file_id,
url_cache=url_cache,
server_name=server_name,
thumbnail=True,
thumbnail_width=thumbnail_info["thumbnail_width"],
thumbnail_height=thumbnail_info["thumbnail_height"],
thumbnail_type=thumbnail_info["thumbnail_type"],
thumbnail_method=thumbnail_info["thumbnail_method"],
thumbnail_length=thumbnail_info["thumbnail_length"],
)
# No matching thumbnail was found.
return None

View File

@ -103,6 +103,7 @@ from synapse.notifier import Notifier
from synapse.push.action_generator import ActionGenerator from synapse.push.action_generator import ActionGenerator
from synapse.push.pusherpool import PusherPool from synapse.push.pusherpool import PusherPool
from synapse.replication.tcp.client import ReplicationDataHandler from synapse.replication.tcp.client import ReplicationDataHandler
from synapse.replication.tcp.external_cache import ExternalCache
from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.resource import ReplicationStreamer from synapse.replication.tcp.resource import ReplicationStreamer
from synapse.replication.tcp.streams import STREAMS_MAP, Stream from synapse.replication.tcp.streams import STREAMS_MAP, Stream
@ -128,6 +129,8 @@ from synapse.util.stringutils import random_string
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
from txredisapi import RedisProtocol
from synapse.handlers.oidc_handler import OidcHandler from synapse.handlers.oidc_handler import OidcHandler
from synapse.handlers.saml_handler import SamlHandler from synapse.handlers.saml_handler import SamlHandler
@ -716,6 +719,33 @@ class HomeServer(metaclass=abc.ABCMeta):
def get_account_data_handler(self) -> AccountDataHandler: def get_account_data_handler(self) -> AccountDataHandler:
return AccountDataHandler(self) return AccountDataHandler(self)
@cache_in_self
def get_external_cache(self) -> ExternalCache:
return ExternalCache(self)
@cache_in_self
def get_outbound_redis_connection(self) -> Optional["RedisProtocol"]:
if not self.config.redis.redis_enabled:
return None
# We only want to import redis module if we're using it, as we have
# `txredisapi` as an optional dependency.
from synapse.replication.tcp.redis import lazyConnection
logger.info(
"Connecting to redis (host=%r port=%r) for external cache",
self.config.redis_host,
self.config.redis_port,
)
return lazyConnection(
hs=self,
host=self.config.redis_host,
port=self.config.redis_port,
password=self.config.redis.redis_password,
reconnect=True,
)
async def remove_pusher(self, app_id: str, push_key: str, user_id: str): async def remove_pusher(self, app_id: str, push_key: str, user_id: str):
return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id) return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id)

View File

@ -310,6 +310,7 @@ class StateHandler:
state_group_before_event = None state_group_before_event = None
state_group_before_event_prev_group = None state_group_before_event_prev_group = None
deltas_to_state_group_before_event = None deltas_to_state_group_before_event = None
entry = None
else: else:
# otherwise, we'll need to resolve the state across the prev_events. # otherwise, we'll need to resolve the state across the prev_events.
@ -340,9 +341,13 @@ class StateHandler:
current_state_ids=state_ids_before_event, current_state_ids=state_ids_before_event,
) )
# XXX: can we update the state cache entry for the new state group? or # Assign the new state group to the cached state entry.
# could we set a flag on resolve_state_groups_for_events to tell it to #
# always make a state group? # Note that this can race in that we could generate multiple state
# groups for the same state entry, but that is just inefficient
# rather than dangerous.
if entry and entry.state_group is None:
entry.state_group = state_group_before_event
# #
# now if it's not a state event, we're done # now if it's not a state event, we're done

View File

@ -262,13 +262,18 @@ class LoggingTransaction:
return self.txn.description return self.txn.description
def execute_batch(self, sql: str, args: Iterable[Iterable[Any]]) -> None: def execute_batch(self, sql: str, args: Iterable[Iterable[Any]]) -> None:
"""Similar to `executemany`, except `txn.rowcount` will not be correct
afterwards.
More efficient than `executemany` on PostgreSQL
"""
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
from psycopg2.extras import execute_batch # type: ignore from psycopg2.extras import execute_batch # type: ignore
self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args) self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
else: else:
for val in args: self.executemany(sql, args)
self.execute(sql, val)
def execute_values(self, sql: str, *args: Any) -> List[Tuple]: def execute_values(self, sql: str, *args: Any) -> List[Tuple]:
"""Corresponds to psycopg2.extras.execute_values. Only available when """Corresponds to psycopg2.extras.execute_values. Only available when
@ -888,7 +893,7 @@ class DatabasePool:
", ".join("?" for _ in keys[0]), ", ".join("?" for _ in keys[0]),
) )
txn.executemany(sql, vals) txn.execute_batch(sql, vals)
async def simple_upsert( async def simple_upsert(
self, self,

View File

@ -1,7 +1,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd # Copyright 2018 New Vector Ltd
# Copyright 2019 The Matrix.org Foundation C.I.C. # Copyright 2019-2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -43,6 +43,7 @@ from .end_to_end_keys import EndToEndKeyStore
from .event_federation import EventFederationStore from .event_federation import EventFederationStore
from .event_push_actions import EventPushActionsStore from .event_push_actions import EventPushActionsStore
from .events_bg_updates import EventsBackgroundUpdatesStore from .events_bg_updates import EventsBackgroundUpdatesStore
from .events_forward_extremities import EventForwardExtremitiesStore
from .filtering import FilteringStore from .filtering import FilteringStore
from .group_server import GroupServerStore from .group_server import GroupServerStore
from .keys import KeyStore from .keys import KeyStore
@ -118,6 +119,7 @@ class DataStore(
UIAuthStore, UIAuthStore,
CacheInvalidationWorkerStore, CacheInvalidationWorkerStore,
ServerMetricsStore, ServerMetricsStore,
EventForwardExtremitiesStore,
): ):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
self.hs = hs self.hs = hs

View File

@ -897,7 +897,7 @@ class DeviceWorkerStore(SQLBaseStore):
DELETE FROM device_lists_outbound_last_success DELETE FROM device_lists_outbound_last_success
WHERE destination = ? AND user_id = ? WHERE destination = ? AND user_id = ?
""" """
txn.executemany(sql, ((row[0], row[1]) for row in rows)) txn.execute_batch(sql, ((row[0], row[1]) for row in rows))
logger.info("Pruned %d device list outbound pokes", count) logger.info("Pruned %d device list outbound pokes", count)
@ -1343,7 +1343,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
# Delete older entries in the table, as we really only care about # Delete older entries in the table, as we really only care about
# when the latest change happened. # when the latest change happened.
txn.executemany( txn.execute_batch(
""" """
DELETE FROM device_lists_stream DELETE FROM device_lists_stream
WHERE user_id = ? AND device_id = ? AND stream_id < ? WHERE user_id = ? AND device_id = ? AND stream_id < ?

View File

@ -487,7 +487,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
VALUES (?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?)
""" """
txn.executemany( txn.execute_batch(
sql, sql,
( (
_gen_entry(user_id, actions) _gen_entry(user_id, actions)
@ -803,7 +803,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
], ],
) )
txn.executemany( txn.execute_batch(
""" """
UPDATE event_push_summary UPDATE event_push_summary
SET notif_count = ?, unread_count = ?, stream_ordering = ? SET notif_count = ?, unread_count = ?, stream_ordering = ?

View File

@ -473,8 +473,9 @@ class PersistEventsStore:
txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain, txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain,
) )
@staticmethod @classmethod
def _add_chain_cover_index( def _add_chain_cover_index(
cls,
txn, txn,
db_pool: DatabasePool, db_pool: DatabasePool,
event_to_room_id: Dict[str, str], event_to_room_id: Dict[str, str],
@ -614,60 +615,17 @@ class PersistEventsStore:
if not events_to_calc_chain_id_for: if not events_to_calc_chain_id_for:
return return
# We now calculate the chain IDs/sequence numbers for the events. We # Allocate chain ID/sequence numbers to each new event.
# do this by looking at the chain ID and sequence number of any auth new_chain_tuples = cls._allocate_chain_ids(
# event with the same type/state_key and incrementing the sequence txn,
# number by one. If there was no match or the chain ID/sequence db_pool,
# number is already taken we generate a new chain. event_to_room_id,
# event_to_types,
# We need to do this in a topologically sorted order as we want to event_to_auth_chain,
# generate chain IDs/sequence numbers of an event's auth events events_to_calc_chain_id_for,
# before the event itself. chain_map,
chains_tuples_allocated = set() # type: Set[Tuple[int, int]] )
new_chain_tuples = {} # type: Dict[str, Tuple[int, int]] chain_map.update(new_chain_tuples)
for event_id in sorted_topologically(
events_to_calc_chain_id_for, event_to_auth_chain
):
existing_chain_id = None
for auth_id in event_to_auth_chain.get(event_id, []):
if event_to_types.get(event_id) == event_to_types.get(auth_id):
existing_chain_id = chain_map[auth_id]
break
new_chain_tuple = None
if existing_chain_id:
# We found a chain ID/sequence number candidate, check its
# not already taken.
proposed_new_id = existing_chain_id[0]
proposed_new_seq = existing_chain_id[1] + 1
if (proposed_new_id, proposed_new_seq) not in chains_tuples_allocated:
already_allocated = db_pool.simple_select_one_onecol_txn(
txn,
table="event_auth_chains",
keyvalues={
"chain_id": proposed_new_id,
"sequence_number": proposed_new_seq,
},
retcol="event_id",
allow_none=True,
)
if already_allocated:
# Mark it as already allocated so we don't need to hit
# the DB again.
chains_tuples_allocated.add((proposed_new_id, proposed_new_seq))
else:
new_chain_tuple = (
proposed_new_id,
proposed_new_seq,
)
if not new_chain_tuple:
new_chain_tuple = (db_pool.event_chain_id_gen.get_next_id_txn(txn), 1)
chains_tuples_allocated.add(new_chain_tuple)
chain_map[event_id] = new_chain_tuple
new_chain_tuples[event_id] = new_chain_tuple
db_pool.simple_insert_many_txn( db_pool.simple_insert_many_txn(
txn, txn,
@ -794,6 +752,137 @@ class PersistEventsStore:
], ],
) )
@staticmethod
def _allocate_chain_ids(
txn,
db_pool: DatabasePool,
event_to_room_id: Dict[str, str],
event_to_types: Dict[str, Tuple[str, str]],
event_to_auth_chain: Dict[str, List[str]],
events_to_calc_chain_id_for: Set[str],
chain_map: Dict[str, Tuple[int, int]],
) -> Dict[str, Tuple[int, int]]:
"""Allocates, but does not persist, chain ID/sequence numbers for the
events in `events_to_calc_chain_id_for`. (c.f. _add_chain_cover_index
for info on args)
"""
# We now calculate the chain IDs/sequence numbers for the events. We do
# this by looking at the chain ID and sequence number of any auth event
# with the same type/state_key and incrementing the sequence number by
# one. If there was no match or the chain ID/sequence number is already
# taken we generate a new chain.
#
# We try to reduce the number of times that we hit the database by
# batching up calls, to make this more efficient when persisting large
# numbers of state events (e.g. during joins).
#
# We do this by:
# 1. Calculating for each event which auth event will be used to
# inherit the chain ID, i.e. converting the auth chain graph to a
# tree that we can allocate chains on. We also keep track of which
# existing chain IDs have been referenced.
# 2. Fetching the max allocated sequence number for each referenced
# existing chain ID, generating a map from chain ID to the max
# allocated sequence number.
# 3. Iterating over the tree and allocating a chain ID/seq no. to the
# new event, by incrementing the sequence number from the
# referenced event's chain ID/seq no. and checking that the
# incremented sequence number hasn't already been allocated (by
# looking in the map generated in the previous step). We generate a
# new chain if the sequence number has already been allocated.
#
existing_chains = set() # type: Set[int]
tree = [] # type: List[Tuple[str, Optional[str]]]
# We need to do this in a topologically sorted order as we want to
# generate chain IDs/sequence numbers of an event's auth events before
# the event itself.
for event_id in sorted_topologically(
events_to_calc_chain_id_for, event_to_auth_chain
):
for auth_id in event_to_auth_chain.get(event_id, []):
if event_to_types.get(event_id) == event_to_types.get(auth_id):
existing_chain_id = chain_map.get(auth_id)
if existing_chain_id:
existing_chains.add(existing_chain_id[0])
tree.append((event_id, auth_id))
break
else:
tree.append((event_id, None))
# Fetch the current max sequence number for each existing referenced chain.
sql = """
SELECT chain_id, MAX(sequence_number) FROM event_auth_chains
WHERE %s
GROUP BY chain_id
"""
clause, args = make_in_list_sql_clause(
db_pool.engine, "chain_id", existing_chains
)
txn.execute(sql % (clause,), args)
chain_to_max_seq_no = {row[0]: row[1] for row in txn} # type: Dict[Any, int]
# Allocate the new events chain ID/sequence numbers.
#
# To reduce the number of calls to the database we don't allocate a
# chain ID number in the loop, instead we use a temporary `object()` for
# each new chain ID. Once we've done the loop we generate the necessary
# number of new chain IDs in one call, replacing all temporary
# objects with real allocated chain IDs.
unallocated_chain_ids = set() # type: Set[object]
new_chain_tuples = {} # type: Dict[str, Tuple[Any, int]]
for event_id, auth_event_id in tree:
# If we reference an auth_event_id we fetch the allocated chain ID,
# either from the existing `chain_map` or the newly generated
# `new_chain_tuples` map.
existing_chain_id = None
if auth_event_id:
existing_chain_id = new_chain_tuples.get(auth_event_id)
if not existing_chain_id:
existing_chain_id = chain_map[auth_event_id]
new_chain_tuple = None # type: Optional[Tuple[Any, int]]
if existing_chain_id:
# We found a chain ID/sequence number candidate, check its
# not already taken.
proposed_new_id = existing_chain_id[0]
proposed_new_seq = existing_chain_id[1] + 1
if chain_to_max_seq_no[proposed_new_id] < proposed_new_seq:
new_chain_tuple = (
proposed_new_id,
proposed_new_seq,
)
# If we need to start a new chain we allocate a temporary chain ID.
if not new_chain_tuple:
new_chain_tuple = (object(), 1)
unallocated_chain_ids.add(new_chain_tuple[0])
new_chain_tuples[event_id] = new_chain_tuple
chain_to_max_seq_no[new_chain_tuple[0]] = new_chain_tuple[1]
# Generate new chain IDs for all unallocated chain IDs.
newly_allocated_chain_ids = db_pool.event_chain_id_gen.get_next_mult_txn(
txn, len(unallocated_chain_ids)
)
# Map from potentially temporary chain ID to real chain ID
chain_id_to_allocated_map = dict(
zip(unallocated_chain_ids, newly_allocated_chain_ids)
) # type: Dict[Any, int]
chain_id_to_allocated_map.update((c, c) for c in existing_chains)
return {
event_id: (chain_id_to_allocated_map[chain_id], seq)
for event_id, (chain_id, seq) in new_chain_tuples.items()
}
def _persist_transaction_ids_txn( def _persist_transaction_ids_txn(
self, self,
txn: LoggingTransaction, txn: LoggingTransaction,
@ -876,7 +965,7 @@ class PersistEventsStore:
WHERE room_id = ? AND type = ? AND state_key = ? WHERE room_id = ? AND type = ? AND state_key = ?
) )
""" """
txn.executemany( txn.execute_batch(
sql, sql,
( (
( (
@ -895,7 +984,7 @@ class PersistEventsStore:
) )
# Now we actually update the current_state_events table # Now we actually update the current_state_events table
txn.executemany( txn.execute_batch(
"DELETE FROM current_state_events" "DELETE FROM current_state_events"
" WHERE room_id = ? AND type = ? AND state_key = ?", " WHERE room_id = ? AND type = ? AND state_key = ?",
( (
@ -907,7 +996,7 @@ class PersistEventsStore:
# We include the membership in the current state table, hence we do # We include the membership in the current state table, hence we do
# a lookup when we insert. This assumes that all events have already # a lookup when we insert. This assumes that all events have already
# been inserted into room_memberships. # been inserted into room_memberships.
txn.executemany( txn.execute_batch(
"""INSERT INTO current_state_events """INSERT INTO current_state_events
(room_id, type, state_key, event_id, membership) (room_id, type, state_key, event_id, membership)
VALUES (?, ?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?)) VALUES (?, ?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
@ -927,7 +1016,7 @@ class PersistEventsStore:
# we have no record of the fact the user *was* a member of the # we have no record of the fact the user *was* a member of the
# room but got, say, state reset out of it. # room but got, say, state reset out of it.
if to_delete or to_insert: if to_delete or to_insert:
txn.executemany( txn.execute_batch(
"DELETE FROM local_current_membership" "DELETE FROM local_current_membership"
" WHERE room_id = ? AND user_id = ?", " WHERE room_id = ? AND user_id = ?",
( (
@ -938,7 +1027,7 @@ class PersistEventsStore:
) )
if to_insert: if to_insert:
txn.executemany( txn.execute_batch(
"""INSERT INTO local_current_membership """INSERT INTO local_current_membership
(room_id, user_id, event_id, membership) (room_id, user_id, event_id, membership)
VALUES (?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?)) VALUES (?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
@ -1738,7 +1827,7 @@ class PersistEventsStore:
""" """
if events_and_contexts: if events_and_contexts:
txn.executemany( txn.execute_batch(
sql, sql,
( (
( (
@ -1767,7 +1856,7 @@ class PersistEventsStore:
# Now we delete the staging area for *all* events that were being # Now we delete the staging area for *all* events that were being
# persisted. # persisted.
txn.executemany( txn.execute_batch(
"DELETE FROM event_push_actions_staging WHERE event_id = ?", "DELETE FROM event_push_actions_staging WHERE event_id = ?",
((event.event_id,) for event, _ in all_events_and_contexts), ((event.event_id,) for event, _ in all_events_and_contexts),
) )
@ -1886,7 +1975,7 @@ class PersistEventsStore:
" )" " )"
) )
txn.executemany( txn.execute_batch(
query, query,
[ [
(e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False) (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
@ -1900,7 +1989,7 @@ class PersistEventsStore:
"DELETE FROM event_backward_extremities" "DELETE FROM event_backward_extremities"
" WHERE event_id = ? AND room_id = ?" " WHERE event_id = ? AND room_id = ?"
) )
txn.executemany( txn.execute_batch(
query, query,
[ [
(ev.event_id, ev.room_id) (ev.event_id, ev.room_id)

View File

@ -139,8 +139,6 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
max_stream_id = progress["max_stream_id_exclusive"] max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0) rows_inserted = progress.get("rows_inserted", 0)
INSERT_CLUMP_SIZE = 1000
def reindex_txn(txn): def reindex_txn(txn):
sql = ( sql = (
"SELECT stream_ordering, event_id, json FROM events" "SELECT stream_ordering, event_id, json FROM events"
@ -178,9 +176,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
sql = "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?" sql = "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?"
for index in range(0, len(update_rows), INSERT_CLUMP_SIZE): txn.execute_batch(sql, update_rows)
clump = update_rows[index : index + INSERT_CLUMP_SIZE]
txn.executemany(sql, clump)
progress = { progress = {
"target_min_stream_id_inclusive": target_min_stream_id, "target_min_stream_id_inclusive": target_min_stream_id,
@ -210,8 +206,6 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
max_stream_id = progress["max_stream_id_exclusive"] max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0) rows_inserted = progress.get("rows_inserted", 0)
INSERT_CLUMP_SIZE = 1000
def reindex_search_txn(txn): def reindex_search_txn(txn):
sql = ( sql = (
"SELECT stream_ordering, event_id FROM events" "SELECT stream_ordering, event_id FROM events"
@ -256,9 +250,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
sql = "UPDATE events SET origin_server_ts = ? WHERE event_id = ?" sql = "UPDATE events SET origin_server_ts = ? WHERE event_id = ?"
for index in range(0, len(rows_to_update), INSERT_CLUMP_SIZE): txn.execute_batch(sql, rows_to_update)
clump = rows_to_update[index : index + INSERT_CLUMP_SIZE]
txn.executemany(sql, clump)
progress = { progress = {
"target_min_stream_id_inclusive": target_min_stream_id, "target_min_stream_id_inclusive": target_min_stream_id,

View File

@ -0,0 +1,101 @@
# -*- coding: utf-8 -*-
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Dict, List
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore
logger = logging.getLogger(__name__)
class EventForwardExtremitiesStore(SQLBaseStore):
async def delete_forward_extremities_for_room(self, room_id: str) -> int:
"""Delete any extra forward extremities for a room.
Invalidates the "get_latest_event_ids_in_room" cache if any forward
extremities were deleted.
Returns count deleted.
"""
def delete_forward_extremities_for_room_txn(txn):
# First we need to get the event_id to not delete
sql = """
SELECT event_id FROM event_forward_extremities
INNER JOIN events USING (room_id, event_id)
WHERE room_id = ?
ORDER BY stream_ordering DESC
LIMIT 1
"""
txn.execute(sql, (room_id,))
rows = txn.fetchall()
try:
event_id = rows[0][0]
logger.debug(
"Found event_id %s as the forward extremity to keep for room %s",
event_id,
room_id,
)
except KeyError:
msg = "No forward extremity event found for room %s" % room_id
logger.warning(msg)
raise SynapseError(400, msg)
# Now delete the extra forward extremities
sql = """
DELETE FROM event_forward_extremities
WHERE event_id != ? AND room_id = ?
"""
txn.execute(sql, (event_id, room_id))
logger.info(
"Deleted %s extra forward extremities for room %s",
txn.rowcount,
room_id,
)
if txn.rowcount > 0:
# Invalidate the cache
self._invalidate_cache_and_stream(
txn, self.get_latest_event_ids_in_room, (room_id,),
)
return txn.rowcount
return await self.db_pool.runInteraction(
"delete_forward_extremities_for_room",
delete_forward_extremities_for_room_txn,
)
async def get_forward_extremities_for_room(self, room_id: str) -> List[Dict]:
"""Get list of forward extremities for a room."""
def get_forward_extremities_for_room_txn(txn):
sql = """
SELECT event_id, state_group, depth, received_ts
FROM event_forward_extremities
INNER JOIN event_to_state_groups USING (event_id)
INNER JOIN events USING (room_id, event_id)
WHERE room_id = ?
"""
txn.execute(sql, (room_id,))
return self.db_pool.cursor_to_dict(txn)
return await self.db_pool.runInteraction(
"get_forward_extremities_for_room", get_forward_extremities_for_room_txn,
)

View File

@ -417,7 +417,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
" WHERE media_origin = ? AND media_id = ?" " WHERE media_origin = ? AND media_id = ?"
) )
txn.executemany( txn.execute_batch(
sql, sql,
( (
(time_ms, media_origin, media_id) (time_ms, media_origin, media_id)
@ -430,7 +430,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
" WHERE media_id = ?" " WHERE media_id = ?"
) )
txn.executemany(sql, ((time_ms, media_id) for media_id in local_media)) txn.execute_batch(sql, ((time_ms, media_id) for media_id in local_media))
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"update_cached_last_access_time", update_cache_txn "update_cached_last_access_time", update_cache_txn
@ -557,7 +557,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
sql = "DELETE FROM local_media_repository_url_cache WHERE media_id = ?" sql = "DELETE FROM local_media_repository_url_cache WHERE media_id = ?"
def _delete_url_cache_txn(txn): def _delete_url_cache_txn(txn):
txn.executemany(sql, [(media_id,) for media_id in media_ids]) txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"delete_url_cache", _delete_url_cache_txn "delete_url_cache", _delete_url_cache_txn
@ -586,11 +586,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
def _delete_url_cache_media_txn(txn): def _delete_url_cache_media_txn(txn):
sql = "DELETE FROM local_media_repository WHERE media_id = ?" sql = "DELETE FROM local_media_repository WHERE media_id = ?"
txn.executemany(sql, [(media_id,) for media_id in media_ids]) txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
sql = "DELETE FROM local_media_repository_thumbnails WHERE media_id = ?" sql = "DELETE FROM local_media_repository_thumbnails WHERE media_id = ?"
txn.executemany(sql, [(media_id,) for media_id in media_ids]) txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"delete_url_cache_media", _delete_url_cache_media_txn "delete_url_cache_media", _delete_url_cache_media_txn

View File

@ -172,7 +172,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
) )
# Update backward extremeties # Update backward extremeties
txn.executemany( txn.execute_batch(
"INSERT INTO event_backward_extremities (room_id, event_id)" "INSERT INTO event_backward_extremities (room_id, event_id)"
" VALUES (?, ?)", " VALUES (?, ?)",
[(room_id, event_id) for event_id, in new_backwards_extrems], [(room_id, event_id) for event_id, in new_backwards_extrems],

View File

@ -344,7 +344,9 @@ class PusherStore(PusherWorkerStore):
txn, self.get_if_user_has_pusher, (user_id,) txn, self.get_if_user_has_pusher, (user_id,)
) )
self.db_pool.simple_delete_one_txn( # It is expected that there is exactly one pusher to delete, but
# if it isn't there (or there are multiple) delete them all.
self.db_pool.simple_delete_txn(
txn, txn,
"pushers", "pushers",
{"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, {"app_id": app_id, "pushkey": pushkey, "user_name": user_id},

View File

@ -360,6 +360,35 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn) await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
async def set_shadow_banned(self, user: UserID, shadow_banned: bool) -> None:
"""Sets whether a user shadow-banned.
Args:
user: user ID of the user to test
shadow_banned: true iff the user is to be shadow-banned, false otherwise.
"""
def set_shadow_banned_txn(txn):
self.db_pool.simple_update_one_txn(
txn,
table="users",
keyvalues={"name": user.to_string()},
updatevalues={"shadow_banned": shadow_banned},
)
# In order for this to apply immediately, clear the cache for this user.
tokens = self.db_pool.simple_select_onecol_txn(
txn,
table="access_tokens",
keyvalues={"user_id": user.to_string()},
retcol="token",
)
for token in tokens:
self._invalidate_cache_and_stream(
txn, self.get_user_by_access_token, (token,)
)
await self.db_pool.runInteraction("set_shadow_banned", set_shadow_banned_txn)
def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]: def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]:
sql = """ sql = """
SELECT users.name as user_id, SELECT users.name as user_id,
@ -1104,7 +1133,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
FROM user_threepids FROM user_threepids
""" """
txn.executemany(sql, [(id_server,) for id_server in id_servers]) txn.execute_batch(sql, [(id_server,) for id_server in id_servers])
if id_servers: if id_servers:
await self.db_pool.runInteraction( await self.db_pool.runInteraction(

View File

@ -873,8 +873,6 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
"max_stream_id_exclusive", self._stream_order_on_start + 1 "max_stream_id_exclusive", self._stream_order_on_start + 1
) )
INSERT_CLUMP_SIZE = 1000
def add_membership_profile_txn(txn): def add_membership_profile_txn(txn):
sql = """ sql = """
SELECT stream_ordering, event_id, events.room_id, event_json.json SELECT stream_ordering, event_id, events.room_id, event_json.json
@ -915,9 +913,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
UPDATE room_memberships SET display_name = ?, avatar_url = ? UPDATE room_memberships SET display_name = ?, avatar_url = ?
WHERE event_id = ? AND room_id = ? WHERE event_id = ? AND room_id = ?
""" """
for index in range(0, len(to_update), INSERT_CLUMP_SIZE): txn.execute_batch(to_update_sql, to_update)
clump = to_update[index : index + INSERT_CLUMP_SIZE]
txn.executemany(to_update_sql, clump)
progress = { progress = {
"target_min_stream_id_inclusive": target_min_stream_id, "target_min_stream_id_inclusive": target_min_stream_id,

View File

@ -55,7 +55,7 @@ def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs
# { "ignored_users": "@someone:example.org": {} } # { "ignored_users": "@someone:example.org": {} }
ignored_users = content.get("ignored_users", {}) ignored_users = content.get("ignored_users", {})
if isinstance(ignored_users, dict) and ignored_users: if isinstance(ignored_users, dict) and ignored_users:
cur.executemany(insert_sql, [(user_id, u) for u in ignored_users]) cur.execute_batch(insert_sql, [(user_id, u) for u in ignored_users])
# Add indexes after inserting data for efficiency. # Add indexes after inserting data for efficiency.
logger.info("Adding constraints to ignored_users table") logger.info("Adding constraints to ignored_users table")

View File

@ -63,7 +63,7 @@ class SearchWorkerStore(SQLBaseStore):
for entry in entries for entry in entries
) )
txn.executemany(sql, args) txn.execute_batch(sql, args)
elif isinstance(self.database_engine, Sqlite3Engine): elif isinstance(self.database_engine, Sqlite3Engine):
sql = ( sql = (
@ -75,7 +75,7 @@ class SearchWorkerStore(SQLBaseStore):
for entry in entries for entry in entries
) )
txn.executemany(sql, args) txn.execute_batch(sql, args)
else: else:
# This should be unreachable. # This should be unreachable.
raise Exception("Unrecognized database engine") raise Exception("Unrecognized database engine")

View File

@ -565,11 +565,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
) )
logger.info("[purge] removing redundant state groups") logger.info("[purge] removing redundant state groups")
txn.executemany( txn.execute_batch(
"DELETE FROM state_groups_state WHERE state_group = ?", "DELETE FROM state_groups_state WHERE state_group = ?",
((sg,) for sg in state_groups_to_delete), ((sg,) for sg in state_groups_to_delete),
) )
txn.executemany( txn.execute_batch(
"DELETE FROM state_groups WHERE id = ?", "DELETE FROM state_groups WHERE id = ?",
((sg,) for sg in state_groups_to_delete), ((sg,) for sg in state_groups_to_delete),
) )

View File

@ -15,12 +15,11 @@
import heapq import heapq
import logging import logging
import threading import threading
from collections import deque from collections import OrderedDict
from contextlib import contextmanager from contextlib import contextmanager
from typing import Dict, List, Optional, Set, Tuple, Union from typing import Dict, List, Optional, Set, Tuple, Union
import attr import attr
from typing_extensions import Deque
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.database import DatabasePool, LoggingTransaction
@ -101,7 +100,13 @@ class StreamIdGenerator:
self._current = (max if step > 0 else min)( self._current = (max if step > 0 else min)(
self._current, _load_current_id(db_conn, table, column, step) self._current, _load_current_id(db_conn, table, column, step)
) )
self._unfinished_ids = deque() # type: Deque[int]
# We use this as an ordered set, as we want to efficiently append items,
# remove items and get the first item. Since we insert IDs in order, the
# insertion ordering will ensure its in the correct ordering.
#
# The key and values are the same, but we never look at the values.
self._unfinished_ids = OrderedDict() # type: OrderedDict[int, int]
def get_next(self): def get_next(self):
""" """
@ -113,7 +118,7 @@ class StreamIdGenerator:
self._current += self._step self._current += self._step
next_id = self._current next_id = self._current
self._unfinished_ids.append(next_id) self._unfinished_ids[next_id] = next_id
@contextmanager @contextmanager
def manager(): def manager():
@ -121,7 +126,7 @@ class StreamIdGenerator:
yield next_id yield next_id
finally: finally:
with self._lock: with self._lock:
self._unfinished_ids.remove(next_id) self._unfinished_ids.pop(next_id)
return _AsyncCtxManagerWrapper(manager()) return _AsyncCtxManagerWrapper(manager())
@ -140,7 +145,7 @@ class StreamIdGenerator:
self._current += n * self._step self._current += n * self._step
for next_id in next_ids: for next_id in next_ids:
self._unfinished_ids.append(next_id) self._unfinished_ids[next_id] = next_id
@contextmanager @contextmanager
def manager(): def manager():
@ -149,7 +154,7 @@ class StreamIdGenerator:
finally: finally:
with self._lock: with self._lock:
for next_id in next_ids: for next_id in next_ids:
self._unfinished_ids.remove(next_id) self._unfinished_ids.pop(next_id)
return _AsyncCtxManagerWrapper(manager()) return _AsyncCtxManagerWrapper(manager())
@ -162,7 +167,7 @@ class StreamIdGenerator:
""" """
with self._lock: with self._lock:
if self._unfinished_ids: if self._unfinished_ids:
return self._unfinished_ids[0] - self._step return next(iter(self._unfinished_ids)) - self._step
return self._current return self._current

View File

@ -69,6 +69,11 @@ class SequenceGenerator(metaclass=abc.ABCMeta):
"""Gets the next ID in the sequence""" """Gets the next ID in the sequence"""
... ...
@abc.abstractmethod
def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
"""Get the next `n` IDs in the sequence"""
...
@abc.abstractmethod @abc.abstractmethod
def check_consistency( def check_consistency(
self, self,
@ -219,6 +224,17 @@ class LocalSequenceGenerator(SequenceGenerator):
self._current_max_id += 1 self._current_max_id += 1
return self._current_max_id return self._current_max_id
def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
with self._lock:
if self._current_max_id is None:
assert self._callback is not None
self._current_max_id = self._callback(txn)
self._callback = None
first_id = self._current_max_id + 1
self._current_max_id += n
return [first_id + i for i in range(n)]
def check_consistency( def check_consistency(
self, self,
db_conn: Connection, db_conn: Connection,

View File

@ -0,0 +1,229 @@
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Iterable, Optional, Tuple
from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
from synapse.events import FrozenEvent
from synapse.push.presentable_names import calculate_room_name
from synapse.types import StateKey, StateMap
from tests import unittest
class MockDataStore:
"""
A fake data store which stores a mapping of state key to event content.
(I.e. the state key is used as the event ID.)
"""
def __init__(self, events: Iterable[Tuple[StateKey, dict]]):
"""
Args:
events: A state map to event contents.
"""
self._events = {}
for i, (event_id, content) in enumerate(events):
self._events[event_id] = FrozenEvent(
{
"event_id": "$event_id",
"type": event_id[0],
"sender": "@user:test",
"state_key": event_id[1],
"room_id": "#room:test",
"content": content,
"origin_server_ts": i,
},
RoomVersions.V1,
)
async def get_event(
self, event_id: StateKey, allow_none: bool = False
) -> Optional[FrozenEvent]:
assert allow_none, "Mock not configured for allow_none = False"
return self._events.get(event_id)
async def get_events(self, event_ids: Iterable[StateKey]):
# This is cheating since it just returns all events.
return self._events
class PresentableNamesTestCase(unittest.HomeserverTestCase):
USER_ID = "@test:test"
OTHER_USER_ID = "@user:test"
def _calculate_room_name(
self,
events: StateMap[dict],
user_id: str = "",
fallback_to_members: bool = True,
fallback_to_single_member: bool = True,
):
# This isn't 100% accurate, but works with MockDataStore.
room_state_ids = {k[0]: k[0] for k in events}
return self.get_success(
calculate_room_name(
MockDataStore(events),
room_state_ids,
user_id or self.USER_ID,
fallback_to_members,
fallback_to_single_member,
)
)
def test_name(self):
"""A room name event should be used."""
events = [
((EventTypes.Name, ""), {"name": "test-name"}),
]
self.assertEqual("test-name", self._calculate_room_name(events))
# Check if the event content has garbage.
events = [((EventTypes.Name, ""), {"foo": 1})]
self.assertEqual("Empty Room", self._calculate_room_name(events))
events = [((EventTypes.Name, ""), {"name": 1})]
self.assertEqual(1, self._calculate_room_name(events))
def test_canonical_alias(self):
"""An canonical alias should be used."""
events = [
((EventTypes.CanonicalAlias, ""), {"alias": "#test-name:test"}),
]
self.assertEqual("#test-name:test", self._calculate_room_name(events))
# Check if the event content has garbage.
events = [((EventTypes.CanonicalAlias, ""), {"foo": 1})]
self.assertEqual("Empty Room", self._calculate_room_name(events))
events = [((EventTypes.CanonicalAlias, ""), {"alias": "test-name"})]
self.assertEqual("Empty Room", self._calculate_room_name(events))
def test_invite(self):
"""An invite has special behaviour."""
events = [
((EventTypes.Member, self.USER_ID), {"membership": Membership.INVITE}),
((EventTypes.Member, self.OTHER_USER_ID), {"displayname": "Other User"}),
]
self.assertEqual("Invite from Other User", self._calculate_room_name(events))
self.assertIsNone(
self._calculate_room_name(events, fallback_to_single_member=False)
)
# Ensure this logic is skipped if we don't fallback to members.
self.assertIsNone(self._calculate_room_name(events, fallback_to_members=False))
# Check if the event content has garbage.
events = [
((EventTypes.Member, self.USER_ID), {"membership": Membership.INVITE}),
((EventTypes.Member, self.OTHER_USER_ID), {"foo": 1}),
]
self.assertEqual("Invite from @user:test", self._calculate_room_name(events))
# No member event for sender.
events = [
((EventTypes.Member, self.USER_ID), {"membership": Membership.INVITE}),
]
self.assertEqual("Room Invite", self._calculate_room_name(events))
def test_no_members(self):
"""Behaviour of an empty room."""
events = []
self.assertEqual("Empty Room", self._calculate_room_name(events))
# Note that events with invalid (or missing) membership are ignored.
events = [
((EventTypes.Member, self.OTHER_USER_ID), {"foo": 1}),
((EventTypes.Member, "@foo:test"), {"membership": "foo"}),
]
self.assertEqual("Empty Room", self._calculate_room_name(events))
def test_no_other_members(self):
"""Behaviour of a room with no other members in it."""
events = [
(
(EventTypes.Member, self.USER_ID),
{"membership": Membership.JOIN, "displayname": "Me"},
),
]
self.assertEqual("Me", self._calculate_room_name(events))
# Check if the event content has no displayname.
events = [
((EventTypes.Member, self.USER_ID), {"membership": Membership.JOIN}),
]
self.assertEqual("@test:test", self._calculate_room_name(events))
# 3pid invite, use the other user (who is set as the sender).
events = [
((EventTypes.Member, self.OTHER_USER_ID), {"membership": Membership.JOIN}),
]
self.assertEqual(
"nobody", self._calculate_room_name(events, user_id=self.OTHER_USER_ID)
)
events = [
((EventTypes.Member, self.OTHER_USER_ID), {"membership": Membership.JOIN}),
((EventTypes.ThirdPartyInvite, self.OTHER_USER_ID), {}),
]
self.assertEqual(
"Inviting email address",
self._calculate_room_name(events, user_id=self.OTHER_USER_ID),
)
def test_one_other_member(self):
"""Behaviour of a room with a single other member."""
events = [
((EventTypes.Member, self.USER_ID), {"membership": Membership.JOIN}),
(
(EventTypes.Member, self.OTHER_USER_ID),
{"membership": Membership.JOIN, "displayname": "Other User"},
),
]
self.assertEqual("Other User", self._calculate_room_name(events))
self.assertIsNone(
self._calculate_room_name(events, fallback_to_single_member=False)
)
# Check if the event content has no displayname and is an invite.
events = [
((EventTypes.Member, self.USER_ID), {"membership": Membership.JOIN}),
(
(EventTypes.Member, self.OTHER_USER_ID),
{"membership": Membership.INVITE},
),
]
self.assertEqual("@user:test", self._calculate_room_name(events))
def test_other_members(self):
"""Behaviour of a room with multiple other members."""
# Two other members.
events = [
((EventTypes.Member, self.USER_ID), {"membership": Membership.JOIN}),
(
(EventTypes.Member, self.OTHER_USER_ID),
{"membership": Membership.JOIN, "displayname": "Other User"},
),
((EventTypes.Member, "@foo:test"), {"membership": Membership.JOIN}),
]
self.assertEqual("Other User and @foo:test", self._calculate_room_name(events))
# Three or more other members.
events.append(
((EventTypes.Member, "@fourth:test"), {"membership": Membership.INVITE})
)
self.assertEqual("Other User and 2 others", self._calculate_room_name(events))

View File

@ -29,7 +29,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
"type": "m.room.history_visibility", "type": "m.room.history_visibility",
"sender": "@user:test", "sender": "@user:test",
"state_key": "", "state_key": "",
"room_id": "@room:test", "room_id": "#room:test",
"content": content, "content": content,
}, },
RoomVersions.V1, RoomVersions.V1,

View File

@ -212,6 +212,9 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
# Fake in memory Redis server that servers can connect to. # Fake in memory Redis server that servers can connect to.
self._redis_server = FakeRedisPubSubServer() self._redis_server = FakeRedisPubSubServer()
# We may have an attempt to connect to redis for the external cache already.
self.connect_any_redis_attempts()
store = self.hs.get_datastore() store = self.hs.get_datastore()
self.database_pool = store.db_pool self.database_pool = store.db_pool
@ -401,25 +404,23 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
fake one. fake one.
""" """
clients = self.reactor.tcpClients clients = self.reactor.tcpClients
self.assertEqual(len(clients), 1) while clients:
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
self.assertEqual(host, "localhost") self.assertEqual(host, "localhost")
self.assertEqual(port, 6379) self.assertEqual(port, 6379)
client_protocol = client_factory.buildProtocol(None) client_protocol = client_factory.buildProtocol(None)
server_protocol = self._redis_server.buildProtocol(None) server_protocol = self._redis_server.buildProtocol(None)
client_to_server_transport = FakeTransport( client_to_server_transport = FakeTransport(
server_protocol, self.reactor, client_protocol server_protocol, self.reactor, client_protocol
) )
client_protocol.makeConnection(client_to_server_transport) client_protocol.makeConnection(client_to_server_transport)
server_to_client_transport = FakeTransport( server_to_client_transport = FakeTransport(
client_protocol, self.reactor, server_protocol client_protocol, self.reactor, server_protocol
) )
server_protocol.makeConnection(server_to_client_transport) server_protocol.makeConnection(server_to_client_transport)
return client_to_server_transport, server_to_client_transport
class TestReplicationDataHandler(GenericWorkerReplicationHandler): class TestReplicationDataHandler(GenericWorkerReplicationHandler):
@ -624,6 +625,12 @@ class FakeRedisPubSubProtocol(Protocol):
(channel,) = args (channel,) = args
self._server.add_subscriber(self) self._server.add_subscriber(self)
self.send(["subscribe", channel, 1]) self.send(["subscribe", channel, 1])
# Since we use SET/GET to cache things we can safely no-op them.
elif command == b"SET":
self.send("OK")
elif command == b"GET":
self.send(None)
else: else:
raise Exception("Unknown command") raise Exception("Unknown command")
@ -645,6 +652,8 @@ class FakeRedisPubSubProtocol(Protocol):
# We assume bytes are just unicode strings. # We assume bytes are just unicode strings.
obj = obj.decode("utf-8") obj = obj.decode("utf-8")
if obj is None:
return "$-1\r\n"
if isinstance(obj, str): if isinstance(obj, str):
return "${len}\r\n{str}\r\n".format(len=len(obj), str=obj) return "${len}\r\n{str}\r\n".format(len=len(obj), str=obj)
if isinstance(obj, int): if isinstance(obj, int):

View File

@ -28,6 +28,7 @@ from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
from synapse.api.room_versions import RoomVersions from synapse.api.room_versions import RoomVersions
from synapse.rest.client.v1 import login, logout, profile, room from synapse.rest.client.v1 import login, logout, profile, room
from synapse.rest.client.v2_alpha import devices, sync from synapse.rest.client.v2_alpha import devices, sync
from synapse.types import JsonDict
from tests import unittest from tests import unittest
from tests.test_utils import make_awaitable from tests.test_utils import make_awaitable
@ -468,13 +469,6 @@ class UsersListTestCase(unittest.HomeserverTestCase):
self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass") self.admin_user_tok = self.login("admin", "pass")
self.user1 = self.register_user(
"user1", "pass1", admin=False, displayname="Name 1"
)
self.user2 = self.register_user(
"user2", "pass2", admin=False, displayname="Name 2"
)
def test_no_auth(self): def test_no_auth(self):
""" """
Try to list users without authentication. Try to list users without authentication.
@ -488,6 +482,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
""" """
If the user is not a server admin, an error is returned. If the user is not a server admin, an error is returned.
""" """
self._create_users(1)
other_user_token = self.login("user1", "pass1") other_user_token = self.login("user1", "pass1")
channel = self.make_request("GET", self.url, access_token=other_user_token) channel = self.make_request("GET", self.url, access_token=other_user_token)
@ -499,6 +494,8 @@ class UsersListTestCase(unittest.HomeserverTestCase):
""" """
List all users, including deactivated users. List all users, including deactivated users.
""" """
self._create_users(2)
channel = self.make_request( channel = self.make_request(
"GET", "GET",
self.url + "?deactivated=true", self.url + "?deactivated=true",
@ -511,14 +508,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
self.assertEqual(3, channel.json_body["total"]) self.assertEqual(3, channel.json_body["total"])
# Check that all fields are available # Check that all fields are available
for u in channel.json_body["users"]: self._check_fields(channel.json_body["users"])
self.assertIn("name", u)
self.assertIn("is_guest", u)
self.assertIn("admin", u)
self.assertIn("user_type", u)
self.assertIn("deactivated", u)
self.assertIn("displayname", u)
self.assertIn("avatar_url", u)
def test_search_term(self): def test_search_term(self):
"""Test that searching for a users works correctly""" """Test that searching for a users works correctly"""
@ -549,6 +539,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
# Check that users were returned # Check that users were returned
self.assertTrue("users" in channel.json_body) self.assertTrue("users" in channel.json_body)
self._check_fields(channel.json_body["users"])
users = channel.json_body["users"] users = channel.json_body["users"]
# Check that the expected number of users were returned # Check that the expected number of users were returned
@ -561,25 +552,30 @@ class UsersListTestCase(unittest.HomeserverTestCase):
u = users[0] u = users[0]
self.assertEqual(expected_user_id, u["name"]) self.assertEqual(expected_user_id, u["name"])
self._create_users(2)
user1 = "@user1:test"
user2 = "@user2:test"
# Perform search tests # Perform search tests
_search_test(self.user1, "er1") _search_test(user1, "er1")
_search_test(self.user1, "me 1") _search_test(user1, "me 1")
_search_test(self.user2, "er2") _search_test(user2, "er2")
_search_test(self.user2, "me 2") _search_test(user2, "me 2")
_search_test(self.user1, "er1", "user_id") _search_test(user1, "er1", "user_id")
_search_test(self.user2, "er2", "user_id") _search_test(user2, "er2", "user_id")
# Test case insensitive # Test case insensitive
_search_test(self.user1, "ER1") _search_test(user1, "ER1")
_search_test(self.user1, "NAME 1") _search_test(user1, "NAME 1")
_search_test(self.user2, "ER2") _search_test(user2, "ER2")
_search_test(self.user2, "NAME 2") _search_test(user2, "NAME 2")
_search_test(self.user1, "ER1", "user_id") _search_test(user1, "ER1", "user_id")
_search_test(self.user2, "ER2", "user_id") _search_test(user2, "ER2", "user_id")
_search_test(None, "foo") _search_test(None, "foo")
_search_test(None, "bar") _search_test(None, "bar")
@ -587,6 +583,179 @@ class UsersListTestCase(unittest.HomeserverTestCase):
_search_test(None, "foo", "user_id") _search_test(None, "foo", "user_id")
_search_test(None, "bar", "user_id") _search_test(None, "bar", "user_id")
def test_invalid_parameter(self):
"""
If parameters are invalid, an error is returned.
"""
# negative limit
channel = self.make_request(
"GET", self.url + "?limit=-5", access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative from
channel = self.make_request(
"GET", self.url + "?from=-5", access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# invalid guests
channel = self.make_request(
"GET", self.url + "?guests=not_bool", access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
# invalid deactivated
channel = self.make_request(
"GET", self.url + "?deactivated=not_bool", access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
def test_limit(self):
"""
Testing list of users with limit
"""
number_users = 20
# Create one less user (since there's already an admin user).
self._create_users(number_users - 1)
channel = self.make_request(
"GET", self.url + "?limit=5", access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), 5)
self.assertEqual(channel.json_body["next_token"], "5")
self._check_fields(channel.json_body["users"])
def test_from(self):
"""
Testing list of users with a defined starting point (from)
"""
number_users = 20
# Create one less user (since there's already an admin user).
self._create_users(number_users - 1)
channel = self.make_request(
"GET", self.url + "?from=5", access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), 15)
self.assertNotIn("next_token", channel.json_body)
self._check_fields(channel.json_body["users"])
def test_limit_and_from(self):
"""
Testing list of users with a defined starting point and limit
"""
number_users = 20
# Create one less user (since there's already an admin user).
self._create_users(number_users - 1)
channel = self.make_request(
"GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(channel.json_body["next_token"], "15")
self.assertEqual(len(channel.json_body["users"]), 10)
self._check_fields(channel.json_body["users"])
def test_next_token(self):
"""
Testing that `next_token` appears at the right place
"""
number_users = 20
# Create one less user (since there's already an admin user).
self._create_users(number_users - 1)
# `next_token` does not appear
# Number of results is the number of entries
channel = self.make_request(
"GET", self.url + "?limit=20", access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), number_users)
self.assertNotIn("next_token", channel.json_body)
# `next_token` does not appear
# Number of max results is larger than the number of entries
channel = self.make_request(
"GET", self.url + "?limit=21", access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), number_users)
self.assertNotIn("next_token", channel.json_body)
# `next_token` does appear
# Number of max results is smaller than the number of entries
channel = self.make_request(
"GET", self.url + "?limit=19", access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), 19)
self.assertEqual(channel.json_body["next_token"], "19")
# Check
# Set `from` to value of `next_token` for request remaining entries
# `next_token` does not appear
channel = self.make_request(
"GET", self.url + "?from=19", access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), 1)
self.assertNotIn("next_token", channel.json_body)
def _check_fields(self, content: JsonDict):
"""Checks that the expected user attributes are present in content
Args:
content: List that is checked for content
"""
for u in content:
self.assertIn("name", u)
self.assertIn("is_guest", u)
self.assertIn("admin", u)
self.assertIn("user_type", u)
self.assertIn("deactivated", u)
self.assertIn("displayname", u)
self.assertIn("avatar_url", u)
def _create_users(self, number_users: int):
"""
Create a number of users
Args:
number_users: Number of users to be created
"""
for i in range(1, number_users + 1):
self.register_user(
"user%d" % i, "pass%d" % i, admin=False, displayname="Name %d" % i,
)
class DeactivateAccountTestCase(unittest.HomeserverTestCase): class DeactivateAccountTestCase(unittest.HomeserverTestCase):
@ -2211,3 +2380,67 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.other_user, channel.json_body["user_id"]) self.assertEqual(self.other_user, channel.json_body["user_id"])
self.assertIn("devices", channel.json_body) self.assertIn("devices", channel.json_body)
class ShadowBanRestTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
]
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
self.other_user = self.register_user("user", "pass")
self.url = "/_synapse/admin/v1/users/%s/shadow_ban" % urllib.parse.quote(
self.other_user
)
def test_no_auth(self):
"""
Try to get information of an user without authentication.
"""
channel = self.make_request("POST", self.url)
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_not_admin(self):
"""
If the user is not a server admin, an error is returned.
"""
other_user_token = self.login("user", "pass")
channel = self.make_request("POST", self.url, access_token=other_user_token)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_user_is_not_local(self):
"""
Tests that shadow-banning for a user that is not a local returns a 400
"""
url = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain"
channel = self.make_request("POST", url, access_token=self.admin_user_tok)
self.assertEqual(400, channel.code, msg=channel.json_body)
def test_success(self):
"""
Shadow-banning should succeed for an admin.
"""
# The user starts off as not shadow-banned.
other_user_token = self.login("user", "pass")
result = self.get_success(self.store.get_user_by_access_token(other_user_token))
self.assertFalse(result.shadow_banned)
channel = self.make_request("POST", self.url, access_token=self.admin_user_tok)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual({}, channel.json_body)
# Ensure the user is shadow-banned (and the cache was cleared).
result = self.get_success(self.store.get_user_by_access_token(other_user_token))
self.assertTrue(result.shadow_banned)

View File

@ -18,6 +18,7 @@ import synapse.rest.admin
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.rest.client.v1 import directory, login, profile, room from synapse.rest.client.v1 import directory, login, profile, room
from synapse.rest.client.v2_alpha import room_upgrade_rest_servlet from synapse.rest.client.v2_alpha import room_upgrade_rest_servlet
from synapse.types import UserID
from tests import unittest from tests import unittest
@ -31,12 +32,7 @@ class _ShadowBannedBase(unittest.HomeserverTestCase):
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()
self.get_success( self.get_success(
self.store.db_pool.simple_update( self.store.set_shadow_banned(UserID.from_string(self.banned_user_id), True)
table="users",
keyvalues={"name": self.banned_user_id},
updatevalues={"shadow_banned": True},
desc="shadow_ban",
)
) )
self.other_user_id = self.register_user("otheruser", "pass") self.other_user_id = self.register_user("otheruser", "pass")

View File

@ -202,7 +202,6 @@ class MediaRepoTests(unittest.HomeserverTestCase):
config = self.default_config() config = self.default_config()
config["media_store_path"] = self.media_store_path config["media_store_path"] = self.media_store_path
config["thumbnail_requirements"] = {}
config["max_image_pixels"] = 2000000 config["max_image_pixels"] = 2000000
provider_config = { provider_config = {
@ -313,15 +312,39 @@ class MediaRepoTests(unittest.HomeserverTestCase):
self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None) self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None)
def test_thumbnail_crop(self): def test_thumbnail_crop(self):
"""Test that a cropped remote thumbnail is available."""
self._test_thumbnail( self._test_thumbnail(
"crop", self.test_image.expected_cropped, self.test_image.expected_found "crop", self.test_image.expected_cropped, self.test_image.expected_found
) )
def test_thumbnail_scale(self): def test_thumbnail_scale(self):
"""Test that a scaled remote thumbnail is available."""
self._test_thumbnail( self._test_thumbnail(
"scale", self.test_image.expected_scaled, self.test_image.expected_found "scale", self.test_image.expected_scaled, self.test_image.expected_found
) )
def test_invalid_type(self):
"""An invalid thumbnail type is never available."""
self._test_thumbnail("invalid", None, False)
@unittest.override_config(
{"thumbnail_sizes": [{"width": 32, "height": 32, "method": "scale"}]}
)
def test_no_thumbnail_crop(self):
"""
Override the config to generate only scaled thumbnails, but request a cropped one.
"""
self._test_thumbnail("crop", None, False)
@unittest.override_config(
{"thumbnail_sizes": [{"width": 32, "height": 32, "method": "crop"}]}
)
def test_no_thumbnail_scale(self):
"""
Override the config to generate only cropped thumbnails, but request a scaled one.
"""
self._test_thumbnail("scale", None, False)
def _test_thumbnail(self, method, expected_body, expected_found): def _test_thumbnail(self, method, expected_body, expected_found):
params = "?width=32&height=32&method=" + method params = "?width=32&height=32&method=" + method
channel = make_request( channel = make_request(

View File

@ -261,3 +261,32 @@ class PreviewUrlTestCase(unittest.TestCase):
html = "" html = ""
og = decode_and_calc_og(html, "http://example.com/test.html") og = decode_and_calc_og(html, "http://example.com/test.html")
self.assertEqual(og, {}) self.assertEqual(og, {})
def test_invalid_encoding(self):
"""An invalid character encoding should be ignored and treated as UTF-8, if possible."""
html = """
<html>
<head><title>Foo</title></head>
<body>
Some text.
</body>
</html>
"""
og = decode_and_calc_og(
html, "http://example.com/test.html", "invalid-encoding"
)
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
def test_invalid_encoding2(self):
"""A body which doesn't match the sent character encoding."""
# Note that this contains an invalid UTF-8 sequence in the title.
html = b"""
<html>
<head><title>\xff\xff Foo</title></head>
<body>
Some text.
</body>
</html>
"""
og = decode_and_calc_og(html, "http://example.com/test.html")
self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."})

18
tox.ini
View File

@ -18,13 +18,16 @@ deps =
# installed on that). # installed on that).
# #
# anyway, make sure that we have a recent enough setuptools. # anyway, make sure that we have a recent enough setuptools.
setuptools>=18.5 setuptools>=18.5 ; python_version >= '3.6'
setuptools>=18.5,<51.0.0 ; python_version < '3.6'
# we also need a semi-recent version of pip, because old ones fail to # we also need a semi-recent version of pip, because old ones fail to
# install the "enum34" dependency of cryptography. # install the "enum34" dependency of cryptography.
pip>=10 pip>=10 ; python_version >= '3.6'
pip>=10,<21.0 ; python_version < '3.6'
# directories/files we run the linters on # directories/files we run the linters on.
# if you update this list, make sure to do the same in scripts-dev/lint.sh
lint_targets = lint_targets =
setup.py setup.py
synapse synapse
@ -103,15 +106,10 @@ usedevelop=true
[testenv:py35-old] [testenv:py35-old]
skip_install=True skip_install=True
deps = deps =
# Ensure a version of setuptools that supports Python 3.5 is installed.
setuptools < 51.0.0
# Old automat version for Twisted # Old automat version for Twisted
Automat == 0.3.0 Automat == 0.3.0
lxml lxml
coverage {[base]deps}
coverage-enable-subprocess==1.0
commands = commands =
# Make all greater-thans equals so we test the oldest version of our direct # Make all greater-thans equals so we test the oldest version of our direct
@ -168,6 +166,8 @@ commands = {toxinidir}/scripts-dev/generate_sample_config --check
skip_install = True skip_install = True
deps = deps =
coverage coverage
pip>=10 ; python_version >= '3.6'
pip>=10,<21.0 ; python_version < '3.6'
commands= commands=
coverage combine coverage combine
coverage report coverage report