Convert replication code to async/await. (#7987)
parent
db5970ac6d
commit
3b415e23a5
|
@ -0,0 +1 @@
|
||||||
|
Convert various parts of the codebase to async/await.
|
|
@ -548,7 +548,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
address (str|None): the IP address used to perform the registration.
|
address (str|None): the IP address used to perform the registration.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred
|
Awaitable
|
||||||
"""
|
"""
|
||||||
if self.hs.config.worker_app:
|
if self.hs.config.worker_app:
|
||||||
return self._register_client(
|
return self._register_client(
|
||||||
|
|
|
@ -20,8 +20,6 @@ import urllib
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
CodeMessageException,
|
CodeMessageException,
|
||||||
HttpResponseException,
|
HttpResponseException,
|
||||||
|
@ -101,7 +99,7 @@ class ReplicationEndpoint(object):
|
||||||
assert self.METHOD in ("PUT", "POST", "GET")
|
assert self.METHOD in ("PUT", "POST", "GET")
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def _serialize_payload(**kwargs):
|
async def _serialize_payload(**kwargs):
|
||||||
"""Static method that is called when creating a request.
|
"""Static method that is called when creating a request.
|
||||||
|
|
||||||
Concrete implementations should have explicit parameters (rather than
|
Concrete implementations should have explicit parameters (rather than
|
||||||
|
@ -110,9 +108,8 @@ class ReplicationEndpoint(object):
|
||||||
argument list.
|
argument list.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[dict]|dict: If POST/PUT request then dictionary must be
|
dict: If POST/PUT request then dictionary must be JSON serialisable,
|
||||||
JSON serialisable, otherwise must be appropriate for adding as
|
otherwise must be appropriate for adding as query args.
|
||||||
query args.
|
|
||||||
"""
|
"""
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
@ -144,8 +141,7 @@ class ReplicationEndpoint(object):
|
||||||
instance_map = hs.config.worker.instance_map
|
instance_map = hs.config.worker.instance_map
|
||||||
|
|
||||||
@trace(opname="outgoing_replication_request")
|
@trace(opname="outgoing_replication_request")
|
||||||
@defer.inlineCallbacks
|
async def send_request(instance_name="master", **kwargs):
|
||||||
def send_request(instance_name="master", **kwargs):
|
|
||||||
if instance_name == local_instance_name:
|
if instance_name == local_instance_name:
|
||||||
raise Exception("Trying to send HTTP request to self")
|
raise Exception("Trying to send HTTP request to self")
|
||||||
if instance_name == "master":
|
if instance_name == "master":
|
||||||
|
@ -159,7 +155,7 @@ class ReplicationEndpoint(object):
|
||||||
"Instance %r not in 'instance_map' config" % (instance_name,)
|
"Instance %r not in 'instance_map' config" % (instance_name,)
|
||||||
)
|
)
|
||||||
|
|
||||||
data = yield cls._serialize_payload(**kwargs)
|
data = await cls._serialize_payload(**kwargs)
|
||||||
|
|
||||||
url_args = [
|
url_args = [
|
||||||
urllib.parse.quote(kwargs[name], safe="") for name in cls.PATH_ARGS
|
urllib.parse.quote(kwargs[name], safe="") for name in cls.PATH_ARGS
|
||||||
|
@ -197,7 +193,7 @@ class ReplicationEndpoint(object):
|
||||||
headers = {} # type: Dict[bytes, List[bytes]]
|
headers = {} # type: Dict[bytes, List[bytes]]
|
||||||
inject_active_span_byte_dict(headers, None, check_destination=False)
|
inject_active_span_byte_dict(headers, None, check_destination=False)
|
||||||
try:
|
try:
|
||||||
result = yield request_func(uri, data, headers=headers)
|
result = await request_func(uri, data, headers=headers)
|
||||||
break
|
break
|
||||||
except CodeMessageException as e:
|
except CodeMessageException as e:
|
||||||
if e.code != 504 or not cls.RETRY_ON_TIMEOUT:
|
if e.code != 504 or not cls.RETRY_ON_TIMEOUT:
|
||||||
|
@ -207,7 +203,7 @@ class ReplicationEndpoint(object):
|
||||||
|
|
||||||
# If we timed out we probably don't need to worry about backing
|
# If we timed out we probably don't need to worry about backing
|
||||||
# off too much, but lets just wait a little anyway.
|
# off too much, but lets just wait a little anyway.
|
||||||
yield clock.sleep(1)
|
await clock.sleep(1)
|
||||||
except HttpResponseException as e:
|
except HttpResponseException as e:
|
||||||
# We convert to SynapseError as we know that it was a SynapseError
|
# We convert to SynapseError as we know that it was a SynapseError
|
||||||
# on the master process that we should send to the client. (And
|
# on the master process that we should send to the client. (And
|
||||||
|
|
|
@ -60,7 +60,7 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint):
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _serialize_payload(user_id):
|
async def _serialize_payload(user_id):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
async def _handle_request(self, request, user_id):
|
async def _handle_request(self, request, user_id):
|
||||||
|
|
|
@ -15,8 +15,6 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||||
from synapse.events import make_event_from_dict
|
from synapse.events import make_event_from_dict
|
||||||
from synapse.events.snapshot import EventContext
|
from synapse.events.snapshot import EventContext
|
||||||
|
@ -67,8 +65,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
|
||||||
self.federation_handler = hs.get_handlers().federation_handler
|
self.federation_handler = hs.get_handlers().federation_handler
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@defer.inlineCallbacks
|
async def _serialize_payload(store, event_and_contexts, backfilled):
|
||||||
def _serialize_payload(store, event_and_contexts, backfilled):
|
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
store
|
store
|
||||||
|
@ -78,9 +75,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
|
||||||
"""
|
"""
|
||||||
event_payloads = []
|
event_payloads = []
|
||||||
for event, context in event_and_contexts:
|
for event, context in event_and_contexts:
|
||||||
serialized_context = yield defer.ensureDeferred(
|
serialized_context = await context.serialize(event, store)
|
||||||
context.serialize(event, store)
|
|
||||||
)
|
|
||||||
|
|
||||||
event_payloads.append(
|
event_payloads.append(
|
||||||
{
|
{
|
||||||
|
@ -156,7 +151,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
|
||||||
self.registry = hs.get_federation_registry()
|
self.registry = hs.get_federation_registry()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _serialize_payload(edu_type, origin, content):
|
async def _serialize_payload(edu_type, origin, content):
|
||||||
return {"origin": origin, "content": content}
|
return {"origin": origin, "content": content}
|
||||||
|
|
||||||
async def _handle_request(self, request, edu_type):
|
async def _handle_request(self, request, edu_type):
|
||||||
|
@ -199,7 +194,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint):
|
||||||
self.registry = hs.get_federation_registry()
|
self.registry = hs.get_federation_registry()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _serialize_payload(query_type, args):
|
async def _serialize_payload(query_type, args):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
query_type (str)
|
query_type (str)
|
||||||
|
@ -240,7 +235,7 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _serialize_payload(room_id, args):
|
async def _serialize_payload(room_id, args):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
room_id (str)
|
room_id (str)
|
||||||
|
@ -275,7 +270,7 @@ class ReplicationStoreRoomOnInviteRestServlet(ReplicationEndpoint):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _serialize_payload(room_id, room_version):
|
async def _serialize_payload(room_id, room_version):
|
||||||
return {"room_version": room_version.identifier}
|
return {"room_version": room_version.identifier}
|
||||||
|
|
||||||
async def _handle_request(self, request, room_id):
|
async def _handle_request(self, request, room_id):
|
||||||
|
|
|
@ -36,7 +36,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
|
||||||
self.registration_handler = hs.get_registration_handler()
|
self.registration_handler = hs.get_registration_handler()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _serialize_payload(user_id, device_id, initial_display_name, is_guest):
|
async def _serialize_payload(user_id, device_id, initial_display_name, is_guest):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
device_id (str|None): Device ID to use, if None a new one is
|
device_id (str|None): Device ID to use, if None a new one is
|
||||||
|
|
|
@ -52,7 +52,9 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _serialize_payload(requester, room_id, user_id, remote_room_hosts, content):
|
async def _serialize_payload(
|
||||||
|
requester, room_id, user_id, remote_room_hosts, content
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
requester(Requester)
|
requester(Requester)
|
||||||
|
@ -112,7 +114,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
|
||||||
self.member_handler = hs.get_room_member_handler()
|
self.member_handler = hs.get_room_member_handler()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _serialize_payload( # type: ignore
|
async def _serialize_payload( # type: ignore
|
||||||
invite_event_id: str,
|
invite_event_id: str,
|
||||||
txn_id: Optional[str],
|
txn_id: Optional[str],
|
||||||
requester: Requester,
|
requester: Requester,
|
||||||
|
@ -174,7 +176,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
|
||||||
self.distributor = hs.get_distributor()
|
self.distributor = hs.get_distributor()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _serialize_payload(room_id, user_id, change):
|
async def _serialize_payload(room_id, user_id, change):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
room_id (str)
|
room_id (str)
|
||||||
|
|
|
@ -50,7 +50,7 @@ class ReplicationBumpPresenceActiveTime(ReplicationEndpoint):
|
||||||
self._presence_handler = hs.get_presence_handler()
|
self._presence_handler = hs.get_presence_handler()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _serialize_payload(user_id):
|
async def _serialize_payload(user_id):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
async def _handle_request(self, request, user_id):
|
async def _handle_request(self, request, user_id):
|
||||||
|
@ -92,7 +92,7 @@ class ReplicationPresenceSetState(ReplicationEndpoint):
|
||||||
self._presence_handler = hs.get_presence_handler()
|
self._presence_handler = hs.get_presence_handler()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _serialize_payload(user_id, state, ignore_status_msg=False):
|
async def _serialize_payload(user_id, state, ignore_status_msg=False):
|
||||||
return {
|
return {
|
||||||
"state": state,
|
"state": state,
|
||||||
"ignore_status_msg": ignore_status_msg,
|
"ignore_status_msg": ignore_status_msg,
|
||||||
|
|
|
@ -34,7 +34,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
|
||||||
self.registration_handler = hs.get_registration_handler()
|
self.registration_handler = hs.get_registration_handler()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _serialize_payload(
|
async def _serialize_payload(
|
||||||
user_id,
|
user_id,
|
||||||
password_hash,
|
password_hash,
|
||||||
was_guest,
|
was_guest,
|
||||||
|
@ -105,7 +105,7 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint):
|
||||||
self.registration_handler = hs.get_registration_handler()
|
self.registration_handler = hs.get_registration_handler()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _serialize_payload(user_id, auth_result, access_token):
|
async def _serialize_payload(user_id, auth_result, access_token):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
user_id (str): The user ID that consented
|
user_id (str): The user ID that consented
|
||||||
|
|
|
@ -15,8 +15,6 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||||
from synapse.events import make_event_from_dict
|
from synapse.events import make_event_from_dict
|
||||||
from synapse.events.snapshot import EventContext
|
from synapse.events.snapshot import EventContext
|
||||||
|
@ -62,8 +60,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@defer.inlineCallbacks
|
async def _serialize_payload(
|
||||||
def _serialize_payload(
|
|
||||||
event_id, store, event, context, requester, ratelimit, extra_users
|
event_id, store, event, context, requester, ratelimit, extra_users
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -77,7 +74,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
|
||||||
extra_users (list(UserID)): Any extra users to notify about event
|
extra_users (list(UserID)): Any extra users to notify about event
|
||||||
"""
|
"""
|
||||||
|
|
||||||
serialized_context = yield defer.ensureDeferred(context.serialize(event, store))
|
serialized_context = await context.serialize(event, store)
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"event": event.get_pdu_json(),
|
"event": event.get_pdu_json(),
|
||||||
|
|
|
@ -54,7 +54,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
|
||||||
self.streams = hs.get_replication_streams()
|
self.streams = hs.get_replication_streams()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _serialize_payload(stream_name, from_token, upto_token):
|
async def _serialize_payload(stream_name, from_token, upto_token):
|
||||||
return {"from_token": from_token, "upto_token": upto_token}
|
return {"from_token": from_token, "upto_token": upto_token}
|
||||||
|
|
||||||
async def _handle_request(self, request, stream_name):
|
async def _handle_request(self, request, stream_name):
|
||||||
|
|
Loading…
Reference in New Issue