Add ability for access tokens to belong to one user but grant access to another user. (#8616)

We do it this way round so that only the "owner" can delete the access token (i.e. `/logout/all` by the "owner" also deletes that token, but `/logout/all` by the "target user" doesn't).

A future PR will add an API for creating such a token.

When the target user and authenticated entity are different the `Processed request` log line will be logged with a: `{@admin:server as @bob:server} ...`. I'm not convinced by that format (especially since it adds spaces in there, making it harder to use `cut -d ' '` to chop off the start of log lines). Suggestions welcome.
pull/8689/head
Erik Johnston 2020-10-29 15:58:44 +00:00 committed by GitHub
parent 22eeb6bc54
commit f21e24ffc2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 197 additions and 138 deletions

1
changelog.d/8616.misc Normal file
View File

@ -0,0 +1 @@
Change schema to support access tokens belonging to one user but granting access to another.

View File

@ -33,6 +33,7 @@ from synapse.api.errors import (
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase from synapse.events import EventBase
from synapse.logging import opentracing as opentracing from synapse.logging import opentracing as opentracing
from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import StateMap, UserID from synapse.types import StateMap, UserID
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
@ -190,10 +191,6 @@ class Auth:
user_id, app_service = await self._get_appservice_user_id(request) user_id, app_service = await self._get_appservice_user_id(request)
if user_id: if user_id:
request.authenticated_entity = user_id
opentracing.set_tag("authenticated_entity", user_id)
opentracing.set_tag("appservice_id", app_service.id)
if ip_addr and self._track_appservice_user_ips: if ip_addr and self._track_appservice_user_ips:
await self.store.insert_client_ip( await self.store.insert_client_ip(
user_id=user_id, user_id=user_id,
@ -203,31 +200,38 @@ class Auth:
device_id="dummy-device", # stubbed device_id="dummy-device", # stubbed
) )
return synapse.types.create_requester(user_id, app_service=app_service) requester = synapse.types.create_requester(
user_id, app_service=app_service
)
request.requester = user_id
opentracing.set_tag("authenticated_entity", user_id)
opentracing.set_tag("user_id", user_id)
opentracing.set_tag("appservice_id", app_service.id)
return requester
user_info = await self.get_user_by_access_token( user_info = await self.get_user_by_access_token(
access_token, rights, allow_expired=allow_expired access_token, rights, allow_expired=allow_expired
) )
user = user_info["user"] token_id = user_info.token_id
token_id = user_info["token_id"] is_guest = user_info.is_guest
is_guest = user_info["is_guest"] shadow_banned = user_info.shadow_banned
shadow_banned = user_info["shadow_banned"]
# Deny the request if the user account has expired. # Deny the request if the user account has expired.
if self._account_validity.enabled and not allow_expired: if self._account_validity.enabled and not allow_expired:
user_id = user.to_string() if await self.store.is_account_expired(
if await self.store.is_account_expired(user_id, self.clock.time_msec()): user_info.user_id, self.clock.time_msec()
):
raise AuthError( raise AuthError(
403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT 403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT
) )
# device_id may not be present if get_user_by_access_token has been device_id = user_info.device_id
# stubbed out.
device_id = user_info.get("device_id")
if user and access_token and ip_addr: if access_token and ip_addr:
await self.store.insert_client_ip( await self.store.insert_client_ip(
user_id=user.to_string(), user_id=user_info.token_owner,
access_token=access_token, access_token=access_token,
ip=ip_addr, ip=ip_addr,
user_agent=user_agent, user_agent=user_agent,
@ -241,19 +245,23 @@ class Auth:
errcode=Codes.GUEST_ACCESS_FORBIDDEN, errcode=Codes.GUEST_ACCESS_FORBIDDEN,
) )
request.authenticated_entity = user.to_string() requester = synapse.types.create_requester(
opentracing.set_tag("authenticated_entity", user.to_string()) user_info.user_id,
if device_id:
opentracing.set_tag("device_id", device_id)
return synapse.types.create_requester(
user,
token_id, token_id,
is_guest, is_guest,
shadow_banned, shadow_banned,
device_id, device_id,
app_service=app_service, app_service=app_service,
authenticated_entity=user_info.token_owner,
) )
request.requester = requester
opentracing.set_tag("authenticated_entity", user_info.token_owner)
opentracing.set_tag("user_id", user_info.user_id)
if device_id:
opentracing.set_tag("device_id", device_id)
return requester
except KeyError: except KeyError:
raise MissingClientTokenError() raise MissingClientTokenError()
@ -284,7 +292,7 @@ class Auth:
async def get_user_by_access_token( async def get_user_by_access_token(
self, token: str, rights: str = "access", allow_expired: bool = False, self, token: str, rights: str = "access", allow_expired: bool = False,
) -> dict: ) -> TokenLookupResult:
""" Validate access token and get user_id from it """ Validate access token and get user_id from it
Args: Args:
@ -293,13 +301,7 @@ class Auth:
allow this allow this
allow_expired: If False, raises an InvalidClientTokenError allow_expired: If False, raises an InvalidClientTokenError
if the token is expired if the token is expired
Returns:
dict that includes:
`user` (UserID)
`is_guest` (bool)
`shadow_banned` (bool)
`token_id` (int|None): access token id. May be None if guest
`device_id` (str|None): device corresponding to access token
Raises: Raises:
InvalidClientTokenError if a user by that token exists, but the token is InvalidClientTokenError if a user by that token exists, but the token is
expired expired
@ -309,9 +311,9 @@ class Auth:
if rights == "access": if rights == "access":
# first look in the database # first look in the database
r = await self._look_up_user_by_access_token(token) r = await self.store.get_user_by_access_token(token)
if r: if r:
valid_until_ms = r["valid_until_ms"] valid_until_ms = r.valid_until_ms
if ( if (
not allow_expired not allow_expired
and valid_until_ms is not None and valid_until_ms is not None
@ -328,7 +330,6 @@ class Auth:
# otherwise it needs to be a valid macaroon # otherwise it needs to be a valid macaroon
try: try:
user_id, guest = self._parse_and_validate_macaroon(token, rights) user_id, guest = self._parse_and_validate_macaroon(token, rights)
user = UserID.from_string(user_id)
if rights == "access": if rights == "access":
if not guest: if not guest:
@ -354,23 +355,17 @@ class Auth:
raise InvalidClientTokenError( raise InvalidClientTokenError(
"Guest access token used for regular user" "Guest access token used for regular user"
) )
ret = {
"user": user, ret = TokenLookupResult(
"is_guest": True, user_id=user_id,
"shadow_banned": False, is_guest=True,
"token_id": None,
# all guests get the same device id # all guests get the same device id
"device_id": GUEST_DEVICE_ID, device_id=GUEST_DEVICE_ID,
} )
elif rights == "delete_pusher": elif rights == "delete_pusher":
# We don't store these tokens in the database # We don't store these tokens in the database
ret = {
"user": user, ret = TokenLookupResult(user_id=user_id, is_guest=False)
"is_guest": False,
"shadow_banned": False,
"token_id": None,
"device_id": None,
}
else: else:
raise RuntimeError("Unknown rights setting %s", rights) raise RuntimeError("Unknown rights setting %s", rights)
return ret return ret
@ -479,31 +474,15 @@ class Auth:
now = self.hs.get_clock().time_msec() now = self.hs.get_clock().time_msec()
return now < expiry return now < expiry
async def _look_up_user_by_access_token(self, token):
ret = await self.store.get_user_by_access_token(token)
if not ret:
return None
# we use ret.get() below because *lots* of unit tests stub out
# get_user_by_access_token in a way where it only returns a couple of
# the fields.
user_info = {
"user": UserID.from_string(ret.get("name")),
"token_id": ret.get("token_id", None),
"is_guest": False,
"shadow_banned": ret.get("shadow_banned"),
"device_id": ret.get("device_id"),
"valid_until_ms": ret.get("valid_until_ms"),
}
return user_info
def get_appservice_by_req(self, request): def get_appservice_by_req(self, request):
token = self.get_access_token_from_request(request) token = self.get_access_token_from_request(request)
service = self.store.get_app_service_by_token(token) service = self.store.get_app_service_by_token(token)
if not service: if not service:
logger.warning("Unrecognised appservice access token.") logger.warning("Unrecognised appservice access token.")
raise InvalidClientTokenError() raise InvalidClientTokenError()
request.authenticated_entity = service.sender request.requester = synapse.types.create_requester(
service.sender, app_service=service
)
return service return service
async def is_server_admin(self, user: UserID) -> bool: async def is_server_admin(self, user: UserID) -> bool:

View File

@ -52,11 +52,11 @@ class ApplicationService:
self, self,
token, token,
hostname, hostname,
id,
sender,
url=None, url=None,
namespaces=None, namespaces=None,
hs_token=None, hs_token=None,
sender=None,
id=None,
protocols=None, protocols=None,
rate_limited=True, rate_limited=True,
ip_range_whitelist=None, ip_range_whitelist=None,

View File

@ -154,7 +154,7 @@ class Authenticator:
) )
logger.debug("Request from %s", origin) logger.debug("Request from %s", origin)
request.authenticated_entity = origin request.requester = origin
# If we get a valid signed request from the other side, its probably # If we get a valid signed request from the other side, its probably
# alive # alive

View File

@ -991,17 +991,17 @@ class AuthHandler(BaseHandler):
# This might return an awaitable, if it does block the log out # This might return an awaitable, if it does block the log out
# until it completes. # until it completes.
result = provider.on_logged_out( result = provider.on_logged_out(
user_id=str(user_info["user"]), user_id=user_info.user_id,
device_id=user_info["device_id"], device_id=user_info.device_id,
access_token=access_token, access_token=access_token,
) )
if inspect.isawaitable(result): if inspect.isawaitable(result):
await result await result
# delete pushers associated with this access token # delete pushers associated with this access token
if user_info["token_id"] is not None: if user_info.token_id is not None:
await self.hs.get_pusherpool().remove_pushers_by_access_token( await self.hs.get_pusherpool().remove_pushers_by_access_token(
str(user_info["user"]), (user_info["token_id"],) user_info.user_id, (user_info.token_id,)
) )
async def delete_access_tokens_for_user( async def delete_access_tokens_for_user(

View File

@ -115,7 +115,10 @@ class RegistrationHandler(BaseHandler):
400, "User ID already taken.", errcode=Codes.USER_IN_USE 400, "User ID already taken.", errcode=Codes.USER_IN_USE
) )
user_data = await self.auth.get_user_by_access_token(guest_access_token) user_data = await self.auth.get_user_by_access_token(guest_access_token)
if not user_data["is_guest"] or user_data["user"].localpart != localpart: if (
not user_data.is_guest
or UserID.from_string(user_data.user_id).localpart != localpart
):
raise AuthError( raise AuthError(
403, 403,
"Cannot register taken user ID without valid guest " "Cannot register taken user ID without valid guest "
@ -741,7 +744,7 @@ class RegistrationHandler(BaseHandler):
# up when the access token is saved, but that's quite an # up when the access token is saved, but that's quite an
# invasive change I'd rather do separately. # invasive change I'd rather do separately.
user_tuple = await self.store.get_user_by_access_token(token) user_tuple = await self.store.get_user_by_access_token(token)
token_id = user_tuple["token_id"] token_id = user_tuple.token_id
await self.pusher_pool.add_pusher( await self.pusher_pool.add_pusher(
user_id=user_id, user_id=user_id,

View File

@ -14,7 +14,7 @@
import contextlib import contextlib
import logging import logging
import time import time
from typing import Optional from typing import Optional, Union
from twisted.python.failure import Failure from twisted.python.failure import Failure
from twisted.web.server import Request, Site from twisted.web.server import Request, Site
@ -23,6 +23,7 @@ from synapse.config.server import ListenerConfig
from synapse.http import redact_uri from synapse.http import redact_uri
from synapse.http.request_metrics import RequestMetrics, requests_counter from synapse.http.request_metrics import RequestMetrics, requests_counter
from synapse.logging.context import LoggingContext, PreserveLoggingContext from synapse.logging.context import LoggingContext, PreserveLoggingContext
from synapse.types import Requester
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -54,9 +55,12 @@ class SynapseRequest(Request):
Request.__init__(self, channel, *args, **kw) Request.__init__(self, channel, *args, **kw)
self.site = channel.site self.site = channel.site
self._channel = channel # this is used by the tests self._channel = channel # this is used by the tests
self.authenticated_entity = None
self.start_time = 0.0 self.start_time = 0.0
# The requester, if authenticated. For federation requests this is the
# server name, for client requests this is the Requester object.
self.requester = None # type: Optional[Union[Requester, str]]
# we can't yet create the logcontext, as we don't know the method. # we can't yet create the logcontext, as we don't know the method.
self.logcontext = None # type: Optional[LoggingContext] self.logcontext = None # type: Optional[LoggingContext]
@ -271,11 +275,23 @@ class SynapseRequest(Request):
# to the client (nb may be negative) # to the client (nb may be negative)
response_send_time = self.finish_time - self._processing_finished_time response_send_time = self.finish_time - self._processing_finished_time
# need to decode as it could be raw utf-8 bytes # Convert the requester into a string that we can log
# from a IDN servname in an auth header authenticated_entity = None
authenticated_entity = self.authenticated_entity if isinstance(self.requester, str):
if authenticated_entity is not None and isinstance(authenticated_entity, bytes): authenticated_entity = self.requester
authenticated_entity = authenticated_entity.decode("utf-8", "replace") elif isinstance(self.requester, Requester):
authenticated_entity = self.requester.authenticated_entity
# If this is a request where the target user doesn't match the user who
# authenticated (e.g. and admin is puppetting a user) then we log both.
if self.requester.user.to_string() != authenticated_entity:
authenticated_entity = "{},{}".format(
authenticated_entity, self.requester.user.to_string(),
)
elif self.requester is not None:
# This shouldn't happen, but we log it so we don't lose information
# and can see that we're doing something wrong.
authenticated_entity = repr(self.requester) # type: ignore[unreachable]
# ...or could be raw utf-8 bytes in the User-Agent header. # ...or could be raw utf-8 bytes in the User-Agent header.
# N.B. if you don't do this, the logger explodes cryptically # N.B. if you don't do this, the logger explodes cryptically

View File

@ -77,8 +77,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
requester = Requester.deserialize(self.store, content["requester"]) requester = Requester.deserialize(self.store, content["requester"])
if requester.user: request.requester = requester
request.authenticated_entity = requester.user.to_string()
logger.info("remote_join: %s into room: %s", user_id, room_id) logger.info("remote_join: %s into room: %s", user_id, room_id)
@ -142,8 +141,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
requester = Requester.deserialize(self.store, content["requester"]) requester = Requester.deserialize(self.store, content["requester"])
if requester.user: request.requester = requester
request.authenticated_entity = requester.user.to_string()
# hopefully we're now on the master, so this won't recurse! # hopefully we're now on the master, so this won't recurse!
event_id, stream_id = await self.member_handler.remote_reject_invite( event_id, stream_id = await self.member_handler.remote_reject_invite(

View File

@ -115,8 +115,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
ratelimit = content["ratelimit"] ratelimit = content["ratelimit"]
extra_users = [UserID.from_string(u) for u in content["extra_users"]] extra_users = [UserID.from_string(u) for u in content["extra_users"]]
if requester.user: request.requester = requester
request.authenticated_entity = requester.user.to_string()
logger.info( logger.info(
"Got event to send with ID: %s into room: %s", event.event_id, event.room_id "Got event to send with ID: %s into room: %s", event.event_id, event.room_id

View File

@ -18,6 +18,8 @@ import logging
import re import re
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import attr
from synapse.api.constants import UserTypes from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.background_process_metrics import wrap_as_background_process
@ -38,6 +40,35 @@ THIRTY_MINUTES_IN_MS = 30 * 60 * 1000
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@attr.s(frozen=True, slots=True)
class TokenLookupResult:
"""Result of looking up an access token.
Attributes:
user_id: The user that this token authenticates as
is_guest
shadow_banned
token_id: The ID of the access token looked up
device_id: The device associated with the token, if any.
valid_until_ms: The timestamp the token expires, if any.
token_owner: The "owner" of the token. This is either the same as the
user, or a server admin who is logged in as the user.
"""
user_id = attr.ib(type=str)
is_guest = attr.ib(type=bool, default=False)
shadow_banned = attr.ib(type=bool, default=False)
token_id = attr.ib(type=Optional[int], default=None)
device_id = attr.ib(type=Optional[str], default=None)
valid_until_ms = attr.ib(type=Optional[int], default=None)
token_owner = attr.ib(type=str)
# Make the token owner default to the user ID, which is the common case.
@token_owner.default
def _default_token_owner(self):
return self.user_id
class RegistrationWorkerStore(CacheInvalidationWorkerStore): class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
@ -102,15 +133,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
return is_trial return is_trial
@cached() @cached()
async def get_user_by_access_token(self, token: str) -> Optional[dict]: async def get_user_by_access_token(self, token: str) -> Optional[TokenLookupResult]:
"""Get a user from the given access token. """Get a user from the given access token.
Args: Args:
token: The access token of a user. token: The access token of a user.
Returns: Returns:
None, if the token did not match, otherwise dict None, if the token did not match, otherwise a `TokenLookupResult`
including the keys `name`, `is_guest`, `device_id`, `token_id`,
`valid_until_ms`.
""" """
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_user_by_access_token", self._query_for_auth, token "get_user_by_access_token", self._query_for_auth, token
@ -331,23 +360,24 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn) await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
def _query_for_auth(self, txn, token): def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]:
sql = """ sql = """
SELECT users.name, SELECT users.name as user_id,
users.is_guest, users.is_guest,
users.shadow_banned, users.shadow_banned,
access_tokens.id as token_id, access_tokens.id as token_id,
access_tokens.device_id, access_tokens.device_id,
access_tokens.valid_until_ms access_tokens.valid_until_ms,
access_tokens.user_id as token_owner
FROM users FROM users
INNER JOIN access_tokens on users.name = access_tokens.user_id INNER JOIN access_tokens on users.name = COALESCE(puppets_user_id, access_tokens.user_id)
WHERE token = ? WHERE token = ?
""" """
txn.execute(sql, (token,)) txn.execute(sql, (token,))
rows = self.db_pool.cursor_to_dict(txn) rows = self.db_pool.cursor_to_dict(txn)
if rows: if rows:
return rows[0] return TokenLookupResult(**rows[0])
return None return None

View File

@ -0,0 +1,17 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Whether the access token is an admin token for controlling another user.
ALTER TABLE access_tokens ADD COLUMN puppets_user_id TEXT;

View File

@ -29,6 +29,7 @@ from typing import (
Tuple, Tuple,
Type, Type,
TypeVar, TypeVar,
Union,
) )
import attr import attr
@ -38,6 +39,7 @@ from unpaddedbase64 import decode_base64
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.appservice.api import ApplicationService
from synapse.storage.databases.main import DataStore from synapse.storage.databases.main import DataStore
# define a version of typing.Collection that works on python 3.5 # define a version of typing.Collection that works on python 3.5
@ -74,6 +76,7 @@ class Requester(
"shadow_banned", "shadow_banned",
"device_id", "device_id",
"app_service", "app_service",
"authenticated_entity",
], ],
) )
): ):
@ -104,6 +107,7 @@ class Requester(
"shadow_banned": self.shadow_banned, "shadow_banned": self.shadow_banned,
"device_id": self.device_id, "device_id": self.device_id,
"app_server_id": self.app_service.id if self.app_service else None, "app_server_id": self.app_service.id if self.app_service else None,
"authenticated_entity": self.authenticated_entity,
} }
@staticmethod @staticmethod
@ -129,16 +133,18 @@ class Requester(
shadow_banned=input["shadow_banned"], shadow_banned=input["shadow_banned"],
device_id=input["device_id"], device_id=input["device_id"],
app_service=appservice, app_service=appservice,
authenticated_entity=input["authenticated_entity"],
) )
def create_requester( def create_requester(
user_id, user_id: Union[str, "UserID"],
access_token_id=None, access_token_id: Optional[int] = None,
is_guest=False, is_guest: Optional[bool] = False,
shadow_banned=False, shadow_banned: Optional[bool] = False,
device_id=None, device_id: Optional[str] = None,
app_service=None, app_service: Optional["ApplicationService"] = None,
authenticated_entity: Optional[str] = None,
): ):
""" """
Create a new ``Requester`` object Create a new ``Requester`` object
@ -151,14 +157,27 @@ def create_requester(
shadow_banned (bool): True if the user making this request is shadow-banned. shadow_banned (bool): True if the user making this request is shadow-banned.
device_id (str|None): device_id which was set at authentication time device_id (str|None): device_id which was set at authentication time
app_service (ApplicationService|None): the AS requesting on behalf of the user app_service (ApplicationService|None): the AS requesting on behalf of the user
authenticated_entity: The entity that authenticated when making the request.
This is different to the user_id when an admin user or the server is
"puppeting" the user.
Returns: Returns:
Requester Requester
""" """
if not isinstance(user_id, UserID): if not isinstance(user_id, UserID):
user_id = UserID.from_string(user_id) user_id = UserID.from_string(user_id)
if authenticated_entity is None:
authenticated_entity = user_id.to_string()
return Requester( return Requester(
user_id, access_token_id, is_guest, shadow_banned, device_id, app_service user_id,
access_token_id,
is_guest,
shadow_banned,
device_id,
app_service,
authenticated_entity,
) )

View File

@ -29,6 +29,7 @@ from synapse.api.errors import (
MissingClientTokenError, MissingClientTokenError,
ResourceLimitError, ResourceLimitError,
) )
from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import UserID from synapse.types import UserID
from tests import unittest from tests import unittest
@ -61,7 +62,9 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_user_by_req_user_valid_token(self): def test_get_user_by_req_user_valid_token(self):
user_info = {"name": self.test_user, "token_id": "ditto", "device_id": "device"} user_info = TokenLookupResult(
user_id=self.test_user, token_id=5, device_id="device"
)
self.store.get_user_by_access_token = Mock( self.store.get_user_by_access_token = Mock(
return_value=defer.succeed(user_info) return_value=defer.succeed(user_info)
) )
@ -84,7 +87,7 @@ class AuthTestCase(unittest.TestCase):
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN") self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
def test_get_user_by_req_user_missing_token(self): def test_get_user_by_req_user_missing_token(self):
user_info = {"name": self.test_user, "token_id": "ditto"} user_info = TokenLookupResult(user_id=self.test_user, token_id=5)
self.store.get_user_by_access_token = Mock( self.store.get_user_by_access_token = Mock(
return_value=defer.succeed(user_info) return_value=defer.succeed(user_info)
) )
@ -221,7 +224,7 @@ class AuthTestCase(unittest.TestCase):
def test_get_user_from_macaroon(self): def test_get_user_from_macaroon(self):
self.store.get_user_by_access_token = Mock( self.store.get_user_by_access_token = Mock(
return_value=defer.succeed( return_value=defer.succeed(
{"name": "@baldrick:matrix.org", "device_id": "device"} TokenLookupResult(user_id="@baldrick:matrix.org", device_id="device")
) )
) )
@ -237,12 +240,11 @@ class AuthTestCase(unittest.TestCase):
user_info = yield defer.ensureDeferred( user_info = yield defer.ensureDeferred(
self.auth.get_user_by_access_token(macaroon.serialize()) self.auth.get_user_by_access_token(macaroon.serialize())
) )
user = user_info["user"] self.assertEqual(user_id, user_info.user_id)
self.assertEqual(UserID.from_string(user_id), user)
# TODO: device_id should come from the macaroon, but currently comes # TODO: device_id should come from the macaroon, but currently comes
# from the db. # from the db.
self.assertEqual(user_info["device_id"], "device") self.assertEqual(user_info.device_id, "device")
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_guest_user_from_macaroon(self): def test_get_guest_user_from_macaroon(self):
@ -264,10 +266,8 @@ class AuthTestCase(unittest.TestCase):
user_info = yield defer.ensureDeferred( user_info = yield defer.ensureDeferred(
self.auth.get_user_by_access_token(serialized) self.auth.get_user_by_access_token(serialized)
) )
user = user_info["user"] self.assertEqual(user_id, user_info.user_id)
is_guest = user_info["is_guest"] self.assertTrue(user_info.is_guest)
self.assertEqual(UserID.from_string(user_id), user)
self.assertTrue(is_guest)
self.store.get_user_by_id.assert_called_with(user_id) self.store.get_user_by_id.assert_called_with(user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -289,12 +289,9 @@ class AuthTestCase(unittest.TestCase):
if token != tok: if token != tok:
return defer.succeed(None) return defer.succeed(None)
return defer.succeed( return defer.succeed(
{ TokenLookupResult(
"name": USER_ID, user_id=USER_ID, is_guest=False, token_id=1234, device_id="DEVICE",
"is_guest": False, )
"token_id": 1234,
"device_id": "DEVICE",
}
) )
self.store.get_user_by_access_token = get_user self.store.get_user_by_access_token = get_user

View File

@ -43,7 +43,7 @@ class TestRatelimiter(unittest.TestCase):
def test_allowed_appservice_ratelimited_via_can_requester_do_action(self): def test_allowed_appservice_ratelimited_via_can_requester_do_action(self):
appservice = ApplicationService( appservice = ApplicationService(
None, "example.com", id="foo", rate_limited=True, None, "example.com", id="foo", rate_limited=True, sender="@as:example.com",
) )
as_requester = create_requester("@user:example.com", app_service=appservice) as_requester = create_requester("@user:example.com", app_service=appservice)
@ -68,7 +68,7 @@ class TestRatelimiter(unittest.TestCase):
def test_allowed_appservice_via_can_requester_do_action(self): def test_allowed_appservice_via_can_requester_do_action(self):
appservice = ApplicationService( appservice = ApplicationService(
None, "example.com", id="foo", rate_limited=False, None, "example.com", id="foo", rate_limited=False, sender="@as:example.com",
) )
as_requester = create_requester("@user:example.com", app_service=appservice) as_requester = create_requester("@user:example.com", app_service=appservice)

View File

@ -31,6 +31,7 @@ class ApplicationServiceTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.service = ApplicationService( self.service = ApplicationService(
id="unique_identifier", id="unique_identifier",
sender="@as:test",
url="some_url", url="some_url",
token="some_token", token="some_token",
hostname="matrix.org", # only used by get_groups_for_user hostname="matrix.org", # only used by get_groups_for_user

View File

@ -289,7 +289,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
# make sure that our device ID has changed # make sure that our device ID has changed
user_info = self.get_success(self.auth.get_user_by_access_token(access_token)) user_info = self.get_success(self.auth.get_user_by_access_token(access_token))
self.assertEqual(user_info["device_id"], retrieved_device_id) self.assertEqual(user_info.device_id, retrieved_device_id)
# make sure the device has the display name that was set from the login # make sure the device has the display name that was set from the login
res = self.get_success(self.handler.get_device(user_id, retrieved_device_id)) res = self.get_success(self.handler.get_device(user_id, retrieved_device_id))

View File

@ -46,7 +46,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
self.info = self.get_success( self.info = self.get_success(
self.hs.get_datastore().get_user_by_access_token(self.access_token,) self.hs.get_datastore().get_user_by_access_token(self.access_token,)
) )
self.token_id = self.info["token_id"] self.token_id = self.info.token_id
self.requester = create_requester(self.user_id, access_token_id=self.token_id) self.requester = create_requester(self.user_id, access_token_id=self.token_id)

View File

@ -100,7 +100,7 @@ class EmailPusherTests(HomeserverTestCase):
user_tuple = self.get_success( user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(self.access_token) self.hs.get_datastore().get_user_by_access_token(self.access_token)
) )
token_id = user_tuple["token_id"] token_id = user_tuple.token_id
self.pusher = self.get_success( self.pusher = self.get_success(
self.hs.get_pusherpool().add_pusher( self.hs.get_pusherpool().add_pusher(

View File

@ -69,7 +69,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success( user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token) self.hs.get_datastore().get_user_by_access_token(access_token)
) )
token_id = user_tuple["token_id"] token_id = user_tuple.token_id
self.get_success( self.get_success(
self.hs.get_pusherpool().add_pusher( self.hs.get_pusherpool().add_pusher(
@ -181,7 +181,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success( user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token) self.hs.get_datastore().get_user_by_access_token(access_token)
) )
token_id = user_tuple["token_id"] token_id = user_tuple.token_id
self.get_success( self.get_success(
self.hs.get_pusherpool().add_pusher( self.hs.get_pusherpool().add_pusher(
@ -297,7 +297,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success( user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token) self.hs.get_datastore().get_user_by_access_token(access_token)
) )
token_id = user_tuple["token_id"] token_id = user_tuple.token_id
self.get_success( self.get_success(
self.hs.get_pusherpool().add_pusher( self.hs.get_pusherpool().add_pusher(
@ -379,7 +379,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success( user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token) self.hs.get_datastore().get_user_by_access_token(access_token)
) )
token_id = user_tuple["token_id"] token_id = user_tuple.token_id
self.get_success( self.get_success(
self.hs.get_pusherpool().add_pusher( self.hs.get_pusherpool().add_pusher(
@ -452,7 +452,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success( user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token) self.hs.get_datastore().get_user_by_access_token(access_token)
) )
token_id = user_tuple["token_id"] token_id = user_tuple.token_id
self.get_success( self.get_success(
self.hs.get_pusherpool().add_pusher( self.hs.get_pusherpool().add_pusher(

View File

@ -55,7 +55,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
user_dict = self.get_success( user_dict = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token) self.hs.get_datastore().get_user_by_access_token(access_token)
) )
token_id = user_dict["token_id"] token_id = user_dict.token_id
self.get_success( self.get_success(
self.hs.get_pusherpool().add_pusher( self.hs.get_pusherpool().add_pusher(

View File

@ -55,6 +55,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.hs.config.server_name, self.hs.config.server_name,
id="1234", id="1234",
namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
sender="@as:test",
) )
self.hs.get_datastore().services_cache.append(appservice) self.hs.get_datastore().services_cache.append(appservice)

View File

@ -69,11 +69,9 @@ class RegistrationStoreTestCase(unittest.TestCase):
self.store.get_user_by_access_token(self.tokens[1]) self.store.get_user_by_access_token(self.tokens[1])
) )
self.assertDictContainsSubset( self.assertEqual(result.user_id, self.user_id)
{"name": self.user_id, "device_id": self.device_id}, result self.assertEqual(result.device_id, self.device_id)
) self.assertIsNotNone(result.token_id)
self.assertTrue("token_id" in result)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_user_delete_access_tokens(self): def test_user_delete_access_tokens(self):
@ -105,7 +103,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
user = yield defer.ensureDeferred( user = yield defer.ensureDeferred(
self.store.get_user_by_access_token(self.tokens[0]) self.store.get_user_by_access_token(self.tokens[0])
) )
self.assertEqual(self.user_id, user["name"]) self.assertEqual(self.user_id, user.user_id)
# now delete the rest # now delete the rest
yield defer.ensureDeferred(self.store.user_delete_access_tokens(self.user_id)) yield defer.ensureDeferred(self.store.user_delete_access_tokens(self.user_id))