Describe which rate limiter was hit in logs (#16135)

pull/16181/merge
David Robertson 2023-08-30 00:39:39 +01:00 committed by GitHub
parent e9235d92f2
commit 62a1a9be52
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 235 additions and 121 deletions

1
changelog.d/16135.misc Normal file
View File

@ -0,0 +1 @@
Describe which rate limiter was hit in logs.

View File

@ -211,6 +211,11 @@ class SynapseError(CodeMessageException):
def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict":
return cs_error(self.msg, self.errcode, **self._additional_fields)
@property
def debug_context(self) -> Optional[str]:
"""Override this to add debugging context that shouldn't be sent to clients."""
return None
class InvalidAPICallError(SynapseError):
"""You called an existing API endpoint, but fed that endpoint
@ -508,8 +513,8 @@ class LimitExceededError(SynapseError):
def __init__(
self,
limiter_name: str,
code: int = 429,
msg: str = "Too Many Requests",
retry_after_ms: Optional[int] = None,
errcode: str = Codes.LIMIT_EXCEEDED,
):
@ -518,12 +523,17 @@ class LimitExceededError(SynapseError):
if self.include_retry_after_header and retry_after_ms is not None
else None
)
super().__init__(code, msg, errcode, headers=headers)
super().__init__(code, "Too Many Requests", errcode, headers=headers)
self.retry_after_ms = retry_after_ms
self.limiter_name = limiter_name
def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict":
return cs_error(self.msg, self.errcode, retry_after_ms=self.retry_after_ms)
@property
def debug_context(self) -> Optional[str]:
return self.limiter_name
class RoomKeysVersionError(SynapseError):
"""A client has tried to upload to a non-current version of the room_keys store"""

View File

@ -61,12 +61,16 @@ class Ratelimiter:
"""
def __init__(
self, store: DataStore, clock: Clock, rate_hz: float, burst_count: int
self,
store: DataStore,
clock: Clock,
cfg: RatelimitSettings,
):
self.clock = clock
self.rate_hz = rate_hz
self.burst_count = burst_count
self.rate_hz = cfg.per_second
self.burst_count = cfg.burst_count
self.store = store
self._limiter_name = cfg.key
# An ordered dictionary representing the token buckets tracked by this rate
# limiter. Each entry maps a key of arbitrary type to a tuple representing:
@ -305,7 +309,8 @@ class Ratelimiter:
if not allowed:
raise LimitExceededError(
retry_after_ms=int(1000 * (time_allowed - time_now_s))
limiter_name=self._limiter_name,
retry_after_ms=int(1000 * (time_allowed - time_now_s)),
)
@ -322,7 +327,9 @@ class RequestRatelimiter:
# The rate_hz and burst_count are overridden on a per-user basis
self.request_ratelimiter = Ratelimiter(
store=self.store, clock=self.clock, rate_hz=0, burst_count=0
store=self.store,
clock=self.clock,
cfg=RatelimitSettings(key=rc_message.key, per_second=0, burst_count=0),
)
self._rc_message = rc_message
@ -332,8 +339,7 @@ class RequestRatelimiter:
self.admin_redaction_ratelimiter: Optional[Ratelimiter] = Ratelimiter(
store=self.store,
clock=self.clock,
rate_hz=rc_admin_redaction.per_second,
burst_count=rc_admin_redaction.burst_count,
cfg=rc_admin_redaction,
)
else:
self.admin_redaction_ratelimiter = None

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, cast
import attr
@ -21,16 +21,47 @@ from synapse.types import JsonDict
from ._base import Config
@attr.s(slots=True, frozen=True, auto_attribs=True)
class RatelimitSettings:
def __init__(
self,
config: Dict[str, float],
key: str
per_second: float
burst_count: int
@classmethod
def parse(
cls,
config: Dict[str, Any],
key: str,
defaults: Optional[Dict[str, float]] = None,
):
) -> "RatelimitSettings":
"""Parse config[key] as a new-style rate limiter config.
The key may refer to a nested dictionary using a full stop (.) to separate
each nested key. For example, use the key "a.b.c" to parse the following:
a:
b:
c:
per_second: 10
burst_count: 200
If this lookup fails, we'll fallback to the defaults.
"""
defaults = defaults or {"per_second": 0.17, "burst_count": 3.0}
self.per_second = config.get("per_second", defaults["per_second"])
self.burst_count = int(config.get("burst_count", defaults["burst_count"]))
rl_config = config
for part in key.split("."):
rl_config = rl_config.get(part, {})
# By this point we should have hit the rate limiter parameters.
# We don't actually check this though!
rl_config = cast(Dict[str, float], rl_config)
return cls(
key=key,
per_second=rl_config.get("per_second", defaults["per_second"]),
burst_count=int(rl_config.get("burst_count", defaults["burst_count"])),
)
@attr.s(auto_attribs=True)
@ -49,15 +80,14 @@ class RatelimitConfig(Config):
# Load the new-style messages config if it exists. Otherwise fall back
# to the old method.
if "rc_message" in config:
self.rc_message = RatelimitSettings(
config["rc_message"], defaults={"per_second": 0.2, "burst_count": 10.0}
self.rc_message = RatelimitSettings.parse(
config, "rc_message", defaults={"per_second": 0.2, "burst_count": 10.0}
)
else:
self.rc_message = RatelimitSettings(
{
"per_second": config.get("rc_messages_per_second", 0.2),
"burst_count": config.get("rc_message_burst_count", 10.0),
}
key="rc_messages",
per_second=config.get("rc_messages_per_second", 0.2),
burst_count=config.get("rc_message_burst_count", 10.0),
)
# Load the new-style federation config, if it exists. Otherwise, fall
@ -79,51 +109,59 @@ class RatelimitConfig(Config):
}
)
self.rc_registration = RatelimitSettings(config.get("rc_registration", {}))
self.rc_registration = RatelimitSettings.parse(config, "rc_registration", {})
self.rc_registration_token_validity = RatelimitSettings(
config.get("rc_registration_token_validity", {}),
self.rc_registration_token_validity = RatelimitSettings.parse(
config,
"rc_registration_token_validity",
defaults={"per_second": 0.1, "burst_count": 5},
)
# It is reasonable to login with a bunch of devices at once (i.e. when
# setting up an account), but it is *not* valid to continually be
# logging into new devices.
rc_login_config = config.get("rc_login", {})
self.rc_login_address = RatelimitSettings(
rc_login_config.get("address", {}),
self.rc_login_address = RatelimitSettings.parse(
config,
"rc_login.address",
defaults={"per_second": 0.003, "burst_count": 5},
)
self.rc_login_account = RatelimitSettings(
rc_login_config.get("account", {}),
self.rc_login_account = RatelimitSettings.parse(
config,
"rc_login.account",
defaults={"per_second": 0.003, "burst_count": 5},
)
self.rc_login_failed_attempts = RatelimitSettings(
rc_login_config.get("failed_attempts", {})
self.rc_login_failed_attempts = RatelimitSettings.parse(
config,
"rc_login.failed_attempts",
{},
)
self.federation_rr_transactions_per_room_per_second = config.get(
"federation_rr_transactions_per_room_per_second", 50
)
rc_admin_redaction = config.get("rc_admin_redaction")
self.rc_admin_redaction = None
if rc_admin_redaction:
self.rc_admin_redaction = RatelimitSettings(rc_admin_redaction)
if "rc_admin_redaction" in config:
self.rc_admin_redaction = RatelimitSettings.parse(
config, "rc_admin_redaction", {}
)
self.rc_joins_local = RatelimitSettings(
config.get("rc_joins", {}).get("local", {}),
self.rc_joins_local = RatelimitSettings.parse(
config,
"rc_joins.local",
defaults={"per_second": 0.1, "burst_count": 10},
)
self.rc_joins_remote = RatelimitSettings(
config.get("rc_joins", {}).get("remote", {}),
self.rc_joins_remote = RatelimitSettings.parse(
config,
"rc_joins.remote",
defaults={"per_second": 0.01, "burst_count": 10},
)
# Track the rate of joins to a given room. If there are too many, temporarily
# prevent local joins and remote joins via this server.
self.rc_joins_per_room = RatelimitSettings(
config.get("rc_joins_per_room", {}),
self.rc_joins_per_room = RatelimitSettings.parse(
config,
"rc_joins_per_room",
defaults={"per_second": 1, "burst_count": 10},
)
@ -132,31 +170,37 @@ class RatelimitConfig(Config):
# * For requests received over federation this is keyed by the origin.
#
# Note that this isn't exposed in the configuration as it is obscure.
self.rc_key_requests = RatelimitSettings(
config.get("rc_key_requests", {}),
self.rc_key_requests = RatelimitSettings.parse(
config,
"rc_key_requests",
defaults={"per_second": 20, "burst_count": 100},
)
self.rc_3pid_validation = RatelimitSettings(
config.get("rc_3pid_validation") or {},
self.rc_3pid_validation = RatelimitSettings.parse(
config,
"rc_3pid_validation",
defaults={"per_second": 0.003, "burst_count": 5},
)
self.rc_invites_per_room = RatelimitSettings(
config.get("rc_invites", {}).get("per_room", {}),
self.rc_invites_per_room = RatelimitSettings.parse(
config,
"rc_invites.per_room",
defaults={"per_second": 0.3, "burst_count": 10},
)
self.rc_invites_per_user = RatelimitSettings(
config.get("rc_invites", {}).get("per_user", {}),
self.rc_invites_per_user = RatelimitSettings.parse(
config,
"rc_invites.per_user",
defaults={"per_second": 0.003, "burst_count": 5},
)
self.rc_invites_per_issuer = RatelimitSettings(
config.get("rc_invites", {}).get("per_issuer", {}),
self.rc_invites_per_issuer = RatelimitSettings.parse(
config,
"rc_invites.per_issuer",
defaults={"per_second": 0.3, "burst_count": 10},
)
self.rc_third_party_invite = RatelimitSettings(
config.get("rc_third_party_invite", {}),
self.rc_third_party_invite = RatelimitSettings.parse(
config,
"rc_third_party_invite",
defaults={"per_second": 0.0025, "burst_count": 5},
)

View File

@ -218,19 +218,17 @@ class AuthHandler:
self._failed_uia_attempts_ratelimiter = Ratelimiter(
store=self.store,
clock=self.clock,
rate_hz=self.hs.config.ratelimiting.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.ratelimiting.rc_login_failed_attempts.burst_count,
cfg=self.hs.config.ratelimiting.rc_login_failed_attempts,
)
# The number of seconds to keep a UI auth session active.
self._ui_auth_session_timeout = hs.config.auth.ui_auth_session_timeout
# Ratelimitier for failed /login attempts
# Ratelimiter for failed /login attempts
self._failed_login_attempts_ratelimiter = Ratelimiter(
store=self.store,
clock=hs.get_clock(),
rate_hz=self.hs.config.ratelimiting.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.ratelimiting.rc_login_failed_attempts.burst_count,
cfg=self.hs.config.ratelimiting.rc_login_failed_attempts,
)
self._clock = self.hs.get_clock()

View File

@ -90,8 +90,7 @@ class DeviceMessageHandler:
self._ratelimiter = Ratelimiter(
store=self.store,
clock=hs.get_clock(),
rate_hz=hs.config.ratelimiting.rc_key_requests.per_second,
burst_count=hs.config.ratelimiting.rc_key_requests.burst_count,
cfg=hs.config.ratelimiting.rc_key_requests,
)
async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None:

View File

@ -66,14 +66,12 @@ class IdentityHandler:
self._3pid_validation_ratelimiter_ip = Ratelimiter(
store=self.store,
clock=hs.get_clock(),
rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
cfg=hs.config.ratelimiting.rc_3pid_validation,
)
self._3pid_validation_ratelimiter_address = Ratelimiter(
store=self.store,
clock=hs.get_clock(),
rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
cfg=hs.config.ratelimiting.rc_3pid_validation,
)
async def ratelimit_request_token_requests(

View File

@ -112,8 +112,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
self._join_rate_limiter_local = Ratelimiter(
store=self.store,
clock=self.clock,
rate_hz=hs.config.ratelimiting.rc_joins_local.per_second,
burst_count=hs.config.ratelimiting.rc_joins_local.burst_count,
cfg=hs.config.ratelimiting.rc_joins_local,
)
# Tracks joins from local users to rooms this server isn't a member of.
# I.e. joins this server makes by requesting /make_join /send_join from
@ -121,8 +120,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
self._join_rate_limiter_remote = Ratelimiter(
store=self.store,
clock=self.clock,
rate_hz=hs.config.ratelimiting.rc_joins_remote.per_second,
burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count,
cfg=hs.config.ratelimiting.rc_joins_remote,
)
# TODO: find a better place to keep this Ratelimiter.
# It needs to be
@ -135,8 +133,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
self._join_rate_per_room_limiter = Ratelimiter(
store=self.store,
clock=self.clock,
rate_hz=hs.config.ratelimiting.rc_joins_per_room.per_second,
burst_count=hs.config.ratelimiting.rc_joins_per_room.burst_count,
cfg=hs.config.ratelimiting.rc_joins_per_room,
)
# Ratelimiter for invites, keyed by room (across all issuers, all
@ -144,8 +141,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
self._invites_per_room_limiter = Ratelimiter(
store=self.store,
clock=self.clock,
rate_hz=hs.config.ratelimiting.rc_invites_per_room.per_second,
burst_count=hs.config.ratelimiting.rc_invites_per_room.burst_count,
cfg=hs.config.ratelimiting.rc_invites_per_room,
)
# Ratelimiter for invites, keyed by recipient (across all rooms, all
@ -153,8 +149,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
self._invites_per_recipient_limiter = Ratelimiter(
store=self.store,
clock=self.clock,
rate_hz=hs.config.ratelimiting.rc_invites_per_user.per_second,
burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count,
cfg=hs.config.ratelimiting.rc_invites_per_user,
)
# Ratelimiter for invites, keyed by issuer (across all rooms, all
@ -162,15 +157,13 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
self._invites_per_issuer_limiter = Ratelimiter(
store=self.store,
clock=self.clock,
rate_hz=hs.config.ratelimiting.rc_invites_per_issuer.per_second,
burst_count=hs.config.ratelimiting.rc_invites_per_issuer.burst_count,
cfg=hs.config.ratelimiting.rc_invites_per_issuer,
)
self._third_party_invite_limiter = Ratelimiter(
store=self.store,
clock=self.clock,
rate_hz=hs.config.ratelimiting.rc_third_party_invite.per_second,
burst_count=hs.config.ratelimiting.rc_third_party_invite.burst_count,
cfg=hs.config.ratelimiting.rc_third_party_invite,
)
self.request_ratelimiter = hs.get_request_ratelimiter()

View File

@ -35,6 +35,7 @@ from synapse.api.errors import (
UnsupportedRoomVersionError,
)
from synapse.api.ratelimiting import Ratelimiter
from synapse.config.ratelimiting import RatelimitSettings
from synapse.events import EventBase
from synapse.types import JsonDict, Requester, StrCollection
from synapse.util.caches.response_cache import ResponseCache
@ -94,7 +95,9 @@ class RoomSummaryHandler:
self._server_name = hs.hostname
self._federation_client = hs.get_federation_client()
self._ratelimiter = Ratelimiter(
store=self._store, clock=hs.get_clock(), rate_hz=5, burst_count=10
store=self._store,
clock=hs.get_clock(),
cfg=RatelimitSettings("<room summary>", per_second=5, burst_count=10),
)
# If a user tries to fetch the same page multiple times in quick succession,

View File

@ -115,7 +115,13 @@ def return_json_error(
if exc.headers is not None:
for header, value in exc.headers.items():
request.setHeader(header, value)
logger.info("%s SynapseError: %s - %s", request, error_code, exc.msg)
error_ctx = exc.debug_context
if error_ctx:
logger.info(
"%s SynapseError: %s - %s (%s)", request, error_code, exc.msg, error_ctx
)
else:
logger.info("%s SynapseError: %s - %s", request, error_code, exc.msg)
elif f.check(CancelledError):
error_code = HTTP_STATUS_REQUEST_CANCELLED
error_dict = {"error": "Request cancelled", "errcode": Codes.UNKNOWN}

View File

@ -120,14 +120,12 @@ class LoginRestServlet(RestServlet):
self._address_ratelimiter = Ratelimiter(
store=self._main_store,
clock=hs.get_clock(),
rate_hz=self.hs.config.ratelimiting.rc_login_address.per_second,
burst_count=self.hs.config.ratelimiting.rc_login_address.burst_count,
cfg=self.hs.config.ratelimiting.rc_login_address,
)
self._account_ratelimiter = Ratelimiter(
store=self._main_store,
clock=hs.get_clock(),
rate_hz=self.hs.config.ratelimiting.rc_login_account.per_second,
burst_count=self.hs.config.ratelimiting.rc_login_account.burst_count,
cfg=self.hs.config.ratelimiting.rc_login_account,
)
# ensure the CAS/SAML/OIDC handlers are loaded on this worker instance.

View File

@ -16,6 +16,7 @@ import logging
from typing import TYPE_CHECKING, Tuple
from synapse.api.ratelimiting import Ratelimiter
from synapse.config.ratelimiting import RatelimitSettings
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseRequest
@ -66,15 +67,18 @@ class LoginTokenRequestServlet(RestServlet):
self.token_timeout = hs.config.auth.login_via_existing_token_timeout
self._require_ui_auth = hs.config.auth.login_via_existing_require_ui_auth
# Ratelimit aggressively to a maxmimum of 1 request per minute.
# Ratelimit aggressively to a maximum of 1 request per minute.
#
# This endpoint can be used to spawn additional sessions and could be
# abused by a malicious client to create many sessions.
self._ratelimiter = Ratelimiter(
store=self._main_store,
clock=hs.get_clock(),
rate_hz=1 / 60,
burst_count=1,
cfg=RatelimitSettings(
key="<login token request>",
per_second=1 / 60,
burst_count=1,
),
)
@interactive_auth_handler

View File

@ -376,8 +376,7 @@ class RegistrationTokenValidityRestServlet(RestServlet):
self.ratelimiter = Ratelimiter(
store=self.store,
clock=hs.get_clock(),
rate_hz=hs.config.ratelimiting.rc_registration_token_validity.per_second,
burst_count=hs.config.ratelimiting.rc_registration_token_validity.burst_count,
cfg=hs.config.ratelimiting.rc_registration_token_validity,
)
async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:

View File

@ -408,8 +408,7 @@ class HomeServer(metaclass=abc.ABCMeta):
return Ratelimiter(
store=self.get_datastores().main,
clock=self.get_clock(),
rate_hz=self.config.ratelimiting.rc_registration.per_second,
burst_count=self.config.ratelimiting.rc_registration.burst_count,
cfg=self.config.ratelimiting.rc_registration,
)
@cache_in_self

View File

@ -291,7 +291,8 @@ class _PerHostRatelimiter:
if self.metrics_name:
rate_limit_reject_counter.labels(self.metrics_name).inc()
raise LimitExceededError(
retry_after_ms=int(self.window_size / self.sleep_limit)
limiter_name="rc_federation",
retry_after_ms=int(self.window_size / self.sleep_limit),
)
self.request_times.append(time_now)

View File

@ -1,6 +1,5 @@
# Copyright 2023 The Matrix.org Foundation C.I.C.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
@ -13,24 +12,32 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from synapse.api.errors import LimitExceededError
from tests import unittest
class ErrorsTestCase(unittest.TestCase):
class LimitExceededErrorTestCase(unittest.TestCase):
def test_key_appears_in_context_but_not_error_dict(self) -> None:
err = LimitExceededError("needle")
serialised = json.dumps(err.error_dict(None))
self.assertIn("needle", err.debug_context)
self.assertNotIn("needle", serialised)
# Create a sub-class to avoid mutating the class-level property.
class LimitExceededErrorHeaders(LimitExceededError):
include_retry_after_header = True
def test_limit_exceeded_header(self) -> None:
err = ErrorsTestCase.LimitExceededErrorHeaders(retry_after_ms=100)
err = self.LimitExceededErrorHeaders(limiter_name="test", retry_after_ms=100)
self.assertEqual(err.error_dict(None).get("retry_after_ms"), 100)
assert err.headers is not None
self.assertEqual(err.headers.get("Retry-After"), "1")
def test_limit_exceeded_rounding(self) -> None:
err = ErrorsTestCase.LimitExceededErrorHeaders(retry_after_ms=3001)
err = self.LimitExceededErrorHeaders(limiter_name="test", retry_after_ms=3001)
self.assertEqual(err.error_dict(None).get("retry_after_ms"), 3001)
assert err.headers is not None
self.assertEqual(err.headers.get("Retry-After"), "4")

View File

@ -1,5 +1,6 @@
from synapse.api.ratelimiting import LimitExceededError, Ratelimiter
from synapse.appservice import ApplicationService
from synapse.config.ratelimiting import RatelimitSettings
from synapse.types import create_requester
from tests import unittest
@ -10,8 +11,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
limiter = Ratelimiter(
store=self.hs.get_datastores().main,
clock=self.clock,
rate_hz=0.1,
burst_count=1,
cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1),
)
allowed, time_allowed = self.get_success_or_raise(
limiter.can_do_action(None, key="test_id", _time_now_s=0)
@ -43,8 +43,11 @@ class TestRatelimiter(unittest.HomeserverTestCase):
limiter = Ratelimiter(
store=self.hs.get_datastores().main,
clock=self.clock,
rate_hz=0.1,
burst_count=1,
cfg=RatelimitSettings(
key="",
per_second=0.1,
burst_count=1,
),
)
allowed, time_allowed = self.get_success_or_raise(
limiter.can_do_action(as_requester, _time_now_s=0)
@ -76,8 +79,11 @@ class TestRatelimiter(unittest.HomeserverTestCase):
limiter = Ratelimiter(
store=self.hs.get_datastores().main,
clock=self.clock,
rate_hz=0.1,
burst_count=1,
cfg=RatelimitSettings(
key="",
per_second=0.1,
burst_count=1,
),
)
allowed, time_allowed = self.get_success_or_raise(
limiter.can_do_action(as_requester, _time_now_s=0)
@ -101,8 +107,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
limiter = Ratelimiter(
store=self.hs.get_datastores().main,
clock=self.clock,
rate_hz=0.1,
burst_count=1,
cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1),
)
# Shouldn't raise
@ -128,8 +133,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
limiter = Ratelimiter(
store=self.hs.get_datastores().main,
clock=self.clock,
rate_hz=0.1,
burst_count=1,
cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1),
)
# First attempt should be allowed
@ -177,8 +181,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
limiter = Ratelimiter(
store=self.hs.get_datastores().main,
clock=self.clock,
rate_hz=0.1,
burst_count=1,
cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1),
)
# First attempt should be allowed
@ -208,8 +211,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
limiter = Ratelimiter(
store=self.hs.get_datastores().main,
clock=self.clock,
rate_hz=0.1,
burst_count=1,
cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1),
)
self.get_success_or_raise(
limiter.can_do_action(None, key="test_id_1", _time_now_s=0)
@ -244,7 +246,11 @@ class TestRatelimiter(unittest.HomeserverTestCase):
)
)
limiter = Ratelimiter(store=store, clock=self.clock, rate_hz=0.1, burst_count=1)
limiter = Ratelimiter(
store=store,
clock=self.clock,
cfg=RatelimitSettings("", per_second=0.1, burst_count=1),
)
# Shouldn't raise
for _ in range(20):
@ -254,8 +260,11 @@ class TestRatelimiter(unittest.HomeserverTestCase):
limiter = Ratelimiter(
store=self.hs.get_datastores().main,
clock=self.clock,
rate_hz=0.1,
burst_count=3,
cfg=RatelimitSettings(
key="",
per_second=0.1,
burst_count=3,
),
)
# Test that 4 actions aren't allowed with a maximum burst of 3.
allowed, time_allowed = self.get_success_or_raise(
@ -321,8 +330,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
limiter = Ratelimiter(
store=self.hs.get_datastores().main,
clock=self.clock,
rate_hz=0.1,
burst_count=3,
cfg=RatelimitSettings("", per_second=0.1, burst_count=3),
)
def consume_at(time: float) -> bool:
@ -346,8 +354,11 @@ class TestRatelimiter(unittest.HomeserverTestCase):
limiter = Ratelimiter(
store=self.hs.get_datastores().main,
clock=self.clock,
rate_hz=0.1,
burst_count=3,
cfg=RatelimitSettings(
"",
per_second=0.1,
burst_count=3,
),
)
# Observe two actions, leaving room in the bucket for one more.
@ -369,8 +380,11 @@ class TestRatelimiter(unittest.HomeserverTestCase):
limiter = Ratelimiter(
store=self.hs.get_datastores().main,
clock=self.clock,
rate_hz=0.1,
burst_count=3,
cfg=RatelimitSettings(
"",
per_second=0.1,
burst_count=3,
),
)
# Observe three actions, filling up the bucket.
@ -398,8 +412,11 @@ class TestRatelimiter(unittest.HomeserverTestCase):
limiter = Ratelimiter(
store=self.hs.get_datastores().main,
clock=self.clock,
rate_hz=0.1,
burst_count=3,
cfg=RatelimitSettings(
"",
per_second=0.1,
burst_count=3,
),
)
# Observe four actions, exceeding the bucket.

View File

@ -12,11 +12,42 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.config.homeserver import HomeServerConfig
from synapse.config.ratelimiting import RatelimitSettings
from tests.unittest import TestCase
from tests.utils import default_config
class ParseRatelimitSettingsTestcase(TestCase):
def test_depth_1(self) -> None:
cfg = {
"a": {
"per_second": 5,
"burst_count": 10,
}
}
parsed = RatelimitSettings.parse(cfg, "a")
self.assertEqual(parsed, RatelimitSettings("a", 5, 10))
def test_depth_2(self) -> None:
cfg = {
"a": {
"b": {
"per_second": 5,
"burst_count": 10,
},
}
}
parsed = RatelimitSettings.parse(cfg, "a.b")
self.assertEqual(parsed, RatelimitSettings("a.b", 5, 10))
def test_missing(self) -> None:
parsed = RatelimitSettings.parse(
{}, "a", defaults={"per_second": 5, "burst_count": 10}
)
self.assertEqual(parsed, RatelimitSettings("a", 5, 10))
class RatelimitConfigTestCase(TestCase):
def test_parse_rc_federation(self) -> None:
config_dict = default_config("test")