Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes

pull/13598/head
Andrew Morgan 2022-08-22 10:47:30 +01:00
commit 80bf6da876
27 changed files with 232 additions and 129 deletions

1
changelog.d/13543.misc Normal file
View File

@ -0,0 +1 @@
Reduce the number of tests using legacy TCP replication.

View File

@ -0,0 +1 @@
Add an experimental implementation for [MSC3852](https://github.com/matrix-org/matrix-spec-proposals/pull/3852).

1
changelog.d/13549.misc Normal file
View File

@ -0,0 +1 @@
Allow specifying additional request fields when using the `HomeServerTestCase.login` helper method.

View File

@ -0,0 +1 @@
Add `org.matrix.msc2716v4` experimental room version with updated content fields.

1
changelog.d/13558.misc Normal file
View File

@ -0,0 +1 @@
Make `HomeServerTestCase` load any configured homeserver modules automatically.

1
changelog.d/13574.bugfix Normal file
View File

@ -0,0 +1 @@
Fix the `opentracing.force_tracing_for_users` config option not applying to [`/sendToDevice`](https://spec.matrix.org/v1.3/client-server-api/#put_matrixclientv3sendtodeviceeventtypetxnid) and [`/keys/upload`](https://spec.matrix.org/v1.3/client-server-api/#post_matrixclientv3keysupload) requests.

View File

@ -753,6 +753,7 @@ A response body like the following is returned:
"device_id": "QBUAZIFURK", "device_id": "QBUAZIFURK",
"display_name": "android", "display_name": "android",
"last_seen_ip": "1.2.3.4", "last_seen_ip": "1.2.3.4",
"last_seen_user_agent": "Mozilla/5.0 (X11; Linux x86_64; rv:103.0) Gecko/20100101 Firefox/103.0",
"last_seen_ts": 1474491775024, "last_seen_ts": 1474491775024,
"user_id": "<user_id>" "user_id": "<user_id>"
}, },
@ -760,6 +761,7 @@ A response body like the following is returned:
"device_id": "AUIECTSRND", "device_id": "AUIECTSRND",
"display_name": "ios", "display_name": "ios",
"last_seen_ip": "1.2.3.5", "last_seen_ip": "1.2.3.5",
"last_seen_user_agent": "Mozilla/5.0 (X11; Linux x86_64; rv:103.0) Gecko/20100101 Firefox/103.0",
"last_seen_ts": 1474491775025, "last_seen_ts": 1474491775025,
"user_id": "<user_id>" "user_id": "<user_id>"
} }
@ -786,6 +788,8 @@ The following fields are returned in the JSON response body:
Absent if no name has been set. Absent if no name has been set.
- `last_seen_ip` - The IP address where this device was last seen. - `last_seen_ip` - The IP address where this device was last seen.
(May be a few minutes out of date, for efficiency reasons). (May be a few minutes out of date, for efficiency reasons).
- `last_seen_user_agent` - The user agent of the device when it was last seen.
(May be a few minutes out of date, for efficiency reasons).
- `last_seen_ts` - The timestamp (in milliseconds since the unix epoch) when this - `last_seen_ts` - The timestamp (in milliseconds since the unix epoch) when this
devices was last seen. (May be a few minutes out of date, for efficiency reasons). devices was last seen. (May be a few minutes out of date, for efficiency reasons).
- `user_id` - Owner of device. - `user_id` - Owner of device.
@ -837,6 +841,7 @@ A response body like the following is returned:
"device_id": "<device_id>", "device_id": "<device_id>",
"display_name": "android", "display_name": "android",
"last_seen_ip": "1.2.3.4", "last_seen_ip": "1.2.3.4",
"last_seen_user_agent": "Mozilla/5.0 (X11; Linux x86_64; rv:103.0) Gecko/20100101 Firefox/103.0",
"last_seen_ts": 1474491775024, "last_seen_ts": 1474491775024,
"user_id": "<user_id>" "user_id": "<user_id>"
} }
@ -858,6 +863,8 @@ The following fields are returned in the JSON response body:
Absent if no name has been set. Absent if no name has been set.
- `last_seen_ip` - The IP address where this device was last seen. - `last_seen_ip` - The IP address where this device was last seen.
(May be a few minutes out of date, for efficiency reasons). (May be a few minutes out of date, for efficiency reasons).
- `last_seen_user_agent` - The user agent of the device when it was last seen.
(May be a few minutes out of date, for efficiency reasons).
- `last_seen_ts` - The timestamp (in milliseconds since the unix epoch) when this - `last_seen_ts` - The timestamp (in milliseconds since the unix epoch) when this
devices was last seen. (May be a few minutes out of date, for efficiency reasons). devices was last seen. (May be a few minutes out of date, for efficiency reasons).
- `user_id` - Owner of device. - `user_id` - Owner of device.

View File

@ -216,11 +216,11 @@ class EventContentFields:
MSC2716_HISTORICAL: Final = "org.matrix.msc2716.historical" MSC2716_HISTORICAL: Final = "org.matrix.msc2716.historical"
# For "insertion" events to indicate what the next batch ID should be in # For "insertion" events to indicate what the next batch ID should be in
# order to connect to it # order to connect to it
MSC2716_NEXT_BATCH_ID: Final = "org.matrix.msc2716.next_batch_id" MSC2716_NEXT_BATCH_ID: Final = "next_batch_id"
# Used on "batch" events to indicate which insertion event it connects to # Used on "batch" events to indicate which insertion event it connects to
MSC2716_BATCH_ID: Final = "org.matrix.msc2716.batch_id" MSC2716_BATCH_ID: Final = "batch_id"
# For "marker" events # For "marker" events
MSC2716_MARKER_INSERTION: Final = "org.matrix.msc2716.marker.insertion" MSC2716_INSERTION_EVENT_REFERENCE: Final = "insertion_event_reference"
# The authorising user for joining a restricted room. # The authorising user for joining a restricted room.
AUTHORISING_USER: Final = "join_authorised_via_users_server" AUTHORISING_USER: Final = "join_authorised_via_users_server"

View File

@ -269,24 +269,6 @@ class RoomVersions:
msc3787_knock_restricted_join_rule=False, msc3787_knock_restricted_join_rule=False,
msc3667_int_only_power_levels=False, msc3667_int_only_power_levels=False,
) )
MSC2716v3 = RoomVersion(
"org.matrix.msc2716v3",
RoomDisposition.UNSTABLE,
EventFormatVersions.V3,
StateResolutionVersions.V2,
enforce_key_validity=True,
special_case_aliases_auth=False,
strict_canonicaljson=True,
limit_notifications_power_levels=True,
msc2176_redaction_rules=False,
msc3083_join_rules=False,
msc3375_redaction_rules=False,
msc2403_knocking=True,
msc2716_historical=True,
msc2716_redactions=True,
msc3787_knock_restricted_join_rule=False,
msc3667_int_only_power_levels=False,
)
MSC3787 = RoomVersion( MSC3787 = RoomVersion(
"org.matrix.msc3787", "org.matrix.msc3787",
RoomDisposition.UNSTABLE, RoomDisposition.UNSTABLE,
@ -323,6 +305,24 @@ class RoomVersions:
msc3787_knock_restricted_join_rule=True, msc3787_knock_restricted_join_rule=True,
msc3667_int_only_power_levels=True, msc3667_int_only_power_levels=True,
) )
MSC2716v4 = RoomVersion(
"org.matrix.msc2716v4",
RoomDisposition.UNSTABLE,
EventFormatVersions.V3,
StateResolutionVersions.V2,
enforce_key_validity=True,
special_case_aliases_auth=False,
strict_canonicaljson=True,
limit_notifications_power_levels=True,
msc2176_redaction_rules=False,
msc3083_join_rules=False,
msc3375_redaction_rules=False,
msc2403_knocking=True,
msc2716_historical=True,
msc2716_redactions=True,
msc3787_knock_restricted_join_rule=False,
msc3667_int_only_power_levels=False,
)
KNOWN_ROOM_VERSIONS: Dict[str, RoomVersion] = { KNOWN_ROOM_VERSIONS: Dict[str, RoomVersion] = {
@ -338,9 +338,9 @@ KNOWN_ROOM_VERSIONS: Dict[str, RoomVersion] = {
RoomVersions.V7, RoomVersions.V7,
RoomVersions.V8, RoomVersions.V8,
RoomVersions.V9, RoomVersions.V9,
RoomVersions.MSC2716v3,
RoomVersions.MSC3787, RoomVersions.MSC3787,
RoomVersions.V10, RoomVersions.V10,
RoomVersions.MSC2716v4,
) )
} }

View File

@ -90,3 +90,6 @@ class ExperimentalConfig(Config):
# MSC3848: Introduce errcodes for specific event sending failures # MSC3848: Introduce errcodes for specific event sending failures
self.msc3848_enabled: bool = experimental.get("msc3848_enabled", False) self.msc3848_enabled: bool = experimental.get("msc3848_enabled", False)
# MSC3852: Expose last seen user agent field on /_matrix/client/v3/devices.
self.msc3852_enabled: bool = experimental.get("msc3852_enabled", False)

View File

@ -161,7 +161,7 @@ def prune_event_dict(room_version: RoomVersion, event_dict: JsonDict) -> JsonDic
elif room_version.msc2716_redactions and event_type == EventTypes.MSC2716_BATCH: elif room_version.msc2716_redactions and event_type == EventTypes.MSC2716_BATCH:
add_fields(EventContentFields.MSC2716_BATCH_ID) add_fields(EventContentFields.MSC2716_BATCH_ID)
elif room_version.msc2716_redactions and event_type == EventTypes.MSC2716_MARKER: elif room_version.msc2716_redactions and event_type == EventTypes.MSC2716_MARKER:
add_fields(EventContentFields.MSC2716_MARKER_INSERTION) add_fields(EventContentFields.MSC2716_INSERTION_EVENT_REFERENCE)
allowed_fields = {k: v for k, v in event_dict.items() if k in allowed_keys} allowed_fields = {k: v for k, v in event_dict.items() if k in allowed_keys}

View File

@ -74,6 +74,7 @@ class DeviceWorkerHandler:
self._state_storage = hs.get_storage_controllers().state self._state_storage = hs.get_storage_controllers().state
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
self.server_name = hs.hostname self.server_name = hs.hostname
self._msc3852_enabled = hs.config.experimental.msc3852_enabled
@trace @trace
async def get_devices_by_user(self, user_id: str) -> List[JsonDict]: async def get_devices_by_user(self, user_id: str) -> List[JsonDict]:
@ -747,7 +748,13 @@ def _update_device_from_client_ips(
device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]] device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]]
) -> None: ) -> None:
ip = client_ips.get((device["user_id"], device["device_id"]), {}) ip = client_ips.get((device["user_id"], device["device_id"]), {})
device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")}) device.update(
{
"last_seen_user_agent": ip.get("user_agent"),
"last_seen_ts": ip.get("last_seen"),
"last_seen_ip": ip.get("ip"),
}
)
class DeviceListUpdater: class DeviceListUpdater:

View File

@ -1384,7 +1384,7 @@ class FederationEventHandler:
logger.debug("_handle_marker_event: received %s", marker_event) logger.debug("_handle_marker_event: received %s", marker_event)
insertion_event_id = marker_event.content.get( insertion_event_id = marker_event.content.get(
EventContentFields.MSC2716_MARKER_INSERTION EventContentFields.MSC2716_INSERTION_EVENT_REFERENCE
) )
if insertion_event_id is None: if insertion_event_id is None:

View File

@ -42,12 +42,26 @@ class DevicesRestServlet(RestServlet):
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
self._msc3852_enabled = hs.config.experimental.msc3852_enabled
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
devices = await self.device_handler.get_devices_by_user( devices = await self.device_handler.get_devices_by_user(
requester.user.to_string() requester.user.to_string()
) )
# If MSC3852 is disabled, then the "last_seen_user_agent" field will be
# removed from each device. If it is enabled, then the field name will
# be replaced by the unstable identifier.
#
# When MSC3852 is accepted, this block of code can just be removed to
# expose "last_seen_user_agent" to clients.
for device in devices:
last_seen_user_agent = device["last_seen_user_agent"]
del device["last_seen_user_agent"]
if self._msc3852_enabled:
device["org.matrix.msc3852.last_seen_user_agent"] = last_seen_user_agent
return 200, {"devices": devices} return 200, {"devices": devices}
@ -108,6 +122,7 @@ class DeviceRestServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
self._msc3852_enabled = hs.config.experimental.msc3852_enabled
async def on_GET( async def on_GET(
self, request: SynapseRequest, device_id: str self, request: SynapseRequest, device_id: str
@ -118,6 +133,18 @@ class DeviceRestServlet(RestServlet):
) )
if device is None: if device is None:
raise NotFoundError("No device found") raise NotFoundError("No device found")
# If MSC3852 is disabled, then the "last_seen_user_agent" field will be
# removed from each device. If it is enabled, then the field name will
# be replaced by the unstable identifier.
#
# When MSC3852 is accepted, this block of code can just be removed to
# expose "last_seen_user_agent" to clients.
last_seen_user_agent = device["last_seen_user_agent"]
del device["last_seen_user_agent"]
if self._msc3852_enabled:
device["org.matrix.msc3852.last_seen_user_agent"] = last_seen_user_agent
return 200, device return 200, device
@interactive_auth_handler @interactive_auth_handler

View File

@ -26,7 +26,7 @@ from synapse.http.servlet import (
parse_string, parse_string,
) )
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import log_kv, set_tag, trace_with_opname from synapse.logging.opentracing import log_kv, set_tag
from synapse.types import JsonDict, StreamToken from synapse.types import JsonDict, StreamToken
from ._base import client_patterns, interactive_auth_handler from ._base import client_patterns, interactive_auth_handler
@ -71,7 +71,6 @@ class KeyUploadServlet(RestServlet):
self.e2e_keys_handler = hs.get_e2e_keys_handler() self.e2e_keys_handler = hs.get_e2e_keys_handler()
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
@trace_with_opname("upload_keys")
async def on_POST( async def on_POST(
self, request: SynapseRequest, device_id: Optional[str] self, request: SynapseRequest, device_id: Optional[str]
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:

View File

@ -19,7 +19,7 @@ from synapse.http import servlet
from synapse.http.server import HttpServer from synapse.http.server import HttpServer
from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import set_tag, trace_with_opname from synapse.logging.opentracing import set_tag
from synapse.rest.client.transactions import HttpTransactionCache from synapse.rest.client.transactions import HttpTransactionCache
from synapse.types import JsonDict from synapse.types import JsonDict
@ -43,7 +43,6 @@ class SendToDeviceRestServlet(servlet.RestServlet):
self.txns = HttpTransactionCache(hs) self.txns = HttpTransactionCache(hs)
self.device_message_handler = hs.get_device_message_handler() self.device_message_handler = hs.get_device_message_handler()
@trace_with_opname("sendToDevice")
def on_PUT( def on_PUT(
self, request: SynapseRequest, message_type: str, txn_id: str self, request: SynapseRequest, message_type: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]: ) -> Awaitable[Tuple[int, JsonDict]]:

View File

@ -141,10 +141,6 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
hs = self.setup_test_homeserver( hs = self.setup_test_homeserver(
federation_transport_client=fed_transport_client, federation_transport_client=fed_transport_client,
) )
# Load the modules into the homeserver
module_api = hs.get_module_api()
for module, config in hs.config.modules.loaded_modules:
module(config=config, api=module_api)
load_legacy_presence_router(hs) load_legacy_presence_router(hs)

View File

@ -21,7 +21,6 @@ from unittest.mock import Mock
import synapse import synapse
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import Codes from synapse.api.errors import Codes
from synapse.handlers.auth import load_legacy_password_auth_providers
from synapse.module_api import ModuleApi from synapse.module_api import ModuleApi
from synapse.rest.client import account, devices, login, logout, register from synapse.rest.client import account, devices, login, logout, register
from synapse.types import JsonDict, UserID from synapse.types import JsonDict, UserID
@ -167,16 +166,6 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.reset_mock() mock_password_provider.reset_mock()
super().setUp() super().setUp()
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver()
# Load the modules into the homeserver
module_api = hs.get_module_api()
for module, config in hs.config.modules.loaded_modules:
module(config=config, api=module_api)
load_legacy_password_auth_providers(hs)
return hs
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
def test_password_only_auth_progiver_login_legacy(self): def test_password_only_auth_progiver_login_legacy(self):
self.password_only_auth_provider_login_test_body() self.password_only_auth_provider_login_test_body()

View File

@ -22,7 +22,6 @@ from synapse.api.errors import (
ResourceLimitError, ResourceLimitError,
SynapseError, SynapseError,
) )
from synapse.events.spamcheck import load_legacy_spam_checkers
from synapse.spam_checker_api import RegistrationBehaviour from synapse.spam_checker_api import RegistrationBehaviour
from synapse.types import RoomAlias, RoomID, UserID, create_requester from synapse.types import RoomAlias, RoomID, UserID, create_requester
@ -144,12 +143,6 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
config=hs_config, federation_client=self.mock_federation_client config=hs_config, federation_client=self.mock_federation_client
) )
load_legacy_spam_checkers(hs)
module_api = hs.get_module_api()
for module, config in hs.config.modules.loaded_modules:
module(config=config, api=module_api)
return hs return hs
def prepare(self, reactor, clock, hs): def prepare(self, reactor, clock, hs):

View File

@ -14,7 +14,7 @@ from synapse.server import HomeServer
from synapse.types import UserID, create_requester from synapse.types import UserID, create_requester
from synapse.util import Clock from synapse.util import Clock
from tests.replication._base import RedisMultiWorkerStreamTestCase from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import make_request from tests.server import make_request
from tests.test_utils import make_awaitable from tests.test_utils import make_awaitable
from tests.unittest import FederatingHomeserverTestCase, override_config from tests.unittest import FederatingHomeserverTestCase, override_config
@ -216,7 +216,7 @@ class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase):
# - trying to remote-join again. # - trying to remote-join again.
class TestReplicatedJoinsLimitedByPerRoomRateLimiter(RedisMultiWorkerStreamTestCase): class TestReplicatedJoinsLimitedByPerRoomRateLimiter(BaseMultiWorkerStreamTestCase):
servlets = [ servlets = [
synapse.rest.admin.register_servlets, synapse.rest.admin.register_servlets,
synapse.rest.client.login.register_servlets, synapse.rest.client.login.register_servlets,

View File

@ -30,7 +30,6 @@ from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.test_utils import simple_async_mock from tests.test_utils import simple_async_mock
from tests.test_utils.event_injection import inject_member_event from tests.test_utils.event_injection import inject_member_event
from tests.unittest import HomeserverTestCase, override_config from tests.unittest import HomeserverTestCase, override_config
from tests.utils import USE_POSTGRES_FOR_TESTS
class ModuleApiTestCase(HomeserverTestCase): class ModuleApiTestCase(HomeserverTestCase):
@ -738,11 +737,6 @@ class ModuleApiTestCase(HomeserverTestCase):
class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase): class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase):
"""For testing ModuleApi functionality in a multi-worker setup""" """For testing ModuleApi functionality in a multi-worker setup"""
# Testing stream ID replication from the main to worker processes requires postgres
# (due to needing `MultiWriterIdGenerator`).
if not USE_POSTGRES_FOR_TESTS:
skip = "Requires Postgres"
servlets = [ servlets = [
admin.register_servlets, admin.register_servlets,
login.register_servlets, login.register_servlets,
@ -752,7 +746,6 @@ class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase):
def default_config(self): def default_config(self):
conf = super().default_config() conf = super().default_config()
conf["redis"] = {"enabled": "true"}
conf["stream_writers"] = {"presence": ["presence_writer"]} conf["stream_writers"] = {"presence": ["presence_writer"]}
conf["instance_map"] = { conf["instance_map"] = {
"presence_writer": {"host": "testserv", "port": 1001}, "presence_writer": {"host": "testserv", "port": 1001},

View File

@ -24,11 +24,11 @@ from synapse.http.site import SynapseRequest, SynapseSite
from synapse.replication.http import ReplicationRestResource from synapse.replication.http import ReplicationRestResource
from synapse.replication.tcp.client import ReplicationDataHandler from synapse.replication.tcp.client import ReplicationDataHandler
from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol from synapse.replication.tcp.protocol import (
from synapse.replication.tcp.resource import ( ClientReplicationStreamProtocol,
ReplicationStreamProtocolFactory,
ServerReplicationStreamProtocol, ServerReplicationStreamProtocol,
) )
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.server import HomeServer from synapse.server import HomeServer
from tests import unittest from tests import unittest
@ -220,15 +220,34 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
"""Base class for tests running multiple workers. """Base class for tests running multiple workers.
Enables Redis, providing a fake Redis server.
Automatically handle HTTP replication requests from workers to master, Automatically handle HTTP replication requests from workers to master,
unlike `BaseStreamTestCase`. unlike `BaseStreamTestCase`.
""" """
if not hiredis:
skip = "Requires hiredis"
if not USE_POSTGRES_FOR_TESTS:
# Redis replication only takes place on Postgres
skip = "Requires Postgres"
def default_config(self) -> Dict[str, Any]:
"""
Overrides the default config to enable Redis.
Even if the test only uses make_worker_hs, the main process needs Redis
enabled otherwise it won't create a Fake Redis server to listen on the
Redis port and accept fake TCP connections.
"""
base = super().default_config()
base["redis"] = {"enabled": True}
return base
def setUp(self): def setUp(self):
super().setUp() super().setUp()
# build a replication server # build a replication server
self.server_factory = ReplicationStreamProtocolFactory(self.hs)
self.streamer = self.hs.get_replication_streamer() self.streamer = self.hs.get_replication_streamer()
# Fake in memory Redis server that servers can connect to. # Fake in memory Redis server that servers can connect to.
@ -247,15 +266,14 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
# handling inbound HTTP requests to that instance. # handling inbound HTTP requests to that instance.
self._hs_to_site = {self.hs: self.site} self._hs_to_site = {self.hs: self.site}
if self.hs.config.redis.redis_enabled: # Handle attempts to connect to fake redis server.
# Handle attempts to connect to fake redis server. self.reactor.add_tcp_client_callback(
self.reactor.add_tcp_client_callback( "localhost",
"localhost", 6379,
6379, self.connect_any_redis_attempts,
self.connect_any_redis_attempts, )
)
self.hs.get_replication_command_handler().start_replication(self.hs) self.hs.get_replication_command_handler().start_replication(self.hs)
# When we see a connection attempt to the master replication listener we # When we see a connection attempt to the master replication listener we
# automatically set up the connection. This is so that tests don't # automatically set up the connection. This is so that tests don't
@ -339,27 +357,6 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
store = worker_hs.get_datastores().main store = worker_hs.get_datastores().main
store.db_pool._db_pool = self.database_pool._db_pool store.db_pool._db_pool = self.database_pool._db_pool
# Set up TCP replication between master and the new worker if we don't
# have Redis support enabled.
if not worker_hs.config.redis.redis_enabled:
repl_handler = ReplicationCommandHandler(worker_hs)
client = ClientReplicationStreamProtocol(
worker_hs,
"client",
"test",
self.clock,
repl_handler,
)
server = self.server_factory.buildProtocol(
IPv4Address("TCP", "127.0.0.1", 0)
)
client_transport = FakeTransport(server, self.reactor)
client.makeConnection(client_transport)
server_transport = FakeTransport(client, self.reactor)
server.makeConnection(server_transport)
# Set up a resource for the worker # Set up a resource for the worker
resource = ReplicationRestResource(worker_hs) resource = ReplicationRestResource(worker_hs)
@ -378,8 +375,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
reactor=self.reactor, reactor=self.reactor,
) )
if worker_hs.config.redis.redis_enabled: worker_hs.get_replication_command_handler().start_replication(worker_hs)
worker_hs.get_replication_command_handler().start_replication(worker_hs)
return worker_hs return worker_hs
@ -582,27 +578,3 @@ class FakeRedisPubSubProtocol(Protocol):
def connectionLost(self, reason): def connectionLost(self, reason):
self._server.remove_subscriber(self) self._server.remove_subscriber(self)
class RedisMultiWorkerStreamTestCase(BaseMultiWorkerStreamTestCase):
"""
A test case that enables Redis, providing a fake Redis server.
"""
if not hiredis:
skip = "Requires hiredis"
if not USE_POSTGRES_FOR_TESTS:
# Redis replication only takes place on Postgres
skip = "Requires Postgres"
def default_config(self) -> Dict[str, Any]:
"""
Overrides the default config to enable Redis.
Even if the test only uses make_worker_hs, the main process needs Redis
enabled otherwise it won't create a Fake Redis server to listen on the
Redis port and accept fake TCP connections.
"""
base = super().default_config()
base["redis"] = {"enabled": True}
return base

View File

@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from tests.replication._base import RedisMultiWorkerStreamTestCase from tests.replication._base import BaseMultiWorkerStreamTestCase
class ChannelsTestCase(RedisMultiWorkerStreamTestCase): class ChannelsTestCase(BaseMultiWorkerStreamTestCase):
def test_subscribed_to_enough_redis_channels(self) -> None: def test_subscribed_to_enough_redis_channels(self) -> None:
# The default main process is subscribed to the USER_IP channel. # The default main process is subscribed to the USER_IP channel.
self.assertCountEqual( self.assertCountEqual(

View File

@ -20,7 +20,6 @@ from synapse.storage.util.id_generators import MultiWriterIdGenerator
from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import make_request from tests.server import make_request
from tests.utils import USE_POSTGRES_FOR_TESTS
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -28,11 +27,6 @@ logger = logging.getLogger(__name__)
class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
"""Checks event persisting sharding works""" """Checks event persisting sharding works"""
# Event persister sharding requires postgres (due to needing
# `MultiWriterIdGenerator`).
if not USE_POSTGRES_FOR_TESTS:
skip = "Requires Postgres"
servlets = [ servlets = [
admin.register_servlets_for_client_rest_resource, admin.register_servlets_for_client_rest_resource,
room.register_servlets, room.register_servlets,
@ -50,7 +44,6 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
def default_config(self): def default_config(self):
conf = super().default_config() conf = super().default_config()
conf["redis"] = {"enabled": "true"}
conf["stream_writers"] = {"events": ["worker1", "worker2"]} conf["stream_writers"] = {"events": ["worker1", "worker2"]}
conf["instance_map"] = { conf["instance_map"] = {
"worker1": {"host": "testserv", "port": 1001}, "worker1": {"host": "testserv", "port": 1001},

View File

@ -1,4 +1,4 @@
# Copyright 2018-2021 The Matrix.org Foundation C.I.C. # Copyright 2018-2022 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -904,6 +904,96 @@ class UsersListTestCase(unittest.HomeserverTestCase):
) )
class UserDevicesTestCase(unittest.HomeserverTestCase):
"""
Tests user device management-related Admin APIs.
"""
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
sync.register_servlets,
]
def prepare(
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
) -> None:
# Set up an Admin user to query the Admin API with.
self.admin_user_id = self.register_user("admin", "pass", admin=True)
self.admin_user_token = self.login("admin", "pass")
# Set up a test user to query the devices of.
self.other_user_device_id = "TESTDEVICEID"
self.other_user_device_display_name = "My Test Device"
self.other_user_client_ip = "1.2.3.4"
self.other_user_user_agent = "EquestriaTechnology/123.0"
self.other_user_id = self.register_user("user", "pass", displayname="User1")
self.other_user_token = self.login(
"user",
"pass",
device_id=self.other_user_device_id,
additional_request_fields={
"initial_device_display_name": self.other_user_device_display_name,
},
)
# Have the "other user" make a request so that the "last_seen_*" fields are
# populated in the tests below.
channel = self.make_request(
"GET",
"/_matrix/client/v3/sync",
access_token=self.other_user_token,
client_ip=self.other_user_client_ip,
custom_headers=[
("User-Agent", self.other_user_user_agent),
],
)
self.assertEqual(200, channel.code, msg=channel.json_body)
def test_list_user_devices(self) -> None:
"""
Tests that a user's devices and attributes are listed correctly via the Admin API.
"""
# Request all devices of "other user"
channel = self.make_request(
"GET",
f"/_synapse/admin/v2/users/{self.other_user_id}/devices",
access_token=self.admin_user_token,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
# Double-check we got the single device expected
user_devices = channel.json_body["devices"]
self.assertEqual(len(user_devices), 1)
self.assertEqual(channel.json_body["total"], 1)
# Check that all the attributes of the device reported are as expected.
self._validate_attributes_of_device_response(user_devices[0])
# Request just a single device for "other user" by its ID
channel = self.make_request(
"GET",
f"/_synapse/admin/v2/users/{self.other_user_id}/devices/"
f"{self.other_user_device_id}",
access_token=self.admin_user_token,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
# Check that all the attributes of the device reported are as expected.
self._validate_attributes_of_device_response(channel.json_body)
def _validate_attributes_of_device_response(self, response: JsonDict) -> None:
# Check that all device expected attributes are present
self.assertEqual(response["user_id"], self.other_user_id)
self.assertEqual(response["device_id"], self.other_user_device_id)
self.assertEqual(response["display_name"], self.other_user_device_display_name)
self.assertEqual(response["last_seen_ip"], self.other_user_client_ip)
self.assertEqual(response["last_seen_user_agent"], self.other_user_user_agent)
self.assertIsInstance(response["last_seen_ts"], int)
self.assertGreater(response["last_seen_ts"], 0)
class DeactivateAccountTestCase(unittest.HomeserverTestCase): class DeactivateAccountTestCase(unittest.HomeserverTestCase):
servlets = [ servlets = [

View File

@ -61,6 +61,10 @@ from twisted.web.resource import IResource
from twisted.web.server import Request, Site from twisted.web.server import Request, Site
from synapse.config.database import DatabaseConnectionConfig from synapse.config.database import DatabaseConnectionConfig
from synapse.events.presence_router import load_legacy_presence_router
from synapse.events.spamcheck import load_legacy_spam_checkers
from synapse.events.third_party_rules import load_legacy_third_party_event_rules
from synapse.handlers.auth import load_legacy_password_auth_providers
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.logging.context import ContextResourceUsage from synapse.logging.context import ContextResourceUsage
from synapse.server import HomeServer from synapse.server import HomeServer
@ -913,4 +917,14 @@ def setup_test_homeserver(
# Make the threadpool and database transactions synchronous for testing. # Make the threadpool and database transactions synchronous for testing.
_make_test_homeserver_synchronous(hs) _make_test_homeserver_synchronous(hs)
# Load any configured modules into the homeserver
module_api = hs.get_module_api()
for module, config in hs.config.modules.loaded_modules:
module(config=config, api=module_api)
load_legacy_spam_checkers(hs)
load_legacy_third_party_event_rules(hs)
load_legacy_presence_router(hs)
load_legacy_password_auth_providers(hs)
return hs return hs

View File

@ -677,14 +677,29 @@ class HomeserverTestCase(TestCase):
username: str, username: str,
password: str, password: str,
device_id: Optional[str] = None, device_id: Optional[str] = None,
additional_request_fields: Optional[Dict[str, str]] = None,
custom_headers: Optional[Iterable[CustomHeaderType]] = None, custom_headers: Optional[Iterable[CustomHeaderType]] = None,
) -> str: ) -> str:
""" """
Log in a user, and get an access token. Requires the Login API be registered. Log in a user, and get an access token. Requires the Login API be registered.
Args:
username: The localpart to assign to the new user.
password: The password to assign to the new user.
device_id: An optional device ID to assign to the new device created during
login.
additional_request_fields: A dictionary containing any additional /login
request fields and their values.
custom_headers: Custom HTTP headers and values to add to the /login request.
Returns:
The newly registered user's Matrix ID.
""" """
body = {"type": "m.login.password", "user": username, "password": password} body = {"type": "m.login.password", "user": username, "password": password}
if device_id: if device_id:
body["device_id"] = device_id body["device_id"] = device_id
if additional_request_fields:
body.update(additional_request_fields)
channel = self.make_request( channel = self.make_request(
"POST", "POST",