Merge pull request #3633 from matrix-org/neilj/mau_tracker

API for monthly_active_users table
pull/3669/head
Neil Johnson 2018-08-08 12:44:15 +00:00 committed by GitHub
commit 990fe9fc23
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 414 additions and 163 deletions

1
changelog.d/3633.feature Normal file
View File

@ -0,0 +1 @@
Add ability to limit number of monthly active users on the server

View File

@ -213,7 +213,7 @@ class Auth(object):
default=[b""] default=[b""]
)[0] )[0]
if user and access_token and ip_addr: if user and access_token and ip_addr:
self.store.insert_client_ip( yield self.store.insert_client_ip(
user_id=user.to_string(), user_id=user.to_string(),
access_token=access_token, access_token=access_token,
ip=ip_addr, ip=ip_addr,
@ -773,3 +773,15 @@ class Auth(object):
raise AuthError( raise AuthError(
403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN 403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
) )
@defer.inlineCallbacks
def check_auth_blocking(self):
"""Checks if the user should be rejected for some external reason,
such as monthly active user limiting or global disable flag
"""
if self.hs.config.limit_usage_by_mau is True:
current_mau = yield self.store.get_monthly_active_count()
if current_mau >= self.hs.config.max_mau_value:
raise AuthError(
403, "MAU Limit Exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED
)

View File

@ -518,12 +518,15 @@ def run(hs):
# If you increase the loop period, the accuracy of user_daily_visits # If you increase the loop period, the accuracy of user_daily_visits
# table will decrease # table will decrease
clock.looping_call(generate_user_daily_visit_stats, 5 * 60 * 1000) clock.looping_call(generate_user_daily_visit_stats, 5 * 60 * 1000)
clock.looping_call(
hs.get_datastore().reap_monthly_active_users, 1000 * 60 * 60
)
@defer.inlineCallbacks @defer.inlineCallbacks
def generate_monthly_active_users(): def generate_monthly_active_users():
count = 0 count = 0
if hs.config.limit_usage_by_mau: if hs.config.limit_usage_by_mau:
count = yield hs.get_datastore().count_monthly_users() count = yield hs.get_datastore().get_monthly_active_count()
current_mau_gauge.set(float(count)) current_mau_gauge.set(float(count))
max_mau_value_gauge.set(float(hs.config.max_mau_value)) max_mau_value_gauge.set(float(hs.config.max_mau_value))

View File

@ -69,12 +69,12 @@ class ServerConfig(Config):
# Options to control access by tracking MAU # Options to control access by tracking MAU
self.limit_usage_by_mau = config.get("limit_usage_by_mau", False) self.limit_usage_by_mau = config.get("limit_usage_by_mau", False)
self.max_mau_value = 0
if self.limit_usage_by_mau: if self.limit_usage_by_mau:
self.max_mau_value = config.get( self.max_mau_value = config.get(
"max_mau_value", 0, "max_mau_value", 0,
) )
else:
self.max_mau_value = 0
# FIXME: federation_domain_whitelist needs sytests # FIXME: federation_domain_whitelist needs sytests
self.federation_domain_whitelist = None self.federation_domain_whitelist = None
federation_domain_whitelist = config.get( federation_domain_whitelist = config.get(

View File

@ -520,7 +520,7 @@ class AuthHandler(BaseHandler):
""" """
logger.info("Logging in user %s on device %s", user_id, device_id) logger.info("Logging in user %s on device %s", user_id, device_id)
access_token = yield self.issue_access_token(user_id, device_id) access_token = yield self.issue_access_token(user_id, device_id)
yield self._check_mau_limits() yield self.auth.check_auth_blocking()
# the device *should* have been registered before we got here; however, # the device *should* have been registered before we got here; however,
# it's possible we raced against a DELETE operation. The thing we # it's possible we raced against a DELETE operation. The thing we
@ -734,7 +734,7 @@ class AuthHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def validate_short_term_login_token_and_get_user_id(self, login_token): def validate_short_term_login_token_and_get_user_id(self, login_token):
yield self._check_mau_limits() yield self.auth.check_auth_blocking()
auth_api = self.hs.get_auth() auth_api = self.hs.get_auth()
user_id = None user_id = None
try: try:
@ -907,19 +907,6 @@ class AuthHandler(BaseHandler):
else: else:
return defer.succeed(False) return defer.succeed(False)
@defer.inlineCallbacks
def _check_mau_limits(self):
"""
Ensure that if mau blocking is enabled that invalid users cannot
log in.
"""
if self.hs.config.limit_usage_by_mau is True:
current_mau = yield self.store.count_monthly_users()
if current_mau >= self.hs.config.max_mau_value:
raise AuthError(
403, "MAU Limit Exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED
)
@attr.s @attr.s
class MacaroonGenerator(object): class MacaroonGenerator(object):

View File

@ -540,9 +540,7 @@ class RegistrationHandler(BaseHandler):
Do not accept registrations if monthly active user limits exceeded Do not accept registrations if monthly active user limits exceeded
and limiting is enabled and limiting is enabled
""" """
if self.hs.config.limit_usage_by_mau is True: try:
current_mau = yield self.store.count_monthly_users() yield self.auth.check_auth_blocking()
if current_mau >= self.hs.config.max_mau_value: except AuthError as e:
raise RegistrationError( raise RegistrationError(e.code, str(e), e.errcode)
403, "MAU Limit Exceeded", Codes.MAU_LIMIT_EXCEEDED
)

View File

@ -39,6 +39,7 @@ from .filtering import FilteringStore
from .group_server import GroupServerStore from .group_server import GroupServerStore
from .keys import KeyStore from .keys import KeyStore
from .media_repository import MediaRepositoryStore from .media_repository import MediaRepositoryStore
from .monthly_active_users import MonthlyActiveUsersStore
from .openid import OpenIdStore from .openid import OpenIdStore
from .presence import PresenceStore, UserPresenceState from .presence import PresenceStore, UserPresenceState
from .profile import ProfileStore from .profile import ProfileStore
@ -87,6 +88,7 @@ class DataStore(RoomMemberStore, RoomStore,
UserDirectoryStore, UserDirectoryStore,
GroupServerStore, GroupServerStore,
UserErasureStore, UserErasureStore,
MonthlyActiveUsersStore,
): ):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
@ -94,7 +96,6 @@ class DataStore(RoomMemberStore, RoomStore,
self._clock = hs.get_clock() self._clock = hs.get_clock()
self.database_engine = hs.database_engine self.database_engine = hs.database_engine
self.db_conn = db_conn
self._stream_id_gen = StreamIdGenerator( self._stream_id_gen = StreamIdGenerator(
db_conn, "events", "stream_ordering", db_conn, "events", "stream_ordering",
extra_tables=[("local_invites", "stream_id")] extra_tables=[("local_invites", "stream_id")]
@ -267,31 +268,6 @@ class DataStore(RoomMemberStore, RoomStore,
return self.runInteraction("count_users", _count_users) return self.runInteraction("count_users", _count_users)
def count_monthly_users(self):
"""Counts the number of users who used this homeserver in the last 30 days
This method should be refactored with count_daily_users - the only
reason not to is waiting on definition of mau
Returns:
Defered[int]
"""
def _count_monthly_users(txn):
thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
sql = """
SELECT COALESCE(count(*), 0) FROM (
SELECT user_id FROM user_ips
WHERE last_seen > ?
GROUP BY user_id
) u
"""
txn.execute(sql, (thirty_days_ago,))
count, = txn.fetchone()
return count
return self.runInteraction("count_monthly_users", _count_monthly_users)
def count_r30_users(self): def count_r30_users(self):
""" """
Counts the number of 30 day retained users, defined as:- Counts the number of 30 day retained users, defined as:-

View File

@ -35,6 +35,7 @@ LAST_SEEN_GRANULARITY = 120 * 1000
class ClientIpStore(background_updates.BackgroundUpdateStore): class ClientIpStore(background_updates.BackgroundUpdateStore):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
self.client_ip_last_seen = Cache( self.client_ip_last_seen = Cache(
name="client_ip_last_seen", name="client_ip_last_seen",
keylen=4, keylen=4,
@ -74,6 +75,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
"before", "shutdown", self._update_client_ips_batch "before", "shutdown", self._update_client_ips_batch
) )
@defer.inlineCallbacks
def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id, def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id,
now=None): now=None):
if not now: if not now:
@ -84,7 +86,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
last_seen = self.client_ip_last_seen.get(key) last_seen = self.client_ip_last_seen.get(key)
except KeyError: except KeyError:
last_seen = None last_seen = None
yield self.populate_monthly_active_users(user_id)
# Rate-limited inserts # Rate-limited inserts
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY: if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
return return

View File

@ -0,0 +1,162 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.util.caches.descriptors import cached
from ._base import SQLBaseStore
# Number of msec of granularity to store the monthly_active_user timestamp
# This means it is not necessary to update the table on every request
LAST_SEEN_GRANULARITY = 60 * 60 * 1000
class MonthlyActiveUsersStore(SQLBaseStore):
def __init__(self, dbconn, hs):
super(MonthlyActiveUsersStore, self).__init__(None, hs)
self._clock = hs.get_clock()
self.hs = hs
def reap_monthly_active_users(self):
"""
Cleans out monthly active user table to ensure that no stale
entries exist.
Returns:
Deferred[]
"""
def _reap_users(txn):
thirty_days_ago = (
int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
)
# Purge stale users
sql = "DELETE FROM monthly_active_users WHERE timestamp < ?"
txn.execute(sql, (thirty_days_ago,))
# If MAU user count still exceeds the MAU threshold, then delete on
# a least recently active basis.
# Note it is not possible to write this query using OFFSET due to
# incompatibilities in how sqlite and postgres support the feature.
# sqlite requires 'LIMIT -1 OFFSET ?', the LIMIT must be present
# While Postgres does not require 'LIMIT', but also does not support
# negative LIMIT values. So there is no way to write it that both can
# support
sql = """
DELETE FROM monthly_active_users
WHERE user_id NOT IN (
SELECT user_id FROM monthly_active_users
ORDER BY timestamp DESC
LIMIT ?
)
"""
txn.execute(sql, (self.hs.config.max_mau_value,))
yield self.runInteraction("reap_monthly_active_users", _reap_users)
# It seems poor to invalidate the whole cache, Postgres supports
# 'Returning' which would allow me to invalidate only the
# specific users, but sqlite has no way to do this and instead
# I would need to SELECT and the DELETE which without locking
# is racy.
# Have resolved to invalidate the whole cache for now and do
# something about it if and when the perf becomes significant
self._user_last_seen_monthly_active.invalidate_all()
self.get_monthly_active_count.invalidate_all()
@cached(num_args=0)
def get_monthly_active_count(self):
"""Generates current count of monthly active users
Returns:
Defered[int]: Number of current monthly active users
"""
def _count_users(txn):
sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users"
txn.execute(sql)
count, = txn.fetchone()
return count
return self.runInteraction("count_users", _count_users)
def upsert_monthly_active_user(self, user_id):
"""
Updates or inserts monthly active user member
Arguments:
user_id (str): user to add/update
Deferred[bool]: True if a new entry was created, False if an
existing one was updated.
"""
is_insert = self._simple_upsert(
desc="upsert_monthly_active_user",
table="monthly_active_users",
keyvalues={
"user_id": user_id,
},
values={
"timestamp": int(self._clock.time_msec()),
},
lock=False,
)
if is_insert:
self._user_last_seen_monthly_active.invalidate((user_id,))
self.get_monthly_active_count.invalidate(())
@cached(num_args=1)
def _user_last_seen_monthly_active(self, user_id):
"""
Checks if a given user is part of the monthly active user group
Arguments:
user_id (str): user to add/update
Return:
int : timestamp since last seen, None if never seen
"""
return(self._simple_select_one_onecol(
table="monthly_active_users",
keyvalues={
"user_id": user_id,
},
retcol="timestamp",
allow_none=True,
desc="_user_last_seen_monthly_active",
))
@defer.inlineCallbacks
def populate_monthly_active_users(self, user_id):
"""Checks on the state of monthly active user limits and optionally
add the user to the monthly active tables
Args:
user_id(str): the user_id to query
"""
if self.hs.config.limit_usage_by_mau:
last_seen_timestamp = yield self._user_last_seen_monthly_active(user_id)
now = self.hs.get_clock().time_msec()
# We want to reduce to the total number of db writes, and are happy
# to trade accuracy of timestamp in order to lighten load. This means
# We always insert new users (where MAU threshold has not been reached),
# but only update if we have not previously seen the user for
# LAST_SEEN_GRANULARITY ms
if last_seen_timestamp is None:
count = yield self.get_monthly_active_count()
if count < self.hs.config.max_mau_value:
yield self.upsert_monthly_active_user(user_id)
elif now - last_seen_timestamp > LAST_SEEN_GRANULARITY:
yield self.upsert_monthly_active_user(user_id)

View File

@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database # Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts. # schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 50 SCHEMA_VERSION = 51
dir_path = os.path.abspath(os.path.dirname(__file__)) dir_path = os.path.abspath(os.path.dirname(__file__))

View File

@ -0,0 +1,27 @@
/* Copyright 2018 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- a table of monthly active users, for use where blocking based on mau limits
CREATE TABLE monthly_active_users (
user_id TEXT NOT NULL,
-- Last time we saw the user. Not guaranteed to be accurate due to rate limiting
-- on updates, Granularity of updates governed by
-- synapse.storage.monthly_active_users.LAST_SEEN_GRANULARITY
-- Measured in ms since epoch.
timestamp BIGINT NOT NULL
);
CREATE UNIQUE INDEX monthly_active_users_users ON monthly_active_users(user_id);
CREATE INDEX monthly_active_users_time_stamp ON monthly_active_users(timestamp);

View File

@ -444,3 +444,28 @@ class AuthTestCase(unittest.TestCase):
self.assertEqual("Guest access token used for regular user", cm.exception.msg) self.assertEqual("Guest access token used for regular user", cm.exception.msg)
self.store.get_user_by_id.assert_called_with(USER_ID) self.store.get_user_by_id.assert_called_with(USER_ID)
@defer.inlineCallbacks
def test_blocking_mau(self):
self.hs.config.limit_usage_by_mau = False
self.hs.config.max_mau_value = 50
lots_of_users = 100
small_number_of_users = 1
# Ensure no error thrown
yield self.auth.check_auth_blocking()
self.hs.config.limit_usage_by_mau = True
self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(lots_of_users)
)
with self.assertRaises(AuthError):
yield self.auth.check_auth_blocking()
# Ensure does not throw an error
self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(small_number_of_users)
)
yield self.auth.check_auth_blocking()

View File

@ -132,14 +132,14 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_mau_limits_exceeded(self): def test_mau_limits_exceeded(self):
self.hs.config.limit_usage_by_mau = True self.hs.config.limit_usage_by_mau = True
self.hs.get_datastore().count_monthly_users = Mock( self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.large_number_of_users) return_value=defer.succeed(self.large_number_of_users)
) )
with self.assertRaises(AuthError): with self.assertRaises(AuthError):
yield self.auth_handler.get_access_token_for_user_id('user_a') yield self.auth_handler.get_access_token_for_user_id('user_a')
self.hs.get_datastore().count_monthly_users = Mock( self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.large_number_of_users) return_value=defer.succeed(self.large_number_of_users)
) )
with self.assertRaises(AuthError): with self.assertRaises(AuthError):
@ -151,13 +151,13 @@ class AuthTestCase(unittest.TestCase):
def test_mau_limits_not_exceeded(self): def test_mau_limits_not_exceeded(self):
self.hs.config.limit_usage_by_mau = True self.hs.config.limit_usage_by_mau = True
self.hs.get_datastore().count_monthly_users = Mock( self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.small_number_of_users) return_value=defer.succeed(self.small_number_of_users)
) )
# Ensure does not raise exception # Ensure does not raise exception
yield self.auth_handler.get_access_token_for_user_id('user_a') yield self.auth_handler.get_access_token_for_user_id('user_a')
self.hs.get_datastore().count_monthly_users = Mock( self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.small_number_of_users) return_value=defer.succeed(self.small_number_of_users)
) )
yield self.auth_handler.validate_short_term_login_token_and_get_user_id( yield self.auth_handler.validate_short_term_login_token_and_get_user_id(

View File

@ -50,6 +50,10 @@ class RegistrationTestCase(unittest.TestCase):
self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator) self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator)
self.hs.handlers = RegistrationHandlers(self.hs) self.hs.handlers = RegistrationHandlers(self.hs)
self.handler = self.hs.get_handlers().registration_handler self.handler = self.hs.get_handlers().registration_handler
self.store = self.hs.get_datastore()
self.hs.config.max_mau_value = 50
self.lots_of_users = 100
self.small_number_of_users = 1
@defer.inlineCallbacks @defer.inlineCallbacks
def test_user_is_created_and_logged_in_if_doesnt_exist(self): def test_user_is_created_and_logged_in_if_doesnt_exist(self):
@ -80,51 +84,43 @@ class RegistrationTestCase(unittest.TestCase):
self.assertEquals(result_token, 'secret') self.assertEquals(result_token, 'secret')
@defer.inlineCallbacks @defer.inlineCallbacks
def test_cannot_register_when_mau_limits_exceeded(self): def test_mau_limits_when_disabled(self):
local_part = "someone"
display_name = "someone"
requester = create_requester("@as:test")
store = self.hs.get_datastore()
self.hs.config.limit_usage_by_mau = False self.hs.config.limit_usage_by_mau = False
self.hs.config.max_mau_value = 50
lots_of_users = 100
small_number_users = 1
store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users))
# Ensure does not throw exception # Ensure does not throw exception
yield self.handler.get_or_create_user(requester, 'a', display_name) yield self.handler.get_or_create_user("requester", 'a', "display_name")
@defer.inlineCallbacks
def test_get_or_create_user_mau_not_blocked(self):
self.hs.config.limit_usage_by_mau = True self.hs.config.limit_usage_by_mau = True
self.store.count_monthly_users = Mock(
with self.assertRaises(RegistrationError): return_value=defer.succeed(self.small_number_of_users)
yield self.handler.get_or_create_user(requester, 'b', display_name) )
store.count_monthly_users = Mock(return_value=defer.succeed(small_number_users))
self._macaroon_mock_generator("another_secret")
# Ensure does not throw exception # Ensure does not throw exception
yield self.handler.get_or_create_user("@neil:matrix.org", 'c', "Neil") yield self.handler.get_or_create_user("@user:server", 'c', "User")
self._macaroon_mock_generator("another another secret")
store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users))
@defer.inlineCallbacks
def test_get_or_create_user_mau_blocked(self):
self.hs.config.limit_usage_by_mau = True
self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.lots_of_users)
)
with self.assertRaises(RegistrationError): with self.assertRaises(RegistrationError):
yield self.handler.register(localpart=local_part) yield self.handler.get_or_create_user("requester", 'b', "display_name")
self._macaroon_mock_generator("another another secret")
store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users))
@defer.inlineCallbacks
def test_register_mau_blocked(self):
self.hs.config.limit_usage_by_mau = True
self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.lots_of_users)
)
with self.assertRaises(RegistrationError): with self.assertRaises(RegistrationError):
yield self.handler.register_saml2(local_part) yield self.handler.register(localpart="local_part")
def _macaroon_mock_generator(self, secret): @defer.inlineCallbacks
""" def test_register_saml2_mau_blocked(self):
Reset macaroon generator in the case where the test creates multiple users self.hs.config.limit_usage_by_mau = True
""" self.store.get_monthly_active_count = Mock(
macaroon_generator = Mock( return_value=defer.succeed(self.lots_of_users)
generate_access_token=Mock(return_value=secret)) )
self.hs.get_macaroon_generator = Mock(return_value=macaroon_generator) with self.assertRaises(RegistrationError):
self.hs.handlers = RegistrationHandlers(self.hs) yield self.handler.register_saml2(localpart="local_part")
self.handler = self.hs.get_handlers().registration_handler

View File

@ -1,65 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
import tests.utils
class InitTestCase(tests.unittest.TestCase):
def __init__(self, *args, **kwargs):
super(InitTestCase, self).__init__(*args, **kwargs)
self.store = None # type: synapse.storage.DataStore
@defer.inlineCallbacks
def setUp(self):
hs = yield tests.utils.setup_test_homeserver()
hs.config.max_mau_value = 50
hs.config.limit_usage_by_mau = True
self.store = hs.get_datastore()
self.clock = hs.get_clock()
@defer.inlineCallbacks
def test_count_monthly_users(self):
count = yield self.store.count_monthly_users()
self.assertEqual(0, count)
yield self._insert_user_ips("@user:server1")
yield self._insert_user_ips("@user:server2")
count = yield self.store.count_monthly_users()
self.assertEqual(2, count)
@defer.inlineCallbacks
def _insert_user_ips(self, user):
"""
Helper function to populate user_ips without using batch insertion infra
args:
user (str): specify username i.e. @user:server.com
"""
yield self.store._simple_upsert(
table="user_ips",
keyvalues={
"user_id": user,
"access_token": "access_token",
"ip": "ip",
"user_agent": "user_agent",
"device_id": "device_id",
},
values={
"last_seen": self.clock.time_msec(),
}
)

View File

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from mock import Mock
from twisted.internet import defer from twisted.internet import defer
@ -27,9 +28,9 @@ class ClientIpStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def setUp(self): def setUp(self):
hs = yield tests.utils.setup_test_homeserver() self.hs = yield tests.utils.setup_test_homeserver()
self.store = hs.get_datastore() self.store = self.hs.get_datastore()
self.clock = hs.get_clock() self.clock = self.hs.get_clock()
@defer.inlineCallbacks @defer.inlineCallbacks
def test_insert_new_client_ip(self): def test_insert_new_client_ip(self):
@ -54,3 +55,62 @@ class ClientIpStoreTestCase(tests.unittest.TestCase):
}, },
r r
) )
@defer.inlineCallbacks
def test_disabled_monthly_active_user(self):
self.hs.config.limit_usage_by_mau = False
self.hs.config.max_mau_value = 50
user_id = "@user:server"
yield self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", "device_id",
)
active = yield self.store._user_last_seen_monthly_active(user_id)
self.assertFalse(active)
@defer.inlineCallbacks
def test_adding_monthly_active_user_when_full(self):
self.hs.config.limit_usage_by_mau = True
self.hs.config.max_mau_value = 50
lots_of_users = 100
user_id = "@user:server"
self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(lots_of_users)
)
yield self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", "device_id",
)
active = yield self.store._user_last_seen_monthly_active(user_id)
self.assertFalse(active)
@defer.inlineCallbacks
def test_adding_monthly_active_user_when_space(self):
self.hs.config.limit_usage_by_mau = True
self.hs.config.max_mau_value = 50
user_id = "@user:server"
active = yield self.store._user_last_seen_monthly_active(user_id)
self.assertFalse(active)
yield self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", "device_id",
)
active = yield self.store._user_last_seen_monthly_active(user_id)
self.assertTrue(active)
@defer.inlineCallbacks
def test_updating_monthly_active_user_when_space(self):
self.hs.config.limit_usage_by_mau = True
self.hs.config.max_mau_value = 50
user_id = "@user:server"
active = yield self.store._user_last_seen_monthly_active(user_id)
self.assertFalse(active)
yield self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", "device_id",
)
yield self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", "device_id",
)
active = yield self.store._user_last_seen_monthly_active(user_id)
self.assertTrue(active)

View File

@ -0,0 +1,66 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
import tests.unittest
import tests.utils
from tests.utils import setup_test_homeserver
class MonthlyActiveUsersTestCase(tests.unittest.TestCase):
def __init__(self, *args, **kwargs):
super(MonthlyActiveUsersTestCase, self).__init__(*args, **kwargs)
@defer.inlineCallbacks
def setUp(self):
self.hs = yield setup_test_homeserver()
self.store = self.hs.get_datastore()
@defer.inlineCallbacks
def test_can_insert_and_count_mau(self):
count = yield self.store.get_monthly_active_count()
self.assertEqual(0, count)
yield self.store.upsert_monthly_active_user("@user:server")
count = yield self.store.get_monthly_active_count()
self.assertEqual(1, count)
@defer.inlineCallbacks
def test__user_last_seen_monthly_active(self):
user_id1 = "@user1:server"
user_id2 = "@user2:server"
user_id3 = "@user3:server"
result = yield self.store._user_last_seen_monthly_active(user_id1)
self.assertFalse(result == 0)
yield self.store.upsert_monthly_active_user(user_id1)
yield self.store.upsert_monthly_active_user(user_id2)
result = yield self.store._user_last_seen_monthly_active(user_id1)
self.assertTrue(result > 0)
result = yield self.store._user_last_seen_monthly_active(user_id3)
self.assertFalse(result == 0)
@defer.inlineCallbacks
def test_reap_monthly_active_users(self):
self.hs.config.max_mau_value = 5
initial_users = 10
for i in range(initial_users):
yield self.store.upsert_monthly_active_user("@user%d:server" % i)
count = yield self.store.get_monthly_active_count()
self.assertTrue(count, initial_users)
yield self.store.reap_monthly_active_users()
count = yield self.store.get_monthly_active_count()
self.assertTrue(count, initial_users - self.hs.config.max_mau_value)

View File

@ -73,6 +73,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None
config.block_events_without_consent_error = None config.block_events_without_consent_error = None
config.media_storage_providers = [] config.media_storage_providers = []
config.auto_join_rooms = [] config.auto_join_rooms = []
config.limit_usage_by_mau = False
# disable user directory updates, because they get done in the # disable user directory updates, because they get done in the
# background, which upsets the test runner. # background, which upsets the test runner.