add cache remover endpoint and wire it up

pull/4654/head
Amber Brown 2019-02-16 04:34:23 +11:00
parent d97c3a6ce6
commit f5bafd70f4
3 changed files with 20 additions and 5 deletions

View File

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
from synapse.replication.http import federation, membership, send_event from synapse.replication.http import federation, membership, send_event, registration
REPLICATION_PREFIX = "/_synapse/replication" REPLICATION_PREFIX = "/_synapse/replication"
@ -28,3 +28,4 @@ class ReplicationRestResource(JsonResource):
send_event.register_servlets(hs, self) send_event.register_servlets(hs, self)
membership.register_servlets(hs, self) membership.register_servlets(hs, self)
federation.register_servlets(hs, self) federation.register_servlets(hs, self)
registration.register_servlets(hs, self)

View File

@ -24,6 +24,7 @@ from twisted.internet import defer
import synapse import synapse
import synapse.types import synapse.types
from synapse.replication.http.registration import RegistrationUserCacheInvalidationServlet
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import Codes, SynapseError, UnrecognizedRequestError from synapse.api.errors import Codes, SynapseError, UnrecognizedRequestError
from synapse.config.server import is_threepid_reserved from synapse.config.server import is_threepid_reserved
@ -193,6 +194,10 @@ class RegisterRestServlet(RestServlet):
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
self.macaroon_gen = hs.get_macaroon_generator() self.macaroon_gen = hs.get_macaroon_generator()
self._invalidate_caches_client = (
RegistrationUserCacheInvalidationServlet.make_client(hs)
)
@interactive_auth_handler @interactive_auth_handler
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
@ -266,6 +271,9 @@ class RegisterRestServlet(RestServlet):
# == Shared Secret Registration == (e.g. create new user scripts) # == Shared Secret Registration == (e.g. create new user scripts)
if 'mac' in body: if 'mac' in body:
if self.hs.config.worker_app:
raise SynapseError(403, "Not available at this endpoint")
# FIXME: Should we really be determining if this is shared secret # FIXME: Should we really be determining if this is shared secret
# auth based purely on the 'mac' key? # auth based purely on the 'mac' key?
result = yield self._do_shared_secret_registration( result = yield self._do_shared_secret_registration(
@ -456,6 +464,9 @@ class RegisterRestServlet(RestServlet):
) )
yield self.registration_handler.post_consent_actions(registered_user_id) yield self.registration_handler.post_consent_actions(registered_user_id)
if self.hs.config.worker_app:
self._invalidate_caches_client(registered_user_id)
defer.returnValue((200, return_dict)) defer.returnValue((200, return_dict))
def on_OPTIONS(self, _): def on_OPTIONS(self, _):

View File

@ -146,6 +146,7 @@ class RegistrationStore(RegistrationWorkerStore,
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(RegistrationStore, self).__init__(db_conn, hs) super(RegistrationStore, self).__init__(db_conn, hs)
self.hs = hs
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.register_background_index_update( self.register_background_index_update(
@ -321,10 +322,12 @@ class RegistrationStore(RegistrationWorkerStore,
(user_id_obj.localpart, create_profile_with_displayname) (user_id_obj.localpart, create_profile_with_displayname)
) )
self._invalidate_cache_and_stream( # Don't invalidate here, it will be done through replication to the worker.
txn, self.get_user_by_id, (user_id,) if not self.hs.config.worker_app:
) self._invalidate_cache_and_stream(
txn.call_after(self.is_guest.invalidate, (user_id,)) txn, self.get_user_by_id, (user_id,)
)
txn.call_after(self.is_guest.invalidate, (user_id,))
def get_users_by_id_case_insensitive(self, user_id): def get_users_by_id_case_insensitive(self, user_id):
"""Gets users that match user_id case insensitively. """Gets users that match user_id case insensitively.