E2E keys: Make federation query share code with client query

Refactor the e2e query handler to separate out the local query, and then make
the federation handler use it.
pull/972/head
Richard van der Hoff 2016-08-02 18:06:31 +01:00
parent 986615b0b2
commit 1efee2f52b
3 changed files with 92 additions and 47 deletions

View File

@ -348,27 +348,9 @@ class FederationServer(FederationBase):
(200, send_content) (200, send_content)
) )
@defer.inlineCallbacks
@log_function @log_function
def on_query_client_keys(self, origin, content): def on_query_client_keys(self, origin, content):
query = [] return self.on_query_request("client_keys", content)
for user_id, device_ids in content.get("device_keys", {}).items():
if not device_ids:
query.append((user_id, None))
else:
for device_id in device_ids:
query.append((user_id, device_id))
results = yield self.store.get_e2e_device_keys(query)
json_result = {}
for user_id, device_keys in results.items():
for device_id, json_bytes in device_keys.items():
json_result.setdefault(user_id, {})[device_id] = json.loads(
json_bytes
)
defer.returnValue({"device_keys": json_result})
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function

View File

@ -367,10 +367,8 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet):
class FederationClientKeysQueryServlet(BaseFederationServlet): class FederationClientKeysQueryServlet(BaseFederationServlet):
PATH = "/user/keys/query" PATH = "/user/keys/query"
@defer.inlineCallbacks
def on_POST(self, origin, content, query): def on_POST(self, origin, content, query):
response = yield self.handler.on_query_client_keys(origin, content) return self.handler.on_query_client_keys(origin, content)
defer.returnValue((200, response))
class FederationClientKeysClaimServlet(BaseFederationServlet): class FederationClientKeysClaimServlet(BaseFederationServlet):

View File

@ -13,12 +13,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import collections
import json import json
import logging import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.api import errors
import synapse.types import synapse.types
from ._base import BaseHandler from ._base import BaseHandler
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -29,39 +32,101 @@ class E2eKeysHandler(BaseHandler):
super(E2eKeysHandler, self).__init__(hs) super(E2eKeysHandler, self).__init__(hs)
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.federation = hs.get_replication_layer() self.federation = hs.get_replication_layer()
self.is_mine = hs.is_mine self.is_mine_id = hs.is_mine_id
# doesn't really work as part of the generic query API, because the
# query request requires an object POST, but we abuse the
# "query handler" interface.
self.federation.register_query_handler(
"client_keys", self.on_federation_query_client_keys
)
@defer.inlineCallbacks @defer.inlineCallbacks
def query_devices(self, query_body): def query_devices(self, query_body):
local_query = [] """ Handle a device key query from a client
remote_queries = {}
for user_id, device_ids in query_body.get("device_keys", {}).items(): {
"device_keys": {
"<user_id>": ["<device_id>"]
}
}
->
{
"device_keys": {
"<user_id>": {
"<device_id>": {
...
}
}
}
}
"""
device_keys_query = query_body.get("device_keys", {})
# separate users by domain.
# make a map from domain to user_id to device_ids
queries_by_domain = collections.defaultdict(dict)
for user_id, device_ids in device_keys_query.items():
user = synapse.types.UserID.from_string(user_id) user = synapse.types.UserID.from_string(user_id)
if self.is_mine(user): queries_by_domain[user.domain][user_id] = device_ids
if not device_ids:
local_query.append((user_id, None)) # do the queries
else: # TODO: do these in parallel
for device_id in device_ids: results = {}
local_query.append((user_id, device_id)) for destination, destination_query in queries_by_domain.items():
if destination == self.hs.hostname:
res = yield self.query_local_devices(destination_query)
else: else:
remote_queries.setdefault(user.domain, {})[user_id] = list( res = yield self.federation.query_client_keys(
device_ids destination, {"device_keys": destination_query}
) )
res = res["device_keys"]
for user_id, keys in res.items():
if user_id in destination_query:
results[user_id] = keys
defer.returnValue((200, {"device_keys": results}))
@defer.inlineCallbacks
def query_local_devices(self, query):
"""Get E2E device keys for local users
Args:
query (dict[string, list[string]|None): map from user_id to a list
of devices to query (None for all devices)
Returns:
defer.Deferred: (resolves to dict[string, dict[string, dict]]):
map from user_id -> device_id -> device details
"""
local_query = []
for user_id, device_ids in query.items():
if not self.is_mine_id(user_id):
logger.warning("Request for keys for non-local user %s",
user_id)
raise errors.SynapseError(400, "Not a user here")
if not device_ids:
local_query.append((user_id, None))
else:
for device_id in device_ids:
local_query.append((user_id, device_id))
results = yield self.store.get_e2e_device_keys(local_query) results = yield self.store.get_e2e_device_keys(local_query)
json_result = {} # un-jsonify the results
json_result = collections.defaultdict(dict)
for user_id, device_keys in results.items(): for user_id, device_keys in results.items():
for device_id, json_bytes in device_keys.items(): for device_id, json_bytes in device_keys.items():
json_result.setdefault(user_id, {})[ json_result[user_id][device_id] = json.loads(json_bytes)
device_id] = json.loads(
json_bytes
)
for destination, device_keys in remote_queries.items(): defer.returnValue(json_result)
remote_result = yield self.federation.query_client_keys(
destination, {"device_keys": device_keys} @defer.inlineCallbacks
) def on_federation_query_client_keys(self, query_body):
for user_id, keys in remote_result["device_keys"].items(): """ Handle a device key query from a federated server
if user_id in device_keys: """
json_result[user_id] = keys device_keys_query = query_body.get("device_keys", {})
defer.returnValue((200, {"device_keys": json_result})) res = yield self.query_local_devices(device_keys_query)
defer.returnValue({"device_keys": res})