Make get_user_by_access_token return a proper type

pull/8616/head
Erik Johnston 2020-10-21 11:26:08 +01:00
parent a9f90fa73a
commit bc422e1203
11 changed files with 81 additions and 92 deletions

View File

@ -33,6 +33,7 @@ from synapse.api.errors import (
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase from synapse.events import EventBase
from synapse.logging import opentracing as opentracing from synapse.logging import opentracing as opentracing
from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import StateMap, UserID from synapse.types import StateMap, UserID
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
@ -210,10 +211,10 @@ class Auth:
user_info = await self.get_user_by_access_token( user_info = await self.get_user_by_access_token(
access_token, rights, allow_expired=allow_expired access_token, rights, allow_expired=allow_expired
) )
user = user_info["user"] user = UserID.from_string(user_info.user_id)
token_id = user_info["token_id"] token_id = user_info.token_id
is_guest = user_info["is_guest"] is_guest = user_info.is_guest
shadow_banned = user_info["shadow_banned"] shadow_banned = user_info.shadow_banned
# Deny the request if the user account has expired. # Deny the request if the user account has expired.
if self._account_validity.enabled and not allow_expired: if self._account_validity.enabled and not allow_expired:
@ -223,11 +224,9 @@ class Auth:
403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT 403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT
) )
# device_id may not be present if get_user_by_access_token has been device_id = user_info.device_id
# stubbed out.
device_id = user_info.get("device_id")
if user and access_token and ip_addr: if access_token and ip_addr:
await self.store.insert_client_ip( await self.store.insert_client_ip(
user_id=user.to_string(), user_id=user.to_string(),
access_token=access_token, access_token=access_token,
@ -286,7 +285,7 @@ class Auth:
async def get_user_by_access_token( async def get_user_by_access_token(
self, token: str, rights: str = "access", allow_expired: bool = False, self, token: str, rights: str = "access", allow_expired: bool = False,
) -> dict: ) -> TokenLookupResult:
""" Validate access token and get user_id from it """ Validate access token and get user_id from it
Args: Args:
@ -295,13 +294,7 @@ class Auth:
allow this allow this
allow_expired: If False, raises an InvalidClientTokenError allow_expired: If False, raises an InvalidClientTokenError
if the token is expired if the token is expired
Returns:
dict that includes:
`user` (UserID)
`is_guest` (bool)
`shadow_banned` (bool)
`token_id` (int|None): access token id. May be None if guest
`device_id` (str|None): device corresponding to access token
Raises: Raises:
InvalidClientTokenError if a user by that token exists, but the token is InvalidClientTokenError if a user by that token exists, but the token is
expired expired
@ -311,9 +304,9 @@ class Auth:
if rights == "access": if rights == "access":
# first look in the database # first look in the database
r = await self._look_up_user_by_access_token(token) r = await self.store.get_user_by_access_token(token)
if r: if r:
valid_until_ms = r["valid_until_ms"] valid_until_ms = r.valid_until_ms
if ( if (
not allow_expired not allow_expired
and valid_until_ms is not None and valid_until_ms is not None
@ -330,7 +323,6 @@ class Auth:
# otherwise it needs to be a valid macaroon # otherwise it needs to be a valid macaroon
try: try:
user_id, guest = self._parse_and_validate_macaroon(token, rights) user_id, guest = self._parse_and_validate_macaroon(token, rights)
user = UserID.from_string(user_id)
if rights == "access": if rights == "access":
if not guest: if not guest:
@ -356,23 +348,17 @@ class Auth:
raise InvalidClientTokenError( raise InvalidClientTokenError(
"Guest access token used for regular user" "Guest access token used for regular user"
) )
ret = {
"user": user, ret = TokenLookupResult(
"is_guest": True, user_id=user_id,
"shadow_banned": False, is_guest=True,
"token_id": None,
# all guests get the same device id # all guests get the same device id
"device_id": GUEST_DEVICE_ID, device_id=GUEST_DEVICE_ID,
} )
elif rights == "delete_pusher": elif rights == "delete_pusher":
# We don't store these tokens in the database # We don't store these tokens in the database
ret = {
"user": user, ret = TokenLookupResult(user_id=user_id, is_guest=False)
"is_guest": False,
"shadow_banned": False,
"token_id": None,
"device_id": None,
}
else: else:
raise RuntimeError("Unknown rights setting %s", rights) raise RuntimeError("Unknown rights setting %s", rights)
return ret return ret
@ -481,24 +467,6 @@ class Auth:
now = self.hs.get_clock().time_msec() now = self.hs.get_clock().time_msec()
return now < expiry return now < expiry
async def _look_up_user_by_access_token(self, token):
ret = await self.store.get_user_by_access_token(token)
if not ret:
return None
# we use ret.get() below because *lots* of unit tests stub out
# get_user_by_access_token in a way where it only returns a couple of
# the fields.
user_info = {
"user": UserID.from_string(ret.get("name")),
"token_id": ret.get("token_id", None),
"is_guest": False,
"shadow_banned": ret.get("shadow_banned"),
"device_id": ret.get("device_id"),
"valid_until_ms": ret.get("valid_until_ms"),
}
return user_info
def get_appservice_by_req(self, request): def get_appservice_by_req(self, request):
token = self.get_access_token_from_request(request) token = self.get_access_token_from_request(request)
service = self.store.get_app_service_by_token(token) service = self.store.get_app_service_by_token(token)

View File

@ -984,17 +984,17 @@ class AuthHandler(BaseHandler):
# This might return an awaitable, if it does block the log out # This might return an awaitable, if it does block the log out
# until it completes. # until it completes.
result = provider.on_logged_out( result = provider.on_logged_out(
user_id=str(user_info["user"]), user_id=user_info.user_id,
device_id=user_info["device_id"], device_id=user_info.device_id,
access_token=access_token, access_token=access_token,
) )
if inspect.isawaitable(result): if inspect.isawaitable(result):
await result await result
# delete pushers associated with this access token # delete pushers associated with this access token
if user_info["token_id"] is not None: if user_info.token_id is not None:
await self.hs.get_pusherpool().remove_pushers_by_access_token( await self.hs.get_pusherpool().remove_pushers_by_access_token(
str(user_info["user"]), (user_info["token_id"],) user_info.user_id, (user_info.token_id,)
) )
async def delete_access_tokens_for_user( async def delete_access_tokens_for_user(

View File

@ -115,7 +115,10 @@ class RegistrationHandler(BaseHandler):
400, "User ID already taken.", errcode=Codes.USER_IN_USE 400, "User ID already taken.", errcode=Codes.USER_IN_USE
) )
user_data = await self.auth.get_user_by_access_token(guest_access_token) user_data = await self.auth.get_user_by_access_token(guest_access_token)
if not user_data["is_guest"] or user_data["user"].localpart != localpart: if (
not user_data.is_guest
or UserID.from_string(user_data.user_id).localpart != localpart
):
raise AuthError( raise AuthError(
403, 403,
"Cannot register taken user ID without valid guest " "Cannot register taken user ID without valid guest "
@ -741,7 +744,7 @@ class RegistrationHandler(BaseHandler):
# up when the access token is saved, but that's quite an # up when the access token is saved, but that's quite an
# invasive change I'd rather do separately. # invasive change I'd rather do separately.
user_tuple = await self.store.get_user_by_access_token(token) user_tuple = await self.store.get_user_by_access_token(token)
token_id = user_tuple["token_id"] token_id = user_tuple.token_id
await self.pusher_pool.add_pusher( await self.pusher_pool.add_pusher(
user_id=user_id, user_id=user_id,

View File

@ -18,6 +18,8 @@ import logging
import re import re
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import attr
from synapse.api.constants import UserTypes from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.background_process_metrics import wrap_as_background_process
@ -38,6 +40,27 @@ THIRTY_MINUTES_IN_MS = 30 * 60 * 1000
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@attr.s(frozen=True, slots=True)
class TokenLookupResult:
"""Result of looking up an access token.
Attributes:
user_id: The user that this token authenticates as
is_guest
shadow_banned
token_id: The ID of the access token looked up
device_id: The device associated with the token, if any.
valid_until_ms: The timestamp the token expires, if any.
"""
user_id = attr.ib(type=str)
is_guest = attr.ib(type=bool, default=False)
shadow_banned = attr.ib(type=bool, default=False)
token_id = attr.ib(type=Optional[int], default=None)
device_id = attr.ib(type=Optional[str], default=None)
valid_until_ms = attr.ib(type=Optional[int], default=None)
class RegistrationWorkerStore(CacheInvalidationWorkerStore): class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
@ -102,7 +125,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
return is_trial return is_trial
@cached() @cached()
async def get_user_by_access_token(self, token: str) -> Optional[dict]: async def get_user_by_access_token(self, token: str) -> Optional[TokenLookupResult]:
"""Get a user from the given access token. """Get a user from the given access token.
Args: Args:
@ -331,9 +354,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn) await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
def _query_for_auth(self, txn, token): def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]:
sql = """ sql = """
SELECT users.name, SELECT users.name as user_id,
users.is_guest, users.is_guest,
users.shadow_banned, users.shadow_banned,
access_tokens.id as token_id, access_tokens.id as token_id,
@ -347,7 +370,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
txn.execute(sql, (token,)) txn.execute(sql, (token,))
rows = self.db_pool.cursor_to_dict(txn) rows = self.db_pool.cursor_to_dict(txn)
if rows: if rows:
return rows[0] return TokenLookupResult(**rows[0])
return None return None

View File

@ -29,6 +29,7 @@ from synapse.api.errors import (
MissingClientTokenError, MissingClientTokenError,
ResourceLimitError, ResourceLimitError,
) )
from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import UserID from synapse.types import UserID
from tests import unittest from tests import unittest
@ -61,7 +62,9 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_user_by_req_user_valid_token(self): def test_get_user_by_req_user_valid_token(self):
user_info = {"name": self.test_user, "token_id": "ditto", "device_id": "device"} user_info = TokenLookupResult(
user_id=self.test_user, token_id=5, device_id="device"
)
self.store.get_user_by_access_token = Mock( self.store.get_user_by_access_token = Mock(
return_value=defer.succeed(user_info) return_value=defer.succeed(user_info)
) )
@ -84,7 +87,7 @@ class AuthTestCase(unittest.TestCase):
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN") self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
def test_get_user_by_req_user_missing_token(self): def test_get_user_by_req_user_missing_token(self):
user_info = {"name": self.test_user, "token_id": "ditto"} user_info = TokenLookupResult(user_id=self.test_user, token_id=5)
self.store.get_user_by_access_token = Mock( self.store.get_user_by_access_token = Mock(
return_value=defer.succeed(user_info) return_value=defer.succeed(user_info)
) )
@ -221,7 +224,7 @@ class AuthTestCase(unittest.TestCase):
def test_get_user_from_macaroon(self): def test_get_user_from_macaroon(self):
self.store.get_user_by_access_token = Mock( self.store.get_user_by_access_token = Mock(
return_value=defer.succeed( return_value=defer.succeed(
{"name": "@baldrick:matrix.org", "device_id": "device"} TokenLookupResult(user_id="@baldrick:matrix.org", device_id="device")
) )
) )
@ -237,12 +240,11 @@ class AuthTestCase(unittest.TestCase):
user_info = yield defer.ensureDeferred( user_info = yield defer.ensureDeferred(
self.auth.get_user_by_access_token(macaroon.serialize()) self.auth.get_user_by_access_token(macaroon.serialize())
) )
user = user_info["user"] self.assertEqual(user_id, user_info.user_id)
self.assertEqual(UserID.from_string(user_id), user)
# TODO: device_id should come from the macaroon, but currently comes # TODO: device_id should come from the macaroon, but currently comes
# from the db. # from the db.
self.assertEqual(user_info["device_id"], "device") self.assertEqual(user_info.device_id, "device")
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_guest_user_from_macaroon(self): def test_get_guest_user_from_macaroon(self):
@ -264,10 +266,8 @@ class AuthTestCase(unittest.TestCase):
user_info = yield defer.ensureDeferred( user_info = yield defer.ensureDeferred(
self.auth.get_user_by_access_token(serialized) self.auth.get_user_by_access_token(serialized)
) )
user = user_info["user"] self.assertEqual(user_id, user_info.user_id)
is_guest = user_info["is_guest"] self.assertTrue(user_info.is_guest)
self.assertEqual(UserID.from_string(user_id), user)
self.assertTrue(is_guest)
self.store.get_user_by_id.assert_called_with(user_id) self.store.get_user_by_id.assert_called_with(user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -289,12 +289,9 @@ class AuthTestCase(unittest.TestCase):
if token != tok: if token != tok:
return defer.succeed(None) return defer.succeed(None)
return defer.succeed( return defer.succeed(
{ TokenLookupResult(
"name": USER_ID, user_id=USER_ID, is_guest=False, token_id=1234, device_id="DEVICE",
"is_guest": False, )
"token_id": 1234,
"device_id": "DEVICE",
}
) )
self.store.get_user_by_access_token = get_user self.store.get_user_by_access_token = get_user

View File

@ -289,7 +289,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
# make sure that our device ID has changed # make sure that our device ID has changed
user_info = self.get_success(self.auth.get_user_by_access_token(access_token)) user_info = self.get_success(self.auth.get_user_by_access_token(access_token))
self.assertEqual(user_info["device_id"], retrieved_device_id) self.assertEqual(user_info.device_id, retrieved_device_id)
# make sure the device has the display name that was set from the login # make sure the device has the display name that was set from the login
res = self.get_success(self.handler.get_device(user_id, retrieved_device_id)) res = self.get_success(self.handler.get_device(user_id, retrieved_device_id))

View File

@ -46,7 +46,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
self.info = self.get_success( self.info = self.get_success(
self.hs.get_datastore().get_user_by_access_token(self.access_token,) self.hs.get_datastore().get_user_by_access_token(self.access_token,)
) )
self.token_id = self.info["token_id"] self.token_id = self.info.token_id
self.requester = create_requester(self.user_id, access_token_id=self.token_id) self.requester = create_requester(self.user_id, access_token_id=self.token_id)

View File

@ -100,7 +100,7 @@ class EmailPusherTests(HomeserverTestCase):
user_tuple = self.get_success( user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(self.access_token) self.hs.get_datastore().get_user_by_access_token(self.access_token)
) )
token_id = user_tuple["token_id"] token_id = user_tuple.token_id
self.pusher = self.get_success( self.pusher = self.get_success(
self.hs.get_pusherpool().add_pusher( self.hs.get_pusherpool().add_pusher(

View File

@ -69,7 +69,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success( user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token) self.hs.get_datastore().get_user_by_access_token(access_token)
) )
token_id = user_tuple["token_id"] token_id = user_tuple.token_id
self.get_success( self.get_success(
self.hs.get_pusherpool().add_pusher( self.hs.get_pusherpool().add_pusher(
@ -181,7 +181,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success( user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token) self.hs.get_datastore().get_user_by_access_token(access_token)
) )
token_id = user_tuple["token_id"] token_id = user_tuple.token_id
self.get_success( self.get_success(
self.hs.get_pusherpool().add_pusher( self.hs.get_pusherpool().add_pusher(
@ -297,7 +297,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success( user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token) self.hs.get_datastore().get_user_by_access_token(access_token)
) )
token_id = user_tuple["token_id"] token_id = user_tuple.token_id
self.get_success( self.get_success(
self.hs.get_pusherpool().add_pusher( self.hs.get_pusherpool().add_pusher(
@ -379,7 +379,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success( user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token) self.hs.get_datastore().get_user_by_access_token(access_token)
) )
token_id = user_tuple["token_id"] token_id = user_tuple.token_id
self.get_success( self.get_success(
self.hs.get_pusherpool().add_pusher( self.hs.get_pusherpool().add_pusher(
@ -452,7 +452,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success( user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token) self.hs.get_datastore().get_user_by_access_token(access_token)
) )
token_id = user_tuple["token_id"] token_id = user_tuple.token_id
self.get_success( self.get_success(
self.hs.get_pusherpool().add_pusher( self.hs.get_pusherpool().add_pusher(

View File

@ -55,7 +55,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
user_dict = self.get_success( user_dict = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token) self.hs.get_datastore().get_user_by_access_token(access_token)
) )
token_id = user_dict["token_id"] token_id = user_dict.token_id
self.get_success( self.get_success(
self.hs.get_pusherpool().add_pusher( self.hs.get_pusherpool().add_pusher(

View File

@ -69,11 +69,9 @@ class RegistrationStoreTestCase(unittest.TestCase):
self.store.get_user_by_access_token(self.tokens[1]) self.store.get_user_by_access_token(self.tokens[1])
) )
self.assertDictContainsSubset( self.assertEqual(result.user_id, self.user_id)
{"name": self.user_id, "device_id": self.device_id}, result self.assertEqual(result.device_id, self.device_id)
) self.assertIsNotNone(result.token_id)
self.assertTrue("token_id" in result)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_user_delete_access_tokens(self): def test_user_delete_access_tokens(self):
@ -105,7 +103,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
user = yield defer.ensureDeferred( user = yield defer.ensureDeferred(
self.store.get_user_by_access_token(self.tokens[0]) self.store.get_user_by_access_token(self.tokens[0])
) )
self.assertEqual(self.user_id, user["name"]) self.assertEqual(self.user_id, user.user_id)
# now delete the rest # now delete the rest
yield defer.ensureDeferred(self.store.user_delete_access_tokens(self.user_id)) yield defer.ensureDeferred(self.store.user_delete_access_tokens(self.user_id))