Implement a batch API for verify_json_objects_for_server

erikj/persist_event_perf
Erik Johnston 2015-06-24 11:21:35 +01:00
parent f859e3ca37
commit a29319fefa
4 changed files with 319 additions and 182 deletions

View File

@ -25,11 +25,11 @@ from syutil.base64util import decode_base64, encode_base64
from synapse.api.errors import SynapseError, Codes from synapse.api.errors import SynapseError, Codes
from synapse.util.retryutils import get_retry_limiter from synapse.util.retryutils import get_retry_limiter
from synapse.util import unwrapFirstError
from synapse.util.async import ObservableDeferred
from OpenSSL import crypto from OpenSSL import crypto
from collections import namedtuple
import urllib import urllib
import hashlib import hashlib
import logging import logging
@ -38,6 +38,9 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
KeyGroup = namedtuple("KeyGroup", ("server_name", "group_id", "key_ids"))
class Keyring(object): class Keyring(object):
def __init__(self, hs): def __init__(self, hs):
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -49,8 +52,20 @@ class Keyring(object):
self.key_downloads = {} self.key_downloads = {}
@defer.inlineCallbacks
def verify_json_for_server(self, server_name, json_object): def verify_json_for_server(self, server_name, json_object):
return self.verify_json_objects_for_server(
[(server_name, json_object)]
)[0]
def verify_json_objects_for_server(self, server_and_json):
server_to_key_groupings = {}
group_id_to_json = {}
group_id_to_group = {}
group_ids = []
next_group_id = 0
for server_name, json_object in server_and_json:
logger.debug("Verifying for %s", server_name) logger.debug("Verifying for %s", server_name)
key_ids = signature_ids(json_object, server_name) key_ids = signature_ids(json_object, server_name)
if not key_ids: if not key_ids:
@ -59,8 +74,22 @@ class Keyring(object):
"Not signed with a supported algorithm", "Not signed with a supported algorithm",
Codes.UNAUTHORIZED, Codes.UNAUTHORIZED,
) )
group_id = next_group_id
next_group_id += 1
group_ids.append(group_id)
group = KeyGroup(server_name, group_id, key_ids)
group_id_to_group[group_id] = group
group_id_to_json[group_id] = json_object
server_to_key_groupings.setdefault(server_name, []).append(group)
@defer.inlineCallbacks
def handle_key_deferred(group, deferred):
server_name = group.server_name
try: try:
verify_key = yield self.get_server_verify_key(server_name, key_ids) _, _, key_id, verify_key = yield deferred
except IOError as e: except IOError as e:
logger.warn( logger.warn(
"Got IOError when downloading keys for %s: %s %s", "Got IOError when downloading keys for %s: %s %s",
@ -72,7 +101,7 @@ class Keyring(object):
Codes.UNAUTHORIZED, Codes.UNAUTHORIZED,
) )
except Exception as e: except Exception as e:
logger.warn( logger.exception(
"Got Exception when downloading keys for %s: %s %s", "Got Exception when downloading keys for %s: %s %s",
server_name, type(e).__name__, str(e.message), server_name, type(e).__name__, str(e.message),
) )
@ -82,6 +111,8 @@ class Keyring(object):
Codes.UNAUTHORIZED, Codes.UNAUTHORIZED,
) )
json_object = group_id_to_json[group.group_id]
try: try:
verify_signed_json(json_object, server_name, verify_key) verify_signed_json(json_object, server_name, verify_key)
except: except:
@ -93,79 +124,154 @@ class Keyring(object):
Codes.UNAUTHORIZED, Codes.UNAUTHORIZED,
) )
@defer.inlineCallbacks deferreds = self.get_server_verify_keys(
def get_server_verify_key(self, server_name, key_ids): group_id_to_group
"""Finds a verification key for the server with one of the key ids.
Trys to fetch the key from a trusted perspective server first.
Args:
server_name(str): The name of the server to fetch a key for.
keys_ids (list of str): The key_ids to check for.
"""
cached = yield self.store.get_server_verify_keys(server_name, key_ids)
if cached:
defer.returnValue(cached[0])
return
download = self.key_downloads.get(server_name)
if download is None:
download = self._get_server_verify_key_impl(server_name, key_ids)
download = ObservableDeferred(
download,
consumeErrors=True
) )
self.key_downloads[server_name] = download
@download.addBoth logger.info(
def callback(ret): "Deferred count: %d vs. %d",
del self.key_downloads[server_name] len(deferreds.items()),
return ret len(server_and_json)
)
r = yield download.observe() return [
defer.returnValue(r) handle_key_deferred(
group_id_to_group[g_id],
deferreds[g_id],
)
for g_id in group_ids
]
def get_server_verify_keys(self, group_id_to_group):
merged_results = {}
fns = (
self.get_keys_from_store, # First try the local store
self.get_keys_from_perspectives, # Then try via perspectives
self.get_keys_from_server, # Then try directly
)
group_deferreds = {
group_id: defer.Deferred()
for group_id in group_id_to_group
}
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_server_verify_key_impl(self, server_name, key_ids): def do_iterations():
keys = None missing_keys = {
group.server_name: key_id
for group in group_id_to_group.values()
for key_id in group.key_ids
}
for fn in fns:
results = yield fn(missing_keys.items())
merged_results.update(results)
missing_groups = {}
for group in group_id_to_group.values():
for key_id in group.key_ids:
if key_id in merged_results[group.server_name]:
group_deferreds.pop(group.group_id).callback((
group.group_id,
group.server_name,
key_id,
merged_results[group.server_name][key_id],
))
break
else:
missing_groups.setdefault(
group.server_name, []
).append(group)
if not missing_groups:
break
missing_keys = {
server_name: set(
key_id for group in groups for key_id in group.key_ids
)
for server_name, groups in missing_groups.items()
}
for group in missing_groups.values():
group_deferreds.pop(group.group_id).errback(SynapseError(
401,
"No key for %s with id %s" % (
group.server_name, group.key_ids,
),
Codes.UNAUTHORIZED,
))
def on_err(err):
for deferred in group_deferreds.values():
deferred.errback(err)
group_deferreds.clear()
do_iterations().addErrback(on_err)
return group_deferreds
@defer.inlineCallbacks
def get_keys_from_store(self, server_name_and_key_ids):
res = yield defer.gatherResults(
[
self.store.get_server_verify_keys(server_name, key_ids)
for server_name, key_ids in server_name_and_key_ids
],
consumeErrors=True,
).addErrback(unwrapFirstError)
defer.returnValue(dict(zip(
[server_name for server_name, _ in server_name_and_key_ids],
res
)))
@defer.inlineCallbacks
def get_keys_from_perspectives(self, server_name_and_key_ids):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_key(perspective_name, perspective_keys): def get_key(perspective_name, perspective_keys):
try: try:
result = yield self.get_server_verify_key_v2_indirect( result = yield self.get_server_verify_key_v2_indirect(
server_name, key_ids, perspective_name, perspective_keys server_name_and_key_ids, perspective_name, perspective_keys
) )
defer.returnValue(result) defer.returnValue(result)
except Exception as e: except Exception as e:
logging.info( logger.info(
"Unable to getting key %r for %r from %r: %s %s", "Unable to get key from %r: %s %s",
key_ids, server_name, perspective_name, perspective_name,
type(e).__name__, str(e.message), type(e).__name__, str(e.message),
) )
perspective_results = yield defer.gatherResults([ results = yield defer.gatherResults([
get_key(p_name, p_keys) get_key(p_name, p_keys)
for p_name, p_keys in self.perspective_servers.items() for p_name, p_keys in self.perspective_servers.items()
]) ])
for results in perspective_results: union_of_keys = {}
if results is not None: for result in results:
keys = results for server_name, keys in results.items():
union_of_keys.setdefault(server_name, {}).update(keys)
defer.returnValue(union_of_keys)
@defer.inlineCallbacks
def get_keys_from_server(self, server_name_and_key_ids):
@defer.inlineCallbacks
def get_key(server_name, key_ids):
limiter = yield get_retry_limiter( limiter = yield get_retry_limiter(
server_name, server_name,
self.clock, self.clock,
self.store, self.store,
) )
with limiter: with limiter:
if not keys: keys = None
try: try:
keys = yield self.get_server_verify_key_v2_direct( keys = yield self.get_server_verify_key_v2_direct(
server_name, key_ids server_name, key_ids
) )
except Exception as e: except Exception as e:
logging.info( logger.info(
"Unable to getting key %r for %r directly: %s %s", "Unable to getting key %r for %r directly: %s %s",
key_ids, server_name, key_ids, server_name,
type(e).__name__, str(e.message), type(e).__name__, str(e.message),
@ -176,14 +282,27 @@ class Keyring(object):
server_name, key_ids server_name, key_ids
) )
for key_id in key_ids: keys = {server_name: keys}
if key_id in keys:
defer.returnValue(keys[key_id]) defer.returnValue(keys)
return
raise ValueError("No verification key found for given key ids") results = yield defer.gatherResults([
get_key(server_name, key_ids)
for server_name, key_ids in server_name_and_key_ids
])
merged = {}
for result in results:
merged.update(result)
defer.returnValue({
server_name: keys
for server_name, keys in merged.items()
if keys
})
@defer.inlineCallbacks @defer.inlineCallbacks
def get_server_verify_key_v2_indirect(self, server_name, key_ids, def get_server_verify_key_v2_indirect(self, server_names_and_key_ids,
perspective_name, perspective_name,
perspective_keys): perspective_keys):
limiter = yield get_retry_limiter( limiter = yield get_retry_limiter(
@ -204,6 +323,7 @@ class Keyring(object):
u"minimum_valid_until_ts": 0 u"minimum_valid_until_ts": 0
} for key_id in key_ids } for key_id in key_ids
} }
for server_name, key_ids in server_names_and_key_ids
} }
}, },
) )
@ -243,12 +363,14 @@ class Keyring(object):
" server %r" % (perspective_name,) " server %r" % (perspective_name,)
) )
response_keys = yield self.process_v2_response( processed_response = yield self.process_v2_response(
server_name, perspective_name, response perspective_name, response
) )
keys.update(response_keys) for server_name, response_keys in processed_response:
keys.setdefault(server_name, {}).update(response_keys)
for server_name, response_keys in keys.items():
yield self.store_keys( yield self.store_keys(
server_name=server_name, server_name=server_name,
from_server=perspective_name, from_server=perspective_name,
@ -259,7 +381,6 @@ class Keyring(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_server_verify_key_v2_direct(self, server_name, key_ids): def get_server_verify_key_v2_direct(self, server_name, key_ids):
keys = {} keys = {}
for requested_key_id in key_ids: for requested_key_id in key_ids:
@ -295,25 +416,25 @@ class Keyring(object):
raise ValueError("TLS certificate not allowed by fingerprints") raise ValueError("TLS certificate not allowed by fingerprints")
response_keys = yield self.process_v2_response( response_keys = yield self.process_v2_response(
server_name=server_name,
from_server=server_name, from_server=server_name,
requested_id=requested_key_id, requested_ids=[requested_key_id],
response_json=response, response_json=response,
) )
keys.update(response_keys) keys.update(response_keys)
for server_name, verify_keys in keys.items():
yield self.store_keys( yield self.store_keys(
server_name=server_name, server_name=server_name,
from_server=server_name, from_server=server_name,
verify_keys=keys, verify_keys=verify_keys,
) )
defer.returnValue(keys) defer.returnValue(keys)
@defer.inlineCallbacks @defer.inlineCallbacks
def process_v2_response(self, server_name, from_server, response_json, def process_v2_response(self, from_server, response_json,
requested_id=None): requested_ids=[]):
time_now_ms = self.clock.time_msec() time_now_ms = self.clock.time_msec()
response_keys = {} response_keys = {}
verify_keys = {} verify_keys = {}
@ -335,7 +456,9 @@ class Keyring(object):
verify_key.time_added = time_now_ms verify_key.time_added = time_now_ms
old_verify_keys[key_id] = verify_key old_verify_keys[key_id] = verify_key
for key_id in response_json["signatures"].get(server_name, {}): results = {}
for server_name, keys_dict in response_json["signatures"].items():
for key_id in keys_dict:
if key_id not in response_json["verify_keys"]: if key_id not in response_json["verify_keys"]:
raise ValueError( raise ValueError(
"Key response must include verification keys for all" "Key response must include verification keys for all"
@ -357,9 +480,7 @@ class Keyring(object):
signed_key_json_bytes = encode_canonical_json(signed_key_json) signed_key_json_bytes = encode_canonical_json(signed_key_json)
ts_valid_until_ms = signed_key_json[u"valid_until_ts"] ts_valid_until_ms = signed_key_json[u"valid_until_ts"]
updated_key_ids = set() updated_key_ids = set(requested_ids)
if requested_id is not None:
updated_key_ids.add(requested_id)
updated_key_ids.update(verify_keys) updated_key_ids.update(verify_keys)
updated_key_ids.update(old_verify_keys) updated_key_ids.update(old_verify_keys)
@ -376,9 +497,9 @@ class Keyring(object):
key_json_bytes=signed_key_json_bytes, key_json_bytes=signed_key_json_bytes,
) )
defer.returnValue(response_keys) results[server_name] = response_keys
raise ValueError("No verification key found for given key ids") defer.returnValue(results)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_server_verify_key_v1_direct(self, server_name, key_ids): def get_server_verify_key_v1_direct(self, server_name, key_ids):

View File

@ -99,35 +99,50 @@ class FederationBase(object):
defer.returnValue(signed_pdus) defer.returnValue(signed_pdus)
@defer.inlineCallbacks
def _check_sigs_and_hash(self, pdu): def _check_sigs_and_hash(self, pdu):
"""Throws a SynapseError if the PDU does not have the correct return self._check_sigs_and_hashes([pdu])[0]
def _check_sigs_and_hashes(self, pdus):
"""Throws a SynapseError if a PDU does not have the correct
signatures. signatures.
Returns: Returns:
FrozenEvent: Either the given event or it redacted if it failed the FrozenEvent: Either the given event or it redacted if it failed the
content hash check. content hash check.
""" """
# Check signatures are correct.
redacted_event = prune_event(pdu)
redacted_pdu_json = redacted_event.get_pdu_json()
try: redacted_pdus = [
yield self.keyring.verify_json_for_server( prune_event(pdu)
pdu.origin, redacted_pdu_json for pdu in pdus
]
deferreds = self.keyring.verify_json_objects_for_server([
(p.origin, p.get_pdu_json())
for p in redacted_pdus
])
def callback(_, pdu, redacted):
if not check_event_content_hash(pdu):
logger.warn(
"Event content has been tampered, redacting %s: %s",
pdu.event_id, pdu.get_pdu_json()
) )
except SynapseError: return redacted
return pdu
def errback(failure, pdu):
failure.trap(SynapseError)
logger.warn( logger.warn(
"Signature check failed for %s", "Signature check failed for %s",
pdu.event_id, pdu.event_id,
) )
raise return failure
if not check_event_content_hash(pdu): for deferred, pdu, redacted in zip(deferreds, pdus, redacted_pdus):
logger.warn( deferred.addCallbacks(
"Event content has been tampered, redacting.", callback, errback,
pdu.event_id, callbackArgs=[pdu, redacted],
errbackArgs=[pdu],
) )
defer.returnValue(redacted_event)
defer.returnValue(pdu) return deferreds

View File

@ -166,10 +166,7 @@ class FederationClient(FederationBase):
] ]
# FIXME: We should handle signature failures more gracefully. # FIXME: We should handle signature failures more gracefully.
pdus[:] = yield defer.gatherResults( pdus[:] = yield self._check_sigs_and_hashes(pdus)
[self._check_sigs_and_hash(pdu) for pdu in pdus],
consumeErrors=True,
).addErrback(unwrapFirstError)
defer.returnValue(pdus) defer.returnValue(pdus)
@ -230,7 +227,7 @@ class FederationClient(FederationBase):
pdu = pdu_list[0] pdu = pdu_list[0]
# Check signatures are correct. # Check signatures are correct.
pdu = yield self._check_sigs_and_hash(pdu) pdu = yield self._check_sigs_and_hashes([pdu])[0]
break break
@ -402,7 +399,7 @@ class FederationClient(FederationBase):
except CodeMessageException: except CodeMessageException:
raise raise
except Exception as e: except Exception as e:
logger.warn( logger.exception(
"Failed to send_join via %s: %s", "Failed to send_join via %s: %s",
destination, e.message destination, e.message
) )

View File

@ -101,7 +101,11 @@ class KeyStore(SQLBaseStore):
(list of VerifyKey): The verification keys. (list of VerifyKey): The verification keys.
""" """
keys = yield self.get_all_server_verify_keys(server_name) keys = yield self.get_all_server_verify_keys(server_name)
defer.returnValue([keys[k] for k in key_ids if k in keys]) defer.returnValue({
k: keys[k]
for k in key_ids
if k in keys and keys[k]
})
@defer.inlineCallbacks @defer.inlineCallbacks
def store_server_verify_key(self, server_name, from_server, time_now_ms, def store_server_verify_key(self, server_name, from_server, time_now_ms,