Merge pull request #4654 from matrix-org/hawkowl/registration-worker

Registration worker
pull/4692/head
Erik Johnston 2019-02-15 17:51:34 +00:00 committed by GitHub
commit 5bd2e2c31d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 97 additions and 6 deletions

View File

@ -39,8 +39,12 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto
from synapse.replication.tcp.client import ReplicationClientHandler from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.rest.client.v1.base import ClientV1RestServlet, client_path_patterns from synapse.rest.client.v1.base import ClientV1RestServlet, client_path_patterns
from synapse.rest.client.v2_alpha._base import client_v2_patterns from synapse.rest.client.v2_alpha._base import client_v2_patterns
from synapse.rest.client.v2_alpha.register import (
register_servlets as register_registration_servlets,
)
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
from synapse.storage.registration import RegistrationStore
from synapse.util.httpresourcetree import create_resource_tree from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole from synapse.util.manhole import manhole
@ -141,6 +145,7 @@ class FrontendProxySlavedStore(
SlavedClientIpStore, SlavedClientIpStore,
SlavedApplicationServiceStore, SlavedApplicationServiceStore,
SlavedRegistrationStore, SlavedRegistrationStore,
RegistrationStore,
BaseSlavedStore, BaseSlavedStore,
): ):
pass pass
@ -161,6 +166,7 @@ class FrontendProxyServer(HomeServer):
elif name == "client": elif name == "client":
resource = JsonResource(self, canonical_json=False) resource = JsonResource(self, canonical_json=False)
KeyUploadServlet(self).register(resource) KeyUploadServlet(self).register(resource)
register_registration_servlets(self, resource)
# If presence is disabled, use the stub servlet that does # If presence is disabled, use the stub servlet that does
# not allow sending presence # not allow sending presence

View File

@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
import logging import logging
import synapse.metrics
from six import iteritems from six import iteritems

View File

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
from synapse.replication.http import federation, membership, send_event from synapse.replication.http import federation, membership, registration, send_event
REPLICATION_PREFIX = "/_synapse/replication" REPLICATION_PREFIX = "/_synapse/replication"
@ -28,3 +28,4 @@ class ReplicationRestResource(JsonResource):
send_event.register_servlets(hs, self) send_event.register_servlets(hs, self)
membership.register_servlets(hs, self) membership.register_servlets(hs, self)
federation.register_servlets(hs, self) federation.register_servlets(hs, self)
registration.register_servlets(hs, self)

View File

@ -0,0 +1,65 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.replication.http._base import ReplicationEndpoint
logger = logging.getLogger(__name__)
class RegistrationUserCacheInvalidationServlet(ReplicationEndpoint):
"""
Invalidate the caches that a registration usually invalidates.
Request format:
POST /_synapse/replication/fed_query/:fed_cleanup_room/:txn_id
{}
"""
NAME = "reg_invalidate_user_caches"
PATH_ARGS = ("user_id",)
def __init__(self, hs):
super(RegistrationUserCacheInvalidationServlet, self).__init__(hs)
self.store = hs.get_datastore()
@staticmethod
def _serialize_payload(user_id, args):
"""
Args:
user_id (str)
"""
return {}
@defer.inlineCallbacks
def _handle_request(self, request, user_id):
def invalidate(txn):
self.store._invalidate_cache_and_stream(
txn, self.store.get_user_by_id, (user_id,)
)
txn.call_after(self.store.is_guest.invalidate, (user_id,))
yield self.store.runInteraction("user_invalidate_caches", invalidate)
defer.returnValue((200, {}))
def register_servlets(hs, http_server):
RegistrationUserCacheInvalidationServlet(hs).register(http_server)

View File

@ -33,6 +33,9 @@ from synapse.http.servlet import (
parse_json_object_from_request, parse_json_object_from_request,
parse_string, parse_string,
) )
from synapse.replication.http.registration import (
RegistrationUserCacheInvalidationServlet,
)
from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.threepids import check_3pid_allowed from synapse.util.threepids import check_3pid_allowed
@ -193,6 +196,10 @@ class RegisterRestServlet(RestServlet):
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
self.macaroon_gen = hs.get_macaroon_generator() self.macaroon_gen = hs.get_macaroon_generator()
self._invalidate_caches_client = (
RegistrationUserCacheInvalidationServlet.make_client(hs)
)
@interactive_auth_handler @interactive_auth_handler
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
@ -266,6 +273,9 @@ class RegisterRestServlet(RestServlet):
# == Shared Secret Registration == (e.g. create new user scripts) # == Shared Secret Registration == (e.g. create new user scripts)
if 'mac' in body: if 'mac' in body:
if self.hs.config.worker_app:
raise SynapseError(403, "Not available at this endpoint")
# FIXME: Should we really be determining if this is shared secret # FIXME: Should we really be determining if this is shared secret
# auth based purely on the 'mac' key? # auth based purely on the 'mac' key?
result = yield self._do_shared_secret_registration( result = yield self._do_shared_secret_registration(
@ -456,6 +466,9 @@ class RegisterRestServlet(RestServlet):
) )
yield self.registration_handler.post_consent_actions(registered_user_id) yield self.registration_handler.post_consent_actions(registered_user_id)
if self.hs.config.worker_app:
yield self._invalidate_caches_client(registered_user_id)
defer.returnValue((200, return_dict)) defer.returnValue((200, return_dict))
def on_OPTIONS(self, _): def on_OPTIONS(self, _):
@ -466,6 +479,10 @@ class RegisterRestServlet(RestServlet):
user_id = yield self.registration_handler.appservice_register( user_id = yield self.registration_handler.appservice_register(
username, as_token username, as_token
) )
if self.hs.config.worker_app:
yield self._invalidate_caches_client(user_id)
defer.returnValue((yield self._create_registration_details(user_id, body))) defer.returnValue((yield self._create_registration_details(user_id, body)))
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -146,6 +146,7 @@ class RegistrationStore(RegistrationWorkerStore,
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(RegistrationStore, self).__init__(db_conn, hs) super(RegistrationStore, self).__init__(db_conn, hs)
self.hs = hs
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.register_background_index_update( self.register_background_index_update(
@ -321,10 +322,12 @@ class RegistrationStore(RegistrationWorkerStore,
(user_id_obj.localpart, create_profile_with_displayname) (user_id_obj.localpart, create_profile_with_displayname)
) )
self._invalidate_cache_and_stream( # Don't invalidate here, it will be done through replication to the worker.
txn, self.get_user_by_id, (user_id,) if not self.hs.config.worker_app:
) self._invalidate_cache_and_stream(
txn.call_after(self.is_guest.invalidate, (user_id,)) txn, self.get_user_by_id, (user_id,)
)
txn.call_after(self.is_guest.invalidate, (user_id,))
def get_users_by_id_case_insensitive(self, user_id): def get_users_by_id_case_insensitive(self, user_id):
"""Gets users that match user_id case insensitively. """Gets users that match user_id case insensitively.