fix up various test cases

A few test cases were relying on being able to mount non-client servlets on the
test resource. it's better to give them their own Resources.
pull/8858/head
Richard van der Hoff 2020-12-02 15:26:25 +00:00
parent 693516e756
commit 7ea85302f3
5 changed files with 38 additions and 17 deletions

View File

@ -1462,7 +1462,7 @@ def register_servlets(hs, resource, authenticator, ratelimiter, servlet_groups=N
Args: Args:
hs (synapse.server.HomeServer): homeserver hs (synapse.server.HomeServer): homeserver
resource (TransportLayerServer): resource class to register to resource (JsonResource): resource class to register to
authenticator (Authenticator): authenticator to use authenticator (Authenticator): authenticator to use
ratelimiter (util.ratelimitutils.FederationRateLimiter): ratelimiter to use ratelimiter (util.ratelimitutils.FederationRateLimiter): ratelimiter to use
servlet_groups (list[str], optional): List of servlet groups to register. servlet_groups (list[str], optional): List of servlet groups to register.

View File

@ -15,18 +15,20 @@
import json import json
from typing import Dict
from mock import ANY, Mock, call from mock import ANY, Mock, call
from twisted.internet import defer from twisted.internet import defer
from twisted.web.resource import Resource
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.federation.transport.server import TransportLayerServer
from synapse.types import UserID, create_requester from synapse.types import UserID, create_requester
from tests import unittest from tests import unittest
from tests.test_utils import make_awaitable from tests.test_utils import make_awaitable
from tests.unittest import override_config from tests.unittest import override_config
from tests.utils import register_federation_servlets
# Some local users to test with # Some local users to test with
U_APPLE = UserID.from_string("@apple:test") U_APPLE = UserID.from_string("@apple:test")
@ -53,8 +55,6 @@ def _make_edu_transaction_json(edu_type, content):
class TypingNotificationsTestCase(unittest.HomeserverTestCase): class TypingNotificationsTestCase(unittest.HomeserverTestCase):
servlets = [register_federation_servlets]
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor, clock):
# we mock out the keyring so as to skip the authentication check on the # we mock out the keyring so as to skip the authentication check on the
# federation API call. # federation API call.
@ -77,6 +77,11 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
return hs return hs
def create_resource_dict(self) -> Dict[str, Resource]:
d = super().create_resource_dict()
d["/_matrix/federation"] = TransportLayerServer(self.hs)
return d
def prepare(self, reactor, clock, hs): def prepare(self, reactor, clock, hs):
mock_notifier = hs.get_notifier() mock_notifier = hs.get_notifier()
self.on_new_event = mock_notifier.on_new_event self.on_new_event = mock_notifier.on_new_event

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Any, Callable, List, Optional, Tuple from typing import Any, Callable, Dict, List, Optional, Tuple
import attr import attr
@ -21,6 +21,7 @@ from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
from twisted.internet.protocol import Protocol from twisted.internet.protocol import Protocol
from twisted.internet.task import LoopingCall from twisted.internet.task import LoopingCall
from twisted.web.http import HTTPChannel from twisted.web.http import HTTPChannel
from twisted.web.resource import Resource
from synapse.app.generic_worker import ( from synapse.app.generic_worker import (
GenericWorkerReplicationHandler, GenericWorkerReplicationHandler,
@ -28,7 +29,7 @@ from synapse.app.generic_worker import (
) )
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest, SynapseSite from synapse.http.site import SynapseRequest, SynapseSite
from synapse.replication.http import ReplicationRestResource, streams from synapse.replication.http import ReplicationRestResource
from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
@ -54,10 +55,6 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
if not hiredis: if not hiredis:
skip = "Requires hiredis" skip = "Requires hiredis"
servlets = [
streams.register_servlets,
]
def prepare(self, reactor, clock, hs): def prepare(self, reactor, clock, hs):
# build a replication server # build a replication server
server_factory = ReplicationStreamProtocolFactory(hs) server_factory = ReplicationStreamProtocolFactory(hs)
@ -88,6 +85,11 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self._client_transport = None self._client_transport = None
self._server_transport = None self._server_transport = None
def create_resource_dict(self) -> Dict[str, Resource]:
d = super().create_resource_dict()
d["/_synapse/replication"] = ReplicationRestResource(self.hs)
return d
def _get_worker_hs_config(self) -> dict: def _get_worker_hs_config(self) -> dict:
config = self.default_config() config = self.default_config()
config["worker_app"] = "synapse.app.generic_worker" config["worker_app"] = "synapse.app.generic_worker"

View File

@ -216,8 +216,9 @@ def make_request(
and not path.startswith(b"/_matrix") and not path.startswith(b"/_matrix")
and not path.startswith(b"/_synapse") and not path.startswith(b"/_synapse")
): ):
if path.startswith(b"/"):
path = path[1:]
path = b"/_matrix/client/r0/" + path path = b"/_matrix/client/r0/" + path
path = path.replace(b"//", b"/")
if not path.startswith(b"/"): if not path.startswith(b"/"):
path = b"/" + path path = b"/" + path

View File

@ -705,13 +705,29 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
A federating homeserver that authenticates incoming requests as `other.example.com`. A federating homeserver that authenticates incoming requests as `other.example.com`.
""" """
def prepare(self, reactor, clock, homeserver): def create_resource_dict(self) -> Dict[str, Resource]:
d = super().create_resource_dict()
d["/_matrix/federation"] = TestTransportLayerServer(self.hs)
return d
class TestTransportLayerServer(JsonResource):
"""A test implementation of TransportLayerServer
authenticates incoming requests as `other.example.com`.
"""
def __init__(self, hs):
super().__init__(hs)
class Authenticator: class Authenticator:
def authenticate_request(self, request, content): def authenticate_request(self, request, content):
return succeed("other.example.com") return succeed("other.example.com")
authenticator = Authenticator()
ratelimiter = FederationRateLimiter( ratelimiter = FederationRateLimiter(
clock, hs.get_clock(),
FederationRateLimitConfig( FederationRateLimitConfig(
window_size=1, window_size=1,
sleep_limit=1, sleep_limit=1,
@ -720,11 +736,8 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
concurrent_requests=1000, concurrent_requests=1000,
), ),
) )
federation_server.register_servlets(
homeserver, self.resource, Authenticator(), ratelimiter
)
return super().prepare(reactor, clock, homeserver) federation_server.register_servlets(hs, self, authenticator, ratelimiter)
def override_config(extra_config): def override_config(extra_config):