diff --git a/changelog.d/8760.misc b/changelog.d/8760.misc new file mode 100644 index 0000000000..54502e9b90 --- /dev/null +++ b/changelog.d/8760.misc @@ -0,0 +1 @@ +Refactor test utilities for injecting HTTP requests. diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index afaf9f7b85..1b2d0497a6 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -296,10 +296,12 @@ class RestHelper: image_length = len(image_data) path = "/_matrix/media/r0/upload?filename=%s" % (filename,) request, channel = make_request( - self.hs.get_reactor(), "POST", path, content=image_data, access_token=tok - ) - request.requestHeaders.addRawHeader( - b"Content-Length", str(image_length).encode("UTF-8") + self.hs.get_reactor(), + "POST", + path, + content=image_data, + access_token=tok, + custom_headers=[(b"Content-Length", str(image_length))], ) request.render(resource) self.hs.get_reactor().pump([100]) diff --git a/tests/server.py b/tests/server.py index 3dd2cfc072..ef03109a6c 100644 --- a/tests/server.py +++ b/tests/server.py @@ -2,7 +2,7 @@ import json import logging from collections import deque from io import SEEK_END, BytesIO -from typing import Callable +from typing import Callable, Iterable, Optional, Tuple, Union import attr from typing_extensions import Deque @@ -139,6 +139,9 @@ def make_request( shorthand=True, federation_auth_origin=None, content_is_form=False, + custom_headers: Optional[ + Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] + ] = None, ): """ Make a web request using the given method and path, feed it the @@ -157,6 +160,8 @@ def make_request( content_is_form: Whether the content is URL encoded form data. Adds the 'Content-Type': 'application/x-www-form-urlencoded' header. + custom_headers: (name, value) pairs to add as request headers + Returns: Tuple[synapse.http.site.SynapseRequest, channel] """ @@ -211,6 +216,10 @@ def make_request( # Assume the body is JSON req.requestHeaders.addRawHeader(b"Content-Type", b"application/json") + if custom_headers: + for k, v in custom_headers: + req.requestHeaders.addRawHeader(k, v) + req.requestReceived(method, path, b"1.1") return req, channel diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index e96ca1c8ca..efca43ec78 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -21,6 +21,7 @@ from synapse.http.site import XForwardedForRequest from synapse.rest.client.v1 import login from tests import unittest +from tests.server import make_request from tests.test_utils import make_awaitable from tests.unittest import override_config @@ -408,17 +409,17 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase): # Advance to a known time self.reactor.advance(123456 - self.reactor.seconds()) - request, channel = self.make_request( + headers1 = {b"User-Agent": b"Mozzila pizza"} + headers1.update(headers) + + request, channel = make_request( + self.reactor, "GET", "/_matrix/client/r0/admin/users/" + self.user_id, access_token=access_token, + custom_headers=headers1.items(), **make_request_args, ) - request.requestHeaders.addRawHeader(b"User-Agent", b"Mozzila pizza") - - # Add the optional headers - for h, v in headers.items(): - request.requestHeaders.addRawHeader(h, v) self.render(request) # Advance so the save loop occurs