Add login spam checker API (#15838)

mv/msc3944
Erik Johnston 2023-06-26 15:12:20 +01:00 committed by GitHub
parent 52d8131e87
commit 25c55a9d22
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 285 additions and 6 deletions

View File

@ -0,0 +1 @@
Add spam checker module API for logins.

View File

@ -348,6 +348,42 @@ callback returns `False`, Synapse falls through to the next one. The value of th
callback that does not return `False` will be used. If this happens, Synapse will not call
any of the subsequent implementations of this callback.
### `check_login_for_spam`
_First introduced in Synapse v1.87.0_
```python
async def check_login_for_spam(
user_id: str,
device_id: Optional[str],
initial_display_name: Optional[str],
request_info: Collection[Tuple[Optional[str], str]],
auth_provider_id: Optional[str] = None,
) -> Union["synapse.module_api.NOT_SPAM", "synapse.module_api.errors.Codes"]
```
Called when a user logs in.
The arguments passed to this callback are:
* `user_id`: The user ID the user is logging in with
* `device_id`: The device ID the user is re-logging into.
* `initial_display_name`: The device display name, if any.
* `request_info`: A collection of tuples, which first item is a user agent, and which
second item is an IP address. These user agents and IP addresses are the ones that were
used during the login process.
* `auth_provider_id`: The identifier of the SSO authentication provider, if any.
If multiple modules implement this callback, they will be considered in order. If a
callback returns `synapse.module_api.NOT_SPAM`, Synapse falls through to the next one.
The value of the first callback that does not return `synapse.module_api.NOT_SPAM` will
be used. If this happens, Synapse will not call any of the subsequent implementations of
this callback.
*Note:* This will not be called when a user registers.
## Example
The example below is a module that implements the spam checker callback

View File

@ -521,6 +521,11 @@ class SynapseRequest(Request):
else:
return self.getClientAddress().host
def request_info(self) -> "RequestInfo":
h = self.getHeader(b"User-Agent")
user_agent = h.decode("ascii", "replace") if h else None
return RequestInfo(user_agent=user_agent, ip=self.get_client_ip_if_available())
class XForwardedForRequest(SynapseRequest):
"""Request object which honours proxy headers
@ -661,3 +666,9 @@ class SynapseSite(Site):
def log(self, request: SynapseRequest) -> None:
pass
@attr.s(auto_attribs=True, frozen=True, slots=True)
class RequestInfo:
user_agent: Optional[str]
ip: str

View File

@ -80,6 +80,7 @@ from synapse.module_api.callbacks.account_validity_callbacks import (
)
from synapse.module_api.callbacks.spamchecker_callbacks import (
CHECK_EVENT_FOR_SPAM_CALLBACK,
CHECK_LOGIN_FOR_SPAM_CALLBACK,
CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK,
CHECK_REGISTRATION_FOR_SPAM_CALLBACK,
CHECK_USERNAME_FOR_SPAM_CALLBACK,
@ -302,6 +303,7 @@ class ModuleApi:
CHECK_REGISTRATION_FOR_SPAM_CALLBACK
] = None,
check_media_file_for_spam: Optional[CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK] = None,
check_login_for_spam: Optional[CHECK_LOGIN_FOR_SPAM_CALLBACK] = None,
) -> None:
"""Registers callbacks for spam checking capabilities.
@ -319,6 +321,7 @@ class ModuleApi:
check_username_for_spam=check_username_for_spam,
check_registration_for_spam=check_registration_for_spam,
check_media_file_for_spam=check_media_file_for_spam,
check_login_for_spam=check_login_for_spam,
)
def register_account_validity_callbacks(

View File

@ -196,6 +196,26 @@ CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK = Callable[
]
],
]
CHECK_LOGIN_FOR_SPAM_CALLBACK = Callable[
[
str,
Optional[str],
Optional[str],
Collection[Tuple[Optional[str], str]],
Optional[str],
],
Awaitable[
Union[
Literal["NOT_SPAM"],
Codes,
# Highly experimental, not officially part of the spamchecker API, may
# disappear without warning depending on the results of ongoing
# experiments.
# Use this to return additional information as part of an error.
Tuple[Codes, JsonDict],
]
],
]
def load_legacy_spam_checkers(hs: "synapse.server.HomeServer") -> None:
@ -315,6 +335,7 @@ class SpamCheckerModuleApiCallbacks:
self._check_media_file_for_spam_callbacks: List[
CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK
] = []
self._check_login_for_spam_callbacks: List[CHECK_LOGIN_FOR_SPAM_CALLBACK] = []
def register_callbacks(
self,
@ -335,6 +356,7 @@ class SpamCheckerModuleApiCallbacks:
CHECK_REGISTRATION_FOR_SPAM_CALLBACK
] = None,
check_media_file_for_spam: Optional[CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK] = None,
check_login_for_spam: Optional[CHECK_LOGIN_FOR_SPAM_CALLBACK] = None,
) -> None:
"""Register callbacks from module for each hook."""
if check_event_for_spam is not None:
@ -378,6 +400,9 @@ class SpamCheckerModuleApiCallbacks:
if check_media_file_for_spam is not None:
self._check_media_file_for_spam_callbacks.append(check_media_file_for_spam)
if check_login_for_spam is not None:
self._check_login_for_spam_callbacks.append(check_login_for_spam)
@trace
async def check_event_for_spam(
self, event: "synapse.events.EventBase"
@ -819,3 +844,58 @@ class SpamCheckerModuleApiCallbacks:
return synapse.api.errors.Codes.FORBIDDEN, {}
return self.NOT_SPAM
async def check_login_for_spam(
self,
user_id: str,
device_id: Optional[str],
initial_display_name: Optional[str],
request_info: Collection[Tuple[Optional[str], str]],
auth_provider_id: Optional[str] = None,
) -> Union[Tuple[Codes, dict], Literal["NOT_SPAM"]]:
"""Checks if we should allow the given registration request.
Args:
user_id: The request user ID
request_info: List of tuples of user agent and IP that
were used during the registration process.
auth_provider_id: The SSO IdP the user used, e.g "oidc", "saml",
"cas". If any. Note this does not include users registered
via a password provider.
Returns:
Enum for how the request should be handled
"""
for callback in self._check_login_for_spam_callbacks:
with Measure(
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
):
res = await delay_cancellation(
callback(
user_id,
device_id,
initial_display_name,
request_info,
auth_provider_id,
)
)
# Normalize return values to `Codes` or `"NOT_SPAM"`.
if res is self.NOT_SPAM:
continue
elif isinstance(res, synapse.api.errors.Codes):
return res, {}
elif (
isinstance(res, tuple)
and len(res) == 2
and isinstance(res[0], synapse.api.errors.Codes)
and isinstance(res[1], dict)
):
return res
else:
logger.warning(
"Module returned invalid value, rejecting login as spam"
)
return synapse.api.errors.Codes.FORBIDDEN, {}
return self.NOT_SPAM

View File

@ -50,7 +50,7 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
from synapse.http.site import SynapseRequest
from synapse.http.site import RequestInfo, SynapseRequest
from synapse.rest.client._base import client_patterns
from synapse.rest.well_known import WellKnownBuilder
from synapse.types import JsonDict, UserID
@ -114,6 +114,7 @@ class LoginRestServlet(RestServlet):
self.auth_handler = self.hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler()
self._sso_handler = hs.get_sso_handler()
self._spam_checker = hs.get_module_api_callbacks().spam_checker
self._well_known_builder = WellKnownBuilder(hs)
self._address_ratelimiter = Ratelimiter(
@ -197,6 +198,8 @@ class LoginRestServlet(RestServlet):
self._refresh_tokens_enabled and client_requested_refresh_token
)
request_info = request.request_info()
try:
if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE:
requester = await self.auth.get_user_by_req(request)
@ -216,6 +219,7 @@ class LoginRestServlet(RestServlet):
login_submission,
appservice,
should_issue_refresh_token=should_issue_refresh_token,
request_info=request_info,
)
elif (
self.jwt_enabled
@ -227,6 +231,7 @@ class LoginRestServlet(RestServlet):
result = await self._do_jwt_login(
login_submission,
should_issue_refresh_token=should_issue_refresh_token,
request_info=request_info,
)
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
await self._address_ratelimiter.ratelimit(
@ -235,6 +240,7 @@ class LoginRestServlet(RestServlet):
result = await self._do_token_login(
login_submission,
should_issue_refresh_token=should_issue_refresh_token,
request_info=request_info,
)
else:
await self._address_ratelimiter.ratelimit(
@ -243,6 +249,7 @@ class LoginRestServlet(RestServlet):
result = await self._do_other_login(
login_submission,
should_issue_refresh_token=should_issue_refresh_token,
request_info=request_info,
)
except KeyError:
raise SynapseError(400, "Missing JSON keys.")
@ -265,6 +272,8 @@ class LoginRestServlet(RestServlet):
login_submission: JsonDict,
appservice: ApplicationService,
should_issue_refresh_token: bool = False,
*,
request_info: RequestInfo,
) -> LoginResponse:
identifier = login_submission.get("identifier")
logger.info("Got appservice login request with identifier: %r", identifier)
@ -300,10 +309,15 @@ class LoginRestServlet(RestServlet):
# The user represented by an appservice's configured sender_localpart
# is not actually created in Synapse.
should_check_deactivated=qualified_user_id != appservice.sender,
request_info=request_info,
)
async def _do_other_login(
self, login_submission: JsonDict, should_issue_refresh_token: bool = False
self,
login_submission: JsonDict,
should_issue_refresh_token: bool = False,
*,
request_info: RequestInfo,
) -> LoginResponse:
"""Handle non-token/saml/jwt logins
@ -333,6 +347,7 @@ class LoginRestServlet(RestServlet):
login_submission,
callback,
should_issue_refresh_token=should_issue_refresh_token,
request_info=request_info,
)
return result
@ -347,6 +362,8 @@ class LoginRestServlet(RestServlet):
should_issue_refresh_token: bool = False,
auth_provider_session_id: Optional[str] = None,
should_check_deactivated: bool = True,
*,
request_info: RequestInfo,
) -> LoginResponse:
"""Called when we've successfully authed the user and now need to
actually login them in (e.g. create devices). This gets called on
@ -371,6 +388,7 @@ class LoginRestServlet(RestServlet):
This exists purely for appservice's configured sender_localpart
which doesn't have an associated user in the database.
request_info: The user agent/IP address of the user.
Returns:
Dictionary of account information after successful login.
@ -417,6 +435,22 @@ class LoginRestServlet(RestServlet):
)
initial_display_name = login_submission.get("initial_device_display_name")
spam_check = await self._spam_checker.check_login_for_spam(
user_id,
device_id=device_id,
initial_display_name=initial_display_name,
request_info=[(request_info.user_agent, request_info.ip)],
auth_provider_id=auth_provider_id,
)
if spam_check != self._spam_checker.NOT_SPAM:
logger.info("Blocking login due to spam checker")
raise SynapseError(
403,
msg="Login was blocked by the server",
errcode=spam_check[0],
additional_fields=spam_check[1],
)
(
device_id,
access_token,
@ -451,7 +485,11 @@ class LoginRestServlet(RestServlet):
return result
async def _do_token_login(
self, login_submission: JsonDict, should_issue_refresh_token: bool = False
self,
login_submission: JsonDict,
should_issue_refresh_token: bool = False,
*,
request_info: RequestInfo,
) -> LoginResponse:
"""
Handle token login.
@ -474,10 +512,15 @@ class LoginRestServlet(RestServlet):
auth_provider_id=res.auth_provider_id,
should_issue_refresh_token=should_issue_refresh_token,
auth_provider_session_id=res.auth_provider_session_id,
request_info=request_info,
)
async def _do_jwt_login(
self, login_submission: JsonDict, should_issue_refresh_token: bool = False
self,
login_submission: JsonDict,
should_issue_refresh_token: bool = False,
*,
request_info: RequestInfo,
) -> LoginResponse:
"""
Handle the custom JWT login.
@ -496,6 +539,7 @@ class LoginRestServlet(RestServlet):
login_submission,
create_non_existent_users=True,
should_issue_refresh_token=should_issue_refresh_token,
request_info=request_info,
)

View File

@ -13,11 +13,12 @@
# limitations under the License.
import time
import urllib.parse
from typing import Any, Dict, List, Optional
from typing import Any, Collection, Dict, List, Optional, Tuple, Union
from unittest.mock import Mock
from urllib.parse import urlencode
import pymacaroons
from typing_extensions import Literal
from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource
@ -26,11 +27,12 @@ import synapse.rest.admin
from synapse.api.constants import ApprovalNoticeMedium, LoginType
from synapse.api.errors import Codes
from synapse.appservice import ApplicationService
from synapse.module_api import ModuleApi
from synapse.rest.client import devices, login, logout, register
from synapse.rest.client.account import WhoamiRestServlet
from synapse.rest.synapse.client import build_synapse_client_resource_tree
from synapse.server import HomeServer
from synapse.types import create_requester
from synapse.types import JsonDict, create_requester
from synapse.util import Clock
from tests import unittest
@ -88,6 +90,56 @@ ADDITIONAL_LOGIN_FLOWS = [
]
class TestSpamChecker:
def __init__(self, config: None, api: ModuleApi):
api.register_spam_checker_callbacks(
check_login_for_spam=self.check_login_for_spam,
)
@staticmethod
def parse_config(config: JsonDict) -> None:
return None
async def check_login_for_spam(
self,
user_id: str,
device_id: Optional[str],
initial_display_name: Optional[str],
request_info: Collection[Tuple[Optional[str], str]],
auth_provider_id: Optional[str] = None,
) -> Union[
Literal["NOT_SPAM"],
Tuple["synapse.module_api.errors.Codes", JsonDict],
]:
return "NOT_SPAM"
class DenyAllSpamChecker:
def __init__(self, config: None, api: ModuleApi):
api.register_spam_checker_callbacks(
check_login_for_spam=self.check_login_for_spam,
)
@staticmethod
def parse_config(config: JsonDict) -> None:
return None
async def check_login_for_spam(
self,
user_id: str,
device_id: Optional[str],
initial_display_name: Optional[str],
request_info: Collection[Tuple[Optional[str], str]],
auth_provider_id: Optional[str] = None,
) -> Union[
Literal["NOT_SPAM"],
Tuple["synapse.module_api.errors.Codes", JsonDict],
]:
# Return an odd set of values to ensure that they get correctly passed
# to the client.
return Codes.LIMIT_EXCEEDED, {"extra": "value"}
class LoginRestServletTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
@ -469,6 +521,58 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
],
)
@override_config(
{
"modules": [
{
"module": TestSpamChecker.__module__
+ "."
+ TestSpamChecker.__qualname__
}
]
}
)
def test_spam_checker_allow(self) -> None:
"""Check that that adding a spam checker doesn't break login."""
self.register_user("kermit", "monkey")
body = {"type": "m.login.password", "user": "kermit", "password": "monkey"}
channel = self.make_request(
"POST",
"/_matrix/client/r0/login",
body,
)
self.assertEqual(channel.code, 200, channel.result)
@override_config(
{
"modules": [
{
"module": DenyAllSpamChecker.__module__
+ "."
+ DenyAllSpamChecker.__qualname__
}
]
}
)
def test_spam_checker_deny(self) -> None:
"""Check that login"""
self.register_user("kermit", "monkey")
body = {"type": "m.login.password", "user": "kermit", "password": "monkey"}
channel = self.make_request(
"POST",
"/_matrix/client/r0/login",
body,
)
self.assertEqual(channel.code, 403, channel.result)
self.assertDictContainsSubset(
{"errcode": Codes.LIMIT_EXCEEDED, "extra": "value"}, channel.json_body
)
@skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC")
class MultiSSOTestCase(unittest.HomeserverTestCase):