Merge remote-tracking branch 'origin/develop' into anoa/public_rooms_module_api

anoa/public_rooms_module_api
Mathieu Velten 2023-05-19 17:35:54 +02:00
commit 75f9e56c77
96 changed files with 1545 additions and 550 deletions

View File

@ -314,8 +314,9 @@ jobs:
# There aren't wheels for some of the older deps, so we need to install
# their build dependencies
- run: |
sudo apt update
sudo apt-get -qq install build-essential libffi-dev python-dev \
libxml2-dev libxslt-dev xmlsec1 zlib1g-dev libjpeg-dev libwebp-dev
libxml2-dev libxslt-dev xmlsec1 zlib1g-dev libjpeg-dev libwebp-dev
- uses: actions/setup-python@v4
with:

1
changelog.d/15464.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a long-standing bug where setting the read marker could fail when using message retention. Contributed by Nick @ Beeper (@fizzadar).

1
changelog.d/15537.misc Normal file
View File

@ -0,0 +1 @@
Add not null constraint to column full_user_id of tables profiles and user_filters.

1
changelog.d/15599.bugfix Normal file
View File

@ -0,0 +1 @@
Print full error and stack-trace of any exception that occurs during startup/initialization.

1
changelog.d/15601.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a long-standing bug where the `url_preview_url_blacklist` configuration setting was not applied to oEmbed or image URLs found while previewing a URL.

1
changelog.d/15602.misc Normal file
View File

@ -0,0 +1 @@
Run mypy type checking with the minimum supported Python version to catch new usage that isn't backwards-compatible.

1
changelog.d/15604.misc Normal file
View File

@ -0,0 +1 @@
Fix subscriptable type usage in Python <3.9.

1
changelog.d/15606.misc Normal file
View File

@ -0,0 +1 @@
Update internal terminology.

View File

@ -0,0 +1 @@
Add a new admin API to create a new device for a user.

1
changelog.d/15613.doc Normal file
View File

@ -0,0 +1 @@
Warn users that at least 3.75GB of space is needed for the nix Synapse development environment.

1
changelog.d/15614.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a bug introduced in Synapse 1.82.0 where the error message displayed when validation of the `app_service_config_files` config option fails would be incorrectly formatted.

1
changelog.d/15615.misc Normal file
View File

@ -0,0 +1 @@
Re-type config paths in `ConfigError`s to be `StrSequence`s instead of `Iterable[str]`s.

1
changelog.d/15620.misc Normal file
View File

@ -0,0 +1 @@
Update internal terminology.

1
changelog.d/15621.misc Normal file
View File

@ -0,0 +1 @@
Update Mutual Rooms (MSC2666) implementation to match new proposal text.

1
changelog.d/15624.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a long-standing bug where deactivated users were still able to login using the custom `org.matrix.login.jwt` login type (if enabled).

1
changelog.d/15625.misc Normal file
View File

@ -0,0 +1 @@
Remove the unstable identifiers from faster joins ([MSC3706](https://github.com/matrix-org/matrix-spec-proposals/pull/3706).

1
changelog.d/15626.misc Normal file
View File

@ -0,0 +1 @@
Fix the olddeps CI.

1
changelog.d/15630.misc Normal file
View File

@ -0,0 +1 @@
Fix two memory leaks in `trial` test runs.

View File

@ -813,6 +813,33 @@ The following fields are returned in the JSON response body:
- `total` - Total number of user's devices.
### Create a device
Creates a new device for a specific `user_id` and `device_id`. Does nothing if the `device_id`
exists already.
The API is:
```
POST /_synapse/admin/v2/users/<user_id>/devices
{
"device_id": "QBUAZIFURK"
}
```
An empty JSON dict is returned.
**Parameters**
The following parameters should be set in the URL:
- `user_id` - fully qualified: for example, `@user:server.com`.
The following fields are required in the JSON request body:
- `device_id` - The device ID to create.
### Delete multiple devices
Deletes the given devices for a specific `user_id`, and invalidates
any access token associated with them.

View File

@ -30,12 +30,6 @@ minimal.
See [the TCP replication documentation](tcp_replication.md).
### The Slaved DataStore
There are read-only version of the synapse storage layer in
`synapse/replication/slave/storage` that use the response of the
replication API to invalidate their caches.
### The TCP Replication Module
Information about how the tcp replication module is structured, including how
the classes interact, can be found in

View File

@ -1,35 +1,30 @@
# A nix flake that sets up a complete Synapse development environment. Dependencies
# A Nix flake that sets up a complete Synapse development environment. Dependencies
# for the SyTest (https://github.com/matrix-org/sytest) and Complement
# (https://github.com/matrix-org/complement) Matrix homeserver test suites are also
# installed automatically.
#
# You must have already installed nix (https://nixos.org) on your system to use this.
# nix can be installed on Linux or MacOS; NixOS is not required. Windows is not
# directly supported, but nix can be installed inside of WSL2 or even Docker
# You must have already installed Nix (https://nixos.org) on your system to use this.
# Nix can be installed on Linux or MacOS; NixOS is not required. Windows is not
# directly supported, but Nix can be installed inside of WSL2 or even Docker
# containers. Please refer to https://nixos.org/download for details.
#
# You must also enable support for flakes in Nix. See the following for how to
# do so permanently: https://nixos.wiki/wiki/Flakes#Enable_flakes
#
# Be warned: you'll need over 3.75 GB of free space to download all the dependencies.
#
# Usage:
#
# With nix installed, navigate to the directory containing this flake and run
# With Nix installed, navigate to the directory containing this flake and run
# `nix develop --impure`. The `--impure` is necessary in order to store state
# locally from "services", such as PostgreSQL and Redis.
#
# You should now be dropped into a new shell with all programs and dependencies
# availabile to you!
#
# You can start up pre-configured, local PostgreSQL and Redis instances by
# You can start up pre-configured local Synapse, PostgreSQL and Redis instances by
# running: `devenv up`. To stop them, use Ctrl-C.
#
# A PostgreSQL database called 'synapse' will be set up for you, along with
# a PostgreSQL user named 'synapse_user'.
# The 'host' can be found by running `echo $PGHOST` with the development
# shell activated. Use these values to configure your Synapse to connect
# to the local PostgreSQL database. You do not need to specify a password.
# https://matrix-org.github.io/synapse/latest/postgres
#
# All state (the venv, postgres and redis data and config) are stored in
# .devenv/state. Deleting a file from here and then re-entering the shell
# will recreate these files from scratch.
@ -66,7 +61,7 @@
let
pkgs = nixpkgs.legacyPackages.${system};
in {
# Everything is configured via devenv - a nix module for creating declarative
# Everything is configured via devenv - a Nix module for creating declarative
# developer environments. See https://devenv.sh/reference/options/ for a list
# of all possible options.
default = devenv.lib.mkShell {
@ -153,11 +148,39 @@
# Redis is needed in order to run Synapse in worker mode.
services.redis.enable = true;
# Configure and start Synapse. Before starting Synapse, this shell code:
# * generates a default homeserver.yaml config file if one does not exist, and
# * ensures a directory containing two additional homeserver config files exists;
# one to configure using the development environment's PostgreSQL as the
# database backend and another for enabling Redis support.
process.before = ''
python -m synapse.app.homeserver -c homeserver.yaml --generate-config --server-name=synapse.dev --report-stats=no
mkdir -p homeserver-config-overrides.d
cat > homeserver-config-overrides.d/database.yaml << EOF
## Do not edit this file. This file is generated by flake.nix
database:
name: psycopg2
args:
user: synapse_user
database: synapse
host: $PGHOST
cp_min: 5
cp_max: 10
EOF
cat > homeserver-config-overrides.d/redis.yaml << EOF
## Do not edit this file. This file is generated by flake.nix
redis:
enabled: true
EOF
'';
# Start synapse when `devenv up` is run.
processes.synapse.exec = "poetry run python -m synapse.app.homeserver -c homeserver.yaml --config-directory homeserver-config-overrides.d";
# Define the perl modules we require to run SyTest.
#
# This list was compiled by cross-referencing https://metacpan.org/
# with the modules defined in './cpanfile' and then finding the
# corresponding nix packages on https://search.nixos.org/packages.
# corresponding Nix packages on https://search.nixos.org/packages.
#
# This was done until `./install-deps.pl --dryrun` produced no output.
env.PERL5LIB = "${with pkgs.perl536Packages; makePerlPath [

View File

@ -13,6 +13,9 @@ no_implicit_optional = True
disallow_untyped_defs = True
strict_equality = True
warn_redundant_casts = True
# Run mypy type checking with the minimum supported Python version to catch new usage
# that isn't backwards-compatible (types, overloads, etc).
python_version = 3.8
files =
docker/,

View File

@ -214,7 +214,7 @@ def handle_startup_exception(e: Exception) -> NoReturn:
# the reactor are written to the logs, followed by a summary to stderr.
logger.exception("Exception during startup")
error_string = "".join(traceback.format_exception(e))
error_string = "".join(traceback.format_exception(type(e), e, e.__traceback__))
indented_error_string = indent(error_string, " ")
quit_with_error(

View File

@ -64,7 +64,7 @@ from synapse.util.logcontext import LoggingContext
logger = logging.getLogger("synapse.app.admin_cmd")
class AdminCmdSlavedStore(
class AdminCmdStore(
FilteringWorkerStore,
ClientIpWorkerStore,
DeviceWorkerStore,
@ -103,7 +103,7 @@ class AdminCmdSlavedStore(
class AdminCmdServer(HomeServer):
DATASTORE_CLASS = AdminCmdSlavedStore # type: ignore
DATASTORE_CLASS = AdminCmdStore # type: ignore
async def export_data_command(hs: HomeServer, args: argparse.Namespace) -> None:

View File

@ -102,7 +102,7 @@ from synapse.util.httpresourcetree import create_resource_tree
logger = logging.getLogger("synapse.app.generic_worker")
class GenericWorkerSlavedStore(
class GenericWorkerStore(
# FIXME(#3714): We need to add UserDirectoryStore as we write directly
# rather than going via the correct worker.
UserDirectoryStore,
@ -154,7 +154,7 @@ class GenericWorkerSlavedStore(
class GenericWorkerServer(HomeServer):
DATASTORE_CLASS = GenericWorkerSlavedStore # type: ignore
DATASTORE_CLASS = GenericWorkerStore # type: ignore
def _listen_http(self, listener_config: ListenerConfig) -> None:
assert listener_config.http_options is not None

View File

@ -44,6 +44,7 @@ import jinja2
import pkg_resources
import yaml
from synapse.types import StrSequence
from synapse.util.templates import _create_mxc_to_http_filter, _format_ts_filter
logger = logging.getLogger(__name__)
@ -58,7 +59,7 @@ class ConfigError(Exception):
the problem lies.
"""
def __init__(self, msg: str, path: Optional[Iterable[str]] = None):
def __init__(self, msg: str, path: Optional[StrSequence] = None):
self.msg = msg
self.path = path

View File

@ -61,9 +61,10 @@ from synapse.config import ( # noqa: F401
voip,
workers,
)
from synapse.types import StrSequence
class ConfigError(Exception):
def __init__(self, msg: str, path: Optional[Iterable[str]] = None):
def __init__(self, msg: str, path: Optional[StrSequence] = None):
self.msg = msg
self.path = path

View File

@ -11,17 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Iterable, Type, TypeVar
from typing import Any, Dict, Type, TypeVar
import jsonschema
from pydantic import BaseModel, ValidationError, parse_obj_as
from synapse.config._base import ConfigError
from synapse.types import JsonDict
from synapse.types import JsonDict, StrSequence
def validate_config(
json_schema: JsonDict, config: Any, config_path: Iterable[str]
json_schema: JsonDict, config: Any, config_path: StrSequence
) -> None:
"""Validates a config setting against a JsonSchema definition
@ -45,7 +45,7 @@ def validate_config(
def json_error_to_config_error(
e: jsonschema.ValidationError, config_path: Iterable[str]
e: jsonschema.ValidationError, config_path: StrSequence
) -> ConfigError:
"""Converts a json validation error to a user-readable ConfigError

View File

@ -36,11 +36,10 @@ class AppServiceConfig(Config):
if not isinstance(self.app_service_config_files, list) or not all(
type(x) is str for x in self.app_service_config_files
):
# type-ignore: this function gets arbitrary json value; we do use this path.
raise ConfigError(
"Expected '%s' to be a list of AS config files:"
% (self.app_service_config_files),
"app_service_config_files",
("app_service_config_files",),
)
self.track_appservice_user_ips = config.get("track_appservice_user_ips", False)

View File

@ -84,18 +84,6 @@ class ExperimentalConfig(Config):
"msc3984_appservice_key_query", False
)
# MSC3706 (server-side support for partial state in /send_join responses)
# Synapse will always serve partial state responses to requests using the stable
# query parameter `omit_members`. If this flag is set, Synapse will also serve
# partial state responses to requests using the unstable query parameter
# `org.matrix.msc3706.partial_state`.
self.msc3706_enabled: bool = experimental.get("msc3706_enabled", False)
# experimental support for faster joins over federation
# (MSC2775, MSC3706, MSC3895)
# requires a target server that can provide a partial join response (MSC3706)
self.faster_joins_enabled: bool = experimental.get("faster_joins", True)
# MSC3720 (Account status endpoint)
self.msc3720_enabled: bool = experimental.get("msc3720_enabled", False)

View File

@ -19,7 +19,7 @@ from urllib import parse as urlparse
import attr
import pkg_resources
from synapse.types import JsonDict
from synapse.types import JsonDict, StrSequence
from ._base import Config, ConfigError
from ._util import validate_config
@ -80,7 +80,7 @@ class OembedConfig(Config):
)
def _parse_and_validate_provider(
self, providers: List[JsonDict], config_path: Iterable[str]
self, providers: List[JsonDict], config_path: StrSequence
) -> Iterable[OEmbedEndpointConfig]:
# Ensure it is the proper form.
validate_config(
@ -112,7 +112,7 @@ class OembedConfig(Config):
api_endpoint, patterns, endpoint.get("formats")
)
def _glob_to_pattern(self, glob: str, config_path: Iterable[str]) -> Pattern:
def _glob_to_pattern(self, glob: str, config_path: StrSequence) -> Pattern:
"""
Convert the glob into a sane regular expression to match against. The
rules followed will be slightly different for the domain portion vs.

View File

@ -224,20 +224,20 @@ class ContentRepositoryConfig(Config):
if "http" in proxy_env or "https" in proxy_env:
logger.warning("".join(HTTP_PROXY_SET_WARNING))
# we always blacklist '0.0.0.0' and '::', which are supposed to be
# we always block '0.0.0.0' and '::', which are supposed to be
# unroutable addresses.
self.url_preview_ip_range_blacklist = generate_ip_set(
self.url_preview_ip_range_blocklist = generate_ip_set(
config["url_preview_ip_range_blacklist"],
["0.0.0.0", "::"],
config_path=("url_preview_ip_range_blacklist",),
)
self.url_preview_ip_range_whitelist = generate_ip_set(
self.url_preview_ip_range_allowlist = generate_ip_set(
config.get("url_preview_ip_range_whitelist", ()),
config_path=("url_preview_ip_range_whitelist",),
)
self.url_preview_url_blacklist = config.get("url_preview_url_blacklist", ())
self.url_preview_url_blocklist = config.get("url_preview_url_blacklist", ())
self.url_preview_accept_language = config.get(
"url_preview_accept_language"

View File

@ -27,7 +27,7 @@ from netaddr import AddrFormatError, IPNetwork, IPSet
from twisted.conch.ssh.keys import Key
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.types import JsonDict
from synapse.types import JsonDict, StrSequence
from synapse.util.module_loader import load_module
from synapse.util.stringutils import parse_and_validate_server_name
@ -73,7 +73,7 @@ def _6to4(network: IPNetwork) -> IPNetwork:
def generate_ip_set(
ip_addresses: Optional[Iterable[str]],
extra_addresses: Optional[Iterable[str]] = None,
config_path: Optional[Iterable[str]] = None,
config_path: Optional[StrSequence] = None,
) -> IPSet:
"""
Generate an IPSet from a list of IP addresses or CIDRs.
@ -115,7 +115,7 @@ def generate_ip_set(
# IP ranges that are considered private / unroutable / don't make sense.
DEFAULT_IP_RANGE_BLACKLIST = [
DEFAULT_IP_RANGE_BLOCKLIST = [
# Localhost
"127.0.0.0/8",
# Private networks.
@ -501,36 +501,36 @@ class ServerConfig(Config):
# due to resource constraints
self.admin_contact = config.get("admin_contact", None)
ip_range_blacklist = config.get(
"ip_range_blacklist", DEFAULT_IP_RANGE_BLACKLIST
ip_range_blocklist = config.get(
"ip_range_blacklist", DEFAULT_IP_RANGE_BLOCKLIST
)
# Attempt to create an IPSet from the given ranges
# Always blacklist 0.0.0.0, ::
self.ip_range_blacklist = generate_ip_set(
ip_range_blacklist, ["0.0.0.0", "::"], config_path=("ip_range_blacklist",)
# Always block 0.0.0.0, ::
self.ip_range_blocklist = generate_ip_set(
ip_range_blocklist, ["0.0.0.0", "::"], config_path=("ip_range_blacklist",)
)
self.ip_range_whitelist = generate_ip_set(
self.ip_range_allowlist = generate_ip_set(
config.get("ip_range_whitelist", ()), config_path=("ip_range_whitelist",)
)
# The federation_ip_range_blacklist is used for backwards-compatibility
# and only applies to federation and identity servers.
if "federation_ip_range_blacklist" in config:
# Always blacklist 0.0.0.0, ::
self.federation_ip_range_blacklist = generate_ip_set(
# Always block 0.0.0.0, ::
self.federation_ip_range_blocklist = generate_ip_set(
config["federation_ip_range_blacklist"],
["0.0.0.0", "::"],
config_path=("federation_ip_range_blacklist",),
)
# 'federation_ip_range_whitelist' was never a supported configuration option.
self.federation_ip_range_whitelist = None
self.federation_ip_range_allowlist = None
else:
# No backwards-compatiblity requrired, as federation_ip_range_blacklist
# is not given. Default to ip_range_blacklist and ip_range_whitelist.
self.federation_ip_range_blacklist = self.ip_range_blacklist
self.federation_ip_range_whitelist = self.ip_range_whitelist
self.federation_ip_range_blocklist = self.ip_range_blocklist
self.federation_ip_range_allowlist = self.ip_range_allowlist
# (undocumented) option for torturing the worker-mode replication a bit,
# for testing. The value defines the number of milliseconds to pause before

View File

@ -739,12 +739,10 @@ class FederationServer(FederationBase):
"event": event_json,
"state": [p.get_pdu_json(time_now) for p in state_events],
"auth_chain": [p.get_pdu_json(time_now) for p in auth_chain_events],
"org.matrix.msc3706.partial_state": caller_supports_partial_state,
"members_omitted": caller_supports_partial_state,
}
if servers_in_room is not None:
resp["org.matrix.msc3706.servers_in_room"] = list(servers_in_room)
resp["servers_in_room"] = list(servers_in_room)
return resp

View File

@ -59,7 +59,6 @@ class TransportLayerClient:
def __init__(self, hs: "HomeServer"):
self.client = hs.get_federation_http_client()
self._faster_joins_enabled = hs.config.experimental.faster_joins_enabled
self._is_mine_server_name = hs.is_mine_server_name
async def get_room_state_ids(
@ -363,12 +362,8 @@ class TransportLayerClient:
) -> "SendJoinResponse":
path = _create_v2_path("/send_join/%s/%s", room_id, event_id)
query_params: Dict[str, str] = {}
if self._faster_joins_enabled:
# lazy-load state on join
query_params["org.matrix.msc3706.partial_state"] = (
"true" if omit_members else "false"
)
query_params["omit_members"] = "true" if omit_members else "false"
# lazy-load state on join
query_params["omit_members"] = "true" if omit_members else "false"
return await self.client.put_json(
destination=destination,
@ -902,9 +897,7 @@ def _members_omitted_parser(response: SendJoinResponse) -> Generator[None, Any,
while True:
val = yield
if not isinstance(val, bool):
raise TypeError(
"members_omitted (formerly org.matrix.msc370c.partial_state) must be a boolean"
)
raise TypeError("members_omitted must be a boolean")
response.members_omitted = val
@ -964,14 +957,6 @@ class SendJoinParser(ByteParser[SendJoinResponse]):
]
if not v1_api:
self._coros.append(
ijson.items_coro(
_members_omitted_parser(self._response),
"org.matrix.msc3706.partial_state",
use_float="True",
)
)
# The stable field name comes last, so it "wins" if the fields disagree
self._coros.append(
ijson.items_coro(
_members_omitted_parser(self._response),
@ -980,14 +965,6 @@ class SendJoinParser(ByteParser[SendJoinResponse]):
)
)
self._coros.append(
ijson.items_coro(
_servers_in_room_parser(self._response),
"org.matrix.msc3706.servers_in_room",
use_float="True",
)
)
# Again, stable field name comes last
self._coros.append(
ijson.items_coro(

View File

@ -440,7 +440,6 @@ class FederationV2SendJoinServlet(BaseFederationServerServlet):
server_name: str,
):
super().__init__(hs, authenticator, ratelimiter, server_name)
self._read_msc3706_query_param = hs.config.experimental.msc3706_enabled
async def on_PUT(
self,
@ -453,16 +452,7 @@ class FederationV2SendJoinServlet(BaseFederationServerServlet):
# TODO(paul): assert that event_id parsed from path actually
# match those given in content
partial_state = False
# The stable query parameter wins, if it disagrees with the unstable
# parameter for some reason.
stable_param = parse_boolean_from_args(query, "omit_members", default=None)
if stable_param is not None:
partial_state = stable_param
elif self._read_msc3706_query_param:
partial_state = parse_boolean_from_args(
query, "org.matrix.msc3706.partial_state", default=False
)
partial_state = parse_boolean_from_args(query, "omit_members", default=False)
result = await self.handler.on_send_join_request(
origin, content, room_id, caller_supports_partial_state=partial_state

View File

@ -148,7 +148,7 @@ class FederationHandler:
self._event_auth_handler = hs.get_event_auth_handler()
self._server_notices_mxid = hs.config.servernotices.server_notices_mxid
self.config = hs.config
self.http_client = hs.get_proxied_blacklisted_http_client()
self.http_client = hs.get_proxied_blocklisted_http_client()
self._replication = hs.get_replication_data_handler()
self._federation_event_handler = hs.get_federation_event_handler()
self._device_handler = hs.get_device_handler()

View File

@ -52,10 +52,10 @@ class IdentityHandler:
# An HTTP client for contacting trusted URLs.
self.http_client = SimpleHttpClient(hs)
# An HTTP client for contacting identity servers specified by clients.
self.blacklisting_http_client = SimpleHttpClient(
self._http_client = SimpleHttpClient(
hs,
ip_blacklist=hs.config.server.federation_ip_range_blacklist,
ip_whitelist=hs.config.server.federation_ip_range_whitelist,
ip_blocklist=hs.config.server.federation_ip_range_blocklist,
ip_allowlist=hs.config.server.federation_ip_range_allowlist,
)
self.federation_http_client = hs.get_federation_http_client()
self.hs = hs
@ -197,7 +197,7 @@ class IdentityHandler:
try:
# Use the blacklisting http client as this call is only to identity servers
# provided by a client
data = await self.blacklisting_http_client.post_json_get_json(
data = await self._http_client.post_json_get_json(
bind_url, bind_data, headers=headers
)
@ -308,9 +308,7 @@ class IdentityHandler:
try:
# Use the blacklisting http client as this call is only to identity servers
# provided by a client
await self.blacklisting_http_client.post_json_get_json(
url, content, headers
)
await self._http_client.post_json_get_json(url, content, headers)
changed = True
except HttpResponseException as e:
changed = False
@ -579,7 +577,7 @@ class IdentityHandler:
"""
# Check what hashing details are supported by this identity server
try:
hash_details = await self.blacklisting_http_client.get_json(
hash_details = await self._http_client.get_json(
"%s%s/_matrix/identity/v2/hash_details" % (id_server_scheme, id_server),
{"access_token": id_access_token},
)
@ -646,7 +644,7 @@ class IdentityHandler:
headers = {"Authorization": create_id_access_token_header(id_access_token)}
try:
lookup_results = await self.blacklisting_http_client.post_json_get_json(
lookup_results = await self._http_client.post_json_get_json(
"%s%s/_matrix/identity/v2/lookup" % (id_server_scheme, id_server),
{
"addresses": [lookup_value],
@ -752,7 +750,7 @@ class IdentityHandler:
url = "%s%s/_matrix/identity/v2/store-invite" % (id_server_scheme, id_server)
try:
data = await self.blacklisting_http_client.post_json_get_json(
data = await self._http_client.post_json_get_json(
url,
invite_config,
{"Authorization": create_id_access_token_header(id_access_token)},

118
synapse/handlers/jwt.py Normal file
View File

@ -0,0 +1,118 @@
# Copyright 2023 Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from authlib.jose import JsonWebToken, JWTClaims
from authlib.jose.errors import BadSignatureError, InvalidClaimError, JoseError
from synapse.api.errors import Codes, LoginError, StoreError, UserDeactivatedError
from synapse.types import JsonDict, UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
class JwtHandler:
def __init__(self, hs: "HomeServer"):
self.hs = hs
self._main_store = hs.get_datastores().main
self.jwt_secret = hs.config.jwt.jwt_secret
self.jwt_subject_claim = hs.config.jwt.jwt_subject_claim
self.jwt_algorithm = hs.config.jwt.jwt_algorithm
self.jwt_issuer = hs.config.jwt.jwt_issuer
self.jwt_audiences = hs.config.jwt.jwt_audiences
async def validate_login(self, login_submission: JsonDict) -> str:
"""
Authenticates the user for the /login API
Args:
login_submission: the whole of the login submission
(including 'type' and other relevant fields)
Returns:
The user ID that is logging in.
Raises:
LoginError if there was an authentication problem.
"""
token = login_submission.get("token", None)
if token is None:
raise LoginError(
403, "Token field for JWT is missing", errcode=Codes.FORBIDDEN
)
jwt = JsonWebToken([self.jwt_algorithm])
claim_options = {}
if self.jwt_issuer is not None:
claim_options["iss"] = {"value": self.jwt_issuer, "essential": True}
if self.jwt_audiences is not None:
claim_options["aud"] = {"values": self.jwt_audiences, "essential": True}
try:
claims = jwt.decode(
token,
key=self.jwt_secret,
claims_cls=JWTClaims,
claims_options=claim_options,
)
except BadSignatureError:
# We handle this case separately to provide a better error message
raise LoginError(
403,
"JWT validation failed: Signature verification failed",
errcode=Codes.FORBIDDEN,
)
except JoseError as e:
# A JWT error occurred, return some info back to the client.
raise LoginError(
403,
"JWT validation failed: %s" % (str(e),),
errcode=Codes.FORBIDDEN,
)
try:
claims.validate(leeway=120) # allows 2 min of clock skew
# Enforce the old behavior which is rolled out in productive
# servers: if the JWT contains an 'aud' claim but none is
# configured, the login attempt will fail
if claims.get("aud") is not None:
if self.jwt_audiences is None or len(self.jwt_audiences) == 0:
raise InvalidClaimError("aud")
except JoseError as e:
raise LoginError(
403,
"JWT validation failed: %s" % (str(e),),
errcode=Codes.FORBIDDEN,
)
user = claims.get(self.jwt_subject_claim, None)
if user is None:
raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN)
user_id = UserID(user, self.hs.hostname).to_string()
# If the account has been deactivated, do not proceed with the login
# flow.
try:
deactivated = await self._main_store.get_user_deactivated_status(user_id)
except StoreError:
# JWT lazily creates users, so they may not exist in the database yet.
deactivated = False
if deactivated:
raise UserDeactivatedError("This account has been deactivated")
return user_id

View File

@ -16,6 +16,7 @@ import logging
from typing import TYPE_CHECKING
from synapse.api.constants import ReceiptTypes
from synapse.api.errors import SynapseError
from synapse.util.async_helpers import Linearizer
if TYPE_CHECKING:
@ -47,12 +48,21 @@ class ReadMarkerHandler:
)
should_update = True
# Get event ordering, this also ensures we know about the event
event_ordering = await self.store.get_event_ordering(event_id)
if existing_read_marker:
# Only update if the new marker is ahead in the stream
should_update = await self.store.is_event_after(
event_id, existing_read_marker["event_id"]
)
try:
old_event_ordering = await self.store.get_event_ordering(
existing_read_marker["event_id"]
)
except SynapseError:
# Old event no longer exists, assume new is ahead. This may
# happen if the old event was removed due to retention.
pass
else:
# Only update if the new marker is ahead in the stream
should_update = event_ordering > old_event_ordering
if should_update:
content = {"event_id": event_id}

View File

@ -204,7 +204,7 @@ class SsoHandler:
self._media_repo = (
hs.get_media_repository() if hs.config.media.can_load_media_repo else None
)
self._http_client = hs.get_proxied_blacklisted_http_client()
self._http_client = hs.get_proxied_blocklisted_http_client()
# The following template is shown after a successful user interactive
# authentication session. It tells the user they can close the window.

View File

@ -117,22 +117,22 @@ RawHeaderValue = Union[
]
def check_against_blacklist(
ip_address: IPAddress, ip_whitelist: Optional[IPSet], ip_blacklist: IPSet
def _is_ip_blocked(
ip_address: IPAddress, allowlist: Optional[IPSet], blocklist: IPSet
) -> bool:
"""
Compares an IP address to allowed and disallowed IP sets.
Args:
ip_address: The IP address to check
ip_whitelist: Allowed IP addresses.
ip_blacklist: Disallowed IP addresses.
allowlist: Allowed IP addresses.
blocklist: Disallowed IP addresses.
Returns:
True if the IP address is in the blacklist and not in the whitelist.
True if the IP address is in the blocklist and not in the allowlist.
"""
if ip_address in ip_blacklist:
if ip_whitelist is None or ip_address not in ip_whitelist:
if ip_address in blocklist:
if allowlist is None or ip_address not in allowlist:
return True
return False
@ -154,27 +154,27 @@ def _make_scheduler(
return _scheduler
class _IPBlacklistingResolver:
class _IPBlockingResolver:
"""
A proxy for reactor.nameResolver which only produces non-blacklisted IP
addresses, preventing DNS rebinding attacks on URL preview.
A proxy for reactor.nameResolver which only produces non-blocklisted IP
addresses, preventing DNS rebinding attacks.
"""
def __init__(
self,
reactor: IReactorPluggableNameResolver,
ip_whitelist: Optional[IPSet],
ip_blacklist: IPSet,
ip_allowlist: Optional[IPSet],
ip_blocklist: IPSet,
):
"""
Args:
reactor: The twisted reactor.
ip_whitelist: IP addresses to allow.
ip_blacklist: IP addresses to disallow.
ip_allowlist: IP addresses to allow.
ip_blocklist: IP addresses to disallow.
"""
self._reactor = reactor
self._ip_whitelist = ip_whitelist
self._ip_blacklist = ip_blacklist
self._ip_allowlist = ip_allowlist
self._ip_blocklist = ip_blocklist
def resolveHostName(
self, recv: IResolutionReceiver, hostname: str, portNumber: int = 0
@ -191,16 +191,13 @@ class _IPBlacklistingResolver:
ip_address = IPAddress(address.host)
if check_against_blacklist(
ip_address, self._ip_whitelist, self._ip_blacklist
):
if _is_ip_blocked(ip_address, self._ip_allowlist, self._ip_blocklist):
logger.info(
"Dropped %s from DNS resolution to %s due to blacklist"
% (ip_address, hostname)
"Blocked %s from DNS resolution to %s" % (ip_address, hostname)
)
has_bad_ip = True
# if we have a blacklisted IP, we'd like to raise an error to block the
# if we have a blocked IP, we'd like to raise an error to block the
# request, but all we can really do from here is claim that there were no
# valid results.
if not has_bad_ip:
@ -232,24 +229,24 @@ class _IPBlacklistingResolver:
# ISynapseReactor implies IReactorCore, but explicitly marking it this as an implementer
# of IReactorCore seems to keep mypy-zope happier.
@implementer(IReactorCore, ISynapseReactor)
class BlacklistingReactorWrapper:
class BlocklistingReactorWrapper:
"""
A Reactor wrapper which will prevent DNS resolution to blacklisted IP
A Reactor wrapper which will prevent DNS resolution to blocked IP
addresses, to prevent DNS rebinding.
"""
def __init__(
self,
reactor: IReactorPluggableNameResolver,
ip_whitelist: Optional[IPSet],
ip_blacklist: IPSet,
ip_allowlist: Optional[IPSet],
ip_blocklist: IPSet,
):
self._reactor = reactor
# We need to use a DNS resolver which filters out blacklisted IP
# We need to use a DNS resolver which filters out blocked IP
# addresses, to prevent DNS rebinding.
self._nameResolver = _IPBlacklistingResolver(
self._reactor, ip_whitelist, ip_blacklist
self._nameResolver = _IPBlockingResolver(
self._reactor, ip_allowlist, ip_blocklist
)
def __getattr__(self, attr: str) -> Any:
@ -260,7 +257,7 @@ class BlacklistingReactorWrapper:
return getattr(self._reactor, attr)
class BlacklistingAgentWrapper(Agent):
class BlocklistingAgentWrapper(Agent):
"""
An Agent wrapper which will prevent access to IP addresses being accessed
directly (without an IP address lookup).
@ -269,18 +266,18 @@ class BlacklistingAgentWrapper(Agent):
def __init__(
self,
agent: IAgent,
ip_blacklist: IPSet,
ip_whitelist: Optional[IPSet] = None,
ip_blocklist: IPSet,
ip_allowlist: Optional[IPSet] = None,
):
"""
Args:
agent: The Agent to wrap.
ip_whitelist: IP addresses to allow.
ip_blacklist: IP addresses to disallow.
ip_allowlist: IP addresses to allow.
ip_blocklist: IP addresses to disallow.
"""
self._agent = agent
self._ip_whitelist = ip_whitelist
self._ip_blacklist = ip_blacklist
self._ip_allowlist = ip_allowlist
self._ip_blocklist = ip_blocklist
def request(
self,
@ -299,13 +296,9 @@ class BlacklistingAgentWrapper(Agent):
# Not an IP
pass
else:
if check_against_blacklist(
ip_address, self._ip_whitelist, self._ip_blacklist
):
logger.info("Blocking access to %s due to blacklist" % (ip_address,))
e = SynapseError(
HTTPStatus.FORBIDDEN, "IP address blocked by IP blacklist entry"
)
if _is_ip_blocked(ip_address, self._ip_allowlist, self._ip_blocklist):
logger.info("Blocking access to %s" % (ip_address,))
e = SynapseError(HTTPStatus.FORBIDDEN, "IP address blocked")
return defer.fail(Failure(e))
return self._agent.request(
@ -763,10 +756,9 @@ class SimpleHttpClient(BaseHttpClient):
Args:
hs: The HomeServer instance to pass in
treq_args: Extra keyword arguments to be given to treq.request.
ip_blacklist: The IP addresses that are blacklisted that
we may not request.
ip_whitelist: The whitelisted IP addresses, that we can
request if it were otherwise caught in a blacklist.
ip_blocklist: The IP addresses that we may not request.
ip_allowlist: The allowed IP addresses, that we can
request if it were otherwise caught in a blocklist.
use_proxy: Whether proxy settings should be discovered and used
from conventional environment variables.
"""
@ -775,19 +767,19 @@ class SimpleHttpClient(BaseHttpClient):
self,
hs: "HomeServer",
treq_args: Optional[Dict[str, Any]] = None,
ip_whitelist: Optional[IPSet] = None,
ip_blacklist: Optional[IPSet] = None,
ip_allowlist: Optional[IPSet] = None,
ip_blocklist: Optional[IPSet] = None,
use_proxy: bool = False,
):
super().__init__(hs, treq_args=treq_args)
self._ip_whitelist = ip_whitelist
self._ip_blacklist = ip_blacklist
self._ip_allowlist = ip_allowlist
self._ip_blocklist = ip_blocklist
if self._ip_blacklist:
# If we have an IP blacklist, we need to use a DNS resolver which
# filters out blacklisted IP addresses, to prevent DNS rebinding.
self.reactor: ISynapseReactor = BlacklistingReactorWrapper(
self.reactor, self._ip_whitelist, self._ip_blacklist
if self._ip_blocklist:
# If we have an IP blocklist, we need to use a DNS resolver which
# filters out blocked IP addresses, to prevent DNS rebinding.
self.reactor: ISynapseReactor = BlocklistingReactorWrapper(
self.reactor, self._ip_allowlist, self._ip_blocklist
)
# the pusher makes lots of concurrent SSL connections to Sygnal, and tends to
@ -809,14 +801,13 @@ class SimpleHttpClient(BaseHttpClient):
use_proxy=use_proxy,
)
if self._ip_blacklist:
# If we have an IP blacklist, we then install the blacklisting Agent
# which prevents direct access to IP addresses, that are not caught
# by the DNS resolution.
self.agent = BlacklistingAgentWrapper(
if self._ip_blocklist:
# If we have an IP blocklist, we then install the Agent which prevents
# direct access to IP addresses, that are not caught by the DNS resolution.
self.agent = BlocklistingAgentWrapper(
self.agent,
ip_blacklist=self._ip_blacklist,
ip_whitelist=self._ip_whitelist,
ip_blocklist=self._ip_blocklist,
ip_allowlist=self._ip_allowlist,
)

View File

@ -36,7 +36,7 @@ from twisted.web.iweb import IAgent, IAgentEndpointFactory, IBodyProducer, IResp
from synapse.crypto.context_factory import FederationPolicyForHTTPS
from synapse.http import proxyagent
from synapse.http.client import BlacklistingAgentWrapper, BlacklistingReactorWrapper
from synapse.http.client import BlocklistingAgentWrapper, BlocklistingReactorWrapper
from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint
from synapse.http.federation.srv_resolver import Server, SrvResolver
from synapse.http.federation.well_known_resolver import WellKnownResolver
@ -65,12 +65,12 @@ class MatrixFederationAgent:
user_agent:
The user agent header to use for federation requests.
ip_whitelist: Allowed IP addresses.
ip_allowlist: Allowed IP addresses.
ip_blacklist: Disallowed IP addresses.
ip_blocklist: Disallowed IP addresses.
proxy_reactor: twisted reactor to use for connections to the proxy server
reactor might have some blacklisting applied (i.e. for DNS queries),
reactor might have some blocking applied (i.e. for DNS queries),
but we need unblocked access to the proxy.
_srv_resolver:
@ -87,17 +87,17 @@ class MatrixFederationAgent:
reactor: ISynapseReactor,
tls_client_options_factory: Optional[FederationPolicyForHTTPS],
user_agent: bytes,
ip_whitelist: Optional[IPSet],
ip_blacklist: IPSet,
ip_allowlist: Optional[IPSet],
ip_blocklist: IPSet,
_srv_resolver: Optional[SrvResolver] = None,
_well_known_resolver: Optional[WellKnownResolver] = None,
):
# proxy_reactor is not blacklisted
# proxy_reactor is not blocklisting reactor
proxy_reactor = reactor
# We need to use a DNS resolver which filters out blacklisted IP
# We need to use a DNS resolver which filters out blocked IP
# addresses, to prevent DNS rebinding.
reactor = BlacklistingReactorWrapper(reactor, ip_whitelist, ip_blacklist)
reactor = BlocklistingReactorWrapper(reactor, ip_allowlist, ip_blocklist)
self._clock = Clock(reactor)
self._pool = HTTPConnectionPool(reactor)
@ -120,7 +120,7 @@ class MatrixFederationAgent:
if _well_known_resolver is None:
_well_known_resolver = WellKnownResolver(
reactor,
agent=BlacklistingAgentWrapper(
agent=BlocklistingAgentWrapper(
ProxyAgent(
reactor,
proxy_reactor,
@ -128,7 +128,7 @@ class MatrixFederationAgent:
contextFactory=tls_client_options_factory,
use_proxy=True,
),
ip_blacklist=ip_blacklist,
ip_blocklist=ip_blocklist,
),
user_agent=self.user_agent,
)
@ -256,7 +256,7 @@ class MatrixHostnameEndpoint:
Args:
reactor: twisted reactor to use for underlying requests
proxy_reactor: twisted reactor to use for connections to the proxy server.
'reactor' might have some blacklisting applied (i.e. for DNS queries),
'reactor' might have some blocking applied (i.e. for DNS queries),
but we need unblocked access to the proxy.
tls_client_options_factory:
factory to use for fetching client tls options, or none to disable TLS.

View File

@ -64,7 +64,7 @@ from synapse.api.errors import (
from synapse.crypto.context_factory import FederationPolicyForHTTPS
from synapse.http import QuieterFileBodyProducer
from synapse.http.client import (
BlacklistingAgentWrapper,
BlocklistingAgentWrapper,
BodyExceededMaxSize,
ByteWriteable,
_make_scheduler,
@ -392,15 +392,15 @@ class MatrixFederationHttpClient:
self.reactor,
tls_client_options_factory,
user_agent.encode("ascii"),
hs.config.server.federation_ip_range_whitelist,
hs.config.server.federation_ip_range_blacklist,
hs.config.server.federation_ip_range_allowlist,
hs.config.server.federation_ip_range_blocklist,
)
# Use a BlacklistingAgentWrapper to prevent circumventing the IP
# blacklist via IP literals in server names
self.agent = BlacklistingAgentWrapper(
# Use a BlocklistingAgentWrapper to prevent circumventing the IP
# blocking via IP literals in server names
self.agent = BlocklistingAgentWrapper(
federation_agent,
ip_blacklist=hs.config.server.federation_ip_range_blacklist,
ip_blocklist=hs.config.server.federation_ip_range_blocklist,
)
self.clock = hs.get_clock()

View File

@ -53,7 +53,7 @@ class ProxyAgent(_AgentBase):
connections.
proxy_reactor: twisted reactor to use for connections to the proxy server
reactor might have some blacklisting applied (i.e. for DNS queries),
reactor might have some blocking applied (i.e. for DNS queries),
but we need unblocked access to the proxy.
contextFactory: A factory for TLS contexts, to control the

View File

@ -105,7 +105,7 @@ class UrlPreviewer:
When Synapse is asked to preview a URL it does the following:
1. Checks against a URL blacklist (defined as `url_preview_url_blacklist` in the
1. Checks against a URL blocklist (defined as `url_preview_url_blacklist` in the
config).
2. Checks the URL against an in-memory cache and returns the result if it exists. (This
is also used to de-duplicate processing of multiple in-flight requests at once.)
@ -113,7 +113,7 @@ class UrlPreviewer:
1. Checks URL and timestamp against the database cache and returns the result if it
has not expired and was successful (a 2xx return code).
2. Checks if the URL matches an oEmbed (https://oembed.com/) pattern. If it
does, update the URL to download.
does and the new URL is not blocked, update the URL to download.
3. Downloads the URL and stores it into a file via the media storage provider
and saves the local media metadata.
4. If the media is an image:
@ -127,14 +127,14 @@ class UrlPreviewer:
and saves the local media metadata.
2. Convert the oEmbed response to an Open Graph response.
3. Override any Open Graph data from the HTML with data from oEmbed.
4. If an image exists in the Open Graph response:
4. If an image URL exists in the Open Graph response:
1. Downloads the URL and stores it into a file via the media storage
provider and saves the local media metadata.
2. Generates thumbnails.
3. Updates the Open Graph response based on image properties.
6. If the media is JSON and an oEmbed URL was found:
6. If an oEmbed URL was found and the media is JSON:
1. Convert the oEmbed response to an Open Graph response.
2. If a thumbnail or image is in the oEmbed response:
2. If an image URL is in the oEmbed response:
1. Downloads the URL and stores it into a file via the media storage
provider and saves the local media metadata.
2. Generates thumbnails.
@ -144,7 +144,8 @@ class UrlPreviewer:
If any additional requests (e.g. from oEmbed autodiscovery, step 5.3 or
image thumbnailing, step 5.4 or 6.4) fails then the URL preview as a whole
does not fail. As much information as possible is returned.
does not fail. If any of them are blocked, then those additional requests
are skipped. As much information as possible is returned.
The in-memory cache expires after 1 hour.
@ -166,8 +167,8 @@ class UrlPreviewer:
self.client = SimpleHttpClient(
hs,
treq_args={"browser_like_redirects": True},
ip_whitelist=hs.config.media.url_preview_ip_range_whitelist,
ip_blacklist=hs.config.media.url_preview_ip_range_blacklist,
ip_allowlist=hs.config.media.url_preview_ip_range_allowlist,
ip_blocklist=hs.config.media.url_preview_ip_range_blocklist,
use_proxy=True,
)
self.media_repo = media_repo
@ -185,7 +186,7 @@ class UrlPreviewer:
or instance_running_jobs == hs.get_instance_name()
)
self.url_preview_url_blacklist = hs.config.media.url_preview_url_blacklist
self.url_preview_url_blocklist = hs.config.media.url_preview_url_blocklist
self.url_preview_accept_language = hs.config.media.url_preview_accept_language
# memory cache mapping urls to an ObservableDeferred returning
@ -203,48 +204,14 @@ class UrlPreviewer:
)
async def preview(self, url: str, user: UserID, ts: int) -> bytes:
# XXX: we could move this into _do_preview if we wanted.
url_tuple = urlsplit(url)
for entry in self.url_preview_url_blacklist:
match = True
for attrib in entry:
pattern = entry[attrib]
value = getattr(url_tuple, attrib)
logger.debug(
"Matching attrib '%s' with value '%s' against pattern '%s'",
attrib,
value,
pattern,
)
if value is None:
match = False
continue
# Some attributes might not be parsed as strings by urlsplit (such as the
# port, which is parsed as an int). Because we use match functions that
# expect strings, we want to make sure that's what we give them.
value_str = str(value)
if pattern.startswith("^"):
if not re.match(pattern, value_str):
match = False
continue
else:
if not fnmatch.fnmatch(value_str, pattern):
match = False
continue
if match:
logger.warning("URL %s blocked by url_blacklist entry %s", url, entry)
raise SynapseError(
403, "URL blocked by url pattern blacklist entry", Codes.UNKNOWN
)
# the in-memory cache:
# * ensures that only one request is active at a time
# * ensures that only one request to a URL is active at a time
# * takes load off the DB for the thundering herds
# * also caches any failures (unlike the DB) so we don't keep
# requesting the same endpoint
# requesting the same endpoint
#
# Note that autodiscovered oEmbed URLs and pre-caching of images
# are not captured in the in-memory cache.
observable = self._cache.get(url)
@ -283,7 +250,7 @@ class UrlPreviewer:
og = og.encode("utf8")
return og
# If this URL can be accessed via oEmbed, use that instead.
# If this URL can be accessed via an allowed oEmbed, use that instead.
url_to_download = url
oembed_url = self._oembed.get_oembed_url(url)
if oembed_url:
@ -329,6 +296,7 @@ class UrlPreviewer:
# defer to that.
oembed_url = self._oembed.autodiscover_from_html(tree)
og_from_oembed: JsonDict = {}
# Only download to the oEmbed URL if it is allowed.
if oembed_url:
try:
oembed_info = await self._handle_url(
@ -411,6 +379,59 @@ class UrlPreviewer:
return jsonog.encode("utf8")
def _is_url_blocked(self, url: str) -> bool:
"""
Check whether the URL is allowed to be previewed (according to the homeserver
configuration).
Args:
url: The requested URL.
Return:
True if the URL is blocked, False if it is allowed.
"""
url_tuple = urlsplit(url)
for entry in self.url_preview_url_blocklist:
match = True
# Iterate over each entry. If *all* attributes of that entry match
# the current URL, then reject it.
for attrib, pattern in entry.items():
value = getattr(url_tuple, attrib)
logger.debug(
"Matching attrib '%s' with value '%s' against pattern '%s'",
attrib,
value,
pattern,
)
if value is None:
match = False
break
# Some attributes might not be parsed as strings by urlsplit (such as the
# port, which is parsed as an int). Because we use match functions that
# expect strings, we want to make sure that's what we give them.
value_str = str(value)
# Check the value against the pattern as either a regular expression or
# a glob. If it doesn't match, the entry doesn't match.
if pattern.startswith("^"):
if not re.match(pattern, value_str):
match = False
break
else:
if not fnmatch.fnmatch(value_str, pattern):
match = False
break
# All fields matched, return true (the URL is blocked).
if match:
logger.warning("URL %s blocked by entry %s", url, entry)
return match
# No matches were found, the URL is allowed.
return False
async def _download_url(self, url: str, output_stream: BinaryIO) -> DownloadResult:
"""
Fetches a remote URL and parses the headers.
@ -451,7 +472,7 @@ class UrlPreviewer:
except DNSLookupError:
# DNS lookup returned no results
# Note: This will also be the case if one of the resolved IP
# addresses is blacklisted
# addresses is blocked.
raise SynapseError(
502,
"DNS resolution failure during URL preview generation",
@ -547,8 +568,16 @@ class UrlPreviewer:
Returns:
A MediaInfo object describing the fetched content.
Raises:
SynapseError if the URL is blocked.
"""
if self._is_url_blocked(url):
raise SynapseError(
403, "URL blocked by url pattern blocklist entry", Codes.UNKNOWN
)
# TODO: we should probably honour robots.txt... except in practice
# we're most likely being explicitly triggered by a human rather than a
# bot, so are we really a robot?
@ -624,7 +653,7 @@ class UrlPreviewer:
return
# The image URL from the HTML might be relative to the previewed page,
# convert it to an URL which can be requested directly.
# convert it to a URL which can be requested directly.
url_parts = urlparse(image_url)
if url_parts.scheme != "data":
image_url = urljoin(media_info.uri, image_url)

View File

@ -137,7 +137,7 @@ from synapse.util.caches.descriptors import CachedFunction, cached as _cached
from synapse.util.frozenutils import freeze
if TYPE_CHECKING:
from synapse.app.generic_worker import GenericWorkerSlavedStore
from synapse.app.generic_worker import GenericWorkerStore
from synapse.server import HomeServer
@ -241,9 +241,7 @@ class ModuleApi:
# TODO: Fix this type hint once the types for the data stores have been ironed
# out.
self._store: Union[
DataStore, "GenericWorkerSlavedStore"
] = hs.get_datastores().main
self._store: Union[DataStore, "GenericWorkerStore"] = hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers()
self._auth = hs.get_auth()
self._auth_handler = auth_handler

View File

@ -143,7 +143,7 @@ class HttpPusher(Pusher):
)
self.url = url
self.http_client = hs.get_proxied_blacklisted_http_client()
self.http_client = hs.get_proxied_blocklisted_http_client()
self.data_minus_url = {}
self.data_minus_url.update(self.data)
del self.data_minus_url["url"]

View File

@ -60,7 +60,7 @@ _WAIT_FOR_REPLICATION_TIMEOUT_SECONDS = 5
class ReplicationDataHandler:
"""Handles incoming stream updates from replication.
This instance notifies the slave data store about updates. Can be subclassed
This instance notifies the data store about updates. Can be subclassed
to handle updates in additional ways.
"""
@ -91,7 +91,7 @@ class ReplicationDataHandler:
) -> None:
"""Called to handle a batch of replication data with a given stream token.
By default this just pokes the slave store. Can be overridden in subclasses to
By default, this just pokes the data store. Can be overridden in subclasses to
handle more.
Args:

View File

@ -137,6 +137,35 @@ class DevicesRestServlet(RestServlet):
devices = await self.device_handler.get_devices_by_user(target_user.to_string())
return HTTPStatus.OK, {"devices": devices, "total": len(devices)}
async def on_POST(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
"""Creates a new device for the user."""
await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id)
if not self.is_mine(target_user):
raise SynapseError(
HTTPStatus.BAD_REQUEST, "Can only create devices for local users"
)
u = await self.store.get_user_by_id(target_user.to_string())
if u is None:
raise NotFoundError("Unknown user")
body = parse_json_object_from_request(request)
device_id = body.get("device_id")
if not device_id:
raise SynapseError(HTTPStatus.BAD_REQUEST, "Missing device_id")
if not isinstance(device_id, str):
raise SynapseError(HTTPStatus.BAD_REQUEST, "device_id must be a string")
await self.device_handler.check_device_registered(
user_id=user_id, device_id=device_id
)
return HTTPStatus.CREATED, {}
class DeleteDevicesRestServlet(RestServlet):
"""

View File

@ -87,11 +87,6 @@ class LoginRestServlet(RestServlet):
# JWT configuration variables.
self.jwt_enabled = hs.config.jwt.jwt_enabled
self.jwt_secret = hs.config.jwt.jwt_secret
self.jwt_subject_claim = hs.config.jwt.jwt_subject_claim
self.jwt_algorithm = hs.config.jwt.jwt_algorithm
self.jwt_issuer = hs.config.jwt.jwt_issuer
self.jwt_audiences = hs.config.jwt.jwt_audiences
# SSO configuration.
self.saml2_enabled = hs.config.saml2.saml2_enabled
@ -427,7 +422,7 @@ class LoginRestServlet(RestServlet):
self, login_submission: JsonDict, should_issue_refresh_token: bool = False
) -> LoginResponse:
"""
Handle the final stage of SSO login.
Handle token login.
Args:
login_submission: The JSON request body.
@ -452,72 +447,24 @@ class LoginRestServlet(RestServlet):
async def _do_jwt_login(
self, login_submission: JsonDict, should_issue_refresh_token: bool = False
) -> LoginResponse:
token = login_submission.get("token", None)
if token is None:
raise LoginError(
403, "Token field for JWT is missing", errcode=Codes.FORBIDDEN
)
"""
Handle the custom JWT login.
from authlib.jose import JsonWebToken, JWTClaims
from authlib.jose.errors import BadSignatureError, InvalidClaimError, JoseError
Args:
login_submission: The JSON request body.
should_issue_refresh_token: True if this login should issue
a refresh token alongside the access token.
jwt = JsonWebToken([self.jwt_algorithm])
claim_options = {}
if self.jwt_issuer is not None:
claim_options["iss"] = {"value": self.jwt_issuer, "essential": True}
if self.jwt_audiences is not None:
claim_options["aud"] = {"values": self.jwt_audiences, "essential": True}
try:
claims = jwt.decode(
token,
key=self.jwt_secret,
claims_cls=JWTClaims,
claims_options=claim_options,
)
except BadSignatureError:
# We handle this case separately to provide a better error message
raise LoginError(
403,
"JWT validation failed: Signature verification failed",
errcode=Codes.FORBIDDEN,
)
except JoseError as e:
# A JWT error occurred, return some info back to the client.
raise LoginError(
403,
"JWT validation failed: %s" % (str(e),),
errcode=Codes.FORBIDDEN,
)
try:
claims.validate(leeway=120) # allows 2 min of clock skew
# Enforce the old behavior which is rolled out in productive
# servers: if the JWT contains an 'aud' claim but none is
# configured, the login attempt will fail
if claims.get("aud") is not None:
if self.jwt_audiences is None or len(self.jwt_audiences) == 0:
raise InvalidClaimError("aud")
except JoseError as e:
raise LoginError(
403,
"JWT validation failed: %s" % (str(e),),
errcode=Codes.FORBIDDEN,
)
user = claims.get(self.jwt_subject_claim, None)
if user is None:
raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN)
user_id = UserID(user, self.hs.hostname).to_string()
result = await self._complete_login(
Returns:
The body of the JSON response.
"""
user_id = await self.hs.get_jwt_handler().validate_login(login_submission)
return await self._complete_login(
user_id,
login_submission,
create_non_existent_users=True,
should_issue_refresh_token=should_issue_refresh_token,
)
return result
def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict:

View File

@ -12,13 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Tuple
from http import HTTPStatus
from typing import TYPE_CHECKING, Dict, List, Tuple
from synapse.api.errors import Codes, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet
from synapse.http.servlet import RestServlet, parse_strings_from_args
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict, UserID
from synapse.types import JsonDict
from ._base import client_patterns
@ -30,11 +31,11 @@ logger = logging.getLogger(__name__)
class UserMutualRoomsServlet(RestServlet):
"""
GET /uk.half-shot.msc2666/user/mutual_rooms/{user_id} HTTP/1.1
GET /uk.half-shot.msc2666/user/mutual_rooms?user_id={user_id} HTTP/1.1
"""
PATTERNS = client_patterns(
"/uk.half-shot.msc2666/user/mutual_rooms/(?P<user_id>[^/]*)",
"/uk.half-shot.msc2666/user/mutual_rooms$",
releases=(), # This is an unstable feature
)
@ -43,17 +44,35 @@ class UserMutualRoomsServlet(RestServlet):
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
UserID.from_string(user_id)
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
# twisted.web.server.Request.args is incorrectly defined as Optional[Any]
args: Dict[bytes, List[bytes]] = request.args # type: ignore
user_ids = parse_strings_from_args(args, "user_id", required=True)
if len(user_ids) > 1:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Duplicate user_id query parameter",
errcode=Codes.INVALID_PARAM,
)
# We don't do batching, so a batch token is illegal by default
if b"batch_token" in args:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Unknown batch_token",
errcode=Codes.INVALID_PARAM,
)
user_id = user_ids[0]
requester = await self.auth.get_user_by_req(request)
if user_id == requester.user.to_string():
raise SynapseError(
code=400,
msg="You cannot request a list of shared rooms with yourself",
errcode=Codes.FORBIDDEN,
HTTPStatus.UNPROCESSABLE_ENTITY,
"You cannot request a list of shared rooms with yourself",
errcode=Codes.INVALID_PARAM,
)
rooms = await self.store.get_mutual_rooms_between_users(

View File

@ -91,7 +91,7 @@ class VersionsRestServlet(RestServlet):
# Implements additional endpoints as described in MSC2432
"org.matrix.msc2432": True,
# Implements additional endpoints as described in MSC2666
"uk.half-shot.msc2666.mutual_rooms": True,
"uk.half-shot.msc2666.query_mutual_rooms": True,
# Whether new rooms will be set to encrypted or not (based on presets).
"io.element.e2ee_forced.public": self.e2ee_forced_public,
"io.element.e2ee_forced.private": self.e2ee_forced_private,

View File

@ -147,6 +147,7 @@ logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from txredisapi import ConnectionHandler
from synapse.handlers.jwt import JwtHandler
from synapse.handlers.oidc import OidcHandler
from synapse.handlers.saml import SamlHandler
@ -453,15 +454,15 @@ class HomeServer(metaclass=abc.ABCMeta):
return SimpleHttpClient(self, use_proxy=True)
@cache_in_self
def get_proxied_blacklisted_http_client(self) -> SimpleHttpClient:
def get_proxied_blocklisted_http_client(self) -> SimpleHttpClient:
"""
An HTTP client that uses configured HTTP(S) proxies and blacklists IPs
based on the IP range blacklist/whitelist.
An HTTP client that uses configured HTTP(S) proxies and blocks IPs
based on the configured IP ranges.
"""
return SimpleHttpClient(
self,
ip_whitelist=self.config.server.ip_range_whitelist,
ip_blacklist=self.config.server.ip_range_blacklist,
ip_allowlist=self.config.server.ip_range_allowlist,
ip_blocklist=self.config.server.ip_range_blocklist,
use_proxy=True,
)
@ -533,6 +534,12 @@ class HomeServer(metaclass=abc.ABCMeta):
def get_sso_handler(self) -> SsoHandler:
return SsoHandler(self)
@cache_in_self
def get_jwt_handler(self) -> "JwtHandler":
from synapse.handlers.jwt import JwtHandler
return JwtHandler(self)
@cache_in_self
def get_sync_handler(self) -> SyncHandler:
return SyncHandler(self)

View File

@ -13,8 +13,7 @@
# limitations under the License.
import logging
from collections import Counter
from typing import TYPE_CHECKING, Collection, List, Tuple
from typing import TYPE_CHECKING, Collection, Counter, List, Tuple
from synapse.api.errors import SynapseError
from synapse.storage.database import LoggingTransaction

View File

@ -565,9 +565,8 @@ class DatabasePool:
# A set of tables that are not safe to use native upserts in.
self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys())
# We add the user_directory_search table to the blacklist on SQLite
# because the existing search table does not have an index, making it
# unsafe to use native upserts.
# The user_directory_search table is unsafe to use native upserts
# on SQLite because the existing search table does not have an index.
if isinstance(self.engine, Sqlite3Engine):
self._unsafe_to_upsert_tables.add("user_directory_search")

View File

@ -85,13 +85,10 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
writers=hs.config.worker.writers.account_data,
)
else:
# Multiple writers are not supported for SQLite.
#
# We shouldn't be running in worker mode with SQLite, but its useful
# to support it for unit tests.
#
# If this process is the writer than we need to use
# `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
# updated over replication. (Multiple writers are not supported for
# SQLite).
self._account_data_id_gen = StreamIdGenerator(
db_conn,
hs.get_replication_notifier(),

View File

@ -274,11 +274,11 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
async def invalidate_cache_and_stream(
self, cache_name: str, keys: Tuple[Any, ...]
) -> None:
"""Invalidates the cache and adds it to the cache stream so slaves
"""Invalidates the cache and adds it to the cache stream so other workers
will know to invalidate their caches.
This should only be used to invalidate caches where slaves won't
otherwise know from other replication streams that the cache should
This should only be used to invalidate caches where other workers won't
otherwise have known from other replication streams that the cache should
be invalidated.
"""
cache_func = getattr(self, cache_name, None)
@ -297,11 +297,11 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
cache_func: CachedFunction,
keys: Tuple[Any, ...],
) -> None:
"""Invalidates the cache and adds it to the cache stream so slaves
"""Invalidates the cache and adds it to the cache stream so other workers
will know to invalidate their caches.
This should only be used to invalidate caches where slaves won't
otherwise know from other replication streams that the cache should
This should only be used to invalidate caches where other workers won't
otherwise have known from other replication streams that the cache should
be invalidated.
"""
txn.call_after(cache_func.invalidate, keys)
@ -310,7 +310,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
def _invalidate_all_cache_and_stream(
self, txn: LoggingTransaction, cache_func: CachedFunction
) -> None:
"""Invalidates the entire cache and adds it to the cache stream so slaves
"""Invalidates the entire cache and adds it to the cache stream so other workers
will know to invalidate their caches.
"""

View File

@ -105,8 +105,6 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
is_writer=hs.config.worker.worker_app is None,
)
# Type-ignore: _device_list_id_gen is mixed in from either DataStore (as a
# StreamIdGenerator) or SlavedDataStore (as a SlavedIdTracker).
device_list_max = self._device_list_id_gen.get_current_token()
device_list_prefill, min_device_list_id = self.db_pool.get_cache_dict(
db_conn,

View File

@ -213,13 +213,10 @@ class EventsWorkerStore(SQLBaseStore):
writers=hs.config.worker.writers.events,
)
else:
# Multiple writers are not supported for SQLite.
#
# We shouldn't be running in worker mode with SQLite, but its useful
# to support it for unit tests.
#
# If this process is the writer than we need to use
# `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
# updated over replication. (Multiple writers are not supported for
# SQLite).
self._stream_id_gen = StreamIdGenerator(
db_conn,
hs.get_replication_notifier(),
@ -1976,12 +1973,6 @@ class EventsWorkerStore(SQLBaseStore):
return rows, to_token, True
async def is_event_after(self, event_id1: str, event_id2: str) -> bool:
"""Returns True if event_id1 is after event_id2 in the stream"""
to_1, so_1 = await self.get_event_ordering(event_id1)
to_2, so_2 = await self.get_event_ordering(event_id2)
return (to_1, so_1) > (to_2, so_2)
@cached(max_entries=5000)
async def get_event_ordering(self, event_id: str) -> Tuple[int, int]:
res = await self.db_pool.simple_select_one(

View File

@ -25,6 +25,7 @@ from synapse.storage.database import (
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.engines import PostgresEngine
from synapse.types import JsonDict, UserID
from synapse.util.caches.descriptors import cached
@ -40,6 +41,8 @@ class FilteringWorkerStore(SQLBaseStore):
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self.server_name: str = hs.hostname
self.database_engine = database.engine
self.db_pool.updates.register_background_index_update(
"full_users_filters_unique_idx",
index_name="full_users_unique_idx",
@ -48,6 +51,98 @@ class FilteringWorkerStore(SQLBaseStore):
unique=True,
)
self.db_pool.updates.register_background_update_handler(
"populate_full_user_id_user_filters",
self.populate_full_user_id_user_filters,
)
async def populate_full_user_id_user_filters(
self, progress: JsonDict, batch_size: int
) -> int:
"""
Background update to populate the column `full_user_id` of the table
user_filters from entries in the column `user_local_part` of the same table
"""
lower_bound_id = progress.get("lower_bound_id", "")
def _get_last_id(txn: LoggingTransaction) -> Optional[str]:
sql = """
SELECT user_id FROM user_filters
WHERE user_id > ?
ORDER BY user_id
LIMIT 1 OFFSET 50
"""
txn.execute(sql, (lower_bound_id,))
res = txn.fetchone()
if res:
upper_bound_id = res[0]
return upper_bound_id
else:
return None
def _process_batch(
txn: LoggingTransaction, lower_bound_id: str, upper_bound_id: str
) -> None:
sql = """
UPDATE user_filters
SET full_user_id = '@' || user_id || ?
WHERE ? < user_id AND user_id <= ? AND full_user_id IS NULL
"""
txn.execute(sql, (f":{self.server_name}", lower_bound_id, upper_bound_id))
def _final_batch(txn: LoggingTransaction, lower_bound_id: str) -> None:
sql = """
UPDATE user_filters
SET full_user_id = '@' || user_id || ?
WHERE ? < user_id AND full_user_id IS NULL
"""
txn.execute(
sql,
(
f":{self.server_name}",
lower_bound_id,
),
)
if isinstance(self.database_engine, PostgresEngine):
sql = """
ALTER TABLE user_filters VALIDATE CONSTRAINT full_user_id_not_null
"""
txn.execute(sql)
upper_bound_id = await self.db_pool.runInteraction(
"populate_full_user_id_user_filters", _get_last_id
)
if upper_bound_id is None:
await self.db_pool.runInteraction(
"populate_full_user_id_user_filters", _final_batch, lower_bound_id
)
await self.db_pool.updates._end_background_update(
"populate_full_user_id_user_filters"
)
return 1
await self.db_pool.runInteraction(
"populate_full_user_id_user_filters",
_process_batch,
lower_bound_id,
upper_bound_id,
)
progress["lower_bound_id"] = upper_bound_id
await self.db_pool.runInteraction(
"populate_full_user_id_user_filters",
self.db_pool.updates._background_update_progress_txn,
"populate_full_user_id_user_filters",
progress,
)
return 50
@cached(num_args=2)
async def get_user_filter(
self, user_localpart: str, filter_id: Union[int, str]

View File

@ -15,9 +15,14 @@ from typing import TYPE_CHECKING, Optional
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.databases.main.roommember import ProfileInfo
from synapse.types import UserID
from synapse.storage.engines import PostgresEngine
from synapse.types import JsonDict, UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -31,6 +36,8 @@ class ProfileWorkerStore(SQLBaseStore):
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self.server_name: str = hs.hostname
self.database_engine = database.engine
self.db_pool.updates.register_background_index_update(
"profiles_full_user_id_key_idx",
index_name="profiles_full_user_id_key",
@ -39,6 +46,97 @@ class ProfileWorkerStore(SQLBaseStore):
unique=True,
)
self.db_pool.updates.register_background_update_handler(
"populate_full_user_id_profiles", self.populate_full_user_id_profiles
)
async def populate_full_user_id_profiles(
self, progress: JsonDict, batch_size: int
) -> int:
"""
Background update to populate the column `full_user_id` of the table
profiles from entries in the column `user_local_part` of the same table
"""
lower_bound_id = progress.get("lower_bound_id", "")
def _get_last_id(txn: LoggingTransaction) -> Optional[str]:
sql = """
SELECT user_id FROM profiles
WHERE user_id > ?
ORDER BY user_id
LIMIT 1 OFFSET 50
"""
txn.execute(sql, (lower_bound_id,))
res = txn.fetchone()
if res:
upper_bound_id = res[0]
return upper_bound_id
else:
return None
def _process_batch(
txn: LoggingTransaction, lower_bound_id: str, upper_bound_id: str
) -> None:
sql = """
UPDATE profiles
SET full_user_id = '@' || user_id || ?
WHERE ? < user_id AND user_id <= ? AND full_user_id IS NULL
"""
txn.execute(sql, (f":{self.server_name}", lower_bound_id, upper_bound_id))
def _final_batch(txn: LoggingTransaction, lower_bound_id: str) -> None:
sql = """
UPDATE profiles
SET full_user_id = '@' || user_id || ?
WHERE ? < user_id AND full_user_id IS NULL
"""
txn.execute(
sql,
(
f":{self.server_name}",
lower_bound_id,
),
)
if isinstance(self.database_engine, PostgresEngine):
sql = """
ALTER TABLE profiles VALIDATE CONSTRAINT full_user_id_not_null
"""
txn.execute(sql)
upper_bound_id = await self.db_pool.runInteraction(
"populate_full_user_id_profiles", _get_last_id
)
if upper_bound_id is None:
await self.db_pool.runInteraction(
"populate_full_user_id_profiles", _final_batch, lower_bound_id
)
await self.db_pool.updates._end_background_update(
"populate_full_user_id_profiles"
)
return 1
await self.db_pool.runInteraction(
"populate_full_user_id_profiles",
_process_batch,
lower_bound_id,
upper_bound_id,
)
progress["lower_bound_id"] = upper_bound_id
await self.db_pool.runInteraction(
"populate_full_user_id_profiles",
self.db_pool.updates._background_update_progress_txn,
"populate_full_user_id_profiles",
progress,
)
return 50
async def get_profileinfo(self, user_localpart: str) -> ProfileInfo:
try:
profile = await self.db_pool.simple_select_one(

View File

@ -85,13 +85,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
else:
self._can_write_to_receipts = True
# Multiple writers are not supported for SQLite.
#
# We shouldn't be running in worker mode with SQLite, but its useful
# to support it for unit tests.
#
# If this process is the writer than we need to use
# `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
# updated over replication. (Multiple writers are not supported for
# SQLite).
self._receipts_id_gen = StreamIdGenerator(
db_conn,
hs.get_replication_notifier(),

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
SCHEMA_VERSION = 76 # remember to update the list below when updating
SCHEMA_VERSION = 77 # remember to update the list below when updating
"""Represents the expectations made by the codebase about the database schema
This should be incremented whenever the codebase changes its requirements on the
@ -100,13 +100,19 @@ Changes in SCHEMA_VERSION = 75:
Changes in SCHEMA_VERSION = 76:
- Adds a full_user_id column to tables profiles and user_filters.
Changes in SCHEMA_VERSION = 77
- (Postgres) Add NOT VALID CHECK (full_user_id IS NOT NULL) to tables profiles and user_filters
"""
SCHEMA_COMPAT_VERSION = (
# Queries against `event_stream_ordering` columns in membership tables must
# be disambiguated.
74
#
# insertions to the column `full_user_id` of tables profiles and user_filters can no
# longer be null
76
)
"""Limit on how far the synapse codebase can be rolled back without breaking db compat

View File

@ -21,7 +21,7 @@ from synapse.storage.prepare_database import get_statements
logger = logging.getLogger(__name__)
# This stream is used to notify replication slaves that some caches have
# This stream is used to notify workers over replication that some caches have
# been invalidated that they cannot infer from the other streams.
CREATE_TABLE = """
CREATE TABLE cache_invalidation_stream (

View File

@ -0,0 +1,16 @@
/* Copyright 2023 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
ALTER TABLE profiles ADD CONSTRAINT full_user_id_not_null CHECK (full_user_id IS NOT NULL) NOT VALID;

View File

@ -0,0 +1,16 @@
/* Copyright 2023 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
ALTER TABLE user_filters ADD CONSTRAINT full_user_id_not_null CHECK (full_user_id IS NOT NULL) NOT VALID;

View File

@ -0,0 +1,16 @@
/* Copyright 2023 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES (7703, 'populate_full_user_id_profiles', '{}');

View File

@ -0,0 +1,16 @@
/* Copyright 2023 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES (7704, 'populate_full_user_id_user_filters', '{}');

View File

@ -85,7 +85,15 @@ JsonSerializable = object
# Collection[str] that does not include str itself; str being a Sequence[str]
# is very misleading and results in bugs.
#
# StrCollection is an unordered collection of strings. If ordering is important,
# StrSequence can be used instead.
StrCollection = Union[Tuple[str, ...], List[str], AbstractSet[str]]
# Sequence[str] that does not include str itself; str being a Sequence[str]
# is very misleading and results in bugs.
#
# Unlike StrCollection, StrSequence is an ordered collection of strings.
StrSequence = Union[Tuple[str, ...], List[str]]
# Note that this seems to require inheriting *directly* from Interface in order

View File

@ -14,17 +14,17 @@
import importlib
import importlib.util
import itertools
from types import ModuleType
from typing import Any, Iterable, Tuple, Type
from typing import Any, Tuple, Type
import jsonschema
from synapse.config._base import ConfigError
from synapse.config._util import json_error_to_config_error
from synapse.types import StrSequence
def load_module(provider: dict, config_path: Iterable[str]) -> Tuple[Type, Any]:
def load_module(provider: dict, config_path: StrSequence) -> Tuple[Type, Any]:
"""Loads a synapse module with its config
Args:
@ -39,9 +39,7 @@ def load_module(provider: dict, config_path: Iterable[str]) -> Tuple[Type, Any]:
modulename = provider.get("module")
if not isinstance(modulename, str):
raise ConfigError(
"expected a string", path=itertools.chain(config_path, ("module",))
)
raise ConfigError("expected a string", path=tuple(config_path) + ("module",))
# We need to import the module, and then pick the class out of
# that, so we split based on the last dot.
@ -55,19 +53,17 @@ def load_module(provider: dict, config_path: Iterable[str]) -> Tuple[Type, Any]:
try:
provider_config = provider_class.parse_config(module_config)
except jsonschema.ValidationError as e:
raise json_error_to_config_error(
e, itertools.chain(config_path, ("config",))
)
raise json_error_to_config_error(e, tuple(config_path) + ("config",))
except ConfigError as e:
raise _wrap_config_error(
"Failed to parse config for module %r" % (modulename,),
prefix=itertools.chain(config_path, ("config",)),
prefix=tuple(config_path) + ("config",),
e=e,
)
except Exception as e:
raise ConfigError(
"Failed to parse config for module %r" % (modulename,),
path=itertools.chain(config_path, ("config",)),
path=tuple(config_path) + ("config",),
) from e
else:
provider_config = module_config
@ -92,9 +88,7 @@ def load_python_module(location: str) -> ModuleType:
return mod
def _wrap_config_error(
msg: str, prefix: Iterable[str], e: ConfigError
) -> "ConfigError":
def _wrap_config_error(msg: str, prefix: StrSequence, e: ConfigError) -> "ConfigError":
"""Wrap a relative ConfigError with a new path
This is useful when we have a ConfigError with a relative path due to a problem
@ -102,7 +96,7 @@ def _wrap_config_error(
"""
path = prefix
if e.path:
path = itertools.chain(prefix, e.path)
path = tuple(prefix) + tuple(e.path)
e1 = ConfigError(msg, path)

View File

@ -25,10 +25,12 @@ from typing import (
Iterator,
List,
Mapping,
MutableSet,
Optional,
Set,
Tuple,
)
from weakref import WeakSet
from prometheus_client.core import Counter
from typing_extensions import ContextManager
@ -86,7 +88,9 @@ queue_wait_timer = Histogram(
)
_rate_limiter_instances: Set["FederationRateLimiter"] = set()
# This must be a `WeakSet`, otherwise we indirectly hold on to entire `HomeServer`s
# during trial test runs and leak a lot of memory.
_rate_limiter_instances: MutableSet["FederationRateLimiter"] = WeakSet()
# Protects the _rate_limiter_instances set from concurrent access
_rate_limiter_instances_lock = threading.Lock()

View File

@ -38,7 +38,7 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
def default_config(self) -> JsonDict:
conf = super().default_config()
# we're using FederationReaderServer, which uses a SlavedStore, so we
# we're using GenericWorkerServer, which uses a GenericWorkerStore, so we
# have to tell the FederationHandler not to try to access stuff that is only
# in the primary store.
conf["worker_app"] = "yes"

View File

@ -63,7 +63,7 @@ class FederationServerTests(unittest.FederatingHomeserverTestCase):
class ServerACLsTestCase(unittest.TestCase):
def test_blacklisted_server(self) -> None:
def test_blocked_server(self) -> None:
e = _create_acl_event({"allow": ["*"], "deny": ["evil.com"]})
logging.info("ACL event: %s", e.content)

View File

@ -86,18 +86,7 @@ class SendJoinParserTestCase(TestCase):
return parsed_response.members_omitted
self.assertTrue(parse({"members_omitted": True}))
self.assertTrue(parse({"org.matrix.msc3706.partial_state": True}))
self.assertFalse(parse({"members_omitted": False}))
self.assertFalse(parse({"org.matrix.msc3706.partial_state": False}))
# If there's a conflict, the stable field wins.
self.assertTrue(
parse({"members_omitted": True, "org.matrix.msc3706.partial_state": False})
)
self.assertFalse(
parse({"members_omitted": False, "org.matrix.msc3706.partial_state": True})
)
def test_servers_in_room(self) -> None:
"""Check that the servers_in_room field is correctly parsed"""
@ -113,28 +102,10 @@ class SendJoinParserTestCase(TestCase):
parsed_response = parser.finish()
return parsed_response.servers_in_room
self.assertEqual(
parse({"org.matrix.msc3706.servers_in_room": ["hs1", "hs2"]}),
["hs1", "hs2"],
)
self.assertEqual(parse({"servers_in_room": ["example.com"]}), ["example.com"])
# If both are provided, the stable identifier should win
self.assertEqual(
parse(
{
"org.matrix.msc3706.servers_in_room": ["old"],
"servers_in_room": ["new"],
}
),
["new"],
)
# And lastly, we should be able to tell if neither field was present.
self.assertEqual(
parse({}),
None,
)
# We should be able to tell the field is not present.
self.assertEqual(parse({}), None)
def test_errors_closing_coroutines(self) -> None:
"""Check we close all coroutines, even if closing the first raises an Exception.
@ -143,7 +114,7 @@ class SendJoinParserTestCase(TestCase):
assertions about its attributes or type.
"""
parser = SendJoinParser(RoomVersions.V1, False)
response = {"org.matrix.msc3706.servers_in_room": ["hs1", "hs2"]}
response = {"servers_in_room": ["hs1", "hs2"]}
serialisation = json.dumps(response).encode()
# Mock the coroutines managed by this parser.

View File

@ -31,7 +31,7 @@ class TestSSOHandler(unittest.HomeserverTestCase):
self.http_client.get_file.side_effect = mock_get_file
self.http_client.user_agent = b"Synapse Test"
hs = self.setup_test_homeserver(
proxied_blacklisted_http_client=self.http_client
proxied_blocklisted_http_client=self.http_client
)
return hs

View File

@ -269,8 +269,8 @@ class MatrixFederationAgentTests(unittest.TestCase):
reactor=cast(ISynapseReactor, self.reactor),
tls_client_options_factory=self.tls_factory,
user_agent=b"test-agent", # Note that this is unused since _well_known_resolver is provided.
ip_whitelist=IPSet(),
ip_blacklist=IPSet(),
ip_allowlist=IPSet(),
ip_blocklist=IPSet(),
_srv_resolver=self.mock_resolver,
_well_known_resolver=self.well_known_resolver,
)
@ -997,8 +997,8 @@ class MatrixFederationAgentTests(unittest.TestCase):
reactor=self.reactor,
tls_client_options_factory=tls_factory,
user_agent=b"test-agent", # This is unused since _well_known_resolver is passed below.
ip_whitelist=IPSet(),
ip_blacklist=IPSet(),
ip_allowlist=IPSet(),
ip_blocklist=IPSet(),
_srv_resolver=self.mock_resolver,
_well_known_resolver=WellKnownResolver(
cast(ISynapseReactor, self.reactor),

View File

@ -27,8 +27,8 @@ from twisted.web.iweb import UNKNOWN_LENGTH
from synapse.api.errors import SynapseError
from synapse.http.client import (
BlacklistingAgentWrapper,
BlacklistingReactorWrapper,
BlocklistingAgentWrapper,
BlocklistingReactorWrapper,
BodyExceededMaxSize,
_DiscardBodyWithMaxSizeProtocol,
read_body_with_max_size,
@ -140,7 +140,7 @@ class ReadBodyWithMaxSizeTests(TestCase):
self.assertEqual(result.getvalue(), b"")
class BlacklistingAgentTest(TestCase):
class BlocklistingAgentTest(TestCase):
def setUp(self) -> None:
self.reactor, self.clock = get_clock()
@ -157,16 +157,16 @@ class BlacklistingAgentTest(TestCase):
self.reactor.lookups[domain.decode()] = ip.decode()
self.reactor.lookups[ip.decode()] = ip.decode()
self.ip_whitelist = IPSet([self.allowed_ip.decode()])
self.ip_blacklist = IPSet(["5.0.0.0/8"])
self.ip_allowlist = IPSet([self.allowed_ip.decode()])
self.ip_blocklist = IPSet(["5.0.0.0/8"])
def test_reactor(self) -> None:
"""Apply the blacklisting reactor and ensure it properly blocks connections to particular domains and IPs."""
"""Apply the blocklisting reactor and ensure it properly blocks connections to particular domains and IPs."""
agent = Agent(
BlacklistingReactorWrapper(
BlocklistingReactorWrapper(
self.reactor,
ip_whitelist=self.ip_whitelist,
ip_blacklist=self.ip_blacklist,
ip_allowlist=self.ip_allowlist,
ip_blocklist=self.ip_blocklist,
),
)
@ -207,11 +207,11 @@ class BlacklistingAgentTest(TestCase):
self.assertEqual(response.code, 200)
def test_agent(self) -> None:
"""Apply the blacklisting agent and ensure it properly blocks connections to particular IPs."""
agent = BlacklistingAgentWrapper(
"""Apply the blocklisting agent and ensure it properly blocks connections to particular IPs."""
agent = BlocklistingAgentWrapper(
Agent(self.reactor),
ip_blacklist=self.ip_blacklist,
ip_whitelist=self.ip_whitelist,
ip_blocklist=self.ip_blocklist,
ip_allowlist=self.ip_allowlist,
)
# The unsafe IPs should be rejected.

View File

@ -231,11 +231,11 @@ class FederationClientTests(HomeserverTestCase):
self.assertIsInstance(f.value, RequestSendFailed)
self.assertIsInstance(f.value.inner_exception, ResponseNeverReceived)
def test_client_ip_range_blacklist(self) -> None:
"""Ensure that Synapse does not try to connect to blacklisted IPs"""
def test_client_ip_range_blocklist(self) -> None:
"""Ensure that Synapse does not try to connect to blocked IPs"""
# Set up the ip_range blacklist
self.hs.config.server.federation_ip_range_blacklist = IPSet(
# Set up the ip_range blocklist
self.hs.config.server.federation_ip_range_blocklist = IPSet(
["127.0.0.0/8", "fe80::/64"]
)
self.reactor.lookups["internal"] = "127.0.0.1"
@ -243,7 +243,7 @@ class FederationClientTests(HomeserverTestCase):
self.reactor.lookups["fine"] = "10.20.30.40"
cl = MatrixFederationHttpClient(self.hs, None)
# Try making a GET request to a blacklisted IPv4 address
# Try making a GET request to a blocked IPv4 address
# ------------------------------------------------------
# Make the request
d = defer.ensureDeferred(cl.get_json("internal:8008", "foo/bar", timeout=10000))
@ -261,7 +261,7 @@ class FederationClientTests(HomeserverTestCase):
self.assertIsInstance(f.value, RequestSendFailed)
self.assertIsInstance(f.value.inner_exception, DNSLookupError)
# Try making a POST request to a blacklisted IPv6 address
# Try making a POST request to a blocked IPv6 address
# -------------------------------------------------------
# Make the request
d = defer.ensureDeferred(
@ -278,11 +278,11 @@ class FederationClientTests(HomeserverTestCase):
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 0)
# Check that it was due to a blacklisted DNS lookup
# Check that it was due to a blocked DNS lookup
f = self.failureResultOf(d, RequestSendFailed)
self.assertIsInstance(f.value.inner_exception, DNSLookupError)
# Try making a GET request to a non-blacklisted IPv4 address
# Try making a GET request to an allowed IPv4 address
# ----------------------------------------------------------
# Make the request
d = defer.ensureDeferred(cl.post_json("fine:8008", "foo/bar", timeout=10000))

View File

@ -32,7 +32,7 @@ from twisted.internet.protocol import Factory, Protocol
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
from twisted.web.http import HTTPChannel
from synapse.http.client import BlacklistingReactorWrapper
from synapse.http.client import BlocklistingReactorWrapper
from synapse.http.connectproxyclient import ProxyCredentials
from synapse.http.proxyagent import ProxyAgent, parse_proxy
@ -684,11 +684,11 @@ class MatrixFederationAgentTests(TestCase):
self.assertEqual(body, b"result")
@patch.dict(os.environ, {"http_proxy": "proxy.com:8888"})
def test_http_request_via_proxy_with_blacklist(self) -> None:
# The blacklist includes the configured proxy IP.
def test_http_request_via_proxy_with_blocklist(self) -> None:
# The blocklist includes the configured proxy IP.
agent = ProxyAgent(
BlacklistingReactorWrapper(
self.reactor, ip_whitelist=None, ip_blacklist=IPSet(["1.0.0.0/8"])
BlocklistingReactorWrapper(
self.reactor, ip_allowlist=None, ip_blocklist=IPSet(["1.0.0.0/8"])
),
self.reactor,
use_proxy=True,
@ -730,11 +730,11 @@ class MatrixFederationAgentTests(TestCase):
self.assertEqual(body, b"result")
@patch.dict(os.environ, {"HTTPS_PROXY": "proxy.com"})
def test_https_request_via_uppercase_proxy_with_blacklist(self) -> None:
# The blacklist includes the configured proxy IP.
def test_https_request_via_uppercase_proxy_with_blocklist(self) -> None:
# The blocklist includes the configured proxy IP.
agent = ProxyAgent(
BlacklistingReactorWrapper(
self.reactor, ip_whitelist=None, ip_blacklist=IPSet(["1.0.0.0/8"])
BlocklistingReactorWrapper(
self.reactor, ip_allowlist=None, ip_blocklist=IPSet(["1.0.0.0/8"])
),
self.reactor,
contextFactory=get_test_https_policy(),

View File

@ -123,17 +123,17 @@ class SimpleHttpClientTests(HomeserverTestCase):
self.assertIsInstance(f.value, RequestTimedOutError)
def test_client_ip_range_blacklist(self) -> None:
"""Ensure that Synapse does not try to connect to blacklisted IPs"""
def test_client_ip_range_blocklist(self) -> None:
"""Ensure that Synapse does not try to connect to blocked IPs"""
# Add some DNS entries we'll blacklist
# Add some DNS entries we'll block
self.reactor.lookups["internal"] = "127.0.0.1"
self.reactor.lookups["internalv6"] = "fe80:0:0:0:0:8a2e:370:7337"
ip_blacklist = IPSet(["127.0.0.0/8", "fe80::/64"])
ip_blocklist = IPSet(["127.0.0.0/8", "fe80::/64"])
cl = SimpleHttpClient(self.hs, ip_blacklist=ip_blacklist)
cl = SimpleHttpClient(self.hs, ip_blocklist=ip_blocklist)
# Try making a GET request to a blacklisted IPv4 address
# Try making a GET request to a blocked IPv4 address
# ------------------------------------------------------
# Make the request
d = defer.ensureDeferred(cl.get_json("http://internal:8008/foo/bar"))
@ -145,7 +145,7 @@ class SimpleHttpClientTests(HomeserverTestCase):
self.failureResultOf(d, DNSLookupError)
# Try making a POST request to a blacklisted IPv6 address
# Try making a POST request to a blocked IPv6 address
# -------------------------------------------------------
# Make the request
d = defer.ensureDeferred(
@ -159,10 +159,10 @@ class SimpleHttpClientTests(HomeserverTestCase):
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 0)
# Check that it was due to a blacklisted DNS lookup
# Check that it was due to a blocked DNS lookup
self.failureResultOf(d, DNSLookupError)
# Try making a GET request to a non-blacklisted IPv4 address
# Try making a GET request to a non-blocked IPv4 address
# ----------------------------------------------------------
# Make the request
d = defer.ensureDeferred(cl.get_json("http://testserv:8008/foo/bar"))

View File

@ -0,0 +1,113 @@
# Copyright 2023 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from twisted.test.proto_helpers import MemoryReactor
from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest
from tests.unittest import override_config
try:
import lxml
except ImportError:
lxml = None
class URLPreviewTests(unittest.HomeserverTestCase):
if not lxml:
skip = "url preview feature requires lxml"
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
config["url_preview_enabled"] = True
config["max_spider_size"] = 9999999
config["url_preview_ip_range_blacklist"] = (
"192.168.1.1",
"1.0.0.0/8",
"3fff:ffff:ffff:ffff:ffff:ffff:ffff:ffff",
"2001:800::/21",
)
self.storage_path = self.mktemp()
self.media_store_path = self.mktemp()
os.mkdir(self.storage_path)
os.mkdir(self.media_store_path)
config["media_store_path"] = self.media_store_path
provider_config = {
"module": "synapse.media.storage_provider.FileStorageProviderBackend",
"store_local": True,
"store_synchronous": False,
"store_remote": True,
"config": {"directory": self.storage_path},
}
config["media_storage_providers"] = [provider_config]
return self.setup_test_homeserver(config=config)
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
media_repo_resource = hs.get_media_repository_resource()
preview_url = media_repo_resource.children[b"preview_url"]
self.url_previewer = preview_url._url_previewer
def test_all_urls_allowed(self) -> None:
self.assertFalse(self.url_previewer._is_url_blocked("http://matrix.org"))
self.assertFalse(self.url_previewer._is_url_blocked("https://matrix.org"))
self.assertFalse(self.url_previewer._is_url_blocked("http://localhost:8000"))
self.assertFalse(
self.url_previewer._is_url_blocked("http://user:pass@matrix.org")
)
@override_config(
{
"url_preview_url_blacklist": [
{"username": "user"},
{"scheme": "http", "netloc": "matrix.org"},
]
}
)
def test_blocked_url(self) -> None:
# Blocked via scheme and URL.
self.assertTrue(self.url_previewer._is_url_blocked("http://matrix.org"))
# Not blocked because all components must match.
self.assertFalse(self.url_previewer._is_url_blocked("https://matrix.org"))
# Blocked due to the user.
self.assertTrue(
self.url_previewer._is_url_blocked("http://user:pass@example.com")
)
self.assertTrue(self.url_previewer._is_url_blocked("http://user@example.com"))
@override_config({"url_preview_url_blacklist": [{"netloc": "*.example.com"}]})
def test_glob_blocked_url(self) -> None:
# All subdomains are blocked.
self.assertTrue(self.url_previewer._is_url_blocked("http://foo.example.com"))
self.assertTrue(self.url_previewer._is_url_blocked("http://.example.com"))
# The TLD is not blocked.
self.assertFalse(self.url_previewer._is_url_blocked("https://example.com"))
@override_config({"url_preview_url_blacklist": [{"netloc": "^.+\\.example\\.com"}]})
def test_regex_blocked_urL(self) -> None:
# All subdomains are blocked.
self.assertTrue(self.url_previewer._is_url_blocked("http://foo.example.com"))
# Requires a non-empty subdomain.
self.assertFalse(self.url_previewer._is_url_blocked("http://.example.com"))
# The TLD is not blocked.
self.assertFalse(self.url_previewer._is_url_blocked("https://example.com"))

View File

@ -52,7 +52,7 @@ class HTTPPusherTests(HomeserverTestCase):
m.post_json_get_json = post_json_get_json
hs = self.setup_test_homeserver(proxied_blacklisted_http_client=m)
hs = self.setup_test_homeserver(proxied_blocklisted_http_client=m)
return hs

View File

@ -1,13 +0,0 @@
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -24,7 +24,7 @@ from synapse.util import Clock
from tests.replication._base import BaseStreamTestCase
class BaseSlavedStoreTestCase(BaseStreamTestCase):
class BaseWorkerStoreTestCase(BaseStreamTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
return self.setup_test_homeserver(federation_client=Mock())
@ -34,7 +34,7 @@ class BaseSlavedStoreTestCase(BaseStreamTestCase):
self.reconnect()
self.master_store = hs.get_datastores().main
self.slaved_store = self.worker_hs.get_datastores().main
self.worker_store = self.worker_hs.get_datastores().main
persistence = hs.get_storage_controllers().persistence
assert persistence is not None
self.persistance = persistence
@ -50,7 +50,7 @@ class BaseSlavedStoreTestCase(BaseStreamTestCase):
self, method: str, args: Iterable[Any], expected_result: Optional[Any] = None
) -> None:
master_result = self.get_success(getattr(self.master_store, method)(*args))
slaved_result = self.get_success(getattr(self.slaved_store, method)(*args))
worker_result = self.get_success(getattr(self.worker_store, method)(*args))
if expected_result is not None:
self.assertEqual(
master_result,
@ -59,14 +59,14 @@ class BaseSlavedStoreTestCase(BaseStreamTestCase):
% (expected_result, master_result),
)
self.assertEqual(
slaved_result,
worker_result,
expected_result,
"Expected slave result to be %r but was %r"
% (expected_result, slaved_result),
"Expected worker result to be %r but was %r"
% (expected_result, worker_result),
)
self.assertEqual(
master_result,
slaved_result,
"Slave result %r does not match master result %r"
% (slaved_result, master_result),
worker_result,
"Worker result %r does not match master result %r"
% (worker_result, master_result),
)

View File

@ -36,7 +36,7 @@ from synapse.util import Clock
from tests.server import FakeTransport
from ._base import BaseSlavedStoreTestCase
from ._base import BaseWorkerStoreTestCase
USER_ID = "@feeling:test"
USER_ID_2 = "@bright:test"
@ -63,7 +63,7 @@ def patch__eq__(cls: object) -> Callable[[], None]:
return unpatch
class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
class EventsWorkerStoreTestCase(BaseWorkerStoreTestCase):
STORE_TYPE = EventsWorkerStore
def setUp(self) -> None:
@ -294,7 +294,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
assert j2.internal_metadata.stream_ordering is not None
event_source = RoomEventSource(self.hs)
event_source.store = self.slaved_store
event_source.store = self.worker_store
current_token = event_source.get_current_key()
# gradually stream out the replication
@ -310,12 +310,12 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
#
# First, we get a list of the rooms we are joined to
joined_rooms = self.get_success(
self.slaved_store.get_rooms_for_user_with_stream_ordering(USER_ID_2)
self.worker_store.get_rooms_for_user_with_stream_ordering(USER_ID_2)
)
# Then, we get a list of the events since the last sync
membership_changes = self.get_success(
self.slaved_store.get_membership_changes_for_user(
self.worker_store.get_membership_changes_for_user(
USER_ID_2, prev_token, current_token
)
)

View File

@ -93,7 +93,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
self.make_worker_hs(
"synapse.app.generic_worker",
{"worker_name": "pusher1", "pusher_instances": ["pusher1"]},
proxied_blacklisted_http_client=http_client_mock,
proxied_blocklisted_http_client=http_client_mock,
)
event_id = self._create_pusher_and_send_msg("user")
@ -126,7 +126,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
"worker_name": "pusher1",
"pusher_instances": ["pusher1", "pusher2"],
},
proxied_blacklisted_http_client=http_client_mock1,
proxied_blocklisted_http_client=http_client_mock1,
)
http_client_mock2 = Mock(spec_set=["post_json_get_json"])
@ -140,7 +140,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
"worker_name": "pusher2",
"pusher_instances": ["pusher1", "pusher2"],
},
proxied_blacklisted_http_client=http_client_mock2,
proxied_blocklisted_http_client=http_client_mock2,
)
# We choose a user name that we know should go to pusher1.

View File

@ -42,7 +42,7 @@ from tests.test_utils.html_parsers import TestHtmlParser
from tests.unittest import HomeserverTestCase, override_config, skip_unless
try:
from authlib.jose import jwk, jwt
from authlib.jose import JsonWebKey, jwt
HAS_JWT = True
except ImportError:
@ -1054,6 +1054,22 @@ class JWTTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(channel.json_body["error"], "Token field for JWT is missing")
def test_deactivated_user(self) -> None:
"""Logging in as a deactivated account should error."""
user_id = self.register_user("kermit", "monkey")
self.get_success(
self.hs.get_deactivate_account_handler().deactivate_account(
user_id, erase_data=False, requester=create_requester(user_id)
)
)
channel = self.jwt_login({"sub": "kermit"})
self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_USER_DEACTIVATED")
self.assertEqual(
channel.json_body["error"], "This account has been deactivated"
)
# The JWTPubKeyTestCase is a complement to JWTTestCase where we instead use
# RSS256, with a public key configured in synapse as "jwt_secret", and tokens
@ -1121,7 +1137,7 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str:
header = {"alg": "RS256"}
if secret.startswith("-----BEGIN RSA PRIVATE KEY-----"):
secret = jwk.dumps(secret, kty="RSA")
secret = JsonWebKey.import_key(secret, {"kty": "RSA"})
result: bytes = jwt.encode(header, payload, secret)
return result.decode("ascii")

View File

@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from urllib.parse import quote
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
@ -44,8 +46,8 @@ class UserMutualRoomsTest(unittest.HomeserverTestCase):
def _get_mutual_rooms(self, token: str, other_user: str) -> FakeChannel:
return self.make_request(
"GET",
"/_matrix/client/unstable/uk.half-shot.msc2666/user/mutual_rooms/%s"
% other_user,
"/_matrix/client/unstable/uk.half-shot.msc2666/user/mutual_rooms"
f"?user_id={quote(other_user)}",
access_token=token,
)

View File

@ -0,0 +1,147 @@
# Copyright 2023 Beeper
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.api.constants import EventTypes
from synapse.rest import admin
from synapse.rest.client import login, read_marker, register, room
from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest
ONE_HOUR_MS = 3600000
ONE_DAY_MS = ONE_HOUR_MS * 24
class ReadMarkerTestCase(unittest.HomeserverTestCase):
servlets = [
login.register_servlets,
register.register_servlets,
read_marker.register_servlets,
room.register_servlets,
synapse.rest.admin.register_servlets,
admin.register_servlets,
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
# merge this default retention config with anything that was specified in
# @override_config
retention_config = {
"enabled": True,
"allowed_lifetime_min": ONE_DAY_MS,
"allowed_lifetime_max": ONE_DAY_MS * 3,
}
retention_config.update(config.get("retention", {}))
config["retention"] = retention_config
self.hs = self.setup_test_homeserver(config=config)
return self.hs
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.owner = self.register_user("owner", "pass")
self.owner_tok = self.login("owner", "pass")
self.store = self.hs.get_datastores().main
self.clock = self.hs.get_clock()
def test_send_read_marker(self) -> None:
room_id = self.helper.create_room_as(self.owner, tok=self.owner_tok)
def send_message() -> str:
res = self.helper.send(room_id=room_id, body="1", tok=self.owner_tok)
return res["event_id"]
# Test setting the read marker on the room
event_id_1 = send_message()
channel = self.make_request(
"POST",
"/rooms/!abc:beep/read_markers",
content={
"m.fully_read": event_id_1,
},
access_token=self.owner_tok,
)
self.assertEqual(channel.code, 200, channel.result)
# Test moving the read marker to a newer event
event_id_2 = send_message()
channel = self.make_request(
"POST",
"/rooms/!abc:beep/read_markers",
content={
"m.fully_read": event_id_2,
},
access_token=self.owner_tok,
)
self.assertEqual(channel.code, 200, channel.result)
def test_send_read_marker_missing_previous_event(self) -> None:
"""
Test moving a read marker from an event that previously existed but was
later removed due to retention rules.
"""
room_id = self.helper.create_room_as(self.owner, tok=self.owner_tok)
# Set retention rule on the room so we remove old events to test this case
self.helper.send_state(
room_id=room_id,
event_type=EventTypes.Retention,
body={"max_lifetime": ONE_DAY_MS},
tok=self.owner_tok,
)
def send_message() -> str:
res = self.helper.send(room_id=room_id, body="1", tok=self.owner_tok)
return res["event_id"]
# Test setting the read marker on the room
event_id_1 = send_message()
channel = self.make_request(
"POST",
"/rooms/!abc:beep/read_markers",
content={
"m.fully_read": event_id_1,
},
access_token=self.owner_tok,
)
# Send a second message (retention will not remove the latest event ever)
send_message()
# And then advance so retention rules remove the first event (where the marker is)
self.reactor.advance(ONE_DAY_MS * 2 / 1000)
event = self.get_success(self.store.get_event(event_id_1, allow_none=True))
assert event is None
# TODO See https://github.com/matrix-org/synapse/issues/13476
self.store.get_event_ordering.invalidate_all()
# Test moving the read marker to a newer event
event_id_2 = send_message()
channel = self.make_request(
"POST",
"/rooms/!abc:beep/read_markers",
content={
"m.fully_read": event_id_2,
},
access_token=self.owner_tok,
)
self.assertEqual(channel.code, 200, channel.result)

View File

@ -418,9 +418,9 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
)
def test_blacklisted_ip_specific(self) -> None:
def test_blocked_ip_specific(self) -> None:
"""
Blacklisted IP addresses, found via DNS, are not spidered.
Blocked IP addresses, found via DNS, are not spidered.
"""
self.lookups["example.com"] = [(IPv4Address, "192.168.1.1")]
@ -439,9 +439,9 @@ class URLPreviewTests(unittest.HomeserverTestCase):
},
)
def test_blacklisted_ip_range(self) -> None:
def test_blocked_ip_range(self) -> None:
"""
Blacklisted IP ranges, IPs found over DNS, are not spidered.
Blocked IP ranges, IPs found over DNS, are not spidered.
"""
self.lookups["example.com"] = [(IPv4Address, "1.1.1.2")]
@ -458,9 +458,9 @@ class URLPreviewTests(unittest.HomeserverTestCase):
},
)
def test_blacklisted_ip_specific_direct(self) -> None:
def test_blocked_ip_specific_direct(self) -> None:
"""
Blacklisted IP addresses, accessed directly, are not spidered.
Blocked IP addresses, accessed directly, are not spidered.
"""
channel = self.make_request(
"GET", "preview_url?url=http://192.168.1.1", shorthand=False
@ -470,16 +470,13 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(len(self.reactor.tcpClients), 0)
self.assertEqual(
channel.json_body,
{
"errcode": "M_UNKNOWN",
"error": "IP address blocked by IP blacklist entry",
},
{"errcode": "M_UNKNOWN", "error": "IP address blocked"},
)
self.assertEqual(channel.code, 403)
def test_blacklisted_ip_range_direct(self) -> None:
def test_blocked_ip_range_direct(self) -> None:
"""
Blacklisted IP ranges, accessed directly, are not spidered.
Blocked IP ranges, accessed directly, are not spidered.
"""
channel = self.make_request(
"GET", "preview_url?url=http://1.1.1.2", shorthand=False
@ -488,15 +485,12 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 403)
self.assertEqual(
channel.json_body,
{
"errcode": "M_UNKNOWN",
"error": "IP address blocked by IP blacklist entry",
},
{"errcode": "M_UNKNOWN", "error": "IP address blocked"},
)
def test_blacklisted_ip_range_whitelisted_ip(self) -> None:
def test_blocked_ip_range_whitelisted_ip(self) -> None:
"""
Blacklisted but then subsequently whitelisted IP addresses can be
Blocked but then subsequently whitelisted IP addresses can be
spidered.
"""
self.lookups["example.com"] = [(IPv4Address, "1.1.1.1")]
@ -527,10 +521,10 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
)
def test_blacklisted_ip_with_external_ip(self) -> None:
def test_blocked_ip_with_external_ip(self) -> None:
"""
If a hostname resolves a blacklisted IP, even if there's a
non-blacklisted one, it will be rejected.
If a hostname resolves a blocked IP, even if there's a non-blocked one,
it will be rejected.
"""
# Hardcode the URL resolving to the IP we want.
self.lookups["example.com"] = [
@ -550,9 +544,9 @@ class URLPreviewTests(unittest.HomeserverTestCase):
},
)
def test_blacklisted_ipv6_specific(self) -> None:
def test_blocked_ipv6_specific(self) -> None:
"""
Blacklisted IP addresses, found via DNS, are not spidered.
Blocked IP addresses, found via DNS, are not spidered.
"""
self.lookups["example.com"] = [
(IPv6Address, "3fff:ffff:ffff:ffff:ffff:ffff:ffff:ffff")
@ -573,9 +567,9 @@ class URLPreviewTests(unittest.HomeserverTestCase):
},
)
def test_blacklisted_ipv6_range(self) -> None:
def test_blocked_ipv6_range(self) -> None:
"""
Blacklisted IP ranges, IPs found over DNS, are not spidered.
Blocked IP ranges, IPs found over DNS, are not spidered.
"""
self.lookups["example.com"] = [(IPv6Address, "2001:800::1")]
@ -653,6 +647,57 @@ class URLPreviewTests(unittest.HomeserverTestCase):
server.data,
)
def test_image(self) -> None:
"""An image should be precached if mentioned in the HTML."""
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
self.lookups["cdn.matrix.org"] = [(IPv4Address, "10.1.2.4")]
result = (
b"""<html><body><img src="http://cdn.matrix.org/foo.png"></body></html>"""
)
channel = self.make_request(
"GET",
"preview_url?url=http://matrix.org",
shorthand=False,
await_result=False,
)
self.pump()
# Respond with the HTML.
client = self.reactor.tcpClients[0][2].buildProtocol(None)
server = AccumulatingProtocol()
server.makeConnection(FakeTransport(client, self.reactor))
client.makeConnection(FakeTransport(server, self.reactor))
client.dataReceived(
(
b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
b'Content-Type: text/html; charset="utf8"\r\n\r\n'
)
% (len(result),)
+ result
)
self.pump()
# Respond with the photo.
client = self.reactor.tcpClients[1][2].buildProtocol(None)
server = AccumulatingProtocol()
server.makeConnection(FakeTransport(client, self.reactor))
client.makeConnection(FakeTransport(server, self.reactor))
client.dataReceived(
(
b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
b"Content-Type: image/png\r\n\r\n"
)
% (len(SMALL_PNG),)
+ SMALL_PNG
)
self.pump()
# The image should be in the result.
self.assertEqual(channel.code, 200)
self._assert_small_png(channel.json_body)
def test_nonexistent_image(self) -> None:
"""If the preview image doesn't exist, ensure some data is returned."""
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
@ -683,9 +728,53 @@ class URLPreviewTests(unittest.HomeserverTestCase):
)
self.pump()
self.assertEqual(channel.code, 200)
# There should not be a second connection.
self.assertEqual(len(self.reactor.tcpClients), 1)
# The image should not be in the result.
self.assertEqual(channel.code, 200)
self.assertNotIn("og:image", channel.json_body)
@unittest.override_config(
{"url_preview_url_blacklist": [{"netloc": "cdn.matrix.org"}]}
)
def test_image_blocked(self) -> None:
"""If the preview image doesn't exist, ensure some data is returned."""
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
self.lookups["cdn.matrix.org"] = [(IPv4Address, "10.1.2.4")]
result = (
b"""<html><body><img src="http://cdn.matrix.org/foo.jpg"></body></html>"""
)
channel = self.make_request(
"GET",
"preview_url?url=http://matrix.org",
shorthand=False,
await_result=False,
)
self.pump()
client = self.reactor.tcpClients[0][2].buildProtocol(None)
server = AccumulatingProtocol()
server.makeConnection(FakeTransport(client, self.reactor))
client.makeConnection(FakeTransport(server, self.reactor))
client.dataReceived(
(
b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
b'Content-Type: text/html; charset="utf8"\r\n\r\n'
)
% (len(result),)
+ result
)
self.pump()
# There should not be a second connection.
self.assertEqual(len(self.reactor.tcpClients), 1)
# The image should not be in the result.
self.assertEqual(channel.code, 200)
self.assertNotIn("og:image", channel.json_body)
def test_oembed_failure(self) -> None:
@ -880,6 +969,11 @@ class URLPreviewTests(unittest.HomeserverTestCase):
)
self.pump()
# Double check that the proper host is being connected to. (Note that
# twitter.com can't be resolved so this is already implicitly checked.)
self.assertIn(b"\r\nHost: publish.twitter.com\r\n", server.data)
self.assertEqual(channel.code, 200)
body = channel.json_body
self.assertEqual(
@ -940,6 +1034,22 @@ class URLPreviewTests(unittest.HomeserverTestCase):
},
)
@unittest.override_config(
{"url_preview_url_blacklist": [{"netloc": "publish.twitter.com"}]}
)
def test_oembed_blocked(self) -> None:
"""The oEmbed URL should not be downloaded if the oEmbed URL is blocked."""
self.lookups["twitter.com"] = [(IPv4Address, "10.1.2.3")]
channel = self.make_request(
"GET",
"preview_url?url=http://twitter.com/matrixdotorg/status/12345",
shorthand=False,
await_result=False,
)
self.pump()
self.assertEqual(channel.code, 403, channel.result)
def test_oembed_autodiscovery(self) -> None:
"""
Autodiscovery works by finding the link in the HTML response and then requesting an oEmbed URL.
@ -980,7 +1090,6 @@ class URLPreviewTests(unittest.HomeserverTestCase):
% (len(result),)
+ result
)
self.pump()
# The oEmbed response.
@ -1004,7 +1113,6 @@ class URLPreviewTests(unittest.HomeserverTestCase):
% (len(oembed_content),)
+ oembed_content
)
self.pump()
# Ensure the URL is what was requested.
@ -1023,7 +1131,6 @@ class URLPreviewTests(unittest.HomeserverTestCase):
% (len(SMALL_PNG),)
+ SMALL_PNG
)
self.pump()
# Ensure the URL is what was requested.
@ -1036,6 +1143,59 @@ class URLPreviewTests(unittest.HomeserverTestCase):
)
self._assert_small_png(body)
@unittest.override_config(
{"url_preview_url_blacklist": [{"netloc": "publish.twitter.com"}]}
)
def test_oembed_autodiscovery_blocked(self) -> None:
"""
If the discovered oEmbed URL is blocked, it should be discarded.
"""
# This is a little cheesy in that we use the www subdomain (which isn't the
# list of oEmbed patterns) to get "raw" HTML response.
self.lookups["www.twitter.com"] = [(IPv4Address, "10.1.2.3")]
self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.4")]
result = b"""
<title>Test</title>
<link rel="alternate" type="application/json+oembed"
href="http://publish.twitter.com/oembed?url=http%3A%2F%2Fcdn.twitter.com%2Fmatrixdotorg%2Fstatus%2F12345&format=json"
title="matrixdotorg" />
"""
channel = self.make_request(
"GET",
"preview_url?url=http://www.twitter.com/matrixdotorg/status/12345",
shorthand=False,
await_result=False,
)
self.pump()
client = self.reactor.tcpClients[0][2].buildProtocol(None)
server = AccumulatingProtocol()
server.makeConnection(FakeTransport(client, self.reactor))
client.makeConnection(FakeTransport(server, self.reactor))
client.dataReceived(
(
b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
b'Content-Type: text/html; charset="utf8"\r\n\r\n'
)
% (len(result),)
+ result
)
self.pump()
# Ensure there's no additional connections.
self.assertEqual(len(self.reactor.tcpClients), 1)
# Ensure the URL is what was requested.
self.assertIn(b"\r\nHost: www.twitter.com\r\n", server.data)
self.assertEqual(channel.code, 200)
body = channel.json_body
self.assertEqual(body["og:title"], "Test")
self.assertNotIn("og:image", body)
def _download_image(self) -> Tuple[str, str]:
"""Downloads an image into the URL cache.
Returns:
@ -1192,8 +1352,8 @@ class URLPreviewTests(unittest.HomeserverTestCase):
)
@unittest.override_config({"url_preview_url_blacklist": [{"port": "*"}]})
def test_blacklist_port(self) -> None:
"""Tests that blacklisting URLs with a port makes previewing such URLs
def test_blocked_port(self) -> None:
"""Tests that blocking URLs with a port makes previewing such URLs
fail with a 403 error and doesn't impact other previews.
"""
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
@ -1230,3 +1390,23 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.pump()
self.assertEqual(channel.code, 200)
@unittest.override_config(
{"url_preview_url_blacklist": [{"netloc": "example.com"}]}
)
def test_blocked_url(self) -> None:
"""Tests that blocking URLs with a host makes previewing such URLs
fail with a 403 error.
"""
self.lookups["example.com"] = [(IPv4Address, "10.1.2.3")]
bad_url = quote("http://example.com/foo")
channel = self.make_request(
"GET",
"preview_url?url=" + bad_url,
shorthand=False,
await_result=False,
)
self.pump()
self.assertEqual(channel.code, 403, channel.result)

View File

@ -14,6 +14,8 @@
from twisted.test.proto_helpers import MemoryReactor
from synapse.server import HomeServer
from synapse.storage.database import LoggingTransaction
from synapse.storage.engines import PostgresEngine
from synapse.types import UserID
from synapse.util import Clock
@ -69,3 +71,64 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase):
self.assertIsNone(
self.get_success(self.store.get_profile_avatar_url(self.u_frank.localpart))
)
def test_profiles_bg_migration(self) -> None:
"""
Test background job that copies entries from column user_id to full_user_id, adding
the hostname in the process.
"""
updater = self.hs.get_datastores().main.db_pool.updates
# drop the constraint so we can insert nulls in full_user_id to populate the test
if isinstance(self.store.database_engine, PostgresEngine):
def f(txn: LoggingTransaction) -> None:
txn.execute(
"ALTER TABLE profiles DROP CONSTRAINT full_user_id_not_null"
)
self.get_success(self.store.db_pool.runInteraction("", f))
for i in range(0, 70):
self.get_success(
self.store.db_pool.simple_insert(
"profiles",
{"user_id": f"hello{i:02}"},
)
)
# re-add the constraint so that when it's validated it actually exists
if isinstance(self.store.database_engine, PostgresEngine):
def f(txn: LoggingTransaction) -> None:
txn.execute(
"ALTER TABLE profiles ADD CONSTRAINT full_user_id_not_null CHECK (full_user_id IS NOT NULL) NOT VALID"
)
self.get_success(self.store.db_pool.runInteraction("", f))
self.get_success(
self.store.db_pool.simple_insert(
"background_updates",
values={
"update_name": "populate_full_user_id_profiles",
"progress_json": "{}",
},
)
)
self.get_success(
updater.run_background_updates(False),
)
expected_values = []
for i in range(0, 70):
expected_values.append((f"@hello{i:02}:{self.hs.hostname}",))
res = self.get_success(
self.store.db_pool.execute(
"", None, "SELECT full_user_id from profiles ORDER BY full_user_id"
)
)
self.assertEqual(len(res), len(expected_values))
self.assertEqual(res, expected_values)

View File

@ -0,0 +1,94 @@
# Copyright 2023 The Matrix.org Foundation C.I.C
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.test.proto_helpers import MemoryReactor
from synapse.server import HomeServer
from synapse.storage.database import LoggingTransaction
from synapse.storage.engines import PostgresEngine
from synapse.util import Clock
from tests import unittest
class UserFiltersStoreTestCase(unittest.HomeserverTestCase):
"""
Test background migration that copies entries from column user_id to full_user_id, adding
the hostname in the process.
"""
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
def test_bg_migration(self) -> None:
updater = self.hs.get_datastores().main.db_pool.updates
# drop the constraint so we can insert nulls in full_user_id to populate the test
if isinstance(self.store.database_engine, PostgresEngine):
def f(txn: LoggingTransaction) -> None:
txn.execute(
"ALTER TABLE user_filters DROP CONSTRAINT full_user_id_not_null"
)
self.get_success(self.store.db_pool.runInteraction("", f))
for i in range(0, 70):
self.get_success(
self.store.db_pool.simple_insert(
"user_filters",
{
"user_id": f"hello{i:02}",
"filter_id": i,
"filter_json": bytearray(i),
},
)
)
# re-add the constraint so that when it's validated it actually exists
if isinstance(self.store.database_engine, PostgresEngine):
def f(txn: LoggingTransaction) -> None:
txn.execute(
"ALTER TABLE user_filters ADD CONSTRAINT full_user_id_not_null CHECK (full_user_id IS NOT NULL) NOT VALID"
)
self.get_success(self.store.db_pool.runInteraction("", f))
self.get_success(
self.store.db_pool.simple_insert(
"background_updates",
values={
"update_name": "populate_full_user_id_user_filters",
"progress_json": "{}",
},
)
)
self.get_success(
updater.run_background_updates(False),
)
expected_values = []
for i in range(0, 70):
expected_values.append((f"@hello{i:02}:{self.hs.hostname}",))
res = self.get_success(
self.store.db_pool.execute(
"", None, "SELECT full_user_id from user_filters ORDER BY full_user_id"
)
)
self.assertEqual(len(res), len(expected_values))
self.assertEqual(res, expected_values)

View File

@ -264,7 +264,7 @@ class StateTestCase(unittest.TestCase):
self.dummy_store.register_events(graph.walk())
context_store: dict[str, EventContext] = {}
context_store: Dict[str, EventContext] = {}
for event in graph.walk():
context = yield defer.ensureDeferred(

View File

@ -229,13 +229,20 @@ class TestCase(unittest.TestCase):
#
# The easiest way to do this would be to do a full GC after each test
# run, but that is very expensive. Instead, we disable GC (above) for
# the duration of the test so that we only need to run a gen-0 GC, which
# is a lot quicker.
# the duration of the test and only run a gen-0 GC, which is a lot
# quicker. This doesn't clean up everything, since the TestCase
# instance still holds references to objects created during the test,
# such as HomeServers, so we do a full GC every so often.
@around(self)
def tearDown(orig: Callable[[], R]) -> R:
ret = orig()
gc.collect(0)
# Run a full GC every 50 gen-0 GCs.
gen0_stats = gc.get_stats()[0]
gen0_collections = gen0_stats["collections"]
if gen0_collections % 50 == 0:
gc.collect()
gc.enable()
set_current_context(SENTINEL_CONTEXT)