Add per user ratelimiting overrides

pull/2208/head
Erik Johnston 2017-05-10 11:05:43 +01:00
parent ca238bc023
commit b990b2fce5
6 changed files with 93 additions and 19 deletions

View File

@ -53,7 +53,20 @@ class BaseHandler(object):
self.event_builder_factory = hs.get_event_builder_factory()
def ratelimit(self, requester):
@defer.inlineCallbacks
def ratelimit(self, requester, update=True):
"""Ratelimits requests.
Args:
requester (Requester)
update (bool): Whether to record that a request is being processed.
Set to False when doing multiple checks for one request (e.g.
to check up front if we would reject the request), and set to
True for the last call for a given request.
Raises:
LimitExceededError if the request should be ratelimited
"""
time_now = self.clock.time()
user_id = requester.user.to_string()
@ -67,10 +80,25 @@ class BaseHandler(object):
if requester.app_service and not requester.app_service.is_rate_limited():
return
# Check if there is a per user override in the DB.
override = yield self.store.get_ratelimit_for_user(user_id)
if override:
# If overriden with a null Hz then ratelimiting has been entirely
# disabled for the user
if not override.messages_per_second:
return
messages_per_second = override.messages_per_second
burst_count = override.burst_count
else:
messages_per_second = self.hs.config.rc_messages_per_second
burst_count = self.hs.config.rc_message_burst_count
allowed, time_allowed = self.ratelimiter.send_message(
user_id, time_now,
msg_rate_hz=self.hs.config.rc_messages_per_second,
burst_count=self.hs.config.rc_message_burst_count,
msg_rate_hz=messages_per_second,
burst_count=burst_count,
update=update,
)
if not allowed:
raise LimitExceededError(

View File

@ -16,7 +16,7 @@
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, Codes, SynapseError, LimitExceededError
from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator
@ -254,17 +254,7 @@ class MessageHandler(BaseHandler):
# We check here if we are currently being rate limited, so that we
# don't do unnecessary work. We check again just before we actually
# send the event.
time_now = self.clock.time()
allowed, time_allowed = self.ratelimiter.send_message(
event.sender, time_now,
msg_rate_hz=self.hs.config.rc_messages_per_second,
burst_count=self.hs.config.rc_message_burst_count,
update=False,
)
if not allowed:
raise LimitExceededError(
retry_after_ms=int(1000 * (time_allowed - time_now)),
)
yield self.ratelimit(requester, update=False)
user = UserID.from_string(event.sender)
@ -499,7 +489,7 @@ class MessageHandler(BaseHandler):
# We now need to go and hit out to wherever we need to hit out to.
if ratelimit:
self.ratelimit(requester)
yield self.ratelimit(requester)
try:
yield self.auth.check_from_context(event, context)

View File

@ -156,7 +156,7 @@ class ProfileHandler(BaseHandler):
if not self.hs.is_mine(user):
return
self.ratelimit(requester)
yield self.ratelimit(requester)
room_ids = yield self.store.get_rooms_for_user(
user.to_string(),

View File

@ -75,7 +75,7 @@ class RoomCreationHandler(BaseHandler):
"""
user_id = requester.user.to_string()
self.ratelimit(requester)
yield self.ratelimit(requester)
if "room_alias_name" in config:
for wchar in string.whitespace:

View File

@ -16,7 +16,7 @@
from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.util.caches.descriptors import cached
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from ._base import SQLBaseStore
from .engines import PostgresEngine, Sqlite3Engine
@ -33,6 +33,11 @@ OpsLevel = collections.namedtuple(
("ban_level", "kick_level", "redact_level",)
)
RatelimitOverride = collections.namedtuple(
"RatelimitOverride",
("messages_per_second", "burst_count",)
)
class RoomStore(SQLBaseStore):
@ -473,3 +478,32 @@ class RoomStore(SQLBaseStore):
return self.runInteraction(
"get_all_new_public_rooms", get_all_new_public_rooms
)
@cachedInlineCallbacks(max_entries=10000)
def get_ratelimit_for_user(self, user_id):
"""Check if there are any overrides for ratelimiting for the given
user
Args:
user_id (str)
Returns:
RatelimitOverride if there is an override, else None. If the contents
of RatelimitOverride are None or 0 then ratelimitng has been
disabled for that user entirely.
"""
row = yield self._simple_select_one(
table="ratelimit_override",
keyvalues={"user_id": user_id},
retcols=("messages_per_second", "burst_count"),
allow_none=True,
desc="get_ratelimit_for_user",
)
if row:
defer.returnValue(RatelimitOverride(
messages_per_second=row["messages_per_second"],
burst_count=row["burst_count"],
))
else:
defer.returnValue(None)

View File

@ -0,0 +1,22 @@
/* Copyright 2017 Vector Creations Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE ratelimit_override (
user_id TEXT NOT NULL,
messages_per_second BIGINT,
burst_count BIGINT
);
CREATE UNIQUE INDEX ratelimit_override_idx ON ratelimit_override(user_id);