Fix caching of remote servers' signature keys
The `@cached` decorator on `KeyStore._get_server_verify_key` was missing its `num_args` parameter, which meant that it was returning the wrong key for any server which had more than one recorded key. By way of a fix, change the default for `num_args` to be *all* arguments. To implement that, factor out a common base class for `CacheDescriptor` and `CacheListDescriptor`.pull/2042/head
							parent
							
								
									37a187bfab
								
							
						
					
					
						commit
						95f21c7a66
					
				|  | @ -189,7 +189,55 @@ class Cache(object): | |||
|         self.cache.clear() | ||||
| 
 | ||||
| 
 | ||||
| class CacheDescriptor(object): | ||||
| class _CacheDescriptorBase(object): | ||||
|     def __init__(self, orig, num_args, inlineCallbacks, cache_context=False): | ||||
|         self.orig = orig | ||||
| 
 | ||||
|         if inlineCallbacks: | ||||
|             self.function_to_call = defer.inlineCallbacks(orig) | ||||
|         else: | ||||
|             self.function_to_call = orig | ||||
| 
 | ||||
|         arg_spec = inspect.getargspec(orig) | ||||
|         all_args = arg_spec.args | ||||
| 
 | ||||
|         if "cache_context" in all_args: | ||||
|             if not cache_context: | ||||
|                 raise ValueError( | ||||
|                     "Cannot have a 'cache_context' arg without setting" | ||||
|                     " cache_context=True" | ||||
|                 ) | ||||
|         elif cache_context: | ||||
|             raise ValueError( | ||||
|                 "Cannot have cache_context=True without having an arg" | ||||
|                 " named `cache_context`" | ||||
|             ) | ||||
| 
 | ||||
|         if num_args is None: | ||||
|             num_args = len(all_args) - 1 | ||||
|             if cache_context: | ||||
|                 num_args -= 1 | ||||
| 
 | ||||
|         if len(all_args) < num_args + 1: | ||||
|             raise Exception( | ||||
|                 "Not enough explicit positional arguments to key off for %r: " | ||||
|                 "got %i args, but wanted %i. (@cached cannot key off *args or " | ||||
|                 "**kwargs)" | ||||
|                 % (orig.__name__, len(all_args), num_args) | ||||
|             ) | ||||
| 
 | ||||
|         self.num_args = num_args | ||||
|         self.arg_names = all_args[1:num_args + 1] | ||||
| 
 | ||||
|         if "cache_context" in self.arg_names: | ||||
|             raise Exception( | ||||
|                 "cache_context arg cannot be included among the cache keys" | ||||
|             ) | ||||
| 
 | ||||
|         self.add_cache_context = cache_context | ||||
| 
 | ||||
| 
 | ||||
| class CacheDescriptor(_CacheDescriptorBase): | ||||
|     """ A method decorator that applies a memoizing cache around the function. | ||||
| 
 | ||||
|     This caches deferreds, rather than the results themselves. Deferreds that | ||||
|  | @ -217,52 +265,24 @@ class CacheDescriptor(object): | |||
|             r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate) | ||||
|             defer.returnValue(r1 + r2) | ||||
| 
 | ||||
|     Args: | ||||
|         num_args (int): number of positional arguments (excluding ``self`` and | ||||
|             ``cache_context``) to use as cache keys. Defaults to all named | ||||
|             args of the function. | ||||
|     """ | ||||
|     def __init__(self, orig, max_entries=1000, num_args=1, tree=False, | ||||
|     def __init__(self, orig, max_entries=1000, num_args=None, tree=False, | ||||
|                  inlineCallbacks=False, cache_context=False, iterable=False): | ||||
| 
 | ||||
|         super(CacheDescriptor, self).__init__( | ||||
|             orig, num_args=num_args, inlineCallbacks=inlineCallbacks, | ||||
|             cache_context=cache_context) | ||||
| 
 | ||||
|         max_entries = int(max_entries * CACHE_SIZE_FACTOR) | ||||
| 
 | ||||
|         self.orig = orig | ||||
| 
 | ||||
|         if inlineCallbacks: | ||||
|             self.function_to_call = defer.inlineCallbacks(orig) | ||||
|         else: | ||||
|             self.function_to_call = orig | ||||
| 
 | ||||
|         self.max_entries = max_entries | ||||
|         self.num_args = num_args | ||||
|         self.tree = tree | ||||
| 
 | ||||
|         self.iterable = iterable | ||||
| 
 | ||||
|         all_args = inspect.getargspec(orig) | ||||
|         self.arg_names = all_args.args[1:num_args + 1] | ||||
| 
 | ||||
|         if "cache_context" in all_args.args: | ||||
|             if not cache_context: | ||||
|                 raise ValueError( | ||||
|                     "Cannot have a 'cache_context' arg without setting" | ||||
|                     " cache_context=True" | ||||
|                 ) | ||||
|             try: | ||||
|                 self.arg_names.remove("cache_context") | ||||
|             except ValueError: | ||||
|                 pass | ||||
|         elif cache_context: | ||||
|             raise ValueError( | ||||
|                 "Cannot have cache_context=True without having an arg" | ||||
|                 " named `cache_context`" | ||||
|             ) | ||||
| 
 | ||||
|         self.add_cache_context = cache_context | ||||
| 
 | ||||
|         if len(self.arg_names) < self.num_args: | ||||
|             raise Exception( | ||||
|                 "Not enough explicit positional arguments to key off of for %r." | ||||
|                 " (@cached cannot key off of *args or **kwargs)" | ||||
|                 % (orig.__name__,) | ||||
|             ) | ||||
| 
 | ||||
|     def __get__(self, obj, objtype=None): | ||||
|         cache = Cache( | ||||
|             name=self.orig.__name__, | ||||
|  | @ -338,48 +358,36 @@ class CacheDescriptor(object): | |||
|         return wrapped | ||||
| 
 | ||||
| 
 | ||||
| class CacheListDescriptor(object): | ||||
| class CacheListDescriptor(_CacheDescriptorBase): | ||||
|     """Wraps an existing cache to support bulk fetching of keys. | ||||
| 
 | ||||
|     Given a list of keys it looks in the cache to find any hits, then passes | ||||
|     the list of missing keys to the wrapped fucntion. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, orig, cached_method_name, list_name, num_args=1, | ||||
|     def __init__(self, orig, cached_method_name, list_name, num_args=None, | ||||
|                  inlineCallbacks=False): | ||||
|         """ | ||||
|         Args: | ||||
|             orig (function) | ||||
|             method_name (str); The name of the chached method. | ||||
|             cached_method_name (str): The name of the chached method. | ||||
|             list_name (str): Name of the argument which is the bulk lookup list | ||||
|             num_args (int) | ||||
|             num_args (int): number of positional arguments (excluding ``self``, | ||||
|                 but including list_name) to use as cache keys. Defaults to all | ||||
|                 named args of the function. | ||||
|             inlineCallbacks (bool): Whether orig is a generator that should | ||||
|                 be wrapped by defer.inlineCallbacks | ||||
|         """ | ||||
|         self.orig = orig | ||||
|         super(CacheListDescriptor, self).__init__( | ||||
|             orig, num_args=num_args, inlineCallbacks=inlineCallbacks) | ||||
| 
 | ||||
|         if inlineCallbacks: | ||||
|             self.function_to_call = defer.inlineCallbacks(orig) | ||||
|         else: | ||||
|             self.function_to_call = orig | ||||
| 
 | ||||
|         self.num_args = num_args | ||||
|         self.list_name = list_name | ||||
| 
 | ||||
|         self.arg_names = inspect.getargspec(orig).args[1:num_args + 1] | ||||
|         self.list_pos = self.arg_names.index(self.list_name) | ||||
| 
 | ||||
|         self.cached_method_name = cached_method_name | ||||
| 
 | ||||
|         self.sentinel = object() | ||||
| 
 | ||||
|         if len(self.arg_names) < self.num_args: | ||||
|             raise Exception( | ||||
|                 "Not enough explicit positional arguments to key off of for %r." | ||||
|                 " (@cached cannot key off of *args or **kwars)" | ||||
|                 % (orig.__name__,) | ||||
|             ) | ||||
| 
 | ||||
|         if self.list_name not in self.arg_names: | ||||
|             raise Exception( | ||||
|                 "Couldn't see arguments %r for %r." | ||||
|  | @ -487,7 +495,7 @@ class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))): | |||
|         self.cache.invalidate(self.key) | ||||
| 
 | ||||
| 
 | ||||
| def cached(max_entries=1000, num_args=1, tree=False, cache_context=False, | ||||
| def cached(max_entries=1000, num_args=None, tree=False, cache_context=False, | ||||
|            iterable=False): | ||||
|     return lambda orig: CacheDescriptor( | ||||
|         orig, | ||||
|  | @ -499,8 +507,8 @@ def cached(max_entries=1000, num_args=1, tree=False, cache_context=False, | |||
|     ) | ||||
| 
 | ||||
| 
 | ||||
| def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False, | ||||
|                           iterable=False): | ||||
| def cachedInlineCallbacks(max_entries=1000, num_args=None, tree=False, | ||||
|                           cache_context=False, iterable=False): | ||||
|     return lambda orig: CacheDescriptor( | ||||
|         orig, | ||||
|         max_entries=max_entries, | ||||
|  | @ -512,7 +520,7 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_contex | |||
|     ) | ||||
| 
 | ||||
| 
 | ||||
| def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False): | ||||
| def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=False): | ||||
|     """Creates a descriptor that wraps a function in a `CacheListDescriptor`. | ||||
| 
 | ||||
|     Used to do batch lookups for an already created cache. A single argument | ||||
|  | @ -525,7 +533,8 @@ def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False) | |||
|         cache (Cache): The underlying cache to use. | ||||
|         list_name (str): The name of the argument that is the list to use to | ||||
|             do batch lookups in the cache. | ||||
|         num_args (int): Number of arguments to use as the key in the cache. | ||||
|         num_args (int): Number of arguments to use as the key in the cache | ||||
|             (including list_name). Defaults to all named parameters. | ||||
|         inlineCallbacks (bool): Should the function be wrapped in an | ||||
|             `defer.inlineCallbacks`? | ||||
| 
 | ||||
|  |  | |||
|  | @ -0,0 +1,53 @@ | |||
| # -*- coding: utf-8 -*- | ||||
| # Copyright 2017 Vector Creations Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| import signedjson.key | ||||
| from twisted.internet import defer | ||||
| 
 | ||||
| import tests.unittest | ||||
| import tests.utils | ||||
| 
 | ||||
| 
 | ||||
| class KeyStoreTestCase(tests.unittest.TestCase): | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         super(KeyStoreTestCase, self).__init__(*args, **kwargs) | ||||
|         self.store = None  # type: synapse.storage.keys.KeyStore | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def setUp(self): | ||||
|         hs = yield tests.utils.setup_test_homeserver() | ||||
|         self.store = hs.get_datastore() | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def test_get_server_verify_keys(self): | ||||
|         key1 = signedjson.key.decode_verify_key_base64( | ||||
|             "ed25519", "key1", "fP5l4JzpZPq/zdbBg5xx6lQGAAOM9/3w94cqiJ5jPrw" | ||||
|         ) | ||||
|         key2 = signedjson.key.decode_verify_key_base64( | ||||
|             "ed25519", "key2", "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw" | ||||
|         ) | ||||
|         yield self.store.store_server_verify_key( | ||||
|             "server1", "from_server", 0, key1 | ||||
|         ) | ||||
|         yield self.store.store_server_verify_key( | ||||
|             "server1", "from_server", 0, key2 | ||||
|         ) | ||||
| 
 | ||||
|         res = yield self.store.get_server_verify_keys( | ||||
|             "server1", ["ed25519:key1", "ed25519:key2", "ed25519:key3"]) | ||||
| 
 | ||||
|         self.assertEqual(len(res.keys()), 2) | ||||
|         self.assertEqual(res["ed25519:key1"].version, "key1") | ||||
|         self.assertEqual(res["ed25519:key2"].version, "key2") | ||||
|  | @ -0,0 +1,14 @@ | |||
| # -*- coding: utf-8 -*- | ||||
| # Copyright 2017 Vector Creations Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
|  | @ -0,0 +1,86 @@ | |||
| # -*- coding: utf-8 -*- | ||||
| # Copyright 2016 OpenMarket Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| import mock | ||||
| from twisted.internet import defer | ||||
| from synapse.util.caches import descriptors | ||||
| from tests import unittest | ||||
| 
 | ||||
| 
 | ||||
| class DescriptorTestCase(unittest.TestCase): | ||||
|     @defer.inlineCallbacks | ||||
|     def test_cache(self): | ||||
|         class Cls(object): | ||||
|             def __init__(self): | ||||
|                 self.mock = mock.Mock() | ||||
| 
 | ||||
|             @descriptors.cached() | ||||
|             def fn(self, arg1, arg2): | ||||
|                 return self.mock(arg1, arg2) | ||||
| 
 | ||||
|         obj = Cls() | ||||
| 
 | ||||
|         obj.mock.return_value = 'fish' | ||||
|         r = yield obj.fn(1, 2) | ||||
|         self.assertEqual(r, 'fish') | ||||
|         obj.mock.assert_called_once_with(1, 2) | ||||
|         obj.mock.reset_mock() | ||||
| 
 | ||||
|         # a call with different params should call the mock again | ||||
|         obj.mock.return_value = 'chips' | ||||
|         r = yield obj.fn(1, 3) | ||||
|         self.assertEqual(r, 'chips') | ||||
|         obj.mock.assert_called_once_with(1, 3) | ||||
|         obj.mock.reset_mock() | ||||
| 
 | ||||
|         # the two values should now be cached | ||||
|         r = yield obj.fn(1, 2) | ||||
|         self.assertEqual(r, 'fish') | ||||
|         r = yield obj.fn(1, 3) | ||||
|         self.assertEqual(r, 'chips') | ||||
|         obj.mock.assert_not_called() | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def test_cache_num_args(self): | ||||
|         """Only the first num_args arguments should matter to the cache""" | ||||
| 
 | ||||
|         class Cls(object): | ||||
|             def __init__(self): | ||||
|                 self.mock = mock.Mock() | ||||
| 
 | ||||
|             @descriptors.cached(num_args=1) | ||||
|             def fn(self, arg1, arg2): | ||||
|                 return self.mock(arg1, arg2) | ||||
| 
 | ||||
|         obj = Cls() | ||||
|         obj.mock.return_value = 'fish' | ||||
|         r = yield obj.fn(1, 2) | ||||
|         self.assertEqual(r, 'fish') | ||||
|         obj.mock.assert_called_once_with(1, 2) | ||||
|         obj.mock.reset_mock() | ||||
| 
 | ||||
|         # a call with different params should call the mock again | ||||
|         obj.mock.return_value = 'chips' | ||||
|         r = yield obj.fn(2, 3) | ||||
|         self.assertEqual(r, 'chips') | ||||
|         obj.mock.assert_called_once_with(2, 3) | ||||
|         obj.mock.reset_mock() | ||||
| 
 | ||||
|         # the two values should now be cached; we should be able to vary | ||||
|         # the second argument and still get the cached result. | ||||
|         r = yield obj.fn(1, 4) | ||||
|         self.assertEqual(r, 'fish') | ||||
|         r = yield obj.fn(2, 5) | ||||
|         self.assertEqual(r, 'chips') | ||||
|         obj.mock.assert_not_called() | ||||
		Loading…
	
		Reference in New Issue
	
	 Richard van der Hoff
						Richard van der Hoff