Merge pull request #3670 from matrix-org/neilj/mau_sync_block

Block ability to read via sync if mau limit exceeded
pull/3692/head
Neil Johnson 2018-08-14 15:21:31 +00:00 committed by GitHub
commit 414d54b61a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 112 additions and 20 deletions

1
changelog.d/3670.feature Normal file
View File

@ -0,0 +1 @@
Where server is disabled, block ability for locked out users to read new messages

View File

@ -775,15 +775,25 @@ class Auth(object):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def check_auth_blocking(self): def check_auth_blocking(self, user_id=None):
"""Checks if the user should be rejected for some external reason, """Checks if the user should be rejected for some external reason,
such as monthly active user limiting or global disable flag such as monthly active user limiting or global disable flag
Args:
user_id(str|None): If present, checks for presence against existing
MAU cohort
""" """
if self.hs.config.hs_disabled: if self.hs.config.hs_disabled:
raise AuthError( raise AuthError(
403, self.hs.config.hs_disabled_message, errcode=Codes.HS_DISABLED 403, self.hs.config.hs_disabled_message, errcode=Codes.HS_DISABLED
) )
if self.hs.config.limit_usage_by_mau is True: if self.hs.config.limit_usage_by_mau is True:
# If the user is already part of the MAU cohort
if user_id:
timestamp = yield self.store.user_last_seen_monthly_active(user_id)
if timestamp:
return
# Else if there is no room in the MAU bucket, bail
current_mau = yield self.store.get_monthly_active_count() current_mau = yield self.store.get_monthly_active_count()
if current_mau >= self.hs.config.max_mau_value: if current_mau >= self.hs.config.max_mau_value:
raise AuthError( raise AuthError(

View File

@ -191,6 +191,7 @@ class SyncHandler(object):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.response_cache = ResponseCache(hs, "sync") self.response_cache = ResponseCache(hs, "sync")
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.auth = hs.get_auth()
# ExpiringCache((User, Device)) -> LruCache(state_key => event_id) # ExpiringCache((User, Device)) -> LruCache(state_key => event_id)
self.lazy_loaded_members_cache = ExpiringCache( self.lazy_loaded_members_cache = ExpiringCache(
@ -198,19 +199,27 @@ class SyncHandler(object):
max_len=0, expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE, max_len=0, expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE,
) )
@defer.inlineCallbacks
def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0, def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0,
full_state=False): full_state=False):
"""Get the sync for a client if we have new data for it now. Otherwise """Get the sync for a client if we have new data for it now. Otherwise
wait for new data to arrive on the server. If the timeout expires, then wait for new data to arrive on the server. If the timeout expires, then
return an empty sync result. return an empty sync result.
Returns: Returns:
A Deferred SyncResult. Deferred[SyncResult]
""" """
return self.response_cache.wrap( # If the user is not part of the mau group, then check that limits have
# not been exceeded (if not part of the group by this point, almost certain
# auth_blocking will occur)
user_id = sync_config.user.to_string()
yield self.auth.check_auth_blocking(user_id)
res = yield self.response_cache.wrap(
sync_config.request_key, sync_config.request_key,
self._wait_for_sync_for_user, self._wait_for_sync_for_user,
sync_config, since_token, timeout, full_state, sync_config, since_token, timeout, full_state,
) )
defer.returnValue(res)
@defer.inlineCallbacks @defer.inlineCallbacks
def _wait_for_sync_for_user(self, sync_config, since_token, timeout, def _wait_for_sync_for_user(self, sync_config, since_token, timeout,

View File

@ -125,7 +125,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
# is racy. # is racy.
# Have resolved to invalidate the whole cache for now and do # Have resolved to invalidate the whole cache for now and do
# something about it if and when the perf becomes significant # something about it if and when the perf becomes significant
self._user_last_seen_monthly_active.invalidate_all() self.user_last_seen_monthly_active.invalidate_all()
self.get_monthly_active_count.invalidate_all() self.get_monthly_active_count.invalidate_all()
@cached(num_args=0) @cached(num_args=0)
@ -164,11 +164,11 @@ class MonthlyActiveUsersStore(SQLBaseStore):
lock=False, lock=False,
) )
if is_insert: if is_insert:
self._user_last_seen_monthly_active.invalidate((user_id,)) self.user_last_seen_monthly_active.invalidate((user_id,))
self.get_monthly_active_count.invalidate(()) self.get_monthly_active_count.invalidate(())
@cached(num_args=1) @cached(num_args=1)
def _user_last_seen_monthly_active(self, user_id): def user_last_seen_monthly_active(self, user_id):
""" """
Checks if a given user is part of the monthly active user group Checks if a given user is part of the monthly active user group
Arguments: Arguments:
@ -185,7 +185,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
}, },
retcol="timestamp", retcol="timestamp",
allow_none=True, allow_none=True,
desc="_user_last_seen_monthly_active", desc="user_last_seen_monthly_active",
)) ))
@defer.inlineCallbacks @defer.inlineCallbacks
@ -197,7 +197,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
user_id(str): the user_id to query user_id(str): the user_id to query
""" """
if self.hs.config.limit_usage_by_mau: if self.hs.config.limit_usage_by_mau:
last_seen_timestamp = yield self._user_last_seen_monthly_active(user_id) last_seen_timestamp = yield self.user_last_seen_monthly_active(user_id)
now = self.hs.get_clock().time_msec() now = self.hs.get_clock().time_msec()
# We want to reduce to the total number of db writes, and are happy # We want to reduce to the total number of db writes, and are happy

View File

@ -0,0 +1,71 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.errors import AuthError, Codes
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION
from synapse.handlers.sync import SyncConfig, SyncHandler
from synapse.types import UserID
import tests.unittest
import tests.utils
from tests.utils import setup_test_homeserver
class SyncTestCase(tests.unittest.TestCase):
""" Tests Sync Handler. """
@defer.inlineCallbacks
def setUp(self):
self.hs = yield setup_test_homeserver()
self.sync_handler = SyncHandler(self.hs)
self.store = self.hs.get_datastore()
@defer.inlineCallbacks
def test_wait_for_sync_for_user_auth_blocking(self):
user_id1 = "@user1:server"
user_id2 = "@user2:server"
sync_config = self._generate_sync_config(user_id1)
self.hs.config.limit_usage_by_mau = True
self.hs.config.max_mau_value = 1
# Check that the happy case does not throw errors
yield self.store.upsert_monthly_active_user(user_id1)
yield self.sync_handler.wait_for_sync_for_user(sync_config)
# Test that global lock works
self.hs.config.hs_disabled = True
with self.assertRaises(AuthError) as e:
yield self.sync_handler.wait_for_sync_for_user(sync_config)
self.assertEquals(e.exception.errcode, Codes.HS_DISABLED)
self.hs.config.hs_disabled = False
sync_config = self._generate_sync_config(user_id2)
with self.assertRaises(AuthError) as e:
yield self.sync_handler.wait_for_sync_for_user(sync_config)
self.assertEquals(e.exception.errcode, Codes.MAU_LIMIT_EXCEEDED)
def _generate_sync_config(self, user_id):
return SyncConfig(
user=UserID(user_id.split(":")[0][1:], user_id.split(":")[1]),
filter_collection=DEFAULT_FILTER_COLLECTION,
is_guest=False,
request_key="request_key",
device_id="device_id",
)

View File

@ -63,7 +63,7 @@ class ClientIpStoreTestCase(tests.unittest.TestCase):
yield self.store.insert_client_ip( yield self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", "device_id" user_id, "access_token", "ip", "user_agent", "device_id"
) )
active = yield self.store._user_last_seen_monthly_active(user_id) active = yield self.store.user_last_seen_monthly_active(user_id)
self.assertFalse(active) self.assertFalse(active)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -79,7 +79,7 @@ class ClientIpStoreTestCase(tests.unittest.TestCase):
yield self.store.insert_client_ip( yield self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", "device_id" user_id, "access_token", "ip", "user_agent", "device_id"
) )
active = yield self.store._user_last_seen_monthly_active(user_id) active = yield self.store.user_last_seen_monthly_active(user_id)
self.assertFalse(active) self.assertFalse(active)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -87,13 +87,13 @@ class ClientIpStoreTestCase(tests.unittest.TestCase):
self.hs.config.limit_usage_by_mau = True self.hs.config.limit_usage_by_mau = True
self.hs.config.max_mau_value = 50 self.hs.config.max_mau_value = 50
user_id = "@user:server" user_id = "@user:server"
active = yield self.store._user_last_seen_monthly_active(user_id) active = yield self.store.user_last_seen_monthly_active(user_id)
self.assertFalse(active) self.assertFalse(active)
yield self.store.insert_client_ip( yield self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", "device_id" user_id, "access_token", "ip", "user_agent", "device_id"
) )
active = yield self.store._user_last_seen_monthly_active(user_id) active = yield self.store.user_last_seen_monthly_active(user_id)
self.assertTrue(active) self.assertTrue(active)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -102,7 +102,7 @@ class ClientIpStoreTestCase(tests.unittest.TestCase):
self.hs.config.max_mau_value = 50 self.hs.config.max_mau_value = 50
user_id = "@user:server" user_id = "@user:server"
active = yield self.store._user_last_seen_monthly_active(user_id) active = yield self.store.user_last_seen_monthly_active(user_id)
self.assertFalse(active) self.assertFalse(active)
yield self.store.insert_client_ip( yield self.store.insert_client_ip(
@ -111,5 +111,5 @@ class ClientIpStoreTestCase(tests.unittest.TestCase):
yield self.store.insert_client_ip( yield self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", "device_id" user_id, "access_token", "ip", "user_agent", "device_id"
) )
active = yield self.store._user_last_seen_monthly_active(user_id) active = yield self.store.user_last_seen_monthly_active(user_id)
self.assertTrue(active) self.assertTrue(active)

View File

@ -60,9 +60,9 @@ class MonthlyActiveUsersTestCase(tests.unittest.TestCase):
# Test user is marked as active # Test user is marked as active
timestamp = yield self.store._user_last_seen_monthly_active(user1) timestamp = yield self.store.user_last_seen_monthly_active(user1)
self.assertTrue(timestamp) self.assertTrue(timestamp)
timestamp = yield self.store._user_last_seen_monthly_active(user2) timestamp = yield self.store.user_last_seen_monthly_active(user2)
self.assertTrue(timestamp) self.assertTrue(timestamp)
# Test that users are never removed from the db. # Test that users are never removed from the db.
@ -86,17 +86,18 @@ class MonthlyActiveUsersTestCase(tests.unittest.TestCase):
self.assertEqual(1, count) self.assertEqual(1, count)
@defer.inlineCallbacks @defer.inlineCallbacks
def test__user_last_seen_monthly_active(self): def test_user_last_seen_monthly_active(self):
user_id1 = "@user1:server" user_id1 = "@user1:server"
user_id2 = "@user2:server" user_id2 = "@user2:server"
user_id3 = "@user3:server" user_id3 = "@user3:server"
result = yield self.store._user_last_seen_monthly_active(user_id1)
result = yield self.store.user_last_seen_monthly_active(user_id1)
self.assertFalse(result == 0) self.assertFalse(result == 0)
yield self.store.upsert_monthly_active_user(user_id1) yield self.store.upsert_monthly_active_user(user_id1)
yield self.store.upsert_monthly_active_user(user_id2) yield self.store.upsert_monthly_active_user(user_id2)
result = yield self.store._user_last_seen_monthly_active(user_id1) result = yield self.store.user_last_seen_monthly_active(user_id1)
self.assertTrue(result > 0) self.assertTrue(result > 0)
result = yield self.store._user_last_seen_monthly_active(user_id3) result = yield self.store.user_last_seen_monthly_active(user_id3)
self.assertFalse(result == 0) self.assertFalse(result == 0)
@defer.inlineCallbacks @defer.inlineCallbacks