From 5cbe8d93fe08ff347730596475a8a50b977754d7 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Fri, 27 Nov 2020 10:49:38 +0000 Subject: [PATCH 01/19] Add typing to membership Replication class methods (#8809) This PR grew out of #6739, and adds typing to some method arguments You'll notice that there are a lot of `# type: ignores` in here. This is due to the base methods not matching the overloads here. This is necessary to stop mypy complaining, but a better solution is #8828. --- changelog.d/8809.misc | 1 + synapse/replication/http/membership.py | 66 +++++++++++++++++--------- 2 files changed, 45 insertions(+), 22 deletions(-) create mode 100644 changelog.d/8809.misc diff --git a/changelog.d/8809.misc b/changelog.d/8809.misc new file mode 100644 index 0000000000..bbf83cf18d --- /dev/null +++ b/changelog.d/8809.misc @@ -0,0 +1 @@ +Remove unnecessary function arguments and add typing to several membership replication classes. \ No newline at end of file diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py index f0c37eaf5e..84e002f934 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py @@ -12,9 +12,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, List, Optional, Tuple + +from twisted.web.http import Request from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint @@ -52,16 +53,23 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): self.clock = hs.get_clock() @staticmethod - async def _serialize_payload( - requester, room_id, user_id, remote_room_hosts, content - ): + async def _serialize_payload( # type: ignore + requester: Requester, + room_id: str, + user_id: str, + remote_room_hosts: List[str], + content: JsonDict, + ) -> JsonDict: """ Args: - requester(Requester) - room_id (str) - user_id (str) - remote_room_hosts (list[str]): Servers to try and join via - content(dict): The event content to use for the join event + requester: The user making the request according to the access token + room_id: The ID of the room. + user_id: The ID of the user. + remote_room_hosts: Servers to try and join via + content: The event content to use for the join event + + Returns: + A dict representing the payload of the request. """ return { "requester": requester.serialize(), @@ -69,7 +77,9 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): "content": content, } - async def _handle_request(self, request, room_id, user_id): + async def _handle_request( # type: ignore + self, request: Request, room_id: str, user_id: str + ) -> Tuple[int, JsonDict]: content = parse_json_object_from_request(request) remote_room_hosts = content["remote_room_hosts"] @@ -118,14 +128,17 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): txn_id: Optional[str], requester: Requester, content: JsonDict, - ): + ) -> JsonDict: """ Args: - invite_event_id: ID of the invite to be rejected - txn_id: optional transaction ID supplied by the client - requester: user making the rejection request, according to the access token - content: additional content to include in the rejection event. + invite_event_id: The ID of the invite to be rejected. + txn_id: Optional transaction ID supplied by the client + requester: User making the rejection request, according to the access token + content: Additional content to include in the rejection event. Normally an empty dict. + + Returns: + A dict representing the payload of the request. """ return { "txn_id": txn_id, @@ -133,7 +146,9 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): "content": content, } - async def _handle_request(self, request, invite_event_id): + async def _handle_request( # type: ignore + self, request: Request, invite_event_id: str + ) -> Tuple[int, JsonDict]: content = parse_json_object_from_request(request) txn_id = content["txn_id"] @@ -174,18 +189,25 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint): self.distributor = hs.get_distributor() @staticmethod - async def _serialize_payload(room_id, user_id, change): + async def _serialize_payload( # type: ignore + room_id: str, user_id: str, change: str + ) -> JsonDict: """ Args: - room_id (str) - user_id (str) - change (str): "left" + room_id: The ID of the room. + user_id: The ID of the user. + change: "left" + + Returns: + A dict representing the payload of the request. """ assert change == "left" return {} - def _handle_request(self, request, room_id, user_id, change): + def _handle_request( # type: ignore + self, request: Request, room_id: str, user_id: str, change: str + ) -> Tuple[int, JsonDict]: logger.info("user membership change: %s in %s", user_id, room_id) user = UserID.from_string(user_id) From 856eab606b17bcd78549ff12e9d41a62d441af07 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 27 Nov 2020 14:37:55 +0200 Subject: [PATCH 02/19] Remove special case of pretty printing JSON responses for curl (#8833) * Remove special case of pretty printing JSON responses for curl Signed-off-by: Tulir Asokan --- changelog.d/8833.removal | 1 + synapse/http/server.py | 29 +++++------------------------ 2 files changed, 6 insertions(+), 24 deletions(-) create mode 100644 changelog.d/8833.removal diff --git a/changelog.d/8833.removal b/changelog.d/8833.removal new file mode 100644 index 0000000000..5c2d195f94 --- /dev/null +++ b/changelog.d/8833.removal @@ -0,0 +1 @@ +Disable pretty printing JSON responses for curl. Users who want pretty-printed output should use [jq](https://stedolan.github.io/jq/) in combination with curl. Contributed by @tulir. diff --git a/synapse/http/server.py b/synapse/http/server.py index c0919f8cb7..46cdfad04a 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -25,7 +25,7 @@ from io import BytesIO from typing import Any, Callable, Dict, Iterator, List, Tuple, Union import jinja2 -from canonicaljson import iterencode_canonical_json, iterencode_pretty_printed_json +from canonicaljson import iterencode_canonical_json from zope.interface import implementer from twisted.internet import defer, interfaces @@ -94,11 +94,7 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None: pass else: respond_with_json( - request, - error_code, - error_dict, - send_cors=True, - pretty_print=_request_user_agent_is_curl(request), + request, error_code, error_dict, send_cors=True, ) @@ -290,7 +286,6 @@ class DirectServeJsonResource(_AsyncResource): code, response_object, send_cors=True, - pretty_print=_request_user_agent_is_curl(request), canonical_json=self.canonical_json, ) @@ -587,7 +582,6 @@ def respond_with_json( code: int, json_object: Any, send_cors: bool = False, - pretty_print: bool = False, canonical_json: bool = True, ): """Sends encoded JSON in response to the given request. @@ -598,8 +592,6 @@ def respond_with_json( json_object: The object to serialize to JSON. send_cors: Whether to send Cross-Origin Resource Sharing headers https://fetch.spec.whatwg.org/#http-cors-protocol - pretty_print: Whether to include indentation and line-breaks in the - resulting JSON bytes. canonical_json: Whether to use the canonicaljson algorithm when encoding the JSON bytes. @@ -615,13 +607,10 @@ def respond_with_json( ) return None - if pretty_print: - encoder = iterencode_pretty_printed_json + if canonical_json: + encoder = iterencode_canonical_json else: - if canonical_json: - encoder = iterencode_canonical_json - else: - encoder = _encode_json_bytes + encoder = _encode_json_bytes request.setResponseCode(code) request.setHeader(b"Content-Type", b"application/json") @@ -759,11 +748,3 @@ def finish_request(request: Request): request.finish() except RuntimeError as e: logger.info("Connection disconnected before response was written: %r", e) - - -def _request_user_agent_is_curl(request: Request) -> bool: - user_agents = request.requestHeaders.getRawHeaders(b"User-Agent", default=[]) - for user_agent in user_agents: - if b"curl" in user_agent: - return True - return False From a090b86209c01e42415df7736eb26c8adbe2aba0 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 30 Nov 2020 16:48:12 +0000 Subject: [PATCH 03/19] Add `force_purge` option to delete-room admin api. (#8843) --- changelog.d/8843.feature | 1 + docs/admin_api/rooms.md | 6 +++++- synapse/handlers/pagination.py | 17 +++++++++++------ synapse/rest/admin/rooms.py | 22 +++++++++++++++++----- 4 files changed, 34 insertions(+), 12 deletions(-) create mode 100644 changelog.d/8843.feature diff --git a/changelog.d/8843.feature b/changelog.d/8843.feature new file mode 100644 index 0000000000..824d46d5aa --- /dev/null +++ b/changelog.d/8843.feature @@ -0,0 +1 @@ +Add `force_purge` option to delete-room admin api. diff --git a/docs/admin_api/rooms.md b/docs/admin_api/rooms.md index 0c05b0ed55..004a802e17 100644 --- a/docs/admin_api/rooms.md +++ b/docs/admin_api/rooms.md @@ -382,7 +382,7 @@ the new room. Users on other servers will be unaffected. The API is: -```json +``` POST /_synapse/admin/v1/rooms//delete ``` @@ -439,6 +439,10 @@ The following JSON body parameters are available: future attempts to join the room. Defaults to `false`. * `purge` - Optional. If set to `true`, it will remove all traces of the room from your database. Defaults to `true`. +* `force_purge` - Optional, and ignored unless `purge` is `true`. If set to `true`, it + will force a purge to go ahead even if there are local users still in the room. Do not + use this unless a regular `purge` operation fails, as it could leave those users' + clients in a confused state. The JSON body must not be empty. The body must be at least `{}`. diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 426b58da9e..5372753707 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -299,17 +299,22 @@ class PaginationHandler: """ return self._purges_by_id.get(purge_id) - async def purge_room(self, room_id: str) -> None: - """Purge the given room from the database""" + async def purge_room(self, room_id: str, force: bool = False) -> None: + """Purge the given room from the database. + + Args: + room_id: room to be purged + force: set true to skip checking for joined users. + """ with await self.pagination_lock.write(room_id): # check we know about the room await self.store.get_room_version_id(room_id) # first check that we have no users in this room - joined = await self.store.is_host_joined(room_id, self._server_name) - - if joined: - raise SynapseError(400, "Users are still joined to this room") + if not force: + joined = await self.store.is_host_joined(room_id, self._server_name) + if joined: + raise SynapseError(400, "Users are still joined to this room") await self.storage.purge_events.purge_room(room_id) diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 353151169a..25f89e4685 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -70,14 +70,18 @@ class ShutdownRoomRestServlet(RestServlet): class DeleteRoomRestServlet(RestServlet): - """Delete a room from server. It is a combination and improvement of - shut down and purge room. + """Delete a room from server. + + It is a combination and improvement of shutdown and purge room. + Shuts down a room by removing all local users from the room. Blocking all future invites and joins to the room is optional. + If desired any local aliases will be repointed to a new room - created by `new_room_user_id` and kicked users will be auto + created by `new_room_user_id` and kicked users will be auto- joined to the new room. - It will remove all trace of a room from the database. + + If 'purge' is true, it will remove all traces of a room from the database. """ PATTERNS = admin_patterns("/rooms/(?P[^/]+)/delete$") @@ -110,6 +114,14 @@ class DeleteRoomRestServlet(RestServlet): Codes.BAD_JSON, ) + force_purge = content.get("force_purge", False) + if not isinstance(force_purge, bool): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Param 'force_purge' must be a boolean, if given", + Codes.BAD_JSON, + ) + ret = await self.room_shutdown_handler.shutdown_room( room_id=room_id, new_room_user_id=content.get("new_room_user_id"), @@ -121,7 +133,7 @@ class DeleteRoomRestServlet(RestServlet): # Purge room if purge: - await self.pagination_handler.purge_room(room_id) + await self.pagination_handler.purge_room(room_id, force=force_purge) return (200, ret) From ca60822b34416e524a80c5689734d870ceb6e130 Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Mon, 30 Nov 2020 19:28:44 +0100 Subject: [PATCH 04/19] Simplify the way the `HomeServer` object caches its internal attributes. (#8565) Changes `@cache_in_self` to use underscore-prefixed attributes. --- changelog.d/8565.misc | 1 + synapse/handlers/identity.py | 3 ++- synapse/rest/client/v2_alpha/account.py | 6 ++--- synapse/rest/client/v2_alpha/register.py | 4 +-- synapse/rest/key/v2/local_key_resource.py | 2 +- synapse/server.py | 27 ++++++++++----------- tests/handlers/test_auth.py | 6 ++--- tests/replication/_base.py | 2 +- tests/rest/client/v1/test_presence.py | 17 +++++++------ tests/rest/client/v2_alpha/test_register.py | 6 ++--- tests/utils.py | 2 +- 11 files changed, 40 insertions(+), 36 deletions(-) create mode 100644 changelog.d/8565.misc diff --git a/changelog.d/8565.misc b/changelog.d/8565.misc new file mode 100644 index 0000000000..7bef422618 --- /dev/null +++ b/changelog.d/8565.misc @@ -0,0 +1 @@ +Simplify the way the `HomeServer` object caches its internal attributes. diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index bc3e9607ca..9b3c6b4551 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -354,7 +354,8 @@ class IdentityHandler(BaseHandler): raise SynapseError(500, "An error was encountered when sending the email") token_expires = ( - self.hs.clock.time_msec() + self.hs.config.email_validation_token_lifetime + self.hs.get_clock().time_msec() + + self.hs.config.email_validation_token_lifetime ) await self.store.start_or_continue_validation_session( diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index a54e1011f7..eebee44a44 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -115,7 +115,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): # comments for request_token_inhibit_3pid_errors. # Also wait for some random amount of time between 100ms and 1s to make it # look like we did something. - await self.hs.clock.sleep(random.randint(1, 10) / 10) + await self.hs.get_clock().sleep(random.randint(1, 10) / 10) return 200, {"sid": random_string(16)} raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND) @@ -387,7 +387,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): # comments for request_token_inhibit_3pid_errors. # Also wait for some random amount of time between 100ms and 1s to make it # look like we did something. - await self.hs.clock.sleep(random.randint(1, 10) / 10) + await self.hs.get_clock().sleep(random.randint(1, 10) / 10) return 200, {"sid": random_string(16)} raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) @@ -466,7 +466,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): # comments for request_token_inhibit_3pid_errors. # Also wait for some random amount of time between 100ms and 1s to make it # look like we did something. - await self.hs.clock.sleep(random.randint(1, 10) / 10) + await self.hs.get_clock().sleep(random.randint(1, 10) / 10) return 200, {"sid": random_string(16)} raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index ea68114026..a89ae6ddf9 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -135,7 +135,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): # comments for request_token_inhibit_3pid_errors. # Also wait for some random amount of time between 100ms and 1s to make it # look like we did something. - await self.hs.clock.sleep(random.randint(1, 10) / 10) + await self.hs.get_clock().sleep(random.randint(1, 10) / 10) return 200, {"sid": random_string(16)} raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) @@ -214,7 +214,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): # comments for request_token_inhibit_3pid_errors. # Also wait for some random amount of time between 100ms and 1s to make it # look like we did something. - await self.hs.clock.sleep(random.randint(1, 10) / 10) + await self.hs.get_clock().sleep(random.randint(1, 10) / 10) return 200, {"sid": random_string(16)} raise SynapseError( diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py index c16280f668..d8e8e48c1c 100644 --- a/synapse/rest/key/v2/local_key_resource.py +++ b/synapse/rest/key/v2/local_key_resource.py @@ -66,7 +66,7 @@ class LocalKey(Resource): def __init__(self, hs): self.config = hs.config - self.clock = hs.clock + self.clock = hs.get_clock() self.update_response_body(self.clock.time_msec()) Resource.__init__(self) diff --git a/synapse/server.py b/synapse/server.py index c82d8f9fad..b017e3489f 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -147,7 +147,8 @@ def cache_in_self(builder: T) -> T: "@cache_in_self can only be used on functions starting with `get_`" ) - depname = builder.__name__[len("get_") :] + # get_attr -> _attr + depname = builder.__name__[len("get") :] building = [False] @@ -235,15 +236,6 @@ class HomeServer(metaclass=abc.ABCMeta): self._instance_id = random_string(5) self._instance_name = config.worker_name or "master" - self.clock = Clock(reactor) - self.distributor = Distributor() - - self.registration_ratelimiter = Ratelimiter( - clock=self.clock, - rate_hz=config.rc_registration.per_second, - burst_count=config.rc_registration.burst_count, - ) - self.version_string = version_string self.datastores = None # type: Optional[Databases] @@ -301,8 +293,9 @@ class HomeServer(metaclass=abc.ABCMeta): def is_mine_id(self, string: str) -> bool: return string.split(":", 1)[1] == self.hostname + @cache_in_self def get_clock(self) -> Clock: - return self.clock + return Clock(self._reactor) def get_datastore(self) -> DataStore: if not self.datastores: @@ -319,11 +312,17 @@ class HomeServer(metaclass=abc.ABCMeta): def get_config(self) -> HomeServerConfig: return self.config + @cache_in_self def get_distributor(self) -> Distributor: - return self.distributor + return Distributor() + @cache_in_self def get_registration_ratelimiter(self) -> Ratelimiter: - return self.registration_ratelimiter + return Ratelimiter( + clock=self.get_clock(), + rate_hz=self.config.rc_registration.per_second, + burst_count=self.config.rc_registration.burst_count, + ) @cache_in_self def get_federation_client(self) -> FederationClient: @@ -687,7 +686,7 @@ class HomeServer(metaclass=abc.ABCMeta): @cache_in_self def get_federation_ratelimiter(self) -> FederationRateLimiter: - return FederationRateLimiter(self.clock, config=self.config.rc_federation) + return FederationRateLimiter(self.get_clock(), config=self.config.rc_federation) @cache_in_self def get_module_api(self) -> ModuleApi: diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index b5055e018c..e24ce81284 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -52,7 +52,7 @@ class AuthTestCase(unittest.TestCase): self.fail("some_user was not in %s" % macaroon.inspect()) def test_macaroon_caveats(self): - self.hs.clock.now = 5000 + self.hs.get_clock().now = 5000 token = self.macaroon_generator.generate_access_token("a_user") macaroon = pymacaroons.Macaroon.deserialize(token) @@ -78,7 +78,7 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_short_term_login_token_gives_user_id(self): - self.hs.clock.now = 1000 + self.hs.get_clock().now = 1000 token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000) user_id = yield defer.ensureDeferred( @@ -87,7 +87,7 @@ class AuthTestCase(unittest.TestCase): self.assertEqual("a_user", user_id) # when we advance the clock, the token should be rejected - self.hs.clock.now = 6000 + self.hs.get_clock().now = 6000 with self.assertRaises(synapse.api.errors.AuthError): yield defer.ensureDeferred( self.auth_handler.validate_short_term_login_token_and_get_user_id(token) diff --git a/tests/replication/_base.py b/tests/replication/_base.py index 516db4c30a..295c5d58a6 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -78,7 +78,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): self.worker_hs.get_datastore().db_pool = hs.get_datastore().db_pool self.test_handler = self._build_replication_data_handler() - self.worker_hs.replication_data_handler = self.test_handler + self.worker_hs._replication_data_handler = self.test_handler repl_handler = ReplicationCommandHandler(self.worker_hs) self.client = ClientReplicationStreamProtocol( diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py index b84f86d28c..5d5c24d01c 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py @@ -33,12 +33,15 @@ class PresenceTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): - hs = self.setup_test_homeserver( - "red", http_client=None, federation_client=Mock() - ) + presence_handler = Mock() + presence_handler.set_state.return_value = defer.succeed(None) - hs.presence_handler = Mock() - hs.presence_handler.set_state.return_value = defer.succeed(None) + hs = self.setup_test_homeserver( + "red", + http_client=None, + federation_client=Mock(), + presence_handler=presence_handler, + ) return hs @@ -55,7 +58,7 @@ class PresenceTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, 200) - self.assertEqual(self.hs.presence_handler.set_state.call_count, 1) + self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 1) def test_put_presence_disabled(self): """ @@ -70,4 +73,4 @@ class PresenceTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, 200) - self.assertEqual(self.hs.presence_handler.set_state.call_count, 0) + self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 0) diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 699a40c3df..8f0c2430e8 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -569,7 +569,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): tok = self.login("kermit", "monkey") # We need to manually add an email address otherwise the handler will do # nothing. - now = self.hs.clock.time_msec() + now = self.hs.get_clock().time_msec() self.get_success( self.store.user_add_threepid( user_id=user_id, @@ -587,7 +587,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): # We need to manually add an email address otherwise the handler will do # nothing. - now = self.hs.clock.time_msec() + now = self.hs.get_clock().time_msec() self.get_success( self.store.user_add_threepid( user_id=user_id, @@ -646,7 +646,7 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): self.hs.config.account_validity.startup_job_max_delta = self.max_delta - now_ms = self.hs.clock.time_msec() + now_ms = self.hs.get_clock().time_msec() self.get_success(self.store._set_expiration_date_when_missing()) res = self.get_success(self.store.get_expiration_ts_for_user(user_id)) diff --git a/tests/utils.py b/tests/utils.py index acec74e9e9..c8d3ffbaba 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -271,7 +271,7 @@ def setup_test_homeserver( # Install @cache_in_self attributes for key, val in kwargs.items(): - setattr(hs, key, val) + setattr(hs, "_" + key, val) # Mock TLS hs.tls_server_context_factory = Mock() From 17fa58bdd1c23b9019d080fd98873aa5182f56c0 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Mon, 30 Nov 2020 18:43:54 +0000 Subject: [PATCH 05/19] Add a config option to change whether unread push notification counts are per-message or per-room (#8820) This PR adds a new config option to the `push` section of the homeserver config, `group_unread_count_by_room`. By default Synapse will group push notifications by room (so if you have 1000 unread messages, if they lie in 55 rooms, you'll see an unread count on your phone of 55). However, it is also useful to be able to send out the true count of unread messages if desired. If `group_unread_count_by_room` is set to `false`, then with the above example, one would see an unread count of 1000 (email anyone?). --- changelog.d/8820.feature | 1 + docs/sample_config.yaml | 10 +++ synapse/config/push.py | 13 +++ synapse/push/httppusher.py | 13 ++- synapse/push/push_tools.py | 16 ++-- tests/push/test_http.py | 163 ++++++++++++++++++++++++++++++++++++- 6 files changed, 207 insertions(+), 9 deletions(-) create mode 100644 changelog.d/8820.feature diff --git a/changelog.d/8820.feature b/changelog.d/8820.feature new file mode 100644 index 0000000000..9e35861b11 --- /dev/null +++ b/changelog.d/8820.feature @@ -0,0 +1 @@ +Add a config option, `push.group_by_unread_count`, which controls whether unread message counts in push notifications are defined as "the number of rooms with unread messages" or "total unread messages". diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index df0f3e1d8e..394eb9a3ff 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -2271,6 +2271,16 @@ push: # #include_content: false + # When a push notification is received, an unread count is also sent. + # This number can either be calculated as the number of unread messages + # for the user, or the number of *rooms* the user has unread messages in. + # + # The default value is "true", meaning push clients will see the number of + # rooms with unread messages in them. Uncomment to instead send the number + # of unread messages. + # + #group_unread_count_by_room: false + # Spam checkers are third-party modules that can block specific actions # of local users, such as creating rooms and registering undesirable diff --git a/synapse/config/push.py b/synapse/config/push.py index a71baac89c..3adbfb73e6 100644 --- a/synapse/config/push.py +++ b/synapse/config/push.py @@ -23,6 +23,9 @@ class PushConfig(Config): def read_config(self, config, **kwargs): push_config = config.get("push") or {} self.push_include_content = push_config.get("include_content", True) + self.push_group_unread_count_by_room = push_config.get( + "group_unread_count_by_room", True + ) pusher_instances = config.get("pusher_instances") or [] self.pusher_shard_config = ShardedWorkerHandlingConfig(pusher_instances) @@ -68,4 +71,14 @@ class PushConfig(Config): # include the event ID and room ID in push notification payloads. # #include_content: false + + # When a push notification is received, an unread count is also sent. + # This number can either be calculated as the number of unread messages + # for the user, or the number of *rooms* the user has unread messages in. + # + # The default value is "true", meaning push clients will see the number of + # rooms with unread messages in them. Uncomment to instead send the number + # of unread messages. + # + #group_unread_count_by_room: false """ diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 793d0db2d9..eff0975b6a 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -75,6 +75,7 @@ class HttpPusher: self.failing_since = pusherdict["failing_since"] self.timed_call = None self._is_processing = False + self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room # This is the highest stream ordering we know it's safe to process. # When new events arrive, we'll be given a window of new events: we @@ -136,7 +137,11 @@ class HttpPusher: async def _update_badge(self): # XXX as per https://github.com/matrix-org/matrix-doc/issues/2627, this seems # to be largely redundant. perhaps we can remove it. - badge = await push_tools.get_badge_count(self.hs.get_datastore(), self.user_id) + badge = await push_tools.get_badge_count( + self.hs.get_datastore(), + self.user_id, + group_by_room=self._group_unread_count_by_room, + ) await self._send_badge(badge) def on_timer(self): @@ -283,7 +288,11 @@ class HttpPusher: return True tweaks = push_rule_evaluator.tweaks_for_actions(push_action["actions"]) - badge = await push_tools.get_badge_count(self.hs.get_datastore(), self.user_id) + badge = await push_tools.get_badge_count( + self.hs.get_datastore(), + self.user_id, + group_by_room=self._group_unread_count_by_room, + ) event = await self.store.get_event(push_action["event_id"], allow_none=True) if event is None: diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index d0145666bf..6e7c880dc0 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -12,12 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from synapse.push.presentable_names import calculate_room_name, name_from_member_event from synapse.storage import Storage +from synapse.storage.databases.main import DataStore -async def get_badge_count(store, user_id): +async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -> int: invites = await store.get_invited_rooms_for_local_user(user_id) joins = await store.get_rooms_for_user(user_id) @@ -34,9 +34,15 @@ async def get_badge_count(store, user_id): room_id, user_id, last_unread_event_id ) ) - # return one badge count per conversation, as count per - # message is so noisy as to be almost useless - badge += 1 if notifs["notify_count"] else 0 + if notifs["notify_count"] == 0: + continue + + if group_by_room: + # return one badge count per conversation + badge += 1 + else: + # increment the badge count by the number of unread messages in the room + badge += notifs["notify_count"] return badge diff --git a/tests/push/test_http.py b/tests/push/test_http.py index 8571924b29..f118430309 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from mock import Mock from twisted.internet.defer import Deferred @@ -20,8 +19,9 @@ from twisted.internet.defer import Deferred import synapse.rest.admin from synapse.logging.context import make_deferred_yieldable from synapse.rest.client.v1 import login, room +from synapse.rest.client.v2_alpha import receipts -from tests.unittest import HomeserverTestCase +from tests.unittest import HomeserverTestCase, override_config class HTTPPusherTests(HomeserverTestCase): @@ -29,6 +29,7 @@ class HTTPPusherTests(HomeserverTestCase): synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, login.register_servlets, + receipts.register_servlets, ] user_id = True hijack_auth = False @@ -499,3 +500,161 @@ class HTTPPusherTests(HomeserverTestCase): # check that this is low-priority self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low") + + def test_push_unread_count_group_by_room(self): + """ + The HTTP pusher will group unread count by number of unread rooms. + """ + # Carry out common push count tests and setup + self._test_push_unread_count() + + # Carry out our option-value specific test + # + # This push should still only contain an unread count of 1 (for 1 unread room) + self.assertEqual( + self.push_attempts[5][2]["notification"]["counts"]["unread"], 1 + ) + + @override_config({"push": {"group_unread_count_by_room": False}}) + def test_push_unread_count_message_count(self): + """ + The HTTP pusher will send the total unread message count. + """ + # Carry out common push count tests and setup + self._test_push_unread_count() + + # Carry out our option-value specific test + # + # We're counting every unread message, so there should now be 4 since the + # last read receipt + self.assertEqual( + self.push_attempts[5][2]["notification"]["counts"]["unread"], 4 + ) + + def _test_push_unread_count(self): + """ + Tests that the correct unread count appears in sent push notifications + + Note that: + * Sending messages will cause push notifications to go out to relevant users + * Sending a read receipt will cause a "badge update" notification to go out to + the user that sent the receipt + """ + # Register the user who gets notified + user_id = self.register_user("user", "pass") + access_token = self.login("user", "pass") + + # Register the user who sends the message + other_user_id = self.register_user("other_user", "pass") + other_access_token = self.login("other_user", "pass") + + # Create a room (as other_user) + room_id = self.helper.create_room_as(other_user_id, tok=other_access_token) + + # The user to get notified joins + self.helper.join(room=room_id, user=user_id, tok=access_token) + + # Register the pusher + user_tuple = self.get_success( + self.hs.get_datastore().get_user_by_access_token(access_token) + ) + token_id = user_tuple.token_id + + self.get_success( + self.hs.get_pusherpool().add_pusher( + user_id=user_id, + access_token=token_id, + kind="http", + app_id="m.http", + app_display_name="HTTP Push Notifications", + device_display_name="pushy push", + pushkey="a@example.com", + lang=None, + data={"url": "example.com"}, + ) + ) + + # Send a message + response = self.helper.send( + room_id, body="Hello there!", tok=other_access_token + ) + # To get an unread count, the user who is getting notified has to have a read + # position in the room. We'll set the read position to this event in a moment + first_message_event_id = response["event_id"] + + # Advance time a bit (so the pusher will register something has happened) and + # make the push succeed + self.push_attempts[0][0].callback({}) + self.pump() + + # Check our push made it + self.assertEqual(len(self.push_attempts), 1) + self.assertEqual(self.push_attempts[0][1], "example.com") + + # Check that the unread count for the room is 0 + # + # The unread count is zero as the user has no read receipt in the room yet + self.assertEqual( + self.push_attempts[0][2]["notification"]["counts"]["unread"], 0 + ) + + # Now set the user's read receipt position to the first event + # + # This will actually trigger a new notification to be sent out so that + # even if the user does not receive another message, their unread + # count goes down + request, channel = self.make_request( + "POST", + "/rooms/%s/receipt/m.read/%s" % (room_id, first_message_event_id), + {}, + access_token=access_token, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Advance time and make the push succeed + self.push_attempts[1][0].callback({}) + self.pump() + + # Unread count is still zero as we've read the only message in the room + self.assertEqual(len(self.push_attempts), 2) + self.assertEqual( + self.push_attempts[1][2]["notification"]["counts"]["unread"], 0 + ) + + # Send another message + self.helper.send( + room_id, body="How's the weather today?", tok=other_access_token + ) + + # Advance time and make the push succeed + self.push_attempts[2][0].callback({}) + self.pump() + + # This push should contain an unread count of 1 as there's now been one + # message since our last read receipt + self.assertEqual(len(self.push_attempts), 3) + self.assertEqual( + self.push_attempts[2][2]["notification"]["counts"]["unread"], 1 + ) + + # Since we're grouping by room, sending more messages shouldn't increase the + # unread count, as they're all being sent in the same room + self.helper.send(room_id, body="Hello?", tok=other_access_token) + + # Advance time and make the push succeed + self.pump() + self.push_attempts[3][0].callback({}) + + self.helper.send(room_id, body="Hello??", tok=other_access_token) + + # Advance time and make the push succeed + self.pump() + self.push_attempts[4][0].callback({}) + + self.helper.send(room_id, body="HELLO???", tok=other_access_token) + + # Advance time and make the push succeed + self.pump() + self.push_attempts[5][0].callback({}) + + self.assertEqual(len(self.push_attempts), 6) From f8d13ca13d9dd0c669a2a1b5bc390d7830c89239 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 30 Nov 2020 18:44:09 +0000 Subject: [PATCH 06/19] Drop (almost) unused index on event_json (#8845) --- changelog.d/8845.misc | 1 + .../storage/databases/main/purge_events.py | 2 +- .../delta/58/24drop_event_json_index.sql | 19 +++++++++++++++++++ 3 files changed, 21 insertions(+), 1 deletion(-) create mode 100644 changelog.d/8845.misc create mode 100644 synapse/storage/databases/main/schema/delta/58/24drop_event_json_index.sql diff --git a/changelog.d/8845.misc b/changelog.d/8845.misc new file mode 100644 index 0000000000..7db1c31520 --- /dev/null +++ b/changelog.d/8845.misc @@ -0,0 +1 @@ +Drop redundant database index on `event_json`. diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index ecfc6717b3..5d668aadb2 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -314,6 +314,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): for table in ( "event_auth", "event_edges", + "event_json", "event_push_actions_staging", "event_reference_hashes", "event_relations", @@ -340,7 +341,6 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): "destination_rooms", "event_backward_extremities", "event_forward_extremities", - "event_json", "event_push_actions", "event_search", "events", diff --git a/synapse/storage/databases/main/schema/delta/58/24drop_event_json_index.sql b/synapse/storage/databases/main/schema/delta/58/24drop_event_json_index.sql new file mode 100644 index 0000000000..8a39d54aed --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/24drop_event_json_index.sql @@ -0,0 +1,19 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- this index is essentially redundant. The only time it was ever used was when purging +-- rooms - and Synapse 1.24 will change that. + +DROP INDEX IF EXISTS event_json_room_id; From 9f0f274fe0bd9dd6ccd7e2b53e2536cb6d02ae9b Mon Sep 17 00:00:00 2001 From: Mathieu Velten Date: Mon, 30 Nov 2020 19:59:29 +0100 Subject: [PATCH 07/19] Allow per-room profile to be used for server notice user (#8799) This applies even if the feature is disabled at the server level with `allow_per_room_profiles`. The server notice not being a real user it doesn't have an user profile. --- changelog.d/8799.bugfix | 1 + synapse/handlers/room_member.py | 10 +++++++++- 2 files changed, 10 insertions(+), 1 deletion(-) create mode 100644 changelog.d/8799.bugfix diff --git a/changelog.d/8799.bugfix b/changelog.d/8799.bugfix new file mode 100644 index 0000000000..a7e6b3556d --- /dev/null +++ b/changelog.d/8799.bugfix @@ -0,0 +1 @@ +Allow per-room profiles to be used for the server notice user. diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 13a793c05a..c002886324 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -346,7 +346,15 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # later on. content = dict(content) - if not self.allow_per_room_profiles or requester.shadow_banned: + # allow the server notices mxid to set room-level profile + is_requester_server_notices_user = ( + self._server_notices_mxid is not None + and requester.user.to_string() == self._server_notices_mxid + ) + + if ( + not self.allow_per_room_profiles and not is_requester_server_notices_user + ) or requester.shadow_banned: # Strip profile data, knowing that new profile data will be added to the # event's content in event_creation_handler.create_event() using the target's # global profile. From 59e18a1333526b922b318c4165ec09570e80bf5c Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 30 Nov 2020 19:20:56 +0000 Subject: [PATCH 08/19] Simplify appservice login code (#8847) we don't need to support legacy login dictionaries here. --- changelog.d/8847.misc | 1 + synapse/rest/client/v1/login.py | 27 +++++++++++++++++++++------ 2 files changed, 22 insertions(+), 6 deletions(-) create mode 100644 changelog.d/8847.misc diff --git a/changelog.d/8847.misc b/changelog.d/8847.misc new file mode 100644 index 0000000000..5028997b04 --- /dev/null +++ b/changelog.d/8847.misc @@ -0,0 +1 @@ +Simplify `uk.half-shot.msc2778.login.application_service` login handler. diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 94452fcbf5..074bdd66c9 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -154,13 +154,28 @@ class LoginRestServlet(RestServlet): async def _do_appservice_login( self, login_submission: JsonDict, appservice: ApplicationService ): - logger.info( - "Got appservice login request with identifier: %r", - login_submission.get("identifier"), - ) + identifier = login_submission.get("identifier") + logger.info("Got appservice login request with identifier: %r", identifier) - identifier = convert_client_dict_legacy_fields_to_identifier(login_submission) - qualified_user_id = self._get_qualified_user_id(identifier) + if not isinstance(identifier, dict): + raise SynapseError( + 400, "Invalid identifier in login submission", Codes.INVALID_PARAM + ) + + # this login flow only supports identifiers of type "m.id.user". + if identifier.get("type") != "m.id.user": + raise SynapseError( + 400, "Unknown login identifier type", Codes.INVALID_PARAM + ) + + user = identifier.get("user") + if not isinstance(user, str): + raise SynapseError(400, "Invalid user in identifier", Codes.INVALID_PARAM) + + if user.startswith("@"): + qualified_user_id = user + else: + qualified_user_id = UserID(user, self.hs.hostname).to_string() if not appservice.is_interested_in_user(qualified_user_id): raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN) From d1be293f0002555016a2edd630272dbe09355347 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Tue, 1 Dec 2020 10:34:52 +0000 Subject: [PATCH 09/19] Fix typo in password_auth_providers doc A word got removed accidentally in 83434df3812650f53c60e91fb23c2079db0fb5b8. --- docs/password_auth_providers.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/password_auth_providers.md b/docs/password_auth_providers.md index 7d98d9f255..d2cdb9b2f4 100644 --- a/docs/password_auth_providers.md +++ b/docs/password_auth_providers.md @@ -26,6 +26,7 @@ Password auth provider classes must provide the following methods: It should perform any appropriate sanity checks on the provided configuration, and return an object which is then passed into + `__init__`. This method should have the `@staticmethod` decoration. From 09ac0569fe910ce4599a07e385522380ce1af6e1 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 1 Dec 2020 11:04:57 +0000 Subject: [PATCH 10/19] Fix broken testcase (#8851) This test was broken by #8565. It doesn't need to set set `self.clock` here anyway - that is done by `setUp`. --- changelog.d/8851.misc | 1 + tests/rest/admin/test_media.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) create mode 100644 changelog.d/8851.misc diff --git a/changelog.d/8851.misc b/changelog.d/8851.misc new file mode 100644 index 0000000000..7bef422618 --- /dev/null +++ b/changelog.d/8851.misc @@ -0,0 +1 @@ +Simplify the way the `HomeServer` object caches its internal attributes. diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py index 2a65ab33bd..dadf9db660 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py @@ -192,7 +192,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.handler = hs.get_device_handler() self.media_repo = hs.get_media_repository_resource() self.server_name = hs.hostname - self.clock = hs.clock self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") From ddc43436838e19a7dd16860389bd76c74578dae7 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 1 Dec 2020 11:10:42 +0000 Subject: [PATCH 11/19] Add some tests for `password_auth_providers` (#8819) These things seemed to be completely untested, so I added a load of tests for them. --- changelog.d/8819.misc | 1 + mypy.ini | 1 + tests/handlers/test_password_providers.py | 486 ++++++++++++++++++++++ 3 files changed, 488 insertions(+) create mode 100644 changelog.d/8819.misc create mode 100644 tests/handlers/test_password_providers.py diff --git a/changelog.d/8819.misc b/changelog.d/8819.misc new file mode 100644 index 0000000000..a5793273a5 --- /dev/null +++ b/changelog.d/8819.misc @@ -0,0 +1 @@ +Add tests for `password_auth_provider`s. diff --git a/mypy.ini b/mypy.ini index a5503abe26..3c8d303064 100644 --- a/mypy.ini +++ b/mypy.ini @@ -80,6 +80,7 @@ files = synapse/util/metrics.py, tests/replication, tests/test_utils, + tests/handlers/test_password_providers.py, tests/rest/client/v2_alpha/test_auth.py, tests/util/test_stream_change_cache.py diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py new file mode 100644 index 0000000000..edfab8a13a --- /dev/null +++ b/tests/handlers/test_password_providers.py @@ -0,0 +1,486 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the password_auth_provider interface""" + +from typing import Any, Type, Union + +from mock import Mock + +from twisted.internet import defer + +import synapse +from synapse.rest.client.v1 import login +from synapse.rest.client.v2_alpha import devices +from synapse.types import JsonDict + +from tests import unittest +from tests.server import FakeChannel +from tests.unittest import override_config + +# (possibly experimental) login flows we expect to appear in the list after the normal +# ones +ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_service"}] + +# a mock instance which the dummy auth providers delegate to, so we can see what's going +# on +mock_password_provider = Mock() + + +class PasswordOnlyAuthProvider: + """A password_provider which only implements `check_password`.""" + + @staticmethod + def parse_config(self): + pass + + def __init__(self, config, account_handler): + pass + + def check_password(self, *args): + return mock_password_provider.check_password(*args) + + +class CustomAuthProvider: + """A password_provider which implements a custom login type.""" + + @staticmethod + def parse_config(self): + pass + + def __init__(self, config, account_handler): + pass + + def get_supported_login_types(self): + return {"test.login_type": ["test_field"]} + + def check_auth(self, *args): + return mock_password_provider.check_auth(*args) + + +def providers_config(*providers: Type[Any]) -> dict: + """Returns a config dict that will enable the given password auth providers""" + return { + "password_providers": [ + {"module": "%s.%s" % (__name__, provider.__qualname__), "config": {}} + for provider in providers + ] + } + + +class PasswordAuthProviderTests(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + devices.register_servlets, + ] + + def setUp(self): + # we use a global mock device, so make sure we are starting with a clean slate + mock_password_provider.reset_mock() + super().setUp() + + @override_config(providers_config(PasswordOnlyAuthProvider)) + def test_password_only_auth_provider_login(self): + # login flows should only have m.login.password + flows = self._get_login_flows() + self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS) + + # check_password must return an awaitable + mock_password_provider.check_password.return_value = defer.succeed(True) + channel = self._send_password_login("u", "p") + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual("@u:test", channel.json_body["user_id"]) + mock_password_provider.check_password.assert_called_once_with("@u:test", "p") + mock_password_provider.reset_mock() + + # login with mxid should work too + channel = self._send_password_login("@u:bz", "p") + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual("@u:bz", channel.json_body["user_id"]) + mock_password_provider.check_password.assert_called_once_with("@u:bz", "p") + mock_password_provider.reset_mock() + + # try a weird username / pass. Honestly it's unclear what we *expect* to happen + # in these cases, but at least we can guard against the API changing + # unexpectedly + channel = self._send_password_login(" USER🙂NAME ", " pASS\U0001F622word ") + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual("@ USER🙂NAME :test", channel.json_body["user_id"]) + mock_password_provider.check_password.assert_called_once_with( + "@ USER🙂NAME :test", " pASS😢word " + ) + + @override_config(providers_config(PasswordOnlyAuthProvider)) + def test_password_only_auth_provider_ui_auth(self): + """UI Auth should delegate correctly to the password provider""" + + # create the user, otherwise access doesn't work + module_api = self.hs.get_module_api() + self.get_success(module_api.register_user("u")) + + # log in twice, to get two devices + mock_password_provider.check_password.return_value = defer.succeed(True) + tok1 = self.login("u", "p") + self.login("u", "p", device_id="dev2") + mock_password_provider.reset_mock() + + # have the auth provider deny the request to start with + mock_password_provider.check_password.return_value = defer.succeed(False) + + # make the initial request which returns a 401 + session = self._start_delete_device_session(tok1, "dev2") + mock_password_provider.check_password.assert_not_called() + + # Make another request providing the UI auth flow. + channel = self._authed_delete_device(tok1, "dev2", session, "u", "p") + self.assertEqual(channel.code, 401) # XXX why not a 403? + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + mock_password_provider.check_password.assert_called_once_with("@u:test", "p") + mock_password_provider.reset_mock() + + # Finally, check the request goes through when we allow it + mock_password_provider.check_password.return_value = defer.succeed(True) + channel = self._authed_delete_device(tok1, "dev2", session, "u", "p") + self.assertEqual(channel.code, 200) + mock_password_provider.check_password.assert_called_once_with("@u:test", "p") + + @override_config(providers_config(PasswordOnlyAuthProvider)) + def test_local_user_fallback_login(self): + """rejected login should fall back to local db""" + self.register_user("localuser", "localpass") + + # check_password must return an awaitable + mock_password_provider.check_password.return_value = defer.succeed(False) + channel = self._send_password_login("u", "p") + self.assertEqual(channel.code, 403, channel.result) + + channel = self._send_password_login("localuser", "localpass") + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual("@localuser:test", channel.json_body["user_id"]) + + @override_config(providers_config(PasswordOnlyAuthProvider)) + def test_local_user_fallback_ui_auth(self): + """rejected login should fall back to local db""" + self.register_user("localuser", "localpass") + + # have the auth provider deny the request + mock_password_provider.check_password.return_value = defer.succeed(False) + + # log in twice, to get two devices + tok1 = self.login("localuser", "localpass") + self.login("localuser", "localpass", device_id="dev2") + mock_password_provider.check_password.reset_mock() + + # first delete should give a 401 + session = self._start_delete_device_session(tok1, "dev2") + mock_password_provider.check_password.assert_not_called() + + # Wrong password + channel = self._authed_delete_device(tok1, "dev2", session, "localuser", "xxx") + self.assertEqual(channel.code, 401) # XXX why not a 403? + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + mock_password_provider.check_password.assert_called_once_with( + "@localuser:test", "xxx" + ) + mock_password_provider.reset_mock() + + # Right password + channel = self._authed_delete_device( + tok1, "dev2", session, "localuser", "localpass" + ) + self.assertEqual(channel.code, 200) + mock_password_provider.check_password.assert_called_once_with( + "@localuser:test", "localpass" + ) + + @override_config( + { + **providers_config(PasswordOnlyAuthProvider), + "password_config": {"localdb_enabled": False}, + } + ) + def test_no_local_user_fallback_login(self): + """localdb_enabled can block login with the local password + """ + self.register_user("localuser", "localpass") + + # check_password must return an awaitable + mock_password_provider.check_password.return_value = defer.succeed(False) + channel = self._send_password_login("localuser", "localpass") + self.assertEqual(channel.code, 403) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + mock_password_provider.check_password.assert_called_once_with( + "@localuser:test", "localpass" + ) + + @override_config( + { + **providers_config(PasswordOnlyAuthProvider), + "password_config": {"localdb_enabled": False}, + } + ) + def test_no_local_user_fallback_ui_auth(self): + """localdb_enabled can block ui auth with the local password + """ + self.register_user("localuser", "localpass") + + # allow login via the auth provider + mock_password_provider.check_password.return_value = defer.succeed(True) + + # log in twice, to get two devices + tok1 = self.login("localuser", "p") + self.login("localuser", "p", device_id="dev2") + mock_password_provider.check_password.reset_mock() + + # first delete should give a 401 + session = self._start_delete_device_session(tok1, "dev2") + mock_password_provider.check_password.assert_not_called() + + # now try deleting with the local password + mock_password_provider.check_password.return_value = defer.succeed(False) + channel = self._authed_delete_device( + tok1, "dev2", session, "localuser", "localpass" + ) + self.assertEqual(channel.code, 401) # XXX why not a 403? + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + mock_password_provider.check_password.assert_called_once_with( + "@localuser:test", "localpass" + ) + + @override_config( + { + **providers_config(PasswordOnlyAuthProvider), + "password_config": {"enabled": False}, + } + ) + def test_password_auth_disabled(self): + """password auth doesn't work if it's disabled across the board""" + # login flows should be empty + flows = self._get_login_flows() + self.assertEqual(flows, ADDITIONAL_LOGIN_FLOWS) + + # login shouldn't work and should be rejected with a 400 ("unknown login type") + channel = self._send_password_login("u", "p") + self.assertEqual(channel.code, 400, channel.result) + mock_password_provider.check_password.assert_not_called() + + @override_config(providers_config(CustomAuthProvider)) + def test_custom_auth_provider_login(self): + # login flows should have the custom flow and m.login.password, since we + # haven't disabled local password lookup. + # (password must come first, because reasons) + flows = self._get_login_flows() + self.assertEqual( + flows, + [{"type": "m.login.password"}, {"type": "test.login_type"}] + + ADDITIONAL_LOGIN_FLOWS, + ) + + # login with missing param should be rejected + channel = self._send_login("test.login_type", "u") + self.assertEqual(channel.code, 400, channel.result) + mock_password_provider.check_auth.assert_not_called() + + mock_password_provider.check_auth.return_value = defer.succeed("@user:bz") + channel = self._send_login("test.login_type", "u", test_field="y") + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual("@user:bz", channel.json_body["user_id"]) + mock_password_provider.check_auth.assert_called_once_with( + "u", "test.login_type", {"test_field": "y"} + ) + mock_password_provider.reset_mock() + + # try a weird username. Again, it's unclear what we *expect* to happen + # in these cases, but at least we can guard against the API changing + # unexpectedly + mock_password_provider.check_auth.return_value = defer.succeed( + "@ MALFORMED! :bz" + ) + channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ") + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual("@ MALFORMED! :bz", channel.json_body["user_id"]) + mock_password_provider.check_auth.assert_called_once_with( + " USER🙂NAME ", "test.login_type", {"test_field": " abc "} + ) + + @override_config(providers_config(CustomAuthProvider)) + def test_custom_auth_provider_ui_auth(self): + # register the user and log in twice, to get two devices + self.register_user("localuser", "localpass") + tok1 = self.login("localuser", "localpass") + self.login("localuser", "localpass", device_id="dev2") + + # make the initial request which returns a 401 + channel = self._delete_device(tok1, "dev2") + self.assertEqual(channel.code, 401) + # Ensure that flows are what is expected. + self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) + self.assertIn({"stages": ["test.login_type"]}, channel.json_body["flows"]) + session = channel.json_body["session"] + + # missing param + body = { + "auth": { + "type": "test.login_type", + "identifier": {"type": "m.id.user", "user": "localuser"}, + # FIXME "identifier" is ignored + # https://github.com/matrix-org/synapse/issues/5665 + "user": "localuser", + "session": session, + }, + } + + channel = self._delete_device(tok1, "dev2", body) + self.assertEqual(channel.code, 400) + # there's a perfectly good M_MISSING_PARAM errcode, but heaven forfend we should + # use it... + self.assertIn("Missing parameters", channel.json_body["error"]) + mock_password_provider.check_auth.assert_not_called() + mock_password_provider.reset_mock() + + # right params, but authing as the wrong user + mock_password_provider.check_auth.return_value = defer.succeed("@user:bz") + body["auth"]["test_field"] = "foo" + channel = self._delete_device(tok1, "dev2", body) + self.assertEqual(channel.code, 403) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + mock_password_provider.check_auth.assert_called_once_with( + "localuser", "test.login_type", {"test_field": "foo"} + ) + mock_password_provider.reset_mock() + + # and finally, succeed + mock_password_provider.check_auth.return_value = defer.succeed( + "@localuser:test" + ) + channel = self._delete_device(tok1, "dev2", body) + self.assertEqual(channel.code, 200) + mock_password_provider.check_auth.assert_called_once_with( + "localuser", "test.login_type", {"test_field": "foo"} + ) + + @override_config(providers_config(CustomAuthProvider)) + def test_custom_auth_provider_callback(self): + callback = Mock(return_value=defer.succeed(None)) + + mock_password_provider.check_auth.return_value = defer.succeed( + ("@user:bz", callback) + ) + channel = self._send_login("test.login_type", "u", test_field="y") + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual("@user:bz", channel.json_body["user_id"]) + mock_password_provider.check_auth.assert_called_once_with( + "u", "test.login_type", {"test_field": "y"} + ) + + # check the args to the callback + callback.assert_called_once() + call_args, call_kwargs = callback.call_args + # should be one positional arg + self.assertEqual(len(call_args), 1) + self.assertEqual(call_args[0]["user_id"], "@user:bz") + for p in ["user_id", "access_token", "device_id", "home_server"]: + self.assertIn(p, call_args[0]) + + @override_config( + {**providers_config(CustomAuthProvider), "password_config": {"enabled": False}} + ) + def test_custom_auth_password_disabled(self): + """Test login with a custom auth provider where password login is disabled""" + self.register_user("localuser", "localpass") + + flows = self._get_login_flows() + self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS) + + # login shouldn't work and should be rejected with a 400 ("unknown login type") + channel = self._send_password_login("localuser", "localpass") + self.assertEqual(channel.code, 400, channel.result) + mock_password_provider.check_auth.assert_not_called() + + @override_config( + { + **providers_config(CustomAuthProvider), + "password_config": {"localdb_enabled": False}, + } + ) + def test_custom_auth_no_local_user_fallback(self): + """Test login with a custom auth provider where the local db is disabled""" + self.register_user("localuser", "localpass") + + flows = self._get_login_flows() + self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS) + + # password login shouldn't work and should be rejected with a 400 + # ("unknown login type") + channel = self._send_password_login("localuser", "localpass") + self.assertEqual(channel.code, 400, channel.result) + + test_custom_auth_no_local_user_fallback.skip = "currently broken" + + def _get_login_flows(self) -> JsonDict: + _, channel = self.make_request("GET", "/_matrix/client/r0/login") + self.assertEqual(channel.code, 200, channel.result) + return channel.json_body["flows"] + + def _send_password_login(self, user: str, password: str) -> FakeChannel: + return self._send_login(type="m.login.password", user=user, password=password) + + def _send_login(self, type, user, **params) -> FakeChannel: + params.update({"user": user, "type": type}) + _, channel = self.make_request("POST", "/_matrix/client/r0/login", params) + return channel + + def _start_delete_device_session(self, access_token, device_id) -> str: + """Make an initial delete device request, and return the UI Auth session ID""" + channel = self._delete_device(access_token, device_id) + self.assertEqual(channel.code, 401) + # Ensure that flows are what is expected. + self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) + return channel.json_body["session"] + + def _authed_delete_device( + self, + access_token: str, + device_id: str, + session: str, + user_id: str, + password: str, + ) -> FakeChannel: + """Make a delete device request, authenticating with the given uid/password""" + return self._delete_device( + access_token, + device_id, + { + "auth": { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": user_id}, + # FIXME "identifier" is ignored + # https://github.com/matrix-org/synapse/issues/5665 + "user": user_id, + "password": password, + "session": session, + }, + }, + ) + + def _delete_device( + self, access_token: str, device: str, body: Union[JsonDict, bytes] = b"", + ) -> FakeChannel: + """Delete an individual device.""" + _, channel = self.make_request( + "DELETE", "devices/" + device, body, access_token=access_token + ) + return channel From 89f79307306ed117d9dcfe46a31a3fe1a1a5ceae Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 1 Dec 2020 13:04:03 +0000 Subject: [PATCH 12/19] Don't offer password login when it is disabled (#8835) Fix a minor bug where we would offer "m.login.password" login if a custom auth provider supported it, even if password login was disabled. --- changelog.d/8835.bugfix | 1 + synapse/handlers/auth.py | 10 +- tests/handlers/test_password_providers.py | 108 +++++++++++++++++++++- 3 files changed, 115 insertions(+), 4 deletions(-) create mode 100644 changelog.d/8835.bugfix diff --git a/changelog.d/8835.bugfix b/changelog.d/8835.bugfix new file mode 100644 index 0000000000..446d04aa55 --- /dev/null +++ b/changelog.d/8835.bugfix @@ -0,0 +1 @@ +Fix minor long-standing bug in login, where we would offer the `password` login type if a custom auth provider supported it, even if password login was disabled. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 5163afd86c..588d3a60df 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -205,15 +205,23 @@ class AuthHandler(BaseHandler): # type in the list. (NB that the spec doesn't require us to do so and # clients which favour types that they don't understand over those that # they do are technically broken) + + # start out by assuming PASSWORD is enabled; we will remove it later if not. login_types = [] - if self._password_enabled: + if hs.config.password_localdb_enabled: login_types.append(LoginType.PASSWORD) + for provider in self.password_providers: if hasattr(provider, "get_supported_login_types"): for t in provider.get_supported_login_types().keys(): if t not in login_types: login_types.append(t) + + if not self._password_enabled: + login_types.remove(LoginType.PASSWORD) + self._supported_login_types = login_types + # Login types and UI Auth types have a heavy overlap, but are not # necessarily identical. Login types have SSO (and other login types) # added in the rest layer, see synapse.rest.client.v1.login.LoginRestServerlet.on_GET. diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index edfab8a13a..dfbc4ee07e 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -70,6 +70,24 @@ class CustomAuthProvider: return mock_password_provider.check_auth(*args) +class PasswordCustomAuthProvider: + """A password_provider which implements password login via `check_auth`, as well + as a custom type.""" + + @staticmethod + def parse_config(self): + pass + + def __init__(self, config, account_handler): + pass + + def get_supported_login_types(self): + return {"m.login.password": ["password"], "test.login_type": ["test_field"]} + + def check_auth(self, *args): + return mock_password_provider.check_auth(*args) + + def providers_config(*providers: Type[Any]) -> dict: """Returns a config dict that will enable the given password auth providers""" return { @@ -246,7 +264,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): mock_password_provider.check_password.reset_mock() # first delete should give a 401 - session = self._start_delete_device_session(tok1, "dev2") + channel = self._delete_device(tok1, "dev2") + self.assertEqual(channel.code, 401) + # there are no valid flows here! + self.assertEqual(channel.json_body["flows"], []) + session = channel.json_body["session"] mock_password_provider.check_password.assert_not_called() # now try deleting with the local password @@ -410,6 +432,88 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.assertEqual(channel.code, 400, channel.result) mock_password_provider.check_auth.assert_not_called() + @override_config( + { + **providers_config(PasswordCustomAuthProvider), + "password_config": {"enabled": False}, + } + ) + def test_password_custom_auth_password_disabled_login(self): + """log in with a custom auth provider which implements password, but password + login is disabled""" + self.register_user("localuser", "localpass") + + flows = self._get_login_flows() + self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS) + + # login shouldn't work and should be rejected with a 400 ("unknown login type") + channel = self._send_password_login("localuser", "localpass") + self.assertEqual(channel.code, 400, channel.result) + mock_password_provider.check_auth.assert_not_called() + + @override_config( + { + **providers_config(PasswordCustomAuthProvider), + "password_config": {"enabled": False}, + } + ) + def test_password_custom_auth_password_disabled_ui_auth(self): + """UI Auth with a custom auth provider which implements password, but password + login is disabled""" + # register the user and log in twice via the test login type to get two devices, + self.register_user("localuser", "localpass") + mock_password_provider.check_auth.return_value = defer.succeed( + "@localuser:test" + ) + channel = self._send_login("test.login_type", "localuser", test_field="") + self.assertEqual(channel.code, 200, channel.result) + tok1 = channel.json_body["access_token"] + + channel = self._send_login( + "test.login_type", "localuser", test_field="", device_id="dev2" + ) + self.assertEqual(channel.code, 200, channel.result) + + # make the initial request which returns a 401 + channel = self._delete_device(tok1, "dev2") + self.assertEqual(channel.code, 401) + # Ensure that flows are what is expected. In particular, "password" should *not* + # be present. + self.assertIn({"stages": ["test.login_type"]}, channel.json_body["flows"]) + session = channel.json_body["session"] + + mock_password_provider.reset_mock() + + # check that auth with password is rejected + body = { + "auth": { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": "localuser"}, + # FIXME "identifier" is ignored + # https://github.com/matrix-org/synapse/issues/5665 + "user": "localuser", + "password": "localpass", + "session": session, + }, + } + + channel = self._delete_device(tok1, "dev2", body) + self.assertEqual(channel.code, 400) + self.assertEqual( + "Password login has been disabled.", channel.json_body["error"] + ) + mock_password_provider.check_auth.assert_not_called() + mock_password_provider.reset_mock() + + # successful auth + body["auth"]["type"] = "test.login_type" + body["auth"]["test_field"] = "x" + channel = self._delete_device(tok1, "dev2", body) + self.assertEqual(channel.code, 200) + mock_password_provider.check_auth.assert_called_once_with( + "localuser", "test.login_type", {"test_field": "x"} + ) + @override_config( { **providers_config(CustomAuthProvider), @@ -428,8 +532,6 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): channel = self._send_password_login("localuser", "localpass") self.assertEqual(channel.code, 400, channel.result) - test_custom_auth_no_local_user_fallback.skip = "currently broken" - def _get_login_flows(self) -> JsonDict: _, channel = self.make_request("GET", "/_matrix/client/r0/login") self.assertEqual(channel.code, 200, channel.result) From 3f0cba657ceb6e5286f52b06a74e969e4af9a146 Mon Sep 17 00:00:00 2001 From: Nicolas Chamo Date: Tue, 1 Dec 2020 10:24:56 -0300 Subject: [PATCH 13/19] Allow Date header through CORS (#8804) --- changelog.d/8804.feature | 1 + synapse/http/server.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/8804.feature diff --git a/changelog.d/8804.feature b/changelog.d/8804.feature new file mode 100644 index 0000000000..a907c8106c --- /dev/null +++ b/changelog.d/8804.feature @@ -0,0 +1 @@ +Allow Date header through CORS. Contributed by Nicolas Chamo. diff --git a/synapse/http/server.py b/synapse/http/server.py index 46cdfad04a..6a4e429a6c 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -674,7 +674,7 @@ def set_cors_headers(request: Request): ) request.setHeader( b"Access-Control-Allow-Headers", - b"Origin, X-Requested-With, Content-Type, Accept, Authorization", + b"Origin, X-Requested-With, Content-Type, Accept, Authorization, Date", ) From 9edff901d1eaca3c72ab4a0b31ff14d1472c6331 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 1 Dec 2020 15:52:49 +0000 Subject: [PATCH 14/19] Add missing `ordering` to background updates (#8850) It's important that we make sure our background updates happen in a defined order, to avoid disasters like #6923. Add an ordering to all of the background updates that have landed since #7190. --- changelog.d/8850.misc | 1 + ...07add_method_to_thumbnail_constraint.sql.postgres | 12 ++++++------ .../databases/main/schema/delta/58/12room_stats.sql | 4 ++-- .../schema/delta/58/22users_have_local_media.sql | 4 ++-- .../schema/delta/58/23e2e_cross_signing_keys_idx.sql | 4 ++-- 5 files changed, 13 insertions(+), 12 deletions(-) create mode 100644 changelog.d/8850.misc diff --git a/changelog.d/8850.misc b/changelog.d/8850.misc new file mode 100644 index 0000000000..4b54b8dd87 --- /dev/null +++ b/changelog.d/8850.misc @@ -0,0 +1 @@ +Add missing `ordering` to background database updates. diff --git a/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres b/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres index b64926e9c9..3275ae2b20 100644 --- a/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres +++ b/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres @@ -20,14 +20,14 @@ */ -- add new index that includes method to local media -INSERT INTO background_updates (update_name, progress_json) VALUES - ('local_media_repository_thumbnails_method_idx', '{}'); +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (5807, 'local_media_repository_thumbnails_method_idx', '{}'); -- add new index that includes method to remote media -INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES - ('remote_media_repository_thumbnails_method_idx', '{}', 'local_media_repository_thumbnails_method_idx'); +INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES + (5807, 'remote_media_repository_thumbnails_method_idx', '{}', 'local_media_repository_thumbnails_method_idx'); -- drop old index -INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES - ('media_repository_drop_index_wo_method', '{}', 'remote_media_repository_thumbnails_method_idx'); +INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES + (5807, 'media_repository_drop_index_wo_method', '{}', 'remote_media_repository_thumbnails_method_idx'); diff --git a/synapse/storage/databases/main/schema/delta/58/12room_stats.sql b/synapse/storage/databases/main/schema/delta/58/12room_stats.sql index cade5dcca8..fd733adf13 100644 --- a/synapse/storage/databases/main/schema/delta/58/12room_stats.sql +++ b/synapse/storage/databases/main/schema/delta/58/12room_stats.sql @@ -28,5 +28,5 @@ -- functionality as the old one. This effectively restarts the background job -- from the beginning, without running it twice in a row, supporting both -- upgrade usecases. -INSERT INTO background_updates (update_name, progress_json) VALUES - ('populate_stats_process_rooms_2', '{}'); +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (5812, 'populate_stats_process_rooms_2', '{}'); diff --git a/synapse/storage/databases/main/schema/delta/58/22users_have_local_media.sql b/synapse/storage/databases/main/schema/delta/58/22users_have_local_media.sql index a2842687f1..e1a35be831 100644 --- a/synapse/storage/databases/main/schema/delta/58/22users_have_local_media.sql +++ b/synapse/storage/databases/main/schema/delta/58/22users_have_local_media.sql @@ -1,2 +1,2 @@ -INSERT INTO background_updates (update_name, progress_json) VALUES - ('users_have_local_media', '{}'); \ No newline at end of file +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (5822, 'users_have_local_media', '{}'); diff --git a/synapse/storage/databases/main/schema/delta/58/23e2e_cross_signing_keys_idx.sql b/synapse/storage/databases/main/schema/delta/58/23e2e_cross_signing_keys_idx.sql index 61c558db77..75c3915a94 100644 --- a/synapse/storage/databases/main/schema/delta/58/23e2e_cross_signing_keys_idx.sql +++ b/synapse/storage/databases/main/schema/delta/58/23e2e_cross_signing_keys_idx.sql @@ -13,5 +13,5 @@ * limitations under the License. */ -INSERT INTO background_updates (update_name, progress_json) VALUES - ('e2e_cross_signing_keys_idx', '{}'); +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (5823, 'e2e_cross_signing_keys_idx', '{}'); From 4d9496559d25ba36eaea45d73e67e79b9d936450 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 1 Dec 2020 17:42:26 +0000 Subject: [PATCH 15/19] Support "identifier" dicts in UIA (#8848) The spec requires synapse to support `identifier` dicts for `m.login.password` user-interactive auth, which it did not (instead, it required an undocumented `user` parameter.) To fix this properly, we need to pull the code that interprets `identifier` into `AuthHandler.validate_login` so that it can be called from the UIA code. Fixes #5665. --- changelog.d/8848.bugfix | 1 + synapse/handlers/auth.py | 187 +++++++++++++++++++--- synapse/rest/client/v1/login.py | 107 +------------ tests/handlers/test_password_providers.py | 11 +- tests/rest/client/v2_alpha/test_auth.py | 33 +++- 5 files changed, 191 insertions(+), 148 deletions(-) create mode 100644 changelog.d/8848.bugfix diff --git a/changelog.d/8848.bugfix b/changelog.d/8848.bugfix new file mode 100644 index 0000000000..499e66f05b --- /dev/null +++ b/changelog.d/8848.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug which caused Synapse to require unspecified parameters during user-interactive authentication. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 588d3a60df..8815f685b9 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -238,6 +238,13 @@ class AuthHandler(BaseHandler): burst_count=self.hs.config.rc_login_failed_attempts.burst_count, ) + # Ratelimitier for failed /login attempts + self._failed_login_attempts_ratelimiter = Ratelimiter( + clock=hs.get_clock(), + rate_hz=self.hs.config.rc_login_failed_attempts.per_second, + burst_count=self.hs.config.rc_login_failed_attempts.burst_count, + ) + self._clock = self.hs.get_clock() # Expire old UI auth sessions after a period of time. @@ -650,14 +657,8 @@ class AuthHandler(BaseHandler): res = await checker.check_auth(authdict, clientip=clientip) return res - # build a v1-login-style dict out of the authdict and fall back to the - # v1 code - user_id = authdict.get("user") - - if user_id is None: - raise SynapseError(400, "", Codes.MISSING_PARAM) - - (canonical_id, callback) = await self.validate_login(user_id, authdict) + # fall back to the v1 login flow + canonical_id, _ = await self.validate_login(authdict) return canonical_id def _get_params_recaptcha(self) -> dict: @@ -832,15 +833,155 @@ class AuthHandler(BaseHandler): return self._supported_login_types async def validate_login( - self, username: str, login_submission: Dict[str, Any] + self, login_submission: Dict[str, Any], ratelimit: bool = False, ) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]: """Authenticates the user for the /login API - Also used by the user-interactive auth flow to validate - m.login.password auth types. + Also used by the user-interactive auth flow to validate auth types which don't + have an explicit UIA handler, including m.password.auth. Args: - username: username supplied by the user + login_submission: the whole of the login submission + (including 'type' and other relevant fields) + ratelimit: whether to apply the failed_login_attempt ratelimiter + Returns: + A tuple of the canonical user id, and optional callback + to be called once the access token and device id are issued + Raises: + StoreError if there was a problem accessing the database + SynapseError if there was a problem with the request + LoginError if there was an authentication problem. + """ + login_type = login_submission.get("type") + + # ideally, we wouldn't be checking the identifier unless we know we have a login + # method which uses it (https://github.com/matrix-org/synapse/issues/8836) + # + # But the auth providers' check_auth interface requires a username, so in + # practice we can only support login methods which we can map to a username + # anyway. + + # special case to check for "password" for the check_password interface + # for the auth providers + password = login_submission.get("password") + if login_type == LoginType.PASSWORD: + if not self._password_enabled: + raise SynapseError(400, "Password login has been disabled.") + if not isinstance(password, str): + raise SynapseError(400, "Bad parameter: password", Codes.INVALID_PARAM) + + # map old-school login fields into new-school "identifier" fields. + identifier_dict = convert_client_dict_legacy_fields_to_identifier( + login_submission + ) + + # convert phone type identifiers to generic threepids + if identifier_dict["type"] == "m.id.phone": + identifier_dict = login_id_phone_to_thirdparty(identifier_dict) + + # convert threepid identifiers to user IDs + if identifier_dict["type"] == "m.id.thirdparty": + address = identifier_dict.get("address") + medium = identifier_dict.get("medium") + + if medium is None or address is None: + raise SynapseError(400, "Invalid thirdparty identifier") + + # For emails, canonicalise the address. + # We store all email addresses canonicalised in the DB. + # (See add_threepid in synapse/handlers/auth.py) + if medium == "email": + try: + address = canonicalise_email(address) + except ValueError as e: + raise SynapseError(400, str(e)) + + # We also apply account rate limiting using the 3PID as a key, as + # otherwise using 3PID bypasses the ratelimiting based on user ID. + if ratelimit: + self._failed_login_attempts_ratelimiter.ratelimit( + (medium, address), update=False + ) + + # Check for login providers that support 3pid login types + if login_type == LoginType.PASSWORD: + # we've already checked that there is a (valid) password field + assert isinstance(password, str) + ( + canonical_user_id, + callback_3pid, + ) = await self.check_password_provider_3pid(medium, address, password) + if canonical_user_id: + # Authentication through password provider and 3pid succeeded + return canonical_user_id, callback_3pid + + # No password providers were able to handle this 3pid + # Check local store + user_id = await self.hs.get_datastore().get_user_id_by_threepid( + medium, address + ) + if not user_id: + logger.warning( + "unknown 3pid identifier medium %s, address %r", medium, address + ) + # We mark that we've failed to log in here, as + # `check_password_provider_3pid` might have returned `None` due + # to an incorrect password, rather than the account not + # existing. + # + # If it returned None but the 3PID was bound then we won't hit + # this code path, which is fine as then the per-user ratelimit + # will kick in below. + if ratelimit: + self._failed_login_attempts_ratelimiter.can_do_action( + (medium, address) + ) + raise LoginError(403, "", errcode=Codes.FORBIDDEN) + + identifier_dict = {"type": "m.id.user", "user": user_id} + + # by this point, the identifier should be an m.id.user: if it's anything + # else, we haven't understood it. + if identifier_dict["type"] != "m.id.user": + raise SynapseError(400, "Unknown login identifier type") + + username = identifier_dict.get("user") + if not username: + raise SynapseError(400, "User identifier is missing 'user' key") + + if username.startswith("@"): + qualified_user_id = username + else: + qualified_user_id = UserID(username, self.hs.hostname).to_string() + + # Check if we've hit the failed ratelimit (but don't update it) + if ratelimit: + self._failed_login_attempts_ratelimiter.ratelimit( + qualified_user_id.lower(), update=False + ) + + try: + return await self._validate_userid_login(username, login_submission) + except LoginError: + # The user has failed to log in, so we need to update the rate + # limiter. Using `can_do_action` avoids us raising a ratelimit + # exception and masking the LoginError. The actual ratelimiting + # should have happened above. + if ratelimit: + self._failed_login_attempts_ratelimiter.can_do_action( + qualified_user_id.lower() + ) + raise + + async def _validate_userid_login( + self, username: str, login_submission: Dict[str, Any], + ) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]: + """Helper for validate_login + + Handles login, once we've mapped 3pids onto userids + + Args: + username: the username, from the identifier dict login_submission: the whole of the login submission (including 'type' and other relevant fields) Returns: @@ -851,7 +992,6 @@ class AuthHandler(BaseHandler): SynapseError if there was a problem with the request LoginError if there was an authentication problem. """ - if username.startswith("@"): qualified_user_id = username else: @@ -860,20 +1000,13 @@ class AuthHandler(BaseHandler): login_type = login_submission.get("type") known_login_type = False - # special case to check for "password" for the check_password interface - # for the auth providers - password = login_submission.get("password") - - if login_type == LoginType.PASSWORD: - if not self._password_enabled: - raise SynapseError(400, "Password login has been disabled.") - if not password: - raise SynapseError(400, "Missing parameter: password") - for provider in self.password_providers: if hasattr(provider, "check_password") and login_type == LoginType.PASSWORD: known_login_type = True - is_valid = await provider.check_password(qualified_user_id, password) + # we've already checked that there is a (valid) password field + is_valid = await provider.check_password( + qualified_user_id, login_submission["password"] + ) if is_valid: return qualified_user_id, None @@ -914,8 +1047,12 @@ class AuthHandler(BaseHandler): if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled: known_login_type = True + # we've already checked that there is a (valid) password field + password = login_submission["password"] + assert isinstance(password, str) + canonical_user_id = await self._check_local_password( - qualified_user_id, password # type: ignore + qualified_user_id, password ) if canonical_user_id: diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 074bdd66c9..d7ae148214 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -19,10 +19,6 @@ from typing import Awaitable, Callable, Dict, Optional from synapse.api.errors import Codes, LoginError, SynapseError from synapse.api.ratelimiting import Ratelimiter from synapse.appservice import ApplicationService -from synapse.handlers.auth import ( - convert_client_dict_legacy_fields_to_identifier, - login_id_phone_to_thirdparty, -) from synapse.http.server import finish_request from synapse.http.servlet import ( RestServlet, @@ -33,7 +29,6 @@ from synapse.http.site import SynapseRequest from synapse.rest.client.v2_alpha._base import client_patterns from synapse.rest.well_known import WellKnownBuilder from synapse.types import JsonDict, UserID -from synapse.util.threepids import canonicalise_email logger = logging.getLogger(__name__) @@ -78,11 +73,6 @@ class LoginRestServlet(RestServlet): rate_hz=self.hs.config.rc_login_account.per_second, burst_count=self.hs.config.rc_login_account.burst_count, ) - self._failed_attempts_ratelimiter = Ratelimiter( - clock=hs.get_clock(), - rate_hz=self.hs.config.rc_login_failed_attempts.per_second, - burst_count=self.hs.config.rc_login_failed_attempts.burst_count, - ) def on_GET(self, request: SynapseRequest): flows = [] @@ -140,17 +130,6 @@ class LoginRestServlet(RestServlet): result["well_known"] = well_known_data return 200, result - def _get_qualified_user_id(self, identifier): - if identifier["type"] != "m.id.user": - raise SynapseError(400, "Unknown login identifier type") - if "user" not in identifier: - raise SynapseError(400, "User identifier is missing 'user' key") - - if identifier["user"].startswith("@"): - return identifier["user"] - else: - return UserID(identifier["user"], self.hs.hostname).to_string() - async def _do_appservice_login( self, login_submission: JsonDict, appservice: ApplicationService ): @@ -201,91 +180,9 @@ class LoginRestServlet(RestServlet): login_submission.get("address"), login_submission.get("user"), ) - identifier = convert_client_dict_legacy_fields_to_identifier(login_submission) - - # convert phone type identifiers to generic threepids - if identifier["type"] == "m.id.phone": - identifier = login_id_phone_to_thirdparty(identifier) - - # convert threepid identifiers to user IDs - if identifier["type"] == "m.id.thirdparty": - address = identifier.get("address") - medium = identifier.get("medium") - - if medium is None or address is None: - raise SynapseError(400, "Invalid thirdparty identifier") - - # For emails, canonicalise the address. - # We store all email addresses canonicalised in the DB. - # (See add_threepid in synapse/handlers/auth.py) - if medium == "email": - try: - address = canonicalise_email(address) - except ValueError as e: - raise SynapseError(400, str(e)) - - # We also apply account rate limiting using the 3PID as a key, as - # otherwise using 3PID bypasses the ratelimiting based on user ID. - self._failed_attempts_ratelimiter.ratelimit((medium, address), update=False) - - # Check for login providers that support 3pid login types - ( - canonical_user_id, - callback_3pid, - ) = await self.auth_handler.check_password_provider_3pid( - medium, address, login_submission["password"] - ) - if canonical_user_id: - # Authentication through password provider and 3pid succeeded - - result = await self._complete_login( - canonical_user_id, login_submission, callback_3pid - ) - return result - - # No password providers were able to handle this 3pid - # Check local store - user_id = await self.hs.get_datastore().get_user_id_by_threepid( - medium, address - ) - if not user_id: - logger.warning( - "unknown 3pid identifier medium %s, address %r", medium, address - ) - # We mark that we've failed to log in here, as - # `check_password_provider_3pid` might have returned `None` due - # to an incorrect password, rather than the account not - # existing. - # - # If it returned None but the 3PID was bound then we won't hit - # this code path, which is fine as then the per-user ratelimit - # will kick in below. - self._failed_attempts_ratelimiter.can_do_action((medium, address)) - raise LoginError(403, "", errcode=Codes.FORBIDDEN) - - identifier = {"type": "m.id.user", "user": user_id} - - # by this point, the identifier should be an m.id.user: if it's anything - # else, we haven't understood it. - qualified_user_id = self._get_qualified_user_id(identifier) - - # Check if we've hit the failed ratelimit (but don't update it) - self._failed_attempts_ratelimiter.ratelimit( - qualified_user_id.lower(), update=False + canonical_user_id, callback = await self.auth_handler.validate_login( + login_submission, ratelimit=True ) - - try: - canonical_user_id, callback = await self.auth_handler.validate_login( - identifier["user"], login_submission - ) - except LoginError: - # The user has failed to log in, so we need to update the rate - # limiter. Using `can_do_action` avoids us raising a ratelimit - # exception and masking the LoginError. The actual ratelimiting - # should have happened above. - self._failed_attempts_ratelimiter.can_do_action(qualified_user_id.lower()) - raise - result = await self._complete_login( canonical_user_id, login_submission, callback ) diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index dfbc4ee07e..22b9a11dc0 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -358,9 +358,6 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "auth": { "type": "test.login_type", "identifier": {"type": "m.id.user", "user": "localuser"}, - # FIXME "identifier" is ignored - # https://github.com/matrix-org/synapse/issues/5665 - "user": "localuser", "session": session, }, } @@ -489,9 +486,6 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "auth": { "type": "m.login.password", "identifier": {"type": "m.id.user", "user": "localuser"}, - # FIXME "identifier" is ignored - # https://github.com/matrix-org/synapse/issues/5665 - "user": "localuser", "password": "localpass", "session": session, }, @@ -541,7 +535,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): return self._send_login(type="m.login.password", user=user, password=password) def _send_login(self, type, user, **params) -> FakeChannel: - params.update({"user": user, "type": type}) + params.update({"identifier": {"type": "m.id.user", "user": user}, "type": type}) _, channel = self.make_request("POST", "/_matrix/client/r0/login", params) return channel @@ -569,9 +563,6 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "auth": { "type": "m.login.password", "identifier": {"type": "m.id.user", "user": user_id}, - # FIXME "identifier" is ignored - # https://github.com/matrix-org/synapse/issues/5665 - "user": user_id, "password": password, "session": session, }, diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index f684c37db5..77246e478f 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -38,11 +38,6 @@ class DummyRecaptchaChecker(UserInteractiveAuthChecker): return succeed(True) -class DummyPasswordChecker(UserInteractiveAuthChecker): - def check_auth(self, authdict, clientip): - return succeed(authdict["identifier"]["user"]) - - class FallbackAuthTests(unittest.HomeserverTestCase): servlets = [ @@ -162,9 +157,6 @@ class UIAuthTests(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - auth_handler = hs.get_auth_handler() - auth_handler.checkers[LoginType.PASSWORD] = DummyPasswordChecker(hs) - self.user_pass = "pass" self.user = self.register_user("test", self.user_pass) self.user_tok = self.login("test", self.user_pass) @@ -234,6 +226,31 @@ class UIAuthTests(unittest.HomeserverTestCase): }, ) + def test_grandfathered_identifier(self): + """Check behaviour without "identifier" dict + + Synapse used to require clients to submit a "user" field for m.login.password + UIA - check that still works. + """ + + device_id = self.get_device_ids()[0] + channel = self.delete_device(device_id, 401) + session = channel.json_body["session"] + + # Make another request providing the UI auth flow. + self.delete_device( + device_id, + 200, + { + "auth": { + "type": "m.login.password", + "user": self.user, + "password": self.user_pass, + "session": session, + }, + }, + ) + def test_can_change_body(self): """ The client dict can be modified during the user interactive authentication session. From edb3d3f82716c2b5c903ddb4d0df155e06c5c9e9 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Wed, 2 Dec 2020 10:38:18 +0000 Subject: [PATCH 16/19] Allow specifying room version in 'RestHelper.create_room_as' and add typing (#8854) This PR adds a `room_version` argument to the `RestHelper`'s `create_room_as` function for tests. I plan to use this for testing knocking, which currently uses an unstable room version. --- changelog.d/8854.misc | 1 + tests/rest/client/v1/utils.py | 27 +++++++++++++++++++++++++-- 2 files changed, 26 insertions(+), 2 deletions(-) create mode 100644 changelog.d/8854.misc diff --git a/changelog.d/8854.misc b/changelog.d/8854.misc new file mode 100644 index 0000000000..5895df2d5c --- /dev/null +++ b/changelog.d/8854.misc @@ -0,0 +1 @@ +Allow for specifying a room version when creating a room in unit tests via `RestHelper.create_room_as`. \ No newline at end of file diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index b58768675b..737c38c396 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -41,14 +41,37 @@ class RestHelper: auth_user_id = attr.ib() def create_room_as( - self, room_creator=None, is_public=True, tok=None, expect_code=200, - ): + self, + room_creator: str = None, + is_public: bool = True, + room_version: str = None, + tok: str = None, + expect_code: int = 200, + ) -> str: + """ + Create a room. + + Args: + room_creator: The user ID to create the room with. + is_public: If True, the `visibility` parameter will be set to the + default (public). Otherwise, the `visibility` parameter will be set + to "private". + room_version: The room version to create the room as. Defaults to Synapse's + default room version. + tok: The access token to use in the request. + expect_code: The expected HTTP response code. + + Returns: + The ID of the newly created room. + """ temp_id = self.auth_user_id self.auth_user_id = room_creator path = "/_matrix/client/r0/createRoom" content = {} if not is_public: content["visibility"] = "private" + if room_version: + content["room_version"] = room_version if tok: path = path + "?access_token=%s" % tok From d3ed93504bb6bb8ad138e356e3c74b6a7286299b Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 2 Dec 2020 10:38:50 +0000 Subject: [PATCH 17/19] Create a `PasswordProvider` wrapper object (#8849) The idea here is to abstract out all the conditional code which tests which methods a given password provider has, to provide a consistent interface. --- changelog.d/8849.misc | 1 + synapse/handlers/auth.py | 203 ++++++++++++++++------ tests/handlers/test_password_providers.py | 5 +- 3 files changed, 152 insertions(+), 57 deletions(-) create mode 100644 changelog.d/8849.misc diff --git a/changelog.d/8849.misc b/changelog.d/8849.misc new file mode 100644 index 0000000000..3dd496ce61 --- /dev/null +++ b/changelog.d/8849.misc @@ -0,0 +1 @@ +Refactor `password_auth_provider` support code. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 8815f685b9..c7dc07008a 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2014 - 2016 OpenMarket Ltd # Copyright 2017 Vector Creations Ltd +# Copyright 2019 - 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -25,6 +26,7 @@ from typing import ( Dict, Iterable, List, + Mapping, Optional, Tuple, Union, @@ -181,17 +183,12 @@ class AuthHandler(BaseHandler): # better way to break the loop account_handler = ModuleApi(hs, self) - self.password_providers = [] - for module, config in hs.config.password_providers: - try: - self.password_providers.append( - module(config=config, account_handler=account_handler) - ) - except Exception as e: - logger.error("Error while initializing %r: %s", module, e) - raise + self.password_providers = [ + PasswordProvider.load(module, config, account_handler) + for module, config in hs.config.password_providers + ] - logger.info("Extra password_providers: %r", self.password_providers) + logger.info("Extra password_providers: %s", self.password_providers) self.hs = hs # FIXME better possibility to access registrationHandler later? self.macaroon_gen = hs.get_macaroon_generator() @@ -853,6 +850,8 @@ class AuthHandler(BaseHandler): LoginError if there was an authentication problem. """ login_type = login_submission.get("type") + if not isinstance(login_type, str): + raise SynapseError(400, "Bad parameter: type", Codes.INVALID_PARAM) # ideally, we wouldn't be checking the identifier unless we know we have a login # method which uses it (https://github.com/matrix-org/synapse/issues/8836) @@ -998,24 +997,12 @@ class AuthHandler(BaseHandler): qualified_user_id = UserID(username, self.hs.hostname).to_string() login_type = login_submission.get("type") + # we already checked that we have a valid login type + assert isinstance(login_type, str) + known_login_type = False for provider in self.password_providers: - if hasattr(provider, "check_password") and login_type == LoginType.PASSWORD: - known_login_type = True - # we've already checked that there is a (valid) password field - is_valid = await provider.check_password( - qualified_user_id, login_submission["password"] - ) - if is_valid: - return qualified_user_id, None - - if not hasattr(provider, "get_supported_login_types") or not hasattr( - provider, "check_auth" - ): - # this password provider doesn't understand custom login types - continue - supported_login_types = provider.get_supported_login_types() if login_type not in supported_login_types: # this password provider doesn't understand this login type @@ -1040,8 +1027,6 @@ class AuthHandler(BaseHandler): result = await provider.check_auth(username, login_type, login_dict) if result: - if isinstance(result, str): - result = (result, None) return result if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled: @@ -1083,19 +1068,9 @@ class AuthHandler(BaseHandler): unsuccessful, `user_id` and `callback` are both `None`. """ for provider in self.password_providers: - if hasattr(provider, "check_3pid_auth"): - # This function is able to return a deferred that either - # resolves None, meaning authentication failure, or upon - # success, to a str (which is the user_id) or a tuple of - # (user_id, callback_func), where callback_func should be run - # after we've finished everything else - result = await provider.check_3pid_auth(medium, address, password) - if result: - # Check if the return value is a str or a tuple - if isinstance(result, str): - # If it's a str, set callback function to None - result = (result, None) - return result + result = await provider.check_3pid_auth(medium, address, password) + if result: + return result return None, None @@ -1153,16 +1128,11 @@ class AuthHandler(BaseHandler): # see if any of our auth providers want to know about this for provider in self.password_providers: - if hasattr(provider, "on_logged_out"): - # This might return an awaitable, if it does block the log out - # until it completes. - result = provider.on_logged_out( - user_id=user_info.user_id, - device_id=user_info.device_id, - access_token=access_token, - ) - if inspect.isawaitable(result): - await result + await provider.on_logged_out( + user_id=user_info.user_id, + device_id=user_info.device_id, + access_token=access_token, + ) # delete pushers associated with this access token if user_info.token_id is not None: @@ -1191,11 +1161,10 @@ class AuthHandler(BaseHandler): # see if any of our auth providers want to know about this for provider in self.password_providers: - if hasattr(provider, "on_logged_out"): - for token, token_id, device_id in tokens_and_devices: - await provider.on_logged_out( - user_id=user_id, device_id=device_id, access_token=token - ) + for token, token_id, device_id in tokens_and_devices: + await provider.on_logged_out( + user_id=user_id, device_id=device_id, access_token=token + ) # delete pushers associated with the access tokens await self.hs.get_pusherpool().remove_pushers_by_access_token( @@ -1519,3 +1488,127 @@ class MacaroonGenerator: macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) return macaroon + + +class PasswordProvider: + """Wrapper for a password auth provider module + + This class abstracts out all of the backwards-compatibility hacks for + password providers, to provide a consistent interface. + """ + + @classmethod + def load(cls, module, config, module_api: ModuleApi) -> "PasswordProvider": + try: + pp = module(config=config, account_handler=module_api) + except Exception as e: + logger.error("Error while initializing %r: %s", module, e) + raise + return cls(pp, module_api) + + def __init__(self, pp, module_api: ModuleApi): + self._pp = pp + self._module_api = module_api + + self._supported_login_types = {} + + # grandfather in check_password support + if hasattr(self._pp, "check_password"): + self._supported_login_types[LoginType.PASSWORD] = ("password",) + + g = getattr(self._pp, "get_supported_login_types", None) + if g: + self._supported_login_types.update(g()) + + def __str__(self): + return str(self._pp) + + def get_supported_login_types(self) -> Mapping[str, Iterable[str]]: + """Get the login types supported by this password provider + + Returns a map from a login type identifier (such as m.login.password) to an + iterable giving the fields which must be provided by the user in the submission + to the /login API. + + This wrapper adds m.login.password to the list if the underlying password + provider supports the check_password() api. + """ + return self._supported_login_types + + async def check_auth( + self, username: str, login_type: str, login_dict: JsonDict + ) -> Optional[Tuple[str, Optional[Callable]]]: + """Check if the user has presented valid login credentials + + This wrapper also calls check_password() if the underlying password provider + supports the check_password() api and the login type is m.login.password. + + Args: + username: user id presented by the client. Either an MXID or an unqualified + username. + + login_type: the login type being attempted - one of the types returned by + get_supported_login_types() + + login_dict: the dictionary of login secrets passed by the client. + + Returns: (user_id, callback) where `user_id` is the fully-qualified mxid of the + user, and `callback` is an optional callback which will be called with the + result from the /login call (including access_token, device_id, etc.) + """ + # first grandfather in a call to check_password + if login_type == LoginType.PASSWORD: + g = getattr(self._pp, "check_password", None) + if g: + qualified_user_id = self._module_api.get_qualified_user_id(username) + is_valid = await self._pp.check_password( + qualified_user_id, login_dict["password"] + ) + if is_valid: + return qualified_user_id, None + + g = getattr(self._pp, "check_auth", None) + if not g: + return None + result = await g(username, login_type, login_dict) + + # Check if the return value is a str or a tuple + if isinstance(result, str): + # If it's a str, set callback function to None + return result, None + + return result + + async def check_3pid_auth( + self, medium: str, address: str, password: str + ) -> Optional[Tuple[str, Optional[Callable]]]: + g = getattr(self._pp, "check_3pid_auth", None) + if not g: + return None + + # This function is able to return a deferred that either + # resolves None, meaning authentication failure, or upon + # success, to a str (which is the user_id) or a tuple of + # (user_id, callback_func), where callback_func should be run + # after we've finished everything else + result = await g(medium, address, password) + + # Check if the return value is a str or a tuple + if isinstance(result, str): + # If it's a str, set callback function to None + return result, None + + return result + + async def on_logged_out( + self, user_id: str, device_id: Optional[str], access_token: str + ) -> None: + g = getattr(self._pp, "on_logged_out", None) + if not g: + return + + # This might return an awaitable, if it does block the log out + # until it completes. + result = g(user_id=user_id, device_id=device_id, access_token=access_token,) + if inspect.isawaitable(result): + await result diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index 22b9a11dc0..ceaf0902d2 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -266,8 +266,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): # first delete should give a 401 channel = self._delete_device(tok1, "dev2") self.assertEqual(channel.code, 401) - # there are no valid flows here! - self.assertEqual(channel.json_body["flows"], []) + # m.login.password UIA is permitted because the auth provider allows it, + # even though the localdb does not. + self.assertEqual(channel.json_body["flows"], [{"stages": ["m.login.password"]}]) session = channel.json_body["session"] mock_password_provider.check_password.assert_not_called() From c21bdc813f5c21153cded05bcd0a57b5836f09fe Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 2 Dec 2020 07:09:21 -0500 Subject: [PATCH 18/19] Add basic SAML tests for mapping users. (#8800) --- .buildkite/scripts/test_old_deps.sh | 2 +- changelog.d/8800.misc | 1 + synapse/handlers/saml_handler.py | 2 +- tests/handlers/test_oidc.py | 34 +++---- tests/handlers/test_saml.py | 136 ++++++++++++++++++++++++++++ 5 files changed, 156 insertions(+), 19 deletions(-) create mode 100644 changelog.d/8800.misc create mode 100644 tests/handlers/test_saml.py diff --git a/.buildkite/scripts/test_old_deps.sh b/.buildkite/scripts/test_old_deps.sh index cdb77b556c..9905c4bc4f 100755 --- a/.buildkite/scripts/test_old_deps.sh +++ b/.buildkite/scripts/test_old_deps.sh @@ -6,7 +6,7 @@ set -ex apt-get update -apt-get install -y python3.5 python3.5-dev python3-pip libxml2-dev libxslt-dev zlib1g-dev tox +apt-get install -y python3.5 python3.5-dev python3-pip libxml2-dev libxslt-dev xmlsec1 zlib1g-dev tox export LANG="C.UTF-8" diff --git a/changelog.d/8800.misc b/changelog.d/8800.misc new file mode 100644 index 0000000000..57cca8fee5 --- /dev/null +++ b/changelog.d/8800.misc @@ -0,0 +1 @@ +Add additional error checking for OpenID Connect and SAML mapping providers. diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index 34db10ffe4..7ffad7d8af 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -265,7 +265,7 @@ class SamlHandler(BaseHandler): return UserAttributes( localpart=result.get("mxid_localpart"), display_name=result.get("displayname"), - emails=result.get("emails"), + emails=result.get("emails", []), ) with (await self._mapping_lock.queue(self._auth_provider_id)): diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index e880d32be6..c9807a7b73 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -23,7 +23,7 @@ import pymacaroons from twisted.python.failure import Failure from twisted.web._newclient import ResponseDone -from synapse.handlers.oidc_handler import OidcError, OidcHandler, OidcMappingProvider +from synapse.handlers.oidc_handler import OidcError, OidcMappingProvider from synapse.handlers.sso import MappingException from synapse.types import UserID @@ -127,13 +127,8 @@ async def get_json(url): class OidcHandlerTestCase(HomeserverTestCase): - def make_homeserver(self, reactor, clock): - - self.http_client = Mock(spec=["get_json"]) - self.http_client.get_json.side_effect = get_json - self.http_client.user_agent = "Synapse Test" - - config = self.default_config() + def default_config(self): + config = super().default_config() config["public_baseurl"] = BASE_URL oidc_config = { "enabled": True, @@ -149,19 +144,24 @@ class OidcHandlerTestCase(HomeserverTestCase): oidc_config.update(config.get("oidc_config", {})) config["oidc_config"] = oidc_config - hs = self.setup_test_homeserver( - http_client=self.http_client, - proxied_http_client=self.http_client, - config=config, - ) + return config - self.handler = OidcHandler(hs) + def make_homeserver(self, reactor, clock): + + self.http_client = Mock(spec=["get_json"]) + self.http_client.get_json.side_effect = get_json + self.http_client.user_agent = "Synapse Test" + + hs = self.setup_test_homeserver(proxied_http_client=self.http_client) + + self.handler = hs.get_oidc_handler() + sso_handler = hs.get_sso_handler() # Mock the render error method. self.render_error = Mock(return_value=None) - self.handler._sso_handler.render_error = self.render_error + sso_handler.render_error = self.render_error # Reduce the number of attempts when generating MXIDs. - self.handler._sso_handler._MAP_USERNAME_RETRIES = 3 + sso_handler._MAP_USERNAME_RETRIES = 3 return hs @@ -832,7 +832,7 @@ class OidcHandlerTestCase(HomeserverTestCase): # test_user is already taken, so test_user1 gets registered instead. self.assertEqual(mxid, "@test_user1:test") - # Register all of the potential users for a particular username. + # Register all of the potential mxids for a particular OIDC username. self.get_success( store.register_user(user_id="@tester:test", password_hash=None) ) diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py new file mode 100644 index 0000000000..79fd47036f --- /dev/null +++ b/tests/handlers/test_saml.py @@ -0,0 +1,136 @@ +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import attr + +from synapse.handlers.sso import MappingException + +from tests.unittest import HomeserverTestCase + +# These are a few constants that are used as config parameters in the tests. +BASE_URL = "https://synapse/" + + +@attr.s +class FakeAuthnResponse: + ava = attr.ib(type=dict) + + +class TestMappingProvider: + def __init__(self, config, module): + pass + + @staticmethod + def parse_config(config): + return + + @staticmethod + def get_saml_attributes(config): + return {"uid"}, {"displayName"} + + def get_remote_user_id(self, saml_response, client_redirect_url): + return saml_response.ava["uid"] + + def saml_response_to_user_attributes( + self, saml_response, failures, client_redirect_url + ): + localpart = saml_response.ava["username"] + (str(failures) if failures else "") + return {"mxid_localpart": localpart, "displayname": None} + + +class SamlHandlerTestCase(HomeserverTestCase): + def default_config(self): + config = super().default_config() + config["public_baseurl"] = BASE_URL + saml_config = { + "sp_config": {"metadata": {}}, + # Disable grandfathering. + "grandfathered_mxid_source_attribute": None, + "user_mapping_provider": {"module": __name__ + ".TestMappingProvider"}, + } + config["saml2_config"] = saml_config + + return config + + def make_homeserver(self, reactor, clock): + hs = self.setup_test_homeserver() + + self.handler = hs.get_saml_handler() + + # Reduce the number of attempts when generating MXIDs. + sso_handler = hs.get_sso_handler() + sso_handler._MAP_USERNAME_RETRIES = 3 + + return hs + + def test_map_saml_response_to_user(self): + """Ensure that mapping the SAML response returned from a provider to an MXID works properly.""" + saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"}) + # The redirect_url doesn't matter with the default user mapping provider. + redirect_url = "" + mxid = self.get_success( + self.handler._map_saml_response_to_user( + saml_response, redirect_url, "user-agent", "10.10.10.10" + ) + ) + self.assertEqual(mxid, "@test_user:test") + + def test_map_saml_response_to_invalid_localpart(self): + """If the mapping provider generates an invalid localpart it should be rejected.""" + saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"}) + redirect_url = "" + e = self.get_failure( + self.handler._map_saml_response_to_user( + saml_response, redirect_url, "user-agent", "10.10.10.10" + ), + MappingException, + ) + self.assertEqual(str(e.value), "localpart is invalid: föö") + + def test_map_saml_response_to_user_retries(self): + """The mapping provider can retry generating an MXID if the MXID is already in use.""" + store = self.hs.get_datastore() + self.get_success( + store.register_user(user_id="@test_user:test", password_hash=None) + ) + saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"}) + redirect_url = "" + mxid = self.get_success( + self.handler._map_saml_response_to_user( + saml_response, redirect_url, "user-agent", "10.10.10.10" + ) + ) + # test_user is already taken, so test_user1 gets registered instead. + self.assertEqual(mxid, "@test_user1:test") + + # Register all of the potential mxids for a particular SAML username. + self.get_success( + store.register_user(user_id="@tester:test", password_hash=None) + ) + for i in range(1, 3): + self.get_success( + store.register_user(user_id="@tester%d:test" % i, password_hash=None) + ) + + # Now attempt to map to a username, this will fail since all potential usernames are taken. + saml_response = FakeAuthnResponse({"uid": "tester", "username": "tester"}) + e = self.get_failure( + self.handler._map_saml_response_to_user( + saml_response, redirect_url, "user-agent", "10.10.10.10" + ), + MappingException, + ) + self.assertEqual( + str(e.value), "Unable to generate a Matrix ID from the SSO response" + ) From 8388384a640d3381b5858d3fb1d2ea0a8c9c059c Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 2 Dec 2020 07:45:42 -0500 Subject: [PATCH 19/19] Fix a regression when grandfathering SAML users. (#8855) This was broken in #8801 when abstracting code shared with OIDC. After this change both SAML and OIDC have a concept of grandfathering users, but with different implementations. --- changelog.d/8855.feature | 1 + synapse/handlers/oidc_handler.py | 30 ++++++++++++++-- synapse/handlers/saml_handler.py | 9 ++--- synapse/handlers/sso.py | 60 ++++++++++---------------------- tests/handlers/test_oidc.py | 8 +++++ tests/handlers/test_saml.py | 34 +++++++++++++++++- 6 files changed, 94 insertions(+), 48 deletions(-) create mode 100644 changelog.d/8855.feature diff --git a/changelog.d/8855.feature b/changelog.d/8855.feature new file mode 100644 index 0000000000..77f7fe4e5d --- /dev/null +++ b/changelog.d/8855.feature @@ -0,0 +1 @@ +Add support for re-trying generation of a localpart for OpenID Connect mapping providers. diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index 78c4e94a9d..55c4377890 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -39,7 +39,7 @@ from synapse.handlers._base import BaseHandler from synapse.handlers.sso import MappingException, UserAttributes from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable -from synapse.types import JsonDict, map_username_to_mxid_localpart +from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart from synapse.util import json_decoder if TYPE_CHECKING: @@ -898,13 +898,39 @@ class OidcHandler(BaseHandler): return UserAttributes(**attributes) + async def grandfather_existing_users() -> Optional[str]: + if self._allow_existing_users: + # If allowing existing users we want to generate a single localpart + # and attempt to match it. + attributes = await oidc_response_to_user_attributes(failures=0) + + user_id = UserID(attributes.localpart, self.server_name).to_string() + users = await self.store.get_users_by_id_case_insensitive(user_id) + if users: + # If an existing matrix ID is returned, then use it. + if len(users) == 1: + previously_registered_user_id = next(iter(users)) + elif user_id in users: + previously_registered_user_id = user_id + else: + # Do not attempt to continue generating Matrix IDs. + raise MappingException( + "Attempted to login as '{}' but it matches more than one user inexactly: {}".format( + user_id, users + ) + ) + + return previously_registered_user_id + + return None + return await self._sso_handler.get_mxid_from_sso( self._auth_provider_id, remote_user_id, user_agent, ip_address, oidc_response_to_user_attributes, - self._allow_existing_users, + grandfather_existing_users, ) diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index 7ffad7d8af..76d4169fe2 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -268,7 +268,7 @@ class SamlHandler(BaseHandler): emails=result.get("emails", []), ) - with (await self._mapping_lock.queue(self._auth_provider_id)): + async def grandfather_existing_users() -> Optional[str]: # backwards-compatibility hack: see if there is an existing user with a # suitable mapping from the uid if ( @@ -290,17 +290,18 @@ class SamlHandler(BaseHandler): if users: registered_user_id = list(users.keys())[0] logger.info("Grandfathering mapping to %s", registered_user_id) - await self.store.record_user_external_id( - self._auth_provider_id, remote_user_id, registered_user_id - ) return registered_user_id + return None + + with (await self._mapping_lock.queue(self._auth_provider_id)): return await self._sso_handler.get_mxid_from_sso( self._auth_provider_id, remote_user_id, user_agent, ip_address, saml_response_to_remapped_user_attributes, + grandfather_existing_users, ) def expire_sessions(self): diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index d963082210..f42b90e1bc 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -116,7 +116,7 @@ class SsoHandler(BaseHandler): user_agent: str, ip_address: str, sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]], - allow_existing_users: bool = False, + grandfather_existing_users: Optional[Callable[[], Awaitable[Optional[str]]]], ) -> str: """ Given an SSO ID, retrieve the user ID for it and possibly register the user. @@ -125,6 +125,10 @@ class SsoHandler(BaseHandler): if it has that matrix ID is returned regardless of the current mapping logic. + If a callable is provided for grandfathering users, it is called and can + potentially return a matrix ID to use. If it does, the SSO ID is linked to + this matrix ID for subsequent calls. + The mapping function is called (potentially multiple times) to generate a localpart for the user. @@ -132,17 +136,6 @@ class SsoHandler(BaseHandler): given user-agent and IP address and the SSO ID is linked to this matrix ID for subsequent calls. - If allow_existing_users is true the mapping function is only called once - and results in: - - 1. The use of a previously registered matrix ID. In this case, the - SSO ID is linked to the matrix ID. (Note it is possible that - other SSO IDs are linked to the same matrix ID.) - 2. An unused localpart, in which case the user is registered (as - discussed above). - 3. An error if the generated localpart matches multiple pre-existing - matrix IDs. Generally this should not happen. - Args: auth_provider_id: A unique identifier for this SSO provider, e.g. "oidc" or "saml". @@ -152,8 +145,9 @@ class SsoHandler(BaseHandler): sso_to_matrix_id_mapper: A callable to generate the user attributes. The only parameter is an integer which represents the amount of times the returned mxid localpart mapping has failed. - allow_existing_users: True if the localpart returned from the - mapping provider can be linked to an existing matrix ID. + grandfather_existing_users: A callable which can return an previously + existing matrix ID. The SSO ID is then linked to the returned + matrix ID. Returns: The user ID associated with the SSO response. @@ -171,6 +165,16 @@ class SsoHandler(BaseHandler): if previously_registered_user_id: return previously_registered_user_id + # Check for grandfathering of users. + if grandfather_existing_users: + previously_registered_user_id = await grandfather_existing_users() + if previously_registered_user_id: + # Future logins should also match this user ID. + await self.store.record_user_external_id( + auth_provider_id, remote_user_id, previously_registered_user_id + ) + return previously_registered_user_id + # Otherwise, generate a new user. for i in range(self._MAP_USERNAME_RETRIES): try: @@ -194,33 +198,7 @@ class SsoHandler(BaseHandler): # Check if this mxid already exists user_id = UserID(attributes.localpart, self.server_name).to_string() - users = await self.store.get_users_by_id_case_insensitive(user_id) - # Note, if allow_existing_users is true then the loop is guaranteed - # to end on the first iteration: either by matching an existing user, - # raising an error, or registering a new user. See the docstring for - # more in-depth an explanation. - if users and allow_existing_users: - # If an existing matrix ID is returned, then use it. - if len(users) == 1: - previously_registered_user_id = next(iter(users)) - elif user_id in users: - previously_registered_user_id = user_id - else: - # Do not attempt to continue generating Matrix IDs. - raise MappingException( - "Attempted to login as '{}' but it matches more than one user inexactly: {}".format( - user_id, users - ) - ) - - # Future logins should also match this user ID. - await self.store.record_user_external_id( - auth_provider_id, remote_user_id, previously_registered_user_id - ) - - return previously_registered_user_id - - elif not users: + if not await self.store.get_users_by_id_case_insensitive(user_id): # This mxid is free break else: diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index c9807a7b73..d485af52fd 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -731,6 +731,14 @@ class OidcHandlerTestCase(HomeserverTestCase): ) self.assertEqual(mxid, "@test_user:test") + # Subsequent calls should map to the same mxid. + mxid = self.get_success( + self.handler._map_userinfo_to_user( + userinfo, token, "user-agent", "10.10.10.10" + ) + ) + self.assertEqual(mxid, "@test_user:test") + # Note that a second SSO user can be mapped to the same Matrix ID. (This # requires a unique sub, but something that maps to the same matrix ID, # in this case we'll just use the same username. A more realistic example diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py index 79fd47036f..e1e13a5faf 100644 --- a/tests/handlers/test_saml.py +++ b/tests/handlers/test_saml.py @@ -16,7 +16,7 @@ import attr from synapse.handlers.sso import MappingException -from tests.unittest import HomeserverTestCase +from tests.unittest import HomeserverTestCase, override_config # These are a few constants that are used as config parameters in the tests. BASE_URL = "https://synapse/" @@ -59,6 +59,10 @@ class SamlHandlerTestCase(HomeserverTestCase): "grandfathered_mxid_source_attribute": None, "user_mapping_provider": {"module": __name__ + ".TestMappingProvider"}, } + + # Update this config with what's in the default config so that + # override_config works as expected. + saml_config.update(config.get("saml2_config", {})) config["saml2_config"] = saml_config return config @@ -86,6 +90,34 @@ class SamlHandlerTestCase(HomeserverTestCase): ) self.assertEqual(mxid, "@test_user:test") + @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}}) + def test_map_saml_response_to_existing_user(self): + """Existing users can log in with SAML account.""" + store = self.hs.get_datastore() + self.get_success( + store.register_user(user_id="@test_user:test", password_hash=None) + ) + + # Map a user via SSO. + saml_response = FakeAuthnResponse( + {"uid": "tester", "mxid": ["test_user"], "username": "test_user"} + ) + redirect_url = "" + mxid = self.get_success( + self.handler._map_saml_response_to_user( + saml_response, redirect_url, "user-agent", "10.10.10.10" + ) + ) + self.assertEqual(mxid, "@test_user:test") + + # Subsequent calls should map to the same mxid. + mxid = self.get_success( + self.handler._map_saml_response_to_user( + saml_response, redirect_url, "user-agent", "10.10.10.10" + ) + ) + self.assertEqual(mxid, "@test_user:test") + def test_map_saml_response_to_invalid_localpart(self): """If the mapping provider generates an invalid localpart it should be rejected.""" saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"})