Merge pull request #209 from matrix-org/erikj/cached_keyword_args
Add support for using keyword arguments with cached functionspull/214/head
						commit
						8049c9a71e
					
				| 
						 | 
					@ -27,6 +27,7 @@ from twisted.internet import defer
 | 
				
			||||||
from collections import namedtuple, OrderedDict
 | 
					from collections import namedtuple, OrderedDict
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import functools
 | 
					import functools
 | 
				
			||||||
 | 
					import inspect
 | 
				
			||||||
import sys
 | 
					import sys
 | 
				
			||||||
import time
 | 
					import time
 | 
				
			||||||
import threading
 | 
					import threading
 | 
				
			||||||
| 
						 | 
					@ -141,13 +142,28 @@ class CacheDescriptor(object):
 | 
				
			||||||
    which can be used to insert values into the cache specifically, without
 | 
					    which can be used to insert values into the cache specifically, without
 | 
				
			||||||
    calling the calculation function.
 | 
					    calling the calculation function.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    def __init__(self, orig, max_entries=1000, num_args=1, lru=True):
 | 
					    def __init__(self, orig, max_entries=1000, num_args=1, lru=True,
 | 
				
			||||||
 | 
					                 inlineCallbacks=False):
 | 
				
			||||||
        self.orig = orig
 | 
					        self.orig = orig
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if inlineCallbacks:
 | 
				
			||||||
 | 
					            self.function_to_call = defer.inlineCallbacks(orig)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.function_to_call = orig
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.max_entries = max_entries
 | 
					        self.max_entries = max_entries
 | 
				
			||||||
        self.num_args = num_args
 | 
					        self.num_args = num_args
 | 
				
			||||||
        self.lru = lru
 | 
					        self.lru = lru
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if len(self.arg_names) < self.num_args:
 | 
				
			||||||
 | 
					            raise Exception(
 | 
				
			||||||
 | 
					                "Not enough explicit positional arguments to key off of for %r."
 | 
				
			||||||
 | 
					                " (@cached cannot key off of *args or **kwars)"
 | 
				
			||||||
 | 
					                % (orig.__name__,)
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __get__(self, obj, objtype=None):
 | 
					    def __get__(self, obj, objtype=None):
 | 
				
			||||||
        cache = Cache(
 | 
					        cache = Cache(
 | 
				
			||||||
            name=self.orig.__name__,
 | 
					            name=self.orig.__name__,
 | 
				
			||||||
| 
						 | 
					@ -158,11 +174,13 @@ class CacheDescriptor(object):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        @functools.wraps(self.orig)
 | 
					        @functools.wraps(self.orig)
 | 
				
			||||||
        @defer.inlineCallbacks
 | 
					        @defer.inlineCallbacks
 | 
				
			||||||
        def wrapped(*keyargs):
 | 
					        def wrapped(*args, **kwargs):
 | 
				
			||||||
 | 
					            arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
 | 
				
			||||||
 | 
					            keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
 | 
				
			||||||
            try:
 | 
					            try:
 | 
				
			||||||
                cached_result = cache.get(*keyargs[:self.num_args])
 | 
					                cached_result = cache.get(*keyargs)
 | 
				
			||||||
                if DEBUG_CACHES:
 | 
					                if DEBUG_CACHES:
 | 
				
			||||||
                    actual_result = yield self.orig(obj, *keyargs)
 | 
					                    actual_result = yield self.function_to_call(obj, *args, **kwargs)
 | 
				
			||||||
                    if actual_result != cached_result:
 | 
					                    if actual_result != cached_result:
 | 
				
			||||||
                        logger.error(
 | 
					                        logger.error(
 | 
				
			||||||
                            "Stale cache entry %s%r: cached: %r, actual %r",
 | 
					                            "Stale cache entry %s%r: cached: %r, actual %r",
 | 
				
			||||||
| 
						 | 
					@ -177,9 +195,9 @@ class CacheDescriptor(object):
 | 
				
			||||||
                # while the SELECT is executing (SYN-369)
 | 
					                # while the SELECT is executing (SYN-369)
 | 
				
			||||||
                sequence = cache.sequence
 | 
					                sequence = cache.sequence
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                ret = yield self.orig(obj, *keyargs)
 | 
					                ret = yield self.function_to_call(obj, *args, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                cache.update(sequence, *keyargs[:self.num_args] + (ret,))
 | 
					                cache.update(sequence, *(keyargs + [ret]))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                defer.returnValue(ret)
 | 
					                defer.returnValue(ret)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -201,6 +219,16 @@ def cached(max_entries=1000, num_args=1, lru=True):
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False):
 | 
				
			||||||
 | 
					    return lambda orig: CacheDescriptor(
 | 
				
			||||||
 | 
					        orig,
 | 
				
			||||||
 | 
					        max_entries=max_entries,
 | 
				
			||||||
 | 
					        num_args=num_args,
 | 
				
			||||||
 | 
					        lru=lru,
 | 
				
			||||||
 | 
					        inlineCallbacks=True,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class LoggingTransaction(object):
 | 
					class LoggingTransaction(object):
 | 
				
			||||||
    """An object that almost-transparently proxies for the 'txn' object
 | 
					    """An object that almost-transparently proxies for the 'txn' object
 | 
				
			||||||
    passed to the constructor. Adds logging and metrics to the .execute()
 | 
					    passed to the constructor. Adds logging and metrics to the .execute()
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -13,7 +13,7 @@
 | 
				
			||||||
# See the License for the specific language governing permissions and
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
# limitations under the License.
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from _base import SQLBaseStore, cached
 | 
					from _base import SQLBaseStore, cachedInlineCallbacks
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from twisted.internet import defer
 | 
					from twisted.internet import defer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -71,8 +71,7 @@ class KeyStore(SQLBaseStore):
 | 
				
			||||||
            desc="store_server_certificate",
 | 
					            desc="store_server_certificate",
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @cached()
 | 
					    @cachedInlineCallbacks()
 | 
				
			||||||
    @defer.inlineCallbacks
 | 
					 | 
				
			||||||
    def get_all_server_verify_keys(self, server_name):
 | 
					    def get_all_server_verify_keys(self, server_name):
 | 
				
			||||||
        rows = yield self._simple_select_list(
 | 
					        rows = yield self._simple_select_list(
 | 
				
			||||||
            table="server_signature_keys",
 | 
					            table="server_signature_keys",
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -13,7 +13,7 @@
 | 
				
			||||||
# See the License for the specific language governing permissions and
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
# limitations under the License.
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ._base import SQLBaseStore, cached
 | 
					from ._base import SQLBaseStore, cachedInlineCallbacks
 | 
				
			||||||
from twisted.internet import defer
 | 
					from twisted.internet import defer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import logging
 | 
					import logging
 | 
				
			||||||
| 
						 | 
					@ -23,8 +23,7 @@ logger = logging.getLogger(__name__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class PushRuleStore(SQLBaseStore):
 | 
					class PushRuleStore(SQLBaseStore):
 | 
				
			||||||
    @cached()
 | 
					    @cachedInlineCallbacks()
 | 
				
			||||||
    @defer.inlineCallbacks
 | 
					 | 
				
			||||||
    def get_push_rules_for_user(self, user_name):
 | 
					    def get_push_rules_for_user(self, user_name):
 | 
				
			||||||
        rows = yield self._simple_select_list(
 | 
					        rows = yield self._simple_select_list(
 | 
				
			||||||
            table=PushRuleTable.table_name,
 | 
					            table=PushRuleTable.table_name,
 | 
				
			||||||
| 
						 | 
					@ -41,8 +40,7 @@ class PushRuleStore(SQLBaseStore):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        defer.returnValue(rows)
 | 
					        defer.returnValue(rows)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @cached()
 | 
					    @cachedInlineCallbacks()
 | 
				
			||||||
    @defer.inlineCallbacks
 | 
					 | 
				
			||||||
    def get_push_rules_enabled_for_user(self, user_name):
 | 
					    def get_push_rules_enabled_for_user(self, user_name):
 | 
				
			||||||
        results = yield self._simple_select_list(
 | 
					        results = yield self._simple_select_list(
 | 
				
			||||||
            table=PushRuleEnableTable.table_name,
 | 
					            table=PushRuleEnableTable.table_name,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -13,7 +13,7 @@
 | 
				
			||||||
# See the License for the specific language governing permissions and
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
# limitations under the License.
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ._base import SQLBaseStore, cached
 | 
					from ._base import SQLBaseStore, cachedInlineCallbacks
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from twisted.internet import defer
 | 
					from twisted.internet import defer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -128,8 +128,7 @@ class ReceiptsStore(SQLBaseStore):
 | 
				
			||||||
    def get_max_receipt_stream_id(self):
 | 
					    def get_max_receipt_stream_id(self):
 | 
				
			||||||
        return self._receipts_id_gen.get_max_token(self)
 | 
					        return self._receipts_id_gen.get_max_token(self)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @cached
 | 
					    @cachedInlineCallbacks()
 | 
				
			||||||
    @defer.inlineCallbacks
 | 
					 | 
				
			||||||
    def get_graph_receipts_for_room(self, room_id):
 | 
					    def get_graph_receipts_for_room(self, room_id):
 | 
				
			||||||
        """Get receipts for sending to remote servers.
 | 
					        """Get receipts for sending to remote servers.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -17,7 +17,7 @@ from twisted.internet import defer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from synapse.api.errors import StoreError
 | 
					from synapse.api.errors import StoreError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ._base import SQLBaseStore, cached
 | 
					from ._base import SQLBaseStore, cachedInlineCallbacks
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import collections
 | 
					import collections
 | 
				
			||||||
import logging
 | 
					import logging
 | 
				
			||||||
| 
						 | 
					@ -186,8 +186,7 @@ class RoomStore(SQLBaseStore):
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @cached()
 | 
					    @cachedInlineCallbacks()
 | 
				
			||||||
    @defer.inlineCallbacks
 | 
					 | 
				
			||||||
    def get_room_name_and_aliases(self, room_id):
 | 
					    def get_room_name_and_aliases(self, room_id):
 | 
				
			||||||
        def f(txn):
 | 
					        def f(txn):
 | 
				
			||||||
            sql = (
 | 
					            sql = (
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -13,7 +13,7 @@
 | 
				
			||||||
# See the License for the specific language governing permissions and
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
# limitations under the License.
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ._base import SQLBaseStore, cached
 | 
					from ._base import SQLBaseStore, cached, cachedInlineCallbacks
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from twisted.internet import defer
 | 
					from twisted.internet import defer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -189,8 +189,7 @@ class StateStore(SQLBaseStore):
 | 
				
			||||||
        events = yield self._get_events(event_ids, get_prev_content=False)
 | 
					        events = yield self._get_events(event_ids, get_prev_content=False)
 | 
				
			||||||
        defer.returnValue(events)
 | 
					        defer.returnValue(events)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @cached(num_args=3)
 | 
					    @cachedInlineCallbacks(num_args=3)
 | 
				
			||||||
    @defer.inlineCallbacks
 | 
					 | 
				
			||||||
    def get_current_state_for_key(self, room_id, event_type, state_key):
 | 
					    def get_current_state_for_key(self, room_id, event_type, state_key):
 | 
				
			||||||
        def f(txn):
 | 
					        def f(txn):
 | 
				
			||||||
            sql = (
 | 
					            sql = (
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue