Enforce validity period on server_keys for fed requests. (#5321)
When handling incoming federation requests, make sure that we have an up-to-date copy of the signing key. We do not yet enforce the validity period for event signatures.pull/5334/head
parent
fe2294ec8d
commit
fec2dcb1a5
|
@ -0,0 +1 @@
|
||||||
|
Ensure that we have an up-to-date copy of the signing key when validating incoming federation requests.
|
|
@ -15,6 +15,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
import six
|
import six
|
||||||
from six import raise_from
|
from six import raise_from
|
||||||
|
@ -70,6 +71,9 @@ class VerifyKeyRequest(object):
|
||||||
|
|
||||||
json_object(dict): The JSON object to verify.
|
json_object(dict): The JSON object to verify.
|
||||||
|
|
||||||
|
minimum_valid_until_ts (int): time at which we require the signing key to
|
||||||
|
be valid. (0 implies we don't care)
|
||||||
|
|
||||||
deferred(Deferred[str, str, nacl.signing.VerifyKey]):
|
deferred(Deferred[str, str, nacl.signing.VerifyKey]):
|
||||||
A deferred (server_name, key_id, verify_key) tuple that resolves when
|
A deferred (server_name, key_id, verify_key) tuple that resolves when
|
||||||
a verify key has been fetched. The deferreds' callbacks are run with no
|
a verify key has been fetched. The deferreds' callbacks are run with no
|
||||||
|
@ -82,7 +86,8 @@ class VerifyKeyRequest(object):
|
||||||
server_name = attr.ib()
|
server_name = attr.ib()
|
||||||
key_ids = attr.ib()
|
key_ids = attr.ib()
|
||||||
json_object = attr.ib()
|
json_object = attr.ib()
|
||||||
deferred = attr.ib()
|
minimum_valid_until_ts = attr.ib()
|
||||||
|
deferred = attr.ib(default=attr.Factory(defer.Deferred))
|
||||||
|
|
||||||
|
|
||||||
class KeyLookupError(ValueError):
|
class KeyLookupError(ValueError):
|
||||||
|
@ -90,14 +95,16 @@ class KeyLookupError(ValueError):
|
||||||
|
|
||||||
|
|
||||||
class Keyring(object):
|
class Keyring(object):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs, key_fetchers=None):
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
self._key_fetchers = (
|
if key_fetchers is None:
|
||||||
StoreKeyFetcher(hs),
|
key_fetchers = (
|
||||||
PerspectivesKeyFetcher(hs),
|
StoreKeyFetcher(hs),
|
||||||
ServerKeyFetcher(hs),
|
PerspectivesKeyFetcher(hs),
|
||||||
)
|
ServerKeyFetcher(hs),
|
||||||
|
)
|
||||||
|
self._key_fetchers = key_fetchers
|
||||||
|
|
||||||
# map from server name to Deferred. Has an entry for each server with
|
# map from server name to Deferred. Has an entry for each server with
|
||||||
# an ongoing key download; the Deferred completes once the download
|
# an ongoing key download; the Deferred completes once the download
|
||||||
|
@ -106,9 +113,25 @@ class Keyring(object):
|
||||||
# These are regular, logcontext-agnostic Deferreds.
|
# These are regular, logcontext-agnostic Deferreds.
|
||||||
self.key_downloads = {}
|
self.key_downloads = {}
|
||||||
|
|
||||||
def verify_json_for_server(self, server_name, json_object):
|
def verify_json_for_server(self, server_name, json_object, validity_time):
|
||||||
|
"""Verify that a JSON object has been signed by a given server
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server_name (str): name of the server which must have signed this object
|
||||||
|
|
||||||
|
json_object (dict): object to be checked
|
||||||
|
|
||||||
|
validity_time (int): timestamp at which we require the signing key to
|
||||||
|
be valid. (0 implies we don't care)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred[None]: completes if the the object was correctly signed, otherwise
|
||||||
|
errbacks with an error
|
||||||
|
"""
|
||||||
|
req = server_name, json_object, validity_time
|
||||||
|
|
||||||
return logcontext.make_deferred_yieldable(
|
return logcontext.make_deferred_yieldable(
|
||||||
self.verify_json_objects_for_server([(server_name, json_object)])[0]
|
self.verify_json_objects_for_server((req,))[0]
|
||||||
)
|
)
|
||||||
|
|
||||||
def verify_json_objects_for_server(self, server_and_json):
|
def verify_json_objects_for_server(self, server_and_json):
|
||||||
|
@ -116,10 +139,12 @@ class Keyring(object):
|
||||||
necessary.
|
necessary.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
server_and_json (list): List of pairs of (server_name, json_object)
|
server_and_json (iterable[Tuple[str, dict, int]):
|
||||||
|
Iterable of triplets of (server_name, json_object, validity_time)
|
||||||
|
validity_time is a timestamp at which the signing key must be valid.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List<Deferred>: for each input pair, a deferred indicating success
|
List<Deferred[None]>: for each input triplet, a deferred indicating success
|
||||||
or failure to verify each json object's signature for the given
|
or failure to verify each json object's signature for the given
|
||||||
server_name. The deferreds run their callbacks in the sentinel
|
server_name. The deferreds run their callbacks in the sentinel
|
||||||
logcontext.
|
logcontext.
|
||||||
|
@ -128,12 +153,12 @@ class Keyring(object):
|
||||||
verify_requests = []
|
verify_requests = []
|
||||||
handle = preserve_fn(_handle_key_deferred)
|
handle = preserve_fn(_handle_key_deferred)
|
||||||
|
|
||||||
def process(server_name, json_object):
|
def process(server_name, json_object, validity_time):
|
||||||
"""Process an entry in the request list
|
"""Process an entry in the request list
|
||||||
|
|
||||||
Given a (server_name, json_object) pair from the request list,
|
Given a (server_name, json_object, validity_time) triplet from the request
|
||||||
adds a key request to verify_requests, and returns a deferred which will
|
list, adds a key request to verify_requests, and returns a deferred which
|
||||||
complete or fail (in the sentinel context) when verification completes.
|
will complete or fail (in the sentinel context) when verification completes.
|
||||||
"""
|
"""
|
||||||
key_ids = signature_ids(json_object, server_name)
|
key_ids = signature_ids(json_object, server_name)
|
||||||
|
|
||||||
|
@ -148,7 +173,7 @@ class Keyring(object):
|
||||||
|
|
||||||
# add the key request to the queue, but don't start it off yet.
|
# add the key request to the queue, but don't start it off yet.
|
||||||
verify_request = VerifyKeyRequest(
|
verify_request = VerifyKeyRequest(
|
||||||
server_name, key_ids, json_object, defer.Deferred()
|
server_name, key_ids, json_object, validity_time
|
||||||
)
|
)
|
||||||
verify_requests.append(verify_request)
|
verify_requests.append(verify_request)
|
||||||
|
|
||||||
|
@ -160,8 +185,8 @@ class Keyring(object):
|
||||||
return handle(verify_request)
|
return handle(verify_request)
|
||||||
|
|
||||||
results = [
|
results = [
|
||||||
process(server_name, json_object)
|
process(server_name, json_object, validity_time)
|
||||||
for server_name, json_object in server_and_json
|
for server_name, json_object, validity_time in server_and_json
|
||||||
]
|
]
|
||||||
|
|
||||||
if verify_requests:
|
if verify_requests:
|
||||||
|
@ -298,8 +323,12 @@ class Keyring(object):
|
||||||
verify_request.deferred.errback(
|
verify_request.deferred.errback(
|
||||||
SynapseError(
|
SynapseError(
|
||||||
401,
|
401,
|
||||||
"No key for %s with id %s"
|
"No key for %s with ids in %s (min_validity %i)"
|
||||||
% (verify_request.server_name, verify_request.key_ids),
|
% (
|
||||||
|
verify_request.server_name,
|
||||||
|
verify_request.key_ids,
|
||||||
|
verify_request.minimum_valid_until_ts,
|
||||||
|
),
|
||||||
Codes.UNAUTHORIZED,
|
Codes.UNAUTHORIZED,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -323,18 +352,28 @@ class Keyring(object):
|
||||||
Args:
|
Args:
|
||||||
fetcher (KeyFetcher): fetcher to use to fetch the keys
|
fetcher (KeyFetcher): fetcher to use to fetch the keys
|
||||||
remaining_requests (set[VerifyKeyRequest]): outstanding key requests.
|
remaining_requests (set[VerifyKeyRequest]): outstanding key requests.
|
||||||
Any successfully-completed requests will be reomved from the list.
|
Any successfully-completed requests will be removed from the list.
|
||||||
"""
|
"""
|
||||||
# dict[str, set(str)]: keys to fetch for each server
|
# dict[str, dict[str, int]]: keys to fetch.
|
||||||
missing_keys = {}
|
# server_name -> key_id -> min_valid_ts
|
||||||
|
missing_keys = defaultdict(dict)
|
||||||
|
|
||||||
for verify_request in remaining_requests:
|
for verify_request in remaining_requests:
|
||||||
# any completed requests should already have been removed
|
# any completed requests should already have been removed
|
||||||
assert not verify_request.deferred.called
|
assert not verify_request.deferred.called
|
||||||
missing_keys.setdefault(verify_request.server_name, set()).update(
|
keys_for_server = missing_keys[verify_request.server_name]
|
||||||
verify_request.key_ids
|
|
||||||
)
|
|
||||||
|
|
||||||
results = yield fetcher.get_keys(missing_keys.items())
|
for key_id in verify_request.key_ids:
|
||||||
|
# If we have several requests for the same key, then we only need to
|
||||||
|
# request that key once, but we should do so with the greatest
|
||||||
|
# min_valid_until_ts of the requests, so that we can satisfy all of
|
||||||
|
# the requests.
|
||||||
|
keys_for_server[key_id] = max(
|
||||||
|
keys_for_server.get(key_id, -1),
|
||||||
|
verify_request.minimum_valid_until_ts
|
||||||
|
)
|
||||||
|
|
||||||
|
results = yield fetcher.get_keys(missing_keys)
|
||||||
|
|
||||||
completed = list()
|
completed = list()
|
||||||
for verify_request in remaining_requests:
|
for verify_request in remaining_requests:
|
||||||
|
@ -344,25 +383,34 @@ class Keyring(object):
|
||||||
# complete this VerifyKeyRequest.
|
# complete this VerifyKeyRequest.
|
||||||
result_keys = results.get(server_name, {})
|
result_keys = results.get(server_name, {})
|
||||||
for key_id in verify_request.key_ids:
|
for key_id in verify_request.key_ids:
|
||||||
key = result_keys.get(key_id)
|
fetch_key_result = result_keys.get(key_id)
|
||||||
if key:
|
if not fetch_key_result:
|
||||||
with PreserveLoggingContext():
|
# we didn't get a result for this key
|
||||||
verify_request.deferred.callback(
|
continue
|
||||||
(server_name, key_id, key.verify_key)
|
|
||||||
)
|
if (
|
||||||
completed.append(verify_request)
|
fetch_key_result.valid_until_ts
|
||||||
break
|
< verify_request.minimum_valid_until_ts
|
||||||
|
):
|
||||||
|
# key was not valid at this point
|
||||||
|
continue
|
||||||
|
|
||||||
|
with PreserveLoggingContext():
|
||||||
|
verify_request.deferred.callback(
|
||||||
|
(server_name, key_id, fetch_key_result.verify_key)
|
||||||
|
)
|
||||||
|
completed.append(verify_request)
|
||||||
|
break
|
||||||
|
|
||||||
remaining_requests.difference_update(completed)
|
remaining_requests.difference_update(completed)
|
||||||
|
|
||||||
|
|
||||||
class KeyFetcher(object):
|
class KeyFetcher(object):
|
||||||
def get_keys(self, server_name_and_key_ids):
|
def get_keys(self, keys_to_fetch):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
server_name_and_key_ids (iterable[Tuple[str, iterable[str]]]):
|
keys_to_fetch (dict[str, dict[str, int]]):
|
||||||
list of (server_name, iterable[key_id]) tuples to fetch keys for
|
the keys to be fetched. server_name -> key_id -> min_valid_ts
|
||||||
Note that the iterables may be iterated more than once.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]:
|
Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]:
|
||||||
|
@ -378,13 +426,15 @@ class StoreKeyFetcher(KeyFetcher):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_keys(self, server_name_and_key_ids):
|
def get_keys(self, keys_to_fetch):
|
||||||
"""see KeyFetcher.get_keys"""
|
"""see KeyFetcher.get_keys"""
|
||||||
|
|
||||||
keys_to_fetch = (
|
keys_to_fetch = (
|
||||||
(server_name, key_id)
|
(server_name, key_id)
|
||||||
for server_name, key_ids in server_name_and_key_ids
|
for server_name, keys_for_server in keys_to_fetch.items()
|
||||||
for key_id in key_ids
|
for key_id in keys_for_server.keys()
|
||||||
)
|
)
|
||||||
|
|
||||||
res = yield self.store.get_server_verify_keys(keys_to_fetch)
|
res = yield self.store.get_server_verify_keys(keys_to_fetch)
|
||||||
keys = {}
|
keys = {}
|
||||||
for (server_name, key_id), key in res.items():
|
for (server_name, key_id), key in res.items():
|
||||||
|
@ -508,14 +558,14 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
||||||
self.perspective_servers = self.config.perspectives
|
self.perspective_servers = self.config.perspectives
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_keys(self, server_name_and_key_ids):
|
def get_keys(self, keys_to_fetch):
|
||||||
"""see KeyFetcher.get_keys"""
|
"""see KeyFetcher.get_keys"""
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_key(perspective_name, perspective_keys):
|
def get_key(perspective_name, perspective_keys):
|
||||||
try:
|
try:
|
||||||
result = yield self.get_server_verify_key_v2_indirect(
|
result = yield self.get_server_verify_key_v2_indirect(
|
||||||
server_name_and_key_ids, perspective_name, perspective_keys
|
keys_to_fetch, perspective_name, perspective_keys
|
||||||
)
|
)
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
except KeyLookupError as e:
|
except KeyLookupError as e:
|
||||||
|
@ -549,13 +599,15 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_server_verify_key_v2_indirect(
|
def get_server_verify_key_v2_indirect(
|
||||||
self, server_names_and_key_ids, perspective_name, perspective_keys
|
self, keys_to_fetch, perspective_name, perspective_keys
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
server_names_and_key_ids (iterable[Tuple[str, iterable[str]]]):
|
keys_to_fetch (dict[str, dict[str, int]]):
|
||||||
list of (server_name, iterable[key_id]) tuples to fetch keys for
|
the keys to be fetched. server_name -> key_id -> min_valid_ts
|
||||||
|
|
||||||
perspective_name (str): name of the notary server to query for the keys
|
perspective_name (str): name of the notary server to query for the keys
|
||||||
|
|
||||||
perspective_keys (dict[str, VerifyKey]): map of key_id->key for the
|
perspective_keys (dict[str, VerifyKey]): map of key_id->key for the
|
||||||
notary server
|
notary server
|
||||||
|
|
||||||
|
@ -569,12 +621,10 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
||||||
"""
|
"""
|
||||||
logger.info(
|
logger.info(
|
||||||
"Requesting keys %s from notary server %s",
|
"Requesting keys %s from notary server %s",
|
||||||
server_names_and_key_ids,
|
keys_to_fetch.items(),
|
||||||
perspective_name,
|
perspective_name,
|
||||||
)
|
)
|
||||||
# TODO(mark): Set the minimum_valid_until_ts to that needed by
|
|
||||||
# the events being validated or the current time if validating
|
|
||||||
# an incoming request.
|
|
||||||
try:
|
try:
|
||||||
query_response = yield self.client.post_json(
|
query_response = yield self.client.post_json(
|
||||||
destination=perspective_name,
|
destination=perspective_name,
|
||||||
|
@ -582,9 +632,10 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
||||||
data={
|
data={
|
||||||
u"server_keys": {
|
u"server_keys": {
|
||||||
server_name: {
|
server_name: {
|
||||||
key_id: {u"minimum_valid_until_ts": 0} for key_id in key_ids
|
key_id: {u"minimum_valid_until_ts": min_valid_ts}
|
||||||
|
for key_id, min_valid_ts in server_keys.items()
|
||||||
}
|
}
|
||||||
for server_name, key_ids in server_names_and_key_ids
|
for server_name, server_keys in keys_to_fetch.items()
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
long_retries=True,
|
long_retries=True,
|
||||||
|
@ -694,15 +745,18 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
|
||||||
self.client = hs.get_http_client()
|
self.client = hs.get_http_client()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_keys(self, server_name_and_key_ids):
|
def get_keys(self, keys_to_fetch):
|
||||||
"""see KeyFetcher.get_keys"""
|
"""see KeyFetcher.get_keys"""
|
||||||
|
# TODO make this more resilient
|
||||||
results = yield logcontext.make_deferred_yieldable(
|
results = yield logcontext.make_deferred_yieldable(
|
||||||
defer.gatherResults(
|
defer.gatherResults(
|
||||||
[
|
[
|
||||||
run_in_background(
|
run_in_background(
|
||||||
self.get_server_verify_key_v2_direct, server_name, key_ids
|
self.get_server_verify_key_v2_direct,
|
||||||
|
server_name,
|
||||||
|
server_keys.keys(),
|
||||||
)
|
)
|
||||||
for server_name, key_ids in server_name_and_key_ids
|
for server_name, server_keys in keys_to_fetch.items()
|
||||||
],
|
],
|
||||||
consumeErrors=True,
|
consumeErrors=True,
|
||||||
).addErrback(unwrapFirstError)
|
).addErrback(unwrapFirstError)
|
||||||
|
@ -721,6 +775,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
|
||||||
keys = {} # type: dict[str, FetchKeyResult]
|
keys = {} # type: dict[str, FetchKeyResult]
|
||||||
|
|
||||||
for requested_key_id in key_ids:
|
for requested_key_id in key_ids:
|
||||||
|
# we may have found this key as a side-effect of asking for another.
|
||||||
if requested_key_id in keys:
|
if requested_key_id in keys:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
|
@ -265,7 +265,7 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
|
||||||
]
|
]
|
||||||
|
|
||||||
more_deferreds = keyring.verify_json_objects_for_server([
|
more_deferreds = keyring.verify_json_objects_for_server([
|
||||||
(p.sender_domain, p.redacted_pdu_json)
|
(p.sender_domain, p.redacted_pdu_json, 0)
|
||||||
for p in pdus_to_check_sender
|
for p in pdus_to_check_sender
|
||||||
])
|
])
|
||||||
|
|
||||||
|
@ -298,7 +298,7 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
|
||||||
]
|
]
|
||||||
|
|
||||||
more_deferreds = keyring.verify_json_objects_for_server([
|
more_deferreds = keyring.verify_json_objects_for_server([
|
||||||
(get_domain_from_id(p.pdu.event_id), p.redacted_pdu_json)
|
(get_domain_from_id(p.pdu.event_id), p.redacted_pdu_json, 0)
|
||||||
for p in pdus_to_check_event_id
|
for p in pdus_to_check_event_id
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|
|
@ -94,6 +94,7 @@ class NoAuthenticationError(AuthenticationError):
|
||||||
|
|
||||||
class Authenticator(object):
|
class Authenticator(object):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
|
self._clock = hs.get_clock()
|
||||||
self.keyring = hs.get_keyring()
|
self.keyring = hs.get_keyring()
|
||||||
self.server_name = hs.hostname
|
self.server_name = hs.hostname
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
@ -102,6 +103,7 @@ class Authenticator(object):
|
||||||
# A method just so we can pass 'self' as the authenticator to the Servlets
|
# A method just so we can pass 'self' as the authenticator to the Servlets
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def authenticate_request(self, request, content):
|
def authenticate_request(self, request, content):
|
||||||
|
now = self._clock.time_msec()
|
||||||
json_request = {
|
json_request = {
|
||||||
"method": request.method.decode('ascii'),
|
"method": request.method.decode('ascii'),
|
||||||
"uri": request.uri.decode('ascii'),
|
"uri": request.uri.decode('ascii'),
|
||||||
|
@ -138,7 +140,7 @@ class Authenticator(object):
|
||||||
401, "Missing Authorization headers", Codes.UNAUTHORIZED,
|
401, "Missing Authorization headers", Codes.UNAUTHORIZED,
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self.keyring.verify_json_for_server(origin, json_request)
|
yield self.keyring.verify_json_for_server(origin, json_request, now)
|
||||||
|
|
||||||
logger.info("Request from %s", origin)
|
logger.info("Request from %s", origin)
|
||||||
request.authenticated_entity = origin
|
request.authenticated_entity = origin
|
||||||
|
|
|
@ -97,10 +97,11 @@ class GroupAttestationSigning(object):
|
||||||
|
|
||||||
# TODO: We also want to check that *new* attestations that people give
|
# TODO: We also want to check that *new* attestations that people give
|
||||||
# us to store are valid for at least a little while.
|
# us to store are valid for at least a little while.
|
||||||
if valid_until_ms < self.clock.time_msec():
|
now = self.clock.time_msec()
|
||||||
|
if valid_until_ms < now:
|
||||||
raise SynapseError(400, "Attestation expired")
|
raise SynapseError(400, "Attestation expired")
|
||||||
|
|
||||||
yield self.keyring.verify_json_for_server(server_name, attestation)
|
yield self.keyring.verify_json_for_server(server_name, attestation, now)
|
||||||
|
|
||||||
def create_attestation(self, group_id, user_id):
|
def create_attestation(self, group_id, user_id):
|
||||||
"""Create an attestation for the group_id and user_id with default
|
"""Create an attestation for the group_id and user_id with default
|
||||||
|
|
|
@ -19,6 +19,7 @@ from mock import Mock
|
||||||
import canonicaljson
|
import canonicaljson
|
||||||
import signedjson.key
|
import signedjson.key
|
||||||
import signedjson.sign
|
import signedjson.sign
|
||||||
|
from signedjson.key import get_verify_key
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
@ -137,7 +138,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
|
||||||
context_11.request = "11"
|
context_11.request = "11"
|
||||||
|
|
||||||
res_deferreds = kr.verify_json_objects_for_server(
|
res_deferreds = kr.verify_json_objects_for_server(
|
||||||
[("server10", json1), ("server11", {})]
|
[("server10", json1, 0), ("server11", {}, 0)]
|
||||||
)
|
)
|
||||||
|
|
||||||
# the unsigned json should be rejected pretty quickly
|
# the unsigned json should be rejected pretty quickly
|
||||||
|
@ -174,7 +175,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
|
||||||
self.http_client.post_json.return_value = defer.Deferred()
|
self.http_client.post_json.return_value = defer.Deferred()
|
||||||
|
|
||||||
res_deferreds_2 = kr.verify_json_objects_for_server(
|
res_deferreds_2 = kr.verify_json_objects_for_server(
|
||||||
[("server10", json1)]
|
[("server10", json1, 0)]
|
||||||
)
|
)
|
||||||
res_deferreds_2[0].addBoth(self.check_context, None)
|
res_deferreds_2[0].addBoth(self.check_context, None)
|
||||||
yield logcontext.make_deferred_yieldable(res_deferreds_2[0])
|
yield logcontext.make_deferred_yieldable(res_deferreds_2[0])
|
||||||
|
@ -197,31 +198,108 @@ class KeyringTestCase(unittest.HomeserverTestCase):
|
||||||
kr = keyring.Keyring(self.hs)
|
kr = keyring.Keyring(self.hs)
|
||||||
|
|
||||||
key1 = signedjson.key.generate_signing_key(1)
|
key1 = signedjson.key.generate_signing_key(1)
|
||||||
key1_id = "%s:%s" % (key1.alg, key1.version)
|
|
||||||
|
|
||||||
r = self.hs.datastore.store_server_verify_keys(
|
r = self.hs.datastore.store_server_verify_keys(
|
||||||
"server9",
|
"server9",
|
||||||
time.time() * 1000,
|
time.time() * 1000,
|
||||||
[
|
[("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), 1000))],
|
||||||
(
|
|
||||||
"server9",
|
|
||||||
key1_id,
|
|
||||||
FetchKeyResult(signedjson.key.get_verify_key(key1), 1000),
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
self.get_success(r)
|
self.get_success(r)
|
||||||
|
|
||||||
json1 = {}
|
json1 = {}
|
||||||
signedjson.sign.sign_json(json1, "server9", key1)
|
signedjson.sign.sign_json(json1, "server9", key1)
|
||||||
|
|
||||||
# should fail immediately on an unsigned object
|
# should fail immediately on an unsigned object
|
||||||
d = _verify_json_for_server(kr, "server9", {})
|
d = _verify_json_for_server(kr, "server9", {}, 0)
|
||||||
self.failureResultOf(d, SynapseError)
|
self.failureResultOf(d, SynapseError)
|
||||||
|
|
||||||
d = _verify_json_for_server(kr, "server9", json1)
|
# should suceed on a signed object
|
||||||
self.assertFalse(d.called)
|
d = _verify_json_for_server(kr, "server9", json1, 500)
|
||||||
|
# self.assertFalse(d.called)
|
||||||
self.get_success(d)
|
self.get_success(d)
|
||||||
|
|
||||||
|
def test_verify_json_dedupes_key_requests(self):
|
||||||
|
"""Two requests for the same key should be deduped."""
|
||||||
|
key1 = signedjson.key.generate_signing_key(1)
|
||||||
|
|
||||||
|
def get_keys(keys_to_fetch):
|
||||||
|
# there should only be one request object (with the max validity)
|
||||||
|
self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
|
||||||
|
|
||||||
|
return defer.succeed(
|
||||||
|
{
|
||||||
|
"server1": {
|
||||||
|
get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_fetcher = keyring.KeyFetcher()
|
||||||
|
mock_fetcher.get_keys = Mock(side_effect=get_keys)
|
||||||
|
kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,))
|
||||||
|
|
||||||
|
json1 = {}
|
||||||
|
signedjson.sign.sign_json(json1, "server1", key1)
|
||||||
|
|
||||||
|
# the first request should succeed; the second should fail because the key
|
||||||
|
# has expired
|
||||||
|
results = kr.verify_json_objects_for_server(
|
||||||
|
[("server1", json1, 500), ("server1", json1, 1500)]
|
||||||
|
)
|
||||||
|
self.assertEqual(len(results), 2)
|
||||||
|
self.get_success(results[0])
|
||||||
|
e = self.get_failure(results[1], SynapseError).value
|
||||||
|
self.assertEqual(e.errcode, "M_UNAUTHORIZED")
|
||||||
|
self.assertEqual(e.code, 401)
|
||||||
|
|
||||||
|
# there should have been a single call to the fetcher
|
||||||
|
mock_fetcher.get_keys.assert_called_once()
|
||||||
|
|
||||||
|
def test_verify_json_falls_back_to_other_fetchers(self):
|
||||||
|
"""If the first fetcher cannot provide a recent enough key, we fall back"""
|
||||||
|
key1 = signedjson.key.generate_signing_key(1)
|
||||||
|
|
||||||
|
def get_keys1(keys_to_fetch):
|
||||||
|
self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
|
||||||
|
return defer.succeed(
|
||||||
|
{
|
||||||
|
"server1": {
|
||||||
|
get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_keys2(keys_to_fetch):
|
||||||
|
self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
|
||||||
|
return defer.succeed(
|
||||||
|
{
|
||||||
|
"server1": {
|
||||||
|
get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_fetcher1 = keyring.KeyFetcher()
|
||||||
|
mock_fetcher1.get_keys = Mock(side_effect=get_keys1)
|
||||||
|
mock_fetcher2 = keyring.KeyFetcher()
|
||||||
|
mock_fetcher2.get_keys = Mock(side_effect=get_keys2)
|
||||||
|
kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher1, mock_fetcher2))
|
||||||
|
|
||||||
|
json1 = {}
|
||||||
|
signedjson.sign.sign_json(json1, "server1", key1)
|
||||||
|
|
||||||
|
results = kr.verify_json_objects_for_server(
|
||||||
|
[("server1", json1, 1200), ("server1", json1, 1500)]
|
||||||
|
)
|
||||||
|
self.assertEqual(len(results), 2)
|
||||||
|
self.get_success(results[0])
|
||||||
|
e = self.get_failure(results[1], SynapseError).value
|
||||||
|
self.assertEqual(e.errcode, "M_UNAUTHORIZED")
|
||||||
|
self.assertEqual(e.code, 401)
|
||||||
|
|
||||||
|
# there should have been a single call to each fetcher
|
||||||
|
mock_fetcher1.get_keys.assert_called_once()
|
||||||
|
mock_fetcher2.get_keys.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
|
class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor, clock):
|
||||||
|
@ -260,8 +338,8 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.http_client.get_json.side_effect = get_json
|
self.http_client.get_json.side_effect = get_json
|
||||||
|
|
||||||
server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
|
keys_to_fetch = {SERVER_NAME: {"key1": 0}}
|
||||||
keys = self.get_success(fetcher.get_keys(server_name_and_key_ids))
|
keys = self.get_success(fetcher.get_keys(keys_to_fetch))
|
||||||
k = keys[SERVER_NAME][testverifykey_id]
|
k = keys[SERVER_NAME][testverifykey_id]
|
||||||
self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
|
self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
|
||||||
self.assertEqual(k.verify_key, testverifykey)
|
self.assertEqual(k.verify_key, testverifykey)
|
||||||
|
@ -288,9 +366,7 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# change the server name: it should cause a rejection
|
# change the server name: it should cause a rejection
|
||||||
response["server_name"] = "OTHER_SERVER"
|
response["server_name"] = "OTHER_SERVER"
|
||||||
self.get_failure(
|
self.get_failure(fetcher.get_keys(keys_to_fetch), KeyLookupError)
|
||||||
fetcher.get_keys(server_name_and_key_ids), KeyLookupError
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
|
class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
|
||||||
|
@ -342,8 +418,8 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.http_client.post_json.side_effect = post_json
|
self.http_client.post_json.side_effect = post_json
|
||||||
|
|
||||||
server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
|
keys_to_fetch = {SERVER_NAME: {"key1": 0}}
|
||||||
keys = self.get_success(fetcher.get_keys(server_name_and_key_ids))
|
keys = self.get_success(fetcher.get_keys(keys_to_fetch))
|
||||||
self.assertIn(SERVER_NAME, keys)
|
self.assertIn(SERVER_NAME, keys)
|
||||||
k = keys[SERVER_NAME][testverifykey_id]
|
k = keys[SERVER_NAME][testverifykey_id]
|
||||||
self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
|
self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
|
||||||
|
@ -401,7 +477,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
def get_key_from_perspectives(response):
|
def get_key_from_perspectives(response):
|
||||||
fetcher = PerspectivesKeyFetcher(self.hs)
|
fetcher = PerspectivesKeyFetcher(self.hs)
|
||||||
server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
|
keys_to_fetch = {SERVER_NAME: {"key1": 0}}
|
||||||
|
|
||||||
def post_json(destination, path, data, **kwargs):
|
def post_json(destination, path, data, **kwargs):
|
||||||
self.assertEqual(destination, self.mock_perspective_server.server_name)
|
self.assertEqual(destination, self.mock_perspective_server.server_name)
|
||||||
|
@ -410,9 +486,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.http_client.post_json.side_effect = post_json
|
self.http_client.post_json.side_effect = post_json
|
||||||
|
|
||||||
return self.get_success(
|
return self.get_success(fetcher.get_keys(keys_to_fetch))
|
||||||
fetcher.get_keys(server_name_and_key_ids)
|
|
||||||
)
|
|
||||||
|
|
||||||
# start with a valid response so we can check we are testing the right thing
|
# start with a valid response so we can check we are testing the right thing
|
||||||
response = build_response()
|
response = build_response()
|
||||||
|
@ -435,6 +509,11 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(keys, {}, "Expected empty dict with missing origin server sig")
|
self.assertEqual(keys, {}, "Expected empty dict with missing origin server sig")
|
||||||
|
|
||||||
|
|
||||||
|
def get_key_id(key):
|
||||||
|
"""Get the matrix ID tag for a given SigningKey or VerifyKey"""
|
||||||
|
return "%s:%s" % (key.alg, key.version)
|
||||||
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def run_in_context(f, *args, **kwargs):
|
def run_in_context(f, *args, **kwargs):
|
||||||
with LoggingContext("testctx") as ctx:
|
with LoggingContext("testctx") as ctx:
|
||||||
|
@ -445,14 +524,16 @@ def run_in_context(f, *args, **kwargs):
|
||||||
defer.returnValue(rv)
|
defer.returnValue(rv)
|
||||||
|
|
||||||
|
|
||||||
def _verify_json_for_server(keyring, server_name, json_object):
|
def _verify_json_for_server(keyring, server_name, json_object, validity_time):
|
||||||
"""thin wrapper around verify_json_for_server which makes sure it is wrapped
|
"""thin wrapper around verify_json_for_server which makes sure it is wrapped
|
||||||
with the patched defer.inlineCallbacks.
|
with the patched defer.inlineCallbacks.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def v():
|
def v():
|
||||||
rv1 = yield keyring.verify_json_for_server(server_name, json_object)
|
rv1 = yield keyring.verify_json_for_server(
|
||||||
|
server_name, json_object, validity_time
|
||||||
|
)
|
||||||
defer.returnValue(rv1)
|
defer.returnValue(rv1)
|
||||||
|
|
||||||
return run_in_context(v)
|
return run_in_context(v)
|
||||||
|
|
Loading…
Reference in New Issue