Some more porting to HomeserverTestCase and remove old RESTHelper (#4913)

pull/4923/head
Amber Brown 2019-03-22 02:10:21 +11:00 committed by GitHub
parent 7bef97dfb7
commit a68e00fca8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 124 additions and 230 deletions

1
changelog.d/4913.misc Normal file
View File

@ -0,0 +1 @@
Refactor some more tests to use HomeserverTestCase.

View File

@ -22,8 +22,6 @@ from synapse.api.errors import ResourceLimitError, SynapseError
from synapse.handlers.register import RegistrationHandler from synapse.handlers.register import RegistrationHandler
from synapse.types import RoomAlias, UserID, create_requester from synapse.types import RoomAlias, UserID, create_requester
from tests.utils import default_config, setup_test_homeserver
from .. import unittest from .. import unittest
@ -32,26 +30,23 @@ class RegistrationHandlers(object):
self.registration_handler = RegistrationHandler(hs) self.registration_handler = RegistrationHandler(hs)
class RegistrationTestCase(unittest.TestCase): class RegistrationTestCase(unittest.HomeserverTestCase):
""" Tests the RegistrationHandler. """ """ Tests the RegistrationHandler. """
@defer.inlineCallbacks def make_homeserver(self, reactor, clock):
def setUp(self): hs_config = self.default_config("test")
self.mock_distributor = Mock()
self.mock_distributor.declare("registered_user")
self.mock_captcha_client = Mock()
hs_config = default_config("test")
# some of the tests rely on us having a user consent version # some of the tests rely on us having a user consent version
hs_config.user_consent_version = "test_consent_version" hs_config.user_consent_version = "test_consent_version"
hs_config.max_mau_value = 50 hs_config.max_mau_value = 50
self.hs = yield setup_test_homeserver( hs = self.setup_test_homeserver(config=hs_config, expire_access_token=True)
self.addCleanup, return hs
config=hs_config,
expire_access_token=True, def prepare(self, reactor, clock, hs):
) self.mock_distributor = Mock()
self.mock_distributor.declare("registered_user")
self.mock_captcha_client = Mock()
self.macaroon_generator = Mock( self.macaroon_generator = Mock(
generate_access_token=Mock(return_value='secret') generate_access_token=Mock(return_value='secret')
) )
@ -63,136 +58,133 @@ class RegistrationTestCase(unittest.TestCase):
self.requester = create_requester("@requester:test") self.requester = create_requester("@requester:test")
@defer.inlineCallbacks
def test_user_is_created_and_logged_in_if_doesnt_exist(self): def test_user_is_created_and_logged_in_if_doesnt_exist(self):
frank = UserID.from_string("@frank:test") frank = UserID.from_string("@frank:test")
user_id = frank.to_string() user_id = frank.to_string()
requester = create_requester(user_id) requester = create_requester(user_id)
result_user_id, result_token = yield self.handler.get_or_create_user( result_user_id, result_token = self.get_success(
requester, frank.localpart, "Frankie" self.handler.get_or_create_user(requester, frank.localpart, "Frankie")
) )
self.assertEquals(result_user_id, user_id) self.assertEquals(result_user_id, user_id)
self.assertTrue(result_token is not None) self.assertTrue(result_token is not None)
self.assertEquals(result_token, 'secret') self.assertEquals(result_token, 'secret')
@defer.inlineCallbacks
def test_if_user_exists(self): def test_if_user_exists(self):
store = self.hs.get_datastore() store = self.hs.get_datastore()
frank = UserID.from_string("@frank:test") frank = UserID.from_string("@frank:test")
yield store.register( self.get_success(
user_id=frank.to_string(), store.register(
token="jkv;g498752-43gj['eamb!-5", user_id=frank.to_string(),
password_hash=None, token="jkv;g498752-43gj['eamb!-5",
password_hash=None,
)
) )
local_part = frank.localpart local_part = frank.localpart
user_id = frank.to_string() user_id = frank.to_string()
requester = create_requester(user_id) requester = create_requester(user_id)
result_user_id, result_token = yield self.handler.get_or_create_user( result_user_id, result_token = self.get_success(
requester, local_part, None self.handler.get_or_create_user(requester, local_part, None)
) )
self.assertEquals(result_user_id, user_id) self.assertEquals(result_user_id, user_id)
self.assertTrue(result_token is not None) self.assertTrue(result_token is not None)
@defer.inlineCallbacks
def test_mau_limits_when_disabled(self): def test_mau_limits_when_disabled(self):
self.hs.config.limit_usage_by_mau = False self.hs.config.limit_usage_by_mau = False
# Ensure does not throw exception # Ensure does not throw exception
yield self.handler.get_or_create_user(self.requester, 'a', "display_name") self.get_success(
self.handler.get_or_create_user(self.requester, 'a', "display_name")
)
@defer.inlineCallbacks
def test_get_or_create_user_mau_not_blocked(self): def test_get_or_create_user_mau_not_blocked(self):
self.hs.config.limit_usage_by_mau = True self.hs.config.limit_usage_by_mau = True
self.store.count_monthly_users = Mock( self.store.count_monthly_users = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value - 1) return_value=defer.succeed(self.hs.config.max_mau_value - 1)
) )
# Ensure does not throw exception # Ensure does not throw exception
yield self.handler.get_or_create_user(self.requester, 'c', "User") self.get_success(self.handler.get_or_create_user(self.requester, 'c', "User"))
@defer.inlineCallbacks
def test_get_or_create_user_mau_blocked(self): def test_get_or_create_user_mau_blocked(self):
self.hs.config.limit_usage_by_mau = True self.hs.config.limit_usage_by_mau = True
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.lots_of_users) return_value=defer.succeed(self.lots_of_users)
) )
with self.assertRaises(ResourceLimitError): self.get_failure(
yield self.handler.get_or_create_user(self.requester, 'b', "display_name") self.handler.get_or_create_user(self.requester, 'b', "display_name"),
ResourceLimitError,
)
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value) return_value=defer.succeed(self.hs.config.max_mau_value)
) )
with self.assertRaises(ResourceLimitError): self.get_failure(
yield self.handler.get_or_create_user(self.requester, 'b', "display_name") self.handler.get_or_create_user(self.requester, 'b', "display_name"),
ResourceLimitError,
)
@defer.inlineCallbacks
def test_register_mau_blocked(self): def test_register_mau_blocked(self):
self.hs.config.limit_usage_by_mau = True self.hs.config.limit_usage_by_mau = True
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.lots_of_users) return_value=defer.succeed(self.lots_of_users)
) )
with self.assertRaises(ResourceLimitError): self.get_failure(
yield self.handler.register(localpart="local_part") self.handler.register(localpart="local_part"), ResourceLimitError
)
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value) return_value=defer.succeed(self.hs.config.max_mau_value)
) )
with self.assertRaises(ResourceLimitError): self.get_failure(
yield self.handler.register(localpart="local_part") self.handler.register(localpart="local_part"), ResourceLimitError
)
@defer.inlineCallbacks
def test_auto_create_auto_join_rooms(self): def test_auto_create_auto_join_rooms(self):
room_alias_str = "#room:test" room_alias_str = "#room:test"
self.hs.config.auto_join_rooms = [room_alias_str] self.hs.config.auto_join_rooms = [room_alias_str]
res = yield self.handler.register(localpart='jeff') res = self.get_success(self.handler.register(localpart='jeff'))
rooms = yield self.store.get_rooms_for_user(res[0]) rooms = self.get_success(self.store.get_rooms_for_user(res[0]))
directory_handler = self.hs.get_handlers().directory_handler directory_handler = self.hs.get_handlers().directory_handler
room_alias = RoomAlias.from_string(room_alias_str) room_alias = RoomAlias.from_string(room_alias_str)
room_id = yield directory_handler.get_association(room_alias) room_id = self.get_success(directory_handler.get_association(room_alias))
self.assertTrue(room_id['room_id'] in rooms) self.assertTrue(room_id['room_id'] in rooms)
self.assertEqual(len(rooms), 1) self.assertEqual(len(rooms), 1)
@defer.inlineCallbacks
def test_auto_create_auto_join_rooms_with_no_rooms(self): def test_auto_create_auto_join_rooms_with_no_rooms(self):
self.hs.config.auto_join_rooms = [] self.hs.config.auto_join_rooms = []
frank = UserID.from_string("@frank:test") frank = UserID.from_string("@frank:test")
res = yield self.handler.register(frank.localpart) res = self.get_success(self.handler.register(frank.localpart))
self.assertEqual(res[0], frank.to_string()) self.assertEqual(res[0], frank.to_string())
rooms = yield self.store.get_rooms_for_user(res[0]) rooms = self.get_success(self.store.get_rooms_for_user(res[0]))
self.assertEqual(len(rooms), 0) self.assertEqual(len(rooms), 0)
@defer.inlineCallbacks
def test_auto_create_auto_join_where_room_is_another_domain(self): def test_auto_create_auto_join_where_room_is_another_domain(self):
self.hs.config.auto_join_rooms = ["#room:another"] self.hs.config.auto_join_rooms = ["#room:another"]
frank = UserID.from_string("@frank:test") frank = UserID.from_string("@frank:test")
res = yield self.handler.register(frank.localpart) res = self.get_success(self.handler.register(frank.localpart))
self.assertEqual(res[0], frank.to_string()) self.assertEqual(res[0], frank.to_string())
rooms = yield self.store.get_rooms_for_user(res[0]) rooms = self.get_success(self.store.get_rooms_for_user(res[0]))
self.assertEqual(len(rooms), 0) self.assertEqual(len(rooms), 0)
@defer.inlineCallbacks
def test_auto_create_auto_join_where_auto_create_is_false(self): def test_auto_create_auto_join_where_auto_create_is_false(self):
self.hs.config.autocreate_auto_join_rooms = False self.hs.config.autocreate_auto_join_rooms = False
room_alias_str = "#room:test" room_alias_str = "#room:test"
self.hs.config.auto_join_rooms = [room_alias_str] self.hs.config.auto_join_rooms = [room_alias_str]
res = yield self.handler.register(localpart='jeff') res = self.get_success(self.handler.register(localpart='jeff'))
rooms = yield self.store.get_rooms_for_user(res[0]) rooms = self.get_success(self.store.get_rooms_for_user(res[0]))
self.assertEqual(len(rooms), 0) self.assertEqual(len(rooms), 0)
@defer.inlineCallbacks
def test_auto_create_auto_join_rooms_when_support_user_exists(self): def test_auto_create_auto_join_rooms_when_support_user_exists(self):
room_alias_str = "#room:test" room_alias_str = "#room:test"
self.hs.config.auto_join_rooms = [room_alias_str] self.hs.config.auto_join_rooms = [room_alias_str]
self.store.is_support_user = Mock(return_value=True) self.store.is_support_user = Mock(return_value=True)
res = yield self.handler.register(localpart='support') res = self.get_success(self.handler.register(localpart='support'))
rooms = yield self.store.get_rooms_for_user(res[0]) rooms = self.get_success(self.store.get_rooms_for_user(res[0]))
self.assertEqual(len(rooms), 0) self.assertEqual(len(rooms), 0)
directory_handler = self.hs.get_handlers().directory_handler directory_handler = self.hs.get_handlers().directory_handler
room_alias = RoomAlias.from_string(room_alias_str) room_alias = RoomAlias.from_string(room_alias_str)
with self.assertRaises(SynapseError): self.get_failure(directory_handler.get_association(room_alias), SynapseError)
yield directory_handler.get_association(room_alias)
@defer.inlineCallbacks
def test_auto_create_auto_join_where_no_consent(self): def test_auto_create_auto_join_where_no_consent(self):
"""Test to ensure that the first user is not auto-joined to a room if """Test to ensure that the first user is not auto-joined to a room if
they have not given general consent. they have not given general consent.
@ -208,27 +200,27 @@ class RegistrationTestCase(unittest.TestCase):
# (Messing with the internals of event_creation_handler is fragile # (Messing with the internals of event_creation_handler is fragile
# but can't see a better way to do this. One option could be to subclass # but can't see a better way to do this. One option could be to subclass
# the test with custom config.) # the test with custom config.)
event_creation_handler._block_events_without_consent_error = ("Error") event_creation_handler._block_events_without_consent_error = "Error"
event_creation_handler._consent_uri_builder = Mock() event_creation_handler._consent_uri_builder = Mock()
room_alias_str = "#room:test" room_alias_str = "#room:test"
self.hs.config.auto_join_rooms = [room_alias_str] self.hs.config.auto_join_rooms = [room_alias_str]
# When:- # When:-
# * the user is registered and post consent actions are called # * the user is registered and post consent actions are called
res = yield self.handler.register(localpart='jeff') res = self.get_success(self.handler.register(localpart='jeff'))
yield self.handler.post_consent_actions(res[0]) self.get_success(self.handler.post_consent_actions(res[0]))
# Then:- # Then:-
# * Ensure that they have not been joined to the room # * Ensure that they have not been joined to the room
rooms = yield self.store.get_rooms_for_user(res[0]) rooms = self.get_success(self.store.get_rooms_for_user(res[0]))
self.assertEqual(len(rooms), 0) self.assertEqual(len(rooms), 0)
@defer.inlineCallbacks
def test_register_support_user(self): def test_register_support_user(self):
res = yield self.handler.register(localpart='user', user_type=UserTypes.SUPPORT) res = self.get_success(
self.handler.register(localpart='user', user_type=UserTypes.SUPPORT)
)
self.assertTrue(self.store.is_support_user(res[0])) self.assertTrue(self.store.is_support_user(res[0]))
@defer.inlineCallbacks
def test_register_not_support_user(self): def test_register_not_support_user(self):
res = yield self.handler.register(localpart='user') res = self.get_success(self.handler.register(localpart='user'))
self.assertFalse(self.store.is_support_user(res[0])) self.assertFalse(self.store.is_support_user(res[0]))

View File

@ -18,136 +18,11 @@ import time
import attr import attr
from twisted.internet import defer
from synapse.api.constants import Membership from synapse.api.constants import Membership
from tests import unittest
from tests.server import make_request, render from tests.server import make_request, render
class RestTestCase(unittest.TestCase):
"""Contains extra helper functions to quickly and clearly perform a given
REST action, which isn't the focus of the test.
This subclass assumes there are mock_resource and auth_user_id attributes.
"""
def __init__(self, *args, **kwargs):
super(RestTestCase, self).__init__(*args, **kwargs)
self.mock_resource = None
self.auth_user_id = None
@defer.inlineCallbacks
def create_room_as(self, room_creator, is_public=True, tok=None):
temp_id = self.auth_user_id
self.auth_user_id = room_creator
path = "/createRoom"
content = "{}"
if not is_public:
content = '{"visibility":"private"}'
if tok:
path = path + "?access_token=%s" % tok
(code, response) = yield self.mock_resource.trigger("POST", path, content)
self.assertEquals(200, code, msg=str(response))
self.auth_user_id = temp_id
defer.returnValue(response["room_id"])
@defer.inlineCallbacks
def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None):
yield self.change_membership(
room=room,
src=src,
targ=targ,
tok=tok,
membership=Membership.INVITE,
expect_code=expect_code,
)
@defer.inlineCallbacks
def join(self, room=None, user=None, expect_code=200, tok=None):
yield self.change_membership(
room=room,
src=user,
targ=user,
tok=tok,
membership=Membership.JOIN,
expect_code=expect_code,
)
@defer.inlineCallbacks
def leave(self, room=None, user=None, expect_code=200, tok=None):
yield self.change_membership(
room=room,
src=user,
targ=user,
tok=tok,
membership=Membership.LEAVE,
expect_code=expect_code,
)
@defer.inlineCallbacks
def change_membership(self, room, src, targ, membership, tok=None, expect_code=200):
temp_id = self.auth_user_id
self.auth_user_id = src
path = "/rooms/%s/state/m.room.member/%s" % (room, targ)
if tok:
path = path + "?access_token=%s" % tok
data = {"membership": membership}
(code, response) = yield self.mock_resource.trigger(
"PUT", path, json.dumps(data)
)
self.assertEquals(
expect_code,
code,
msg="Expected: %d, got: %d, resp: %r" % (expect_code, code, response),
)
self.auth_user_id = temp_id
@defer.inlineCallbacks
def register(self, user_id):
(code, response) = yield self.mock_resource.trigger(
"POST",
"/register",
json.dumps(
{"user": user_id, "password": "test", "type": "m.login.password"}
),
)
self.assertEquals(200, code, msg=response)
defer.returnValue(response)
@defer.inlineCallbacks
def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200):
if txn_id is None:
txn_id = "m%s" % (str(time.time()))
if body is None:
body = "body_text_here"
path = "/rooms/%s/send/m.room.message/%s" % (room_id, txn_id)
content = '{"msgtype":"m.text","body":"%s"}' % body
if tok:
path = path + "?access_token=%s" % tok
(code, response) = yield self.mock_resource.trigger("PUT", path, content)
self.assertEquals(expect_code, code, msg=str(response))
def assert_dict(self, required, actual):
"""Does a partial assert of a dict.
Args:
required (dict): The keys and value which MUST be in 'actual'.
actual (dict): The test result. Extra keys will not be checked.
"""
for key in required:
self.assertEquals(
required[key], actual[key], msg="%s mismatch. %s" % (key, actual)
)
@attr.s @attr.s
class RestHelper(object): class RestHelper(object):
"""Contains extra helper functions to quickly and clearly perform a given """Contains extra helper functions to quickly and clearly perform a given

View File

@ -1,3 +1,18 @@
# -*- coding: utf-8 -*-
# Copyright 2018, 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from mock import Mock from mock import Mock
from twisted.internet import defer from twisted.internet import defer
@ -9,16 +24,18 @@ from synapse.server_notices.resource_limits_server_notices import (
) )
from tests import unittest from tests import unittest
from tests.utils import default_config, setup_test_homeserver
class TestResourceLimitsServerNotices(unittest.TestCase): class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
@defer.inlineCallbacks
def setUp(self): def make_homeserver(self, reactor, clock):
hs_config = default_config(name="test") hs_config = self.default_config("test")
hs_config.server_notices_mxid = "@server:test" hs_config.server_notices_mxid = "@server:test"
self.hs = yield setup_test_homeserver(self.addCleanup, config=hs_config) hs = self.setup_test_homeserver(config=hs_config, expire_access_token=True)
return hs
def prepare(self, reactor, clock, hs):
self.server_notices_sender = self.hs.get_server_notices_sender() self.server_notices_sender = self.hs.get_server_notices_sender()
# relying on [1] is far from ideal, but the only case where # relying on [1] is far from ideal, but the only case where
@ -53,23 +70,21 @@ class TestResourceLimitsServerNotices(unittest.TestCase):
self._rlsn._store.get_tags_for_room = Mock(return_value={}) self._rlsn._store.get_tags_for_room = Mock(return_value={})
self.hs.config.admin_contact = "mailto:user@test.com" self.hs.config.admin_contact = "mailto:user@test.com"
@defer.inlineCallbacks
def test_maybe_send_server_notice_to_user_flag_off(self): def test_maybe_send_server_notice_to_user_flag_off(self):
"""Tests cases where the flags indicate nothing to do""" """Tests cases where the flags indicate nothing to do"""
# test hs disabled case # test hs disabled case
self.hs.config.hs_disabled = True self.hs.config.hs_disabled = True
yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
self._send_notice.assert_not_called() self._send_notice.assert_not_called()
# Test when mau limiting disabled # Test when mau limiting disabled
self.hs.config.hs_disabled = False self.hs.config.hs_disabled = False
self.hs.limit_usage_by_mau = False self.hs.limit_usage_by_mau = False
yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
self._send_notice.assert_not_called() self._send_notice.assert_not_called()
@defer.inlineCallbacks
def test_maybe_send_server_notice_to_user_remove_blocked_notice(self): def test_maybe_send_server_notice_to_user_remove_blocked_notice(self):
"""Test when user has blocked notice, but should have it removed""" """Test when user has blocked notice, but should have it removed"""
@ -81,13 +96,14 @@ class TestResourceLimitsServerNotices(unittest.TestCase):
return_value=defer.succeed({"123": mock_event}) return_value=defer.succeed({"123": mock_event})
) )
yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
# Would be better to check the content, but once == remove blocking event # Would be better to check the content, but once == remove blocking event
self._send_notice.assert_called_once() self._send_notice.assert_called_once()
@defer.inlineCallbacks
def test_maybe_send_server_notice_to_user_remove_blocked_notice_noop(self): def test_maybe_send_server_notice_to_user_remove_blocked_notice_noop(self):
"""Test when user has blocked notice, but notice ought to be there (NOOP)""" """
Test when user has blocked notice, but notice ought to be there (NOOP)
"""
self._rlsn._auth.check_auth_blocking = Mock( self._rlsn._auth.check_auth_blocking = Mock(
side_effect=ResourceLimitError(403, 'foo') side_effect=ResourceLimitError(403, 'foo')
) )
@ -98,52 +114,49 @@ class TestResourceLimitsServerNotices(unittest.TestCase):
self._rlsn._store.get_events = Mock( self._rlsn._store.get_events = Mock(
return_value=defer.succeed({"123": mock_event}) return_value=defer.succeed({"123": mock_event})
) )
yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
self._send_notice.assert_not_called() self._send_notice.assert_not_called()
@defer.inlineCallbacks
def test_maybe_send_server_notice_to_user_add_blocked_notice(self): def test_maybe_send_server_notice_to_user_add_blocked_notice(self):
"""Test when user does not have blocked notice, but should have one""" """
Test when user does not have blocked notice, but should have one
"""
self._rlsn._auth.check_auth_blocking = Mock( self._rlsn._auth.check_auth_blocking = Mock(
side_effect=ResourceLimitError(403, 'foo') side_effect=ResourceLimitError(403, 'foo')
) )
yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
# Would be better to check contents, but 2 calls == set blocking event # Would be better to check contents, but 2 calls == set blocking event
self.assertTrue(self._send_notice.call_count == 2) self.assertTrue(self._send_notice.call_count == 2)
@defer.inlineCallbacks
def test_maybe_send_server_notice_to_user_add_blocked_notice_noop(self): def test_maybe_send_server_notice_to_user_add_blocked_notice_noop(self):
"""Test when user does not have blocked notice, nor should they (NOOP)""" """
Test when user does not have blocked notice, nor should they (NOOP)
"""
self._rlsn._auth.check_auth_blocking = Mock() self._rlsn._auth.check_auth_blocking = Mock()
yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
self._send_notice.assert_not_called() self._send_notice.assert_not_called()
@defer.inlineCallbacks
def test_maybe_send_server_notice_to_user_not_in_mau_cohort(self): def test_maybe_send_server_notice_to_user_not_in_mau_cohort(self):
"""
"""Test when user is not part of the MAU cohort - this should not ever Test when user is not part of the MAU cohort - this should not ever
happen - but ... happen - but ...
""" """
self._rlsn._auth.check_auth_blocking = Mock() self._rlsn._auth.check_auth_blocking = Mock()
self._rlsn._store.user_last_seen_monthly_active = Mock( self._rlsn._store.user_last_seen_monthly_active = Mock(
return_value=defer.succeed(None) return_value=defer.succeed(None)
) )
yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
self._send_notice.assert_not_called() self._send_notice.assert_not_called()
class TestResourceLimitsServerNoticesWithRealRooms(unittest.TestCase): class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
@defer.inlineCallbacks def prepare(self, reactor, clock, hs):
def setUp(self):
self.hs = yield setup_test_homeserver(self.addCleanup)
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()
self.server_notices_sender = self.hs.get_server_notices_sender() self.server_notices_sender = self.hs.get_server_notices_sender()
self.server_notices_manager = self.hs.get_server_notices_manager() self.server_notices_manager = self.hs.get_server_notices_manager()
@ -168,26 +181,27 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.TestCase):
self.hs.config.admin_contact = "mailto:user@test.com" self.hs.config.admin_contact = "mailto:user@test.com"
@defer.inlineCallbacks
def test_server_notice_only_sent_once(self): def test_server_notice_only_sent_once(self):
self.store.get_monthly_active_count = Mock(return_value=1000) self.store.get_monthly_active_count = Mock(return_value=1000)
self.store.user_last_seen_monthly_active = Mock(return_value=1000) self.store.user_last_seen_monthly_active = Mock(return_value=1000)
# Call the function multiple times to ensure we only send the notice once # Call the function multiple times to ensure we only send the notice once
yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
# Now lets get the last load of messages in the service notice room and # Now lets get the last load of messages in the service notice room and
# check that there is only one server notice # check that there is only one server notice
room_id = yield self.server_notices_manager.get_notice_room_for_user( room_id = self.get_success(
self.user_id self.server_notices_manager.get_notice_room_for_user(self.user_id)
) )
token = yield self.event_source.get_current_token() token = self.get_success(self.event_source.get_current_token())
events, _ = yield self.store.get_recent_events_for_room( events, _ = self.get_success(
room_id, limit=100, end_token=token.room_key self.store.get_recent_events_for_room(
room_id, limit=100, end_token=token.room_key
)
) )
count = 0 count = 0

View File

@ -314,6 +314,9 @@ class HomeserverTestCase(TestCase):
""" """
kwargs = dict(kwargs) kwargs = dict(kwargs)
kwargs.update(self._hs_args) kwargs.update(self._hs_args)
if "config" not in kwargs:
config = self.default_config()
kwargs["config"] = config
hs = setup_test_homeserver(self.addCleanup, *args, **kwargs) hs = setup_test_homeserver(self.addCleanup, *args, **kwargs)
stor = hs.get_datastore() stor = hs.get_datastore()
@ -336,6 +339,15 @@ class HomeserverTestCase(TestCase):
self.pump(by=by) self.pump(by=by)
return self.successResultOf(d) return self.successResultOf(d)
def get_failure(self, d, exc):
"""
Run a Deferred and get a Failure from it. The failure must be of the type `exc`.
"""
if not isinstance(d, Deferred):
return d
self.pump()
return self.failureResultOf(d, exc)
def register_user(self, username, password, admin=False): def register_user(self, username, password, admin=False):
""" """
Register a user. Requires the Admin API be registered. Register a user. Requires the Admin API be registered.