Compare commits

..

12 Commits

Author SHA1 Message Date
Richard van der Hoff 3dc1871219
Merge pull request #8757 from matrix-org/rav/pass_site_to_make_request
Pass a Site into `make_request`
2020-11-16 18:22:24 +00:00
Richard van der Hoff f125895475
Move `wait_until_result` into `FakeChannel` (#8758)
FakeChannel has everything we need, and this more accurately models the real
flow.
2020-11-16 18:21:47 +00:00
Richard van der Hoff c3e3552ec4 fixup test 2020-11-16 15:51:47 +00:00
Andrew Morgan 4f76eef0e8
Generalise _locally_reject_invite (#8751)
`_locally_reject_invite` generates an out-of-band membership event which can be passed to clients, but not other homeservers.

This is used when we fail to reject an invite over federation. If this happens, we instead just generate a leave event locally and send it down /sync, allowing clients to reject invites even if we can't reach the remote homeserver.

A similar flow needs to be put in place for rescinding knocks. If we're unable to contact any remote server from the room we've tried to knock on, we'd still like to generate and store the leave event locally. Hence the need to reuse, and thus generalise, this method.

Separated from #6739.
2020-11-16 15:37:36 +00:00
Richard van der Hoff bebfb9a97b
Merge branch 'develop' into rav/pass_site_to_make_request 2020-11-16 15:22:40 +00:00
Richard van der Hoff 791d7cd6f0
Rename `create_test_json_resource` to `create_test_resource` (#8759)
The root resource isn't necessarily a JsonResource, so rename this method
accordingly, and update a couple of test classes to use the method rather than
directly manipulating self.resource.
2020-11-16 14:45:52 +00:00
Richard van der Hoff ebc405446e
Add a `custom_headers` param to `make_request` (#8760)
Some tests want to set some custom HTTP request headers, so provide a way to do
that before calling requestReceived().
2020-11-16 14:45:22 +00:00
Richard van der Hoff 0d33c53534 changelog 2020-11-15 23:09:03 +00:00
Richard van der Hoff cfd895a22e use global make_request() directly where we have a custom Resource
Where we want to render a request against a specific Resource, call the global
make_request() function rather than the one in HomeserverTestCase, allowing us
to pass in an appropriate `Site`.
2020-11-15 23:09:03 +00:00
Richard van der Hoff 70c0d47989 fix dict handling for make_request() 2020-11-15 23:09:03 +00:00
Richard van der Hoff 9debe657a3 pass a Site into make_request 2020-11-15 23:09:03 +00:00
Richard van der Hoff d3523e3e97 pass a Site into RestHelper 2020-11-15 23:09:03 +00:00
26 changed files with 312 additions and 164 deletions

1
changelog.d/8751.misc Normal file
View File

@ -0,0 +1 @@
Generalise `RoomMemberHandler._locally_reject_invite` to apply to more flows than just invite.

1
changelog.d/8757.misc Normal file
View File

@ -0,0 +1 @@
Refactor test utilities for injecting HTTP requests.

1
changelog.d/8758.misc Normal file
View File

@ -0,0 +1 @@
Refactor test utilities for injecting HTTP requests.

1
changelog.d/8759.misc Normal file
View File

@ -0,0 +1 @@
Refactor test utilities for injecting HTTP requests.

1
changelog.d/8760.misc Normal file
View File

@ -0,0 +1 @@
Refactor test utilities for injecting HTTP requests.

View File

@ -1104,32 +1104,34 @@ class RoomMemberMasterHandler(RoomMemberHandler):
#
logger.warning("Failed to reject invite: %s", e)
return await self._locally_reject_invite(
return await self._generate_local_out_of_band_leave(
invite_event, txn_id, requester, content
)
async def _locally_reject_invite(
async def _generate_local_out_of_band_leave(
self,
invite_event: EventBase,
previous_membership_event: EventBase,
txn_id: Optional[str],
requester: Requester,
content: JsonDict,
) -> Tuple[str, int]:
"""Generate a local invite rejection
"""Generate a local leave event for a room
This is called after we fail to reject an invite via a remote server. It
generates an out-of-band membership event locally.
This can be called after we e.g fail to reject an invite via a remote server.
It generates an out-of-band membership event locally.
Args:
invite_event: the invite to be rejected
previous_membership_event: the previous membership event for this user
txn_id: optional transaction ID supplied by the client
requester: user making the rejection request, according to the access token
content: additional content to include in the rejection event.
requester: user making the request, according to the access token
content: additional content to include in the leave event.
Normally an empty dict.
"""
room_id = invite_event.room_id
target_user = invite_event.state_key
Returns:
A tuple containing (event_id, stream_id of the leave event)
"""
room_id = previous_membership_event.room_id
target_user = previous_membership_event.state_key
content["membership"] = Membership.LEAVE
@ -1141,12 +1143,12 @@ class RoomMemberMasterHandler(RoomMemberHandler):
"state_key": target_user,
}
# the auth events for the new event are the same as that of the invite, plus
# the invite itself.
# the auth events for the new event are the same as that of the previous event, plus
# the event itself.
#
# the prev_events are just the invite.
prev_event_ids = [invite_event.event_id]
auth_event_ids = invite_event.auth_event_ids() + prev_event_ids
# the prev_events consist solely of the previous membership event.
prev_event_ids = [previous_membership_event.event_id]
auth_event_ids = previous_membership_event.auth_event_ids() + prev_event_ids
event, context = await self.event_creation_handler.create_event(
requester,

View File

@ -15,6 +15,7 @@
from synapse.app.generic_worker import GenericWorkerServer
from tests.server import make_request, render
from tests.unittest import HomeserverTestCase
@ -55,10 +56,10 @@ class FrontendProxyTests(HomeserverTestCase):
# Grab the resource from the site that was told to listen
self.assertEqual(len(self.reactor.tcpServers), 1)
site = self.reactor.tcpServers[0][1]
self.resource = site.resource.children[b"_matrix"].children[b"client"]
resource = site.resource.children[b"_matrix"].children[b"client"]
request, channel = self.make_request("PUT", "presence/a/status")
self.render(request)
request, channel = make_request(self.reactor, site, "PUT", "presence/a/status")
render(request, resource, self.reactor)
# 400 + unrecognised, because nothing is registered
self.assertEqual(channel.code, 400)
@ -77,10 +78,10 @@ class FrontendProxyTests(HomeserverTestCase):
# Grab the resource from the site that was told to listen
self.assertEqual(len(self.reactor.tcpServers), 1)
site = self.reactor.tcpServers[0][1]
self.resource = site.resource.children[b"_matrix"].children[b"client"]
resource = site.resource.children[b"_matrix"].children[b"client"]
request, channel = self.make_request("PUT", "presence/a/status")
self.render(request)
request, channel = make_request(self.reactor, site, "PUT", "presence/a/status")
render(request, resource, self.reactor)
# 401, because the stub servlet still checks authentication
self.assertEqual(channel.code, 401)

View File

@ -20,6 +20,7 @@ from synapse.app.generic_worker import GenericWorkerServer
from synapse.app.homeserver import SynapseHomeServer
from synapse.config.server import parse_listener_def
from tests.server import make_request, render
from tests.unittest import HomeserverTestCase
@ -66,16 +67,16 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
# Grab the resource from the site that was told to listen
site = self.reactor.tcpServers[0][1]
try:
self.resource = site.resource.children[b"_matrix"].children[b"federation"]
resource = site.resource.children[b"_matrix"].children[b"federation"]
except KeyError:
if expectation == "no_resource":
return
raise
request, channel = self.make_request(
"GET", "/_matrix/federation/v1/openid/userinfo"
request, channel = make_request(
self.reactor, site, "GET", "/_matrix/federation/v1/openid/userinfo"
)
self.render(request)
render(request, resource, self.reactor)
self.assertEqual(channel.code, 401)
@ -115,15 +116,15 @@ class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase):
# Grab the resource from the site that was told to listen
site = self.reactor.tcpServers[0][1]
try:
self.resource = site.resource.children[b"_matrix"].children[b"federation"]
resource = site.resource.children[b"_matrix"].children[b"federation"]
except KeyError:
if expectation == "no_resource":
return
raise
request, channel = self.make_request(
"GET", "/_matrix/federation/v1/openid/userinfo"
request, channel = make_request(
self.reactor, site, "GET", "/_matrix/federation/v1/openid/userinfo"
)
self.render(request)
render(request, resource, self.reactor)
self.assertEqual(channel.code, 401)

View File

@ -17,6 +17,7 @@
from synapse.http.additional_resource import AdditionalResource
from synapse.http.server import respond_with_json
from tests.server import FakeSite, make_request, render
from tests.unittest import HomeserverTestCase
@ -43,20 +44,20 @@ class AdditionalResourceTests(HomeserverTestCase):
def test_async(self):
handler = _AsyncTestCustomEndpoint({}, None).handle_request
self.resource = AdditionalResource(self.hs, handler)
resource = AdditionalResource(self.hs, handler)
request, channel = self.make_request("GET", "/")
self.render(request)
request, channel = make_request(self.reactor, FakeSite(resource), "GET", "/")
render(request, resource, self.reactor)
self.assertEqual(request.code, 200)
self.assertEqual(channel.json_body, {"some_key": "some_value_async"})
def test_sync(self):
handler = _SyncTestCustomEndpoint({}, None).handle_request
self.resource = AdditionalResource(self.hs, handler)
resource = AdditionalResource(self.hs, handler)
request, channel = self.make_request("GET", "/")
self.render(request)
request, channel = make_request(self.reactor, FakeSite(resource), "GET", "/")
render(request, resource, self.reactor)
self.assertEqual(request.code, 200)
self.assertEqual(channel.json_body, {"some_key": "some_value_sync"})

View File

@ -240,8 +240,8 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
lambda: self._handle_http_replication_attempt(self.hs, 8765),
)
def create_test_json_resource(self):
"""Overrides `HomeserverTestCase.create_test_json_resource`.
def create_test_resource(self):
"""Overrides `HomeserverTestCase.create_test_resource`.
"""
# We override this so that it automatically registers all the HTTP
# replication servlets, without having to explicitly do that in all

View File

@ -20,7 +20,7 @@ from synapse.rest.client.v2_alpha import register
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.rest.client.v2_alpha.test_auth import DummyRecaptchaChecker
from tests.server import FakeChannel
from tests.server import FakeChannel, make_request
logger = logging.getLogger(__name__)
@ -46,8 +46,11 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
"""Test that registration works when using a single client reader worker.
"""
worker_hs = self.make_worker_hs("synapse.app.client_reader")
site = self._hs_to_site[worker_hs]
request_1, channel_1 = self.make_request(
request_1, channel_1 = make_request(
self.reactor,
site,
"POST",
"register",
{"username": "user", "type": "m.login.password", "password": "bar"},
@ -59,8 +62,12 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
session = channel_1.json_body["session"]
# also complete the dummy auth
request_2, channel_2 = self.make_request(
"POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}}
request_2, channel_2 = make_request(
self.reactor,
site,
"POST",
"register",
{"auth": {"session": session, "type": "m.login.dummy"}},
) # type: SynapseRequest, FakeChannel
self.render_on_worker(worker_hs, request_2)
self.assertEqual(request_2.code, 200)
@ -74,7 +81,10 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
worker_hs_1 = self.make_worker_hs("synapse.app.client_reader")
worker_hs_2 = self.make_worker_hs("synapse.app.client_reader")
request_1, channel_1 = self.make_request(
site_1 = self._hs_to_site[worker_hs_1]
request_1, channel_1 = make_request(
self.reactor,
site_1,
"POST",
"register",
{"username": "user", "type": "m.login.password", "password": "bar"},
@ -86,8 +96,13 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
session = channel_1.json_body["session"]
# also complete the dummy auth
request_2, channel_2 = self.make_request(
"POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}}
site_2 = self._hs_to_site[worker_hs_2]
request_2, channel_2 = make_request(
self.reactor,
site_2,
"POST",
"register",
{"auth": {"session": session, "type": "m.login.dummy"}},
) # type: SynapseRequest, FakeChannel
self.render_on_worker(worker_hs_2, request_2)
self.assertEqual(request_2.code, 200)

View File

@ -28,7 +28,7 @@ from synapse.server import HomeServer
from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import FakeChannel, FakeTransport
from tests.server import FakeChannel, FakeSite, FakeTransport, make_request
logger = logging.getLogger(__name__)
@ -67,14 +67,16 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
The channel for the *client* request and the *outbound* request for
the media which the caller should respond to.
"""
request, channel = self.make_request(
resource = hs.get_media_repository_resource().children[b"download"]
request, channel = make_request(
self.reactor,
FakeSite(resource),
"GET",
"/{}/{}".format(target, media_id),
shorthand=False,
access_token=self.access_token,
)
request.render(hs.get_media_repository_resource().children[b"download"])
request.render(resource)
self.pump()
clients = self.reactor.tcpClients

View File

@ -22,6 +22,7 @@ from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import sync
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import make_request
from tests.utils import USE_POSTGRES_FOR_TESTS
logger = logging.getLogger(__name__)
@ -148,6 +149,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
sync_hs = self.make_worker_hs(
"synapse.app.generic_worker", {"worker_name": "sync"},
)
sync_hs_site = self._hs_to_site[sync_hs]
# Specially selected room IDs that get persisted on different workers.
room_id1 = "!foo:test"
@ -178,7 +180,9 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
)
# Do an initial sync so that we're up to date.
request, channel = self.make_request("GET", "/sync", access_token=access_token)
request, channel = make_request(
self.reactor, sync_hs_site, "GET", "/sync", access_token=access_token
)
self.render_on_worker(sync_hs, request)
next_batch = channel.json_body["next_batch"]
@ -203,8 +207,12 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
# Check that syncing still gets the new event, despite the gap in the
# stream IDs.
request, channel = self.make_request(
"GET", "/sync?since={}".format(next_batch), access_token=access_token
request, channel = make_request(
self.reactor,
sync_hs_site,
"GET",
"/sync?since={}".format(next_batch),
access_token=access_token,
)
self.render_on_worker(sync_hs, request)
@ -230,7 +238,9 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
response = self.helper.send(room_id2, body="Hi!", tok=self.other_access_token)
first_event_in_room2 = response["event_id"]
request, channel = self.make_request(
request, channel = make_request(
self.reactor,
sync_hs_site,
"GET",
"/sync?since={}".format(vector_clock_token),
access_token=access_token,
@ -254,8 +264,12 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
self.helper.send(room_id1, body="Hi again!", tok=self.other_access_token)
self.helper.send(room_id2, body="Hi again!", tok=self.other_access_token)
request, channel = self.make_request(
"GET", "/sync?since={}".format(next_batch), access_token=access_token
request, channel = make_request(
self.reactor,
sync_hs_site,
"GET",
"/sync?since={}".format(next_batch),
access_token=access_token,
)
self.render_on_worker(sync_hs, request)
@ -269,7 +283,9 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
# Paginating back in the first room should not produce any results, as
# no events have happened in it. This tests that we are correctly
# filtering results based on the vector clock portion.
request, channel = self.make_request(
request, channel = make_request(
self.reactor,
sync_hs_site,
"GET",
"/rooms/{}/messages?from={}&to={}&dir=b".format(
room_id1, prev_batch1, vector_clock_token
@ -281,7 +297,9 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
# Paginating back on the second room should produce the first event
# again. This tests that pagination isn't completely broken.
request, channel = self.make_request(
request, channel = make_request(
self.reactor,
sync_hs_site,
"GET",
"/rooms/{}/messages?from={}&to={}&dir=b".format(
room_id2, prev_batch2, vector_clock_token
@ -295,7 +313,9 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
)
# Paginating forwards should give the same results
request, channel = self.make_request(
request, channel = make_request(
self.reactor,
sync_hs_site,
"GET",
"/rooms/{}/messages?from={}&to={}&dir=f".format(
room_id1, vector_clock_token, prev_batch1
@ -305,7 +325,9 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
self.render_on_worker(sync_hs, request)
self.assertListEqual([], channel.json_body["chunk"])
request, channel = self.make_request(
request, channel = make_request(
self.reactor,
sync_hs_site,
"GET",
"/rooms/{}/messages?from={}&to={}&dir=f".format(
room_id2, vector_clock_token, prev_batch2,

View File

@ -30,12 +30,13 @@ from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import groups
from tests import unittest
from tests.server import FakeSite, make_request
class VersionTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v1/server_version"
def create_test_json_resource(self):
def create_test_resource(self):
resource = JsonResource(self.hs)
VersionServlet(self.hs).register(resource)
return resource
@ -222,8 +223,13 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
def _ensure_quarantined(self, admin_user_tok, server_and_media_id):
"""Ensure a piece of media is quarantined when trying to access it."""
request, channel = self.make_request(
"GET", server_and_media_id, shorthand=False, access_token=admin_user_tok,
request, channel = make_request(
self.reactor,
FakeSite(self.download_resource),
"GET",
server_and_media_id,
shorthand=False,
access_token=admin_user_tok,
)
request.render(self.download_resource)
self.pump(1.0)
@ -287,7 +293,9 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
server_name, media_id = server_name_and_media_id.split("/")
# Attempt to access the media
request, channel = self.make_request(
request, channel = make_request(
self.reactor,
FakeSite(self.download_resource),
"GET",
server_name_and_media_id,
shorthand=False,
@ -462,7 +470,9 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
self._ensure_quarantined(admin_user_tok, server_and_media_id_1)
# Attempt to access each piece of media
request, channel = self.make_request(
request, channel = make_request(
self.reactor,
FakeSite(self.download_resource),
"GET",
server_and_media_id_2,
shorthand=False,

View File

@ -23,6 +23,7 @@ from synapse.rest.client.v1 import login, profile, room
from synapse.rest.media.v1.filepath import MediaFilePaths
from tests import unittest
from tests.server import FakeSite, make_request
class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
@ -124,7 +125,9 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
self.assertEqual(server_name, self.server_name)
# Attempt to access media
request, channel = self.make_request(
request, channel = make_request(
self.reactor,
FakeSite(download_resource),
"GET",
server_and_media_id,
shorthand=False,
@ -161,7 +164,9 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
)
# Attempt to access media
request, channel = self.make_request(
request, channel = make_request(
self.reactor,
FakeSite(download_resource),
"GET",
server_and_media_id,
shorthand=False,
@ -535,7 +540,9 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
media_id = server_and_media_id.split("/")[1]
local_path = self.filepaths.local_media_filepath(media_id)
request, channel = self.make_request(
request, channel = make_request(
self.reactor,
FakeSite(download_resource),
"GET",
server_and_media_id,
shorthand=False,

View File

@ -21,7 +21,7 @@ from synapse.rest.client.v1 import login, room
from synapse.rest.consent import consent_resource
from tests import unittest
from tests.server import render
from tests.server import FakeSite, make_request, render
class ConsentResourceTestCase(unittest.HomeserverTestCase):
@ -61,7 +61,9 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
def test_render_public_consent(self):
"""You can observe the terms form without specifying a user"""
resource = consent_resource.ConsentResource(self.hs)
request, channel = self.make_request("GET", "/consent?v=1", shorthand=False)
request, channel = make_request(
self.reactor, FakeSite(resource), "GET", "/consent?v=1", shorthand=False
)
render(request, resource, self.reactor)
self.assertEqual(channel.code, 200)
@ -81,8 +83,13 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
uri_builder.build_user_consent_uri(user_id).replace("_matrix/", "")
+ "&u=user"
)
request, channel = self.make_request(
"GET", consent_uri, access_token=access_token, shorthand=False
request, channel = make_request(
self.reactor,
FakeSite(resource),
"GET",
consent_uri,
access_token=access_token,
shorthand=False,
)
render(request, resource, self.reactor)
self.assertEqual(channel.code, 200)
@ -92,7 +99,9 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
self.assertEqual(consented, "False")
# POST to the consent page, saying we've agreed
request, channel = self.make_request(
request, channel = make_request(
self.reactor,
FakeSite(resource),
"POST",
consent_uri + "&v=" + version,
access_token=access_token,
@ -103,8 +112,13 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
# Fetch the consent page, to get the consent version -- it should have
# changed
request, channel = self.make_request(
"GET", consent_uri, access_token=access_token, shorthand=False
request, channel = make_request(
self.reactor,
FakeSite(resource),
"GET",
consent_uri,
access_token=access_token,
shorthand=False,
)
render(request, resource, self.reactor)
self.assertEqual(channel.code, 200)

View File

@ -23,10 +23,11 @@ from typing import Any, Dict, Optional
import attr
from twisted.web.resource import Resource
from twisted.web.server import Site
from synapse.api.constants import Membership
from tests.server import make_request, render
from tests.server import FakeSite, make_request, render
@attr.s
@ -36,7 +37,7 @@ class RestHelper:
"""
hs = attr.ib()
resource = attr.ib()
site = attr.ib(type=Site)
auth_user_id = attr.ib()
def create_room_as(
@ -52,9 +53,13 @@ class RestHelper:
path = path + "?access_token=%s" % tok
request, channel = make_request(
self.hs.get_reactor(), "POST", path, json.dumps(content).encode("utf8")
self.hs.get_reactor(),
self.site,
"POST",
path,
json.dumps(content).encode("utf8"),
)
render(request, self.resource, self.hs.get_reactor())
render(request, self.site.resource, self.hs.get_reactor())
assert channel.result["code"] == b"%d" % expect_code, channel.result
self.auth_user_id = temp_id
@ -125,10 +130,14 @@ class RestHelper:
data.update(extra_data)
request, channel = make_request(
self.hs.get_reactor(), "PUT", path, json.dumps(data).encode("utf8")
self.hs.get_reactor(),
self.site,
"PUT",
path,
json.dumps(data).encode("utf8"),
)
render(request, self.resource, self.hs.get_reactor())
render(request, self.site.resource, self.hs.get_reactor())
assert int(channel.result["code"]) == expect_code, (
"Expected: %d, got: %d, resp: %r"
@ -158,9 +167,13 @@ class RestHelper:
path = path + "?access_token=%s" % tok
request, channel = make_request(
self.hs.get_reactor(), "PUT", path, json.dumps(content).encode("utf8")
self.hs.get_reactor(),
self.site,
"PUT",
path,
json.dumps(content).encode("utf8"),
)
render(request, self.resource, self.hs.get_reactor())
render(request, self.site.resource, self.hs.get_reactor())
assert int(channel.result["code"]) == expect_code, (
"Expected: %d, got: %d, resp: %r"
@ -210,9 +223,11 @@ class RestHelper:
if body is not None:
content = json.dumps(body).encode("utf8")
request, channel = make_request(self.hs.get_reactor(), method, path, content)
request, channel = make_request(
self.hs.get_reactor(), self.site, method, path, content
)
render(request, self.resource, self.hs.get_reactor())
render(request, self.site.resource, self.hs.get_reactor())
assert int(channel.result["code"]) == expect_code, (
"Expected: %d, got: %d, resp: %r"
@ -296,10 +311,13 @@ class RestHelper:
image_length = len(image_data)
path = "/_matrix/media/r0/upload?filename=%s" % (filename,)
request, channel = make_request(
self.hs.get_reactor(), "POST", path, content=image_data, access_token=tok
)
request.requestHeaders.addRawHeader(
b"Content-Length", str(image_length).encode("UTF-8")
self.hs.get_reactor(),
FakeSite(resource),
"POST",
path,
content=image_data,
access_token=tok,
custom_headers=[(b"Content-Length", str(image_length))],
)
request.render(resource)
self.hs.get_reactor().pump([100])

View File

@ -31,6 +31,7 @@ from synapse.rest.client.v2_alpha import account, register
from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
from tests import unittest
from tests.server import FakeSite, make_request
from tests.unittest import override_config
@ -255,9 +256,16 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
path = link.replace("https://example.com", "")
# Load the password reset confirmation page
request, channel = self.make_request("GET", path, shorthand=False)
request, channel = make_request(
self.reactor,
FakeSite(self.submit_token_resource),
"GET",
path,
shorthand=False,
)
request.render(self.submit_token_resource)
self.pump()
self.assertEquals(200, channel.code, channel.result)
# Now POST to the same endpoint, mimicking the same behaviour as clicking the
@ -271,7 +279,9 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
form_args.append(arg)
# Confirm the password reset
request, channel = self.make_request(
request, channel = make_request(
self.reactor,
FakeSite(self.submit_token_resource),
"POST",
path,
content=urlencode(form_args).encode("utf8"),

View File

@ -32,7 +32,7 @@ from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.stringutils import random_string
from tests import unittest
from tests.server import FakeChannel, wait_until_result
from tests.server import FakeChannel
from tests.utils import default_config
@ -41,7 +41,7 @@ class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase):
self.http_client = Mock()
return self.setup_test_homeserver(http_client=self.http_client)
def create_test_json_resource(self):
def create_test_resource(self):
return create_resource_tree(
{"/_matrix/key/v2": KeyApiV2Resource(self.hs)}, root_resource=NoResource()
)
@ -94,7 +94,7 @@ class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase):
% (server_name.encode("utf-8"), key_id.encode("utf-8")),
b"1.1",
)
wait_until_result(self.reactor, req)
channel.await_result()
self.assertEqual(channel.code, 200)
resp = channel.json_body
return resp
@ -190,7 +190,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
req.requestReceived(
b"POST", path.encode("utf-8"), b"1.1",
)
wait_until_result(self.reactor, req)
channel.await_result()
self.assertEqual(channel.code, 200)
resp = channel.json_body
return resp

View File

@ -36,6 +36,7 @@ from synapse.rest.media.v1.media_storage import MediaStorage
from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
from tests import unittest
from tests.server import FakeSite, make_request
class MediaStorageTests(unittest.HomeserverTestCase):
@ -227,7 +228,13 @@ class MediaRepoTests(unittest.HomeserverTestCase):
def _req(self, content_disposition):
request, channel = self.make_request("GET", self.media_id, shorthand=False)
request, channel = make_request(
self.reactor,
FakeSite(self.download_resource),
"GET",
self.media_id,
shorthand=False,
)
request.render(self.download_resource)
self.pump()
@ -317,8 +324,12 @@ class MediaRepoTests(unittest.HomeserverTestCase):
def _test_thumbnail(self, method, expected_body, expected_found):
params = "?width=32&height=32&method=" + method
request, channel = self.make_request(
"GET", self.media_id + params, shorthand=False
request, channel = make_request(
self.reactor,
FakeSite(self.thumbnail_resource),
"GET",
self.media_id + params,
shorthand=False,
)
request.render(self.thumbnail_resource)
self.pump()

View File

@ -20,11 +20,9 @@ from tests import unittest
class HealthCheckTests(unittest.HomeserverTestCase):
def setUp(self):
super().setUp()
def create_test_resource(self):
# replace the JsonResource with a HealthResource.
self.resource = HealthResource()
return HealthResource()
def test_health(self):
request, channel = self.make_request("GET", "/health", shorthand=False)

View File

@ -20,11 +20,9 @@ from tests import unittest
class WellKnownTests(unittest.HomeserverTestCase):
def setUp(self):
super().setUp()
def create_test_resource(self):
# replace the JsonResource with a WellKnownResource
self.resource = WellKnownResource(self.hs)
return WellKnownResource(self.hs)
def test_well_known(self):
self.hs.config.public_baseurl = "https://tesths"

View File

@ -2,7 +2,7 @@ import json
import logging
from collections import deque
from io import SEEK_END, BytesIO
from typing import Callable
from typing import Callable, Iterable, Optional, Tuple, Union
import attr
from typing_extensions import Deque
@ -21,6 +21,7 @@ from twisted.python.failure import Failure
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
from twisted.web.http import unquote
from twisted.web.http_headers import Headers
from twisted.web.resource import IResource
from twisted.web.server import Site
from synapse.http.site import SynapseRequest
@ -117,6 +118,25 @@ class FakeChannel:
def transport(self):
return self
def await_result(self, timeout: int = 100) -> None:
"""
Wait until the request is finished.
"""
self._reactor.run()
x = 0
while not self.result.get("done"):
# If there's a producer, tell it to resume producing so we get content
if self._producer:
self._producer.resumeProducing()
x += 1
if x > timeout:
raise TimedOutException("Timed out waiting for request to finish.")
self._reactor.advance(0.1)
class FakeSite:
"""
@ -128,9 +148,21 @@ class FakeSite:
site_tag = "test"
access_logger = logging.getLogger("synapse.access.http.fake")
def __init__(self, resource: IResource):
"""
Args:
resource: the resource to be used for rendering all requests
"""
self._resource = resource
def getResourceFor(self, request):
return self._resource
def make_request(
reactor,
site: Site,
method,
path,
content=b"",
@ -139,12 +171,17 @@ def make_request(
shorthand=True,
federation_auth_origin=None,
content_is_form=False,
custom_headers: Optional[
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
] = None,
):
"""
Make a web request using the given method and path, feed it the
content, and return the Request and the Channel underneath.
Args:
site: The twisted Site to associate with the Channel
method (bytes/unicode): The HTTP request method ("verb").
path (bytes/unicode): The HTTP path, suitably URL encoded (e.g.
escaped UTF-8 & spaces and such).
@ -157,6 +194,8 @@ def make_request(
content_is_form: Whether the content is URL encoded form data. Adds the
'Content-Type': 'application/x-www-form-urlencoded' header.
custom_headers: (name, value) pairs to add as request headers
Returns:
Tuple[synapse.http.site.SynapseRequest, channel]
"""
@ -178,10 +217,11 @@ def make_request(
if not path.startswith(b"/"):
path = b"/" + path
if isinstance(content, dict):
content = json.dumps(content).encode("utf8")
if isinstance(content, str):
content = content.encode("utf8")
site = FakeSite()
channel = FakeChannel(site, reactor)
req = request(channel)
@ -211,35 +251,18 @@ def make_request(
# Assume the body is JSON
req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")
if custom_headers:
for k, v in custom_headers:
req.requestHeaders.addRawHeader(k, v)
req.requestReceived(method, path, b"1.1")
return req, channel
def wait_until_result(clock, request, timeout=100):
"""
Wait until the request is finished.
"""
clock.run()
x = 0
while not request.finished:
# If there's a producer, tell it to resume producing so we get content
if request._channel._producer:
request._channel._producer.resumeProducing()
x += 1
if x > timeout:
raise TimedOutException("Timed out waiting for request to finish.")
clock.advance(0.1)
def render(request, resource, clock):
request.render(resource)
wait_until_result(clock, request)
request._channel.await_result()
@implementer(IReactorPluggableNameResolver)

View File

@ -21,6 +21,7 @@ from synapse.http.site import XForwardedForRequest
from synapse.rest.client.v1 import login
from tests import unittest
from tests.server import make_request
from tests.test_utils import make_awaitable
from tests.unittest import override_config
@ -408,17 +409,18 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase):
# Advance to a known time
self.reactor.advance(123456 - self.reactor.seconds())
request, channel = self.make_request(
headers1 = {b"User-Agent": b"Mozzila pizza"}
headers1.update(headers)
request, channel = make_request(
self.reactor,
self.site,
"GET",
"/_matrix/client/r0/admin/users/" + self.user_id,
access_token=access_token,
custom_headers=headers1.items(),
**make_request_args,
)
request.requestHeaders.addRawHeader(b"User-Agent", b"Mozzila pizza")
# Add the optional headers
for h, v in headers.items():
request.requestHeaders.addRawHeader(h, v)
self.render(request)
# Advance so the save loop occurs

View File

@ -26,6 +26,7 @@ from synapse.util import Clock
from tests import unittest
from tests.server import (
FakeSite,
ThreadedMemoryReactorClock,
make_request,
render,
@ -62,7 +63,7 @@ class JsonResourceTests(unittest.TestCase):
)
request, channel = make_request(
self.reactor, b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83"
self.reactor, FakeSite(res), b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83"
)
render(request, res, self.reactor)
@ -83,7 +84,9 @@ class JsonResourceTests(unittest.TestCase):
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
)
request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo")
request, channel = make_request(
self.reactor, FakeSite(res), b"GET", b"/_matrix/foo"
)
render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b"500")
@ -108,7 +111,9 @@ class JsonResourceTests(unittest.TestCase):
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
)
request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo")
request, channel = make_request(
self.reactor, FakeSite(res), b"GET", b"/_matrix/foo"
)
render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b"500")
@ -127,7 +132,9 @@ class JsonResourceTests(unittest.TestCase):
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
)
request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo")
request, channel = make_request(
self.reactor, FakeSite(res), b"GET", b"/_matrix/foo"
)
render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b"403")
@ -150,7 +157,9 @@ class JsonResourceTests(unittest.TestCase):
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
)
request, channel = make_request(self.reactor, b"GET", b"/_matrix/foobar")
request, channel = make_request(
self.reactor, FakeSite(res), b"GET", b"/_matrix/foobar"
)
render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b"400")
@ -173,7 +182,9 @@ class JsonResourceTests(unittest.TestCase):
)
# The path was registered as GET, but this is a HEAD request.
request, channel = make_request(self.reactor, b"HEAD", b"/_matrix/foo")
request, channel = make_request(
self.reactor, FakeSite(res), b"HEAD", b"/_matrix/foo"
)
render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b"200")
@ -196,9 +207,6 @@ class OptionsResourceTests(unittest.TestCase):
def _make_request(self, method, path):
"""Create a request from the method/path and return a channel with the response."""
request, channel = make_request(self.reactor, method, path, shorthand=False)
request.prepath = [] # This doesn't get set properly by make_request.
# Create a site and query for the resource.
site = SynapseSite(
"test",
@ -207,6 +215,12 @@ class OptionsResourceTests(unittest.TestCase):
self.resource,
"1.0",
)
request, channel = make_request(
self.reactor, site, method, path, shorthand=False
)
request.prepath = [] # This doesn't get set properly by make_request.
request.site = site
resource = site.getResourceFor(request)
@ -284,7 +298,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback
request, channel = make_request(self.reactor, b"GET", b"/path")
request, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b"200")
@ -303,7 +317,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback
request, channel = make_request(self.reactor, b"GET", b"/path")
request, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b"301")
@ -325,7 +339,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback
request, channel = make_request(self.reactor, b"GET", b"/path")
request, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b"304")
@ -345,7 +359,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback
request, channel = make_request(self.reactor, b"HEAD", b"/path")
request, channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/path")
render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b"200")

View File

@ -30,6 +30,7 @@ from twisted.internet.defer import Deferred, ensureDeferred, succeed
from twisted.python.failure import Failure
from twisted.python.threadpool import ThreadPool
from twisted.trial import unittest
from twisted.web.resource import Resource
from synapse.api.constants import EventTypes, Membership
from synapse.config.homeserver import HomeServerConfig
@ -239,10 +240,8 @@ class HomeserverTestCase(TestCase):
if not isinstance(self.hs, HomeServer):
raise Exception("A homeserver wasn't returned, but %r" % (self.hs,))
# Register the resources
self.resource = self.create_test_json_resource()
# create a site to wrap the resource.
# create the root resource, and a site to wrap it.
self.resource = self.create_test_resource()
self.site = SynapseSite(
logger_name="synapse.access.http.fake",
site_tag=self.hs.config.server.server_name,
@ -253,7 +252,7 @@ class HomeserverTestCase(TestCase):
from tests.rest.client.v1.utils import RestHelper
self.helper = RestHelper(self.hs, self.resource, getattr(self, "user_id", None))
self.helper = RestHelper(self.hs, self.site, getattr(self, "user_id", None))
if hasattr(self, "user_id"):
if self.hijack_auth:
@ -323,15 +322,12 @@ class HomeserverTestCase(TestCase):
hs = self.setup_test_homeserver()
return hs
def create_test_json_resource(self):
def create_test_resource(self) -> Resource:
"""
Create a test JsonResource, with the relevant servlets registerd to it
Create a the root resource for the test server.
The default implementation calls each function in `servlets` to do the
registration.
Returns:
JsonResource:
The default implementation creates a JsonResource and calls each function in
`servlets` to register servletes against it
"""
resource = JsonResource(self.hs)
@ -429,11 +425,9 @@ class HomeserverTestCase(TestCase):
Returns:
Tuple[synapse.http.site.SynapseRequest, channel]
"""
if isinstance(content, dict):
content = json.dumps(content).encode("utf8")
return make_request(
self.reactor,
self.site,
method,
path,
content,