Compare commits

...

16 Commits

Author SHA1 Message Date
Erik Johnston 4c38c71966 Merge remote-tracking branch 'origin/develop' into erikj/make_room_admin 2020-11-17 13:43:46 +00:00
Erik Johnston f737368a26
Add admin API for logging in as a user (#8617) 2020-11-17 10:51:25 +00:00
Richard van der Hoff 3dc1871219
Merge pull request #8757 from matrix-org/rav/pass_site_to_make_request
Pass a Site into `make_request`
2020-11-16 18:22:24 +00:00
Richard van der Hoff f125895475
Move `wait_until_result` into `FakeChannel` (#8758)
FakeChannel has everything we need, and this more accurately models the real
flow.
2020-11-16 18:21:47 +00:00
Richard van der Hoff c3e3552ec4 fixup test 2020-11-16 15:51:47 +00:00
Andrew Morgan 4f76eef0e8
Generalise _locally_reject_invite (#8751)
`_locally_reject_invite` generates an out-of-band membership event which can be passed to clients, but not other homeservers.

This is used when we fail to reject an invite over federation. If this happens, we instead just generate a leave event locally and send it down /sync, allowing clients to reject invites even if we can't reach the remote homeserver.

A similar flow needs to be put in place for rescinding knocks. If we're unable to contact any remote server from the room we've tried to knock on, we'd still like to generate and store the leave event locally. Hence the need to reuse, and thus generalise, this method.

Separated from #6739.
2020-11-16 15:37:36 +00:00
Richard van der Hoff bebfb9a97b
Merge branch 'develop' into rav/pass_site_to_make_request 2020-11-16 15:22:40 +00:00
Richard van der Hoff 791d7cd6f0
Rename `create_test_json_resource` to `create_test_resource` (#8759)
The root resource isn't necessarily a JsonResource, so rename this method
accordingly, and update a couple of test classes to use the method rather than
directly manipulating self.resource.
2020-11-16 14:45:52 +00:00
Richard van der Hoff ebc405446e
Add a `custom_headers` param to `make_request` (#8760)
Some tests want to set some custom HTTP request headers, so provide a way to do
that before calling requestReceived().
2020-11-16 14:45:22 +00:00
Richard van der Hoff 0d33c53534 changelog 2020-11-15 23:09:03 +00:00
Richard van der Hoff cfd895a22e use global make_request() directly where we have a custom Resource
Where we want to render a request against a specific Resource, call the global
make_request() function rather than the one in HomeserverTestCase, allowing us
to pass in an appropriate `Site`.
2020-11-15 23:09:03 +00:00
Richard van der Hoff 70c0d47989 fix dict handling for make_request() 2020-11-15 23:09:03 +00:00
Richard van der Hoff 9debe657a3 pass a Site into make_request 2020-11-15 23:09:03 +00:00
Richard van der Hoff d3523e3e97 pass a Site into RestHelper 2020-11-15 23:09:03 +00:00
Adrian Wannenmacher f1de4bb58b
Clarify the usecase for an msisdn delegate (#8734)
Signed-off-by: Adrian Wannenmacher <tfld@tfld.dev>
2020-11-14 23:09:36 +00:00
Andrew Morgan e8d0853739
Generalise _maybe_store_room_on_invite (#8754)
There's a handy function called maybe_store_room_on_invite which allows us to create an entry in the rooms table for a room and its version for which we aren't joined to yet, but we can reference when ingesting events about.

This is currently used for invites where we receive some stripped state about the room and pass it down via /sync to the client, without us being in the room yet.

There is a similar requirement for knocking, where we will eventually do the same thing, and need an entry in the rooms table as well. Thus, reusing this function works, however its name needs to be generalised a bit.

Separated out from #6739.
2020-11-13 16:24:04 +00:00
57 changed files with 812 additions and 268 deletions

1
changelog.d/8617.feature Normal file
View File

@ -0,0 +1 @@
Add admin API for logging in as a user.

1
changelog.d/8734.doc Normal file
View File

@ -0,0 +1 @@
Clarify the usecase for an msisdn delegate. Contributed by Adrian Wannenmacher.

1
changelog.d/8751.misc Normal file
View File

@ -0,0 +1 @@
Generalise `RoomMemberHandler._locally_reject_invite` to apply to more flows than just invite.

1
changelog.d/8754.misc Normal file
View File

@ -0,0 +1 @@
Generalise `RoomStore.maybe_store_room_on_invite` to handle other, non-invite membership events.

1
changelog.d/8757.misc Normal file
View File

@ -0,0 +1 @@
Refactor test utilities for injecting HTTP requests.

1
changelog.d/8758.misc Normal file
View File

@ -0,0 +1 @@
Refactor test utilities for injecting HTTP requests.

1
changelog.d/8759.misc Normal file
View File

@ -0,0 +1 @@
Refactor test utilities for injecting HTTP requests.

1
changelog.d/8760.misc Normal file
View File

@ -0,0 +1 @@
Refactor test utilities for injecting HTTP requests.

View File

@ -424,6 +424,41 @@ The following fields are returned in the JSON response body:
- ``next_token``: integer - Indication for pagination. See above. - ``next_token``: integer - Indication for pagination. See above.
- ``total`` - integer - Total number of media. - ``total`` - integer - Total number of media.
Login as a user
===============
Get an access token that can be used to authenticate as that user. Useful for
when admins wish to do actions on behalf of a user.
The API is::
POST /_synapse/admin/v1/users/<user_id>/login
{}
An optional ``valid_until_ms`` field can be specified in the request body as an
integer timestamp that specifies when the token should expire. By default tokens
do not expire.
A response body like the following is returned:
.. code:: json
{
"access_token": "<opaque_access_token_string>"
}
This API does *not* generate a new device for the user, and so will not appear
their ``/devices`` list, and in general the target user should not be able to
tell they have been logged in as.
To expire the token call the standard ``/logout`` API with the token.
Note: The token will expire if the *admin* user calls ``/logout/all`` from any
of their devices, but the token will *not* expire if the target user does the
same.
User devices User devices
============ ============

View File

@ -1230,8 +1230,9 @@ account_validity:
# email will be globally disabled. # email will be globally disabled.
# #
# Additionally, if `msisdn` is not set, registration and password resets via msisdn # Additionally, if `msisdn` is not set, registration and password resets via msisdn
# will be disabled regardless. This is due to Synapse currently not supporting any # will be disabled regardless, and users will not be able to associate an msisdn
# method of sending SMS messages on its own. # identifier to their account. This is due to Synapse currently not supporting
# any method of sending SMS messages on its own.
# #
# To enable using an identity server for operations regarding a particular third-party # To enable using an identity server for operations regarding a particular third-party
# identifier type, set the value to the URL of that identity server as shown in the # identifier type, set the value to the URL of that identity server as shown in the

View File

@ -14,10 +14,12 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Optional
from synapse.api.constants import LimitBlockingTypes, UserTypes from synapse.api.constants import LimitBlockingTypes, UserTypes
from synapse.api.errors import Codes, ResourceLimitError from synapse.api.errors import Codes, ResourceLimitError
from synapse.config.server import is_threepid_reserved from synapse.config.server import is_threepid_reserved
from synapse.types import Requester
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -33,24 +35,47 @@ class AuthBlocking:
self._max_mau_value = hs.config.max_mau_value self._max_mau_value = hs.config.max_mau_value
self._limit_usage_by_mau = hs.config.limit_usage_by_mau self._limit_usage_by_mau = hs.config.limit_usage_by_mau
self._mau_limits_reserved_threepids = hs.config.mau_limits_reserved_threepids self._mau_limits_reserved_threepids = hs.config.mau_limits_reserved_threepids
self._server_name = hs.hostname
async def check_auth_blocking(self, user_id=None, threepid=None, user_type=None): async def check_auth_blocking(
self,
user_id: Optional[str] = None,
threepid: Optional[dict] = None,
user_type: Optional[str] = None,
requester: Optional[Requester] = None,
):
"""Checks if the user should be rejected for some external reason, """Checks if the user should be rejected for some external reason,
such as monthly active user limiting or global disable flag such as monthly active user limiting or global disable flag
Args: Args:
user_id(str|None): If present, checks for presence against existing user_id: If present, checks for presence against existing
MAU cohort MAU cohort
threepid(dict|None): If present, checks for presence against configured threepid: If present, checks for presence against configured
reserved threepid. Used in cases where the user is trying register reserved threepid. Used in cases where the user is trying register
with a MAU blocked server, normally they would be rejected but their with a MAU blocked server, normally they would be rejected but their
threepid is on the reserved list. user_id and threepid is on the reserved list. user_id and
threepid should never be set at the same time. threepid should never be set at the same time.
user_type(str|None): If present, is used to decide whether to check against user_type: If present, is used to decide whether to check against
certain blocking reasons like MAU. certain blocking reasons like MAU.
requester: If present, and the authenticated entity is a user, checks for
presence against existing MAU cohort. Passing in both a `user_id` and
`requester` is an error.
""" """
if requester and user_id:
raise Exception(
"Passed in both 'user_id' and 'requester' to 'check_auth_blocking'"
)
if requester:
if requester.authenticated_entity.startswith("@"):
user_id = requester.authenticated_entity
elif requester.authenticated_entity == self._server_name:
# We never block the server from doing actions on behalf of
# users.
return
# Never fail an auth check for the server notices users or support user # Never fail an auth check for the server notices users or support user
# This can be a problem where event creation is prohibited due to blocking # This can be a problem where event creation is prohibited due to blocking

View File

@ -347,8 +347,9 @@ class RegistrationConfig(Config):
# email will be globally disabled. # email will be globally disabled.
# #
# Additionally, if `msisdn` is not set, registration and password resets via msisdn # Additionally, if `msisdn` is not set, registration and password resets via msisdn
# will be disabled regardless. This is due to Synapse currently not supporting any # will be disabled regardless, and users will not be able to associate an msisdn
# method of sending SMS messages on its own. # identifier to their account. This is due to Synapse currently not supporting
# any method of sending SMS messages on its own.
# #
# To enable using an identity server for operations regarding a particular third-party # To enable using an identity server for operations regarding a particular third-party
# identifier type, set the value to the URL of that identity server as shown in the # identifier type, set the value to the URL of that identity server as shown in the

View File

@ -169,7 +169,9 @@ class BaseHandler:
# and having homeservers have their own users leave keeps more # and having homeservers have their own users leave keeps more
# of that decision-making and control local to the guest-having # of that decision-making and control local to the guest-having
# homeserver. # homeserver.
requester = synapse.types.create_requester(target_user, is_guest=True) requester = synapse.types.create_requester(
target_user, is_guest=True, authenticated_entity=self.server_name
)
handler = self.hs.get_room_member_handler() handler = self.hs.get_room_member_handler()
await handler.update_membership( await handler.update_membership(
requester, requester,

View File

@ -698,8 +698,12 @@ class AuthHandler(BaseHandler):
} }
async def get_access_token_for_user_id( async def get_access_token_for_user_id(
self, user_id: str, device_id: Optional[str], valid_until_ms: Optional[int] self,
): user_id: str,
device_id: Optional[str],
valid_until_ms: Optional[int],
puppets_user_id: Optional[str] = None,
) -> str:
""" """
Creates a new access token for the user with the given user ID. Creates a new access token for the user with the given user ID.
@ -725,13 +729,25 @@ class AuthHandler(BaseHandler):
fmt_expiry = time.strftime( fmt_expiry = time.strftime(
" until %Y-%m-%d %H:%M:%S", time.localtime(valid_until_ms / 1000.0) " until %Y-%m-%d %H:%M:%S", time.localtime(valid_until_ms / 1000.0)
) )
logger.info("Logging in user %s on device %s%s", user_id, device_id, fmt_expiry)
if puppets_user_id:
logger.info(
"Logging in user %s as %s%s", user_id, puppets_user_id, fmt_expiry
)
else:
logger.info(
"Logging in user %s on device %s%s", user_id, device_id, fmt_expiry
)
await self.auth.check_auth_blocking(user_id) await self.auth.check_auth_blocking(user_id)
access_token = self.macaroon_gen.generate_access_token(user_id) access_token = self.macaroon_gen.generate_access_token(user_id)
await self.store.add_access_token_to_user( await self.store.add_access_token_to_user(
user_id, access_token, device_id, valid_until_ms user_id=user_id,
token=access_token,
device_id=device_id,
valid_until_ms=valid_until_ms,
puppets_user_id=puppets_user_id,
) )
# the device *should* have been registered before we got here; however, # the device *should* have been registered before we got here; however,

View File

@ -39,6 +39,7 @@ class DeactivateAccountHandler(BaseHandler):
self._room_member_handler = hs.get_room_member_handler() self._room_member_handler = hs.get_room_member_handler()
self._identity_handler = hs.get_identity_handler() self._identity_handler = hs.get_identity_handler()
self.user_directory_handler = hs.get_user_directory_handler() self.user_directory_handler = hs.get_user_directory_handler()
self._server_name = hs.hostname
# Flag that indicates whether the process to part users from rooms is running # Flag that indicates whether the process to part users from rooms is running
self._user_parter_running = False self._user_parter_running = False
@ -152,7 +153,7 @@ class DeactivateAccountHandler(BaseHandler):
for room in pending_invites: for room in pending_invites:
try: try:
await self._room_member_handler.update_membership( await self._room_member_handler.update_membership(
create_requester(user), create_requester(user, authenticated_entity=self._server_name),
user, user,
room.room_id, room.room_id,
"leave", "leave",
@ -208,7 +209,7 @@ class DeactivateAccountHandler(BaseHandler):
logger.info("User parter parting %r from %r", user_id, room_id) logger.info("User parter parting %r from %r", user_id, room_id)
try: try:
await self._room_member_handler.update_membership( await self._room_member_handler.update_membership(
create_requester(user), create_requester(user, authenticated_entity=self._server_name),
user, user,
room_id, room_id,
"leave", "leave",

View File

@ -67,7 +67,7 @@ from synapse.replication.http.devices import ReplicationUserDevicesResyncRestSer
from synapse.replication.http.federation import ( from synapse.replication.http.federation import (
ReplicationCleanRoomRestServlet, ReplicationCleanRoomRestServlet,
ReplicationFederationSendEventsRestServlet, ReplicationFederationSendEventsRestServlet,
ReplicationStoreRoomOnInviteRestServlet, ReplicationStoreRoomOnOutlierMembershipRestServlet,
) )
from synapse.state import StateResolutionStore from synapse.state import StateResolutionStore
from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.databases.main.events_worker import EventRedactBehaviour
@ -152,12 +152,14 @@ class FederationHandler(BaseHandler):
self._user_device_resync = ReplicationUserDevicesResyncRestServlet.make_client( self._user_device_resync = ReplicationUserDevicesResyncRestServlet.make_client(
hs hs
) )
self._maybe_store_room_on_invite = ReplicationStoreRoomOnInviteRestServlet.make_client( self._maybe_store_room_on_outlier_membership = ReplicationStoreRoomOnOutlierMembershipRestServlet.make_client(
hs hs
) )
else: else:
self._device_list_updater = hs.get_device_handler().device_list_updater self._device_list_updater = hs.get_device_handler().device_list_updater
self._maybe_store_room_on_invite = self.store.maybe_store_room_on_invite self._maybe_store_room_on_outlier_membership = (
self.store.maybe_store_room_on_outlier_membership
)
# When joining a room we need to queue any events for that room up. # When joining a room we need to queue any events for that room up.
# For each room, a list of (pdu, origin) tuples. # For each room, a list of (pdu, origin) tuples.
@ -1617,7 +1619,7 @@ class FederationHandler(BaseHandler):
# keep a record of the room version, if we don't yet know it. # keep a record of the room version, if we don't yet know it.
# (this may get overwritten if we later get a different room version in a # (this may get overwritten if we later get a different room version in a
# join dance). # join dance).
await self._maybe_store_room_on_invite( await self._maybe_store_room_on_outlier_membership(
room_id=event.room_id, room_version=room_version room_id=event.room_id, room_version=room_version
) )

View File

@ -472,7 +472,7 @@ class EventCreationHandler:
Returns: Returns:
Tuple of created event, Context Tuple of created event, Context
""" """
await self.auth.check_auth_blocking(requester.user.to_string()) await self.auth.check_auth_blocking(requester=requester)
if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "": if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "":
room_version = event_dict["content"]["room_version"] room_version = event_dict["content"]["room_version"]
@ -619,7 +619,13 @@ class EventCreationHandler:
if requester.app_service is not None: if requester.app_service is not None:
return return
user_id = requester.user.to_string() user_id = requester.authenticated_entity
if not user_id.startswith("@"):
# The authenticated entity might not be a user, e.g. if it's the
# server puppetting the user.
return
user = UserID.from_string(user_id)
# exempt the system notices user # exempt the system notices user
if ( if (
@ -639,9 +645,7 @@ class EventCreationHandler:
if u["consent_version"] == self.config.user_consent_version: if u["consent_version"] == self.config.user_consent_version:
return return
consent_uri = self._consent_uri_builder.build_user_consent_uri( consent_uri = self._consent_uri_builder.build_user_consent_uri(user.localpart)
requester.user.localpart
)
msg = self._block_events_without_consent_error % {"consent_uri": consent_uri} msg = self._block_events_without_consent_error % {"consent_uri": consent_uri}
raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri) raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri)
@ -1252,7 +1256,7 @@ class EventCreationHandler:
for user_id in members: for user_id in members:
if not self.hs.is_mine_id(user_id): if not self.hs.is_mine_id(user_id):
continue continue
requester = create_requester(user_id) requester = create_requester(user_id, authenticated_entity=self.server_name)
try: try:
event, context = await self.create_event( event, context = await self.create_event(
requester, requester,
@ -1273,11 +1277,6 @@ class EventCreationHandler:
requester, event, context, ratelimit=False, ignore_shadow_ban=True, requester, event, context, ratelimit=False, ignore_shadow_ban=True,
) )
return True return True
except ConsentNotGivenError:
logger.info(
"Failed to send dummy event into room %s for user %s due to "
"lack of consent. Will try another user" % (room_id, user_id)
)
except AuthError: except AuthError:
logger.info( logger.info(
"Failed to send dummy event into room %s for user %s due to " "Failed to send dummy event into room %s for user %s due to "

View File

@ -206,7 +206,9 @@ class ProfileHandler(BaseHandler):
# the join event to update the displayname in the rooms. # the join event to update the displayname in the rooms.
# This must be done by the target user himself. # This must be done by the target user himself.
if by_admin: if by_admin:
requester = create_requester(target_user) requester = create_requester(
target_user, authenticated_entity=requester.authenticated_entity,
)
await self.store.set_profile_displayname( await self.store.set_profile_displayname(
target_user.localpart, displayname_to_set target_user.localpart, displayname_to_set
@ -286,7 +288,9 @@ class ProfileHandler(BaseHandler):
# Same like set_displayname # Same like set_displayname
if by_admin: if by_admin:
requester = create_requester(target_user) requester = create_requester(
target_user, authenticated_entity=requester.authenticated_entity
)
await self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url) await self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url)

View File

@ -52,6 +52,7 @@ class RegistrationHandler(BaseHandler):
self.ratelimiter = hs.get_registration_ratelimiter() self.ratelimiter = hs.get_registration_ratelimiter()
self.macaroon_gen = hs.get_macaroon_generator() self.macaroon_gen = hs.get_macaroon_generator()
self._server_notices_mxid = hs.config.server_notices_mxid self._server_notices_mxid = hs.config.server_notices_mxid
self._server_name = hs.hostname
self.spam_checker = hs.get_spam_checker() self.spam_checker = hs.get_spam_checker()
@ -317,7 +318,8 @@ class RegistrationHandler(BaseHandler):
requires_join = False requires_join = False
if self.hs.config.registration.auto_join_user_id: if self.hs.config.registration.auto_join_user_id:
fake_requester = create_requester( fake_requester = create_requester(
self.hs.config.registration.auto_join_user_id self.hs.config.registration.auto_join_user_id,
authenticated_entity=self._server_name,
) )
# If the room requires an invite, add the user to the list of invites. # If the room requires an invite, add the user to the list of invites.
@ -329,7 +331,9 @@ class RegistrationHandler(BaseHandler):
# being necessary this will occur after the invite was sent. # being necessary this will occur after the invite was sent.
requires_join = True requires_join = True
else: else:
fake_requester = create_requester(user_id) fake_requester = create_requester(
user_id, authenticated_entity=self._server_name
)
# Choose whether to federate the new room. # Choose whether to federate the new room.
if not self.hs.config.registration.autocreate_auto_join_rooms_federated: if not self.hs.config.registration.autocreate_auto_join_rooms_federated:
@ -362,7 +366,9 @@ class RegistrationHandler(BaseHandler):
# created it, then ensure the first user joins it. # created it, then ensure the first user joins it.
if requires_join: if requires_join:
await room_member_handler.update_membership( await room_member_handler.update_membership(
requester=create_requester(user_id), requester=create_requester(
user_id, authenticated_entity=self._server_name
),
target=UserID.from_string(user_id), target=UserID.from_string(user_id),
room_id=info["room_id"], room_id=info["room_id"],
# Since it was just created, there are no remote hosts. # Since it was just created, there are no remote hosts.
@ -370,11 +376,6 @@ class RegistrationHandler(BaseHandler):
action="join", action="join",
ratelimit=False, ratelimit=False,
) )
except ConsentNotGivenError as e:
# Technically not necessary to pull out this error though
# moving away from bare excepts is a good thing to do.
logger.error("Failed to join new user to %r: %r", r, e)
except Exception as e: except Exception as e:
logger.error("Failed to join new user to %r: %r", r, e) logger.error("Failed to join new user to %r: %r", r, e)
@ -426,7 +427,8 @@ class RegistrationHandler(BaseHandler):
if requires_invite: if requires_invite:
await room_member_handler.update_membership( await room_member_handler.update_membership(
requester=create_requester( requester=create_requester(
self.hs.config.registration.auto_join_user_id self.hs.config.registration.auto_join_user_id,
authenticated_entity=self._server_name,
), ),
target=UserID.from_string(user_id), target=UserID.from_string(user_id),
room_id=room_id, room_id=room_id,
@ -437,7 +439,9 @@ class RegistrationHandler(BaseHandler):
# Send the join. # Send the join.
await room_member_handler.update_membership( await room_member_handler.update_membership(
requester=create_requester(user_id), requester=create_requester(
user_id, authenticated_entity=self._server_name
),
target=UserID.from_string(user_id), target=UserID.from_string(user_id),
room_id=room_id, room_id=room_id,
remote_room_hosts=remote_room_hosts, remote_room_hosts=remote_room_hosts,

View File

@ -587,7 +587,7 @@ class RoomCreationHandler(BaseHandler):
""" """
user_id = requester.user.to_string() user_id = requester.user.to_string()
await self.auth.check_auth_blocking(user_id) await self.auth.check_auth_blocking(requester=requester)
if ( if (
self._server_notices_mxid is not None self._server_notices_mxid is not None
@ -1257,7 +1257,9 @@ class RoomShutdownHandler:
400, "User must be our own: %s" % (new_room_user_id,) 400, "User must be our own: %s" % (new_room_user_id,)
) )
room_creator_requester = create_requester(new_room_user_id) room_creator_requester = create_requester(
new_room_user_id, authenticated_entity=requester_user_id
)
info, stream_id = await self._room_creation_handler.create_room( info, stream_id = await self._room_creation_handler.create_room(
room_creator_requester, room_creator_requester,
@ -1297,7 +1299,9 @@ class RoomShutdownHandler:
try: try:
# Kick users from room # Kick users from room
target_requester = create_requester(user_id) target_requester = create_requester(
user_id, authenticated_entity=requester_user_id
)
_, stream_id = await self.room_member_handler.update_membership( _, stream_id = await self.room_member_handler.update_membership(
requester=target_requester, requester=target_requester,
target=target_requester.user, target=target_requester.user,

View File

@ -965,6 +965,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
self.distributor = hs.get_distributor() self.distributor = hs.get_distributor()
self.distributor.declare("user_left_room") self.distributor.declare("user_left_room")
self._server_name = hs.hostname
async def _is_remote_room_too_complex( async def _is_remote_room_too_complex(
self, room_id: str, remote_room_hosts: List[str] self, room_id: str, remote_room_hosts: List[str]
@ -1059,7 +1060,9 @@ class RoomMemberMasterHandler(RoomMemberHandler):
return event_id, stream_id return event_id, stream_id
# The room is too large. Leave. # The room is too large. Leave.
requester = types.create_requester(user, None, False, False, None) requester = types.create_requester(
user, authenticated_entity=self._server_name
)
await self.update_membership( await self.update_membership(
requester=requester, target=user, room_id=room_id, action="leave" requester=requester, target=user, room_id=room_id, action="leave"
) )
@ -1104,32 +1107,34 @@ class RoomMemberMasterHandler(RoomMemberHandler):
# #
logger.warning("Failed to reject invite: %s", e) logger.warning("Failed to reject invite: %s", e)
return await self._locally_reject_invite( return await self._generate_local_out_of_band_leave(
invite_event, txn_id, requester, content invite_event, txn_id, requester, content
) )
async def _locally_reject_invite( async def _generate_local_out_of_band_leave(
self, self,
invite_event: EventBase, previous_membership_event: EventBase,
txn_id: Optional[str], txn_id: Optional[str],
requester: Requester, requester: Requester,
content: JsonDict, content: JsonDict,
) -> Tuple[str, int]: ) -> Tuple[str, int]:
"""Generate a local invite rejection """Generate a local leave event for a room
This is called after we fail to reject an invite via a remote server. It This can be called after we e.g fail to reject an invite via a remote server.
generates an out-of-band membership event locally. It generates an out-of-band membership event locally.
Args: Args:
invite_event: the invite to be rejected previous_membership_event: the previous membership event for this user
txn_id: optional transaction ID supplied by the client txn_id: optional transaction ID supplied by the client
requester: user making the rejection request, according to the access token requester: user making the request, according to the access token
content: additional content to include in the rejection event. content: additional content to include in the leave event.
Normally an empty dict. Normally an empty dict.
"""
room_id = invite_event.room_id Returns:
target_user = invite_event.state_key A tuple containing (event_id, stream_id of the leave event)
"""
room_id = previous_membership_event.room_id
target_user = previous_membership_event.state_key
content["membership"] = Membership.LEAVE content["membership"] = Membership.LEAVE
@ -1141,12 +1146,12 @@ class RoomMemberMasterHandler(RoomMemberHandler):
"state_key": target_user, "state_key": target_user,
} }
# the auth events for the new event are the same as that of the invite, plus # the auth events for the new event are the same as that of the previous event, plus
# the invite itself. # the event itself.
# #
# the prev_events are just the invite. # the prev_events consist solely of the previous membership event.
prev_event_ids = [invite_event.event_id] prev_event_ids = [previous_membership_event.event_id]
auth_event_ids = invite_event.auth_event_ids() + prev_event_ids auth_event_ids = previous_membership_event.auth_event_ids() + prev_event_ids
event, context = await self.event_creation_handler.create_event( event, context = await self.event_creation_handler.create_event(
requester, requester,

View File

@ -31,6 +31,7 @@ from synapse.types import (
Collection, Collection,
JsonDict, JsonDict,
MutableStateMap, MutableStateMap,
Requester,
RoomStreamToken, RoomStreamToken,
StateMap, StateMap,
StreamToken, StreamToken,
@ -260,6 +261,7 @@ class SyncHandler:
async def wait_for_sync_for_user( async def wait_for_sync_for_user(
self, self,
requester: Requester,
sync_config: SyncConfig, sync_config: SyncConfig,
since_token: Optional[StreamToken] = None, since_token: Optional[StreamToken] = None,
timeout: int = 0, timeout: int = 0,
@ -273,7 +275,7 @@ class SyncHandler:
# not been exceeded (if not part of the group by this point, almost certain # not been exceeded (if not part of the group by this point, almost certain
# auth_blocking will occur) # auth_blocking will occur)
user_id = sync_config.user.to_string() user_id = sync_config.user.to_string()
await self.auth.check_auth_blocking(user_id) await self.auth.check_auth_blocking(requester=requester)
res = await self.response_cache.wrap( res = await self.response_cache.wrap(
sync_config.request_key, sync_config.request_key,

View File

@ -49,6 +49,7 @@ class ModuleApi:
self._store = hs.get_datastore() self._store = hs.get_datastore()
self._auth = hs.get_auth() self._auth = hs.get_auth()
self._auth_handler = auth_handler self._auth_handler = auth_handler
self._server_name = hs.hostname
# We expose these as properties below in order to attach a helpful docstring. # We expose these as properties below in order to attach a helpful docstring.
self._http_client = hs.get_simple_http_client() # type: SimpleHttpClient self._http_client = hs.get_simple_http_client() # type: SimpleHttpClient
@ -336,7 +337,9 @@ class ModuleApi:
SynapseError if the event was not allowed. SynapseError if the event was not allowed.
""" """
# Create a requester object # Create a requester object
requester = create_requester(event_dict["sender"]) requester = create_requester(
event_dict["sender"], authenticated_entity=self._server_name
)
# Create and send the event # Create and send the event
( (

View File

@ -254,20 +254,20 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
return 200, {} return 200, {}
class ReplicationStoreRoomOnInviteRestServlet(ReplicationEndpoint): class ReplicationStoreRoomOnOutlierMembershipRestServlet(ReplicationEndpoint):
"""Called to clean up any data in DB for a given room, ready for the """Called to clean up any data in DB for a given room, ready for the
server to join the room. server to join the room.
Request format: Request format:
POST /_synapse/replication/store_room_on_invite/:room_id/:txn_id POST /_synapse/replication/store_room_on_outlier_membership/:room_id/:txn_id
{ {
"room_version": "1", "room_version": "1",
} }
""" """
NAME = "store_room_on_invite" NAME = "store_room_on_outlier_membership"
PATH_ARGS = ("room_id",) PATH_ARGS = ("room_id",)
def __init__(self, hs): def __init__(self, hs):
@ -282,7 +282,7 @@ class ReplicationStoreRoomOnInviteRestServlet(ReplicationEndpoint):
async def _handle_request(self, request, room_id): async def _handle_request(self, request, room_id):
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
room_version = KNOWN_ROOM_VERSIONS[content["room_version"]] room_version = KNOWN_ROOM_VERSIONS[content["room_version"]]
await self.store.maybe_store_room_on_invite(room_id, room_version) await self.store.maybe_store_room_on_outlier_membership(room_id, room_version)
return 200, {} return 200, {}
@ -291,4 +291,4 @@ def register_servlets(hs, http_server):
ReplicationFederationSendEduRestServlet(hs).register(http_server) ReplicationFederationSendEduRestServlet(hs).register(http_server)
ReplicationGetQueryRestServlet(hs).register(http_server) ReplicationGetQueryRestServlet(hs).register(http_server)
ReplicationCleanRoomRestServlet(hs).register(http_server) ReplicationCleanRoomRestServlet(hs).register(http_server)
ReplicationStoreRoomOnInviteRestServlet(hs).register(http_server) ReplicationStoreRoomOnOutlierMembershipRestServlet(hs).register(http_server)

View File

@ -62,6 +62,7 @@ from synapse.rest.admin.users import (
UserRestServletV2, UserRestServletV2,
UsersRestServlet, UsersRestServlet,
UsersRestServletV2, UsersRestServletV2,
UserTokenRestServlet,
WhoisRestServlet, WhoisRestServlet,
) )
from synapse.types import RoomStreamToken from synapse.types import RoomStreamToken
@ -224,6 +225,7 @@ def register_servlets(hs, http_server):
UserAdminServlet(hs).register(http_server) UserAdminServlet(hs).register(http_server)
UserMediaRestServlet(hs).register(http_server) UserMediaRestServlet(hs).register(http_server)
UserMembershipRestServlet(hs).register(http_server) UserMembershipRestServlet(hs).register(http_server)
UserTokenRestServlet(hs).register(http_server)
UserRestServletV2(hs).register(http_server) UserRestServletV2(hs).register(http_server)
UsersRestServletV2(hs).register(http_server) UsersRestServletV2(hs).register(http_server)
DeviceRestServlet(hs).register(http_server) DeviceRestServlet(hs).register(http_server)

View File

@ -312,7 +312,9 @@ class JoinRoomAliasServlet(RestServlet):
400, "%s was not legal room ID or room alias" % (room_identifier,) 400, "%s was not legal room ID or room alias" % (room_identifier,)
) )
fake_requester = create_requester(target_user) fake_requester = create_requester(
target_user, authenticated_entity=requester.authenticated_entity
)
# send invite if room has "JoinRules.INVITE" # send invite if room has "JoinRules.INVITE"
room_state = await self.state_handler.get_current_state(room_id) room_state = await self.state_handler.get_current_state(room_id)

View File

@ -16,7 +16,7 @@ import hashlib
import hmac import hmac
import logging import logging
from http import HTTPStatus from http import HTTPStatus
from typing import Tuple from typing import TYPE_CHECKING, Tuple
from synapse.api.constants import UserTypes from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.api.errors import Codes, NotFoundError, SynapseError
@ -37,6 +37,9 @@ from synapse.rest.admin._base import (
) )
from synapse.types import JsonDict, UserID from synapse.types import JsonDict, UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_GET_PUSHERS_ALLOWED_KEYS = { _GET_PUSHERS_ALLOWED_KEYS = {
@ -828,3 +831,52 @@ class UserMediaRestServlet(RestServlet):
ret["next_token"] = start + len(media) ret["next_token"] = start + len(media)
return 200, ret return 200, ret
class UserTokenRestServlet(RestServlet):
"""An admin API for logging in as a user.
Example:
POST /_synapse/admin/v1/users/@test:example.com/login
{}
200 OK
{
"access_token": "<some_token>"
}
"""
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/login$")
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
async def on_POST(self, request, user_id):
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
auth_user = requester.user
if not self.hs.is_mine_id(user_id):
raise SynapseError(400, "Only local users can be logged in as")
body = parse_json_object_from_request(request, allow_empty_body=True)
valid_until_ms = body.get("valid_until_ms")
if valid_until_ms and not isinstance(valid_until_ms, int):
raise SynapseError(400, "'valid_until_ms' parameter must be an int")
if auth_user.to_string() == user_id:
raise SynapseError(400, "Cannot use admin API to login as self")
token = await self.auth_handler.get_access_token_for_user_id(
user_id=auth_user.to_string(),
device_id=None,
valid_until_ms=valid_until_ms,
puppets_user_id=user_id,
)
return 200, {"access_token": token}

View File

@ -171,6 +171,7 @@ class SyncRestServlet(RestServlet):
) )
with context: with context:
sync_result = await self.sync_handler.wait_for_sync_for_user( sync_result = await self.sync_handler.wait_for_sync_for_user(
requester,
sync_config, sync_config,
since_token=since_token, since_token=since_token,
timeout=timeout, timeout=timeout,

View File

@ -39,6 +39,7 @@ class ServerNoticesManager:
self._room_member_handler = hs.get_room_member_handler() self._room_member_handler = hs.get_room_member_handler()
self._event_creation_handler = hs.get_event_creation_handler() self._event_creation_handler = hs.get_event_creation_handler()
self._is_mine_id = hs.is_mine_id self._is_mine_id = hs.is_mine_id
self._server_name = hs.hostname
self._notifier = hs.get_notifier() self._notifier = hs.get_notifier()
self.server_notices_mxid = self._config.server_notices_mxid self.server_notices_mxid = self._config.server_notices_mxid
@ -72,7 +73,9 @@ class ServerNoticesManager:
await self.maybe_invite_user_to_room(user_id, room_id) await self.maybe_invite_user_to_room(user_id, room_id)
system_mxid = self._config.server_notices_mxid system_mxid = self._config.server_notices_mxid
requester = create_requester(system_mxid) requester = create_requester(
system_mxid, authenticated_entity=self._server_name
)
logger.info("Sending server notice to %s", user_id) logger.info("Sending server notice to %s", user_id)
@ -145,7 +148,9 @@ class ServerNoticesManager:
"avatar_url": self._config.server_notices_mxid_avatar_url, "avatar_url": self._config.server_notices_mxid_avatar_url,
} }
requester = create_requester(self.server_notices_mxid) requester = create_requester(
self.server_notices_mxid, authenticated_entity=self._server_name
)
info, _ = await self._room_creation_handler.create_room( info, _ = await self._room_creation_handler.create_room(
requester, requester,
config={ config={
@ -174,7 +179,9 @@ class ServerNoticesManager:
user_id: The ID of the user to invite. user_id: The ID of the user to invite.
room_id: The ID of the room to invite the user to. room_id: The ID of the room to invite the user to.
""" """
requester = create_requester(self.server_notices_mxid) requester = create_requester(
self.server_notices_mxid, authenticated_entity=self._server_name
)
# Check whether the user has already joined or been invited to this room. If # Check whether the user has already joined or been invited to this room. If
# that's the case, there is no need to re-invite them. # that's the case, there is no need to re-invite them.

View File

@ -1110,6 +1110,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
token: str, token: str,
device_id: Optional[str], device_id: Optional[str],
valid_until_ms: Optional[int], valid_until_ms: Optional[int],
puppets_user_id: Optional[str] = None,
) -> int: ) -> int:
"""Adds an access token for the given user. """Adds an access token for the given user.
@ -1133,6 +1134,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
"token": token, "token": token,
"device_id": device_id, "device_id": device_id,
"valid_until_ms": valid_until_ms, "valid_until_ms": valid_until_ms,
"puppets_user_id": puppets_user_id,
}, },
desc="add_access_token_to_user", desc="add_access_token_to_user",
) )

View File

@ -1240,13 +1240,15 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
logger.error("store_room with room_id=%s failed: %s", room_id, e) logger.error("store_room with room_id=%s failed: %s", room_id, e)
raise StoreError(500, "Problem creating room.") raise StoreError(500, "Problem creating room.")
async def maybe_store_room_on_invite(self, room_id: str, room_version: RoomVersion): async def maybe_store_room_on_outlier_membership(
self, room_id: str, room_version: RoomVersion
):
""" """
When we receive an invite over federation, store the version of the room if we When we receive an invite or any other event over federation that may relate to a room
don't already know the room version. we are not in, store the version of the room if we don't already know the room version.
""" """
await self.db_pool.simple_upsert( await self.db_pool.simple_upsert(
desc="maybe_store_room_on_invite", desc="maybe_store_room_on_outlier_membership",
table="rooms", table="rooms",
keyvalues={"room_id": room_id}, keyvalues={"room_id": room_id},
values={}, values={},

View File

@ -282,7 +282,11 @@ class AuthTestCase(unittest.TestCase):
) )
) )
self.store.add_access_token_to_user.assert_called_with( self.store.add_access_token_to_user.assert_called_with(
USER_ID, token, "DEVICE", None user_id=USER_ID,
token=token,
device_id="DEVICE",
valid_until_ms=None,
puppets_user_id=None,
) )
def get_user(tok): def get_user(tok):

View File

@ -15,6 +15,7 @@
from synapse.app.generic_worker import GenericWorkerServer from synapse.app.generic_worker import GenericWorkerServer
from tests.server import make_request, render
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
@ -55,10 +56,10 @@ class FrontendProxyTests(HomeserverTestCase):
# Grab the resource from the site that was told to listen # Grab the resource from the site that was told to listen
self.assertEqual(len(self.reactor.tcpServers), 1) self.assertEqual(len(self.reactor.tcpServers), 1)
site = self.reactor.tcpServers[0][1] site = self.reactor.tcpServers[0][1]
self.resource = site.resource.children[b"_matrix"].children[b"client"] resource = site.resource.children[b"_matrix"].children[b"client"]
request, channel = self.make_request("PUT", "presence/a/status") request, channel = make_request(self.reactor, site, "PUT", "presence/a/status")
self.render(request) render(request, resource, self.reactor)
# 400 + unrecognised, because nothing is registered # 400 + unrecognised, because nothing is registered
self.assertEqual(channel.code, 400) self.assertEqual(channel.code, 400)
@ -77,10 +78,10 @@ class FrontendProxyTests(HomeserverTestCase):
# Grab the resource from the site that was told to listen # Grab the resource from the site that was told to listen
self.assertEqual(len(self.reactor.tcpServers), 1) self.assertEqual(len(self.reactor.tcpServers), 1)
site = self.reactor.tcpServers[0][1] site = self.reactor.tcpServers[0][1]
self.resource = site.resource.children[b"_matrix"].children[b"client"] resource = site.resource.children[b"_matrix"].children[b"client"]
request, channel = self.make_request("PUT", "presence/a/status") request, channel = make_request(self.reactor, site, "PUT", "presence/a/status")
self.render(request) render(request, resource, self.reactor)
# 401, because the stub servlet still checks authentication # 401, because the stub servlet still checks authentication
self.assertEqual(channel.code, 401) self.assertEqual(channel.code, 401)

View File

@ -20,6 +20,7 @@ from synapse.app.generic_worker import GenericWorkerServer
from synapse.app.homeserver import SynapseHomeServer from synapse.app.homeserver import SynapseHomeServer
from synapse.config.server import parse_listener_def from synapse.config.server import parse_listener_def
from tests.server import make_request, render
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
@ -66,16 +67,16 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
# Grab the resource from the site that was told to listen # Grab the resource from the site that was told to listen
site = self.reactor.tcpServers[0][1] site = self.reactor.tcpServers[0][1]
try: try:
self.resource = site.resource.children[b"_matrix"].children[b"federation"] resource = site.resource.children[b"_matrix"].children[b"federation"]
except KeyError: except KeyError:
if expectation == "no_resource": if expectation == "no_resource":
return return
raise raise
request, channel = self.make_request( request, channel = make_request(
"GET", "/_matrix/federation/v1/openid/userinfo" self.reactor, site, "GET", "/_matrix/federation/v1/openid/userinfo"
) )
self.render(request) render(request, resource, self.reactor)
self.assertEqual(channel.code, 401) self.assertEqual(channel.code, 401)
@ -115,15 +116,15 @@ class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase):
# Grab the resource from the site that was told to listen # Grab the resource from the site that was told to listen
site = self.reactor.tcpServers[0][1] site = self.reactor.tcpServers[0][1]
try: try:
self.resource = site.resource.children[b"_matrix"].children[b"federation"] resource = site.resource.children[b"_matrix"].children[b"federation"]
except KeyError: except KeyError:
if expectation == "no_resource": if expectation == "no_resource":
return return
raise raise
request, channel = self.make_request( request, channel = make_request(
"GET", "/_matrix/federation/v1/openid/userinfo" self.reactor, site, "GET", "/_matrix/federation/v1/openid/userinfo"
) )
self.render(request) render(request, resource, self.reactor)
self.assertEqual(channel.code, 401) self.assertEqual(channel.code, 401)

View File

@ -16,7 +16,7 @@
from synapse.api.errors import Codes, ResourceLimitError from synapse.api.errors import Codes, ResourceLimitError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION from synapse.api.filtering import DEFAULT_FILTER_COLLECTION
from synapse.handlers.sync import SyncConfig from synapse.handlers.sync import SyncConfig
from synapse.types import UserID from synapse.types import UserID, create_requester
import tests.unittest import tests.unittest
import tests.utils import tests.utils
@ -38,6 +38,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
user_id1 = "@user1:test" user_id1 = "@user1:test"
user_id2 = "@user2:test" user_id2 = "@user2:test"
sync_config = self._generate_sync_config(user_id1) sync_config = self._generate_sync_config(user_id1)
requester = create_requester(user_id1)
self.reactor.advance(100) # So we get not 0 time self.reactor.advance(100) # So we get not 0 time
self.auth_blocking._limit_usage_by_mau = True self.auth_blocking._limit_usage_by_mau = True
@ -45,21 +46,26 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
# Check that the happy case does not throw errors # Check that the happy case does not throw errors
self.get_success(self.store.upsert_monthly_active_user(user_id1)) self.get_success(self.store.upsert_monthly_active_user(user_id1))
self.get_success(self.sync_handler.wait_for_sync_for_user(sync_config)) self.get_success(
self.sync_handler.wait_for_sync_for_user(requester, sync_config)
)
# Test that global lock works # Test that global lock works
self.auth_blocking._hs_disabled = True self.auth_blocking._hs_disabled = True
e = self.get_failure( e = self.get_failure(
self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError self.sync_handler.wait_for_sync_for_user(requester, sync_config),
ResourceLimitError,
) )
self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.auth_blocking._hs_disabled = False self.auth_blocking._hs_disabled = False
sync_config = self._generate_sync_config(user_id2) sync_config = self._generate_sync_config(user_id2)
requester = create_requester(user_id2)
e = self.get_failure( e = self.get_failure(
self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError self.sync_handler.wait_for_sync_for_user(requester, sync_config),
ResourceLimitError,
) )
self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)

View File

@ -17,6 +17,7 @@
from synapse.http.additional_resource import AdditionalResource from synapse.http.additional_resource import AdditionalResource
from synapse.http.server import respond_with_json from synapse.http.server import respond_with_json
from tests.server import FakeSite, make_request, render
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
@ -43,20 +44,20 @@ class AdditionalResourceTests(HomeserverTestCase):
def test_async(self): def test_async(self):
handler = _AsyncTestCustomEndpoint({}, None).handle_request handler = _AsyncTestCustomEndpoint({}, None).handle_request
self.resource = AdditionalResource(self.hs, handler) resource = AdditionalResource(self.hs, handler)
request, channel = self.make_request("GET", "/") request, channel = make_request(self.reactor, FakeSite(resource), "GET", "/")
self.render(request) render(request, resource, self.reactor)
self.assertEqual(request.code, 200) self.assertEqual(request.code, 200)
self.assertEqual(channel.json_body, {"some_key": "some_value_async"}) self.assertEqual(channel.json_body, {"some_key": "some_value_async"})
def test_sync(self): def test_sync(self):
handler = _SyncTestCustomEndpoint({}, None).handle_request handler = _SyncTestCustomEndpoint({}, None).handle_request
self.resource = AdditionalResource(self.hs, handler) resource = AdditionalResource(self.hs, handler)
request, channel = self.make_request("GET", "/") request, channel = make_request(self.reactor, FakeSite(resource), "GET", "/")
self.render(request) render(request, resource, self.reactor)
self.assertEqual(request.code, 200) self.assertEqual(request.code, 200)
self.assertEqual(channel.json_body, {"some_key": "some_value_sync"}) self.assertEqual(channel.json_body, {"some_key": "some_value_sync"})

View File

@ -94,12 +94,13 @@ class ModuleApiTestCase(HomeserverTestCase):
self.assertFalse(hasattr(event, "state_key")) self.assertFalse(hasattr(event, "state_key"))
self.assertDictEqual(event.content, content) self.assertDictEqual(event.content, content)
expected_requester = create_requester(
user_id, authenticated_entity=self.hs.hostname
)
# Check that the event was sent # Check that the event was sent
self.event_creation_handler.create_and_send_nonmember_event.assert_called_with( self.event_creation_handler.create_and_send_nonmember_event.assert_called_with(
create_requester(user_id), expected_requester, event_dict, ratelimit=False, ignore_shadow_ban=True,
event_dict,
ratelimit=False,
ignore_shadow_ban=True,
) )
# Create and send a state event # Create and send a state event
@ -128,7 +129,7 @@ class ModuleApiTestCase(HomeserverTestCase):
# Check that the event was sent # Check that the event was sent
self.event_creation_handler.create_and_send_nonmember_event.assert_called_with( self.event_creation_handler.create_and_send_nonmember_event.assert_called_with(
create_requester(user_id), expected_requester,
{ {
"type": "m.room.power_levels", "type": "m.room.power_levels",
"content": content, "content": content,

View File

@ -240,8 +240,8 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
lambda: self._handle_http_replication_attempt(self.hs, 8765), lambda: self._handle_http_replication_attempt(self.hs, 8765),
) )
def create_test_json_resource(self): def create_test_resource(self):
"""Overrides `HomeserverTestCase.create_test_json_resource`. """Overrides `HomeserverTestCase.create_test_resource`.
""" """
# We override this so that it automatically registers all the HTTP # We override this so that it automatically registers all the HTTP
# replication servlets, without having to explicitly do that in all # replication servlets, without having to explicitly do that in all

View File

@ -20,7 +20,7 @@ from synapse.rest.client.v2_alpha import register
from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.rest.client.v2_alpha.test_auth import DummyRecaptchaChecker from tests.rest.client.v2_alpha.test_auth import DummyRecaptchaChecker
from tests.server import FakeChannel from tests.server import FakeChannel, make_request
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -46,8 +46,11 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
"""Test that registration works when using a single client reader worker. """Test that registration works when using a single client reader worker.
""" """
worker_hs = self.make_worker_hs("synapse.app.client_reader") worker_hs = self.make_worker_hs("synapse.app.client_reader")
site = self._hs_to_site[worker_hs]
request_1, channel_1 = self.make_request( request_1, channel_1 = make_request(
self.reactor,
site,
"POST", "POST",
"register", "register",
{"username": "user", "type": "m.login.password", "password": "bar"}, {"username": "user", "type": "m.login.password", "password": "bar"},
@ -59,8 +62,12 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
session = channel_1.json_body["session"] session = channel_1.json_body["session"]
# also complete the dummy auth # also complete the dummy auth
request_2, channel_2 = self.make_request( request_2, channel_2 = make_request(
"POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}} self.reactor,
site,
"POST",
"register",
{"auth": {"session": session, "type": "m.login.dummy"}},
) # type: SynapseRequest, FakeChannel ) # type: SynapseRequest, FakeChannel
self.render_on_worker(worker_hs, request_2) self.render_on_worker(worker_hs, request_2)
self.assertEqual(request_2.code, 200) self.assertEqual(request_2.code, 200)
@ -74,7 +81,10 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
worker_hs_1 = self.make_worker_hs("synapse.app.client_reader") worker_hs_1 = self.make_worker_hs("synapse.app.client_reader")
worker_hs_2 = self.make_worker_hs("synapse.app.client_reader") worker_hs_2 = self.make_worker_hs("synapse.app.client_reader")
request_1, channel_1 = self.make_request( site_1 = self._hs_to_site[worker_hs_1]
request_1, channel_1 = make_request(
self.reactor,
site_1,
"POST", "POST",
"register", "register",
{"username": "user", "type": "m.login.password", "password": "bar"}, {"username": "user", "type": "m.login.password", "password": "bar"},
@ -86,8 +96,13 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
session = channel_1.json_body["session"] session = channel_1.json_body["session"]
# also complete the dummy auth # also complete the dummy auth
request_2, channel_2 = self.make_request( site_2 = self._hs_to_site[worker_hs_2]
"POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}} request_2, channel_2 = make_request(
self.reactor,
site_2,
"POST",
"register",
{"auth": {"session": session, "type": "m.login.dummy"}},
) # type: SynapseRequest, FakeChannel ) # type: SynapseRequest, FakeChannel
self.render_on_worker(worker_hs_2, request_2) self.render_on_worker(worker_hs_2, request_2)
self.assertEqual(request_2.code, 200) self.assertEqual(request_2.code, 200)

View File

@ -28,7 +28,7 @@ from synapse.server import HomeServer
from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import FakeChannel, FakeTransport from tests.server import FakeChannel, FakeSite, FakeTransport, make_request
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -67,14 +67,16 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
The channel for the *client* request and the *outbound* request for The channel for the *client* request and the *outbound* request for
the media which the caller should respond to. the media which the caller should respond to.
""" """
resource = hs.get_media_repository_resource().children[b"download"]
request, channel = self.make_request( request, channel = make_request(
self.reactor,
FakeSite(resource),
"GET", "GET",
"/{}/{}".format(target, media_id), "/{}/{}".format(target, media_id),
shorthand=False, shorthand=False,
access_token=self.access_token, access_token=self.access_token,
) )
request.render(hs.get_media_repository_resource().children[b"download"]) request.render(resource)
self.pump() self.pump()
clients = self.reactor.tcpClients clients = self.reactor.tcpClients

View File

@ -22,6 +22,7 @@ from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import sync from synapse.rest.client.v2_alpha import sync
from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import make_request
from tests.utils import USE_POSTGRES_FOR_TESTS from tests.utils import USE_POSTGRES_FOR_TESTS
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -148,6 +149,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
sync_hs = self.make_worker_hs( sync_hs = self.make_worker_hs(
"synapse.app.generic_worker", {"worker_name": "sync"}, "synapse.app.generic_worker", {"worker_name": "sync"},
) )
sync_hs_site = self._hs_to_site[sync_hs]
# Specially selected room IDs that get persisted on different workers. # Specially selected room IDs that get persisted on different workers.
room_id1 = "!foo:test" room_id1 = "!foo:test"
@ -178,7 +180,9 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
) )
# Do an initial sync so that we're up to date. # Do an initial sync so that we're up to date.
request, channel = self.make_request("GET", "/sync", access_token=access_token) request, channel = make_request(
self.reactor, sync_hs_site, "GET", "/sync", access_token=access_token
)
self.render_on_worker(sync_hs, request) self.render_on_worker(sync_hs, request)
next_batch = channel.json_body["next_batch"] next_batch = channel.json_body["next_batch"]
@ -203,8 +207,12 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
# Check that syncing still gets the new event, despite the gap in the # Check that syncing still gets the new event, despite the gap in the
# stream IDs. # stream IDs.
request, channel = self.make_request( request, channel = make_request(
"GET", "/sync?since={}".format(next_batch), access_token=access_token self.reactor,
sync_hs_site,
"GET",
"/sync?since={}".format(next_batch),
access_token=access_token,
) )
self.render_on_worker(sync_hs, request) self.render_on_worker(sync_hs, request)
@ -230,7 +238,9 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
response = self.helper.send(room_id2, body="Hi!", tok=self.other_access_token) response = self.helper.send(room_id2, body="Hi!", tok=self.other_access_token)
first_event_in_room2 = response["event_id"] first_event_in_room2 = response["event_id"]
request, channel = self.make_request( request, channel = make_request(
self.reactor,
sync_hs_site,
"GET", "GET",
"/sync?since={}".format(vector_clock_token), "/sync?since={}".format(vector_clock_token),
access_token=access_token, access_token=access_token,
@ -254,8 +264,12 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
self.helper.send(room_id1, body="Hi again!", tok=self.other_access_token) self.helper.send(room_id1, body="Hi again!", tok=self.other_access_token)
self.helper.send(room_id2, body="Hi again!", tok=self.other_access_token) self.helper.send(room_id2, body="Hi again!", tok=self.other_access_token)
request, channel = self.make_request( request, channel = make_request(
"GET", "/sync?since={}".format(next_batch), access_token=access_token self.reactor,
sync_hs_site,
"GET",
"/sync?since={}".format(next_batch),
access_token=access_token,
) )
self.render_on_worker(sync_hs, request) self.render_on_worker(sync_hs, request)
@ -269,7 +283,9 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
# Paginating back in the first room should not produce any results, as # Paginating back in the first room should not produce any results, as
# no events have happened in it. This tests that we are correctly # no events have happened in it. This tests that we are correctly
# filtering results based on the vector clock portion. # filtering results based on the vector clock portion.
request, channel = self.make_request( request, channel = make_request(
self.reactor,
sync_hs_site,
"GET", "GET",
"/rooms/{}/messages?from={}&to={}&dir=b".format( "/rooms/{}/messages?from={}&to={}&dir=b".format(
room_id1, prev_batch1, vector_clock_token room_id1, prev_batch1, vector_clock_token
@ -281,7 +297,9 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
# Paginating back on the second room should produce the first event # Paginating back on the second room should produce the first event
# again. This tests that pagination isn't completely broken. # again. This tests that pagination isn't completely broken.
request, channel = self.make_request( request, channel = make_request(
self.reactor,
sync_hs_site,
"GET", "GET",
"/rooms/{}/messages?from={}&to={}&dir=b".format( "/rooms/{}/messages?from={}&to={}&dir=b".format(
room_id2, prev_batch2, vector_clock_token room_id2, prev_batch2, vector_clock_token
@ -295,7 +313,9 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
) )
# Paginating forwards should give the same results # Paginating forwards should give the same results
request, channel = self.make_request( request, channel = make_request(
self.reactor,
sync_hs_site,
"GET", "GET",
"/rooms/{}/messages?from={}&to={}&dir=f".format( "/rooms/{}/messages?from={}&to={}&dir=f".format(
room_id1, vector_clock_token, prev_batch1 room_id1, vector_clock_token, prev_batch1
@ -305,7 +325,9 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
self.render_on_worker(sync_hs, request) self.render_on_worker(sync_hs, request)
self.assertListEqual([], channel.json_body["chunk"]) self.assertListEqual([], channel.json_body["chunk"])
request, channel = self.make_request( request, channel = make_request(
self.reactor,
sync_hs_site,
"GET", "GET",
"/rooms/{}/messages?from={}&to={}&dir=f".format( "/rooms/{}/messages?from={}&to={}&dir=f".format(
room_id2, vector_clock_token, prev_batch2, room_id2, vector_clock_token, prev_batch2,

View File

@ -30,12 +30,13 @@ from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import groups from synapse.rest.client.v2_alpha import groups
from tests import unittest from tests import unittest
from tests.server import FakeSite, make_request
class VersionTestCase(unittest.HomeserverTestCase): class VersionTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v1/server_version" url = "/_synapse/admin/v1/server_version"
def create_test_json_resource(self): def create_test_resource(self):
resource = JsonResource(self.hs) resource = JsonResource(self.hs)
VersionServlet(self.hs).register(resource) VersionServlet(self.hs).register(resource)
return resource return resource
@ -222,8 +223,13 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
def _ensure_quarantined(self, admin_user_tok, server_and_media_id): def _ensure_quarantined(self, admin_user_tok, server_and_media_id):
"""Ensure a piece of media is quarantined when trying to access it.""" """Ensure a piece of media is quarantined when trying to access it."""
request, channel = self.make_request( request, channel = make_request(
"GET", server_and_media_id, shorthand=False, access_token=admin_user_tok, self.reactor,
FakeSite(self.download_resource),
"GET",
server_and_media_id,
shorthand=False,
access_token=admin_user_tok,
) )
request.render(self.download_resource) request.render(self.download_resource)
self.pump(1.0) self.pump(1.0)
@ -287,7 +293,9 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
server_name, media_id = server_name_and_media_id.split("/") server_name, media_id = server_name_and_media_id.split("/")
# Attempt to access the media # Attempt to access the media
request, channel = self.make_request( request, channel = make_request(
self.reactor,
FakeSite(self.download_resource),
"GET", "GET",
server_name_and_media_id, server_name_and_media_id,
shorthand=False, shorthand=False,
@ -462,7 +470,9 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
self._ensure_quarantined(admin_user_tok, server_and_media_id_1) self._ensure_quarantined(admin_user_tok, server_and_media_id_1)
# Attempt to access each piece of media # Attempt to access each piece of media
request, channel = self.make_request( request, channel = make_request(
self.reactor,
FakeSite(self.download_resource),
"GET", "GET",
server_and_media_id_2, server_and_media_id_2,
shorthand=False, shorthand=False,

View File

@ -23,6 +23,7 @@ from synapse.rest.client.v1 import login, profile, room
from synapse.rest.media.v1.filepath import MediaFilePaths from synapse.rest.media.v1.filepath import MediaFilePaths
from tests import unittest from tests import unittest
from tests.server import FakeSite, make_request
class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
@ -124,7 +125,9 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
self.assertEqual(server_name, self.server_name) self.assertEqual(server_name, self.server_name)
# Attempt to access media # Attempt to access media
request, channel = self.make_request( request, channel = make_request(
self.reactor,
FakeSite(download_resource),
"GET", "GET",
server_and_media_id, server_and_media_id,
shorthand=False, shorthand=False,
@ -161,7 +164,9 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
) )
# Attempt to access media # Attempt to access media
request, channel = self.make_request( request, channel = make_request(
self.reactor,
FakeSite(download_resource),
"GET", "GET",
server_and_media_id, server_and_media_id,
shorthand=False, shorthand=False,
@ -535,7 +540,9 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
media_id = server_and_media_id.split("/")[1] media_id = server_and_media_id.split("/")[1]
local_path = self.filepaths.local_media_filepath(media_id) local_path = self.filepaths.local_media_filepath(media_id)
request, channel = self.make_request( request, channel = make_request(
self.reactor,
FakeSite(download_resource),
"GET", "GET",
server_and_media_id, server_and_media_id,
shorthand=False, shorthand=False,

View File

@ -24,8 +24,8 @@ from mock import Mock
import synapse.rest.admin import synapse.rest.admin
from synapse.api.constants import UserTypes from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
from synapse.rest.client.v1 import login, profile, room from synapse.rest.client.v1 import login, logout, profile, room
from synapse.rest.client.v2_alpha import sync from synapse.rest.client.v2_alpha import devices, sync
from tests import unittest from tests import unittest
from tests.test_utils import make_awaitable from tests.test_utils import make_awaitable
@ -1638,3 +1638,244 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
self.assertIn("last_access_ts", m) self.assertIn("last_access_ts", m)
self.assertIn("quarantined_by", m) self.assertIn("quarantined_by", m)
self.assertIn("safe_from_quarantine", m) self.assertIn("safe_from_quarantine", m)
class UserTokenRestTestCase(unittest.HomeserverTestCase):
"""Test for /_synapse/admin/v1/users/<user>/login
"""
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
sync.register_servlets,
room.register_servlets,
devices.register_servlets,
logout.register_servlets,
]
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
self.other_user = self.register_user("user", "pass")
self.other_user_tok = self.login("user", "pass")
self.url = "/_synapse/admin/v1/users/%s/login" % urllib.parse.quote(
self.other_user
)
def _get_token(self) -> str:
request, channel = self.make_request(
"POST", self.url, b"{}", access_token=self.admin_user_tok
)
self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
return channel.json_body["access_token"]
def test_no_auth(self):
"""Try to login as a user without authentication.
"""
request, channel = self.make_request("POST", self.url, b"{}")
self.render(request)
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_not_admin(self):
"""Try to login as a user as a non-admin user.
"""
request, channel = self.make_request(
"POST", self.url, b"{}", access_token=self.other_user_tok
)
self.render(request)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
def test_send_event(self):
"""Test that sending event as a user works.
"""
# Create a room.
room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_tok)
# Login in as the user
puppet_token = self._get_token()
# Test that sending works, and generates the event as the right user.
resp = self.helper.send_event(room_id, "com.example.test", tok=puppet_token)
event_id = resp["event_id"]
event = self.get_success(self.store.get_event(event_id))
self.assertEqual(event.sender, self.other_user)
def test_devices(self):
"""Tests that logging in as a user doesn't create a new device for them.
"""
# Login in as the user
self._get_token()
# Check that we don't see a new device in our devices list
request, channel = self.make_request(
"GET", "devices", b"{}", access_token=self.other_user_tok
)
self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# We should only see the one device (from the login in `prepare`)
self.assertEqual(len(channel.json_body["devices"]), 1)
def test_logout(self):
"""Test that calling `/logout` with the token works.
"""
# Login in as the user
puppet_token = self._get_token()
# Test that we can successfully make a request
request, channel = self.make_request(
"GET", "devices", b"{}", access_token=puppet_token
)
self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# Logout with the puppet token
request, channel = self.make_request(
"POST", "logout", b"{}", access_token=puppet_token
)
self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# The puppet token should no longer work
request, channel = self.make_request(
"GET", "devices", b"{}", access_token=puppet_token
)
self.render(request)
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
# .. but the real user's tokens should still work
request, channel = self.make_request(
"GET", "devices", b"{}", access_token=self.other_user_tok
)
self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
def test_user_logout_all(self):
"""Tests that the target user calling `/logout/all` does *not* expire
the token.
"""
# Login in as the user
puppet_token = self._get_token()
# Test that we can successfully make a request
request, channel = self.make_request(
"GET", "devices", b"{}", access_token=puppet_token
)
self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# Logout all with the real user token
request, channel = self.make_request(
"POST", "logout/all", b"{}", access_token=self.other_user_tok
)
self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# The puppet token should still work
request, channel = self.make_request(
"GET", "devices", b"{}", access_token=puppet_token
)
self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# .. but the real user's tokens shouldn't
request, channel = self.make_request(
"GET", "devices", b"{}", access_token=self.other_user_tok
)
self.render(request)
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
def test_admin_logout_all(self):
"""Tests that the admin user calling `/logout/all` does expire the
token.
"""
# Login in as the user
puppet_token = self._get_token()
# Test that we can successfully make a request
request, channel = self.make_request(
"GET", "devices", b"{}", access_token=puppet_token
)
self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# Logout all with the admin user token
request, channel = self.make_request(
"POST", "logout/all", b"{}", access_token=self.admin_user_tok
)
self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# The puppet token should no longer work
request, channel = self.make_request(
"GET", "devices", b"{}", access_token=puppet_token
)
self.render(request)
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
# .. but the real user's tokens should still work
request, channel = self.make_request(
"GET", "devices", b"{}", access_token=self.other_user_tok
)
self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@unittest.override_config(
{
"public_baseurl": "https://example.org/",
"user_consent": {
"version": "1.0",
"policy_name": "My Cool Privacy Policy",
"template_dir": "/",
"require_at_registration": True,
"block_events_error": "You should accept the policy",
},
"form_secret": "123secret",
}
)
def test_consent(self):
"""Test that sending a message is not subject to the privacy policies.
"""
# Have the admin user accept the terms.
self.get_success(self.store.user_set_consent_version(self.admin_user, "1.0"))
# First, cheekily accept the terms and create a room
self.get_success(self.store.user_set_consent_version(self.other_user, "1.0"))
room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_tok)
self.helper.send_event(room_id, "com.example.test", tok=self.other_user_tok)
# Now unaccept it and check that we can't send an event
self.get_success(self.store.user_set_consent_version(self.other_user, "0.0"))
self.helper.send_event(
room_id, "com.example.test", tok=self.other_user_tok, expect_code=403
)
# Login in as the user
puppet_token = self._get_token()
# Sending an event on their behalf should work fine
self.helper.send_event(room_id, "com.example.test", tok=puppet_token)
@override_config(
{"limit_usage_by_mau": True, "max_mau_value": 1, "mau_trial_days": 0}
)
def test_mau_limit(self):
# Create a room as the admin user. This will bump the monthly active users to 1.
room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
# Trying to join as the other user should fail due to reaching MAU limit.
self.helper.join(
room_id, user=self.other_user, tok=self.other_user_tok, expect_code=403
)
# Logging in as the other user and joining a room should work, even
# though the MAU limit would stop the user doing so.
puppet_token = self._get_token()
self.helper.join(room_id, user=self.other_user, tok=puppet_token)

View File

@ -21,7 +21,7 @@ from synapse.rest.client.v1 import login, room
from synapse.rest.consent import consent_resource from synapse.rest.consent import consent_resource
from tests import unittest from tests import unittest
from tests.server import render from tests.server import FakeSite, make_request, render
class ConsentResourceTestCase(unittest.HomeserverTestCase): class ConsentResourceTestCase(unittest.HomeserverTestCase):
@ -61,7 +61,9 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
def test_render_public_consent(self): def test_render_public_consent(self):
"""You can observe the terms form without specifying a user""" """You can observe the terms form without specifying a user"""
resource = consent_resource.ConsentResource(self.hs) resource = consent_resource.ConsentResource(self.hs)
request, channel = self.make_request("GET", "/consent?v=1", shorthand=False) request, channel = make_request(
self.reactor, FakeSite(resource), "GET", "/consent?v=1", shorthand=False
)
render(request, resource, self.reactor) render(request, resource, self.reactor)
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200)
@ -81,8 +83,13 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
uri_builder.build_user_consent_uri(user_id).replace("_matrix/", "") uri_builder.build_user_consent_uri(user_id).replace("_matrix/", "")
+ "&u=user" + "&u=user"
) )
request, channel = self.make_request( request, channel = make_request(
"GET", consent_uri, access_token=access_token, shorthand=False self.reactor,
FakeSite(resource),
"GET",
consent_uri,
access_token=access_token,
shorthand=False,
) )
render(request, resource, self.reactor) render(request, resource, self.reactor)
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200)
@ -92,7 +99,9 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
self.assertEqual(consented, "False") self.assertEqual(consented, "False")
# POST to the consent page, saying we've agreed # POST to the consent page, saying we've agreed
request, channel = self.make_request( request, channel = make_request(
self.reactor,
FakeSite(resource),
"POST", "POST",
consent_uri + "&v=" + version, consent_uri + "&v=" + version,
access_token=access_token, access_token=access_token,
@ -103,8 +112,13 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
# Fetch the consent page, to get the consent version -- it should have # Fetch the consent page, to get the consent version -- it should have
# changed # changed
request, channel = self.make_request( request, channel = make_request(
"GET", consent_uri, access_token=access_token, shorthand=False self.reactor,
FakeSite(resource),
"GET",
consent_uri,
access_token=access_token,
shorthand=False,
) )
render(request, resource, self.reactor) render(request, resource, self.reactor)
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200)

View File

@ -23,10 +23,11 @@ from typing import Any, Dict, Optional
import attr import attr
from twisted.web.resource import Resource from twisted.web.resource import Resource
from twisted.web.server import Site
from synapse.api.constants import Membership from synapse.api.constants import Membership
from tests.server import make_request, render from tests.server import FakeSite, make_request, render
@attr.s @attr.s
@ -36,7 +37,7 @@ class RestHelper:
""" """
hs = attr.ib() hs = attr.ib()
resource = attr.ib() site = attr.ib(type=Site)
auth_user_id = attr.ib() auth_user_id = attr.ib()
def create_room_as( def create_room_as(
@ -52,9 +53,13 @@ class RestHelper:
path = path + "?access_token=%s" % tok path = path + "?access_token=%s" % tok
request, channel = make_request( request, channel = make_request(
self.hs.get_reactor(), "POST", path, json.dumps(content).encode("utf8") self.hs.get_reactor(),
self.site,
"POST",
path,
json.dumps(content).encode("utf8"),
) )
render(request, self.resource, self.hs.get_reactor()) render(request, self.site.resource, self.hs.get_reactor())
assert channel.result["code"] == b"%d" % expect_code, channel.result assert channel.result["code"] == b"%d" % expect_code, channel.result
self.auth_user_id = temp_id self.auth_user_id = temp_id
@ -125,10 +130,14 @@ class RestHelper:
data.update(extra_data) data.update(extra_data)
request, channel = make_request( request, channel = make_request(
self.hs.get_reactor(), "PUT", path, json.dumps(data).encode("utf8") self.hs.get_reactor(),
self.site,
"PUT",
path,
json.dumps(data).encode("utf8"),
) )
render(request, self.resource, self.hs.get_reactor()) render(request, self.site.resource, self.hs.get_reactor())
assert int(channel.result["code"]) == expect_code, ( assert int(channel.result["code"]) == expect_code, (
"Expected: %d, got: %d, resp: %r" "Expected: %d, got: %d, resp: %r"
@ -158,9 +167,13 @@ class RestHelper:
path = path + "?access_token=%s" % tok path = path + "?access_token=%s" % tok
request, channel = make_request( request, channel = make_request(
self.hs.get_reactor(), "PUT", path, json.dumps(content).encode("utf8") self.hs.get_reactor(),
self.site,
"PUT",
path,
json.dumps(content).encode("utf8"),
) )
render(request, self.resource, self.hs.get_reactor()) render(request, self.site.resource, self.hs.get_reactor())
assert int(channel.result["code"]) == expect_code, ( assert int(channel.result["code"]) == expect_code, (
"Expected: %d, got: %d, resp: %r" "Expected: %d, got: %d, resp: %r"
@ -210,9 +223,11 @@ class RestHelper:
if body is not None: if body is not None:
content = json.dumps(body).encode("utf8") content = json.dumps(body).encode("utf8")
request, channel = make_request(self.hs.get_reactor(), method, path, content) request, channel = make_request(
self.hs.get_reactor(), self.site, method, path, content
)
render(request, self.resource, self.hs.get_reactor()) render(request, self.site.resource, self.hs.get_reactor())
assert int(channel.result["code"]) == expect_code, ( assert int(channel.result["code"]) == expect_code, (
"Expected: %d, got: %d, resp: %r" "Expected: %d, got: %d, resp: %r"
@ -296,10 +311,13 @@ class RestHelper:
image_length = len(image_data) image_length = len(image_data)
path = "/_matrix/media/r0/upload?filename=%s" % (filename,) path = "/_matrix/media/r0/upload?filename=%s" % (filename,)
request, channel = make_request( request, channel = make_request(
self.hs.get_reactor(), "POST", path, content=image_data, access_token=tok self.hs.get_reactor(),
) FakeSite(resource),
request.requestHeaders.addRawHeader( "POST",
b"Content-Length", str(image_length).encode("UTF-8") path,
content=image_data,
access_token=tok,
custom_headers=[(b"Content-Length", str(image_length))],
) )
request.render(resource) request.render(resource)
self.hs.get_reactor().pump([100]) self.hs.get_reactor().pump([100])

View File

@ -31,6 +31,7 @@ from synapse.rest.client.v2_alpha import account, register
from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
from tests import unittest from tests import unittest
from tests.server import FakeSite, make_request
from tests.unittest import override_config from tests.unittest import override_config
@ -255,9 +256,16 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
path = link.replace("https://example.com", "") path = link.replace("https://example.com", "")
# Load the password reset confirmation page # Load the password reset confirmation page
request, channel = self.make_request("GET", path, shorthand=False) request, channel = make_request(
self.reactor,
FakeSite(self.submit_token_resource),
"GET",
path,
shorthand=False,
)
request.render(self.submit_token_resource) request.render(self.submit_token_resource)
self.pump() self.pump()
self.assertEquals(200, channel.code, channel.result) self.assertEquals(200, channel.code, channel.result)
# Now POST to the same endpoint, mimicking the same behaviour as clicking the # Now POST to the same endpoint, mimicking the same behaviour as clicking the
@ -271,7 +279,9 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
form_args.append(arg) form_args.append(arg)
# Confirm the password reset # Confirm the password reset
request, channel = self.make_request( request, channel = make_request(
self.reactor,
FakeSite(self.submit_token_resource),
"POST", "POST",
path, path,
content=urlencode(form_args).encode("utf8"), content=urlencode(form_args).encode("utf8"),

View File

@ -32,7 +32,7 @@ from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
from tests import unittest from tests import unittest
from tests.server import FakeChannel, wait_until_result from tests.server import FakeChannel
from tests.utils import default_config from tests.utils import default_config
@ -41,7 +41,7 @@ class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase):
self.http_client = Mock() self.http_client = Mock()
return self.setup_test_homeserver(http_client=self.http_client) return self.setup_test_homeserver(http_client=self.http_client)
def create_test_json_resource(self): def create_test_resource(self):
return create_resource_tree( return create_resource_tree(
{"/_matrix/key/v2": KeyApiV2Resource(self.hs)}, root_resource=NoResource() {"/_matrix/key/v2": KeyApiV2Resource(self.hs)}, root_resource=NoResource()
) )
@ -94,7 +94,7 @@ class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase):
% (server_name.encode("utf-8"), key_id.encode("utf-8")), % (server_name.encode("utf-8"), key_id.encode("utf-8")),
b"1.1", b"1.1",
) )
wait_until_result(self.reactor, req) channel.await_result()
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200)
resp = channel.json_body resp = channel.json_body
return resp return resp
@ -190,7 +190,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
req.requestReceived( req.requestReceived(
b"POST", path.encode("utf-8"), b"1.1", b"POST", path.encode("utf-8"), b"1.1",
) )
wait_until_result(self.reactor, req) channel.await_result()
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200)
resp = channel.json_body resp = channel.json_body
return resp return resp

View File

@ -36,6 +36,7 @@ from synapse.rest.media.v1.media_storage import MediaStorage
from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
from tests import unittest from tests import unittest
from tests.server import FakeSite, make_request
class MediaStorageTests(unittest.HomeserverTestCase): class MediaStorageTests(unittest.HomeserverTestCase):
@ -227,7 +228,13 @@ class MediaRepoTests(unittest.HomeserverTestCase):
def _req(self, content_disposition): def _req(self, content_disposition):
request, channel = self.make_request("GET", self.media_id, shorthand=False) request, channel = make_request(
self.reactor,
FakeSite(self.download_resource),
"GET",
self.media_id,
shorthand=False,
)
request.render(self.download_resource) request.render(self.download_resource)
self.pump() self.pump()
@ -317,8 +324,12 @@ class MediaRepoTests(unittest.HomeserverTestCase):
def _test_thumbnail(self, method, expected_body, expected_found): def _test_thumbnail(self, method, expected_body, expected_found):
params = "?width=32&height=32&method=" + method params = "?width=32&height=32&method=" + method
request, channel = self.make_request( request, channel = make_request(
"GET", self.media_id + params, shorthand=False self.reactor,
FakeSite(self.thumbnail_resource),
"GET",
self.media_id + params,
shorthand=False,
) )
request.render(self.thumbnail_resource) request.render(self.thumbnail_resource)
self.pump() self.pump()

View File

@ -20,11 +20,9 @@ from tests import unittest
class HealthCheckTests(unittest.HomeserverTestCase): class HealthCheckTests(unittest.HomeserverTestCase):
def setUp(self): def create_test_resource(self):
super().setUp()
# replace the JsonResource with a HealthResource. # replace the JsonResource with a HealthResource.
self.resource = HealthResource() return HealthResource()
def test_health(self): def test_health(self):
request, channel = self.make_request("GET", "/health", shorthand=False) request, channel = self.make_request("GET", "/health", shorthand=False)

View File

@ -20,11 +20,9 @@ from tests import unittest
class WellKnownTests(unittest.HomeserverTestCase): class WellKnownTests(unittest.HomeserverTestCase):
def setUp(self): def create_test_resource(self):
super().setUp()
# replace the JsonResource with a WellKnownResource # replace the JsonResource with a WellKnownResource
self.resource = WellKnownResource(self.hs) return WellKnownResource(self.hs)
def test_well_known(self): def test_well_known(self):
self.hs.config.public_baseurl = "https://tesths" self.hs.config.public_baseurl = "https://tesths"

View File

@ -2,7 +2,7 @@ import json
import logging import logging
from collections import deque from collections import deque
from io import SEEK_END, BytesIO from io import SEEK_END, BytesIO
from typing import Callable from typing import Callable, Iterable, Optional, Tuple, Union
import attr import attr
from typing_extensions import Deque from typing_extensions import Deque
@ -21,6 +21,7 @@ from twisted.python.failure import Failure
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
from twisted.web.http import unquote from twisted.web.http import unquote
from twisted.web.http_headers import Headers from twisted.web.http_headers import Headers
from twisted.web.resource import IResource
from twisted.web.server import Site from twisted.web.server import Site
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
@ -117,6 +118,25 @@ class FakeChannel:
def transport(self): def transport(self):
return self return self
def await_result(self, timeout: int = 100) -> None:
"""
Wait until the request is finished.
"""
self._reactor.run()
x = 0
while not self.result.get("done"):
# If there's a producer, tell it to resume producing so we get content
if self._producer:
self._producer.resumeProducing()
x += 1
if x > timeout:
raise TimedOutException("Timed out waiting for request to finish.")
self._reactor.advance(0.1)
class FakeSite: class FakeSite:
""" """
@ -128,9 +148,21 @@ class FakeSite:
site_tag = "test" site_tag = "test"
access_logger = logging.getLogger("synapse.access.http.fake") access_logger = logging.getLogger("synapse.access.http.fake")
def __init__(self, resource: IResource):
"""
Args:
resource: the resource to be used for rendering all requests
"""
self._resource = resource
def getResourceFor(self, request):
return self._resource
def make_request( def make_request(
reactor, reactor,
site: Site,
method, method,
path, path,
content=b"", content=b"",
@ -139,12 +171,17 @@ def make_request(
shorthand=True, shorthand=True,
federation_auth_origin=None, federation_auth_origin=None,
content_is_form=False, content_is_form=False,
custom_headers: Optional[
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
] = None,
): ):
""" """
Make a web request using the given method and path, feed it the Make a web request using the given method and path, feed it the
content, and return the Request and the Channel underneath. content, and return the Request and the Channel underneath.
Args: Args:
site: The twisted Site to associate with the Channel
method (bytes/unicode): The HTTP request method ("verb"). method (bytes/unicode): The HTTP request method ("verb").
path (bytes/unicode): The HTTP path, suitably URL encoded (e.g. path (bytes/unicode): The HTTP path, suitably URL encoded (e.g.
escaped UTF-8 & spaces and such). escaped UTF-8 & spaces and such).
@ -157,6 +194,8 @@ def make_request(
content_is_form: Whether the content is URL encoded form data. Adds the content_is_form: Whether the content is URL encoded form data. Adds the
'Content-Type': 'application/x-www-form-urlencoded' header. 'Content-Type': 'application/x-www-form-urlencoded' header.
custom_headers: (name, value) pairs to add as request headers
Returns: Returns:
Tuple[synapse.http.site.SynapseRequest, channel] Tuple[synapse.http.site.SynapseRequest, channel]
""" """
@ -178,10 +217,11 @@ def make_request(
if not path.startswith(b"/"): if not path.startswith(b"/"):
path = b"/" + path path = b"/" + path
if isinstance(content, dict):
content = json.dumps(content).encode("utf8")
if isinstance(content, str): if isinstance(content, str):
content = content.encode("utf8") content = content.encode("utf8")
site = FakeSite()
channel = FakeChannel(site, reactor) channel = FakeChannel(site, reactor)
req = request(channel) req = request(channel)
@ -211,35 +251,18 @@ def make_request(
# Assume the body is JSON # Assume the body is JSON
req.requestHeaders.addRawHeader(b"Content-Type", b"application/json") req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")
if custom_headers:
for k, v in custom_headers:
req.requestHeaders.addRawHeader(k, v)
req.requestReceived(method, path, b"1.1") req.requestReceived(method, path, b"1.1")
return req, channel return req, channel
def wait_until_result(clock, request, timeout=100):
"""
Wait until the request is finished.
"""
clock.run()
x = 0
while not request.finished:
# If there's a producer, tell it to resume producing so we get content
if request._channel._producer:
request._channel._producer.resumeProducing()
x += 1
if x > timeout:
raise TimedOutException("Timed out waiting for request to finish.")
clock.advance(0.1)
def render(request, resource, clock): def render(request, resource, clock):
request.render(resource) request.render(resource)
wait_until_result(clock, request) request._channel.await_result()
@implementer(IReactorPluggableNameResolver) @implementer(IReactorPluggableNameResolver)

View File

@ -309,36 +309,6 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
) )
self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids)) self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids))
@patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=0)
def test_send_dummy_event_without_consent(self):
self._create_extremity_rich_graph()
self._enable_consent_checking()
# Pump the reactor repeatedly so that the background updates have a
# chance to run. Attempt to add dummy event with user that has not consented
# Check that dummy event send fails.
self.pump(10 * 60)
latest_event_ids = self.get_success(
self.store.get_latest_event_ids_in_room(self.room_id)
)
self.assertTrue(len(latest_event_ids) == self.EXTREMITIES_COUNT)
# Create new user, and add consent
user2 = self.register_user("user2", "password")
token2 = self.login("user2", "password")
self.get_success(
self.store.user_set_consent_version(user2, self.CONSENT_VERSION)
)
self.helper.join(self.room_id, user2, tok=token2)
# Background updates should now cause a dummy event to be added to the graph
self.pump(10 * 60)
latest_event_ids = self.get_success(
self.store.get_latest_event_ids_in_room(self.room_id)
)
self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids))
@patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=250) @patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=250)
def test_expiry_logic(self): def test_expiry_logic(self):
"""Simple test to ensure that _expire_rooms_to_exclude_from_dummy_event_insertion() """Simple test to ensure that _expire_rooms_to_exclude_from_dummy_event_insertion()

View File

@ -21,6 +21,7 @@ from synapse.http.site import XForwardedForRequest
from synapse.rest.client.v1 import login from synapse.rest.client.v1 import login
from tests import unittest from tests import unittest
from tests.server import make_request
from tests.test_utils import make_awaitable from tests.test_utils import make_awaitable
from tests.unittest import override_config from tests.unittest import override_config
@ -408,17 +409,18 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase):
# Advance to a known time # Advance to a known time
self.reactor.advance(123456 - self.reactor.seconds()) self.reactor.advance(123456 - self.reactor.seconds())
request, channel = self.make_request( headers1 = {b"User-Agent": b"Mozzila pizza"}
headers1.update(headers)
request, channel = make_request(
self.reactor,
self.site,
"GET", "GET",
"/_matrix/client/r0/admin/users/" + self.user_id, "/_matrix/client/r0/admin/users/" + self.user_id,
access_token=access_token, access_token=access_token,
custom_headers=headers1.items(),
**make_request_args, **make_request_args,
) )
request.requestHeaders.addRawHeader(b"User-Agent", b"Mozzila pizza")
# Add the optional headers
for h, v in headers.items():
request.requestHeaders.addRawHeader(h, v)
self.render(request) self.render(request)
# Advance so the save loop occurs # Advance so the save loop occurs

View File

@ -26,6 +26,7 @@ from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.server import ( from tests.server import (
FakeSite,
ThreadedMemoryReactorClock, ThreadedMemoryReactorClock,
make_request, make_request,
render, render,
@ -62,7 +63,7 @@ class JsonResourceTests(unittest.TestCase):
) )
request, channel = make_request( request, channel = make_request(
self.reactor, b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83" self.reactor, FakeSite(res), b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83"
) )
render(request, res, self.reactor) render(request, res, self.reactor)
@ -83,7 +84,9 @@ class JsonResourceTests(unittest.TestCase):
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
) )
request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo") request, channel = make_request(
self.reactor, FakeSite(res), b"GET", b"/_matrix/foo"
)
render(request, res, self.reactor) render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b"500") self.assertEqual(channel.result["code"], b"500")
@ -108,7 +111,9 @@ class JsonResourceTests(unittest.TestCase):
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
) )
request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo") request, channel = make_request(
self.reactor, FakeSite(res), b"GET", b"/_matrix/foo"
)
render(request, res, self.reactor) render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b"500") self.assertEqual(channel.result["code"], b"500")
@ -127,7 +132,9 @@ class JsonResourceTests(unittest.TestCase):
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
) )
request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo") request, channel = make_request(
self.reactor, FakeSite(res), b"GET", b"/_matrix/foo"
)
render(request, res, self.reactor) render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b"403") self.assertEqual(channel.result["code"], b"403")
@ -150,7 +157,9 @@ class JsonResourceTests(unittest.TestCase):
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
) )
request, channel = make_request(self.reactor, b"GET", b"/_matrix/foobar") request, channel = make_request(
self.reactor, FakeSite(res), b"GET", b"/_matrix/foobar"
)
render(request, res, self.reactor) render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b"400") self.assertEqual(channel.result["code"], b"400")
@ -173,7 +182,9 @@ class JsonResourceTests(unittest.TestCase):
) )
# The path was registered as GET, but this is a HEAD request. # The path was registered as GET, but this is a HEAD request.
request, channel = make_request(self.reactor, b"HEAD", b"/_matrix/foo") request, channel = make_request(
self.reactor, FakeSite(res), b"HEAD", b"/_matrix/foo"
)
render(request, res, self.reactor) render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b"200") self.assertEqual(channel.result["code"], b"200")
@ -196,9 +207,6 @@ class OptionsResourceTests(unittest.TestCase):
def _make_request(self, method, path): def _make_request(self, method, path):
"""Create a request from the method/path and return a channel with the response.""" """Create a request from the method/path and return a channel with the response."""
request, channel = make_request(self.reactor, method, path, shorthand=False)
request.prepath = [] # This doesn't get set properly by make_request.
# Create a site and query for the resource. # Create a site and query for the resource.
site = SynapseSite( site = SynapseSite(
"test", "test",
@ -207,6 +215,12 @@ class OptionsResourceTests(unittest.TestCase):
self.resource, self.resource,
"1.0", "1.0",
) )
request, channel = make_request(
self.reactor, site, method, path, shorthand=False
)
request.prepath = [] # This doesn't get set properly by make_request.
request.site = site request.site = site
resource = site.getResourceFor(request) resource = site.getResourceFor(request)
@ -284,7 +298,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
res = WrapHtmlRequestHandlerTests.TestResource() res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback res.callback = callback
request, channel = make_request(self.reactor, b"GET", b"/path") request, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
render(request, res, self.reactor) render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b"200") self.assertEqual(channel.result["code"], b"200")
@ -303,7 +317,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
res = WrapHtmlRequestHandlerTests.TestResource() res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback res.callback = callback
request, channel = make_request(self.reactor, b"GET", b"/path") request, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
render(request, res, self.reactor) render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b"301") self.assertEqual(channel.result["code"], b"301")
@ -325,7 +339,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
res = WrapHtmlRequestHandlerTests.TestResource() res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback res.callback = callback
request, channel = make_request(self.reactor, b"GET", b"/path") request, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
render(request, res, self.reactor) render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b"304") self.assertEqual(channel.result["code"], b"304")
@ -345,7 +359,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
res = WrapHtmlRequestHandlerTests.TestResource() res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback res.callback = callback
request, channel = make_request(self.reactor, b"HEAD", b"/path") request, channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/path")
render(request, res, self.reactor) render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b"200") self.assertEqual(channel.result["code"], b"200")

View File

@ -169,6 +169,7 @@ class StateTestCase(unittest.TestCase):
"get_state_handler", "get_state_handler",
"get_clock", "get_clock",
"get_state_resolution_handler", "get_state_resolution_handler",
"hostname",
] ]
) )
hs.config = default_config("tesths", True) hs.config = default_config("tesths", True)

View File

@ -30,6 +30,7 @@ from twisted.internet.defer import Deferred, ensureDeferred, succeed
from twisted.python.failure import Failure from twisted.python.failure import Failure
from twisted.python.threadpool import ThreadPool from twisted.python.threadpool import ThreadPool
from twisted.trial import unittest from twisted.trial import unittest
from twisted.web.resource import Resource
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
@ -239,10 +240,8 @@ class HomeserverTestCase(TestCase):
if not isinstance(self.hs, HomeServer): if not isinstance(self.hs, HomeServer):
raise Exception("A homeserver wasn't returned, but %r" % (self.hs,)) raise Exception("A homeserver wasn't returned, but %r" % (self.hs,))
# Register the resources # create the root resource, and a site to wrap it.
self.resource = self.create_test_json_resource() self.resource = self.create_test_resource()
# create a site to wrap the resource.
self.site = SynapseSite( self.site = SynapseSite(
logger_name="synapse.access.http.fake", logger_name="synapse.access.http.fake",
site_tag=self.hs.config.server.server_name, site_tag=self.hs.config.server.server_name,
@ -253,7 +252,7 @@ class HomeserverTestCase(TestCase):
from tests.rest.client.v1.utils import RestHelper from tests.rest.client.v1.utils import RestHelper
self.helper = RestHelper(self.hs, self.resource, getattr(self, "user_id", None)) self.helper = RestHelper(self.hs, self.site, getattr(self, "user_id", None))
if hasattr(self, "user_id"): if hasattr(self, "user_id"):
if self.hijack_auth: if self.hijack_auth:
@ -323,15 +322,12 @@ class HomeserverTestCase(TestCase):
hs = self.setup_test_homeserver() hs = self.setup_test_homeserver()
return hs return hs
def create_test_json_resource(self): def create_test_resource(self) -> Resource:
""" """
Create a test JsonResource, with the relevant servlets registerd to it Create a the root resource for the test server.
The default implementation calls each function in `servlets` to do the The default implementation creates a JsonResource and calls each function in
registration. `servlets` to register servletes against it
Returns:
JsonResource:
""" """
resource = JsonResource(self.hs) resource = JsonResource(self.hs)
@ -429,11 +425,9 @@ class HomeserverTestCase(TestCase):
Returns: Returns:
Tuple[synapse.http.site.SynapseRequest, channel] Tuple[synapse.http.site.SynapseRequest, channel]
""" """
if isinstance(content, dict):
content = json.dumps(content).encode("utf8")
return make_request( return make_request(
self.reactor, self.reactor,
self.site,
method, method,
path, path,
content, content,