Hook up the push rules to the notifier
parent
2223204eba
commit
ddf9e7b302
|
@ -647,8 +647,8 @@ class MessageHandler(BaseHandler):
|
|||
user_id, messages, is_peeking=is_peeking
|
||||
)
|
||||
|
||||
start_token = StreamToken(token[0], 0, 0, 0, 0)
|
||||
end_token = StreamToken(token[1], 0, 0, 0, 0)
|
||||
start_token = StreamToken.START.copy_and_replace("room_key", token[0])
|
||||
end_token = StreamToken.START.copy_and_replace("room_key", token[1])
|
||||
|
||||
time_now = self.clock.time_msec()
|
||||
|
||||
|
|
|
@ -284,7 +284,7 @@ class Notifier(object):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def wait_for_events(self, user_id, timeout, callback, room_ids=None,
|
||||
from_token=StreamToken("s0", "0", "0", "0", "0")):
|
||||
from_token=StreamToken.START):
|
||||
"""Wait until the callback returns a non empty response or the
|
||||
timeout fires.
|
||||
"""
|
||||
|
|
|
@ -36,6 +36,11 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
|||
SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = (
|
||||
"Unrecognised request: You probably wanted a trailing slash")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(PushRuleRestServlet, self).__init__(hs)
|
||||
self.store = hs.get_datastore()
|
||||
self.notifier = hs.get_notifier()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request):
|
||||
spec = _rule_spec_from_path(request.postpath)
|
||||
|
@ -51,8 +56,11 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
|||
|
||||
content = _parse_json(request)
|
||||
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
if 'attr' in spec:
|
||||
yield self.set_rule_attr(requester.user.to_string(), spec, content)
|
||||
yield self.set_rule_attr(user_id, spec, content)
|
||||
self.notify_user(user_id)
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
if spec['rule_id'].startswith('.'):
|
||||
|
@ -77,8 +85,8 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
|||
after = _namespaced_rule_id(spec, after[0])
|
||||
|
||||
try:
|
||||
yield self.hs.get_datastore().add_push_rule(
|
||||
user_id=requester.user.to_string(),
|
||||
yield self.store.add_push_rule(
|
||||
user_id=user_id,
|
||||
rule_id=_namespaced_rule_id_from_spec(spec),
|
||||
priority_class=priority_class,
|
||||
conditions=conditions,
|
||||
|
@ -86,6 +94,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
|||
before=before,
|
||||
after=after
|
||||
)
|
||||
self.notify_user(user_id)
|
||||
except InconsistentRuleException as e:
|
||||
raise SynapseError(400, e.message)
|
||||
except RuleNotFoundException as e:
|
||||
|
@ -98,13 +107,15 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
|||
spec = _rule_spec_from_path(request.postpath)
|
||||
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
|
||||
|
||||
try:
|
||||
yield self.hs.get_datastore().delete_push_rule(
|
||||
requester.user.to_string(), namespaced_rule_id
|
||||
yield self.store.delete_push_rule(
|
||||
user_id, namespaced_rule_id
|
||||
)
|
||||
self.notify_user(user_id)
|
||||
defer.returnValue((200, {}))
|
||||
except StoreError as e:
|
||||
if e.code == 404:
|
||||
|
@ -115,14 +126,12 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
|||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user = requester.user
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
# we build up the full structure and then decide which bits of it
|
||||
# to send which means doing unnecessary work sometimes but is
|
||||
# is probably not going to make a whole lot of difference
|
||||
rawrules = yield self.hs.get_datastore().get_push_rules_for_user(
|
||||
user.to_string()
|
||||
)
|
||||
rawrules = yield self.store.get_push_rules_for_user(user_id)
|
||||
|
||||
ruleslist = []
|
||||
for rawrule in rawrules:
|
||||
|
@ -138,8 +147,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
|||
|
||||
rules['global'] = _add_empty_priority_class_arrays(rules['global'])
|
||||
|
||||
enabled_map = yield self.hs.get_datastore().\
|
||||
get_push_rules_enabled_for_user(user.to_string())
|
||||
enabled_map = yield self.store.get_push_rules_enabled_for_user(user_id)
|
||||
|
||||
for r in ruleslist:
|
||||
rulearray = None
|
||||
|
@ -152,9 +160,9 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
|||
|
||||
pattern_type = c.pop("pattern_type", None)
|
||||
if pattern_type == "user_id":
|
||||
c["pattern"] = user.to_string()
|
||||
c["pattern"] = user_id
|
||||
elif pattern_type == "user_localpart":
|
||||
c["pattern"] = user.localpart
|
||||
c["pattern"] = requester.user.localpart
|
||||
|
||||
rulearray = rules['global'][template_name]
|
||||
|
||||
|
@ -188,6 +196,12 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
|||
def on_OPTIONS(self, _):
|
||||
return 200, {}
|
||||
|
||||
def notify_user(self, user_id):
|
||||
stream_id = self.store.get_push_rules_stream_token()
|
||||
self.notifier.on_new_event(
|
||||
"push_rules_key", stream_id, users=[user_id]
|
||||
)
|
||||
|
||||
def set_rule_attr(self, user_id, spec, val):
|
||||
if spec['attr'] == 'enabled':
|
||||
if isinstance(val, dict) and "enabled" in val:
|
||||
|
@ -198,7 +212,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
|||
# bools directly, so let's not break them.
|
||||
raise SynapseError(400, "Value for 'enabled' must be boolean")
|
||||
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
|
||||
return self.hs.get_datastore().set_push_rule_enabled(
|
||||
return self.store.set_push_rule_enabled(
|
||||
user_id, namespaced_rule_id, val
|
||||
)
|
||||
elif spec['attr'] == 'actions':
|
||||
|
@ -210,7 +224,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
|||
if is_default_rule:
|
||||
if namespaced_rule_id not in BASE_RULE_IDS:
|
||||
raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,))
|
||||
return self.hs.get_datastore().set_push_rule_actions(
|
||||
return self.store.set_push_rule_actions(
|
||||
user_id, namespaced_rule_id, actions, is_default_rule
|
||||
)
|
||||
else:
|
||||
|
|
|
@ -38,9 +38,12 @@ class EventSources(object):
|
|||
name: cls(hs)
|
||||
for name, cls in EventSources.SOURCE_TYPES.items()
|
||||
}
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_current_token(self, direction='f'):
|
||||
push_rules_key, _ = self.store.get_push_rules_stream_token()
|
||||
|
||||
token = StreamToken(
|
||||
room_key=(
|
||||
yield self.sources["room"].get_current_key(direction)
|
||||
|
@ -57,5 +60,6 @@ class EventSources(object):
|
|||
account_data_key=(
|
||||
yield self.sources["account_data"].get_current_key()
|
||||
),
|
||||
push_rules_key=push_rules_key,
|
||||
)
|
||||
defer.returnValue(token)
|
||||
|
|
|
@ -115,6 +115,7 @@ class StreamToken(
|
|||
"typing_key",
|
||||
"receipt_key",
|
||||
"account_data_key",
|
||||
"push_rules_key",
|
||||
))
|
||||
):
|
||||
_SEPARATOR = "_"
|
||||
|
@ -150,6 +151,7 @@ class StreamToken(
|
|||
or (int(other.typing_key) < int(self.typing_key))
|
||||
or (int(other.receipt_key) < int(self.receipt_key))
|
||||
or (int(other.account_data_key) < int(self.account_data_key))
|
||||
or (int(other.push_rules_key) < int(self.push_rules_key))
|
||||
)
|
||||
|
||||
def copy_and_advance(self, key, new_value):
|
||||
|
@ -174,6 +176,11 @@ class StreamToken(
|
|||
return StreamToken(**d)
|
||||
|
||||
|
||||
StreamToken.START = StreamToken(
|
||||
*(["s0"] + ["0"] * (len(StreamToken._fields) - 1))
|
||||
)
|
||||
|
||||
|
||||
class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
|
||||
"""Tokens are positions between events. The token "s1" comes after event 1.
|
||||
|
||||
|
|
Loading…
Reference in New Issue