Fix remaining mypy issues due to Twisted upgrade. (#9608)
parent
026503fa3b
commit
d29b71aa50
|
@ -0,0 +1 @@
|
|||
Fix incorrect type hints.
|
|
@ -19,7 +19,7 @@ from typing import Any, List, Optional, Type, Union
|
|||
|
||||
from twisted.internet import protocol
|
||||
|
||||
class RedisProtocol:
|
||||
class RedisProtocol(protocol.Protocol):
|
||||
def publish(self, channel: str, message: bytes): ...
|
||||
async def ping(self) -> None: ...
|
||||
async def set(
|
||||
|
|
|
@ -45,7 +45,9 @@ from twisted.internet.interfaces import (
|
|||
IHostResolution,
|
||||
IReactorPluggableNameResolver,
|
||||
IResolutionReceiver,
|
||||
ITCPTransport,
|
||||
)
|
||||
from twisted.internet.protocol import connectionDone
|
||||
from twisted.internet.task import Cooperator
|
||||
from twisted.python.failure import Failure
|
||||
from twisted.web._newclient import ResponseDone
|
||||
|
@ -760,6 +762,8 @@ class BodyExceededMaxSize(Exception):
|
|||
class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
|
||||
"""A protocol which immediately errors upon receiving data."""
|
||||
|
||||
transport = None # type: Optional[ITCPTransport]
|
||||
|
||||
def __init__(self, deferred: defer.Deferred):
|
||||
self.deferred = deferred
|
||||
|
||||
|
@ -771,18 +775,21 @@ class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
|
|||
self.deferred.errback(BodyExceededMaxSize())
|
||||
# Close the connection (forcefully) since all the data will get
|
||||
# discarded anyway.
|
||||
assert self.transport is not None
|
||||
self.transport.abortConnection()
|
||||
|
||||
def dataReceived(self, data: bytes) -> None:
|
||||
self._maybe_fail()
|
||||
|
||||
def connectionLost(self, reason: Failure) -> None:
|
||||
def connectionLost(self, reason: Failure = connectionDone) -> None:
|
||||
self._maybe_fail()
|
||||
|
||||
|
||||
class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
|
||||
"""A protocol which reads body to a stream, erroring if the body exceeds a maximum size."""
|
||||
|
||||
transport = None # type: Optional[ITCPTransport]
|
||||
|
||||
def __init__(
|
||||
self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int]
|
||||
):
|
||||
|
@ -805,9 +812,10 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
|
|||
self.deferred.errback(BodyExceededMaxSize())
|
||||
# Close the connection (forcefully) since all the data will get
|
||||
# discarded anyway.
|
||||
assert self.transport is not None
|
||||
self.transport.abortConnection()
|
||||
|
||||
def connectionLost(self, reason: Failure) -> None:
|
||||
def connectionLost(self, reason: Failure = connectionDone) -> None:
|
||||
# If the maximum size was already exceeded, there's nothing to do.
|
||||
if self.deferred.called:
|
||||
return
|
||||
|
|
|
@ -302,7 +302,7 @@ class ReplicationCommandHandler:
|
|||
hs, outbound_redis_connection
|
||||
)
|
||||
hs.get_reactor().connectTCP(
|
||||
hs.config.redis.redis_host,
|
||||
hs.config.redis.redis_host.encode(),
|
||||
hs.config.redis.redis_port,
|
||||
self._factory,
|
||||
)
|
||||
|
@ -311,7 +311,7 @@ class ReplicationCommandHandler:
|
|||
self._factory = DirectTcpReplicationClientFactory(hs, client_name, self)
|
||||
host = hs.config.worker_replication_host
|
||||
port = hs.config.worker_replication_port
|
||||
hs.get_reactor().connectTCP(host, port, self._factory)
|
||||
hs.get_reactor().connectTCP(host.encode(), port, self._factory)
|
||||
|
||||
def get_streams(self) -> Dict[str, Stream]:
|
||||
"""Get a map from stream name to all streams."""
|
||||
|
|
|
@ -56,6 +56,7 @@ from prometheus_client import Counter
|
|||
from zope.interface import Interface, implementer
|
||||
|
||||
from twisted.internet import task
|
||||
from twisted.internet.tcp import Connection
|
||||
from twisted.protocols.basic import LineOnlyReceiver
|
||||
from twisted.python.failure import Failure
|
||||
|
||||
|
@ -145,6 +146,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
|||
(if they send a `PING` command)
|
||||
"""
|
||||
|
||||
# The transport is going to be an ITCPTransport, but that doesn't have the
|
||||
# (un)registerProducer methods, those are only on the implementation.
|
||||
transport = None # type: Connection
|
||||
|
||||
delimiter = b"\n"
|
||||
|
||||
# Valid commands we expect to receive
|
||||
|
@ -189,6 +194,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
|||
|
||||
connected_connections.append(self) # Register connection for metrics
|
||||
|
||||
assert self.transport is not None
|
||||
self.transport.registerProducer(self, True) # For the *Producing callbacks
|
||||
|
||||
self._send_pending_commands()
|
||||
|
@ -213,6 +219,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
|||
logger.info(
|
||||
"[%s] Failed to close connection gracefully, aborting", self.id()
|
||||
)
|
||||
assert self.transport is not None
|
||||
self.transport.abortConnection()
|
||||
else:
|
||||
if now - self.last_sent_command >= PING_TIME:
|
||||
|
@ -302,6 +309,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
|||
def close(self):
|
||||
logger.warning("[%s] Closing connection", self.id())
|
||||
self.time_we_closed = self.clock.time_msec()
|
||||
assert self.transport is not None
|
||||
self.transport.loseConnection()
|
||||
self.on_connection_closed()
|
||||
|
||||
|
@ -399,6 +407,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
|||
def connectionLost(self, reason):
|
||||
logger.info("[%s] Replication connection closed: %r", self.id(), reason)
|
||||
if isinstance(reason, Failure):
|
||||
assert reason.type is not None
|
||||
connection_close_counter.labels(reason.type.__name__).inc()
|
||||
else:
|
||||
connection_close_counter.labels(reason.__class__.__name__).inc()
|
||||
|
|
|
@ -365,6 +365,6 @@ def lazyConnection(
|
|||
factory.continueTrying = reconnect
|
||||
|
||||
reactor = hs.get_reactor()
|
||||
reactor.connectTCP(host, port, factory, timeout=30, bindAddress=None)
|
||||
reactor.connectTCP(host.encode(), port, factory, timeout=30, bindAddress=None)
|
||||
|
||||
return factory.handler
|
||||
|
|
|
@ -13,9 +13,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import attr
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
|
||||
|
||||
from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
|
||||
from twisted.internet.protocol import Protocol
|
||||
|
@ -158,10 +156,8 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
|||
# Set up client side protocol
|
||||
client_protocol = client_factory.buildProtocol(None)
|
||||
|
||||
request_factory = OneShotRequestFactory()
|
||||
|
||||
# Set up the server side protocol
|
||||
channel = _PushHTTPChannel(self.reactor, request_factory, self.site)
|
||||
channel = _PushHTTPChannel(self.reactor, SynapseRequest, self.site)
|
||||
|
||||
# Connect client to server and vice versa.
|
||||
client_to_server_transport = FakeTransport(
|
||||
|
@ -183,7 +179,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
|||
server_to_client_transport.loseConnection()
|
||||
client_to_server_transport.loseConnection()
|
||||
|
||||
return request_factory.request
|
||||
return channel.request
|
||||
|
||||
def assert_request_is_get_repl_stream_updates(
|
||||
self, request: SynapseRequest, stream_name: str
|
||||
|
@ -237,7 +233,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
|
|||
if self.hs.config.redis.redis_enabled:
|
||||
# Handle attempts to connect to fake redis server.
|
||||
self.reactor.add_tcp_client_callback(
|
||||
"localhost",
|
||||
b"localhost",
|
||||
6379,
|
||||
self.connect_any_redis_attempts,
|
||||
)
|
||||
|
@ -392,10 +388,8 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
|
|||
# Set up client side protocol
|
||||
client_protocol = client_factory.buildProtocol(None)
|
||||
|
||||
request_factory = OneShotRequestFactory()
|
||||
|
||||
# Set up the server side protocol
|
||||
channel = _PushHTTPChannel(self.reactor, request_factory, self._hs_to_site[hs])
|
||||
channel = _PushHTTPChannel(self.reactor, SynapseRequest, self._hs_to_site[hs])
|
||||
|
||||
# Connect client to server and vice versa.
|
||||
client_to_server_transport = FakeTransport(
|
||||
|
@ -421,7 +415,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
|
|||
clients = self.reactor.tcpClients
|
||||
while clients:
|
||||
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
|
||||
self.assertEqual(host, "localhost")
|
||||
self.assertEqual(host, b"localhost")
|
||||
self.assertEqual(port, 6379)
|
||||
|
||||
client_protocol = client_factory.buildProtocol(None)
|
||||
|
@ -453,21 +447,6 @@ class TestReplicationDataHandler(GenericWorkerReplicationHandler):
|
|||
self.received_rdata_rows.append((stream_name, token, r))
|
||||
|
||||
|
||||
@attr.s()
|
||||
class OneShotRequestFactory:
|
||||
"""A simple request factory that generates a single `SynapseRequest` and
|
||||
stores it for future use. Can only be used once.
|
||||
"""
|
||||
|
||||
request = attr.ib(default=None)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
assert self.request is None
|
||||
|
||||
self.request = SynapseRequest(*args, **kwargs)
|
||||
return self.request
|
||||
|
||||
|
||||
class _PushHTTPChannel(HTTPChannel):
|
||||
"""A HTTPChannel that wraps pull producers to push producers.
|
||||
|
||||
|
@ -479,7 +458,7 @@ class _PushHTTPChannel(HTTPChannel):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self, reactor: IReactorTime, request_factory: Callable[..., Request], site: Site
|
||||
self, reactor: IReactorTime, request_factory: Type[Request], site: Site
|
||||
):
|
||||
super().__init__()
|
||||
self.reactor = reactor
|
||||
|
@ -510,6 +489,11 @@ class _PushHTTPChannel(HTTPChannel):
|
|||
request.responseHeaders.setRawHeaders(b"connection", [b"close"])
|
||||
return False
|
||||
|
||||
def requestDone(self, request):
|
||||
# Store the request for inspection.
|
||||
self.request = request
|
||||
super().requestDone(request)
|
||||
|
||||
|
||||
class _PullToPushProducer:
|
||||
"""A push producer that wraps a pull producer."""
|
||||
|
@ -597,6 +581,8 @@ class FakeRedisPubSubServer:
|
|||
class FakeRedisPubSubProtocol(Protocol):
|
||||
"""A connection from a client talking to the fake Redis server."""
|
||||
|
||||
transport = None # type: Optional[FakeTransport]
|
||||
|
||||
def __init__(self, server: FakeRedisPubSubServer):
|
||||
self._server = server
|
||||
self._reader = hiredis.Reader()
|
||||
|
@ -641,6 +627,8 @@ class FakeRedisPubSubProtocol(Protocol):
|
|||
|
||||
def send(self, msg):
|
||||
"""Send a message back to the client."""
|
||||
assert self.transport is not None
|
||||
|
||||
raw = self.encode(msg).encode("utf-8")
|
||||
|
||||
self.transport.write(raw)
|
||||
|
|
|
@ -16,6 +16,7 @@ from twisted.internet.interfaces import (
|
|||
IReactorPluggableNameResolver,
|
||||
IReactorTCP,
|
||||
IResolverSimple,
|
||||
ITransport,
|
||||
)
|
||||
from twisted.python.failure import Failure
|
||||
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
|
||||
|
@ -467,6 +468,7 @@ def get_clock():
|
|||
return clock, hs_clock
|
||||
|
||||
|
||||
@implementer(ITransport)
|
||||
@attr.s(cmp=False)
|
||||
class FakeTransport:
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue