Split PushRulesStore

pull/2898/head
Erik Johnston 2018-02-21 10:39:27 +00:00
parent a2b25de68d
commit cbaad969f9
3 changed files with 61 additions and 45 deletions

View File

@ -15,29 +15,15 @@
from .events import SlavedEventStore from .events import SlavedEventStore
from ._slaved_id_tracker import SlavedIdTracker from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore from synapse.storage.push_rule import PushRulesWorkerStore
from synapse.storage.push_rule import PushRuleStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedPushRuleStore(SlavedEventStore): class SlavedPushRuleStore(PushRulesWorkerStore, SlavedEventStore):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(SlavedPushRuleStore, self).__init__(db_conn, hs)
self._push_rules_stream_id_gen = SlavedIdTracker( self._push_rules_stream_id_gen = SlavedIdTracker(
db_conn, "push_rules_stream", "stream_id", db_conn, "push_rules_stream", "stream_id",
) )
self.push_rules_stream_cache = StreamChangeCache( super(SlavedPushRuleStore, self).__init__(db_conn, hs)
"PushRulesStreamChangeCache",
self._push_rules_stream_id_gen.get_current_token(),
)
get_push_rules_for_user = PushRuleStore.__dict__["get_push_rules_for_user"]
get_push_rules_enabled_for_user = (
PushRuleStore.__dict__["get_push_rules_enabled_for_user"]
)
have_push_rules_changed_for_user = (
DataStore.have_push_rules_changed_for_user.__func__
)
def get_push_rules_stream_token(self): def get_push_rules_stream_token(self):
return ( return (
@ -45,6 +31,9 @@ class SlavedPushRuleStore(SlavedEventStore):
self._stream_id_gen.get_current_token(), self._stream_id_gen.get_current_token(),
) )
def get_max_push_rules_stream_id(self):
return self._push_rules_stream_id_gen.get_current_token()
def stream_positions(self): def stream_positions(self):
result = super(SlavedPushRuleStore, self).stream_positions() result = super(SlavedPushRuleStore, self).stream_positions()
result["push_rules"] = self._push_rules_stream_id_gen.get_current_token() result["push_rules"] = self._push_rules_stream_id_gen.get_current_token()

View File

@ -177,18 +177,6 @@ class DataStore(RoomMemberStore, RoomStore,
prefilled_cache=presence_cache_prefill prefilled_cache=presence_cache_prefill
) )
push_rules_prefill, push_rules_id = self._get_cache_dict(
db_conn, "push_rules_stream",
entity_column="user_id",
stream_column="stream_id",
max_value=self._push_rules_stream_id_gen.get_current_token()[0],
)
self.push_rules_stream_cache = StreamChangeCache(
"PushRulesStreamChangeCache", push_rules_id,
prefilled_cache=push_rules_prefill,
)
max_device_inbox_id = self._device_inbox_id_gen.get_current_token() max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
device_inbox_prefill, min_device_inbox_id = self._get_cache_dict( device_inbox_prefill, min_device_inbox_id = self._get_cache_dict(
db_conn, "device_inbox", db_conn, "device_inbox",

View File

@ -15,10 +15,12 @@
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.push.baserules import list_with_base_rules from synapse.push.baserules import list_with_base_rules
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from twisted.internet import defer from twisted.internet import defer
import abc
import logging import logging
import simplejson as json import simplejson as json
@ -48,7 +50,39 @@ def _load_rules(rawrules, enabled_map):
return rules return rules
class PushRuleStore(SQLBaseStore): class PushRulesWorkerStore(SQLBaseStore):
"""This is an abstract base class where subclasses must implement
`get_max_push_rules_stream_id` which can be called in the initializer.
"""
# This ABCMeta metaclass ensures that we cannot be instantiated without
# the abstract methods being implemented.
__metaclass__ = abc.ABCMeta
def __init__(self, db_conn, hs):
super(PushRulesWorkerStore, self).__init__(db_conn, hs)
push_rules_prefill, push_rules_id = self._get_cache_dict(
db_conn, "push_rules_stream",
entity_column="user_id",
stream_column="stream_id",
max_value=self.get_max_push_rules_stream_id(),
)
self.push_rules_stream_cache = StreamChangeCache(
"PushRulesStreamChangeCache", push_rules_id,
prefilled_cache=push_rules_prefill,
)
@abc.abstractmethod
def get_max_push_rules_stream_id(self):
"""Get the position of the push rules stream.
Returns:
int
"""
raise NotImplementedError()
@cachedInlineCallbacks(max_entries=5000) @cachedInlineCallbacks(max_entries=5000)
def get_push_rules_for_user(self, user_id): def get_push_rules_for_user(self, user_id):
rows = yield self._simple_select_list( rows = yield self._simple_select_list(
@ -89,6 +123,24 @@ class PushRuleStore(SQLBaseStore):
r['rule_id']: False if r['enabled'] == 0 else True for r in results r['rule_id']: False if r['enabled'] == 0 else True for r in results
}) })
def have_push_rules_changed_for_user(self, user_id, last_id):
if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
return defer.succeed(False)
else:
def have_push_rules_changed_txn(txn):
sql = (
"SELECT COUNT(stream_id) FROM push_rules_stream"
" WHERE user_id = ? AND ? < stream_id"
)
txn.execute(sql, (user_id, last_id))
count, = txn.fetchone()
return bool(count)
return self.runInteraction(
"have_push_rules_changed", have_push_rules_changed_txn
)
class PushRuleStore(PushRulesWorkerStore):
@cachedList(cached_method_name="get_push_rules_for_user", @cachedList(cached_method_name="get_push_rules_for_user",
list_name="user_ids", num_args=1, inlineCallbacks=True) list_name="user_ids", num_args=1, inlineCallbacks=True)
def bulk_get_push_rules(self, user_ids): def bulk_get_push_rules(self, user_ids):
@ -526,21 +578,8 @@ class PushRuleStore(SQLBaseStore):
room stream ordering it corresponds to.""" room stream ordering it corresponds to."""
return self._push_rules_stream_id_gen.get_current_token() return self._push_rules_stream_id_gen.get_current_token()
def have_push_rules_changed_for_user(self, user_id, last_id): def get_max_push_rules_stream_id(self):
if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id): return self.get_push_rules_stream_token()[0]
return defer.succeed(False)
else:
def have_push_rules_changed_txn(txn):
sql = (
"SELECT COUNT(stream_id) FROM push_rules_stream"
" WHERE user_id = ? AND ? < stream_id"
)
txn.execute(sql, (user_id, last_id))
count, = txn.fetchone()
return bool(count)
return self.runInteraction(
"have_push_rules_changed", have_push_rules_changed_txn
)
class RuleNotFoundException(Exception): class RuleNotFoundException(Exception):