Merge pull request #2621 from matrix-org/rav/refactor_accesstoken_delete

Move access token deletion into auth handler
pull/2624/head
David Baker 2017-11-01 16:26:06 +00:00 committed by GitHub
commit c9b9ef575b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 62 additions and 27 deletions

View File

@ -605,13 +605,58 @@ class AuthHandler(BaseHandler):
if e.code == 404: if e.code == 404:
raise SynapseError(404, "Unknown user", Codes.NOT_FOUND) raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
raise e raise e
yield self.store.user_delete_access_tokens( yield self.delete_access_tokens_for_user(
user_id, except_access_token_id user_id, except_token_id=except_access_token_id,
) )
yield self.hs.get_pusherpool().remove_pushers_by_user( yield self.hs.get_pusherpool().remove_pushers_by_user(
user_id, except_access_token_id user_id, except_access_token_id
) )
@defer.inlineCallbacks
def deactivate_account(self, user_id):
"""Deactivate a user's account
Args:
user_id (str): ID of user to be deactivated
Returns:
Deferred
"""
# FIXME: Theoretically there is a race here wherein user resets
# password using threepid.
yield self.delete_access_tokens_for_user(user_id)
yield self.store.user_delete_threepids(user_id)
yield self.store.user_set_password_hash(user_id, None)
def delete_access_token(self, access_token):
"""Invalidate a single access token
Args:
access_token (str): access token to be deleted
Returns:
Deferred
"""
return self.store.delete_access_token(access_token)
def delete_access_tokens_for_user(self, user_id, except_token_id=None,
device_id=None):
"""Invalidate access tokens belonging to a user
Args:
user_id (str): ID of user the tokens belong to
except_token_id (str|None): access_token ID which should *not* be
deleted
device_id (str|None): ID of device the tokens are associated with.
If None, tokens associated with any device (or no device) will
be deleted
Returns:
Deferred
"""
return self.store.user_delete_access_tokens(
user_id, except_token_id=except_token_id, device_id=device_id,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def add_threepid(self, user_id, medium, address, validated_at): def add_threepid(self, user_id, medium, address, validated_at):
# 'Canonicalise' email addresses down to lower case. # 'Canonicalise' email addresses down to lower case.

View File

@ -34,6 +34,7 @@ class DeviceHandler(BaseHandler):
self.hs = hs self.hs = hs
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self._auth_handler = hs.get_auth_handler()
self.federation_sender = hs.get_federation_sender() self.federation_sender = hs.get_federation_sender()
self.federation = hs.get_replication_layer() self.federation = hs.get_replication_layer()
@ -159,7 +160,7 @@ class DeviceHandler(BaseHandler):
else: else:
raise raise
yield self.store.user_delete_access_tokens( yield self._auth_handler.delete_access_tokens_for_user(
user_id, device_id=device_id, user_id, device_id=device_id,
) )
@ -193,7 +194,7 @@ class DeviceHandler(BaseHandler):
# Delete access tokens and e2e keys for each device. Not optimised as it is not # Delete access tokens and e2e keys for each device. Not optimised as it is not
# considered as part of a critical path. # considered as part of a critical path.
for device_id in device_ids: for device_id in device_ids:
yield self.store.user_delete_access_tokens( yield self._auth_handler.delete_access_tokens_for_user(
user_id, device_id=device_id, user_id, device_id=device_id,
) )
yield self.store.delete_e2e_keys_by_device( yield self.store.delete_e2e_keys_by_device(

View File

@ -36,6 +36,7 @@ class RegistrationHandler(BaseHandler):
super(RegistrationHandler, self).__init__(hs) super(RegistrationHandler, self).__init__(hs)
self.auth = hs.get_auth() self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
self.profile_handler = hs.get_profile_handler() self.profile_handler = hs.get_profile_handler()
self.captcha_client = CaptchaServerHttpClient(hs) self.captcha_client = CaptchaServerHttpClient(hs)
@ -416,7 +417,7 @@ class RegistrationHandler(BaseHandler):
create_profile_with_localpart=user.localpart, create_profile_with_localpart=user.localpart,
) )
else: else:
yield self.store.user_delete_access_tokens(user_id=user_id) yield self._auth_handler.delete_access_tokens_for_user(user_id)
yield self.store.add_access_token_to_user(user_id=user_id, token=token) yield self.store.add_access_token_to_user(user_id=user_id, token=token)
if displayname is not None: if displayname is not None:

View File

@ -137,7 +137,7 @@ class DeactivateAccountRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/admin/deactivate/(?P<target_user_id>[^/]*)") PATTERNS = client_path_patterns("/admin/deactivate/(?P<target_user_id>[^/]*)")
def __init__(self, hs): def __init__(self, hs):
self.store = hs.get_datastore() self._auth_handler = hs.get_auth_handler()
super(DeactivateAccountRestServlet, self).__init__(hs) super(DeactivateAccountRestServlet, self).__init__(hs)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -149,12 +149,7 @@ class DeactivateAccountRestServlet(ClientV1RestServlet):
if not is_admin: if not is_admin:
raise AuthError(403, "You are not a server admin") raise AuthError(403, "You are not a server admin")
# FIXME: Theoretically there is a race here wherein user resets password yield self._auth_handler.deactivate_account(target_user_id)
# using threepid.
yield self.store.user_delete_access_tokens(target_user_id)
yield self.store.user_delete_threepids(target_user_id)
yield self.store.user_set_password_hash(target_user_id, None)
defer.returnValue((200, {})) defer.returnValue((200, {}))

View File

@ -30,7 +30,7 @@ class LogoutRestServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(LogoutRestServlet, self).__init__(hs) super(LogoutRestServlet, self).__init__(hs)
self.store = hs.get_datastore() self._auth_handler = hs.get_auth_handler()
def on_OPTIONS(self, request): def on_OPTIONS(self, request):
return (200, {}) return (200, {})
@ -38,7 +38,7 @@ class LogoutRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
access_token = get_access_token_from_request(request) access_token = get_access_token_from_request(request)
yield self.store.delete_access_token(access_token) yield self._auth_handler.delete_access_token(access_token)
defer.returnValue((200, {})) defer.returnValue((200, {}))
@ -47,8 +47,8 @@ class LogoutAllRestServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(LogoutAllRestServlet, self).__init__(hs) super(LogoutAllRestServlet, self).__init__(hs)
self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
def on_OPTIONS(self, request): def on_OPTIONS(self, request):
return (200, {}) return (200, {})
@ -57,7 +57,7 @@ class LogoutAllRestServlet(ClientV1RestServlet):
def on_POST(self, request): def on_POST(self, request):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string() user_id = requester.user.to_string()
yield self.store.user_delete_access_tokens(user_id) yield self._auth_handler.delete_access_tokens_for_user(user_id)
defer.returnValue((200, {})) defer.returnValue((200, {}))

View File

@ -162,7 +162,6 @@ class DeactivateAccountRestServlet(RestServlet):
def __init__(self, hs): def __init__(self, hs):
self.hs = hs self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
super(DeactivateAccountRestServlet, self).__init__() super(DeactivateAccountRestServlet, self).__init__()
@ -180,7 +179,9 @@ class DeactivateAccountRestServlet(RestServlet):
# allow ASes to dectivate their own users # allow ASes to dectivate their own users
if requester and requester.app_service: if requester and requester.app_service:
yield self._deactivate_account(requester.user.to_string()) yield self.auth_handler.deactivate_account(
requester.user.to_string()
)
defer.returnValue((200, {})) defer.returnValue((200, {}))
authed, result, params, _ = yield self.auth_handler.check_auth([ authed, result, params, _ = yield self.auth_handler.check_auth([
@ -205,17 +206,9 @@ class DeactivateAccountRestServlet(RestServlet):
logger.error("Auth succeeded but no known type!", result.keys()) logger.error("Auth succeeded but no known type!", result.keys())
raise SynapseError(500, "", Codes.UNKNOWN) raise SynapseError(500, "", Codes.UNKNOWN)
yield self._deactivate_account(user_id) yield self.auth_handler.deactivate_account(user_id)
defer.returnValue((200, {})) defer.returnValue((200, {}))
@defer.inlineCallbacks
def _deactivate_account(self, user_id):
# FIXME: Theoretically there is a race here wherein user resets
# password using threepid.
yield self.store.user_delete_access_tokens(user_id)
yield self.store.user_delete_threepids(user_id)
yield self.store.user_set_password_hash(user_id, None)
class EmailThreepidRequestTokenRestServlet(RestServlet): class EmailThreepidRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/3pid/email/requestToken$") PATTERNS = client_v2_patterns("/account/3pid/email/requestToken$")