Make get_user_by_access_token return a proper type
parent
a9f90fa73a
commit
bc422e1203
|
@ -33,6 +33,7 @@ from synapse.api.errors import (
|
|||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||
from synapse.events import EventBase
|
||||
from synapse.logging import opentracing as opentracing
|
||||
from synapse.storage.databases.main.registration import TokenLookupResult
|
||||
from synapse.types import StateMap, UserID
|
||||
from synapse.util.caches.lrucache import LruCache
|
||||
from synapse.util.metrics import Measure
|
||||
|
@ -210,10 +211,10 @@ class Auth:
|
|||
user_info = await self.get_user_by_access_token(
|
||||
access_token, rights, allow_expired=allow_expired
|
||||
)
|
||||
user = user_info["user"]
|
||||
token_id = user_info["token_id"]
|
||||
is_guest = user_info["is_guest"]
|
||||
shadow_banned = user_info["shadow_banned"]
|
||||
user = UserID.from_string(user_info.user_id)
|
||||
token_id = user_info.token_id
|
||||
is_guest = user_info.is_guest
|
||||
shadow_banned = user_info.shadow_banned
|
||||
|
||||
# Deny the request if the user account has expired.
|
||||
if self._account_validity.enabled and not allow_expired:
|
||||
|
@ -223,11 +224,9 @@ class Auth:
|
|||
403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT
|
||||
)
|
||||
|
||||
# device_id may not be present if get_user_by_access_token has been
|
||||
# stubbed out.
|
||||
device_id = user_info.get("device_id")
|
||||
device_id = user_info.device_id
|
||||
|
||||
if user and access_token and ip_addr:
|
||||
if access_token and ip_addr:
|
||||
await self.store.insert_client_ip(
|
||||
user_id=user.to_string(),
|
||||
access_token=access_token,
|
||||
|
@ -286,7 +285,7 @@ class Auth:
|
|||
|
||||
async def get_user_by_access_token(
|
||||
self, token: str, rights: str = "access", allow_expired: bool = False,
|
||||
) -> dict:
|
||||
) -> TokenLookupResult:
|
||||
""" Validate access token and get user_id from it
|
||||
|
||||
Args:
|
||||
|
@ -295,13 +294,7 @@ class Auth:
|
|||
allow this
|
||||
allow_expired: If False, raises an InvalidClientTokenError
|
||||
if the token is expired
|
||||
Returns:
|
||||
dict that includes:
|
||||
`user` (UserID)
|
||||
`is_guest` (bool)
|
||||
`shadow_banned` (bool)
|
||||
`token_id` (int|None): access token id. May be None if guest
|
||||
`device_id` (str|None): device corresponding to access token
|
||||
|
||||
Raises:
|
||||
InvalidClientTokenError if a user by that token exists, but the token is
|
||||
expired
|
||||
|
@ -311,9 +304,9 @@ class Auth:
|
|||
|
||||
if rights == "access":
|
||||
# first look in the database
|
||||
r = await self._look_up_user_by_access_token(token)
|
||||
r = await self.store.get_user_by_access_token(token)
|
||||
if r:
|
||||
valid_until_ms = r["valid_until_ms"]
|
||||
valid_until_ms = r.valid_until_ms
|
||||
if (
|
||||
not allow_expired
|
||||
and valid_until_ms is not None
|
||||
|
@ -330,7 +323,6 @@ class Auth:
|
|||
# otherwise it needs to be a valid macaroon
|
||||
try:
|
||||
user_id, guest = self._parse_and_validate_macaroon(token, rights)
|
||||
user = UserID.from_string(user_id)
|
||||
|
||||
if rights == "access":
|
||||
if not guest:
|
||||
|
@ -356,23 +348,17 @@ class Auth:
|
|||
raise InvalidClientTokenError(
|
||||
"Guest access token used for regular user"
|
||||
)
|
||||
ret = {
|
||||
"user": user,
|
||||
"is_guest": True,
|
||||
"shadow_banned": False,
|
||||
"token_id": None,
|
||||
|
||||
ret = TokenLookupResult(
|
||||
user_id=user_id,
|
||||
is_guest=True,
|
||||
# all guests get the same device id
|
||||
"device_id": GUEST_DEVICE_ID,
|
||||
}
|
||||
device_id=GUEST_DEVICE_ID,
|
||||
)
|
||||
elif rights == "delete_pusher":
|
||||
# We don't store these tokens in the database
|
||||
ret = {
|
||||
"user": user,
|
||||
"is_guest": False,
|
||||
"shadow_banned": False,
|
||||
"token_id": None,
|
||||
"device_id": None,
|
||||
}
|
||||
|
||||
ret = TokenLookupResult(user_id=user_id, is_guest=False)
|
||||
else:
|
||||
raise RuntimeError("Unknown rights setting %s", rights)
|
||||
return ret
|
||||
|
@ -481,24 +467,6 @@ class Auth:
|
|||
now = self.hs.get_clock().time_msec()
|
||||
return now < expiry
|
||||
|
||||
async def _look_up_user_by_access_token(self, token):
|
||||
ret = await self.store.get_user_by_access_token(token)
|
||||
if not ret:
|
||||
return None
|
||||
|
||||
# we use ret.get() below because *lots* of unit tests stub out
|
||||
# get_user_by_access_token in a way where it only returns a couple of
|
||||
# the fields.
|
||||
user_info = {
|
||||
"user": UserID.from_string(ret.get("name")),
|
||||
"token_id": ret.get("token_id", None),
|
||||
"is_guest": False,
|
||||
"shadow_banned": ret.get("shadow_banned"),
|
||||
"device_id": ret.get("device_id"),
|
||||
"valid_until_ms": ret.get("valid_until_ms"),
|
||||
}
|
||||
return user_info
|
||||
|
||||
def get_appservice_by_req(self, request):
|
||||
token = self.get_access_token_from_request(request)
|
||||
service = self.store.get_app_service_by_token(token)
|
||||
|
|
|
@ -984,17 +984,17 @@ class AuthHandler(BaseHandler):
|
|||
# This might return an awaitable, if it does block the log out
|
||||
# until it completes.
|
||||
result = provider.on_logged_out(
|
||||
user_id=str(user_info["user"]),
|
||||
device_id=user_info["device_id"],
|
||||
user_id=user_info.user_id,
|
||||
device_id=user_info.device_id,
|
||||
access_token=access_token,
|
||||
)
|
||||
if inspect.isawaitable(result):
|
||||
await result
|
||||
|
||||
# delete pushers associated with this access token
|
||||
if user_info["token_id"] is not None:
|
||||
if user_info.token_id is not None:
|
||||
await self.hs.get_pusherpool().remove_pushers_by_access_token(
|
||||
str(user_info["user"]), (user_info["token_id"],)
|
||||
user_info.user_id, (user_info.token_id,)
|
||||
)
|
||||
|
||||
async def delete_access_tokens_for_user(
|
||||
|
|
|
@ -115,7 +115,10 @@ class RegistrationHandler(BaseHandler):
|
|||
400, "User ID already taken.", errcode=Codes.USER_IN_USE
|
||||
)
|
||||
user_data = await self.auth.get_user_by_access_token(guest_access_token)
|
||||
if not user_data["is_guest"] or user_data["user"].localpart != localpart:
|
||||
if (
|
||||
not user_data.is_guest
|
||||
or UserID.from_string(user_data.user_id).localpart != localpart
|
||||
):
|
||||
raise AuthError(
|
||||
403,
|
||||
"Cannot register taken user ID without valid guest "
|
||||
|
@ -741,7 +744,7 @@ class RegistrationHandler(BaseHandler):
|
|||
# up when the access token is saved, but that's quite an
|
||||
# invasive change I'd rather do separately.
|
||||
user_tuple = await self.store.get_user_by_access_token(token)
|
||||
token_id = user_tuple["token_id"]
|
||||
token_id = user_tuple.token_id
|
||||
|
||||
await self.pusher_pool.add_pusher(
|
||||
user_id=user_id,
|
||||
|
|
|
@ -18,6 +18,8 @@ import logging
|
|||
import re
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
|
||||
import attr
|
||||
|
||||
from synapse.api.constants import UserTypes
|
||||
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
|
||||
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
||||
|
@ -38,6 +40,27 @@ THIRTY_MINUTES_IN_MS = 30 * 60 * 1000
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@attr.s(frozen=True, slots=True)
|
||||
class TokenLookupResult:
|
||||
"""Result of looking up an access token.
|
||||
|
||||
Attributes:
|
||||
user_id: The user that this token authenticates as
|
||||
is_guest
|
||||
shadow_banned
|
||||
token_id: The ID of the access token looked up
|
||||
device_id: The device associated with the token, if any.
|
||||
valid_until_ms: The timestamp the token expires, if any.
|
||||
"""
|
||||
|
||||
user_id = attr.ib(type=str)
|
||||
is_guest = attr.ib(type=bool, default=False)
|
||||
shadow_banned = attr.ib(type=bool, default=False)
|
||||
token_id = attr.ib(type=Optional[int], default=None)
|
||||
device_id = attr.ib(type=Optional[str], default=None)
|
||||
valid_until_ms = attr.ib(type=Optional[int], default=None)
|
||||
|
||||
|
||||
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
@ -102,7 +125,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
return is_trial
|
||||
|
||||
@cached()
|
||||
async def get_user_by_access_token(self, token: str) -> Optional[dict]:
|
||||
async def get_user_by_access_token(self, token: str) -> Optional[TokenLookupResult]:
|
||||
"""Get a user from the given access token.
|
||||
|
||||
Args:
|
||||
|
@ -331,9 +354,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
|
||||
await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
|
||||
|
||||
def _query_for_auth(self, txn, token):
|
||||
def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]:
|
||||
sql = """
|
||||
SELECT users.name,
|
||||
SELECT users.name as user_id,
|
||||
users.is_guest,
|
||||
users.shadow_banned,
|
||||
access_tokens.id as token_id,
|
||||
|
@ -347,7 +370,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
txn.execute(sql, (token,))
|
||||
rows = self.db_pool.cursor_to_dict(txn)
|
||||
if rows:
|
||||
return rows[0]
|
||||
return TokenLookupResult(**rows[0])
|
||||
|
||||
return None
|
||||
|
||||
|
|
|
@ -29,6 +29,7 @@ from synapse.api.errors import (
|
|||
MissingClientTokenError,
|
||||
ResourceLimitError,
|
||||
)
|
||||
from synapse.storage.databases.main.registration import TokenLookupResult
|
||||
from synapse.types import UserID
|
||||
|
||||
from tests import unittest
|
||||
|
@ -61,7 +62,9 @@ class AuthTestCase(unittest.TestCase):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_user_by_req_user_valid_token(self):
|
||||
user_info = {"name": self.test_user, "token_id": "ditto", "device_id": "device"}
|
||||
user_info = TokenLookupResult(
|
||||
user_id=self.test_user, token_id=5, device_id="device"
|
||||
)
|
||||
self.store.get_user_by_access_token = Mock(
|
||||
return_value=defer.succeed(user_info)
|
||||
)
|
||||
|
@ -84,7 +87,7 @@ class AuthTestCase(unittest.TestCase):
|
|||
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
|
||||
|
||||
def test_get_user_by_req_user_missing_token(self):
|
||||
user_info = {"name": self.test_user, "token_id": "ditto"}
|
||||
user_info = TokenLookupResult(user_id=self.test_user, token_id=5)
|
||||
self.store.get_user_by_access_token = Mock(
|
||||
return_value=defer.succeed(user_info)
|
||||
)
|
||||
|
@ -221,7 +224,7 @@ class AuthTestCase(unittest.TestCase):
|
|||
def test_get_user_from_macaroon(self):
|
||||
self.store.get_user_by_access_token = Mock(
|
||||
return_value=defer.succeed(
|
||||
{"name": "@baldrick:matrix.org", "device_id": "device"}
|
||||
TokenLookupResult(user_id="@baldrick:matrix.org", device_id="device")
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -237,12 +240,11 @@ class AuthTestCase(unittest.TestCase):
|
|||
user_info = yield defer.ensureDeferred(
|
||||
self.auth.get_user_by_access_token(macaroon.serialize())
|
||||
)
|
||||
user = user_info["user"]
|
||||
self.assertEqual(UserID.from_string(user_id), user)
|
||||
self.assertEqual(user_id, user_info.user_id)
|
||||
|
||||
# TODO: device_id should come from the macaroon, but currently comes
|
||||
# from the db.
|
||||
self.assertEqual(user_info["device_id"], "device")
|
||||
self.assertEqual(user_info.device_id, "device")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_guest_user_from_macaroon(self):
|
||||
|
@ -264,10 +266,8 @@ class AuthTestCase(unittest.TestCase):
|
|||
user_info = yield defer.ensureDeferred(
|
||||
self.auth.get_user_by_access_token(serialized)
|
||||
)
|
||||
user = user_info["user"]
|
||||
is_guest = user_info["is_guest"]
|
||||
self.assertEqual(UserID.from_string(user_id), user)
|
||||
self.assertTrue(is_guest)
|
||||
self.assertEqual(user_id, user_info.user_id)
|
||||
self.assertTrue(user_info.is_guest)
|
||||
self.store.get_user_by_id.assert_called_with(user_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -289,12 +289,9 @@ class AuthTestCase(unittest.TestCase):
|
|||
if token != tok:
|
||||
return defer.succeed(None)
|
||||
return defer.succeed(
|
||||
{
|
||||
"name": USER_ID,
|
||||
"is_guest": False,
|
||||
"token_id": 1234,
|
||||
"device_id": "DEVICE",
|
||||
}
|
||||
TokenLookupResult(
|
||||
user_id=USER_ID, is_guest=False, token_id=1234, device_id="DEVICE",
|
||||
)
|
||||
)
|
||||
|
||||
self.store.get_user_by_access_token = get_user
|
||||
|
|
|
@ -289,7 +289,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
|
|||
# make sure that our device ID has changed
|
||||
user_info = self.get_success(self.auth.get_user_by_access_token(access_token))
|
||||
|
||||
self.assertEqual(user_info["device_id"], retrieved_device_id)
|
||||
self.assertEqual(user_info.device_id, retrieved_device_id)
|
||||
|
||||
# make sure the device has the display name that was set from the login
|
||||
res = self.get_success(self.handler.get_device(user_id, retrieved_device_id))
|
||||
|
|
|
@ -46,7 +46,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
|
|||
self.info = self.get_success(
|
||||
self.hs.get_datastore().get_user_by_access_token(self.access_token,)
|
||||
)
|
||||
self.token_id = self.info["token_id"]
|
||||
self.token_id = self.info.token_id
|
||||
|
||||
self.requester = create_requester(self.user_id, access_token_id=self.token_id)
|
||||
|
||||
|
|
|
@ -100,7 +100,7 @@ class EmailPusherTests(HomeserverTestCase):
|
|||
user_tuple = self.get_success(
|
||||
self.hs.get_datastore().get_user_by_access_token(self.access_token)
|
||||
)
|
||||
token_id = user_tuple["token_id"]
|
||||
token_id = user_tuple.token_id
|
||||
|
||||
self.pusher = self.get_success(
|
||||
self.hs.get_pusherpool().add_pusher(
|
||||
|
|
|
@ -69,7 +69,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
|||
user_tuple = self.get_success(
|
||||
self.hs.get_datastore().get_user_by_access_token(access_token)
|
||||
)
|
||||
token_id = user_tuple["token_id"]
|
||||
token_id = user_tuple.token_id
|
||||
|
||||
self.get_success(
|
||||
self.hs.get_pusherpool().add_pusher(
|
||||
|
@ -181,7 +181,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
|||
user_tuple = self.get_success(
|
||||
self.hs.get_datastore().get_user_by_access_token(access_token)
|
||||
)
|
||||
token_id = user_tuple["token_id"]
|
||||
token_id = user_tuple.token_id
|
||||
|
||||
self.get_success(
|
||||
self.hs.get_pusherpool().add_pusher(
|
||||
|
@ -297,7 +297,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
|||
user_tuple = self.get_success(
|
||||
self.hs.get_datastore().get_user_by_access_token(access_token)
|
||||
)
|
||||
token_id = user_tuple["token_id"]
|
||||
token_id = user_tuple.token_id
|
||||
|
||||
self.get_success(
|
||||
self.hs.get_pusherpool().add_pusher(
|
||||
|
@ -379,7 +379,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
|||
user_tuple = self.get_success(
|
||||
self.hs.get_datastore().get_user_by_access_token(access_token)
|
||||
)
|
||||
token_id = user_tuple["token_id"]
|
||||
token_id = user_tuple.token_id
|
||||
|
||||
self.get_success(
|
||||
self.hs.get_pusherpool().add_pusher(
|
||||
|
@ -452,7 +452,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
|||
user_tuple = self.get_success(
|
||||
self.hs.get_datastore().get_user_by_access_token(access_token)
|
||||
)
|
||||
token_id = user_tuple["token_id"]
|
||||
token_id = user_tuple.token_id
|
||||
|
||||
self.get_success(
|
||||
self.hs.get_pusherpool().add_pusher(
|
||||
|
|
|
@ -55,7 +55,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
|
|||
user_dict = self.get_success(
|
||||
self.hs.get_datastore().get_user_by_access_token(access_token)
|
||||
)
|
||||
token_id = user_dict["token_id"]
|
||||
token_id = user_dict.token_id
|
||||
|
||||
self.get_success(
|
||||
self.hs.get_pusherpool().add_pusher(
|
||||
|
|
|
@ -69,11 +69,9 @@ class RegistrationStoreTestCase(unittest.TestCase):
|
|||
self.store.get_user_by_access_token(self.tokens[1])
|
||||
)
|
||||
|
||||
self.assertDictContainsSubset(
|
||||
{"name": self.user_id, "device_id": self.device_id}, result
|
||||
)
|
||||
|
||||
self.assertTrue("token_id" in result)
|
||||
self.assertEqual(result.user_id, self.user_id)
|
||||
self.assertEqual(result.device_id, self.device_id)
|
||||
self.assertIsNotNone(result.token_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_user_delete_access_tokens(self):
|
||||
|
@ -105,7 +103,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
|
|||
user = yield defer.ensureDeferred(
|
||||
self.store.get_user_by_access_token(self.tokens[0])
|
||||
)
|
||||
self.assertEqual(self.user_id, user["name"])
|
||||
self.assertEqual(self.user_id, user.user_id)
|
||||
|
||||
# now delete the rest
|
||||
yield defer.ensureDeferred(self.store.user_delete_access_tokens(self.user_id))
|
||||
|
|
Loading…
Reference in New Issue