Merge pull request #1110 from matrix-org/markjh/e2e_timeout
Add a timeout parameter for end2end key queries.pull/1112/head
commit
76b09c29b0
|
@ -176,7 +176,7 @@ class FederationClient(FederationBase):
|
|||
)
|
||||
|
||||
@log_function
|
||||
def query_client_keys(self, destination, content):
|
||||
def query_client_keys(self, destination, content, timeout):
|
||||
"""Query device keys for a device hosted on a remote server.
|
||||
|
||||
Args:
|
||||
|
@ -188,10 +188,12 @@ class FederationClient(FederationBase):
|
|||
response
|
||||
"""
|
||||
sent_queries_counter.inc("client_device_keys")
|
||||
return self.transport_layer.query_client_keys(destination, content)
|
||||
return self.transport_layer.query_client_keys(
|
||||
destination, content, timeout
|
||||
)
|
||||
|
||||
@log_function
|
||||
def claim_client_keys(self, destination, content):
|
||||
def claim_client_keys(self, destination, content, timeout):
|
||||
"""Claims one-time keys for a device hosted on a remote server.
|
||||
|
||||
Args:
|
||||
|
@ -203,7 +205,9 @@ class FederationClient(FederationBase):
|
|||
response
|
||||
"""
|
||||
sent_queries_counter.inc("client_one_time_keys")
|
||||
return self.transport_layer.claim_client_keys(destination, content)
|
||||
return self.transport_layer.claim_client_keys(
|
||||
destination, content, timeout
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
|
|
|
@ -298,7 +298,7 @@ class TransportLayerClient(object):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def query_client_keys(self, destination, query_content):
|
||||
def query_client_keys(self, destination, query_content, timeout):
|
||||
"""Query the device keys for a list of user ids hosted on a remote
|
||||
server.
|
||||
|
||||
|
@ -327,12 +327,13 @@ class TransportLayerClient(object):
|
|||
destination=destination,
|
||||
path=path,
|
||||
data=query_content,
|
||||
timeout=timeout,
|
||||
)
|
||||
defer.returnValue(content)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def claim_client_keys(self, destination, query_content):
|
||||
def claim_client_keys(self, destination, query_content, timeout):
|
||||
"""Claim one-time keys for a list of devices hosted on a remote server.
|
||||
|
||||
Request:
|
||||
|
@ -363,6 +364,7 @@ class TransportLayerClient(object):
|
|||
destination=destination,
|
||||
path=path,
|
||||
data=query_content,
|
||||
timeout=timeout,
|
||||
)
|
||||
defer.returnValue(content)
|
||||
|
||||
|
|
|
@ -13,14 +13,14 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import collections
|
||||
import json
|
||||
import logging
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api import errors
|
||||
import synapse.types
|
||||
from synapse.api.errors import SynapseError, CodeMessageException
|
||||
from synapse.types import get_domain_from_id
|
||||
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -30,7 +30,6 @@ class E2eKeysHandler(object):
|
|||
self.store = hs.get_datastore()
|
||||
self.federation = hs.get_replication_layer()
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
self.server_name = hs.hostname
|
||||
|
||||
# doesn't really work as part of the generic query API, because the
|
||||
# query request requires an object POST, but we abuse the
|
||||
|
@ -40,7 +39,7 @@ class E2eKeysHandler(object):
|
|||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def query_devices(self, query_body):
|
||||
def query_devices(self, query_body, timeout):
|
||||
""" Handle a device key query from a client
|
||||
|
||||
{
|
||||
|
@ -63,27 +62,50 @@ class E2eKeysHandler(object):
|
|||
|
||||
# separate users by domain.
|
||||
# make a map from domain to user_id to device_ids
|
||||
queries_by_domain = collections.defaultdict(dict)
|
||||
local_query = {}
|
||||
remote_queries = {}
|
||||
|
||||
for user_id, device_ids in device_keys_query.items():
|
||||
user = synapse.types.UserID.from_string(user_id)
|
||||
queries_by_domain[user.domain][user_id] = device_ids
|
||||
if self.is_mine_id(user_id):
|
||||
local_query[user_id] = device_ids
|
||||
else:
|
||||
domain = get_domain_from_id(user_id)
|
||||
remote_queries.setdefault(domain, {})[user_id] = device_ids
|
||||
|
||||
# do the queries
|
||||
# TODO: do these in parallel
|
||||
failures = {}
|
||||
results = {}
|
||||
for destination, destination_query in queries_by_domain.items():
|
||||
if destination == self.server_name:
|
||||
res = yield self.query_local_devices(destination_query)
|
||||
else:
|
||||
res = yield self.federation.query_client_keys(
|
||||
destination, {"device_keys": destination_query}
|
||||
)
|
||||
res = res["device_keys"]
|
||||
for user_id, keys in res.items():
|
||||
if user_id in destination_query:
|
||||
if local_query:
|
||||
local_result = yield self.query_local_devices(local_query)
|
||||
for user_id, keys in local_result.items():
|
||||
if user_id in local_query:
|
||||
results[user_id] = keys
|
||||
|
||||
defer.returnValue((200, {"device_keys": results}))
|
||||
@defer.inlineCallbacks
|
||||
def do_remote_query(destination):
|
||||
destination_query = remote_queries[destination]
|
||||
try:
|
||||
remote_result = yield self.federation.query_client_keys(
|
||||
destination,
|
||||
{"device_keys": destination_query},
|
||||
timeout=timeout
|
||||
)
|
||||
for user_id, keys in remote_result["device_keys"].items():
|
||||
if user_id in destination_query:
|
||||
results[user_id] = keys
|
||||
except CodeMessageException as e:
|
||||
failures[destination] = {
|
||||
"status": e.code, "message": e.message
|
||||
}
|
||||
|
||||
yield preserve_context_over_deferred(defer.gatherResults([
|
||||
preserve_fn(do_remote_query)(destination)
|
||||
for destination in remote_queries
|
||||
]))
|
||||
|
||||
defer.returnValue((200, {
|
||||
"device_keys": results, "failures": failures,
|
||||
}))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def query_local_devices(self, query):
|
||||
|
@ -104,7 +126,7 @@ class E2eKeysHandler(object):
|
|||
if not self.is_mine_id(user_id):
|
||||
logger.warning("Request for keys for non-local user %s",
|
||||
user_id)
|
||||
raise errors.SynapseError(400, "Not a user here")
|
||||
raise SynapseError(400, "Not a user here")
|
||||
|
||||
if not device_ids:
|
||||
local_query.append((user_id, None))
|
||||
|
|
|
@ -246,7 +246,7 @@ class MatrixFederationHttpClient(object):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def put_json(self, destination, path, data={}, json_data_callback=None,
|
||||
long_retries=False):
|
||||
long_retries=False, timeout=None):
|
||||
""" Sends the specifed json data using PUT
|
||||
|
||||
Args:
|
||||
|
@ -259,6 +259,8 @@ class MatrixFederationHttpClient(object):
|
|||
use as the request body.
|
||||
long_retries (bool): A boolean that indicates whether we should
|
||||
retry for a short or long time.
|
||||
timeout(int): How long to try (in ms) the destination for before
|
||||
giving up. None indicates no timeout.
|
||||
|
||||
Returns:
|
||||
Deferred: Succeeds when we get a 2xx HTTP response. The result
|
||||
|
@ -285,6 +287,7 @@ class MatrixFederationHttpClient(object):
|
|||
body_callback=body_callback,
|
||||
headers_dict={"Content-Type": ["application/json"]},
|
||||
long_retries=long_retries,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
if 200 <= response.code < 300:
|
||||
|
@ -300,7 +303,8 @@ class MatrixFederationHttpClient(object):
|
|||
defer.returnValue(json.loads(body))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def post_json(self, destination, path, data={}, long_retries=True):
|
||||
def post_json(self, destination, path, data={}, long_retries=True,
|
||||
timeout=None):
|
||||
""" Sends the specifed json data using POST
|
||||
|
||||
Args:
|
||||
|
@ -311,6 +315,8 @@ class MatrixFederationHttpClient(object):
|
|||
the request body. This will be encoded as JSON.
|
||||
long_retries (bool): A boolean that indicates whether we should
|
||||
retry for a short or long time.
|
||||
timeout(int): How long to try (in ms) the destination for before
|
||||
giving up. None indicates no timeout.
|
||||
|
||||
Returns:
|
||||
Deferred: Succeeds when we get a 2xx HTTP response. The result
|
||||
|
@ -331,6 +337,7 @@ class MatrixFederationHttpClient(object):
|
|||
body_callback=body_callback,
|
||||
headers_dict={"Content-Type": ["application/json"]},
|
||||
long_retries=True,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
if 200 <= response.code < 300:
|
||||
|
|
|
@ -19,11 +19,12 @@ import simplejson as json
|
|||
from canonicaljson import encode_canonical_json
|
||||
from twisted.internet import defer
|
||||
|
||||
import synapse.api.errors
|
||||
import synapse.server
|
||||
import synapse.types
|
||||
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
||||
from synapse.types import UserID
|
||||
from synapse.api.errors import SynapseError, CodeMessageException
|
||||
from synapse.http.servlet import (
|
||||
RestServlet, parse_json_object_from_request, parse_integer
|
||||
)
|
||||
from synapse.types import get_domain_from_id
|
||||
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
||||
from ._base import client_v2_patterns
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -88,7 +89,7 @@ class KeyUploadServlet(RestServlet):
|
|||
device_id = requester.device_id
|
||||
|
||||
if device_id is None:
|
||||
raise synapse.api.errors.SynapseError(
|
||||
raise SynapseError(
|
||||
400,
|
||||
"To upload keys, you must pass device_id when authenticating"
|
||||
)
|
||||
|
@ -195,18 +196,21 @@ class KeyQueryServlet(RestServlet):
|
|||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, user_id, device_id):
|
||||
yield self.auth.get_user_by_req(request)
|
||||
timeout = parse_integer(request, "timeout", 10 * 1000)
|
||||
body = parse_json_object_from_request(request)
|
||||
result = yield self.e2e_keys_handler.query_devices(body)
|
||||
result = yield self.e2e_keys_handler.query_devices(body, timeout)
|
||||
defer.returnValue(result)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, user_id, device_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
timeout = parse_integer(request, "timeout", 10 * 1000)
|
||||
auth_user_id = requester.user.to_string()
|
||||
user_id = user_id if user_id else auth_user_id
|
||||
device_ids = [device_id] if device_id else []
|
||||
result = yield self.e2e_keys_handler.query_devices(
|
||||
{"device_keys": {user_id: device_ids}}
|
||||
{"device_keys": {user_id: device_ids}},
|
||||
timeout,
|
||||
)
|
||||
defer.returnValue(result)
|
||||
|
||||
|
@ -244,39 +248,43 @@ class OneTimeKeyServlet(RestServlet):
|
|||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.federation = hs.get_replication_layer()
|
||||
self.is_mine = hs.is_mine
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, user_id, device_id, algorithm):
|
||||
yield self.auth.get_user_by_req(request)
|
||||
timeout = parse_integer(request, "timeout", 10 * 1000)
|
||||
result = yield self.handle_request(
|
||||
{"one_time_keys": {user_id: {device_id: algorithm}}}
|
||||
{"one_time_keys": {user_id: {device_id: algorithm}}},
|
||||
timeout,
|
||||
)
|
||||
defer.returnValue(result)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, user_id, device_id, algorithm):
|
||||
yield self.auth.get_user_by_req(request)
|
||||
timeout = parse_integer(request, "timeout", 10 * 1000)
|
||||
body = parse_json_object_from_request(request)
|
||||
result = yield self.handle_request(body)
|
||||
result = yield self.handle_request(body, timeout)
|
||||
defer.returnValue(result)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def handle_request(self, body):
|
||||
def handle_request(self, body, timeout):
|
||||
local_query = []
|
||||
remote_queries = {}
|
||||
|
||||
for user_id, device_keys in body.get("one_time_keys", {}).items():
|
||||
user = UserID.from_string(user_id)
|
||||
if self.is_mine(user):
|
||||
if self.is_mine_id(user_id):
|
||||
for device_id, algorithm in device_keys.items():
|
||||
local_query.append((user_id, device_id, algorithm))
|
||||
else:
|
||||
remote_queries.setdefault(user.domain, {})[user_id] = (
|
||||
device_keys
|
||||
)
|
||||
domain = get_domain_from_id(user_id)
|
||||
remote_queries.setdefault(domain, {})[user_id] = device_keys
|
||||
|
||||
results = yield self.store.claim_e2e_one_time_keys(local_query)
|
||||
|
||||
json_result = {}
|
||||
failures = {}
|
||||
for user_id, device_keys in results.items():
|
||||
for device_id, keys in device_keys.items():
|
||||
for key_id, json_bytes in keys.items():
|
||||
|
@ -284,15 +292,32 @@ class OneTimeKeyServlet(RestServlet):
|
|||
key_id: json.loads(json_bytes)
|
||||
}
|
||||
|
||||
for destination, device_keys in remote_queries.items():
|
||||
@defer.inlineCallbacks
|
||||
def claim_client_keys(destination):
|
||||
device_keys = remote_queries[destination]
|
||||
try:
|
||||
remote_result = yield self.federation.claim_client_keys(
|
||||
destination, {"one_time_keys": device_keys}
|
||||
destination,
|
||||
{"one_time_keys": device_keys},
|
||||
timeout=timeout
|
||||
)
|
||||
for user_id, keys in remote_result["one_time_keys"].items():
|
||||
if user_id in device_keys:
|
||||
json_result[user_id] = keys
|
||||
except CodeMessageException as e:
|
||||
failures[destination] = {
|
||||
"status": e.code, "message": e.message
|
||||
}
|
||||
|
||||
defer.returnValue((200, {"one_time_keys": json_result}))
|
||||
yield preserve_context_over_deferred(defer.gatherResults([
|
||||
preserve_fn(claim_client_keys)(destination)
|
||||
for destination in remote_queries
|
||||
]))
|
||||
|
||||
defer.returnValue((200, {
|
||||
"one_time_keys": json_result,
|
||||
"failures": failures
|
||||
}))
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
|
|
Loading…
Reference in New Issue