Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes

matrix-org-hotfixes
Erik Johnston 2023-06-27 10:31:31 +01:00
commit 21cb804023
14 changed files with 409 additions and 41 deletions

4
Cargo.lock generated
View File

@ -340,9 +340,9 @@ dependencies = [
[[package]]
name = "serde_json"
version = "1.0.97"
version = "1.0.99"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bdf3bf93142acad5821c99197022e170842cdbc1c30482b98750c688c640842a"
checksum = "46266871c240a00b8f503b877622fe33430b3c7d963bdc0f2adc511e54a1eae3"
dependencies = [
"itoa",
"ryu",

1
changelog.d/15817.bugfix Normal file
View File

@ -0,0 +1 @@
Fix sqlite `user_filters` upgrade introduced in v1.86.0.

View File

@ -0,0 +1 @@
Add spam checker module API for logins.

View File

@ -348,6 +348,42 @@ callback returns `False`, Synapse falls through to the next one. The value of th
callback that does not return `False` will be used. If this happens, Synapse will not call
any of the subsequent implementations of this callback.
### `check_login_for_spam`
_First introduced in Synapse v1.87.0_
```python
async def check_login_for_spam(
user_id: str,
device_id: Optional[str],
initial_display_name: Optional[str],
request_info: Collection[Tuple[Optional[str], str]],
auth_provider_id: Optional[str] = None,
) -> Union["synapse.module_api.NOT_SPAM", "synapse.module_api.errors.Codes"]
```
Called when a user logs in.
The arguments passed to this callback are:
* `user_id`: The user ID the user is logging in with
* `device_id`: The device ID the user is re-logging into.
* `initial_display_name`: The device display name, if any.
* `request_info`: A collection of tuples, which first item is a user agent, and which
second item is an IP address. These user agents and IP addresses are the ones that were
used during the login process.
* `auth_provider_id`: The identifier of the SSO authentication provider, if any.
If multiple modules implement this callback, they will be considered in order. If a
callback returns `synapse.module_api.NOT_SPAM`, Synapse falls through to the next one.
The value of the first callback that does not return `synapse.module_api.NOT_SPAM` will
be used. If this happens, Synapse will not call any of the subsequent implementations of
this callback.
*Note:* This will not be called when a user registers.
## Example
The example below is a module that implements the spam checker callback

58
poetry.lock generated
View File

@ -2245,28 +2245,28 @@ jupyter = ["ipywidgets (>=7.5.1,<9)"]
[[package]]
name = "ruff"
version = "0.0.272"
version = "0.0.275"
description = "An extremely fast Python linter, written in Rust."
optional = false
python-versions = ">=3.7"
files = [
{file = "ruff-0.0.272-py3-none-macosx_10_7_x86_64.whl", hash = "sha256:ae9b57546e118660175d45d264b87e9b4c19405c75b587b6e4d21e6a17bf4fdf"},
{file = "ruff-0.0.272-py3-none-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:1609b864a8d7ee75a8c07578bdea0a7db75a144404e75ef3162e0042bfdc100d"},
{file = "ruff-0.0.272-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ee76b4f05fcfff37bd6ac209d1370520d509ea70b5a637bdf0a04d0c99e13dff"},
{file = "ruff-0.0.272-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:48eccf225615e106341a641f826b15224b8a4240b84269ead62f0afd6d7e2d95"},
{file = "ruff-0.0.272-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:677284430ac539bb23421a2b431b4ebc588097ef3ef918d0e0a8d8ed31fea216"},
{file = "ruff-0.0.272-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:9c4bfb75456a8e1efe14c52fcefb89cfb8f2a0d31ed8d804b82c6cf2dc29c42c"},
{file = "ruff-0.0.272-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:86bc788245361a8148ff98667da938a01e1606b28a45e50ac977b09d3ad2c538"},
{file = "ruff-0.0.272-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:27b2ea68d2aa69fff1b20b67636b1e3e22a6a39e476c880da1282c3e4bf6ee5a"},
{file = "ruff-0.0.272-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bd2bbe337a3f84958f796c77820d55ac2db1e6753f39d1d1baed44e07f13f96d"},
{file = "ruff-0.0.272-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:d5a208f8ef0e51d4746930589f54f9f92f84bb69a7d15b1de34ce80a7681bc00"},
{file = "ruff-0.0.272-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:905ff8f3d6206ad56fcd70674453527b9011c8b0dc73ead27618426feff6908e"},
{file = "ruff-0.0.272-py3-none-musllinux_1_2_i686.whl", hash = "sha256:19643d448f76b1eb8a764719072e9c885968971bfba872e14e7257e08bc2f2b7"},
{file = "ruff-0.0.272-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:691d72a00a99707a4e0b2846690961157aef7b17b6b884f6b4420a9f25cd39b5"},
{file = "ruff-0.0.272-py3-none-win32.whl", hash = "sha256:dc406e5d756d932da95f3af082814d2467943631a587339ee65e5a4f4fbe83eb"},
{file = "ruff-0.0.272-py3-none-win_amd64.whl", hash = "sha256:a37ec80e238ead2969b746d7d1b6b0d31aa799498e9ba4281ab505b93e1f4b28"},
{file = "ruff-0.0.272-py3-none-win_arm64.whl", hash = "sha256:06b8ee4eb8711ab119db51028dd9f5384b44728c23586424fd6e241a5b9c4a3b"},
{file = "ruff-0.0.272.tar.gz", hash = "sha256:273a01dc8c3c4fd4c2af7ea7a67c8d39bb09bce466e640dd170034da75d14cab"},
{file = "ruff-0.0.275-py3-none-macosx_10_7_x86_64.whl", hash = "sha256:5e6554a072e7ce81eb6f0bec1cebd3dcb0e358652c0f4900d7d630d61691e914"},
{file = "ruff-0.0.275-py3-none-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:1cc599022fe5ffb143a965b8d659eb64161ab8ab4433d208777eab018a1aab67"},
{file = "ruff-0.0.275-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5206fc1cd8c1c1deadd2e6360c0dbcd690f1c845da588ca9d32e4a764a402c60"},
{file = "ruff-0.0.275-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0c4e6468da26f77b90cae35319d310999f471a8c352998e9b39937a23750149e"},
{file = "ruff-0.0.275-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0dbdea02942131dbc15dd45f431d152224f15e1dd1859fcd0c0487b658f60f1a"},
{file = "ruff-0.0.275-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:22efd9f41af27ef8fb9779462c46c35c89134d33e326c889971e10b2eaf50c63"},
{file = "ruff-0.0.275-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2c09662112cfa22d7467a19252a546291fd0eae4f423e52b75a7a2000a1894db"},
{file = "ruff-0.0.275-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:80043726662144876a381efaab88841c88e8df8baa69559f96b22d4fa216bef1"},
{file = "ruff-0.0.275-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5859ee543b01b7eb67835dfd505faa8bb7cc1550f0295c92c1401b45b42be399"},
{file = "ruff-0.0.275-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:c8ace4d40a57b5ea3c16555f25a6b16bc5d8b2779ae1912ce2633543d4e9b1da"},
{file = "ruff-0.0.275-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:8347fc16aa185aae275906c4ac5b770e00c896b6a0acd5ba521f158801911998"},
{file = "ruff-0.0.275-py3-none-musllinux_1_2_i686.whl", hash = "sha256:ec43658c64bfda44fd84bbea9da8c7a3b34f65448192d1c4dd63e9f4e7abfdd4"},
{file = "ruff-0.0.275-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:508b13f7ca37274cceaba4fb3ea5da6ca192356323d92acf39462337c33ad14e"},
{file = "ruff-0.0.275-py3-none-win32.whl", hash = "sha256:6afb1c4422f24f361e877937e2a44b3f8176774a476f5e33845ebfe887dd5ec2"},
{file = "ruff-0.0.275-py3-none-win_amd64.whl", hash = "sha256:d9b264d78621bf7b698b6755d4913ab52c19bd28bee1a16001f954d64c1a1220"},
{file = "ruff-0.0.275-py3-none-win_arm64.whl", hash = "sha256:a19ce3bea71023eee5f0f089dde4a4272d088d5ac0b675867e074983238ccc65"},
{file = "ruff-0.0.275.tar.gz", hash = "sha256:a63a0b645da699ae5c758fce19188e901b3033ec54d862d93fcd042addf7f38d"},
]
[[package]]
@ -2711,21 +2711,21 @@ files = [
[[package]]
name = "towncrier"
version = "22.12.0"
version = "23.6.0"
description = "Building newsfiles for your project."
optional = false
python-versions = ">=3.7"
files = [
{file = "towncrier-22.12.0-py3-none-any.whl", hash = "sha256:9767a899a4d6856950f3598acd9e8f08da2663c49fdcda5ea0f9e6ba2afc8eea"},
{file = "towncrier-22.12.0.tar.gz", hash = "sha256:9c49d7e75f646a9aea02ae904c0bc1639c8fd14a01292d2b123b8d307564034d"},
{file = "towncrier-23.6.0-py3-none-any.whl", hash = "sha256:da552f29192b3c2b04d630133f194c98e9f14f0558669d427708e203fea4d0a5"},
{file = "towncrier-23.6.0.tar.gz", hash = "sha256:fc29bd5ab4727c8dacfbe636f7fb5dc53b99805b62da1c96b214836159ff70c1"},
]
[package.dependencies]
click = "*"
click-default-group = "*"
importlib-resources = {version = ">=5", markers = "python_version < \"3.10\""}
incremental = "*"
jinja2 = "*"
setuptools = "*"
tomli = {version = "*", markers = "python_version < \"3.11\""}
[package.extras]
@ -2931,13 +2931,13 @@ files = [
[[package]]
name = "types-opentracing"
version = "2.4.10.4"
version = "2.4.10.5"
description = "Typing stubs for opentracing"
optional = false
python-versions = "*"
files = [
{file = "types-opentracing-2.4.10.4.tar.gz", hash = "sha256:347040c9da4ada7d3c795659912c95d98c5651e242e8eaa0344815fee5bb97e2"},
{file = "types_opentracing-2.4.10.4-py3-none-any.whl", hash = "sha256:73c9b958eea3df6c4906ebf3865608a562dd9981c1bbc75a373a583c613bed56"},
{file = "types-opentracing-2.4.10.5.tar.gz", hash = "sha256:852d13ab1324832835d50c00cfd58b9267f0e79ec3189e5664c2a90c26880fd4"},
{file = "types_opentracing-2.4.10.5-py3-none-any.whl", hash = "sha256:8f12ab4dce3e298a8e6655da9a6d52171e7a275357eae4cec22a1663d94023a7"},
]
[[package]]
@ -3003,13 +3003,13 @@ types-urllib3 = "*"
[[package]]
name = "types-setuptools"
version = "67.8.0.0"
version = "68.0.0.0"
description = "Typing stubs for setuptools"
optional = false
python-versions = "*"
files = [
{file = "types-setuptools-67.8.0.0.tar.gz", hash = "sha256:95c9ed61871d6c0e258433373a4e1753c0a7c3627a46f4d4058c7b5a08ab844f"},
{file = "types_setuptools-67.8.0.0-py3-none-any.whl", hash = "sha256:6df73340d96b238a4188b7b7668814b37e8018168aef1eef94a3b1872e3f60ff"},
{file = "types-setuptools-68.0.0.0.tar.gz", hash = "sha256:fc958b4123b155ffc069a66d3af5fe6c1f9d0600c35c0c8444b2ab4147112641"},
{file = "types_setuptools-68.0.0.0-py3-none-any.whl", hash = "sha256:cc00e09ba8f535362cbe1ea8b8407d15d14b59c57f4190cceaf61a9e57616446"},
]
[[package]]
@ -3294,4 +3294,4 @@ user-search = ["pyicu"]
[metadata]
lock-version = "2.0"
python-versions = "^3.7.1"
content-hash = "090924370b17fd265407b5a3f9cbc00997308f575b455399b39a48e3ca1a5a8e"
content-hash = "7f31754a1009d7b6c9a1bd7221a0b243ffd510f362c28f0da417aaac16757a87"

View File

@ -311,7 +311,7 @@ all = [
# We pin black so that our tests don't start failing on new releases.
isort = ">=5.10.1"
black = ">=22.3.0"
ruff = "0.0.272"
ruff = "0.0.275"
# Typechecking
lxml-stubs = ">=0.4.0"

View File

@ -521,6 +521,11 @@ class SynapseRequest(Request):
else:
return self.getClientAddress().host
def request_info(self) -> "RequestInfo":
h = self.getHeader(b"User-Agent")
user_agent = h.decode("ascii", "replace") if h else None
return RequestInfo(user_agent=user_agent, ip=self.get_client_ip_if_available())
class XForwardedForRequest(SynapseRequest):
"""Request object which honours proxy headers
@ -661,3 +666,9 @@ class SynapseSite(Site):
def log(self, request: SynapseRequest) -> None:
pass
@attr.s(auto_attribs=True, frozen=True, slots=True)
class RequestInfo:
user_agent: Optional[str]
ip: str

View File

@ -80,6 +80,7 @@ from synapse.module_api.callbacks.account_validity_callbacks import (
)
from synapse.module_api.callbacks.spamchecker_callbacks import (
CHECK_EVENT_FOR_SPAM_CALLBACK,
CHECK_LOGIN_FOR_SPAM_CALLBACK,
CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK,
CHECK_REGISTRATION_FOR_SPAM_CALLBACK,
CHECK_USERNAME_FOR_SPAM_CALLBACK,
@ -302,6 +303,7 @@ class ModuleApi:
CHECK_REGISTRATION_FOR_SPAM_CALLBACK
] = None,
check_media_file_for_spam: Optional[CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK] = None,
check_login_for_spam: Optional[CHECK_LOGIN_FOR_SPAM_CALLBACK] = None,
) -> None:
"""Registers callbacks for spam checking capabilities.
@ -319,6 +321,7 @@ class ModuleApi:
check_username_for_spam=check_username_for_spam,
check_registration_for_spam=check_registration_for_spam,
check_media_file_for_spam=check_media_file_for_spam,
check_login_for_spam=check_login_for_spam,
)
def register_account_validity_callbacks(

View File

@ -196,6 +196,26 @@ CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK = Callable[
]
],
]
CHECK_LOGIN_FOR_SPAM_CALLBACK = Callable[
[
str,
Optional[str],
Optional[str],
Collection[Tuple[Optional[str], str]],
Optional[str],
],
Awaitable[
Union[
Literal["NOT_SPAM"],
Codes,
# Highly experimental, not officially part of the spamchecker API, may
# disappear without warning depending on the results of ongoing
# experiments.
# Use this to return additional information as part of an error.
Tuple[Codes, JsonDict],
]
],
]
def load_legacy_spam_checkers(hs: "synapse.server.HomeServer") -> None:
@ -315,6 +335,7 @@ class SpamCheckerModuleApiCallbacks:
self._check_media_file_for_spam_callbacks: List[
CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK
] = []
self._check_login_for_spam_callbacks: List[CHECK_LOGIN_FOR_SPAM_CALLBACK] = []
def register_callbacks(
self,
@ -335,6 +356,7 @@ class SpamCheckerModuleApiCallbacks:
CHECK_REGISTRATION_FOR_SPAM_CALLBACK
] = None,
check_media_file_for_spam: Optional[CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK] = None,
check_login_for_spam: Optional[CHECK_LOGIN_FOR_SPAM_CALLBACK] = None,
) -> None:
"""Register callbacks from module for each hook."""
if check_event_for_spam is not None:
@ -378,6 +400,9 @@ class SpamCheckerModuleApiCallbacks:
if check_media_file_for_spam is not None:
self._check_media_file_for_spam_callbacks.append(check_media_file_for_spam)
if check_login_for_spam is not None:
self._check_login_for_spam_callbacks.append(check_login_for_spam)
@trace
async def check_event_for_spam(
self, event: "synapse.events.EventBase"
@ -819,3 +844,58 @@ class SpamCheckerModuleApiCallbacks:
return synapse.api.errors.Codes.FORBIDDEN, {}
return self.NOT_SPAM
async def check_login_for_spam(
self,
user_id: str,
device_id: Optional[str],
initial_display_name: Optional[str],
request_info: Collection[Tuple[Optional[str], str]],
auth_provider_id: Optional[str] = None,
) -> Union[Tuple[Codes, dict], Literal["NOT_SPAM"]]:
"""Checks if we should allow the given registration request.
Args:
user_id: The request user ID
request_info: List of tuples of user agent and IP that
were used during the registration process.
auth_provider_id: The SSO IdP the user used, e.g "oidc", "saml",
"cas". If any. Note this does not include users registered
via a password provider.
Returns:
Enum for how the request should be handled
"""
for callback in self._check_login_for_spam_callbacks:
with Measure(
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
):
res = await delay_cancellation(
callback(
user_id,
device_id,
initial_display_name,
request_info,
auth_provider_id,
)
)
# Normalize return values to `Codes` or `"NOT_SPAM"`.
if res is self.NOT_SPAM:
continue
elif isinstance(res, synapse.api.errors.Codes):
return res, {}
elif (
isinstance(res, tuple)
and len(res) == 2
and isinstance(res[0], synapse.api.errors.Codes)
and isinstance(res[1], dict)
):
return res
else:
logger.warning(
"Module returned invalid value, rejecting login as spam"
)
return synapse.api.errors.Codes.FORBIDDEN, {}
return self.NOT_SPAM

View File

@ -50,7 +50,7 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
from synapse.http.site import SynapseRequest
from synapse.http.site import RequestInfo, SynapseRequest
from synapse.rest.client._base import client_patterns
from synapse.rest.well_known import WellKnownBuilder
from synapse.types import JsonDict, UserID
@ -114,6 +114,7 @@ class LoginRestServlet(RestServlet):
self.auth_handler = self.hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler()
self._sso_handler = hs.get_sso_handler()
self._spam_checker = hs.get_module_api_callbacks().spam_checker
self._well_known_builder = WellKnownBuilder(hs)
self._address_ratelimiter = Ratelimiter(
@ -197,6 +198,8 @@ class LoginRestServlet(RestServlet):
self._refresh_tokens_enabled and client_requested_refresh_token
)
request_info = request.request_info()
try:
if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE:
requester = await self.auth.get_user_by_req(request)
@ -216,6 +219,7 @@ class LoginRestServlet(RestServlet):
login_submission,
appservice,
should_issue_refresh_token=should_issue_refresh_token,
request_info=request_info,
)
elif (
self.jwt_enabled
@ -227,6 +231,7 @@ class LoginRestServlet(RestServlet):
result = await self._do_jwt_login(
login_submission,
should_issue_refresh_token=should_issue_refresh_token,
request_info=request_info,
)
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
await self._address_ratelimiter.ratelimit(
@ -235,6 +240,7 @@ class LoginRestServlet(RestServlet):
result = await self._do_token_login(
login_submission,
should_issue_refresh_token=should_issue_refresh_token,
request_info=request_info,
)
else:
await self._address_ratelimiter.ratelimit(
@ -243,6 +249,7 @@ class LoginRestServlet(RestServlet):
result = await self._do_other_login(
login_submission,
should_issue_refresh_token=should_issue_refresh_token,
request_info=request_info,
)
except KeyError:
raise SynapseError(400, "Missing JSON keys.")
@ -265,6 +272,8 @@ class LoginRestServlet(RestServlet):
login_submission: JsonDict,
appservice: ApplicationService,
should_issue_refresh_token: bool = False,
*,
request_info: RequestInfo,
) -> LoginResponse:
identifier = login_submission.get("identifier")
logger.info("Got appservice login request with identifier: %r", identifier)
@ -300,10 +309,15 @@ class LoginRestServlet(RestServlet):
# The user represented by an appservice's configured sender_localpart
# is not actually created in Synapse.
should_check_deactivated=qualified_user_id != appservice.sender,
request_info=request_info,
)
async def _do_other_login(
self, login_submission: JsonDict, should_issue_refresh_token: bool = False
self,
login_submission: JsonDict,
should_issue_refresh_token: bool = False,
*,
request_info: RequestInfo,
) -> LoginResponse:
"""Handle non-token/saml/jwt logins
@ -333,6 +347,7 @@ class LoginRestServlet(RestServlet):
login_submission,
callback,
should_issue_refresh_token=should_issue_refresh_token,
request_info=request_info,
)
return result
@ -347,6 +362,8 @@ class LoginRestServlet(RestServlet):
should_issue_refresh_token: bool = False,
auth_provider_session_id: Optional[str] = None,
should_check_deactivated: bool = True,
*,
request_info: RequestInfo,
) -> LoginResponse:
"""Called when we've successfully authed the user and now need to
actually login them in (e.g. create devices). This gets called on
@ -371,6 +388,7 @@ class LoginRestServlet(RestServlet):
This exists purely for appservice's configured sender_localpart
which doesn't have an associated user in the database.
request_info: The user agent/IP address of the user.
Returns:
Dictionary of account information after successful login.
@ -417,6 +435,22 @@ class LoginRestServlet(RestServlet):
)
initial_display_name = login_submission.get("initial_device_display_name")
spam_check = await self._spam_checker.check_login_for_spam(
user_id,
device_id=device_id,
initial_display_name=initial_display_name,
request_info=[(request_info.user_agent, request_info.ip)],
auth_provider_id=auth_provider_id,
)
if spam_check != self._spam_checker.NOT_SPAM:
logger.info("Blocking login due to spam checker")
raise SynapseError(
403,
msg="Login was blocked by the server",
errcode=spam_check[0],
additional_fields=spam_check[1],
)
(
device_id,
access_token,
@ -451,7 +485,11 @@ class LoginRestServlet(RestServlet):
return result
async def _do_token_login(
self, login_submission: JsonDict, should_issue_refresh_token: bool = False
self,
login_submission: JsonDict,
should_issue_refresh_token: bool = False,
*,
request_info: RequestInfo,
) -> LoginResponse:
"""
Handle token login.
@ -474,10 +512,15 @@ class LoginRestServlet(RestServlet):
auth_provider_id=res.auth_provider_id,
should_issue_refresh_token=should_issue_refresh_token,
auth_provider_session_id=res.auth_provider_session_id,
request_info=request_info,
)
async def _do_jwt_login(
self, login_submission: JsonDict, should_issue_refresh_token: bool = False
self,
login_submission: JsonDict,
should_issue_refresh_token: bool = False,
*,
request_info: RequestInfo,
) -> LoginResponse:
"""
Handle the custom JWT login.
@ -496,6 +539,7 @@ class LoginRestServlet(RestServlet):
login_submission,
create_non_existent_users=True,
should_issue_refresh_token=should_issue_refresh_token,
request_info=request_info,
)

View File

@ -61,9 +61,7 @@ def run_upgrade(
full_user_id text NOT NULL,
user_id text NOT NULL,
filter_id bigint NOT NULL,
filter_json bytea NOT NULL,
UNIQUE (full_user_id),
UNIQUE (user_id)
filter_json bytea NOT NULL
)
"""
cur.execute(create_sql)

View File

@ -0,0 +1,65 @@
# Copyright 2023 The Matrix.org Foundation C.I.C
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.config.homeserver import HomeServerConfig
from synapse.storage.database import LoggingTransaction
from synapse.storage.engines import BaseDatabaseEngine, Sqlite3Engine
def run_update(
cur: LoggingTransaction,
database_engine: BaseDatabaseEngine,
config: HomeServerConfig,
) -> None:
"""
Fix to drop unused indexes caused by incorrectly adding UNIQUE constraint to
columns `user_id` and `full_user_id` of table `user_filters` in previous migration.
"""
if isinstance(database_engine, Sqlite3Engine):
cur.execute("DROP TABLE IF EXISTS temp_user_filters")
create_sql = """
CREATE TABLE temp_user_filters (
full_user_id text NOT NULL,
user_id text NOT NULL,
filter_id bigint NOT NULL,
filter_json bytea NOT NULL
)
"""
cur.execute(create_sql)
copy_sql = """
INSERT INTO temp_user_filters (
user_id,
filter_id,
filter_json,
full_user_id)
SELECT user_id, filter_id, filter_json, full_user_id FROM user_filters
"""
cur.execute(copy_sql)
drop_sql = """
DROP TABLE user_filters
"""
cur.execute(drop_sql)
rename_sql = """
ALTER TABLE temp_user_filters RENAME to user_filters
"""
cur.execute(rename_sql)
index_sql = """
CREATE UNIQUE INDEX IF NOT EXISTS user_filters_unique ON
user_filters (user_id, filter_id)
"""
cur.execute(index_sql)

View File

@ -0,0 +1,25 @@
# Copyright 2023 The Matrix.org Foundation C.I.C
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.storage.database import LoggingTransaction
from synapse.storage.engines import BaseDatabaseEngine, Sqlite3Engine
def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) -> None:
if isinstance(database_engine, Sqlite3Engine):
idx_sql = """
CREATE UNIQUE INDEX IF NOT EXISTS user_filters_full_user_id_unique ON
user_filters (full_user_id, filter_id)
"""
cur.execute(idx_sql)

View File

@ -13,11 +13,12 @@
# limitations under the License.
import time
import urllib.parse
from typing import Any, Dict, List, Optional
from typing import Any, Collection, Dict, List, Optional, Tuple, Union
from unittest.mock import Mock
from urllib.parse import urlencode
import pymacaroons
from typing_extensions import Literal
from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource
@ -26,11 +27,12 @@ import synapse.rest.admin
from synapse.api.constants import ApprovalNoticeMedium, LoginType
from synapse.api.errors import Codes
from synapse.appservice import ApplicationService
from synapse.module_api import ModuleApi
from synapse.rest.client import devices, login, logout, register
from synapse.rest.client.account import WhoamiRestServlet
from synapse.rest.synapse.client import build_synapse_client_resource_tree
from synapse.server import HomeServer
from synapse.types import create_requester
from synapse.types import JsonDict, create_requester
from synapse.util import Clock
from tests import unittest
@ -88,6 +90,56 @@ ADDITIONAL_LOGIN_FLOWS = [
]
class TestSpamChecker:
def __init__(self, config: None, api: ModuleApi):
api.register_spam_checker_callbacks(
check_login_for_spam=self.check_login_for_spam,
)
@staticmethod
def parse_config(config: JsonDict) -> None:
return None
async def check_login_for_spam(
self,
user_id: str,
device_id: Optional[str],
initial_display_name: Optional[str],
request_info: Collection[Tuple[Optional[str], str]],
auth_provider_id: Optional[str] = None,
) -> Union[
Literal["NOT_SPAM"],
Tuple["synapse.module_api.errors.Codes", JsonDict],
]:
return "NOT_SPAM"
class DenyAllSpamChecker:
def __init__(self, config: None, api: ModuleApi):
api.register_spam_checker_callbacks(
check_login_for_spam=self.check_login_for_spam,
)
@staticmethod
def parse_config(config: JsonDict) -> None:
return None
async def check_login_for_spam(
self,
user_id: str,
device_id: Optional[str],
initial_display_name: Optional[str],
request_info: Collection[Tuple[Optional[str], str]],
auth_provider_id: Optional[str] = None,
) -> Union[
Literal["NOT_SPAM"],
Tuple["synapse.module_api.errors.Codes", JsonDict],
]:
# Return an odd set of values to ensure that they get correctly passed
# to the client.
return Codes.LIMIT_EXCEEDED, {"extra": "value"}
class LoginRestServletTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
@ -469,6 +521,58 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
],
)
@override_config(
{
"modules": [
{
"module": TestSpamChecker.__module__
+ "."
+ TestSpamChecker.__qualname__
}
]
}
)
def test_spam_checker_allow(self) -> None:
"""Check that that adding a spam checker doesn't break login."""
self.register_user("kermit", "monkey")
body = {"type": "m.login.password", "user": "kermit", "password": "monkey"}
channel = self.make_request(
"POST",
"/_matrix/client/r0/login",
body,
)
self.assertEqual(channel.code, 200, channel.result)
@override_config(
{
"modules": [
{
"module": DenyAllSpamChecker.__module__
+ "."
+ DenyAllSpamChecker.__qualname__
}
]
}
)
def test_spam_checker_deny(self) -> None:
"""Check that login"""
self.register_user("kermit", "monkey")
body = {"type": "m.login.password", "user": "kermit", "password": "monkey"}
channel = self.make_request(
"POST",
"/_matrix/client/r0/login",
body,
)
self.assertEqual(channel.code, 403, channel.result)
self.assertDictContainsSubset(
{"errcode": Codes.LIMIT_EXCEEDED, "extra": "value"}, channel.json_body
)
@skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC")
class MultiSSOTestCase(unittest.HomeserverTestCase):