Use `getClientAddress` instead of `getClientIP`. (#12599)
getClientIP was deprecated in Twisted 18.4.0, which also added getClientAddress. The Synapse minimum version for Twisted is currently 18.9.0, so all supported versions have the new API.pull/12634/head
parent
116a4c8340
commit
7fbf42499d
|
@ -0,0 +1 @@
|
||||||
|
Use `getClientAddress` instead of the deprecated `getClientIP`.
|
|
@ -187,7 +187,7 @@ class Auth:
|
||||||
Once get_user_by_req has set up the opentracing span, this does the actual work.
|
Once get_user_by_req has set up the opentracing span, this does the actual work.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
ip_addr = request.getClientIP()
|
ip_addr = request.getClientAddress().host
|
||||||
user_agent = get_request_user_agent(request)
|
user_agent = get_request_user_agent(request)
|
||||||
|
|
||||||
access_token = self.get_access_token_from_request(request)
|
access_token = self.get_access_token_from_request(request)
|
||||||
|
@ -356,7 +356,7 @@ class Auth:
|
||||||
return None, None, None
|
return None, None, None
|
||||||
|
|
||||||
if app_service.ip_range_whitelist:
|
if app_service.ip_range_whitelist:
|
||||||
ip_address = IPAddress(request.getClientIP())
|
ip_address = IPAddress(request.getClientAddress().host)
|
||||||
if ip_address not in app_service.ip_range_whitelist:
|
if ip_address not in app_service.ip_range_whitelist:
|
||||||
return None, None, None
|
return None, None, None
|
||||||
|
|
||||||
|
|
|
@ -551,7 +551,7 @@ class AuthHandler:
|
||||||
await self.store.set_ui_auth_clientdict(sid, clientdict)
|
await self.store.set_ui_auth_clientdict(sid, clientdict)
|
||||||
|
|
||||||
user_agent = get_request_user_agent(request)
|
user_agent = get_request_user_agent(request)
|
||||||
clientip = request.getClientIP()
|
clientip = request.getClientAddress().host
|
||||||
|
|
||||||
await self.store.add_user_agent_ip_to_ui_auth_session(
|
await self.store.add_user_agent_ip_to_ui_auth_session(
|
||||||
session.session_id, user_agent, clientip
|
session.session_id, user_agent, clientip
|
||||||
|
|
|
@ -92,7 +92,7 @@ class IdentityHandler:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
await self._3pid_validation_ratelimiter_ip.ratelimit(
|
await self._3pid_validation_ratelimiter_ip.ratelimit(
|
||||||
None, (medium, request.getClientIP())
|
None, (medium, request.getClientAddress().host)
|
||||||
)
|
)
|
||||||
await self._3pid_validation_ratelimiter_address.ratelimit(
|
await self._3pid_validation_ratelimiter_address.ratelimit(
|
||||||
None, (medium, address)
|
None, (medium, address)
|
||||||
|
|
|
@ -468,7 +468,7 @@ class SsoHandler:
|
||||||
auth_provider_id,
|
auth_provider_id,
|
||||||
remote_user_id,
|
remote_user_id,
|
||||||
get_request_user_agent(request),
|
get_request_user_agent(request),
|
||||||
request.getClientIP(),
|
request.getClientAddress().host,
|
||||||
)
|
)
|
||||||
new_user = True
|
new_user = True
|
||||||
elif self._sso_update_profile_information:
|
elif self._sso_update_profile_information:
|
||||||
|
@ -928,7 +928,7 @@ class SsoHandler:
|
||||||
session.auth_provider_id,
|
session.auth_provider_id,
|
||||||
session.remote_user_id,
|
session.remote_user_id,
|
||||||
get_request_user_agent(request),
|
get_request_user_agent(request),
|
||||||
request.getClientIP(),
|
request.getClientAddress().host,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
|
@ -238,7 +238,7 @@ class SynapseRequest(Request):
|
||||||
request_id,
|
request_id,
|
||||||
request=ContextRequest(
|
request=ContextRequest(
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
ip_address=self.getClientIP(),
|
ip_address=self.getClientAddress().host,
|
||||||
site_tag=self.synapse_site.site_tag,
|
site_tag=self.synapse_site.site_tag,
|
||||||
# The requester is going to be unknown at this point.
|
# The requester is going to be unknown at this point.
|
||||||
requester=None,
|
requester=None,
|
||||||
|
@ -381,7 +381,7 @@ class SynapseRequest(Request):
|
||||||
|
|
||||||
self.synapse_site.access_logger.debug(
|
self.synapse_site.access_logger.debug(
|
||||||
"%s - %s - Received request: %s %s",
|
"%s - %s - Received request: %s %s",
|
||||||
self.getClientIP(),
|
self.getClientAddress().host,
|
||||||
self.synapse_site.site_tag,
|
self.synapse_site.site_tag,
|
||||||
self.get_method(),
|
self.get_method(),
|
||||||
self.get_redacted_uri(),
|
self.get_redacted_uri(),
|
||||||
|
@ -429,7 +429,7 @@ class SynapseRequest(Request):
|
||||||
"%s - %s - {%s}"
|
"%s - %s - {%s}"
|
||||||
" Processed request: %.3fsec/%.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)"
|
" Processed request: %.3fsec/%.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)"
|
||||||
' %sB %s "%s %s %s" "%s" [%d dbevts]',
|
' %sB %s "%s %s %s" "%s" [%d dbevts]',
|
||||||
self.getClientIP(),
|
self.getClientAddress().host,
|
||||||
self.synapse_site.site_tag,
|
self.synapse_site.site_tag,
|
||||||
requester,
|
requester,
|
||||||
processing_time,
|
processing_time,
|
||||||
|
|
|
@ -884,7 +884,7 @@ def trace_servlet(request: "SynapseRequest", extract_context: bool = False):
|
||||||
tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
|
tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
|
||||||
tags.HTTP_METHOD: request.get_method(),
|
tags.HTTP_METHOD: request.get_method(),
|
||||||
tags.HTTP_URL: request.get_redacted_uri(),
|
tags.HTTP_URL: request.get_redacted_uri(),
|
||||||
tags.PEER_HOST_IPV6: request.getClientIP(),
|
tags.PEER_HOST_IPV6: request.getClientAddress().host,
|
||||||
}
|
}
|
||||||
|
|
||||||
request_name = request.request_metrics.name
|
request_name = request.request_metrics.name
|
||||||
|
|
|
@ -112,7 +112,7 @@ class AuthRestServlet(RestServlet):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self.auth_handler.add_oob_auth(
|
await self.auth_handler.add_oob_auth(
|
||||||
LoginType.RECAPTCHA, authdict, request.getClientIP()
|
LoginType.RECAPTCHA, authdict, request.getClientAddress().host
|
||||||
)
|
)
|
||||||
except LoginError as e:
|
except LoginError as e:
|
||||||
# Authentication failed, let user try again
|
# Authentication failed, let user try again
|
||||||
|
@ -132,7 +132,7 @@ class AuthRestServlet(RestServlet):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self.auth_handler.add_oob_auth(
|
await self.auth_handler.add_oob_auth(
|
||||||
LoginType.TERMS, authdict, request.getClientIP()
|
LoginType.TERMS, authdict, request.getClientAddress().host
|
||||||
)
|
)
|
||||||
except LoginError as e:
|
except LoginError as e:
|
||||||
# Authentication failed, let user try again
|
# Authentication failed, let user try again
|
||||||
|
@ -161,7 +161,9 @@ class AuthRestServlet(RestServlet):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self.auth_handler.add_oob_auth(
|
await self.auth_handler.add_oob_auth(
|
||||||
LoginType.REGISTRATION_TOKEN, authdict, request.getClientIP()
|
LoginType.REGISTRATION_TOKEN,
|
||||||
|
authdict,
|
||||||
|
request.getClientAddress().host,
|
||||||
)
|
)
|
||||||
except LoginError as e:
|
except LoginError as e:
|
||||||
html = self.registration_token_template.render(
|
html = self.registration_token_template.render(
|
||||||
|
|
|
@ -176,7 +176,7 @@ class LoginRestServlet(RestServlet):
|
||||||
|
|
||||||
if appservice.is_rate_limited():
|
if appservice.is_rate_limited():
|
||||||
await self._address_ratelimiter.ratelimit(
|
await self._address_ratelimiter.ratelimit(
|
||||||
None, request.getClientIP()
|
None, request.getClientAddress().host
|
||||||
)
|
)
|
||||||
|
|
||||||
result = await self._do_appservice_login(
|
result = await self._do_appservice_login(
|
||||||
|
@ -188,19 +188,25 @@ class LoginRestServlet(RestServlet):
|
||||||
self.jwt_enabled
|
self.jwt_enabled
|
||||||
and login_submission["type"] == LoginRestServlet.JWT_TYPE
|
and login_submission["type"] == LoginRestServlet.JWT_TYPE
|
||||||
):
|
):
|
||||||
await self._address_ratelimiter.ratelimit(None, request.getClientIP())
|
await self._address_ratelimiter.ratelimit(
|
||||||
|
None, request.getClientAddress().host
|
||||||
|
)
|
||||||
result = await self._do_jwt_login(
|
result = await self._do_jwt_login(
|
||||||
login_submission,
|
login_submission,
|
||||||
should_issue_refresh_token=should_issue_refresh_token,
|
should_issue_refresh_token=should_issue_refresh_token,
|
||||||
)
|
)
|
||||||
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
|
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
|
||||||
await self._address_ratelimiter.ratelimit(None, request.getClientIP())
|
await self._address_ratelimiter.ratelimit(
|
||||||
|
None, request.getClientAddress().host
|
||||||
|
)
|
||||||
result = await self._do_token_login(
|
result = await self._do_token_login(
|
||||||
login_submission,
|
login_submission,
|
||||||
should_issue_refresh_token=should_issue_refresh_token,
|
should_issue_refresh_token=should_issue_refresh_token,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
await self._address_ratelimiter.ratelimit(None, request.getClientIP())
|
await self._address_ratelimiter.ratelimit(
|
||||||
|
None, request.getClientAddress().host
|
||||||
|
)
|
||||||
result = await self._do_other_login(
|
result = await self._do_other_login(
|
||||||
login_submission,
|
login_submission,
|
||||||
should_issue_refresh_token=should_issue_refresh_token,
|
should_issue_refresh_token=should_issue_refresh_token,
|
||||||
|
|
|
@ -352,7 +352,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
|
||||||
if self.inhibit_user_in_use_error:
|
if self.inhibit_user_in_use_error:
|
||||||
return 200, {"available": True}
|
return 200, {"available": True}
|
||||||
|
|
||||||
ip = request.getClientIP()
|
ip = request.getClientAddress().host
|
||||||
with self.ratelimiter.ratelimit(ip) as wait_deferred:
|
with self.ratelimiter.ratelimit(ip) as wait_deferred:
|
||||||
await wait_deferred
|
await wait_deferred
|
||||||
|
|
||||||
|
@ -394,7 +394,7 @@ class RegistrationTokenValidityRestServlet(RestServlet):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
|
async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
|
||||||
await self.ratelimiter.ratelimit(None, (request.getClientIP(),))
|
await self.ratelimiter.ratelimit(None, (request.getClientAddress().host,))
|
||||||
|
|
||||||
if not self.hs.config.registration.enable_registration:
|
if not self.hs.config.registration.enable_registration:
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
|
@ -441,7 +441,7 @@ class RegisterRestServlet(RestServlet):
|
||||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||||
body = parse_json_object_from_request(request)
|
body = parse_json_object_from_request(request)
|
||||||
|
|
||||||
client_addr = request.getClientIP()
|
client_addr = request.getClientAddress().host
|
||||||
|
|
||||||
await self.ratelimiter.ratelimit(None, client_addr, update=False)
|
await self.ratelimiter.ratelimit(None, client_addr, update=False)
|
||||||
|
|
||||||
|
|
|
@ -105,7 +105,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
self.store.get_user_by_access_token = simple_async_mock(None)
|
self.store.get_user_by_access_token = simple_async_mock(None)
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.getClientIP.return_value = "127.0.0.1"
|
request.getClientAddress.return_value.host = "127.0.0.1"
|
||||||
request.args[b"access_token"] = [self.test_token]
|
request.args[b"access_token"] = [self.test_token]
|
||||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
requester = self.get_success(self.auth.get_user_by_req(request))
|
requester = self.get_success(self.auth.get_user_by_req(request))
|
||||||
|
@ -124,7 +124,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
self.store.get_user_by_access_token = simple_async_mock(None)
|
self.store.get_user_by_access_token = simple_async_mock(None)
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.getClientIP.return_value = "192.168.10.10"
|
request.getClientAddress.return_value.host = "192.168.10.10"
|
||||||
request.args[b"access_token"] = [self.test_token]
|
request.args[b"access_token"] = [self.test_token]
|
||||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
requester = self.get_success(self.auth.get_user_by_req(request))
|
requester = self.get_success(self.auth.get_user_by_req(request))
|
||||||
|
@ -143,7 +143,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
self.store.get_user_by_access_token = simple_async_mock(None)
|
self.store.get_user_by_access_token = simple_async_mock(None)
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.getClientIP.return_value = "131.111.8.42"
|
request.getClientAddress.return_value.host = "131.111.8.42"
|
||||||
request.args[b"access_token"] = [self.test_token]
|
request.args[b"access_token"] = [self.test_token]
|
||||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
f = self.get_failure(
|
f = self.get_failure(
|
||||||
|
@ -190,7 +190,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
self.store.get_user_by_access_token = simple_async_mock(None)
|
self.store.get_user_by_access_token = simple_async_mock(None)
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.getClientIP.return_value = "127.0.0.1"
|
request.getClientAddress.return_value.host = "127.0.0.1"
|
||||||
request.args[b"access_token"] = [self.test_token]
|
request.args[b"access_token"] = [self.test_token]
|
||||||
request.args[b"user_id"] = [masquerading_user_id]
|
request.args[b"user_id"] = [masquerading_user_id]
|
||||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
|
@ -209,7 +209,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
self.store.get_user_by_access_token = simple_async_mock(None)
|
self.store.get_user_by_access_token = simple_async_mock(None)
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.getClientIP.return_value = "127.0.0.1"
|
request.getClientAddress.return_value.host = "127.0.0.1"
|
||||||
request.args[b"access_token"] = [self.test_token]
|
request.args[b"access_token"] = [self.test_token]
|
||||||
request.args[b"user_id"] = [masquerading_user_id]
|
request.args[b"user_id"] = [masquerading_user_id]
|
||||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
|
@ -236,7 +236,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
self.store.get_device = simple_async_mock({"hidden": False})
|
self.store.get_device = simple_async_mock({"hidden": False})
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.getClientIP.return_value = "127.0.0.1"
|
request.getClientAddress.return_value.host = "127.0.0.1"
|
||||||
request.args[b"access_token"] = [self.test_token]
|
request.args[b"access_token"] = [self.test_token]
|
||||||
request.args[b"user_id"] = [masquerading_user_id]
|
request.args[b"user_id"] = [masquerading_user_id]
|
||||||
request.args[b"org.matrix.msc3202.device_id"] = [masquerading_device_id]
|
request.args[b"org.matrix.msc3202.device_id"] = [masquerading_device_id]
|
||||||
|
@ -268,7 +268,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
self.store.get_device = simple_async_mock(None)
|
self.store.get_device = simple_async_mock(None)
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.getClientIP.return_value = "127.0.0.1"
|
request.getClientAddress.return_value.host = "127.0.0.1"
|
||||||
request.args[b"access_token"] = [self.test_token]
|
request.args[b"access_token"] = [self.test_token]
|
||||||
request.args[b"user_id"] = [masquerading_user_id]
|
request.args[b"user_id"] = [masquerading_user_id]
|
||||||
request.args[b"org.matrix.msc3202.device_id"] = [masquerading_device_id]
|
request.args[b"org.matrix.msc3202.device_id"] = [masquerading_device_id]
|
||||||
|
@ -288,7 +288,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.store.insert_client_ip = simple_async_mock(None)
|
self.store.insert_client_ip = simple_async_mock(None)
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.getClientIP.return_value = "127.0.0.1"
|
request.getClientAddress.return_value.host = "127.0.0.1"
|
||||||
request.args[b"access_token"] = [self.test_token]
|
request.args[b"access_token"] = [self.test_token]
|
||||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
self.get_success(self.auth.get_user_by_req(request))
|
self.get_success(self.auth.get_user_by_req(request))
|
||||||
|
@ -305,7 +305,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.store.insert_client_ip = simple_async_mock(None)
|
self.store.insert_client_ip = simple_async_mock(None)
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.getClientIP.return_value = "127.0.0.1"
|
request.getClientAddress.return_value.host = "127.0.0.1"
|
||||||
request.args[b"access_token"] = [self.test_token]
|
request.args[b"access_token"] = [self.test_token]
|
||||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
self.get_success(self.auth.get_user_by_req(request))
|
self.get_success(self.auth.get_user_by_req(request))
|
||||||
|
|
|
@ -204,7 +204,7 @@ def _mock_request():
|
||||||
mock = Mock(
|
mock = Mock(
|
||||||
spec=[
|
spec=[
|
||||||
"finish",
|
"finish",
|
||||||
"getClientIP",
|
"getClientAddress",
|
||||||
"getHeader",
|
"getHeader",
|
||||||
"setHeader",
|
"setHeader",
|
||||||
"setResponseCode",
|
"setResponseCode",
|
||||||
|
|
|
@ -1300,7 +1300,7 @@ def _build_callback_request(
|
||||||
"getCookie",
|
"getCookie",
|
||||||
"cookies",
|
"cookies",
|
||||||
"requestHeaders",
|
"requestHeaders",
|
||||||
"getClientIP",
|
"getClientAddress",
|
||||||
"getHeader",
|
"getHeader",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
@ -1310,5 +1310,5 @@ def _build_callback_request(
|
||||||
request.args = {}
|
request.args = {}
|
||||||
request.args[b"code"] = [code.encode("utf-8")]
|
request.args[b"code"] = [code.encode("utf-8")]
|
||||||
request.args[b"state"] = [state.encode("utf-8")]
|
request.args[b"state"] = [state.encode("utf-8")]
|
||||||
request.getClientIP.return_value = ip_address
|
request.getClientAddress.return_value.host = ip_address
|
||||||
return request
|
return request
|
||||||
|
|
|
@ -352,7 +352,7 @@ def _mock_request():
|
||||||
mock = Mock(
|
mock = Mock(
|
||||||
spec=[
|
spec=[
|
||||||
"finish",
|
"finish",
|
||||||
"getClientIP",
|
"getClientAddress",
|
||||||
"getHeader",
|
"getHeader",
|
||||||
"setHeader",
|
"setHeader",
|
||||||
"setResponseCode",
|
"setResponseCode",
|
||||||
|
|
|
@ -154,10 +154,12 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(port, 8765)
|
self.assertEqual(port, 8765)
|
||||||
|
|
||||||
# Set up client side protocol
|
# Set up client side protocol
|
||||||
client_protocol = client_factory.buildProtocol(None)
|
client_address = IPv4Address("TCP", "127.0.0.1", 1234)
|
||||||
|
client_protocol = client_factory.buildProtocol(("127.0.0.1", 1234))
|
||||||
|
|
||||||
# Set up the server side protocol
|
# Set up the server side protocol
|
||||||
channel = self.site.buildProtocol(None)
|
server_address = IPv4Address("TCP", host, port)
|
||||||
|
channel = self.site.buildProtocol((host, port))
|
||||||
|
|
||||||
# hook into the channel's request factory so that we can keep a record
|
# hook into the channel's request factory so that we can keep a record
|
||||||
# of the requests
|
# of the requests
|
||||||
|
@ -173,12 +175,12 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# Connect client to server and vice versa.
|
# Connect client to server and vice versa.
|
||||||
client_to_server_transport = FakeTransport(
|
client_to_server_transport = FakeTransport(
|
||||||
channel, self.reactor, client_protocol
|
channel, self.reactor, client_protocol, server_address, client_address
|
||||||
)
|
)
|
||||||
client_protocol.makeConnection(client_to_server_transport)
|
client_protocol.makeConnection(client_to_server_transport)
|
||||||
|
|
||||||
server_to_client_transport = FakeTransport(
|
server_to_client_transport = FakeTransport(
|
||||||
client_protocol, self.reactor, channel
|
client_protocol, self.reactor, channel, client_address, server_address
|
||||||
)
|
)
|
||||||
channel.makeConnection(server_to_client_transport)
|
channel.makeConnection(server_to_client_transport)
|
||||||
|
|
||||||
|
@ -406,19 +408,21 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(port, repl_port)
|
self.assertEqual(port, repl_port)
|
||||||
|
|
||||||
# Set up client side protocol
|
# Set up client side protocol
|
||||||
client_protocol = client_factory.buildProtocol(None)
|
client_address = IPv4Address("TCP", "127.0.0.1", 1234)
|
||||||
|
client_protocol = client_factory.buildProtocol(("127.0.0.1", 1234))
|
||||||
|
|
||||||
# Set up the server side protocol
|
# Set up the server side protocol
|
||||||
channel = self._hs_to_site[hs].buildProtocol(None)
|
server_address = IPv4Address("TCP", host, port)
|
||||||
|
channel = self._hs_to_site[hs].buildProtocol((host, port))
|
||||||
|
|
||||||
# Connect client to server and vice versa.
|
# Connect client to server and vice versa.
|
||||||
client_to_server_transport = FakeTransport(
|
client_to_server_transport = FakeTransport(
|
||||||
channel, self.reactor, client_protocol
|
channel, self.reactor, client_protocol, server_address, client_address
|
||||||
)
|
)
|
||||||
client_protocol.makeConnection(client_to_server_transport)
|
client_protocol.makeConnection(client_to_server_transport)
|
||||||
|
|
||||||
server_to_client_transport = FakeTransport(
|
server_to_client_transport = FakeTransport(
|
||||||
client_protocol, self.reactor, channel
|
client_protocol, self.reactor, channel, client_address, server_address
|
||||||
)
|
)
|
||||||
channel.makeConnection(server_to_client_transport)
|
channel.makeConnection(server_to_client_transport)
|
||||||
|
|
||||||
|
|
|
@ -181,7 +181,7 @@ class FakeChannel:
|
||||||
self.resource_usage = _self.logcontext.get_resource_usage()
|
self.resource_usage = _self.logcontext.get_resource_usage()
|
||||||
|
|
||||||
def getPeer(self):
|
def getPeer(self):
|
||||||
# We give an address so that getClientIP returns a non null entry,
|
# We give an address so that getClientAddress/getClientIP returns a non null entry,
|
||||||
# causing us to record the MAU
|
# causing us to record the MAU
|
||||||
return address.IPv4Address("TCP", self._ip, 3423)
|
return address.IPv4Address("TCP", self._ip, 3423)
|
||||||
|
|
||||||
|
@ -562,7 +562,10 @@ class FakeTransport:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_peer_address: Optional[IAddress] = attr.ib(default=None)
|
_peer_address: Optional[IAddress] = attr.ib(default=None)
|
||||||
"""The value to be returend by getPeer"""
|
"""The value to be returned by getPeer"""
|
||||||
|
|
||||||
|
_host_address: Optional[IAddress] = attr.ib(default=None)
|
||||||
|
"""The value to be returned by getHost"""
|
||||||
|
|
||||||
disconnecting = False
|
disconnecting = False
|
||||||
disconnected = False
|
disconnected = False
|
||||||
|
@ -571,11 +574,11 @@ class FakeTransport:
|
||||||
producer = attr.ib(default=None)
|
producer = attr.ib(default=None)
|
||||||
autoflush = attr.ib(default=True)
|
autoflush = attr.ib(default=True)
|
||||||
|
|
||||||
def getPeer(self):
|
def getPeer(self) -> Optional[IAddress]:
|
||||||
return self._peer_address
|
return self._peer_address
|
||||||
|
|
||||||
def getHost(self):
|
def getHost(self) -> Optional[IAddress]:
|
||||||
return None
|
return self._host_address
|
||||||
|
|
||||||
def loseConnection(self, reason=None):
|
def loseConnection(self, reason=None):
|
||||||
if not self.disconnecting:
|
if not self.disconnecting:
|
||||||
|
|
Loading…
Reference in New Issue