Compare commits

...

12 Commits

Author SHA1 Message Date
Richard van der Hoff 3dc1871219
Merge pull request #8757 from matrix-org/rav/pass_site_to_make_request
Pass a Site into `make_request`
2020-11-16 18:22:24 +00:00
Richard van der Hoff f125895475
Move `wait_until_result` into `FakeChannel` (#8758)
FakeChannel has everything we need, and this more accurately models the real
flow.
2020-11-16 18:21:47 +00:00
Richard van der Hoff c3e3552ec4 fixup test 2020-11-16 15:51:47 +00:00
Andrew Morgan 4f76eef0e8
Generalise _locally_reject_invite (#8751)
`_locally_reject_invite` generates an out-of-band membership event which can be passed to clients, but not other homeservers.

This is used when we fail to reject an invite over federation. If this happens, we instead just generate a leave event locally and send it down /sync, allowing clients to reject invites even if we can't reach the remote homeserver.

A similar flow needs to be put in place for rescinding knocks. If we're unable to contact any remote server from the room we've tried to knock on, we'd still like to generate and store the leave event locally. Hence the need to reuse, and thus generalise, this method.

Separated from #6739.
2020-11-16 15:37:36 +00:00
Richard van der Hoff bebfb9a97b
Merge branch 'develop' into rav/pass_site_to_make_request 2020-11-16 15:22:40 +00:00
Richard van der Hoff 791d7cd6f0
Rename `create_test_json_resource` to `create_test_resource` (#8759)
The root resource isn't necessarily a JsonResource, so rename this method
accordingly, and update a couple of test classes to use the method rather than
directly manipulating self.resource.
2020-11-16 14:45:52 +00:00
Richard van der Hoff ebc405446e
Add a `custom_headers` param to `make_request` (#8760)
Some tests want to set some custom HTTP request headers, so provide a way to do
that before calling requestReceived().
2020-11-16 14:45:22 +00:00
Richard van der Hoff 0d33c53534 changelog 2020-11-15 23:09:03 +00:00
Richard van der Hoff cfd895a22e use global make_request() directly where we have a custom Resource
Where we want to render a request against a specific Resource, call the global
make_request() function rather than the one in HomeserverTestCase, allowing us
to pass in an appropriate `Site`.
2020-11-15 23:09:03 +00:00
Richard van der Hoff 70c0d47989 fix dict handling for make_request() 2020-11-15 23:09:03 +00:00
Richard van der Hoff 9debe657a3 pass a Site into make_request 2020-11-15 23:09:03 +00:00
Richard van der Hoff d3523e3e97 pass a Site into RestHelper 2020-11-15 23:09:03 +00:00
26 changed files with 312 additions and 164 deletions

1
changelog.d/8751.misc Normal file
View File

@ -0,0 +1 @@
Generalise `RoomMemberHandler._locally_reject_invite` to apply to more flows than just invite.

1
changelog.d/8757.misc Normal file
View File

@ -0,0 +1 @@
Refactor test utilities for injecting HTTP requests.

1
changelog.d/8758.misc Normal file
View File

@ -0,0 +1 @@
Refactor test utilities for injecting HTTP requests.

1
changelog.d/8759.misc Normal file
View File

@ -0,0 +1 @@
Refactor test utilities for injecting HTTP requests.

1
changelog.d/8760.misc Normal file
View File

@ -0,0 +1 @@
Refactor test utilities for injecting HTTP requests.

View File

@ -1104,32 +1104,34 @@ class RoomMemberMasterHandler(RoomMemberHandler):
# #
logger.warning("Failed to reject invite: %s", e) logger.warning("Failed to reject invite: %s", e)
return await self._locally_reject_invite( return await self._generate_local_out_of_band_leave(
invite_event, txn_id, requester, content invite_event, txn_id, requester, content
) )
async def _locally_reject_invite( async def _generate_local_out_of_band_leave(
self, self,
invite_event: EventBase, previous_membership_event: EventBase,
txn_id: Optional[str], txn_id: Optional[str],
requester: Requester, requester: Requester,
content: JsonDict, content: JsonDict,
) -> Tuple[str, int]: ) -> Tuple[str, int]:
"""Generate a local invite rejection """Generate a local leave event for a room
This is called after we fail to reject an invite via a remote server. It This can be called after we e.g fail to reject an invite via a remote server.
generates an out-of-band membership event locally. It generates an out-of-band membership event locally.
Args: Args:
invite_event: the invite to be rejected previous_membership_event: the previous membership event for this user
txn_id: optional transaction ID supplied by the client txn_id: optional transaction ID supplied by the client
requester: user making the rejection request, according to the access token requester: user making the request, according to the access token
content: additional content to include in the rejection event. content: additional content to include in the leave event.
Normally an empty dict. Normally an empty dict.
"""
room_id = invite_event.room_id Returns:
target_user = invite_event.state_key A tuple containing (event_id, stream_id of the leave event)
"""
room_id = previous_membership_event.room_id
target_user = previous_membership_event.state_key
content["membership"] = Membership.LEAVE content["membership"] = Membership.LEAVE
@ -1141,12 +1143,12 @@ class RoomMemberMasterHandler(RoomMemberHandler):
"state_key": target_user, "state_key": target_user,
} }
# the auth events for the new event are the same as that of the invite, plus # the auth events for the new event are the same as that of the previous event, plus
# the invite itself. # the event itself.
# #
# the prev_events are just the invite. # the prev_events consist solely of the previous membership event.
prev_event_ids = [invite_event.event_id] prev_event_ids = [previous_membership_event.event_id]
auth_event_ids = invite_event.auth_event_ids() + prev_event_ids auth_event_ids = previous_membership_event.auth_event_ids() + prev_event_ids
event, context = await self.event_creation_handler.create_event( event, context = await self.event_creation_handler.create_event(
requester, requester,

View File

@ -15,6 +15,7 @@
from synapse.app.generic_worker import GenericWorkerServer from synapse.app.generic_worker import GenericWorkerServer
from tests.server import make_request, render
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
@ -55,10 +56,10 @@ class FrontendProxyTests(HomeserverTestCase):
# Grab the resource from the site that was told to listen # Grab the resource from the site that was told to listen
self.assertEqual(len(self.reactor.tcpServers), 1) self.assertEqual(len(self.reactor.tcpServers), 1)
site = self.reactor.tcpServers[0][1] site = self.reactor.tcpServers[0][1]
self.resource = site.resource.children[b"_matrix"].children[b"client"] resource = site.resource.children[b"_matrix"].children[b"client"]
request, channel = self.make_request("PUT", "presence/a/status") request, channel = make_request(self.reactor, site, "PUT", "presence/a/status")
self.render(request) render(request, resource, self.reactor)
# 400 + unrecognised, because nothing is registered # 400 + unrecognised, because nothing is registered
self.assertEqual(channel.code, 400) self.assertEqual(channel.code, 400)
@ -77,10 +78,10 @@ class FrontendProxyTests(HomeserverTestCase):
# Grab the resource from the site that was told to listen # Grab the resource from the site that was told to listen
self.assertEqual(len(self.reactor.tcpServers), 1) self.assertEqual(len(self.reactor.tcpServers), 1)
site = self.reactor.tcpServers[0][1] site = self.reactor.tcpServers[0][1]
self.resource = site.resource.children[b"_matrix"].children[b"client"] resource = site.resource.children[b"_matrix"].children[b"client"]
request, channel = self.make_request("PUT", "presence/a/status") request, channel = make_request(self.reactor, site, "PUT", "presence/a/status")
self.render(request) render(request, resource, self.reactor)
# 401, because the stub servlet still checks authentication # 401, because the stub servlet still checks authentication
self.assertEqual(channel.code, 401) self.assertEqual(channel.code, 401)

View File

@ -20,6 +20,7 @@ from synapse.app.generic_worker import GenericWorkerServer
from synapse.app.homeserver import SynapseHomeServer from synapse.app.homeserver import SynapseHomeServer
from synapse.config.server import parse_listener_def from synapse.config.server import parse_listener_def
from tests.server import make_request, render
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
@ -66,16 +67,16 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
# Grab the resource from the site that was told to listen # Grab the resource from the site that was told to listen
site = self.reactor.tcpServers[0][1] site = self.reactor.tcpServers[0][1]
try: try:
self.resource = site.resource.children[b"_matrix"].children[b"federation"] resource = site.resource.children[b"_matrix"].children[b"federation"]
except KeyError: except KeyError:
if expectation == "no_resource": if expectation == "no_resource":
return return
raise raise
request, channel = self.make_request( request, channel = make_request(
"GET", "/_matrix/federation/v1/openid/userinfo" self.reactor, site, "GET", "/_matrix/federation/v1/openid/userinfo"
) )
self.render(request) render(request, resource, self.reactor)
self.assertEqual(channel.code, 401) self.assertEqual(channel.code, 401)
@ -115,15 +116,15 @@ class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase):
# Grab the resource from the site that was told to listen # Grab the resource from the site that was told to listen
site = self.reactor.tcpServers[0][1] site = self.reactor.tcpServers[0][1]
try: try:
self.resource = site.resource.children[b"_matrix"].children[b"federation"] resource = site.resource.children[b"_matrix"].children[b"federation"]
except KeyError: except KeyError:
if expectation == "no_resource": if expectation == "no_resource":
return return
raise raise
request, channel = self.make_request( request, channel = make_request(
"GET", "/_matrix/federation/v1/openid/userinfo" self.reactor, site, "GET", "/_matrix/federation/v1/openid/userinfo"
) )
self.render(request) render(request, resource, self.reactor)
self.assertEqual(channel.code, 401) self.assertEqual(channel.code, 401)

View File

@ -17,6 +17,7 @@
from synapse.http.additional_resource import AdditionalResource from synapse.http.additional_resource import AdditionalResource
from synapse.http.server import respond_with_json from synapse.http.server import respond_with_json
from tests.server import FakeSite, make_request, render
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
@ -43,20 +44,20 @@ class AdditionalResourceTests(HomeserverTestCase):
def test_async(self): def test_async(self):
handler = _AsyncTestCustomEndpoint({}, None).handle_request handler = _AsyncTestCustomEndpoint({}, None).handle_request
self.resource = AdditionalResource(self.hs, handler) resource = AdditionalResource(self.hs, handler)
request, channel = self.make_request("GET", "/") request, channel = make_request(self.reactor, FakeSite(resource), "GET", "/")
self.render(request) render(request, resource, self.reactor)
self.assertEqual(request.code, 200) self.assertEqual(request.code, 200)
self.assertEqual(channel.json_body, {"some_key": "some_value_async"}) self.assertEqual(channel.json_body, {"some_key": "some_value_async"})
def test_sync(self): def test_sync(self):
handler = _SyncTestCustomEndpoint({}, None).handle_request handler = _SyncTestCustomEndpoint({}, None).handle_request
self.resource = AdditionalResource(self.hs, handler) resource = AdditionalResource(self.hs, handler)
request, channel = self.make_request("GET", "/") request, channel = make_request(self.reactor, FakeSite(resource), "GET", "/")
self.render(request) render(request, resource, self.reactor)
self.assertEqual(request.code, 200) self.assertEqual(request.code, 200)
self.assertEqual(channel.json_body, {"some_key": "some_value_sync"}) self.assertEqual(channel.json_body, {"some_key": "some_value_sync"})

View File

@ -240,8 +240,8 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
lambda: self._handle_http_replication_attempt(self.hs, 8765), lambda: self._handle_http_replication_attempt(self.hs, 8765),
) )
def create_test_json_resource(self): def create_test_resource(self):
"""Overrides `HomeserverTestCase.create_test_json_resource`. """Overrides `HomeserverTestCase.create_test_resource`.
""" """
# We override this so that it automatically registers all the HTTP # We override this so that it automatically registers all the HTTP
# replication servlets, without having to explicitly do that in all # replication servlets, without having to explicitly do that in all

View File

@ -20,7 +20,7 @@ from synapse.rest.client.v2_alpha import register
from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.rest.client.v2_alpha.test_auth import DummyRecaptchaChecker from tests.rest.client.v2_alpha.test_auth import DummyRecaptchaChecker
from tests.server import FakeChannel from tests.server import FakeChannel, make_request
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -46,8 +46,11 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
"""Test that registration works when using a single client reader worker. """Test that registration works when using a single client reader worker.
""" """
worker_hs = self.make_worker_hs("synapse.app.client_reader") worker_hs = self.make_worker_hs("synapse.app.client_reader")
site = self._hs_to_site[worker_hs]
request_1, channel_1 = self.make_request( request_1, channel_1 = make_request(
self.reactor,
site,
"POST", "POST",
"register", "register",
{"username": "user", "type": "m.login.password", "password": "bar"}, {"username": "user", "type": "m.login.password", "password": "bar"},
@ -59,8 +62,12 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
session = channel_1.json_body["session"] session = channel_1.json_body["session"]
# also complete the dummy auth # also complete the dummy auth
request_2, channel_2 = self.make_request( request_2, channel_2 = make_request(
"POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}} self.reactor,
site,
"POST",
"register",
{"auth": {"session": session, "type": "m.login.dummy"}},
) # type: SynapseRequest, FakeChannel ) # type: SynapseRequest, FakeChannel
self.render_on_worker(worker_hs, request_2) self.render_on_worker(worker_hs, request_2)
self.assertEqual(request_2.code, 200) self.assertEqual(request_2.code, 200)
@ -74,7 +81,10 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
worker_hs_1 = self.make_worker_hs("synapse.app.client_reader") worker_hs_1 = self.make_worker_hs("synapse.app.client_reader")
worker_hs_2 = self.make_worker_hs("synapse.app.client_reader") worker_hs_2 = self.make_worker_hs("synapse.app.client_reader")
request_1, channel_1 = self.make_request( site_1 = self._hs_to_site[worker_hs_1]
request_1, channel_1 = make_request(
self.reactor,
site_1,
"POST", "POST",
"register", "register",
{"username": "user", "type": "m.login.password", "password": "bar"}, {"username": "user", "type": "m.login.password", "password": "bar"},
@ -86,8 +96,13 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
session = channel_1.json_body["session"] session = channel_1.json_body["session"]
# also complete the dummy auth # also complete the dummy auth
request_2, channel_2 = self.make_request( site_2 = self._hs_to_site[worker_hs_2]
"POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}} request_2, channel_2 = make_request(
self.reactor,
site_2,
"POST",
"register",
{"auth": {"session": session, "type": "m.login.dummy"}},
) # type: SynapseRequest, FakeChannel ) # type: SynapseRequest, FakeChannel
self.render_on_worker(worker_hs_2, request_2) self.render_on_worker(worker_hs_2, request_2)
self.assertEqual(request_2.code, 200) self.assertEqual(request_2.code, 200)

View File

@ -28,7 +28,7 @@ from synapse.server import HomeServer
from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import FakeChannel, FakeTransport from tests.server import FakeChannel, FakeSite, FakeTransport, make_request
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -67,14 +67,16 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
The channel for the *client* request and the *outbound* request for The channel for the *client* request and the *outbound* request for
the media which the caller should respond to. the media which the caller should respond to.
""" """
resource = hs.get_media_repository_resource().children[b"download"]
request, channel = self.make_request( request, channel = make_request(
self.reactor,
FakeSite(resource),
"GET", "GET",
"/{}/{}".format(target, media_id), "/{}/{}".format(target, media_id),
shorthand=False, shorthand=False,
access_token=self.access_token, access_token=self.access_token,
) )
request.render(hs.get_media_repository_resource().children[b"download"]) request.render(resource)
self.pump() self.pump()
clients = self.reactor.tcpClients clients = self.reactor.tcpClients

View File

@ -22,6 +22,7 @@ from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import sync from synapse.rest.client.v2_alpha import sync
from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import make_request
from tests.utils import USE_POSTGRES_FOR_TESTS from tests.utils import USE_POSTGRES_FOR_TESTS
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -148,6 +149,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
sync_hs = self.make_worker_hs( sync_hs = self.make_worker_hs(
"synapse.app.generic_worker", {"worker_name": "sync"}, "synapse.app.generic_worker", {"worker_name": "sync"},
) )
sync_hs_site = self._hs_to_site[sync_hs]
# Specially selected room IDs that get persisted on different workers. # Specially selected room IDs that get persisted on different workers.
room_id1 = "!foo:test" room_id1 = "!foo:test"
@ -178,7 +180,9 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
) )
# Do an initial sync so that we're up to date. # Do an initial sync so that we're up to date.
request, channel = self.make_request("GET", "/sync", access_token=access_token) request, channel = make_request(
self.reactor, sync_hs_site, "GET", "/sync", access_token=access_token
)
self.render_on_worker(sync_hs, request) self.render_on_worker(sync_hs, request)
next_batch = channel.json_body["next_batch"] next_batch = channel.json_body["next_batch"]
@ -203,8 +207,12 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
# Check that syncing still gets the new event, despite the gap in the # Check that syncing still gets the new event, despite the gap in the
# stream IDs. # stream IDs.
request, channel = self.make_request( request, channel = make_request(
"GET", "/sync?since={}".format(next_batch), access_token=access_token self.reactor,
sync_hs_site,
"GET",
"/sync?since={}".format(next_batch),
access_token=access_token,
) )
self.render_on_worker(sync_hs, request) self.render_on_worker(sync_hs, request)
@ -230,7 +238,9 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
response = self.helper.send(room_id2, body="Hi!", tok=self.other_access_token) response = self.helper.send(room_id2, body="Hi!", tok=self.other_access_token)
first_event_in_room2 = response["event_id"] first_event_in_room2 = response["event_id"]
request, channel = self.make_request( request, channel = make_request(
self.reactor,
sync_hs_site,
"GET", "GET",
"/sync?since={}".format(vector_clock_token), "/sync?since={}".format(vector_clock_token),
access_token=access_token, access_token=access_token,
@ -254,8 +264,12 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
self.helper.send(room_id1, body="Hi again!", tok=self.other_access_token) self.helper.send(room_id1, body="Hi again!", tok=self.other_access_token)
self.helper.send(room_id2, body="Hi again!", tok=self.other_access_token) self.helper.send(room_id2, body="Hi again!", tok=self.other_access_token)
request, channel = self.make_request( request, channel = make_request(
"GET", "/sync?since={}".format(next_batch), access_token=access_token self.reactor,
sync_hs_site,
"GET",
"/sync?since={}".format(next_batch),
access_token=access_token,
) )
self.render_on_worker(sync_hs, request) self.render_on_worker(sync_hs, request)
@ -269,7 +283,9 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
# Paginating back in the first room should not produce any results, as # Paginating back in the first room should not produce any results, as
# no events have happened in it. This tests that we are correctly # no events have happened in it. This tests that we are correctly
# filtering results based on the vector clock portion. # filtering results based on the vector clock portion.
request, channel = self.make_request( request, channel = make_request(
self.reactor,
sync_hs_site,
"GET", "GET",
"/rooms/{}/messages?from={}&to={}&dir=b".format( "/rooms/{}/messages?from={}&to={}&dir=b".format(
room_id1, prev_batch1, vector_clock_token room_id1, prev_batch1, vector_clock_token
@ -281,7 +297,9 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
# Paginating back on the second room should produce the first event # Paginating back on the second room should produce the first event
# again. This tests that pagination isn't completely broken. # again. This tests that pagination isn't completely broken.
request, channel = self.make_request( request, channel = make_request(
self.reactor,
sync_hs_site,
"GET", "GET",
"/rooms/{}/messages?from={}&to={}&dir=b".format( "/rooms/{}/messages?from={}&to={}&dir=b".format(
room_id2, prev_batch2, vector_clock_token room_id2, prev_batch2, vector_clock_token
@ -295,7 +313,9 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
) )
# Paginating forwards should give the same results # Paginating forwards should give the same results
request, channel = self.make_request( request, channel = make_request(
self.reactor,
sync_hs_site,
"GET", "GET",
"/rooms/{}/messages?from={}&to={}&dir=f".format( "/rooms/{}/messages?from={}&to={}&dir=f".format(
room_id1, vector_clock_token, prev_batch1 room_id1, vector_clock_token, prev_batch1
@ -305,7 +325,9 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
self.render_on_worker(sync_hs, request) self.render_on_worker(sync_hs, request)
self.assertListEqual([], channel.json_body["chunk"]) self.assertListEqual([], channel.json_body["chunk"])
request, channel = self.make_request( request, channel = make_request(
self.reactor,
sync_hs_site,
"GET", "GET",
"/rooms/{}/messages?from={}&to={}&dir=f".format( "/rooms/{}/messages?from={}&to={}&dir=f".format(
room_id2, vector_clock_token, prev_batch2, room_id2, vector_clock_token, prev_batch2,

View File

@ -30,12 +30,13 @@ from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import groups from synapse.rest.client.v2_alpha import groups
from tests import unittest from tests import unittest
from tests.server import FakeSite, make_request
class VersionTestCase(unittest.HomeserverTestCase): class VersionTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v1/server_version" url = "/_synapse/admin/v1/server_version"
def create_test_json_resource(self): def create_test_resource(self):
resource = JsonResource(self.hs) resource = JsonResource(self.hs)
VersionServlet(self.hs).register(resource) VersionServlet(self.hs).register(resource)
return resource return resource
@ -222,8 +223,13 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
def _ensure_quarantined(self, admin_user_tok, server_and_media_id): def _ensure_quarantined(self, admin_user_tok, server_and_media_id):
"""Ensure a piece of media is quarantined when trying to access it.""" """Ensure a piece of media is quarantined when trying to access it."""
request, channel = self.make_request( request, channel = make_request(
"GET", server_and_media_id, shorthand=False, access_token=admin_user_tok, self.reactor,
FakeSite(self.download_resource),
"GET",
server_and_media_id,
shorthand=False,
access_token=admin_user_tok,
) )
request.render(self.download_resource) request.render(self.download_resource)
self.pump(1.0) self.pump(1.0)
@ -287,7 +293,9 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
server_name, media_id = server_name_and_media_id.split("/") server_name, media_id = server_name_and_media_id.split("/")
# Attempt to access the media # Attempt to access the media
request, channel = self.make_request( request, channel = make_request(
self.reactor,
FakeSite(self.download_resource),
"GET", "GET",
server_name_and_media_id, server_name_and_media_id,
shorthand=False, shorthand=False,
@ -462,7 +470,9 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
self._ensure_quarantined(admin_user_tok, server_and_media_id_1) self._ensure_quarantined(admin_user_tok, server_and_media_id_1)
# Attempt to access each piece of media # Attempt to access each piece of media
request, channel = self.make_request( request, channel = make_request(
self.reactor,
FakeSite(self.download_resource),
"GET", "GET",
server_and_media_id_2, server_and_media_id_2,
shorthand=False, shorthand=False,

View File

@ -23,6 +23,7 @@ from synapse.rest.client.v1 import login, profile, room
from synapse.rest.media.v1.filepath import MediaFilePaths from synapse.rest.media.v1.filepath import MediaFilePaths
from tests import unittest from tests import unittest
from tests.server import FakeSite, make_request
class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
@ -124,7 +125,9 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
self.assertEqual(server_name, self.server_name) self.assertEqual(server_name, self.server_name)
# Attempt to access media # Attempt to access media
request, channel = self.make_request( request, channel = make_request(
self.reactor,
FakeSite(download_resource),
"GET", "GET",
server_and_media_id, server_and_media_id,
shorthand=False, shorthand=False,
@ -161,7 +164,9 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
) )
# Attempt to access media # Attempt to access media
request, channel = self.make_request( request, channel = make_request(
self.reactor,
FakeSite(download_resource),
"GET", "GET",
server_and_media_id, server_and_media_id,
shorthand=False, shorthand=False,
@ -535,7 +540,9 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
media_id = server_and_media_id.split("/")[1] media_id = server_and_media_id.split("/")[1]
local_path = self.filepaths.local_media_filepath(media_id) local_path = self.filepaths.local_media_filepath(media_id)
request, channel = self.make_request( request, channel = make_request(
self.reactor,
FakeSite(download_resource),
"GET", "GET",
server_and_media_id, server_and_media_id,
shorthand=False, shorthand=False,

View File

@ -21,7 +21,7 @@ from synapse.rest.client.v1 import login, room
from synapse.rest.consent import consent_resource from synapse.rest.consent import consent_resource
from tests import unittest from tests import unittest
from tests.server import render from tests.server import FakeSite, make_request, render
class ConsentResourceTestCase(unittest.HomeserverTestCase): class ConsentResourceTestCase(unittest.HomeserverTestCase):
@ -61,7 +61,9 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
def test_render_public_consent(self): def test_render_public_consent(self):
"""You can observe the terms form without specifying a user""" """You can observe the terms form without specifying a user"""
resource = consent_resource.ConsentResource(self.hs) resource = consent_resource.ConsentResource(self.hs)
request, channel = self.make_request("GET", "/consent?v=1", shorthand=False) request, channel = make_request(
self.reactor, FakeSite(resource), "GET", "/consent?v=1", shorthand=False
)
render(request, resource, self.reactor) render(request, resource, self.reactor)
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200)
@ -81,8 +83,13 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
uri_builder.build_user_consent_uri(user_id).replace("_matrix/", "") uri_builder.build_user_consent_uri(user_id).replace("_matrix/", "")
+ "&u=user" + "&u=user"
) )
request, channel = self.make_request( request, channel = make_request(
"GET", consent_uri, access_token=access_token, shorthand=False self.reactor,
FakeSite(resource),
"GET",
consent_uri,
access_token=access_token,
shorthand=False,
) )
render(request, resource, self.reactor) render(request, resource, self.reactor)
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200)
@ -92,7 +99,9 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
self.assertEqual(consented, "False") self.assertEqual(consented, "False")
# POST to the consent page, saying we've agreed # POST to the consent page, saying we've agreed
request, channel = self.make_request( request, channel = make_request(
self.reactor,
FakeSite(resource),
"POST", "POST",
consent_uri + "&v=" + version, consent_uri + "&v=" + version,
access_token=access_token, access_token=access_token,
@ -103,8 +112,13 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
# Fetch the consent page, to get the consent version -- it should have # Fetch the consent page, to get the consent version -- it should have
# changed # changed
request, channel = self.make_request( request, channel = make_request(
"GET", consent_uri, access_token=access_token, shorthand=False self.reactor,
FakeSite(resource),
"GET",
consent_uri,
access_token=access_token,
shorthand=False,
) )
render(request, resource, self.reactor) render(request, resource, self.reactor)
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200)

View File

@ -23,10 +23,11 @@ from typing import Any, Dict, Optional
import attr import attr
from twisted.web.resource import Resource from twisted.web.resource import Resource
from twisted.web.server import Site
from synapse.api.constants import Membership from synapse.api.constants import Membership
from tests.server import make_request, render from tests.server import FakeSite, make_request, render
@attr.s @attr.s
@ -36,7 +37,7 @@ class RestHelper:
""" """
hs = attr.ib() hs = attr.ib()
resource = attr.ib() site = attr.ib(type=Site)
auth_user_id = attr.ib() auth_user_id = attr.ib()
def create_room_as( def create_room_as(
@ -52,9 +53,13 @@ class RestHelper:
path = path + "?access_token=%s" % tok path = path + "?access_token=%s" % tok
request, channel = make_request( request, channel = make_request(
self.hs.get_reactor(), "POST", path, json.dumps(content).encode("utf8") self.hs.get_reactor(),
self.site,
"POST",
path,
json.dumps(content).encode("utf8"),
) )
render(request, self.resource, self.hs.get_reactor()) render(request, self.site.resource, self.hs.get_reactor())
assert channel.result["code"] == b"%d" % expect_code, channel.result assert channel.result["code"] == b"%d" % expect_code, channel.result
self.auth_user_id = temp_id self.auth_user_id = temp_id
@ -125,10 +130,14 @@ class RestHelper:
data.update(extra_data) data.update(extra_data)
request, channel = make_request( request, channel = make_request(
self.hs.get_reactor(), "PUT", path, json.dumps(data).encode("utf8") self.hs.get_reactor(),
self.site,
"PUT",
path,
json.dumps(data).encode("utf8"),
) )
render(request, self.resource, self.hs.get_reactor()) render(request, self.site.resource, self.hs.get_reactor())
assert int(channel.result["code"]) == expect_code, ( assert int(channel.result["code"]) == expect_code, (
"Expected: %d, got: %d, resp: %r" "Expected: %d, got: %d, resp: %r"
@ -158,9 +167,13 @@ class RestHelper:
path = path + "?access_token=%s" % tok path = path + "?access_token=%s" % tok
request, channel = make_request( request, channel = make_request(
self.hs.get_reactor(), "PUT", path, json.dumps(content).encode("utf8") self.hs.get_reactor(),
self.site,
"PUT",
path,
json.dumps(content).encode("utf8"),
) )
render(request, self.resource, self.hs.get_reactor()) render(request, self.site.resource, self.hs.get_reactor())
assert int(channel.result["code"]) == expect_code, ( assert int(channel.result["code"]) == expect_code, (
"Expected: %d, got: %d, resp: %r" "Expected: %d, got: %d, resp: %r"
@ -210,9 +223,11 @@ class RestHelper:
if body is not None: if body is not None:
content = json.dumps(body).encode("utf8") content = json.dumps(body).encode("utf8")
request, channel = make_request(self.hs.get_reactor(), method, path, content) request, channel = make_request(
self.hs.get_reactor(), self.site, method, path, content
)
render(request, self.resource, self.hs.get_reactor()) render(request, self.site.resource, self.hs.get_reactor())
assert int(channel.result["code"]) == expect_code, ( assert int(channel.result["code"]) == expect_code, (
"Expected: %d, got: %d, resp: %r" "Expected: %d, got: %d, resp: %r"
@ -296,10 +311,13 @@ class RestHelper:
image_length = len(image_data) image_length = len(image_data)
path = "/_matrix/media/r0/upload?filename=%s" % (filename,) path = "/_matrix/media/r0/upload?filename=%s" % (filename,)
request, channel = make_request( request, channel = make_request(
self.hs.get_reactor(), "POST", path, content=image_data, access_token=tok self.hs.get_reactor(),
) FakeSite(resource),
request.requestHeaders.addRawHeader( "POST",
b"Content-Length", str(image_length).encode("UTF-8") path,
content=image_data,
access_token=tok,
custom_headers=[(b"Content-Length", str(image_length))],
) )
request.render(resource) request.render(resource)
self.hs.get_reactor().pump([100]) self.hs.get_reactor().pump([100])

View File

@ -31,6 +31,7 @@ from synapse.rest.client.v2_alpha import account, register
from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
from tests import unittest from tests import unittest
from tests.server import FakeSite, make_request
from tests.unittest import override_config from tests.unittest import override_config
@ -255,9 +256,16 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
path = link.replace("https://example.com", "") path = link.replace("https://example.com", "")
# Load the password reset confirmation page # Load the password reset confirmation page
request, channel = self.make_request("GET", path, shorthand=False) request, channel = make_request(
self.reactor,
FakeSite(self.submit_token_resource),
"GET",
path,
shorthand=False,
)
request.render(self.submit_token_resource) request.render(self.submit_token_resource)
self.pump() self.pump()
self.assertEquals(200, channel.code, channel.result) self.assertEquals(200, channel.code, channel.result)
# Now POST to the same endpoint, mimicking the same behaviour as clicking the # Now POST to the same endpoint, mimicking the same behaviour as clicking the
@ -271,7 +279,9 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
form_args.append(arg) form_args.append(arg)
# Confirm the password reset # Confirm the password reset
request, channel = self.make_request( request, channel = make_request(
self.reactor,
FakeSite(self.submit_token_resource),
"POST", "POST",
path, path,
content=urlencode(form_args).encode("utf8"), content=urlencode(form_args).encode("utf8"),

View File

@ -32,7 +32,7 @@ from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
from tests import unittest from tests import unittest
from tests.server import FakeChannel, wait_until_result from tests.server import FakeChannel
from tests.utils import default_config from tests.utils import default_config
@ -41,7 +41,7 @@ class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase):
self.http_client = Mock() self.http_client = Mock()
return self.setup_test_homeserver(http_client=self.http_client) return self.setup_test_homeserver(http_client=self.http_client)
def create_test_json_resource(self): def create_test_resource(self):
return create_resource_tree( return create_resource_tree(
{"/_matrix/key/v2": KeyApiV2Resource(self.hs)}, root_resource=NoResource() {"/_matrix/key/v2": KeyApiV2Resource(self.hs)}, root_resource=NoResource()
) )
@ -94,7 +94,7 @@ class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase):
% (server_name.encode("utf-8"), key_id.encode("utf-8")), % (server_name.encode("utf-8"), key_id.encode("utf-8")),
b"1.1", b"1.1",
) )
wait_until_result(self.reactor, req) channel.await_result()
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200)
resp = channel.json_body resp = channel.json_body
return resp return resp
@ -190,7 +190,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
req.requestReceived( req.requestReceived(
b"POST", path.encode("utf-8"), b"1.1", b"POST", path.encode("utf-8"), b"1.1",
) )
wait_until_result(self.reactor, req) channel.await_result()
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200)
resp = channel.json_body resp = channel.json_body
return resp return resp

View File

@ -36,6 +36,7 @@ from synapse.rest.media.v1.media_storage import MediaStorage
from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
from tests import unittest from tests import unittest
from tests.server import FakeSite, make_request
class MediaStorageTests(unittest.HomeserverTestCase): class MediaStorageTests(unittest.HomeserverTestCase):
@ -227,7 +228,13 @@ class MediaRepoTests(unittest.HomeserverTestCase):
def _req(self, content_disposition): def _req(self, content_disposition):
request, channel = self.make_request("GET", self.media_id, shorthand=False) request, channel = make_request(
self.reactor,
FakeSite(self.download_resource),
"GET",
self.media_id,
shorthand=False,
)
request.render(self.download_resource) request.render(self.download_resource)
self.pump() self.pump()
@ -317,8 +324,12 @@ class MediaRepoTests(unittest.HomeserverTestCase):
def _test_thumbnail(self, method, expected_body, expected_found): def _test_thumbnail(self, method, expected_body, expected_found):
params = "?width=32&height=32&method=" + method params = "?width=32&height=32&method=" + method
request, channel = self.make_request( request, channel = make_request(
"GET", self.media_id + params, shorthand=False self.reactor,
FakeSite(self.thumbnail_resource),
"GET",
self.media_id + params,
shorthand=False,
) )
request.render(self.thumbnail_resource) request.render(self.thumbnail_resource)
self.pump() self.pump()

View File

@ -20,11 +20,9 @@ from tests import unittest
class HealthCheckTests(unittest.HomeserverTestCase): class HealthCheckTests(unittest.HomeserverTestCase):
def setUp(self): def create_test_resource(self):
super().setUp()
# replace the JsonResource with a HealthResource. # replace the JsonResource with a HealthResource.
self.resource = HealthResource() return HealthResource()
def test_health(self): def test_health(self):
request, channel = self.make_request("GET", "/health", shorthand=False) request, channel = self.make_request("GET", "/health", shorthand=False)

View File

@ -20,11 +20,9 @@ from tests import unittest
class WellKnownTests(unittest.HomeserverTestCase): class WellKnownTests(unittest.HomeserverTestCase):
def setUp(self): def create_test_resource(self):
super().setUp()
# replace the JsonResource with a WellKnownResource # replace the JsonResource with a WellKnownResource
self.resource = WellKnownResource(self.hs) return WellKnownResource(self.hs)
def test_well_known(self): def test_well_known(self):
self.hs.config.public_baseurl = "https://tesths" self.hs.config.public_baseurl = "https://tesths"

View File

@ -2,7 +2,7 @@ import json
import logging import logging
from collections import deque from collections import deque
from io import SEEK_END, BytesIO from io import SEEK_END, BytesIO
from typing import Callable from typing import Callable, Iterable, Optional, Tuple, Union
import attr import attr
from typing_extensions import Deque from typing_extensions import Deque
@ -21,6 +21,7 @@ from twisted.python.failure import Failure
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
from twisted.web.http import unquote from twisted.web.http import unquote
from twisted.web.http_headers import Headers from twisted.web.http_headers import Headers
from twisted.web.resource import IResource
from twisted.web.server import Site from twisted.web.server import Site
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
@ -117,6 +118,25 @@ class FakeChannel:
def transport(self): def transport(self):
return self return self
def await_result(self, timeout: int = 100) -> None:
"""
Wait until the request is finished.
"""
self._reactor.run()
x = 0
while not self.result.get("done"):
# If there's a producer, tell it to resume producing so we get content
if self._producer:
self._producer.resumeProducing()
x += 1
if x > timeout:
raise TimedOutException("Timed out waiting for request to finish.")
self._reactor.advance(0.1)
class FakeSite: class FakeSite:
""" """
@ -128,9 +148,21 @@ class FakeSite:
site_tag = "test" site_tag = "test"
access_logger = logging.getLogger("synapse.access.http.fake") access_logger = logging.getLogger("synapse.access.http.fake")
def __init__(self, resource: IResource):
"""
Args:
resource: the resource to be used for rendering all requests
"""
self._resource = resource
def getResourceFor(self, request):
return self._resource
def make_request( def make_request(
reactor, reactor,
site: Site,
method, method,
path, path,
content=b"", content=b"",
@ -139,12 +171,17 @@ def make_request(
shorthand=True, shorthand=True,
federation_auth_origin=None, federation_auth_origin=None,
content_is_form=False, content_is_form=False,
custom_headers: Optional[
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
] = None,
): ):
""" """
Make a web request using the given method and path, feed it the Make a web request using the given method and path, feed it the
content, and return the Request and the Channel underneath. content, and return the Request and the Channel underneath.
Args: Args:
site: The twisted Site to associate with the Channel
method (bytes/unicode): The HTTP request method ("verb"). method (bytes/unicode): The HTTP request method ("verb").
path (bytes/unicode): The HTTP path, suitably URL encoded (e.g. path (bytes/unicode): The HTTP path, suitably URL encoded (e.g.
escaped UTF-8 & spaces and such). escaped UTF-8 & spaces and such).
@ -157,6 +194,8 @@ def make_request(
content_is_form: Whether the content is URL encoded form data. Adds the content_is_form: Whether the content is URL encoded form data. Adds the
'Content-Type': 'application/x-www-form-urlencoded' header. 'Content-Type': 'application/x-www-form-urlencoded' header.
custom_headers: (name, value) pairs to add as request headers
Returns: Returns:
Tuple[synapse.http.site.SynapseRequest, channel] Tuple[synapse.http.site.SynapseRequest, channel]
""" """
@ -178,10 +217,11 @@ def make_request(
if not path.startswith(b"/"): if not path.startswith(b"/"):
path = b"/" + path path = b"/" + path
if isinstance(content, dict):
content = json.dumps(content).encode("utf8")
if isinstance(content, str): if isinstance(content, str):
content = content.encode("utf8") content = content.encode("utf8")
site = FakeSite()
channel = FakeChannel(site, reactor) channel = FakeChannel(site, reactor)
req = request(channel) req = request(channel)
@ -211,35 +251,18 @@ def make_request(
# Assume the body is JSON # Assume the body is JSON
req.requestHeaders.addRawHeader(b"Content-Type", b"application/json") req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")
if custom_headers:
for k, v in custom_headers:
req.requestHeaders.addRawHeader(k, v)
req.requestReceived(method, path, b"1.1") req.requestReceived(method, path, b"1.1")
return req, channel return req, channel
def wait_until_result(clock, request, timeout=100):
"""
Wait until the request is finished.
"""
clock.run()
x = 0
while not request.finished:
# If there's a producer, tell it to resume producing so we get content
if request._channel._producer:
request._channel._producer.resumeProducing()
x += 1
if x > timeout:
raise TimedOutException("Timed out waiting for request to finish.")
clock.advance(0.1)
def render(request, resource, clock): def render(request, resource, clock):
request.render(resource) request.render(resource)
wait_until_result(clock, request) request._channel.await_result()
@implementer(IReactorPluggableNameResolver) @implementer(IReactorPluggableNameResolver)

View File

@ -21,6 +21,7 @@ from synapse.http.site import XForwardedForRequest
from synapse.rest.client.v1 import login from synapse.rest.client.v1 import login
from tests import unittest from tests import unittest
from tests.server import make_request
from tests.test_utils import make_awaitable from tests.test_utils import make_awaitable
from tests.unittest import override_config from tests.unittest import override_config
@ -408,17 +409,18 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase):
# Advance to a known time # Advance to a known time
self.reactor.advance(123456 - self.reactor.seconds()) self.reactor.advance(123456 - self.reactor.seconds())
request, channel = self.make_request( headers1 = {b"User-Agent": b"Mozzila pizza"}
headers1.update(headers)
request, channel = make_request(
self.reactor,
self.site,
"GET", "GET",
"/_matrix/client/r0/admin/users/" + self.user_id, "/_matrix/client/r0/admin/users/" + self.user_id,
access_token=access_token, access_token=access_token,
custom_headers=headers1.items(),
**make_request_args, **make_request_args,
) )
request.requestHeaders.addRawHeader(b"User-Agent", b"Mozzila pizza")
# Add the optional headers
for h, v in headers.items():
request.requestHeaders.addRawHeader(h, v)
self.render(request) self.render(request)
# Advance so the save loop occurs # Advance so the save loop occurs

View File

@ -26,6 +26,7 @@ from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.server import ( from tests.server import (
FakeSite,
ThreadedMemoryReactorClock, ThreadedMemoryReactorClock,
make_request, make_request,
render, render,
@ -62,7 +63,7 @@ class JsonResourceTests(unittest.TestCase):
) )
request, channel = make_request( request, channel = make_request(
self.reactor, b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83" self.reactor, FakeSite(res), b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83"
) )
render(request, res, self.reactor) render(request, res, self.reactor)
@ -83,7 +84,9 @@ class JsonResourceTests(unittest.TestCase):
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
) )
request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo") request, channel = make_request(
self.reactor, FakeSite(res), b"GET", b"/_matrix/foo"
)
render(request, res, self.reactor) render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b"500") self.assertEqual(channel.result["code"], b"500")
@ -108,7 +111,9 @@ class JsonResourceTests(unittest.TestCase):
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
) )
request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo") request, channel = make_request(
self.reactor, FakeSite(res), b"GET", b"/_matrix/foo"
)
render(request, res, self.reactor) render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b"500") self.assertEqual(channel.result["code"], b"500")
@ -127,7 +132,9 @@ class JsonResourceTests(unittest.TestCase):
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
) )
request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo") request, channel = make_request(
self.reactor, FakeSite(res), b"GET", b"/_matrix/foo"
)
render(request, res, self.reactor) render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b"403") self.assertEqual(channel.result["code"], b"403")
@ -150,7 +157,9 @@ class JsonResourceTests(unittest.TestCase):
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
) )
request, channel = make_request(self.reactor, b"GET", b"/_matrix/foobar") request, channel = make_request(
self.reactor, FakeSite(res), b"GET", b"/_matrix/foobar"
)
render(request, res, self.reactor) render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b"400") self.assertEqual(channel.result["code"], b"400")
@ -173,7 +182,9 @@ class JsonResourceTests(unittest.TestCase):
) )
# The path was registered as GET, but this is a HEAD request. # The path was registered as GET, but this is a HEAD request.
request, channel = make_request(self.reactor, b"HEAD", b"/_matrix/foo") request, channel = make_request(
self.reactor, FakeSite(res), b"HEAD", b"/_matrix/foo"
)
render(request, res, self.reactor) render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b"200") self.assertEqual(channel.result["code"], b"200")
@ -196,9 +207,6 @@ class OptionsResourceTests(unittest.TestCase):
def _make_request(self, method, path): def _make_request(self, method, path):
"""Create a request from the method/path and return a channel with the response.""" """Create a request from the method/path and return a channel with the response."""
request, channel = make_request(self.reactor, method, path, shorthand=False)
request.prepath = [] # This doesn't get set properly by make_request.
# Create a site and query for the resource. # Create a site and query for the resource.
site = SynapseSite( site = SynapseSite(
"test", "test",
@ -207,6 +215,12 @@ class OptionsResourceTests(unittest.TestCase):
self.resource, self.resource,
"1.0", "1.0",
) )
request, channel = make_request(
self.reactor, site, method, path, shorthand=False
)
request.prepath = [] # This doesn't get set properly by make_request.
request.site = site request.site = site
resource = site.getResourceFor(request) resource = site.getResourceFor(request)
@ -284,7 +298,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
res = WrapHtmlRequestHandlerTests.TestResource() res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback res.callback = callback
request, channel = make_request(self.reactor, b"GET", b"/path") request, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
render(request, res, self.reactor) render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b"200") self.assertEqual(channel.result["code"], b"200")
@ -303,7 +317,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
res = WrapHtmlRequestHandlerTests.TestResource() res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback res.callback = callback
request, channel = make_request(self.reactor, b"GET", b"/path") request, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
render(request, res, self.reactor) render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b"301") self.assertEqual(channel.result["code"], b"301")
@ -325,7 +339,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
res = WrapHtmlRequestHandlerTests.TestResource() res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback res.callback = callback
request, channel = make_request(self.reactor, b"GET", b"/path") request, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
render(request, res, self.reactor) render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b"304") self.assertEqual(channel.result["code"], b"304")
@ -345,7 +359,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
res = WrapHtmlRequestHandlerTests.TestResource() res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback res.callback = callback
request, channel = make_request(self.reactor, b"HEAD", b"/path") request, channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/path")
render(request, res, self.reactor) render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b"200") self.assertEqual(channel.result["code"], b"200")

View File

@ -30,6 +30,7 @@ from twisted.internet.defer import Deferred, ensureDeferred, succeed
from twisted.python.failure import Failure from twisted.python.failure import Failure
from twisted.python.threadpool import ThreadPool from twisted.python.threadpool import ThreadPool
from twisted.trial import unittest from twisted.trial import unittest
from twisted.web.resource import Resource
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
@ -239,10 +240,8 @@ class HomeserverTestCase(TestCase):
if not isinstance(self.hs, HomeServer): if not isinstance(self.hs, HomeServer):
raise Exception("A homeserver wasn't returned, but %r" % (self.hs,)) raise Exception("A homeserver wasn't returned, but %r" % (self.hs,))
# Register the resources # create the root resource, and a site to wrap it.
self.resource = self.create_test_json_resource() self.resource = self.create_test_resource()
# create a site to wrap the resource.
self.site = SynapseSite( self.site = SynapseSite(
logger_name="synapse.access.http.fake", logger_name="synapse.access.http.fake",
site_tag=self.hs.config.server.server_name, site_tag=self.hs.config.server.server_name,
@ -253,7 +252,7 @@ class HomeserverTestCase(TestCase):
from tests.rest.client.v1.utils import RestHelper from tests.rest.client.v1.utils import RestHelper
self.helper = RestHelper(self.hs, self.resource, getattr(self, "user_id", None)) self.helper = RestHelper(self.hs, self.site, getattr(self, "user_id", None))
if hasattr(self, "user_id"): if hasattr(self, "user_id"):
if self.hijack_auth: if self.hijack_auth:
@ -323,15 +322,12 @@ class HomeserverTestCase(TestCase):
hs = self.setup_test_homeserver() hs = self.setup_test_homeserver()
return hs return hs
def create_test_json_resource(self): def create_test_resource(self) -> Resource:
""" """
Create a test JsonResource, with the relevant servlets registerd to it Create a the root resource for the test server.
The default implementation calls each function in `servlets` to do the The default implementation creates a JsonResource and calls each function in
registration. `servlets` to register servletes against it
Returns:
JsonResource:
""" """
resource = JsonResource(self.hs) resource = JsonResource(self.hs)
@ -429,11 +425,9 @@ class HomeserverTestCase(TestCase):
Returns: Returns:
Tuple[synapse.http.site.SynapseRequest, channel] Tuple[synapse.http.site.SynapseRequest, channel]
""" """
if isinstance(content, dict):
content = json.dumps(content).encode("utf8")
return make_request( return make_request(
self.reactor, self.reactor,
self.site,
method, method,
path, path,
content, content,