Split out profile handler to fix tests

pull/2429/head
Erik Johnston 2017-08-25 14:34:56 +01:00
parent 27ebc5c8f2
commit bf81f3cf2c
11 changed files with 35 additions and 29 deletions

View File

@ -20,7 +20,6 @@ from .room import (
from .room_member import RoomMemberHandler from .room_member import RoomMemberHandler
from .message import MessageHandler from .message import MessageHandler
from .federation import FederationHandler from .federation import FederationHandler
from .profile import ProfileHandler
from .directory import DirectoryHandler from .directory import DirectoryHandler
from .admin import AdminHandler from .admin import AdminHandler
from .identity import IdentityHandler from .identity import IdentityHandler
@ -52,7 +51,6 @@ class Handlers(object):
self.room_creation_handler = RoomCreationHandler(hs) self.room_creation_handler = RoomCreationHandler(hs)
self.room_member_handler = RoomMemberHandler(hs) self.room_member_handler = RoomMemberHandler(hs)
self.federation_handler = FederationHandler(hs) self.federation_handler = FederationHandler(hs)
self.profile_handler = ProfileHandler(hs)
self.directory_handler = DirectoryHandler(hs) self.directory_handler = DirectoryHandler(hs)
self.admin_handler = AdminHandler(hs) self.admin_handler = AdminHandler(hs)
self.identity_handler = IdentityHandler(hs) self.identity_handler = IdentityHandler(hs)

View File

@ -56,8 +56,7 @@ class GroupsLocalHandler(object):
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.attestations = hs.get_groups_attestation_signing() self.attestations = hs.get_groups_attestation_signing()
handlers = hs.get_handlers() self.profile_handler = hs.get_profile_handler()
self.profile_handler = handlers.profile_handler
# Ensure attestations get renewed # Ensure attestations get renewed
hs.get_groups_attestation_renewer() hs.get_groups_attestation_renewer()

View File

@ -47,6 +47,7 @@ class MessageHandler(BaseHandler):
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.validator = EventValidator() self.validator = EventValidator()
self.profile_handler = hs.get_profile_handler()
self.pagination_lock = ReadWriteLock() self.pagination_lock = ReadWriteLock()
@ -210,7 +211,7 @@ class MessageHandler(BaseHandler):
if membership in {Membership.JOIN, Membership.INVITE}: if membership in {Membership.JOIN, Membership.INVITE}:
# If event doesn't include a display name, add one. # If event doesn't include a display name, add one.
profile = self.hs.get_handlers().profile_handler profile = self.profile_handler
content = builder.content content = builder.content
try: try:

View File

@ -22,18 +22,21 @@ from synapse.api.errors import SynapseError, AuthError, CodeMessageException
from synapse.types import UserID, get_domain_from_id from synapse.types import UserID, get_domain_from_id
from ._base import BaseHandler from ._base import BaseHandler
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ProfileHandler(BaseHandler): class ProfileHandler(object):
PROFILE_UPDATE_MS = 60 * 1000 PROFILE_UPDATE_MS = 60 * 1000
PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000 PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000
def __init__(self, hs): def __init__(self, hs):
super(ProfileHandler, self).__init__(hs) self.hs = hs
self.store = hs.get_datastore()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.ratelimiter = hs.get_ratelimiter()
# AWFUL hack to get at BaseHandler.ratelimit
self.base_handler = BaseHandler(hs)
self.federation = hs.get_replication_layer() self.federation = hs.get_replication_layer()
self.federation.register_query_handler( self.federation.register_query_handler(
@ -194,7 +197,7 @@ class ProfileHandler(BaseHandler):
if not self.hs.is_mine(user): if not self.hs.is_mine(user):
return return
yield self.ratelimit(requester) yield self.base_handler.ratelimit(requester)
room_ids = yield self.store.get_rooms_for_user( room_ids = yield self.store.get_rooms_for_user(
user.to_string(), user.to_string(),

View File

@ -36,6 +36,7 @@ class RegistrationHandler(BaseHandler):
super(RegistrationHandler, self).__init__(hs) super(RegistrationHandler, self).__init__(hs)
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.profile_handler = hs.get_profile_handler()
self.captcha_client = CaptchaServerHttpClient(hs) self.captcha_client = CaptchaServerHttpClient(hs)
self._next_generated_user_id = None self._next_generated_user_id = None
@ -423,8 +424,7 @@ class RegistrationHandler(BaseHandler):
if displayname is not None: if displayname is not None:
logger.info("setting user display name: %s -> %s", user_id, displayname) logger.info("setting user display name: %s -> %s", user_id, displayname)
profile_handler = self.hs.get_handlers().profile_handler yield self.profile_handler.set_displayname(
yield profile_handler.set_displayname(
user, requester, displayname, by_admin=True, user, requester, displayname, by_admin=True,
) )

View File

@ -45,6 +45,8 @@ class RoomMemberHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
super(RoomMemberHandler, self).__init__(hs) super(RoomMemberHandler, self).__init__(hs)
self.profile_handler = hs.get_profile_handler()
self.member_linearizer = Linearizer(name="member") self.member_linearizer = Linearizer(name="member")
self.clock = hs.get_clock() self.clock = hs.get_clock()
@ -255,7 +257,7 @@ class RoomMemberHandler(BaseHandler):
content["membership"] = Membership.JOIN content["membership"] = Membership.JOIN
profile = self.hs.get_handlers().profile_handler profile = self.profile_handler
if not content_specified: if not content_specified:
content["displayname"] = yield profile.get_displayname(target) content["displayname"] = yield profile.get_displayname(target)
content["avatar_url"] = yield profile.get_avatar_url(target) content["avatar_url"] = yield profile.get_avatar_url(target)

View File

@ -26,13 +26,13 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(ProfileDisplaynameRestServlet, self).__init__(hs) super(ProfileDisplaynameRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers() self.profile_handler = hs.get_profile_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
displayname = yield self.handlers.profile_handler.get_displayname( displayname = yield self.profile_handler.get_displayname(
user, user,
) )
@ -55,7 +55,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
except: except:
defer.returnValue((400, "Unable to parse name")) defer.returnValue((400, "Unable to parse name"))
yield self.handlers.profile_handler.set_displayname( yield self.profile_handler.set_displayname(
user, requester, new_name, is_admin) user, requester, new_name, is_admin)
defer.returnValue((200, {})) defer.returnValue((200, {}))
@ -69,13 +69,13 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(ProfileAvatarURLRestServlet, self).__init__(hs) super(ProfileAvatarURLRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers() self.profile_handler = hs.get_profile_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
avatar_url = yield self.handlers.profile_handler.get_avatar_url( avatar_url = yield self.profile_handler.get_avatar_url(
user, user,
) )
@ -97,7 +97,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
except: except:
defer.returnValue((400, "Unable to parse name")) defer.returnValue((400, "Unable to parse name"))
yield self.handlers.profile_handler.set_avatar_url( yield self.profile_handler.set_avatar_url(
user, requester, new_name, is_admin) user, requester, new_name, is_admin)
defer.returnValue((200, {})) defer.returnValue((200, {}))
@ -111,16 +111,16 @@ class ProfileRestServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(ProfileRestServlet, self).__init__(hs) super(ProfileRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers() self.profile_handler = hs.get_profile_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
displayname = yield self.handlers.profile_handler.get_displayname( displayname = yield self.profile_handler.get_displayname(
user, user,
) )
avatar_url = yield self.handlers.profile_handler.get_avatar_url( avatar_url = yield self.profile_handler.get_avatar_url(
user, user,
) )

View File

@ -51,6 +51,7 @@ from synapse.handlers.receipts import ReceiptsHandler
from synapse.handlers.read_marker import ReadMarkerHandler from synapse.handlers.read_marker import ReadMarkerHandler
from synapse.handlers.user_directory import UserDirectoyHandler from synapse.handlers.user_directory import UserDirectoyHandler
from synapse.handlers.groups_local import GroupsLocalHandler from synapse.handlers.groups_local import GroupsLocalHandler
from synapse.handlers.profile import ProfileHandler
from synapse.groups.groups_server import GroupsServerHandler from synapse.groups.groups_server import GroupsServerHandler
from synapse.groups.attestations import GroupAttestionRenewer, GroupAttestationSigning from synapse.groups.attestations import GroupAttestionRenewer, GroupAttestationSigning
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
@ -114,6 +115,7 @@ class HomeServer(object):
'application_service_scheduler', 'application_service_scheduler',
'application_service_handler', 'application_service_handler',
'device_message_handler', 'device_message_handler',
'profile_handler',
'notifier', 'notifier',
'distributor', 'distributor',
'client_resource', 'client_resource',
@ -258,6 +260,9 @@ class HomeServer(object):
def build_initial_sync_handler(self): def build_initial_sync_handler(self):
return InitialSyncHandler(self) return InitialSyncHandler(self)
def build_profile_handler(self):
return ProfileHandler(self)
def build_event_sources(self): def build_event_sources(self):
return EventSources(self) return EventSources(self)

View File

@ -62,8 +62,6 @@ class ProfileTestCase(unittest.TestCase):
self.ratelimiter = hs.get_ratelimiter() self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0) self.ratelimiter.send_message.return_value = (True, 0)
hs.handlers = ProfileHandlers(hs)
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.frank = UserID.from_string("@1234ABCD:test") self.frank = UserID.from_string("@1234ABCD:test")
@ -72,7 +70,7 @@ class ProfileTestCase(unittest.TestCase):
yield self.store.create_profile(self.frank.localpart) yield self.store.create_profile(self.frank.localpart)
self.handler = hs.get_handlers().profile_handler self.handler = hs.get_profile_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_my_name(self): def test_get_my_name(self):

View File

@ -40,13 +40,14 @@ class RegistrationTestCase(unittest.TestCase):
self.hs = yield setup_test_homeserver( self.hs = yield setup_test_homeserver(
handlers=None, handlers=None,
http_client=None, http_client=None,
expire_access_token=True) expire_access_token=True,
profile_handler=Mock(),
)
self.macaroon_generator = Mock( self.macaroon_generator = Mock(
generate_access_token=Mock(return_value='secret')) generate_access_token=Mock(return_value='secret'))
self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator) self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator)
self.hs.handlers = RegistrationHandlers(self.hs) self.hs.handlers = RegistrationHandlers(self.hs)
self.handler = self.hs.get_handlers().registration_handler self.handler = self.hs.get_handlers().registration_handler
self.hs.get_handlers().profile_handler = Mock()
@defer.inlineCallbacks @defer.inlineCallbacks
def test_user_is_created_and_logged_in_if_doesnt_exist(self): def test_user_is_created_and_logged_in_if_doesnt_exist(self):

View File

@ -46,6 +46,7 @@ class ProfileTestCase(unittest.TestCase):
resource_for_client=self.mock_resource, resource_for_client=self.mock_resource,
federation=Mock(), federation=Mock(),
replication_layer=Mock(), replication_layer=Mock(),
profile_handler=self.mock_handler
) )
def _get_user_by_req(request=None, allow_guest=False): def _get_user_by_req(request=None, allow_guest=False):
@ -53,8 +54,6 @@ class ProfileTestCase(unittest.TestCase):
hs.get_v1auth().get_user_by_req = _get_user_by_req hs.get_v1auth().get_user_by_req = _get_user_by_req
hs.get_handlers().profile_handler = self.mock_handler
profile.register_servlets(hs, self.mock_resource) profile.register_servlets(hs, self.mock_resource)
@defer.inlineCallbacks @defer.inlineCallbacks