Add some more tests for Keyring

pull/2459/head
Richard van der Hoff 2017-09-20 01:32:42 +01:00
parent c5c24c239b
commit 72472456d8
1 changed files with 140 additions and 37 deletions

View File

@ -12,39 +12,72 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import signedjson import time
import signedjson.key
import signedjson.sign
from mock import Mock from mock import Mock
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.crypto import keyring from synapse.crypto import keyring
from synapse.util import async from synapse.util import async, logcontext
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext
from tests import unittest, utils from tests import unittest, utils
from twisted.internet import defer from twisted.internet import defer
class MockPerspectiveServer(object):
def __init__(self):
self.server_name = "mock_server"
self.key = signedjson.key.generate_signing_key(0)
def get_verify_keys(self):
vk = signedjson.key.get_verify_key(self.key)
return {
"%s:%s" % (vk.alg, vk.version): vk,
}
def get_signed_key(self, server_name, verify_key):
key_id = "%s:%s" % (verify_key.alg, verify_key.version)
res = {
"server_name": server_name,
"old_verify_keys": {},
"valid_until_ts": time.time() * 1000 + 3600,
"verify_keys": {
key_id: {
"key": signedjson.key.encode_verify_key_base64(verify_key)
}
}
}
signedjson.sign.sign_json(res, self.server_name, self.key)
return res
class KeyringTestCase(unittest.TestCase): class KeyringTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def setUp(self): def setUp(self):
self.mock_perspective_server = MockPerspectiveServer()
self.http_client = Mock() self.http_client = Mock()
self.hs = yield utils.setup_test_homeserver( self.hs = yield utils.setup_test_homeserver(
handlers=None, handlers=None,
http_client=self.http_client, http_client=self.http_client,
) )
self.hs.config.perspectives = { self.hs.config.perspectives = {
"persp_server": {"k": "v"} self.mock_perspective_server.server_name:
self.mock_perspective_server.get_verify_keys()
} }
def check_context(self, _, expected):
self.assertEquals(
getattr(LoggingContext.current_context(), "test_key", None),
expected
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_wait_for_previous_lookups(self): def test_wait_for_previous_lookups(self):
sentinel_context = LoggingContext.current_context() sentinel_context = LoggingContext.current_context()
kr = keyring.Keyring(self.hs) kr = keyring.Keyring(self.hs)
def check_context(_, expected):
self.assertEquals(
LoggingContext.current_context().test_key, expected
)
lookup_1_deferred = defer.Deferred() lookup_1_deferred = defer.Deferred()
lookup_2_deferred = defer.Deferred() lookup_2_deferred = defer.Deferred()
@ -60,7 +93,7 @@ class KeyringTestCase(unittest.TestCase):
self.assertTrue(wait_1_deferred.called) self.assertTrue(wait_1_deferred.called)
# ... so we should have preserved the LoggingContext. # ... so we should have preserved the LoggingContext.
self.assertIs(LoggingContext.current_context(), context_one) self.assertIs(LoggingContext.current_context(), context_one)
wait_1_deferred.addBoth(check_context, "one") wait_1_deferred.addBoth(self.check_context, "one")
with LoggingContext("two") as context_two: with LoggingContext("two") as context_two:
context_two.test_key = "two" context_two.test_key = "two"
@ -74,7 +107,7 @@ class KeyringTestCase(unittest.TestCase):
self.assertFalse(wait_2_deferred.called) self.assertFalse(wait_2_deferred.called)
# ... so we should have reset the LoggingContext. # ... so we should have reset the LoggingContext.
self.assertIs(LoggingContext.current_context(), sentinel_context) self.assertIs(LoggingContext.current_context(), sentinel_context)
wait_2_deferred.addBoth(check_context, "two") wait_2_deferred.addBoth(self.check_context, "two")
# let the first lookup complete (in the sentinel context) # let the first lookup complete (in the sentinel context)
lookup_1_deferred.callback(None) lookup_1_deferred.callback(None)
@ -89,18 +122,40 @@ class KeyringTestCase(unittest.TestCase):
kr = keyring.Keyring(self.hs) kr = keyring.Keyring(self.hs)
json1 = {} json1 = {}
signedjson.sign.sign_json(json1, "server1", key1) signedjson.sign.sign_json(json1, "server10", key1)
self.http_client.post_json.return_value = defer.Deferred() persp_resp = {
"server_keys": [
self.mock_perspective_server.get_signed_key(
"server10",
signedjson.key.get_verify_key(key1)
),
]
}
persp_deferred = defer.Deferred()
@defer.inlineCallbacks
def get_perspectives(**kwargs):
self.assertEquals(
LoggingContext.current_context().test_key, "11",
)
with logcontext.PreserveLoggingContext():
yield persp_deferred
defer.returnValue(persp_resp)
self.http_client.post_json.side_effect = get_perspectives
with LoggingContext("11") as context_11:
context_11.test_key = "11"
# start off a first set of lookups # start off a first set of lookups
res_deferreds = kr.verify_json_objects_for_server( res_deferreds = kr.verify_json_objects_for_server(
[("server1", json1), [("server10", json1),
("server2", {}) ("server11", {})
] ]
) )
# the unsigned json should be rejected pretty quickly # the unsigned json should be rejected pretty quickly
self.assertTrue(res_deferreds[1].called)
try: try:
yield res_deferreds[1] yield res_deferreds[1]
self.assertFalse("unsigned json didn't cause a failure") self.assertFalse("unsigned json didn't cause a failure")
@ -108,19 +163,67 @@ class KeyringTestCase(unittest.TestCase):
pass pass
self.assertFalse(res_deferreds[0].called) self.assertFalse(res_deferreds[0].called)
res_deferreds[0].addBoth(self.check_context, None)
# wait a tick for it to send the request to the perspectives server # wait a tick for it to send the request to the perspectives server
# (it first tries the datastore) # (it first tries the datastore)
yield async.sleep(0.005) yield async.sleep(0.005)
self.http_client.post_json.assert_called_once() self.http_client.post_json.assert_called_once()
# a second request for a server with outstanding requests should self.assertIs(LoggingContext.current_context(), context_11)
# block rather than start a second call
context_12 = LoggingContext("12")
context_12.test_key = "12"
with logcontext.PreserveLoggingContext(context_12):
# a second request for a server with outstanding requests
# should block rather than start a second call
self.http_client.post_json.reset_mock() self.http_client.post_json.reset_mock()
self.http_client.post_json.return_value = defer.Deferred() self.http_client.post_json.return_value = defer.Deferred()
kr.verify_json_objects_for_server( res_deferreds_2 = kr.verify_json_objects_for_server(
[("server1", json1)], [("server10", json1)],
) )
yield async.sleep(0.005) yield async.sleep(0.005)
self.http_client.post_json.assert_not_called() self.http_client.post_json.assert_not_called()
res_deferreds_2[0].addBoth(self.check_context, None)
# complete the first request
with logcontext.PreserveLoggingContext():
persp_deferred.callback(persp_resp)
self.assertIs(LoggingContext.current_context(), context_11)
with logcontext.PreserveLoggingContext():
yield res_deferreds[0]
yield res_deferreds_2[0]
@defer.inlineCallbacks
def test_verify_json_for_server(self):
kr = keyring.Keyring(self.hs)
key1 = signedjson.key.generate_signing_key(1)
yield self.hs.datastore.store_server_verify_key(
"server9", "", time.time() * 1000,
signedjson.key.get_verify_key(key1),
)
json1 = {}
signedjson.sign.sign_json(json1, "server9", key1)
sentinel_context = LoggingContext.current_context()
with LoggingContext("one") as context_one:
context_one.test_key = "one"
defer = kr.verify_json_for_server("server9", {})
try:
yield defer
self.fail("should fail on unsigned json")
except SynapseError:
pass
self.assertIs(LoggingContext.current_context(), context_one)
defer = kr.verify_json_for_server("server9", json1)
self.assertFalse(defer.called)
self.assertIs(LoggingContext.current_context(), sentinel_context)
yield defer
self.assertIs(LoggingContext.current_context(), context_one)