Merge branch 'release-v1.34.0' into matrix-org-hotfixes

erikj/disable_catchup_to_hq
Brendan Abolivier 2021-05-12 16:41:04 +01:00
commit 019ed44b84
17 changed files with 276 additions and 120 deletions

1
changelog.d/5588.misc Normal file
View File

@ -0,0 +1 @@
Reduce the length of Synapse's access tokens.

1
changelog.d/9951.feature Normal file
View File

@ -0,0 +1 @@
Improve performance of sending events for worker-based deployments using Redis.

1
changelog.d/9968.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a bug introduced in v1.27.0 preventing users and appservices exempt from ratelimiting from creating rooms with many invitees.

View File

@ -1,4 +1,4 @@
#!/usr/bin/env python2 #!/usr/bin/env python
import sys import sys

View File

@ -57,6 +57,7 @@ class Ratelimiter:
rate_hz: Optional[float] = None, rate_hz: Optional[float] = None,
burst_count: Optional[int] = None, burst_count: Optional[int] = None,
update: bool = True, update: bool = True,
n_actions: int = 1,
_time_now_s: Optional[int] = None, _time_now_s: Optional[int] = None,
) -> Tuple[bool, float]: ) -> Tuple[bool, float]:
"""Can the entity (e.g. user or IP address) perform the action? """Can the entity (e.g. user or IP address) perform the action?
@ -76,6 +77,9 @@ class Ratelimiter:
burst_count: How many actions that can be performed before being limited. burst_count: How many actions that can be performed before being limited.
Overrides the value set during instantiation if set. Overrides the value set during instantiation if set.
update: Whether to count this check as performing the action update: Whether to count this check as performing the action
n_actions: The number of times the user wants to do this action. If the user
cannot do all of the actions, the user's action count is not incremented
at all.
_time_now_s: The current time. Optional, defaults to the current time according _time_now_s: The current time. Optional, defaults to the current time according
to self.clock. Only used by tests. to self.clock. Only used by tests.
@ -124,17 +128,20 @@ class Ratelimiter:
time_delta = time_now_s - time_start time_delta = time_now_s - time_start
performed_count = action_count - time_delta * rate_hz performed_count = action_count - time_delta * rate_hz
if performed_count < 0: if performed_count < 0:
# Allow, reset back to count 1 performed_count = 0
allowed = True
time_start = time_now_s time_start = time_now_s
action_count = 1.0
elif performed_count > burst_count - 1.0: # This check would be easier read as performed_count + n_actions > burst_count,
# but performed_count might be a very precise float (with lots of numbers
# following the point) in which case Python might round it up when adding it to
# n_actions. Writing it this way ensures it doesn't happen.
if performed_count > burst_count - n_actions:
# Deny, we have exceeded our burst count # Deny, we have exceeded our burst count
allowed = False allowed = False
else: else:
# We haven't reached our limit yet # We haven't reached our limit yet
allowed = True allowed = True
action_count += 1.0 action_count = performed_count + n_actions
if update: if update:
self.actions[key] = (action_count, time_start, rate_hz) self.actions[key] = (action_count, time_start, rate_hz)
@ -182,6 +189,7 @@ class Ratelimiter:
rate_hz: Optional[float] = None, rate_hz: Optional[float] = None,
burst_count: Optional[int] = None, burst_count: Optional[int] = None,
update: bool = True, update: bool = True,
n_actions: int = 1,
_time_now_s: Optional[int] = None, _time_now_s: Optional[int] = None,
): ):
"""Checks if an action can be performed. If not, raises a LimitExceededError """Checks if an action can be performed. If not, raises a LimitExceededError
@ -201,6 +209,9 @@ class Ratelimiter:
burst_count: How many actions that can be performed before being limited. burst_count: How many actions that can be performed before being limited.
Overrides the value set during instantiation if set. Overrides the value set during instantiation if set.
update: Whether to count this check as performing the action update: Whether to count this check as performing the action
n_actions: The number of times the user wants to do this action. If the user
cannot do all of the actions, the user's action count is not incremented
at all.
_time_now_s: The current time. Optional, defaults to the current time according _time_now_s: The current time. Optional, defaults to the current time according
to self.clock. Only used by tests. to self.clock. Only used by tests.
@ -216,6 +227,7 @@ class Ratelimiter:
rate_hz=rate_hz, rate_hz=rate_hz,
burst_count=burst_count, burst_count=burst_count,
update=update, update=update,
n_actions=n_actions,
_time_now_s=time_now_s, _time_now_s=time_now_s,
) )

View File

@ -17,6 +17,7 @@ import logging
import time import time
import unicodedata import unicodedata
import urllib.parse import urllib.parse
from binascii import crc32
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
@ -34,6 +35,7 @@ from typing import (
import attr import attr
import bcrypt import bcrypt
import pymacaroons import pymacaroons
import unpaddedbase64
from twisted.web.server import Request from twisted.web.server import Request
@ -66,6 +68,7 @@ from synapse.util import stringutils as stringutils
from synapse.util.async_helpers import maybe_awaitable from synapse.util.async_helpers import maybe_awaitable
from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.stringutils import base62_encode
from synapse.util.threepids import canonicalise_email from synapse.util.threepids import canonicalise_email
if TYPE_CHECKING: if TYPE_CHECKING:
@ -808,10 +811,12 @@ class AuthHandler(BaseHandler):
logger.info( logger.info(
"Logging in user %s as %s%s", user_id, puppets_user_id, fmt_expiry "Logging in user %s as %s%s", user_id, puppets_user_id, fmt_expiry
) )
target_user_id_obj = UserID.from_string(puppets_user_id)
else: else:
logger.info( logger.info(
"Logging in user %s on device %s%s", user_id, device_id, fmt_expiry "Logging in user %s on device %s%s", user_id, device_id, fmt_expiry
) )
target_user_id_obj = UserID.from_string(user_id)
if ( if (
not is_appservice_ghost not is_appservice_ghost
@ -819,7 +824,7 @@ class AuthHandler(BaseHandler):
): ):
await self.auth.check_auth_blocking(user_id) await self.auth.check_auth_blocking(user_id)
access_token = self.macaroon_gen.generate_access_token(user_id) access_token = self.generate_access_token(target_user_id_obj)
await self.store.add_access_token_to_user( await self.store.add_access_token_to_user(
user_id=user_id, user_id=user_id,
token=access_token, token=access_token,
@ -1192,6 +1197,19 @@ class AuthHandler(BaseHandler):
return None return None
return user_id return user_id
def generate_access_token(self, for_user: UserID) -> str:
"""Generates an opaque string, for use as an access token"""
# we use the following format for access tokens:
# syt_<base64 local part>_<random string>_<base62 crc check>
b64local = unpaddedbase64.encode_base64(for_user.localpart.encode("utf-8"))
random_string = stringutils.random_string(20)
base = f"syt_{b64local}_{random_string}"
crc = base62_encode(crc32(base.encode("ascii")), minwidth=6)
return f"{base}_{crc}"
async def validate_short_term_login_token( async def validate_short_term_login_token(
self, login_token: str self, login_token: str
) -> LoginTokenAttributes: ) -> LoginTokenAttributes:
@ -1585,10 +1603,7 @@ class MacaroonGenerator:
hs = attr.ib() hs = attr.ib()
def generate_access_token( def generate_guest_access_token(self, user_id: str) -> str:
self, user_id: str, extra_caveats: Optional[List[str]] = None
) -> str:
extra_caveats = extra_caveats or []
macaroon = self._generate_base_macaroon(user_id) macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("type = access")
# Include a nonce, to make sure that each login gets a different # Include a nonce, to make sure that each login gets a different
@ -1596,8 +1611,7 @@ class MacaroonGenerator:
macaroon.add_first_party_caveat( macaroon.add_first_party_caveat(
"nonce = %s" % (stringutils.random_string_with_symbols(16),) "nonce = %s" % (stringutils.random_string_with_symbols(16),)
) )
for caveat in extra_caveats: macaroon.add_first_party_caveat("guest = true")
macaroon.add_first_party_caveat(caveat)
return macaroon.serialize() return macaroon.serialize()
def generate_short_term_login_token( def generate_short_term_login_token(

View File

@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from twisted.internet import defer
from twisted.internet.interfaces import IDelayedCall from twisted.internet.interfaces import IDelayedCall
from synapse import event_auth from synapse import event_auth
@ -43,14 +44,14 @@ from synapse.events import EventBase
from synapse.events.builder import EventBuilder from synapse.events.builder import EventBuilder
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.logging.context import run_in_background from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.send_event import ReplicationSendEventRestServlet from synapse.replication.http.send_event import ReplicationSendEventRestServlet
from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester
from synapse.util import json_decoder, json_encoder from synapse.util import json_decoder, json_encoder, log_failure
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer, unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.metrics import measure_func from synapse.util.metrics import measure_func
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
@ -979,9 +980,43 @@ class EventCreationHandler:
logger.exception("Failed to encode content: %r", event.content) logger.exception("Failed to encode content: %r", event.content)
raise raise
await self.action_generator.handle_push_actions_for_event(event, context) # We now persist the event (and update the cache in parallel, since we
# don't want to block on it).
result = await make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
self._persist_event,
requester=requester,
event=event,
context=context,
ratelimit=ratelimit,
extra_users=extra_users,
),
run_in_background(
self.cache_joined_hosts_for_event, event, context
).addErrback(log_failure, "cache_joined_hosts_for_event failed"),
],
consumeErrors=True,
)
).addErrback(unwrapFirstError)
await self.cache_joined_hosts_for_event(event, context) return result[0]
async def _persist_event(
self,
requester: Requester,
event: EventBase,
context: EventContext,
ratelimit: bool = True,
extra_users: Optional[List[UserID]] = None,
) -> EventBase:
"""Actually persists the event. Should only be called by
`handle_new_client_event`, and see its docstring for documentation of
the arguments.
"""
await self.action_generator.handle_push_actions_for_event(event, context)
try: try:
# If we're a worker we need to hit out to the master. # If we're a worker we need to hit out to the master.

View File

@ -722,9 +722,7 @@ class RegistrationHandler(BaseHandler):
) )
if is_guest: if is_guest:
assert valid_until_ms is None assert valid_until_ms is None
access_token = self.macaroon_gen.generate_access_token( access_token = self.macaroon_gen.generate_guest_access_token(user_id)
user_id, ["guest = true"]
)
else: else:
access_token = await self._auth_handler.get_access_token_for_user_id( access_token = await self._auth_handler.get_access_token_for_user_id(
user_id, user_id,

View File

@ -32,7 +32,14 @@ from synapse.api.constants import (
RoomCreationPreset, RoomCreationPreset,
RoomEncryptionAlgorithms, RoomEncryptionAlgorithms,
) )
from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError from synapse.api.errors import (
AuthError,
Codes,
LimitExceededError,
NotFoundError,
StoreError,
SynapseError,
)
from synapse.api.filtering import Filter from synapse.api.filtering import Filter
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.events import EventBase from synapse.events import EventBase
@ -126,10 +133,6 @@ class RoomCreationHandler(BaseHandler):
self.third_party_event_rules = hs.get_third_party_event_rules() self.third_party_event_rules = hs.get_third_party_event_rules()
self._invite_burst_count = (
hs.config.ratelimiting.rc_invites_per_room.burst_count
)
async def upgrade_room( async def upgrade_room(
self, requester: Requester, old_room_id: str, new_version: RoomVersion self, requester: Requester, old_room_id: str, new_version: RoomVersion
) -> str: ) -> str:
@ -676,8 +679,18 @@ class RoomCreationHandler(BaseHandler):
invite_3pid_list = [] invite_3pid_list = []
invite_list = [] invite_list = []
if len(invite_list) + len(invite_3pid_list) > self._invite_burst_count: if invite_list or invite_3pid_list:
raise SynapseError(400, "Cannot invite so many users at once") try:
# If there are invites in the request, see if the ratelimiting settings
# allow that number of invites to be sent from the current user.
await self.room_member_handler.ratelimit_multiple_invites(
requester,
room_id=None,
n_invites=len(invite_list) + len(invite_3pid_list),
update=False,
)
except LimitExceededError:
raise SynapseError(400, "Cannot invite so many users at once")
await self.event_creation_handler.assert_accepted_privacy_policy(requester) await self.event_creation_handler.assert_accepted_privacy_policy(requester)

View File

@ -164,6 +164,31 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
async def forget(self, user: UserID, room_id: str) -> None: async def forget(self, user: UserID, room_id: str) -> None:
raise NotImplementedError() raise NotImplementedError()
async def ratelimit_multiple_invites(
self,
requester: Optional[Requester],
room_id: Optional[str],
n_invites: int,
update: bool = True,
):
"""Ratelimit more than one invite sent by the given requester in the given room.
Args:
requester: The requester sending the invites.
room_id: The room the invites are being sent in.
n_invites: The amount of invites to ratelimit for.
update: Whether to update the ratelimiter's cache.
Raises:
LimitExceededError: The requester can't send that many invites in the room.
"""
await self._invites_per_room_limiter.ratelimit(
requester,
room_id,
update=update,
n_actions=n_invites,
)
async def ratelimit_invite( async def ratelimit_invite(
self, self,
requester: Optional[Requester], requester: Optional[Requester],

View File

@ -220,3 +220,23 @@ def strtobool(val: str) -> bool:
return False return False
else: else:
raise ValueError("invalid truth value %r" % (val,)) raise ValueError("invalid truth value %r" % (val,))
_BASE62 = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
def base62_encode(num: int, minwidth: int = 1) -> str:
"""Encode a number using base62
Args:
num: number to be encoded
minwidth: width to pad to, if the number is small
"""
res = ""
while num:
num, rem = divmod(num, 62)
res = _BASE62[rem] + res
# pad to minimum width
pad = "0" * (minwidth - len(res))
return pad + res

View File

@ -21,13 +21,11 @@ from synapse.api.constants import UserTypes
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, AuthError,
Codes, Codes,
InvalidClientCredentialsError,
InvalidClientTokenError, InvalidClientTokenError,
MissingClientTokenError, MissingClientTokenError,
ResourceLimitError, ResourceLimitError,
) )
from synapse.storage.databases.main.registration import TokenLookupResult from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import UserID
from tests import unittest from tests import unittest
from tests.test_utils import simple_async_mock from tests.test_utils import simple_async_mock
@ -253,67 +251,6 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.assertTrue(user_info.is_guest) self.assertTrue(user_info.is_guest)
self.store.get_user_by_id.assert_called_with(user_id) self.store.get_user_by_id.assert_called_with(user_id)
def test_cannot_use_regular_token_as_guest(self):
USER_ID = "@percy:matrix.org"
self.store.add_access_token_to_user = simple_async_mock(None)
self.store.get_device = simple_async_mock(None)
token = self.get_success(
self.hs.get_auth_handler().get_access_token_for_user_id(
USER_ID, "DEVICE", valid_until_ms=None
)
)
self.store.add_access_token_to_user.assert_called_with(
user_id=USER_ID,
token=token,
device_id="DEVICE",
valid_until_ms=None,
puppets_user_id=None,
)
async def get_user(tok):
if token != tok:
return None
return TokenLookupResult(
user_id=USER_ID,
is_guest=False,
token_id=1234,
device_id="DEVICE",
)
self.store.get_user_by_access_token = get_user
self.store.get_user_by_id = simple_async_mock({"is_guest": False})
# check the token works
request = Mock(args={})
request.args[b"access_token"] = [token.encode("ascii")]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = self.get_success(
self.auth.get_user_by_req(request, allow_guest=True)
)
self.assertEqual(UserID.from_string(USER_ID), requester.user)
self.assertFalse(requester.is_guest)
# add an is_guest caveat
mac = pymacaroons.Macaroon.deserialize(token)
mac.add_first_party_caveat("guest = true")
guest_tok = mac.serialize()
# the token should *not* work now
request = Mock(args={})
request.args[b"access_token"] = [guest_tok.encode("ascii")]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
cm = self.get_failure(
self.auth.get_user_by_req(request, allow_guest=True),
InvalidClientCredentialsError,
)
self.assertEqual(401, cm.value.code)
self.assertEqual("Guest access token used for regular user", cm.value.msg)
self.store.get_user_by_id.assert_called_with(USER_ID)
def test_blocking_mau(self): def test_blocking_mau(self):
self.auth_blocking._limit_usage_by_mau = False self.auth_blocking._limit_usage_by_mau = False
self.auth_blocking._max_mau_value = 50 self.auth_blocking._max_mau_value = 50

View File

@ -230,3 +230,60 @@ class TestRatelimiter(unittest.HomeserverTestCase):
# Shouldn't raise # Shouldn't raise
for _ in range(20): for _ in range(20):
self.get_success_or_raise(limiter.ratelimit(requester, _time_now_s=0)) self.get_success_or_raise(limiter.ratelimit(requester, _time_now_s=0))
def test_multiple_actions(self):
limiter = Ratelimiter(
store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=3
)
# Test that 4 actions aren't allowed with a maximum burst of 3.
allowed, time_allowed = self.get_success_or_raise(
limiter.can_do_action(None, key="test_id", n_actions=4, _time_now_s=0)
)
self.assertFalse(allowed)
# Test that 3 actions are allowed with a maximum burst of 3.
allowed, time_allowed = self.get_success_or_raise(
limiter.can_do_action(None, key="test_id", n_actions=3, _time_now_s=0)
)
self.assertTrue(allowed)
self.assertEquals(10.0, time_allowed)
# Test that, after doing these 3 actions, we can't do any more action without
# waiting.
allowed, time_allowed = self.get_success_or_raise(
limiter.can_do_action(None, key="test_id", n_actions=1, _time_now_s=0)
)
self.assertFalse(allowed)
self.assertEquals(10.0, time_allowed)
# Test that after waiting we can do only 1 action.
allowed, time_allowed = self.get_success_or_raise(
limiter.can_do_action(
None,
key="test_id",
update=False,
n_actions=1,
_time_now_s=10,
)
)
self.assertTrue(allowed)
# The time allowed is the current time because we could still repeat the action
# once.
self.assertEquals(10.0, time_allowed)
allowed, time_allowed = self.get_success_or_raise(
limiter.can_do_action(None, key="test_id", n_actions=2, _time_now_s=10)
)
self.assertFalse(allowed)
# The time allowed doesn't change despite allowed being False because, while we
# don't allow 2 actions, we could still do 1.
self.assertEquals(10.0, time_allowed)
# Test that after waiting a bit more we can do 2 actions.
allowed, time_allowed = self.get_success_or_raise(
limiter.can_do_action(None, key="test_id", n_actions=2, _time_now_s=20)
)
self.assertTrue(allowed)
# The time allowed is the current time because we could still repeat the action
# once.
self.assertEquals(20.0, time_allowed)

View File

@ -16,12 +16,17 @@ from unittest.mock import Mock
import pymacaroons import pymacaroons
from synapse.api.errors import AuthError, ResourceLimitError from synapse.api.errors import AuthError, ResourceLimitError
from synapse.rest import admin
from tests import unittest from tests import unittest
from tests.test_utils import make_awaitable from tests.test_utils import make_awaitable
class AuthTestCase(unittest.HomeserverTestCase): class AuthTestCase(unittest.HomeserverTestCase):
servlets = [
admin.register_servlets,
]
def prepare(self, reactor, clock, hs): def prepare(self, reactor, clock, hs):
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
self.macaroon_generator = hs.get_macaroon_generator() self.macaroon_generator = hs.get_macaroon_generator()
@ -35,16 +40,10 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.small_number_of_users = 1 self.small_number_of_users = 1
self.large_number_of_users = 100 self.large_number_of_users = 100
def test_token_is_a_macaroon(self): self.user1 = self.register_user("a_user", "pass")
token = self.macaroon_generator.generate_access_token("some_user")
# Check that we can parse the thing with pymacaroons
macaroon = pymacaroons.Macaroon.deserialize(token)
# The most basic of sanity checks
if "some_user" not in macaroon.inspect():
self.fail("some_user was not in %s" % macaroon.inspect())
def test_macaroon_caveats(self): def test_macaroon_caveats(self):
token = self.macaroon_generator.generate_access_token("a_user") token = self.macaroon_generator.generate_guest_access_token("a_user")
macaroon = pymacaroons.Macaroon.deserialize(token) macaroon = pymacaroons.Macaroon.deserialize(token)
def verify_gen(caveat): def verify_gen(caveat):
@ -59,19 +58,23 @@ class AuthTestCase(unittest.HomeserverTestCase):
def verify_nonce(caveat): def verify_nonce(caveat):
return caveat.startswith("nonce =") return caveat.startswith("nonce =")
def verify_guest(caveat):
return caveat == "guest = true"
v = pymacaroons.Verifier() v = pymacaroons.Verifier()
v.satisfy_general(verify_gen) v.satisfy_general(verify_gen)
v.satisfy_general(verify_user) v.satisfy_general(verify_user)
v.satisfy_general(verify_type) v.satisfy_general(verify_type)
v.satisfy_general(verify_nonce) v.satisfy_general(verify_nonce)
v.satisfy_general(verify_guest)
v.verify(macaroon, self.hs.config.macaroon_secret_key) v.verify(macaroon, self.hs.config.macaroon_secret_key)
def test_short_term_login_token_gives_user_id(self): def test_short_term_login_token_gives_user_id(self):
token = self.macaroon_generator.generate_short_term_login_token( token = self.macaroon_generator.generate_short_term_login_token(
"a_user", "", 5000 self.user1, "", 5000
) )
res = self.get_success(self.auth_handler.validate_short_term_login_token(token)) res = self.get_success(self.auth_handler.validate_short_term_login_token(token))
self.assertEqual("a_user", res.user_id) self.assertEqual(self.user1, res.user_id)
self.assertEqual("", res.auth_provider_id) self.assertEqual("", res.auth_provider_id)
# when we advance the clock, the token should be rejected # when we advance the clock, the token should be rejected
@ -83,22 +86,22 @@ class AuthTestCase(unittest.HomeserverTestCase):
def test_short_term_login_token_gives_auth_provider(self): def test_short_term_login_token_gives_auth_provider(self):
token = self.macaroon_generator.generate_short_term_login_token( token = self.macaroon_generator.generate_short_term_login_token(
"a_user", auth_provider_id="my_idp" self.user1, auth_provider_id="my_idp"
) )
res = self.get_success(self.auth_handler.validate_short_term_login_token(token)) res = self.get_success(self.auth_handler.validate_short_term_login_token(token))
self.assertEqual("a_user", res.user_id) self.assertEqual(self.user1, res.user_id)
self.assertEqual("my_idp", res.auth_provider_id) self.assertEqual("my_idp", res.auth_provider_id)
def test_short_term_login_token_cannot_replace_user_id(self): def test_short_term_login_token_cannot_replace_user_id(self):
token = self.macaroon_generator.generate_short_term_login_token( token = self.macaroon_generator.generate_short_term_login_token(
"a_user", "", 5000 self.user1, "", 5000
) )
macaroon = pymacaroons.Macaroon.deserialize(token) macaroon = pymacaroons.Macaroon.deserialize(token)
res = self.get_success( res = self.get_success(
self.auth_handler.validate_short_term_login_token(macaroon.serialize()) self.auth_handler.validate_short_term_login_token(macaroon.serialize())
) )
self.assertEqual("a_user", res.user_id) self.assertEqual(self.user1, res.user_id)
# add another "user_id" caveat, which might allow us to override the # add another "user_id" caveat, which might allow us to override the
# user_id. # user_id.
@ -114,7 +117,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
# Ensure does not throw exception # Ensure does not throw exception
self.get_success( self.get_success(
self.auth_handler.get_access_token_for_user_id( self.auth_handler.get_access_token_for_user_id(
"user_a", device_id=None, valid_until_ms=None self.user1, device_id=None, valid_until_ms=None
) )
) )
@ -132,7 +135,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.get_failure( self.get_failure(
self.auth_handler.get_access_token_for_user_id( self.auth_handler.get_access_token_for_user_id(
"user_a", device_id=None, valid_until_ms=None self.user1, device_id=None, valid_until_ms=None
), ),
ResourceLimitError, ResourceLimitError,
) )
@ -160,7 +163,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
# If not in monthly active cohort # If not in monthly active cohort
self.get_failure( self.get_failure(
self.auth_handler.get_access_token_for_user_id( self.auth_handler.get_access_token_for_user_id(
"user_a", device_id=None, valid_until_ms=None self.user1, device_id=None, valid_until_ms=None
), ),
ResourceLimitError, ResourceLimitError,
) )
@ -177,7 +180,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
) )
self.get_success( self.get_success(
self.auth_handler.get_access_token_for_user_id( self.auth_handler.get_access_token_for_user_id(
"user_a", device_id=None, valid_until_ms=None self.user1, device_id=None, valid_until_ms=None
) )
) )
self.get_success( self.get_success(
@ -195,7 +198,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
# Ensure does not raise exception # Ensure does not raise exception
self.get_success( self.get_success(
self.auth_handler.get_access_token_for_user_id( self.auth_handler.get_access_token_for_user_id(
"user_a", device_id=None, valid_until_ms=None self.user1, device_id=None, valid_until_ms=None
) )
) )
@ -210,6 +213,6 @@ class AuthTestCase(unittest.HomeserverTestCase):
def _get_macaroon(self): def _get_macaroon(self):
token = self.macaroon_generator.generate_short_term_login_token( token = self.macaroon_generator.generate_short_term_login_token(
"user_a", "", 5000 self.user1, "", 5000
) )
return pymacaroons.Macaroon.deserialize(token) return pymacaroons.Macaroon.deserialize(token)

View File

@ -48,10 +48,6 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.mock_distributor = Mock() self.mock_distributor = Mock()
self.mock_distributor.declare("registered_user") self.mock_distributor.declare("registered_user")
self.mock_captcha_client = Mock() self.mock_captcha_client = Mock()
self.macaroon_generator = Mock(
generate_access_token=Mock(return_value="secret")
)
self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator)
self.handler = self.hs.get_registration_handler() self.handler = self.hs.get_registration_handler()
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()
self.lots_of_users = 100 self.lots_of_users = 100
@ -67,8 +63,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.get_or_create_user(requester, frank.localpart, "Frankie") self.get_or_create_user(requester, frank.localpart, "Frankie")
) )
self.assertEquals(result_user_id, user_id) self.assertEquals(result_user_id, user_id)
self.assertTrue(result_token is not None) self.assertIsInstance(result_token, str)
self.assertEquals(result_token, "secret") self.assertGreater(len(result_token), 20)
def test_if_user_exists(self): def test_if_user_exists(self):
store = self.hs.get_datastore() store = self.hs.get_datastore()
@ -500,7 +496,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
user_id = self.get_success(self.handler.register_user(localpart="user")) user_id = self.get_success(self.handler.register_user(localpart="user"))
# Get an access token. # Get an access token.
token = self.macaroon_generator.generate_access_token(user_id) token = "testtok"
self.get_success( self.get_success(
self.store.add_access_token_to_user( self.store.add_access_token_to_user(
user_id=user_id, token=token, device_id=None, valid_until_ms=None user_id=user_id, token=token, device_id=None, valid_until_ms=None
@ -577,7 +573,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
token = self.macaroon_generator.generate_access_token(user_id) token = self.hs.get_auth_handler().generate_access_token(user)
if need_register: if need_register:
await self.handler.register_with_store( await self.handler.register_with_store(

View File

@ -463,6 +463,43 @@ class RoomsCreateTestCase(RoomBase):
) )
self.assertEquals(400, channel.code) self.assertEquals(400, channel.code)
@unittest.override_config({"rc_invites": {"per_room": {"burst_count": 3}}})
def test_post_room_invitees_ratelimit(self):
"""Test that invites sent when creating a room are ratelimited by a RateLimiter,
which ratelimits them correctly, including by not limiting when the requester is
exempt from ratelimiting.
"""
# Build the request's content. We use local MXIDs because invites over federation
# are more difficult to mock.
content = json.dumps(
{
"invite": [
"@alice1:red",
"@alice2:red",
"@alice3:red",
"@alice4:red",
]
}
).encode("utf8")
# Test that the invites are correctly ratelimited.
channel = self.make_request("POST", "/createRoom", content)
self.assertEqual(400, channel.code)
self.assertEqual(
"Cannot invite so many users at once",
channel.json_body["error"],
)
# Add the current user to the ratelimit overrides, allowing them no ratelimiting.
self.get_success(
self.hs.get_datastore().set_ratelimit_for_user(self.user_id, 0, 0)
)
# Test that the invites aren't ratelimited anymore.
channel = self.make_request("POST", "/createRoom", content)
self.assertEqual(200, channel.code)
class RoomTopicTestCase(RoomBase): class RoomTopicTestCase(RoomBase):
""" Tests /rooms/$room_id/topic REST events. """ """ Tests /rooms/$room_id/topic REST events. """

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.util.stringutils import assert_valid_client_secret from synapse.util.stringutils import assert_valid_client_secret, base62_encode
from .. import unittest from .. import unittest
@ -45,3 +45,9 @@ class StringUtilsTestCase(unittest.TestCase):
for client_secret in bad: for client_secret in bad:
with self.assertRaises(SynapseError): with self.assertRaises(SynapseError):
assert_valid_client_secret(client_secret) assert_valid_client_secret(client_secret)
def test_base62_encode(self):
self.assertEqual("0", base62_encode(0))
self.assertEqual("10", base62_encode(62))
self.assertEqual("1c", base62_encode(100))
self.assertEqual("001c", base62_encode(100, minwidth=4))