Merge pull request #1110 from matrix-org/markjh/e2e_timeout

Add a timeout parameter for end2end key queries.
pull/1112/head
Mark Haines 2016-09-13 10:50:45 +01:00 committed by GitHub
commit 76b09c29b0
5 changed files with 114 additions and 54 deletions

View File

@ -176,7 +176,7 @@ class FederationClient(FederationBase):
)
@log_function
def query_client_keys(self, destination, content):
def query_client_keys(self, destination, content, timeout):
"""Query device keys for a device hosted on a remote server.
Args:
@ -188,10 +188,12 @@ class FederationClient(FederationBase):
response
"""
sent_queries_counter.inc("client_device_keys")
return self.transport_layer.query_client_keys(destination, content)
return self.transport_layer.query_client_keys(
destination, content, timeout
)
@log_function
def claim_client_keys(self, destination, content):
def claim_client_keys(self, destination, content, timeout):
"""Claims one-time keys for a device hosted on a remote server.
Args:
@ -203,7 +205,9 @@ class FederationClient(FederationBase):
response
"""
sent_queries_counter.inc("client_one_time_keys")
return self.transport_layer.claim_client_keys(destination, content)
return self.transport_layer.claim_client_keys(
destination, content, timeout
)
@defer.inlineCallbacks
@log_function

View File

@ -298,7 +298,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
def query_client_keys(self, destination, query_content):
def query_client_keys(self, destination, query_content, timeout):
"""Query the device keys for a list of user ids hosted on a remote
server.
@ -327,12 +327,13 @@ class TransportLayerClient(object):
destination=destination,
path=path,
data=query_content,
timeout=timeout,
)
defer.returnValue(content)
@defer.inlineCallbacks
@log_function
def claim_client_keys(self, destination, query_content):
def claim_client_keys(self, destination, query_content, timeout):
"""Claim one-time keys for a list of devices hosted on a remote server.
Request:
@ -363,6 +364,7 @@ class TransportLayerClient(object):
destination=destination,
path=path,
data=query_content,
timeout=timeout,
)
defer.returnValue(content)

View File

@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import json
import logging
from twisted.internet import defer
from synapse.api import errors
import synapse.types
from synapse.api.errors import SynapseError, CodeMessageException
from synapse.types import get_domain_from_id
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
logger = logging.getLogger(__name__)
@ -30,7 +30,6 @@ class E2eKeysHandler(object):
self.store = hs.get_datastore()
self.federation = hs.get_replication_layer()
self.is_mine_id = hs.is_mine_id
self.server_name = hs.hostname
# doesn't really work as part of the generic query API, because the
# query request requires an object POST, but we abuse the
@ -40,7 +39,7 @@ class E2eKeysHandler(object):
)
@defer.inlineCallbacks
def query_devices(self, query_body):
def query_devices(self, query_body, timeout):
""" Handle a device key query from a client
{
@ -63,27 +62,50 @@ class E2eKeysHandler(object):
# separate users by domain.
# make a map from domain to user_id to device_ids
queries_by_domain = collections.defaultdict(dict)
local_query = {}
remote_queries = {}
for user_id, device_ids in device_keys_query.items():
user = synapse.types.UserID.from_string(user_id)
queries_by_domain[user.domain][user_id] = device_ids
if self.is_mine_id(user_id):
local_query[user_id] = device_ids
else:
domain = get_domain_from_id(user_id)
remote_queries.setdefault(domain, {})[user_id] = device_ids
# do the queries
# TODO: do these in parallel
failures = {}
results = {}
for destination, destination_query in queries_by_domain.items():
if destination == self.server_name:
res = yield self.query_local_devices(destination_query)
else:
res = yield self.federation.query_client_keys(
destination, {"device_keys": destination_query}
)
res = res["device_keys"]
for user_id, keys in res.items():
if user_id in destination_query:
if local_query:
local_result = yield self.query_local_devices(local_query)
for user_id, keys in local_result.items():
if user_id in local_query:
results[user_id] = keys
defer.returnValue((200, {"device_keys": results}))
@defer.inlineCallbacks
def do_remote_query(destination):
destination_query = remote_queries[destination]
try:
remote_result = yield self.federation.query_client_keys(
destination,
{"device_keys": destination_query},
timeout=timeout
)
for user_id, keys in remote_result["device_keys"].items():
if user_id in destination_query:
results[user_id] = keys
except CodeMessageException as e:
failures[destination] = {
"status": e.code, "message": e.message
}
yield preserve_context_over_deferred(defer.gatherResults([
preserve_fn(do_remote_query)(destination)
for destination in remote_queries
]))
defer.returnValue((200, {
"device_keys": results, "failures": failures,
}))
@defer.inlineCallbacks
def query_local_devices(self, query):
@ -104,7 +126,7 @@ class E2eKeysHandler(object):
if not self.is_mine_id(user_id):
logger.warning("Request for keys for non-local user %s",
user_id)
raise errors.SynapseError(400, "Not a user here")
raise SynapseError(400, "Not a user here")
if not device_ids:
local_query.append((user_id, None))

View File

@ -246,7 +246,7 @@ class MatrixFederationHttpClient(object):
@defer.inlineCallbacks
def put_json(self, destination, path, data={}, json_data_callback=None,
long_retries=False):
long_retries=False, timeout=None):
""" Sends the specifed json data using PUT
Args:
@ -259,6 +259,8 @@ class MatrixFederationHttpClient(object):
use as the request body.
long_retries (bool): A boolean that indicates whether we should
retry for a short or long time.
timeout(int): How long to try (in ms) the destination for before
giving up. None indicates no timeout.
Returns:
Deferred: Succeeds when we get a 2xx HTTP response. The result
@ -285,6 +287,7 @@ class MatrixFederationHttpClient(object):
body_callback=body_callback,
headers_dict={"Content-Type": ["application/json"]},
long_retries=long_retries,
timeout=timeout,
)
if 200 <= response.code < 300:
@ -300,7 +303,8 @@ class MatrixFederationHttpClient(object):
defer.returnValue(json.loads(body))
@defer.inlineCallbacks
def post_json(self, destination, path, data={}, long_retries=True):
def post_json(self, destination, path, data={}, long_retries=True,
timeout=None):
""" Sends the specifed json data using POST
Args:
@ -311,6 +315,8 @@ class MatrixFederationHttpClient(object):
the request body. This will be encoded as JSON.
long_retries (bool): A boolean that indicates whether we should
retry for a short or long time.
timeout(int): How long to try (in ms) the destination for before
giving up. None indicates no timeout.
Returns:
Deferred: Succeeds when we get a 2xx HTTP response. The result
@ -331,6 +337,7 @@ class MatrixFederationHttpClient(object):
body_callback=body_callback,
headers_dict={"Content-Type": ["application/json"]},
long_retries=True,
timeout=timeout,
)
if 200 <= response.code < 300:

View File

@ -19,11 +19,12 @@ import simplejson as json
from canonicaljson import encode_canonical_json
from twisted.internet import defer
import synapse.api.errors
import synapse.server
import synapse.types
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import UserID
from synapse.api.errors import SynapseError, CodeMessageException
from synapse.http.servlet import (
RestServlet, parse_json_object_from_request, parse_integer
)
from synapse.types import get_domain_from_id
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from ._base import client_v2_patterns
logger = logging.getLogger(__name__)
@ -88,7 +89,7 @@ class KeyUploadServlet(RestServlet):
device_id = requester.device_id
if device_id is None:
raise synapse.api.errors.SynapseError(
raise SynapseError(
400,
"To upload keys, you must pass device_id when authenticating"
)
@ -195,18 +196,21 @@ class KeyQueryServlet(RestServlet):
@defer.inlineCallbacks
def on_POST(self, request, user_id, device_id):
yield self.auth.get_user_by_req(request)
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
result = yield self.e2e_keys_handler.query_devices(body)
result = yield self.e2e_keys_handler.query_devices(body, timeout)
defer.returnValue(result)
@defer.inlineCallbacks
def on_GET(self, request, user_id, device_id):
requester = yield self.auth.get_user_by_req(request)
timeout = parse_integer(request, "timeout", 10 * 1000)
auth_user_id = requester.user.to_string()
user_id = user_id if user_id else auth_user_id
device_ids = [device_id] if device_id else []
result = yield self.e2e_keys_handler.query_devices(
{"device_keys": {user_id: device_ids}}
{"device_keys": {user_id: device_ids}},
timeout,
)
defer.returnValue(result)
@ -244,39 +248,43 @@ class OneTimeKeyServlet(RestServlet):
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.federation = hs.get_replication_layer()
self.is_mine = hs.is_mine
self.is_mine_id = hs.is_mine_id
@defer.inlineCallbacks
def on_GET(self, request, user_id, device_id, algorithm):
yield self.auth.get_user_by_req(request)
timeout = parse_integer(request, "timeout", 10 * 1000)
result = yield self.handle_request(
{"one_time_keys": {user_id: {device_id: algorithm}}}
{"one_time_keys": {user_id: {device_id: algorithm}}},
timeout,
)
defer.returnValue(result)
@defer.inlineCallbacks
def on_POST(self, request, user_id, device_id, algorithm):
yield self.auth.get_user_by_req(request)
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
result = yield self.handle_request(body)
result = yield self.handle_request(body, timeout)
defer.returnValue(result)
@defer.inlineCallbacks
def handle_request(self, body):
def handle_request(self, body, timeout):
local_query = []
remote_queries = {}
for user_id, device_keys in body.get("one_time_keys", {}).items():
user = UserID.from_string(user_id)
if self.is_mine(user):
if self.is_mine_id(user_id):
for device_id, algorithm in device_keys.items():
local_query.append((user_id, device_id, algorithm))
else:
remote_queries.setdefault(user.domain, {})[user_id] = (
device_keys
)
domain = get_domain_from_id(user_id)
remote_queries.setdefault(domain, {})[user_id] = device_keys
results = yield self.store.claim_e2e_one_time_keys(local_query)
json_result = {}
failures = {}
for user_id, device_keys in results.items():
for device_id, keys in device_keys.items():
for key_id, json_bytes in keys.items():
@ -284,15 +292,32 @@ class OneTimeKeyServlet(RestServlet):
key_id: json.loads(json_bytes)
}
for destination, device_keys in remote_queries.items():
@defer.inlineCallbacks
def claim_client_keys(destination):
device_keys = remote_queries[destination]
try:
remote_result = yield self.federation.claim_client_keys(
destination, {"one_time_keys": device_keys}
destination,
{"one_time_keys": device_keys},
timeout=timeout
)
for user_id, keys in remote_result["one_time_keys"].items():
if user_id in device_keys:
json_result[user_id] = keys
except CodeMessageException as e:
failures[destination] = {
"status": e.code, "message": e.message
}
defer.returnValue((200, {"one_time_keys": json_result}))
yield preserve_context_over_deferred(defer.gatherResults([
preserve_fn(claim_client_keys)(destination)
for destination in remote_queries
]))
defer.returnValue((200, {
"one_time_keys": json_result,
"failures": failures
}))
def register_servlets(hs, http_server):