Merge branch 'develop' into client_v2_sync
						commit
						9c61556504
					
				| 
						 | 
				
			
			@ -21,7 +21,7 @@ from synapse.api.constants import EventTypes, Membership, JoinRules
 | 
			
		|||
from synapse.api.errors import AuthError, StoreError, Codes, SynapseError
 | 
			
		||||
from synapse.util.logutils import log_function
 | 
			
		||||
from synapse.util.async import run_on_reactor
 | 
			
		||||
from synapse.types import UserID
 | 
			
		||||
from synapse.types import UserID, ClientInfo
 | 
			
		||||
 | 
			
		||||
import logging
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -290,7 +290,9 @@ class Auth(object):
 | 
			
		|||
        Args:
 | 
			
		||||
            request - An HTTP request with an access_token query parameter.
 | 
			
		||||
        Returns:
 | 
			
		||||
            UserID : User ID object of the user making the request
 | 
			
		||||
            tuple : of UserID and device string:
 | 
			
		||||
                User ID object of the user making the request
 | 
			
		||||
                Client ID object of the client instance the user is using
 | 
			
		||||
        Raises:
 | 
			
		||||
            AuthError if no user by that token exists or the token is invalid.
 | 
			
		||||
        """
 | 
			
		||||
| 
						 | 
				
			
			@ -299,6 +301,8 @@ class Auth(object):
 | 
			
		|||
            access_token = request.args["access_token"][0]
 | 
			
		||||
            user_info = yield self.get_user_by_token(access_token)
 | 
			
		||||
            user = user_info["user"]
 | 
			
		||||
            device_id = user_info["device_id"]
 | 
			
		||||
            token_id = user_info["token_id"]
 | 
			
		||||
 | 
			
		||||
            ip_addr = self.hs.get_ip_from_request(request)
 | 
			
		||||
            user_agent = request.requestHeaders.getRawHeaders(
 | 
			
		||||
| 
						 | 
				
			
			@ -314,7 +318,7 @@ class Auth(object):
 | 
			
		|||
                    user_agent=user_agent
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            defer.returnValue(user)
 | 
			
		||||
            defer.returnValue((user, ClientInfo(device_id, token_id)))
 | 
			
		||||
        except KeyError:
 | 
			
		||||
            raise AuthError(403, "Missing access token.")
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -339,6 +343,7 @@ class Auth(object):
 | 
			
		|||
                "admin": bool(ret.get("admin", False)),
 | 
			
		||||
                "device_id": ret.get("device_id"),
 | 
			
		||||
                "user": UserID.from_string(ret.get("name")),
 | 
			
		||||
                "token_id": ret.get("token_id", None),
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            defer.returnValue(user_info)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -21,6 +21,7 @@ logger = logging.getLogger(__name__)
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
class Codes(object):
 | 
			
		||||
    UNRECOGNIZED = "M_UNRECOGNIZED"
 | 
			
		||||
    UNAUTHORIZED = "M_UNAUTHORIZED"
 | 
			
		||||
    FORBIDDEN = "M_FORBIDDEN"
 | 
			
		||||
    BAD_JSON = "M_BAD_JSON"
 | 
			
		||||
| 
						 | 
				
			
			@ -34,6 +35,7 @@ class Codes(object):
 | 
			
		|||
    LIMIT_EXCEEDED = "M_LIMIT_EXCEEDED"
 | 
			
		||||
    CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED"
 | 
			
		||||
    CAPTCHA_INVALID = "M_CAPTCHA_INVALID"
 | 
			
		||||
    MISSING_PARAM = "M_MISSING_PARAM",
 | 
			
		||||
    TOO_LARGE = "M_TOO_LARGE"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -81,6 +83,34 @@ class RegistrationError(SynapseError):
 | 
			
		|||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class UnrecognizedRequestError(SynapseError):
 | 
			
		||||
    """An error indicating we don't understand the request you're trying to make"""
 | 
			
		||||
    def __init__(self, *args, **kwargs):
 | 
			
		||||
        if "errcode" not in kwargs:
 | 
			
		||||
            kwargs["errcode"] = Codes.UNRECOGNIZED
 | 
			
		||||
        message = None
 | 
			
		||||
        if len(args) == 0:
 | 
			
		||||
            message = "Unrecognized request"
 | 
			
		||||
        else:
 | 
			
		||||
            message = args[0]
 | 
			
		||||
        super(UnrecognizedRequestError, self).__init__(
 | 
			
		||||
            400,
 | 
			
		||||
            message,
 | 
			
		||||
            **kwargs
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class NotFoundError(SynapseError):
 | 
			
		||||
    """An error indicating we can't find the thing you asked for"""
 | 
			
		||||
    def __init__(self, *args, **kwargs):
 | 
			
		||||
        if "errcode" not in kwargs:
 | 
			
		||||
            kwargs["errcode"] = Codes.NOT_FOUND
 | 
			
		||||
        super(NotFoundError, self).__init__(
 | 
			
		||||
            404,
 | 
			
		||||
            "Not found",
 | 
			
		||||
            **kwargs
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
class AuthError(SynapseError):
 | 
			
		||||
    """An error raised when there was a problem authorising an event."""
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -272,6 +272,8 @@ def setup():
 | 
			
		|||
        bind_port = None
 | 
			
		||||
    hs.start_listening(bind_port, config.unsecure_port)
 | 
			
		||||
 | 
			
		||||
    hs.get_pusherpool().start()
 | 
			
		||||
 | 
			
		||||
    if config.daemonize:
 | 
			
		||||
        print config.pid_file
 | 
			
		||||
        daemon = Daemonize(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -49,10 +49,11 @@ class EventStreamHandler(BaseHandler):
 | 
			
		|||
    @defer.inlineCallbacks
 | 
			
		||||
    @log_function
 | 
			
		||||
    def get_stream(self, auth_user_id, pagin_config, timeout=0,
 | 
			
		||||
                   as_client_event=True):
 | 
			
		||||
                   as_client_event=True, affect_presence=True):
 | 
			
		||||
        auth_user = UserID.from_string(auth_user_id)
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            if affect_presence:
 | 
			
		||||
                if auth_user not in self._streams_per_user:
 | 
			
		||||
                    self._streams_per_user[auth_user] = 0
 | 
			
		||||
                    if auth_user in self._stop_timer_per_user:
 | 
			
		||||
| 
						 | 
				
			
			@ -94,6 +95,7 @@ class EventStreamHandler(BaseHandler):
 | 
			
		|||
            defer.returnValue(chunk)
 | 
			
		||||
 | 
			
		||||
        finally:
 | 
			
		||||
            if affect_presence:
 | 
			
		||||
                self._streams_per_user[auth_user] -= 1
 | 
			
		||||
                if not self._streams_per_user[auth_user]:
 | 
			
		||||
                    del self._streams_per_user[auth_user]
 | 
			
		||||
| 
						 | 
				
			
			@ -107,7 +109,7 @@ class EventStreamHandler(BaseHandler):
 | 
			
		|||
 | 
			
		||||
                        self._stop_timer_per_user.pop(auth_user, None)
 | 
			
		||||
 | 
			
		||||
                    yield self.distributor.fire(
 | 
			
		||||
                        return self.distributor.fire(
 | 
			
		||||
                            "stopped_user_eventstream", auth_user
 | 
			
		||||
                        )
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -114,7 +114,8 @@ class MessageHandler(BaseHandler):
 | 
			
		|||
        defer.returnValue(chunk)
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def create_and_send_event(self, event_dict, ratelimit=True):
 | 
			
		||||
    def create_and_send_event(self, event_dict, ratelimit=True,
 | 
			
		||||
                              client=None, txn_id=None):
 | 
			
		||||
        """ Given a dict from a client, create and handle a new event.
 | 
			
		||||
 | 
			
		||||
        Creates an FrozenEvent object, filling out auth_events, prev_events,
 | 
			
		||||
| 
						 | 
				
			
			@ -148,6 +149,15 @@ class MessageHandler(BaseHandler):
 | 
			
		|||
                    builder.content
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        if client is not None:
 | 
			
		||||
            if client.token_id is not None:
 | 
			
		||||
                builder.internal_metadata.token_id = client.token_id
 | 
			
		||||
            if client.device_id is not None:
 | 
			
		||||
                builder.internal_metadata.device_id = client.device_id
 | 
			
		||||
 | 
			
		||||
        if txn_id is not None:
 | 
			
		||||
            builder.internal_metadata.txn_id = txn_id
 | 
			
		||||
 | 
			
		||||
        event, context = yield self._create_new_client_event(
 | 
			
		||||
            builder=builder,
 | 
			
		||||
        )
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -87,6 +87,10 @@ class PresenceHandler(BaseHandler):
 | 
			
		|||
            "changed_presencelike_data", self.changed_presencelike_data
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # outbound signal from the presence module to advertise when a user's
 | 
			
		||||
        # presence has changed
 | 
			
		||||
        distributor.declare("user_presence_changed")
 | 
			
		||||
 | 
			
		||||
        self.distributor = distributor
 | 
			
		||||
 | 
			
		||||
        self.federation = hs.get_replication_layer()
 | 
			
		||||
| 
						 | 
				
			
			@ -604,6 +608,7 @@ class PresenceHandler(BaseHandler):
 | 
			
		|||
            room_ids=room_ids,
 | 
			
		||||
            statuscache=statuscache,
 | 
			
		||||
        )
 | 
			
		||||
        yield self.distributor.fire("user_presence_changed", user, statuscache)
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def _push_presence_remote(self, user, destination, state=None):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -163,7 +163,7 @@ class RegistrationHandler(BaseHandler):
 | 
			
		|||
        # each request
 | 
			
		||||
        httpCli = SimpleHttpClient(self.hs)
 | 
			
		||||
        # XXX: make this configurable!
 | 
			
		||||
        trustedIdServers = ['matrix.org:8090']
 | 
			
		||||
        trustedIdServers = ['matrix.org:8090', 'matrix.org']
 | 
			
		||||
        if not creds['idServer'] in trustedIdServers:
 | 
			
		||||
            logger.warn('%s is not a trusted ID server: rejecting 3pid ' +
 | 
			
		||||
                        'credentials', creds['idServer'])
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -62,6 +62,25 @@ class SimpleHttpClient(object):
 | 
			
		|||
 | 
			
		||||
        defer.returnValue(json.loads(body))
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def post_json_get_json(self, uri, post_json):
 | 
			
		||||
        json_str = json.dumps(post_json)
 | 
			
		||||
 | 
			
		||||
        logger.info("HTTP POST %s -> %s", json_str, uri)
 | 
			
		||||
 | 
			
		||||
        response = yield self.agent.request(
 | 
			
		||||
            "POST",
 | 
			
		||||
            uri.encode("ascii"),
 | 
			
		||||
            headers=Headers({
 | 
			
		||||
                "Content-Type": ["application/json"]
 | 
			
		||||
            }),
 | 
			
		||||
            bodyProducer=FileBodyProducer(StringIO(json_str))
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        body = yield readBody(response)
 | 
			
		||||
 | 
			
		||||
        defer.returnValue(json.loads(body))
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def get_json(self, uri, args={}):
 | 
			
		||||
        """ Get's some json from the given host and path
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -16,7 +16,7 @@
 | 
			
		|||
 | 
			
		||||
from synapse.http.agent_name import AGENT_NAME
 | 
			
		||||
from synapse.api.errors import (
 | 
			
		||||
    cs_exception, SynapseError, CodeMessageException
 | 
			
		||||
    cs_exception, SynapseError, CodeMessageException, UnrecognizedRequestError
 | 
			
		||||
)
 | 
			
		||||
from synapse.util.logcontext import LoggingContext
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -139,11 +139,7 @@ class JsonResource(HttpServer, resource.Resource):
 | 
			
		|||
                    return
 | 
			
		||||
 | 
			
		||||
            # Huh. No one wanted to handle that? Fiiiiiine. Send 400.
 | 
			
		||||
            self._send_response(
 | 
			
		||||
                request,
 | 
			
		||||
                400,
 | 
			
		||||
                {"error": "Unrecognized request"}
 | 
			
		||||
            )
 | 
			
		||||
            raise UnrecognizedRequestError()
 | 
			
		||||
        except CodeMessageException as e:
 | 
			
		||||
            if isinstance(e, SynapseError):
 | 
			
		||||
                logger.info("%s SynapseError: %s - %s", request, e.code, e.msg)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,319 @@
 | 
			
		|||
# -*- coding: utf-8 -*-
 | 
			
		||||
# Copyright 2015 OpenMarket Ltd
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
from twisted.internet import defer
 | 
			
		||||
 | 
			
		||||
from synapse.streams.config import PaginationConfig
 | 
			
		||||
from synapse.types import StreamToken
 | 
			
		||||
 | 
			
		||||
import synapse.util.async
 | 
			
		||||
 | 
			
		||||
import logging
 | 
			
		||||
import fnmatch
 | 
			
		||||
import json
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Pusher(object):
 | 
			
		||||
    INITIAL_BACKOFF = 1000
 | 
			
		||||
    MAX_BACKOFF = 60 * 60 * 1000
 | 
			
		||||
    GIVE_UP_AFTER = 24 * 60 * 60 * 1000
 | 
			
		||||
    DEFAULT_ACTIONS = ['notify']
 | 
			
		||||
 | 
			
		||||
    def __init__(self, _hs, instance_handle, user_name, app_id,
 | 
			
		||||
                 app_display_name, device_display_name, pushkey, pushkey_ts,
 | 
			
		||||
                 data, last_token, last_success, failing_since):
 | 
			
		||||
        self.hs = _hs
 | 
			
		||||
        self.evStreamHandler = self.hs.get_handlers().event_stream_handler
 | 
			
		||||
        self.store = self.hs.get_datastore()
 | 
			
		||||
        self.clock = self.hs.get_clock()
 | 
			
		||||
        self.instance_handle = instance_handle
 | 
			
		||||
        self.user_name = user_name
 | 
			
		||||
        self.app_id = app_id
 | 
			
		||||
        self.app_display_name = app_display_name
 | 
			
		||||
        self.device_display_name = device_display_name
 | 
			
		||||
        self.pushkey = pushkey
 | 
			
		||||
        self.pushkey_ts = pushkey_ts
 | 
			
		||||
        self.data = data
 | 
			
		||||
        self.last_token = last_token
 | 
			
		||||
        self.last_success = last_success  # not actually used
 | 
			
		||||
        self.backoff_delay = Pusher.INITIAL_BACKOFF
 | 
			
		||||
        self.failing_since = failing_since
 | 
			
		||||
        self.alive = True
 | 
			
		||||
 | 
			
		||||
        # The last value of last_active_time that we saw
 | 
			
		||||
        self.last_last_active_time = 0
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def _actions_for_event(self, ev):
 | 
			
		||||
        """
 | 
			
		||||
        This should take into account notification settings that the user
 | 
			
		||||
        has configured both globally and per-room when we have the ability
 | 
			
		||||
        to do such things.
 | 
			
		||||
        """
 | 
			
		||||
        if ev['user_id'] == self.user_name:
 | 
			
		||||
            # let's assume you probably know about messages you sent yourself
 | 
			
		||||
            defer.returnValue(['dont_notify'])
 | 
			
		||||
 | 
			
		||||
        if ev['type'] == 'm.room.member':
 | 
			
		||||
            if ev['state_key'] != self.user_name:
 | 
			
		||||
                defer.returnValue(['dont_notify'])
 | 
			
		||||
 | 
			
		||||
        rules = yield self.store.get_push_rules_for_user_name(self.user_name)
 | 
			
		||||
 | 
			
		||||
        for r in rules:
 | 
			
		||||
            matches = True
 | 
			
		||||
 | 
			
		||||
            conditions = json.loads(r['conditions'])
 | 
			
		||||
            actions = json.loads(r['actions'])
 | 
			
		||||
 | 
			
		||||
            for c in conditions:
 | 
			
		||||
                matches &= self._event_fulfills_condition(ev, c)
 | 
			
		||||
            # ignore rules with no actions (we have an explict 'dont_notify'
 | 
			
		||||
            if len(actions) == 0:
 | 
			
		||||
                logger.warn(
 | 
			
		||||
                    "Ignoring rule id %s with no actions for user %s" %
 | 
			
		||||
                    (r['rule_id'], r['user_name'])
 | 
			
		||||
                )
 | 
			
		||||
                continue
 | 
			
		||||
            if matches:
 | 
			
		||||
                defer.returnValue(actions)
 | 
			
		||||
 | 
			
		||||
        defer.returnValue(Pusher.DEFAULT_ACTIONS)
 | 
			
		||||
 | 
			
		||||
    def _event_fulfills_condition(self, ev, condition):
 | 
			
		||||
        if condition['kind'] == 'event_match':
 | 
			
		||||
            if 'pattern' not in condition:
 | 
			
		||||
                logger.warn("event_match condition with no pattern")
 | 
			
		||||
                return False
 | 
			
		||||
            pat = condition['pattern']
 | 
			
		||||
 | 
			
		||||
            val = _value_for_dotted_key(condition['key'], ev)
 | 
			
		||||
            if fnmatch.fnmatch(val, pat):
 | 
			
		||||
                return True
 | 
			
		||||
            return False
 | 
			
		||||
        elif condition['kind'] == 'device':
 | 
			
		||||
            if 'instance_handle' not in condition:
 | 
			
		||||
                return True
 | 
			
		||||
            return condition['instance_handle'] == self.instance_handle
 | 
			
		||||
        else:
 | 
			
		||||
            return True
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def get_context_for_event(self, ev):
 | 
			
		||||
        name_aliases = yield self.store.get_room_name_and_aliases(
 | 
			
		||||
            ev['room_id']
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        ctx = {'aliases': name_aliases[1]}
 | 
			
		||||
        if name_aliases[0] is not None:
 | 
			
		||||
            ctx['name'] = name_aliases[0]
 | 
			
		||||
 | 
			
		||||
        defer.returnValue(ctx)
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def start(self):
 | 
			
		||||
        if not self.last_token:
 | 
			
		||||
            # First-time setup: get a token to start from (we can't
 | 
			
		||||
            # just start from no token, ie. 'now'
 | 
			
		||||
            # because we need the result to be reproduceable in case
 | 
			
		||||
            # we fail to dispatch the push)
 | 
			
		||||
            config = PaginationConfig(from_token=None, limit='1')
 | 
			
		||||
            chunk = yield self.evStreamHandler.get_stream(
 | 
			
		||||
                self.user_name, config, timeout=0)
 | 
			
		||||
            self.last_token = chunk['end']
 | 
			
		||||
            self.store.update_pusher_last_token(
 | 
			
		||||
                self.user_name, self.pushkey, self.last_token)
 | 
			
		||||
            logger.info("Pusher %s for user %s starting from token %s",
 | 
			
		||||
                        self.pushkey, self.user_name, self.last_token)
 | 
			
		||||
 | 
			
		||||
        while self.alive:
 | 
			
		||||
            from_tok = StreamToken.from_string(self.last_token)
 | 
			
		||||
            config = PaginationConfig(from_token=from_tok, limit='1')
 | 
			
		||||
            chunk = yield self.evStreamHandler.get_stream(
 | 
			
		||||
                self.user_name, config,
 | 
			
		||||
                timeout=100*365*24*60*60*1000, affect_presence=False
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            # limiting to 1 may get 1 event plus 1 presence event, so
 | 
			
		||||
            # pick out the actual event
 | 
			
		||||
            single_event = None
 | 
			
		||||
            for c in chunk['chunk']:
 | 
			
		||||
                if 'event_id' in c:  # Hmmm...
 | 
			
		||||
                    single_event = c
 | 
			
		||||
                    break
 | 
			
		||||
            if not single_event:
 | 
			
		||||
                self.last_token = chunk['end']
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
            if not self.alive:
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
            processed = False
 | 
			
		||||
            actions = yield self._actions_for_event(single_event)
 | 
			
		||||
            tweaks = _tweaks_for_actions(actions)
 | 
			
		||||
 | 
			
		||||
            if len(actions) == 0:
 | 
			
		||||
                logger.warn("Empty actions! Using default action.")
 | 
			
		||||
                actions = Pusher.DEFAULT_ACTIONS
 | 
			
		||||
            if 'notify' not in actions and 'dont_notify' not in actions:
 | 
			
		||||
                logger.warn("Neither notify nor dont_notify in actions: adding default")
 | 
			
		||||
                actions.extend(Pusher.DEFAULT_ACTIONS)
 | 
			
		||||
            if 'dont_notify' in actions:
 | 
			
		||||
                logger.debug(
 | 
			
		||||
                    "%s for %s: dont_notify",
 | 
			
		||||
                    single_event['event_id'], self.user_name
 | 
			
		||||
                )
 | 
			
		||||
                processed = True
 | 
			
		||||
            else:
 | 
			
		||||
                rejected = yield self.dispatch_push(single_event, tweaks)
 | 
			
		||||
                if isinstance(rejected, list) or isinstance(rejected, tuple):
 | 
			
		||||
                    processed = True
 | 
			
		||||
                    for pk in rejected:
 | 
			
		||||
                        if pk != self.pushkey:
 | 
			
		||||
                            # for sanity, we only remove the pushkey if it
 | 
			
		||||
                            # was the one we actually sent...
 | 
			
		||||
                            logger.warn(
 | 
			
		||||
                                ("Ignoring rejected pushkey %s because we "
 | 
			
		||||
                                "didn't send it"), pk
 | 
			
		||||
                            )
 | 
			
		||||
                        else:
 | 
			
		||||
                            logger.info(
 | 
			
		||||
                                "Pushkey %s was rejected: removing",
 | 
			
		||||
                                pk
 | 
			
		||||
                            )
 | 
			
		||||
                            yield self.hs.get_pusherpool().remove_pusher(
 | 
			
		||||
                                self.app_id, pk
 | 
			
		||||
                            )
 | 
			
		||||
 | 
			
		||||
            if not self.alive:
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
            if processed:
 | 
			
		||||
                self.backoff_delay = Pusher.INITIAL_BACKOFF
 | 
			
		||||
                self.last_token = chunk['end']
 | 
			
		||||
                self.store.update_pusher_last_token_and_success(
 | 
			
		||||
                    self.user_name,
 | 
			
		||||
                    self.pushkey,
 | 
			
		||||
                    self.last_token,
 | 
			
		||||
                    self.clock.time_msec()
 | 
			
		||||
                )
 | 
			
		||||
                if self.failing_since:
 | 
			
		||||
                    self.failing_since = None
 | 
			
		||||
                    self.store.update_pusher_failing_since(
 | 
			
		||||
                        self.user_name,
 | 
			
		||||
                        self.pushkey,
 | 
			
		||||
                        self.failing_since)
 | 
			
		||||
            else:
 | 
			
		||||
                if not self.failing_since:
 | 
			
		||||
                    self.failing_since = self.clock.time_msec()
 | 
			
		||||
                    self.store.update_pusher_failing_since(
 | 
			
		||||
                        self.user_name,
 | 
			
		||||
                        self.pushkey,
 | 
			
		||||
                        self.failing_since
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
                if (self.failing_since and
 | 
			
		||||
                   self.failing_since <
 | 
			
		||||
                   self.clock.time_msec() - Pusher.GIVE_UP_AFTER):
 | 
			
		||||
                    # we really only give up so that if the URL gets
 | 
			
		||||
                    # fixed, we don't suddenly deliver a load
 | 
			
		||||
                    # of old notifications.
 | 
			
		||||
                    logger.warn("Giving up on a notification to user %s, "
 | 
			
		||||
                                "pushkey %s",
 | 
			
		||||
                                self.user_name, self.pushkey
 | 
			
		||||
                    )
 | 
			
		||||
                    self.backoff_delay = Pusher.INITIAL_BACKOFF
 | 
			
		||||
                    self.last_token = chunk['end']
 | 
			
		||||
                    self.store.update_pusher_last_token(
 | 
			
		||||
                        self.user_name,
 | 
			
		||||
                        self.pushkey,
 | 
			
		||||
                        self.last_token
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
                    self.failing_since = None
 | 
			
		||||
                    self.store.update_pusher_failing_since(
 | 
			
		||||
                        self.user_name,
 | 
			
		||||
                        self.pushkey,
 | 
			
		||||
                        self.failing_since
 | 
			
		||||
                    )
 | 
			
		||||
                else:
 | 
			
		||||
                    logger.warn("Failed to dispatch push for user %s "
 | 
			
		||||
                                "(failing for %dms)."
 | 
			
		||||
                                "Trying again in %dms",
 | 
			
		||||
                                self.user_name,
 | 
			
		||||
                                self.clock.time_msec() - self.failing_since,
 | 
			
		||||
                                self.backoff_delay
 | 
			
		||||
                    )
 | 
			
		||||
                    yield synapse.util.async.sleep(self.backoff_delay / 1000.0)
 | 
			
		||||
                    self.backoff_delay *= 2
 | 
			
		||||
                    if self.backoff_delay > Pusher.MAX_BACKOFF:
 | 
			
		||||
                        self.backoff_delay = Pusher.MAX_BACKOFF
 | 
			
		||||
 | 
			
		||||
    def stop(self):
 | 
			
		||||
        self.alive = False
 | 
			
		||||
 | 
			
		||||
    def dispatch_push(self, p, tweaks):
 | 
			
		||||
        """
 | 
			
		||||
        Overridden by implementing classes to actually deliver the notification
 | 
			
		||||
        Args:
 | 
			
		||||
            p: The event to notify for as a single event from the event stream
 | 
			
		||||
        Returns: If the notification was delivered, an array containing any
 | 
			
		||||
                 pushkeys that were rejected by the push gateway.
 | 
			
		||||
                 False if the notification could not be delivered (ie.
 | 
			
		||||
                 should be retried).
 | 
			
		||||
        """
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def reset_badge_count(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def presence_changed(self, state):
 | 
			
		||||
        """
 | 
			
		||||
        We clear badge counts whenever a user's last_active time is bumped
 | 
			
		||||
        This is by no means perfect but I think it's the best we can do
 | 
			
		||||
        without read receipts.
 | 
			
		||||
        """
 | 
			
		||||
        if 'last_active' in state.state:
 | 
			
		||||
            last_active = state.state['last_active']
 | 
			
		||||
            if last_active > self.last_last_active_time:
 | 
			
		||||
                logger.info("Resetting badge count for %s", self.user_name)
 | 
			
		||||
                self.reset_badge_count()
 | 
			
		||||
                self.last_last_active_time = last_active
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _value_for_dotted_key(dotted_key, event):
 | 
			
		||||
    parts = dotted_key.split(".")
 | 
			
		||||
    val = event
 | 
			
		||||
    while len(parts) > 0:
 | 
			
		||||
        if parts[0] not in val:
 | 
			
		||||
            return None
 | 
			
		||||
        val = val[parts[0]]
 | 
			
		||||
        parts = parts[1:]
 | 
			
		||||
    return val
 | 
			
		||||
 | 
			
		||||
def _tweaks_for_actions(actions):
 | 
			
		||||
    tweaks = {}
 | 
			
		||||
    for a in actions:
 | 
			
		||||
        if not isinstance(a, dict):
 | 
			
		||||
            continue
 | 
			
		||||
        if 'set_sound' in a:
 | 
			
		||||
            tweaks['sound'] = a['set_sound']
 | 
			
		||||
    return tweaks
 | 
			
		||||
 | 
			
		||||
class PusherConfigException(Exception):
 | 
			
		||||
    def __init__(self, msg):
 | 
			
		||||
        super(PusherConfigException, self).__init__(msg)
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,145 @@
 | 
			
		|||
# -*- coding: utf-8 -*-
 | 
			
		||||
# Copyright 2015 OpenMarket Ltd
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
from synapse.push import Pusher, PusherConfigException
 | 
			
		||||
from synapse.http.client import SimpleHttpClient
 | 
			
		||||
 | 
			
		||||
from twisted.internet import defer
 | 
			
		||||
 | 
			
		||||
import logging
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class HttpPusher(Pusher):
 | 
			
		||||
    def __init__(self, _hs, instance_handle, user_name, app_id,
 | 
			
		||||
                 app_display_name, device_display_name, pushkey, pushkey_ts,
 | 
			
		||||
                 data, last_token, last_success, failing_since):
 | 
			
		||||
        super(HttpPusher, self).__init__(
 | 
			
		||||
            _hs,
 | 
			
		||||
            instance_handle,
 | 
			
		||||
            user_name,
 | 
			
		||||
            app_id,
 | 
			
		||||
            app_display_name,
 | 
			
		||||
            device_display_name,
 | 
			
		||||
            pushkey,
 | 
			
		||||
            pushkey_ts,
 | 
			
		||||
            data,
 | 
			
		||||
            last_token,
 | 
			
		||||
            last_success,
 | 
			
		||||
            failing_since
 | 
			
		||||
        )
 | 
			
		||||
        if 'url' not in data:
 | 
			
		||||
            raise PusherConfigException(
 | 
			
		||||
                "'url' required in data for HTTP pusher"
 | 
			
		||||
            )
 | 
			
		||||
        self.url = data['url']
 | 
			
		||||
        self.httpCli = SimpleHttpClient(self.hs)
 | 
			
		||||
        self.data_minus_url = {}
 | 
			
		||||
        self.data_minus_url.update(self.data)
 | 
			
		||||
        del self.data_minus_url['url']
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def _build_notification_dict(self, event, tweaks):
 | 
			
		||||
        # we probably do not want to push for every presence update
 | 
			
		||||
        # (we may want to be able to set up notifications when specific
 | 
			
		||||
        # people sign in, but we'd want to only deliver the pertinent ones)
 | 
			
		||||
        # Actually, presence events will not get this far now because we
 | 
			
		||||
        # need to filter them out in the main Pusher code.
 | 
			
		||||
        if 'event_id' not in event:
 | 
			
		||||
            defer.returnValue(None)
 | 
			
		||||
 | 
			
		||||
        ctx = yield self.get_context_for_event(event)
 | 
			
		||||
 | 
			
		||||
        d = {
 | 
			
		||||
            'notification': {
 | 
			
		||||
                'id': event['event_id'],
 | 
			
		||||
                'type': event['type'],
 | 
			
		||||
                'from': event['user_id'],
 | 
			
		||||
                # we may have to fetch this over federation and we
 | 
			
		||||
                # can't trust it anyway: is it worth it?
 | 
			
		||||
                #'from_display_name': 'Steve Stevington'
 | 
			
		||||
                'counts': { #-- we don't mark messages as read yet so
 | 
			
		||||
                # we have no way of knowing
 | 
			
		||||
                    # Just set the badge to 1 until we have read receipts
 | 
			
		||||
                    'unread': 1,
 | 
			
		||||
                #    'missed_calls': 2
 | 
			
		||||
                },
 | 
			
		||||
                'devices': [
 | 
			
		||||
                    {
 | 
			
		||||
                        'app_id': self.app_id,
 | 
			
		||||
                        'pushkey': self.pushkey,
 | 
			
		||||
                        'pushkey_ts': long(self.pushkey_ts / 1000),
 | 
			
		||||
                        'data': self.data_minus_url,
 | 
			
		||||
                        'tweaks': tweaks
 | 
			
		||||
                    }
 | 
			
		||||
                ]
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        if event['type'] == 'm.room.member':
 | 
			
		||||
            d['notification']['membership'] = event['content']['membership']
 | 
			
		||||
 | 
			
		||||
        if len(ctx['aliases']):
 | 
			
		||||
            d['notification']['room_alias'] = ctx['aliases'][0]
 | 
			
		||||
        if 'name' in ctx:
 | 
			
		||||
            d['notification']['room_name'] = ctx['name']
 | 
			
		||||
 | 
			
		||||
        defer.returnValue(d)
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def dispatch_push(self, event, tweaks):
 | 
			
		||||
        notification_dict = yield self._build_notification_dict(event, tweaks)
 | 
			
		||||
        if not notification_dict:
 | 
			
		||||
            defer.returnValue([])
 | 
			
		||||
        try:
 | 
			
		||||
            resp = yield self.httpCli.post_json_get_json(self.url, notification_dict)
 | 
			
		||||
        except:
 | 
			
		||||
            logger.exception("Failed to push %s ", self.url)
 | 
			
		||||
            defer.returnValue(False)
 | 
			
		||||
        rejected = []
 | 
			
		||||
        if 'rejected' in resp:
 | 
			
		||||
            rejected = resp['rejected']
 | 
			
		||||
        defer.returnValue(rejected)
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def reset_badge_count(self):
 | 
			
		||||
        d = {
 | 
			
		||||
            'notification': {
 | 
			
		||||
                'id': '',
 | 
			
		||||
                'type': None,
 | 
			
		||||
                'from': '',
 | 
			
		||||
                'counts': {
 | 
			
		||||
                    'unread': 0,
 | 
			
		||||
                    'missed_calls': 0
 | 
			
		||||
                },
 | 
			
		||||
                'devices': [
 | 
			
		||||
                    {
 | 
			
		||||
                        'app_id': self.app_id,
 | 
			
		||||
                        'pushkey': self.pushkey,
 | 
			
		||||
                        'pushkey_ts': long(self.pushkey_ts / 1000),
 | 
			
		||||
                        'data': self.data_minus_url,
 | 
			
		||||
                    }
 | 
			
		||||
                ]
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        try:
 | 
			
		||||
            resp = yield self.httpCli.post_json_get_json(self.url, d)
 | 
			
		||||
        except:
 | 
			
		||||
            logger.exception("Failed to push %s ", self.url)
 | 
			
		||||
            defer.returnValue(False)
 | 
			
		||||
        rejected = []
 | 
			
		||||
        if 'rejected' in resp:
 | 
			
		||||
            rejected = resp['rejected']
 | 
			
		||||
        defer.returnValue(rejected)
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,152 @@
 | 
			
		|||
#!/usr/bin/env python
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
# Copyright 2015 OpenMarket Ltd
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
from twisted.internet import defer
 | 
			
		||||
 | 
			
		||||
from httppusher import HttpPusher
 | 
			
		||||
from synapse.push import PusherConfigException
 | 
			
		||||
 | 
			
		||||
import logging
 | 
			
		||||
import json
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PusherPool:
 | 
			
		||||
    def __init__(self, _hs):
 | 
			
		||||
        self.hs = _hs
 | 
			
		||||
        self.store = self.hs.get_datastore()
 | 
			
		||||
        self.pushers = {}
 | 
			
		||||
        self.last_pusher_started = -1
 | 
			
		||||
 | 
			
		||||
        distributor = self.hs.get_distributor()
 | 
			
		||||
        distributor.observe(
 | 
			
		||||
            "user_presence_changed", self.user_presence_changed
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def user_presence_changed(self, user, state):
 | 
			
		||||
        user_name = user.to_string()
 | 
			
		||||
 | 
			
		||||
        # until we have read receipts, pushers use this to reset a user's
 | 
			
		||||
        # badge counters to zero
 | 
			
		||||
        for p in self.pushers.values():
 | 
			
		||||
            if p.user_name == user_name:
 | 
			
		||||
                yield p.presence_changed(state)
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def start(self):
 | 
			
		||||
        pushers = yield self.store.get_all_pushers()
 | 
			
		||||
        for p in pushers:
 | 
			
		||||
            p['data'] = json.loads(p['data'])
 | 
			
		||||
        self._start_pushers(pushers)
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def add_pusher(self, user_name, instance_handle, kind, app_id,
 | 
			
		||||
                   app_display_name, device_display_name, pushkey, lang, data):
 | 
			
		||||
        # we try to create the pusher just to validate the config: it
 | 
			
		||||
        # will then get pulled out of the database,
 | 
			
		||||
        # recreated, added and started: this means we have only one
 | 
			
		||||
        # code path adding pushers.
 | 
			
		||||
        self._create_pusher({
 | 
			
		||||
            "user_name": user_name,
 | 
			
		||||
            "kind": kind,
 | 
			
		||||
            "instance_handle": instance_handle,
 | 
			
		||||
            "app_id": app_id,
 | 
			
		||||
            "app_display_name": app_display_name,
 | 
			
		||||
            "device_display_name": device_display_name,
 | 
			
		||||
            "pushkey": pushkey,
 | 
			
		||||
            "pushkey_ts": self.hs.get_clock().time_msec(),
 | 
			
		||||
            "lang": lang,
 | 
			
		||||
            "data": data,
 | 
			
		||||
            "last_token": None,
 | 
			
		||||
            "last_success": None,
 | 
			
		||||
            "failing_since": None
 | 
			
		||||
        })
 | 
			
		||||
        yield self._add_pusher_to_store(
 | 
			
		||||
            user_name, instance_handle, kind, app_id,
 | 
			
		||||
            app_display_name, device_display_name,
 | 
			
		||||
            pushkey, lang, data
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def _add_pusher_to_store(self, user_name, instance_handle, kind, app_id,
 | 
			
		||||
                             app_display_name, device_display_name,
 | 
			
		||||
                             pushkey, lang, data):
 | 
			
		||||
        yield self.store.add_pusher(
 | 
			
		||||
            user_name=user_name,
 | 
			
		||||
            instance_handle=instance_handle,
 | 
			
		||||
            kind=kind,
 | 
			
		||||
            app_id=app_id,
 | 
			
		||||
            app_display_name=app_display_name,
 | 
			
		||||
            device_display_name=device_display_name,
 | 
			
		||||
            pushkey=pushkey,
 | 
			
		||||
            pushkey_ts=self.hs.get_clock().time_msec(),
 | 
			
		||||
            lang=lang,
 | 
			
		||||
            data=json.dumps(data)
 | 
			
		||||
        )
 | 
			
		||||
        self._refresh_pusher((app_id, pushkey))
 | 
			
		||||
 | 
			
		||||
    def _create_pusher(self, pusherdict):
 | 
			
		||||
        if pusherdict['kind'] == 'http':
 | 
			
		||||
            return HttpPusher(
 | 
			
		||||
                self.hs,
 | 
			
		||||
                instance_handle=pusherdict['instance_handle'],
 | 
			
		||||
                user_name=pusherdict['user_name'],
 | 
			
		||||
                app_id=pusherdict['app_id'],
 | 
			
		||||
                app_display_name=pusherdict['app_display_name'],
 | 
			
		||||
                device_display_name=pusherdict['device_display_name'],
 | 
			
		||||
                pushkey=pusherdict['pushkey'],
 | 
			
		||||
                pushkey_ts=pusherdict['pushkey_ts'],
 | 
			
		||||
                data=pusherdict['data'],
 | 
			
		||||
                last_token=pusherdict['last_token'],
 | 
			
		||||
                last_success=pusherdict['last_success'],
 | 
			
		||||
                failing_since=pusherdict['failing_since']
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            raise PusherConfigException(
 | 
			
		||||
                "Unknown pusher type '%s' for user %s" %
 | 
			
		||||
                (pusherdict['kind'], pusherdict['user_name'])
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def _refresh_pusher(self, app_id_pushkey):
 | 
			
		||||
        p = yield self.store.get_pushers_by_app_id_and_pushkey(
 | 
			
		||||
            app_id_pushkey
 | 
			
		||||
        )
 | 
			
		||||
        p['data'] = json.loads(p['data'])
 | 
			
		||||
 | 
			
		||||
        self._start_pushers([p])
 | 
			
		||||
 | 
			
		||||
    def _start_pushers(self, pushers):
 | 
			
		||||
        logger.info("Starting %d pushers", len(pushers))
 | 
			
		||||
        for pusherdict in pushers:
 | 
			
		||||
            p = self._create_pusher(pusherdict)
 | 
			
		||||
            if p:
 | 
			
		||||
                fullid = "%s:%s" % (pusherdict['app_id'], pusherdict['pushkey'])
 | 
			
		||||
                if fullid in self.pushers:
 | 
			
		||||
                    self.pushers[fullid].stop()
 | 
			
		||||
                self.pushers[fullid] = p
 | 
			
		||||
                p.start()
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def remove_pusher(self, app_id, pushkey):
 | 
			
		||||
        fullid = "%s:%s" % (app_id, pushkey)
 | 
			
		||||
        if fullid in self.pushers:
 | 
			
		||||
            logger.info("Stopping pusher %s", fullid)
 | 
			
		||||
            self.pushers[fullid].stop()
 | 
			
		||||
            del self.pushers[fullid]
 | 
			
		||||
        yield self.store.delete_pusher_by_app_id_pushkey(app_id, pushkey)
 | 
			
		||||
| 
						 | 
				
			
			@ -13,10 +13,9 @@
 | 
			
		|||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from . import (
 | 
			
		||||
    room, events, register, login, profile, presence, initial_sync, directory,
 | 
			
		||||
    voip, admin,
 | 
			
		||||
    voip, admin, pusher, push_rule
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
from synapse.http.server import JsonResource
 | 
			
		||||
| 
						 | 
				
			
			@ -41,3 +40,5 @@ class ClientV1RestResource(JsonResource):
 | 
			
		|||
        directory.register_servlets(hs, client_resource)
 | 
			
		||||
        voip.register_servlets(hs, client_resource)
 | 
			
		||||
        admin.register_servlets(hs, client_resource)
 | 
			
		||||
        pusher.register_servlets(hs, client_resource)
 | 
			
		||||
        push_rule.register_servlets(hs, client_resource)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -31,7 +31,7 @@ class WhoisRestServlet(ClientV1RestServlet):
 | 
			
		|||
    @defer.inlineCallbacks
 | 
			
		||||
    def on_GET(self, request, user_id):
 | 
			
		||||
        target_user = UserID.from_string(user_id)
 | 
			
		||||
        auth_user = yield self.auth.get_user_by_req(request)
 | 
			
		||||
        auth_user, client = yield self.auth.get_user_by_req(request)
 | 
			
		||||
        is_admin = yield self.auth.is_server_admin(auth_user)
 | 
			
		||||
 | 
			
		||||
        if not is_admin and target_user != auth_user:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -45,7 +45,7 @@ class ClientDirectoryServer(ClientV1RestServlet):
 | 
			
		|||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def on_PUT(self, request, room_alias):
 | 
			
		||||
        user = yield self.auth.get_user_by_req(request)
 | 
			
		||||
        user, client = yield self.auth.get_user_by_req(request)
 | 
			
		||||
 | 
			
		||||
        content = _parse_json(request)
 | 
			
		||||
        if not "room_id" in content:
 | 
			
		||||
| 
						 | 
				
			
			@ -85,7 +85,7 @@ class ClientDirectoryServer(ClientV1RestServlet):
 | 
			
		|||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def on_DELETE(self, request, room_alias):
 | 
			
		||||
        user = yield self.auth.get_user_by_req(request)
 | 
			
		||||
        user, client = yield self.auth.get_user_by_req(request)
 | 
			
		||||
 | 
			
		||||
        is_admin = yield self.auth.is_server_admin(user)
 | 
			
		||||
        if not is_admin:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -34,7 +34,7 @@ class EventStreamRestServlet(ClientV1RestServlet):
 | 
			
		|||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def on_GET(self, request):
 | 
			
		||||
        auth_user = yield self.auth.get_user_by_req(request)
 | 
			
		||||
        auth_user, client = yield self.auth.get_user_by_req(request)
 | 
			
		||||
        try:
 | 
			
		||||
            handler = self.handlers.event_stream_handler
 | 
			
		||||
            pagin_config = PaginationConfig.from_request(request)
 | 
			
		||||
| 
						 | 
				
			
			@ -71,7 +71,7 @@ class EventRestServlet(ClientV1RestServlet):
 | 
			
		|||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def on_GET(self, request, event_id):
 | 
			
		||||
        auth_user = yield self.auth.get_user_by_req(request)
 | 
			
		||||
        auth_user, client = yield self.auth.get_user_by_req(request)
 | 
			
		||||
        handler = self.handlers.event_handler
 | 
			
		||||
        event = yield handler.get_event(auth_user, event_id)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -25,7 +25,7 @@ class InitialSyncRestServlet(ClientV1RestServlet):
 | 
			
		|||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def on_GET(self, request):
 | 
			
		||||
        user = yield self.auth.get_user_by_req(request)
 | 
			
		||||
        user, client = yield self.auth.get_user_by_req(request)
 | 
			
		||||
        with_feedback = "feedback" in request.args
 | 
			
		||||
        as_client_event = "raw" not in request.args
 | 
			
		||||
        pagination_config = PaginationConfig.from_request(request)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -32,7 +32,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
 | 
			
		|||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def on_GET(self, request, user_id):
 | 
			
		||||
        auth_user = yield self.auth.get_user_by_req(request)
 | 
			
		||||
        auth_user, client = yield self.auth.get_user_by_req(request)
 | 
			
		||||
        user = UserID.from_string(user_id)
 | 
			
		||||
 | 
			
		||||
        state = yield self.handlers.presence_handler.get_state(
 | 
			
		||||
| 
						 | 
				
			
			@ -42,7 +42,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
 | 
			
		|||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def on_PUT(self, request, user_id):
 | 
			
		||||
        auth_user = yield self.auth.get_user_by_req(request)
 | 
			
		||||
        auth_user, client = yield self.auth.get_user_by_req(request)
 | 
			
		||||
        user = UserID.from_string(user_id)
 | 
			
		||||
 | 
			
		||||
        state = {}
 | 
			
		||||
| 
						 | 
				
			
			@ -77,7 +77,7 @@ class PresenceListRestServlet(ClientV1RestServlet):
 | 
			
		|||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def on_GET(self, request, user_id):
 | 
			
		||||
        auth_user = yield self.auth.get_user_by_req(request)
 | 
			
		||||
        auth_user, client = yield self.auth.get_user_by_req(request)
 | 
			
		||||
        user = UserID.from_string(user_id)
 | 
			
		||||
 | 
			
		||||
        if not self.hs.is_mine(user):
 | 
			
		||||
| 
						 | 
				
			
			@ -97,7 +97,7 @@ class PresenceListRestServlet(ClientV1RestServlet):
 | 
			
		|||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def on_POST(self, request, user_id):
 | 
			
		||||
        auth_user = yield self.auth.get_user_by_req(request)
 | 
			
		||||
        auth_user, client = yield self.auth.get_user_by_req(request)
 | 
			
		||||
        user = UserID.from_string(user_id)
 | 
			
		||||
 | 
			
		||||
        if not self.hs.is_mine(user):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -37,7 +37,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
 | 
			
		|||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def on_PUT(self, request, user_id):
 | 
			
		||||
        auth_user = yield self.auth.get_user_by_req(request)
 | 
			
		||||
        auth_user, client = yield self.auth.get_user_by_req(request)
 | 
			
		||||
        user = UserID.from_string(user_id)
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
| 
						 | 
				
			
			@ -70,7 +70,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
 | 
			
		|||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def on_PUT(self, request, user_id):
 | 
			
		||||
        auth_user = yield self.auth.get_user_by_req(request)
 | 
			
		||||
        auth_user, client = yield self.auth.get_user_by_req(request)
 | 
			
		||||
        user = UserID.from_string(user_id)
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,386 @@
 | 
			
		|||
# -*- coding: utf-8 -*-
 | 
			
		||||
# Copyright 2014 OpenMarket Ltd
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
from twisted.internet import defer
 | 
			
		||||
 | 
			
		||||
from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError, NotFoundError, \
 | 
			
		||||
    StoreError
 | 
			
		||||
from .base import ClientV1RestServlet, client_path_pattern
 | 
			
		||||
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
 | 
			
		||||
 | 
			
		||||
import json
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PushRuleRestServlet(ClientV1RestServlet):
 | 
			
		||||
    PATTERN = client_path_pattern("/pushrules/.*$")
 | 
			
		||||
    PRIORITY_CLASS_MAP = {
 | 
			
		||||
        'underride': 0,
 | 
			
		||||
        'sender': 1,
 | 
			
		||||
        'room': 2,
 | 
			
		||||
        'content': 3,
 | 
			
		||||
        'override': 4
 | 
			
		||||
    }
 | 
			
		||||
    PRIORITY_CLASS_INVERSE_MAP = {v: k for k,v in PRIORITY_CLASS_MAP.items()}
 | 
			
		||||
    SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = (
 | 
			
		||||
        "Unrecognised request: You probably wanted a trailing slash")
 | 
			
		||||
 | 
			
		||||
    def rule_spec_from_path(self, path):
 | 
			
		||||
        if len(path) < 2:
 | 
			
		||||
            raise UnrecognizedRequestError()
 | 
			
		||||
        if path[0] != 'pushrules':
 | 
			
		||||
            raise UnrecognizedRequestError()
 | 
			
		||||
 | 
			
		||||
        scope = path[1]
 | 
			
		||||
        path = path[2:]
 | 
			
		||||
        if scope not in ['global', 'device']:
 | 
			
		||||
            raise UnrecognizedRequestError()
 | 
			
		||||
 | 
			
		||||
        device = None
 | 
			
		||||
        if scope == 'device':
 | 
			
		||||
            if len(path) == 0:
 | 
			
		||||
                raise UnrecognizedRequestError()
 | 
			
		||||
            device = path[0]
 | 
			
		||||
            path = path[1:]
 | 
			
		||||
 | 
			
		||||
        if len(path) == 0:
 | 
			
		||||
            raise UnrecognizedRequestError()
 | 
			
		||||
 | 
			
		||||
        template = path[0]
 | 
			
		||||
        path = path[1:]
 | 
			
		||||
 | 
			
		||||
        if len(path) == 0:
 | 
			
		||||
            raise UnrecognizedRequestError()
 | 
			
		||||
 | 
			
		||||
        rule_id = path[0]
 | 
			
		||||
 | 
			
		||||
        spec = {
 | 
			
		||||
            'scope': scope,
 | 
			
		||||
            'template': template,
 | 
			
		||||
            'rule_id': rule_id
 | 
			
		||||
        }
 | 
			
		||||
        if device:
 | 
			
		||||
            spec['device'] = device
 | 
			
		||||
        return spec
 | 
			
		||||
 | 
			
		||||
    def rule_tuple_from_request_object(self, rule_template, rule_id, req_obj, device=None):
 | 
			
		||||
        if rule_template in ['override', 'underride']:
 | 
			
		||||
            if 'conditions' not in req_obj:
 | 
			
		||||
                raise InvalidRuleException("Missing 'conditions'")
 | 
			
		||||
            conditions = req_obj['conditions']
 | 
			
		||||
            for c in conditions:
 | 
			
		||||
                if 'kind' not in c:
 | 
			
		||||
                    raise InvalidRuleException("Condition without 'kind'")
 | 
			
		||||
        elif rule_template == 'room':
 | 
			
		||||
            conditions = [{
 | 
			
		||||
                'kind': 'event_match',
 | 
			
		||||
                'key': 'room_id',
 | 
			
		||||
                'pattern': rule_id
 | 
			
		||||
            }]
 | 
			
		||||
        elif rule_template == 'sender':
 | 
			
		||||
            conditions = [{
 | 
			
		||||
                'kind': 'event_match',
 | 
			
		||||
                'key': 'user_id',
 | 
			
		||||
                'pattern': rule_id
 | 
			
		||||
            }]
 | 
			
		||||
        elif rule_template == 'content':
 | 
			
		||||
            if 'pattern' not in req_obj:
 | 
			
		||||
                raise InvalidRuleException("Content rule missing 'pattern'")
 | 
			
		||||
            pat = req_obj['pattern']
 | 
			
		||||
            if pat.strip("*?[]") == pat:
 | 
			
		||||
                # no special glob characters so we assume the user means
 | 
			
		||||
                # 'contains this string' rather than 'is this string'
 | 
			
		||||
                pat = "*%s*" % (pat,)
 | 
			
		||||
            conditions = [{
 | 
			
		||||
                'kind': 'event_match',
 | 
			
		||||
                'key': 'content.body',
 | 
			
		||||
                'pattern': pat
 | 
			
		||||
            }]
 | 
			
		||||
        else:
 | 
			
		||||
            raise InvalidRuleException("Unknown rule template: %s" % (rule_template,))
 | 
			
		||||
 | 
			
		||||
        if device:
 | 
			
		||||
            conditions.append({
 | 
			
		||||
                'kind': 'device',
 | 
			
		||||
                'instance_handle': device
 | 
			
		||||
            })
 | 
			
		||||
 | 
			
		||||
        if 'actions' not in req_obj:
 | 
			
		||||
            raise InvalidRuleException("No actions found")
 | 
			
		||||
        actions = req_obj['actions']
 | 
			
		||||
 | 
			
		||||
        for a in actions:
 | 
			
		||||
            if a in ['notify', 'dont_notify', 'coalesce']:
 | 
			
		||||
                pass
 | 
			
		||||
            elif isinstance(a, dict) and 'set_sound' in a:
 | 
			
		||||
                pass
 | 
			
		||||
            else:
 | 
			
		||||
                raise InvalidRuleException("Unrecognised action")
 | 
			
		||||
 | 
			
		||||
        return conditions, actions
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def on_PUT(self, request):
 | 
			
		||||
        spec = self.rule_spec_from_path(request.postpath)
 | 
			
		||||
        try:
 | 
			
		||||
            priority_class = _priority_class_from_spec(spec)
 | 
			
		||||
        except InvalidRuleException as e:
 | 
			
		||||
            raise SynapseError(400, e.message)
 | 
			
		||||
 | 
			
		||||
        user = yield self.auth.get_user_by_req(request)
 | 
			
		||||
 | 
			
		||||
        content = _parse_json(request)
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            (conditions, actions) = self.rule_tuple_from_request_object(
 | 
			
		||||
                spec['template'],
 | 
			
		||||
                spec['rule_id'],
 | 
			
		||||
                content,
 | 
			
		||||
                device=spec['device'] if 'device' in spec else None
 | 
			
		||||
            )
 | 
			
		||||
        except InvalidRuleException as e:
 | 
			
		||||
            raise SynapseError(400, e.message)
 | 
			
		||||
 | 
			
		||||
        before = request.args.get("before", None)
 | 
			
		||||
        if before and len(before):
 | 
			
		||||
            before = before[0]
 | 
			
		||||
        after = request.args.get("after", None)
 | 
			
		||||
        if after and len(after):
 | 
			
		||||
            after = after[0]
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            yield self.hs.get_datastore().add_push_rule(
 | 
			
		||||
                user_name=user.to_string(),
 | 
			
		||||
                rule_id=spec['rule_id'],
 | 
			
		||||
                priority_class=priority_class,
 | 
			
		||||
                conditions=conditions,
 | 
			
		||||
                actions=actions,
 | 
			
		||||
                before=before,
 | 
			
		||||
                after=after
 | 
			
		||||
            )
 | 
			
		||||
        except InconsistentRuleException as e:
 | 
			
		||||
            raise SynapseError(400, e.message)
 | 
			
		||||
        except RuleNotFoundException as e:
 | 
			
		||||
            raise SynapseError(400, e.message)
 | 
			
		||||
 | 
			
		||||
        defer.returnValue((200, {}))
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def on_DELETE(self, request):
 | 
			
		||||
        spec = self.rule_spec_from_path(request.postpath)
 | 
			
		||||
        try:
 | 
			
		||||
            priority_class = _priority_class_from_spec(spec)
 | 
			
		||||
        except InvalidRuleException as e:
 | 
			
		||||
            raise SynapseError(400, e.message)
 | 
			
		||||
 | 
			
		||||
        user = yield self.auth.get_user_by_req(request)
 | 
			
		||||
 | 
			
		||||
        if 'device' in spec:
 | 
			
		||||
            rules = yield self.hs.get_datastore().get_push_rules_for_user_name(
 | 
			
		||||
                user.to_string()
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            for r in rules:
 | 
			
		||||
                conditions = json.loads(r['conditions'])
 | 
			
		||||
                ih = _instance_handle_from_conditions(conditions)
 | 
			
		||||
                if ih == spec['device'] and r['priority_class'] == priority_class:
 | 
			
		||||
                    yield self.hs.get_datastore().delete_push_rule(
 | 
			
		||||
                        user.to_string(), spec['rule_id']
 | 
			
		||||
                    )
 | 
			
		||||
                    defer.returnValue((200, {}))
 | 
			
		||||
            raise NotFoundError()
 | 
			
		||||
        else:
 | 
			
		||||
            try:
 | 
			
		||||
                yield self.hs.get_datastore().delete_push_rule(
 | 
			
		||||
                    user.to_string(), spec['rule_id'],
 | 
			
		||||
                    priority_class=priority_class
 | 
			
		||||
                )
 | 
			
		||||
                defer.returnValue((200, {}))
 | 
			
		||||
            except StoreError as e:
 | 
			
		||||
                if e.code == 404:
 | 
			
		||||
                    raise NotFoundError()
 | 
			
		||||
                else:
 | 
			
		||||
                    raise
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def on_GET(self, request):
 | 
			
		||||
        user = yield self.auth.get_user_by_req(request)
 | 
			
		||||
 | 
			
		||||
        # we build up the full structure and then decide which bits of it
 | 
			
		||||
        # to send which means doing unnecessary work sometimes but is
 | 
			
		||||
        # is probably not going to make a whole lot of difference
 | 
			
		||||
        rawrules = yield self.hs.get_datastore().get_push_rules_for_user_name(user.to_string())
 | 
			
		||||
 | 
			
		||||
        rules = {'global': {}, 'device': {}}
 | 
			
		||||
 | 
			
		||||
        rules['global'] = _add_empty_priority_class_arrays(rules['global'])
 | 
			
		||||
 | 
			
		||||
        for r in rawrules:
 | 
			
		||||
            rulearray = None
 | 
			
		||||
 | 
			
		||||
            r["conditions"] = json.loads(r["conditions"])
 | 
			
		||||
            r["actions"] = json.loads(r["actions"])
 | 
			
		||||
 | 
			
		||||
            template_name = _priority_class_to_template_name(r['priority_class'])
 | 
			
		||||
 | 
			
		||||
            if r['priority_class'] > PushRuleRestServlet.PRIORITY_CLASS_MAP['override']:
 | 
			
		||||
                # per-device rule
 | 
			
		||||
                instance_handle = _instance_handle_from_conditions(r["conditions"])
 | 
			
		||||
                r = _strip_device_condition(r)
 | 
			
		||||
                if not instance_handle:
 | 
			
		||||
                    continue
 | 
			
		||||
                if instance_handle not in rules['device']:
 | 
			
		||||
                    rules['device'][instance_handle] = {}
 | 
			
		||||
                    rules['device'][instance_handle] = (
 | 
			
		||||
                        _add_empty_priority_class_arrays(
 | 
			
		||||
                            rules['device'][instance_handle]
 | 
			
		||||
                        )
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
                rulearray = rules['device'][instance_handle][template_name]
 | 
			
		||||
            else:
 | 
			
		||||
                rulearray = rules['global'][template_name]
 | 
			
		||||
 | 
			
		||||
            template_rule = _rule_to_template(r)
 | 
			
		||||
            if template_rule:
 | 
			
		||||
                rulearray.append(template_rule)
 | 
			
		||||
 | 
			
		||||
        path = request.postpath[1:]
 | 
			
		||||
 | 
			
		||||
        if path == []:
 | 
			
		||||
            # we're a reference impl: pedantry is our job.
 | 
			
		||||
            raise UnrecognizedRequestError(PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR)
 | 
			
		||||
 | 
			
		||||
        if path[0] == '':
 | 
			
		||||
            defer.returnValue((200, rules))
 | 
			
		||||
        elif path[0] == 'global':
 | 
			
		||||
            path = path[1:]
 | 
			
		||||
            result = _filter_ruleset_with_path(rules['global'], path)
 | 
			
		||||
            defer.returnValue((200, result))
 | 
			
		||||
        elif path[0] == 'device':
 | 
			
		||||
            path = path[1:]
 | 
			
		||||
            if path == []:
 | 
			
		||||
                raise UnrecognizedRequestError(PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR)
 | 
			
		||||
            if path[0] == '':
 | 
			
		||||
                defer.returnValue((200, rules['device']))
 | 
			
		||||
 | 
			
		||||
            instance_handle = path[0]
 | 
			
		||||
            path = path[1:]
 | 
			
		||||
            if instance_handle not in rules['device']:
 | 
			
		||||
                ret = {}
 | 
			
		||||
                ret = _add_empty_priority_class_arrays(ret)
 | 
			
		||||
                defer.returnValue((200, ret))
 | 
			
		||||
            ruleset = rules['device'][instance_handle]
 | 
			
		||||
            result = _filter_ruleset_with_path(ruleset, path)
 | 
			
		||||
            defer.returnValue((200, result))
 | 
			
		||||
        else:
 | 
			
		||||
            raise UnrecognizedRequestError()
 | 
			
		||||
 | 
			
		||||
    def on_OPTIONS(self, _):
 | 
			
		||||
        return 200, {}
 | 
			
		||||
 | 
			
		||||
def _add_empty_priority_class_arrays(d):
 | 
			
		||||
    for pc in PushRuleRestServlet.PRIORITY_CLASS_MAP.keys():
 | 
			
		||||
        d[pc] = []
 | 
			
		||||
    return d
 | 
			
		||||
 | 
			
		||||
def _instance_handle_from_conditions(conditions):
 | 
			
		||||
    """
 | 
			
		||||
    Given a list of conditions, return the instance handle of the
 | 
			
		||||
    device rule if there is one
 | 
			
		||||
    """
 | 
			
		||||
    for c in conditions:
 | 
			
		||||
        if c['kind'] == 'device':
 | 
			
		||||
            return c['instance_handle']
 | 
			
		||||
    return None
 | 
			
		||||
 | 
			
		||||
def _filter_ruleset_with_path(ruleset, path):
 | 
			
		||||
    if path == []:
 | 
			
		||||
        raise UnrecognizedRequestError(PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR)
 | 
			
		||||
 | 
			
		||||
    if path[0] == '':
 | 
			
		||||
        return ruleset
 | 
			
		||||
    template_kind = path[0]
 | 
			
		||||
    if template_kind not in ruleset:
 | 
			
		||||
        raise UnrecognizedRequestError()
 | 
			
		||||
    path = path[1:]
 | 
			
		||||
    if path == []:
 | 
			
		||||
        raise UnrecognizedRequestError(PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR)
 | 
			
		||||
    if path[0] == '':
 | 
			
		||||
        return ruleset[template_kind]
 | 
			
		||||
    rule_id = path[0]
 | 
			
		||||
    for r in ruleset[template_kind]:
 | 
			
		||||
        if r['rule_id'] == rule_id:
 | 
			
		||||
            return r
 | 
			
		||||
    raise NotFoundError
 | 
			
		||||
 | 
			
		||||
def _priority_class_from_spec(spec):
 | 
			
		||||
    if spec['template'] not in PushRuleRestServlet.PRIORITY_CLASS_MAP.keys():
 | 
			
		||||
        raise InvalidRuleException("Unknown template: %s" % (spec['kind']))
 | 
			
		||||
    pc = PushRuleRestServlet.PRIORITY_CLASS_MAP[spec['template']]
 | 
			
		||||
 | 
			
		||||
    if spec['scope'] == 'device':
 | 
			
		||||
        pc += len(PushRuleRestServlet.PRIORITY_CLASS_MAP)
 | 
			
		||||
 | 
			
		||||
    return pc
 | 
			
		||||
 | 
			
		||||
def _priority_class_to_template_name(pc):
 | 
			
		||||
    if pc > PushRuleRestServlet.PRIORITY_CLASS_MAP['override']:
 | 
			
		||||
        # per-device
 | 
			
		||||
        prio_class_index = pc - len(PushRuleRestServlet.PRIORITY_CLASS_MAP)
 | 
			
		||||
        return PushRuleRestServlet.PRIORITY_CLASS_INVERSE_MAP[prio_class_index]
 | 
			
		||||
    else:
 | 
			
		||||
        return PushRuleRestServlet.PRIORITY_CLASS_INVERSE_MAP[pc]
 | 
			
		||||
 | 
			
		||||
def _rule_to_template(rule):
 | 
			
		||||
    template_name = _priority_class_to_template_name(rule['priority_class'])
 | 
			
		||||
    if template_name in ['override', 'underride']:
 | 
			
		||||
        return {k: rule[k] for k in ["rule_id", "conditions", "actions"]}
 | 
			
		||||
    elif template_name in ["sender", "room"]:
 | 
			
		||||
        return {k: rule[k] for k in ["rule_id", "actions"]}
 | 
			
		||||
    elif template_name == 'content':
 | 
			
		||||
        if len(rule["conditions"]) != 1:
 | 
			
		||||
            return None
 | 
			
		||||
        thecond = rule["conditions"][0]
 | 
			
		||||
        if "pattern" not in thecond:
 | 
			
		||||
            return None
 | 
			
		||||
        ret = {k: rule[k] for k in ["rule_id", "actions"]}
 | 
			
		||||
        ret["pattern"] = thecond["pattern"]
 | 
			
		||||
        return ret
 | 
			
		||||
 | 
			
		||||
def _strip_device_condition(rule):
 | 
			
		||||
    for i,c in enumerate(rule['conditions']):
 | 
			
		||||
        if c['kind'] == 'device':
 | 
			
		||||
            del rule['conditions'][i]
 | 
			
		||||
    return rule
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class InvalidRuleException(Exception):
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# XXX: C+ped from rest/room.py - surely this should be common?
 | 
			
		||||
def _parse_json(request):
 | 
			
		||||
    try:
 | 
			
		||||
        content = json.loads(request.content.read())
 | 
			
		||||
        if type(content) != dict:
 | 
			
		||||
            raise SynapseError(400, "Content must be a JSON object.",
 | 
			
		||||
                               errcode=Codes.NOT_JSON)
 | 
			
		||||
        return content
 | 
			
		||||
    except ValueError:
 | 
			
		||||
        raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def register_servlets(hs, http_server):
 | 
			
		||||
    PushRuleRestServlet(hs).register(http_server)
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,80 @@
 | 
			
		|||
# -*- coding: utf-8 -*-
 | 
			
		||||
# Copyright 2014 OpenMarket Ltd
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
from twisted.internet import defer
 | 
			
		||||
 | 
			
		||||
from synapse.api.errors import SynapseError, Codes
 | 
			
		||||
from synapse.push import PusherConfigException
 | 
			
		||||
from .base import ClientV1RestServlet, client_path_pattern
 | 
			
		||||
 | 
			
		||||
import json
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PusherRestServlet(ClientV1RestServlet):
 | 
			
		||||
    PATTERN = client_path_pattern("/pushers/set$")
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def on_POST(self, request):
 | 
			
		||||
        user = yield self.auth.get_user_by_req(request)
 | 
			
		||||
 | 
			
		||||
        content = _parse_json(request)
 | 
			
		||||
 | 
			
		||||
        reqd = ['instance_handle', 'kind', 'app_id', 'app_display_name',
 | 
			
		||||
                'device_display_name', 'pushkey', 'lang', 'data']
 | 
			
		||||
        missing = []
 | 
			
		||||
        for i in reqd:
 | 
			
		||||
            if i not in content:
 | 
			
		||||
                missing.append(i)
 | 
			
		||||
        if len(missing):
 | 
			
		||||
            raise SynapseError(400, "Missing parameters: "+','.join(missing),
 | 
			
		||||
                               errcode=Codes.MISSING_PARAM)
 | 
			
		||||
 | 
			
		||||
        pusher_pool = self.hs.get_pusherpool()
 | 
			
		||||
        try:
 | 
			
		||||
            yield pusher_pool.add_pusher(
 | 
			
		||||
                user_name=user.to_string(),
 | 
			
		||||
                instance_handle=content['instance_handle'],
 | 
			
		||||
                kind=content['kind'],
 | 
			
		||||
                app_id=content['app_id'],
 | 
			
		||||
                app_display_name=content['app_display_name'],
 | 
			
		||||
                device_display_name=content['device_display_name'],
 | 
			
		||||
                pushkey=content['pushkey'],
 | 
			
		||||
                lang=content['lang'],
 | 
			
		||||
                data=content['data']
 | 
			
		||||
            )
 | 
			
		||||
        except PusherConfigException as pce:
 | 
			
		||||
            raise SynapseError(400, "Config Error: "+pce.message,
 | 
			
		||||
                               errcode=Codes.MISSING_PARAM)
 | 
			
		||||
 | 
			
		||||
        defer.returnValue((200, {}))
 | 
			
		||||
 | 
			
		||||
    def on_OPTIONS(self, _):
 | 
			
		||||
        return 200, {}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# XXX: C+ped from rest/room.py - surely this should be common?
 | 
			
		||||
def _parse_json(request):
 | 
			
		||||
    try:
 | 
			
		||||
        content = json.loads(request.content.read())
 | 
			
		||||
        if type(content) != dict:
 | 
			
		||||
            raise SynapseError(400, "Content must be a JSON object.",
 | 
			
		||||
                               errcode=Codes.NOT_JSON)
 | 
			
		||||
        return content
 | 
			
		||||
    except ValueError:
 | 
			
		||||
        raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def register_servlets(hs, http_server):
 | 
			
		||||
    PusherRestServlet(hs).register(http_server)
 | 
			
		||||
| 
						 | 
				
			
			@ -62,7 +62,7 @@ class RoomCreateRestServlet(ClientV1RestServlet):
 | 
			
		|||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def on_POST(self, request):
 | 
			
		||||
        auth_user = yield self.auth.get_user_by_req(request)
 | 
			
		||||
        auth_user, client = yield self.auth.get_user_by_req(request)
 | 
			
		||||
 | 
			
		||||
        room_config = self.get_room_config(request)
 | 
			
		||||
        info = yield self.make_room(room_config, auth_user, None)
 | 
			
		||||
| 
						 | 
				
			
			@ -125,7 +125,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
 | 
			
		|||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def on_GET(self, request, room_id, event_type, state_key):
 | 
			
		||||
        user = yield self.auth.get_user_by_req(request)
 | 
			
		||||
        user, client = yield self.auth.get_user_by_req(request)
 | 
			
		||||
 | 
			
		||||
        msg_handler = self.handlers.message_handler
 | 
			
		||||
        data = yield msg_handler.get_room_data(
 | 
			
		||||
| 
						 | 
				
			
			@ -142,8 +142,8 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
 | 
			
		|||
        defer.returnValue((200, data.get_dict()["content"]))
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def on_PUT(self, request, room_id, event_type, state_key):
 | 
			
		||||
        user = yield self.auth.get_user_by_req(request)
 | 
			
		||||
    def on_PUT(self, request, room_id, event_type, state_key, txn_id=None):
 | 
			
		||||
        user, client = yield self.auth.get_user_by_req(request)
 | 
			
		||||
 | 
			
		||||
        content = _parse_json(request)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -158,7 +158,9 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
 | 
			
		|||
            event_dict["state_key"] = state_key
 | 
			
		||||
 | 
			
		||||
        msg_handler = self.handlers.message_handler
 | 
			
		||||
        yield msg_handler.create_and_send_event(event_dict)
 | 
			
		||||
        yield msg_handler.create_and_send_event(
 | 
			
		||||
            event_dict, client=client, txn_id=txn_id,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        defer.returnValue((200, {}))
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -172,8 +174,8 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
 | 
			
		|||
        register_txn_path(self, PATTERN, http_server, with_get=True)
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def on_POST(self, request, room_id, event_type):
 | 
			
		||||
        user = yield self.auth.get_user_by_req(request)
 | 
			
		||||
    def on_POST(self, request, room_id, event_type, txn_id=None):
 | 
			
		||||
        user, client = yield self.auth.get_user_by_req(request)
 | 
			
		||||
        content = _parse_json(request)
 | 
			
		||||
 | 
			
		||||
        msg_handler = self.handlers.message_handler
 | 
			
		||||
| 
						 | 
				
			
			@ -183,7 +185,9 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
 | 
			
		|||
                "content": content,
 | 
			
		||||
                "room_id": room_id,
 | 
			
		||||
                "sender": user.to_string(),
 | 
			
		||||
            }
 | 
			
		||||
            },
 | 
			
		||||
            client=client,
 | 
			
		||||
            txn_id=txn_id,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        defer.returnValue((200, {"event_id": event.event_id}))
 | 
			
		||||
| 
						 | 
				
			
			@ -200,7 +204,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
 | 
			
		|||
        except KeyError:
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
        response = yield self.on_POST(request, room_id, event_type)
 | 
			
		||||
        response = yield self.on_POST(request, room_id, event_type, txn_id)
 | 
			
		||||
 | 
			
		||||
        self.txns.store_client_transaction(request, txn_id, response)
 | 
			
		||||
        defer.returnValue(response)
 | 
			
		||||
| 
						 | 
				
			
			@ -215,8 +219,8 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
 | 
			
		|||
        register_txn_path(self, PATTERN, http_server)
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def on_POST(self, request, room_identifier):
 | 
			
		||||
        user = yield self.auth.get_user_by_req(request)
 | 
			
		||||
    def on_POST(self, request, room_identifier, txn_id=None):
 | 
			
		||||
        user, client = yield self.auth.get_user_by_req(request)
 | 
			
		||||
 | 
			
		||||
        # the identifier could be a room alias or a room id. Try one then the
 | 
			
		||||
        # other if it fails to parse, without swallowing other valid
 | 
			
		||||
| 
						 | 
				
			
			@ -245,7 +249,9 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
 | 
			
		|||
                    "room_id": identifier.to_string(),
 | 
			
		||||
                    "sender": user.to_string(),
 | 
			
		||||
                    "state_key": user.to_string(),
 | 
			
		||||
                }
 | 
			
		||||
                },
 | 
			
		||||
                client=client,
 | 
			
		||||
                txn_id=txn_id,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            defer.returnValue((200, {"room_id": identifier.to_string()}))
 | 
			
		||||
| 
						 | 
				
			
			@ -259,7 +265,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
 | 
			
		|||
        except KeyError:
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
        response = yield self.on_POST(request, room_identifier)
 | 
			
		||||
        response = yield self.on_POST(request, room_identifier, txn_id)
 | 
			
		||||
 | 
			
		||||
        self.txns.store_client_transaction(request, txn_id, response)
 | 
			
		||||
        defer.returnValue(response)
 | 
			
		||||
| 
						 | 
				
			
			@ -283,7 +289,7 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
 | 
			
		|||
    @defer.inlineCallbacks
 | 
			
		||||
    def on_GET(self, request, room_id):
 | 
			
		||||
        # TODO support Pagination stream API (limit/tokens)
 | 
			
		||||
        user = yield self.auth.get_user_by_req(request)
 | 
			
		||||
        user, client = yield self.auth.get_user_by_req(request)
 | 
			
		||||
        handler = self.handlers.room_member_handler
 | 
			
		||||
        members = yield handler.get_room_members_as_pagination_chunk(
 | 
			
		||||
            room_id=room_id,
 | 
			
		||||
| 
						 | 
				
			
			@ -311,7 +317,7 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
 | 
			
		|||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def on_GET(self, request, room_id):
 | 
			
		||||
        user = yield self.auth.get_user_by_req(request)
 | 
			
		||||
        user, client = yield self.auth.get_user_by_req(request)
 | 
			
		||||
        pagination_config = PaginationConfig.from_request(
 | 
			
		||||
            request, default_limit=10,
 | 
			
		||||
        )
 | 
			
		||||
| 
						 | 
				
			
			@ -335,7 +341,7 @@ class RoomStateRestServlet(ClientV1RestServlet):
 | 
			
		|||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def on_GET(self, request, room_id):
 | 
			
		||||
        user = yield self.auth.get_user_by_req(request)
 | 
			
		||||
        user, client = yield self.auth.get_user_by_req(request)
 | 
			
		||||
        handler = self.handlers.message_handler
 | 
			
		||||
        # Get all the current state for this room
 | 
			
		||||
        events = yield handler.get_state_events(
 | 
			
		||||
| 
						 | 
				
			
			@ -351,7 +357,7 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet):
 | 
			
		|||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def on_GET(self, request, room_id):
 | 
			
		||||
        user = yield self.auth.get_user_by_req(request)
 | 
			
		||||
        user, client = yield self.auth.get_user_by_req(request)
 | 
			
		||||
        pagination_config = PaginationConfig.from_request(request)
 | 
			
		||||
        content = yield self.handlers.message_handler.room_initial_sync(
 | 
			
		||||
            room_id=room_id,
 | 
			
		||||
| 
						 | 
				
			
			@ -395,8 +401,8 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
 | 
			
		|||
        register_txn_path(self, PATTERN, http_server)
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def on_POST(self, request, room_id, membership_action):
 | 
			
		||||
        user = yield self.auth.get_user_by_req(request)
 | 
			
		||||
    def on_POST(self, request, room_id, membership_action, txn_id=None):
 | 
			
		||||
        user, client = yield self.auth.get_user_by_req(request)
 | 
			
		||||
 | 
			
		||||
        content = _parse_json(request)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -418,7 +424,9 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
 | 
			
		|||
                "room_id": room_id,
 | 
			
		||||
                "sender": user.to_string(),
 | 
			
		||||
                "state_key": state_key,
 | 
			
		||||
            }
 | 
			
		||||
            },
 | 
			
		||||
            client=client,
 | 
			
		||||
            txn_id=txn_id,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        defer.returnValue((200, {}))
 | 
			
		||||
| 
						 | 
				
			
			@ -432,7 +440,9 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
 | 
			
		|||
        except KeyError:
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
        response = yield self.on_POST(request, room_id, membership_action)
 | 
			
		||||
        response = yield self.on_POST(
 | 
			
		||||
            request, room_id, membership_action, txn_id
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.txns.store_client_transaction(request, txn_id, response)
 | 
			
		||||
        defer.returnValue(response)
 | 
			
		||||
| 
						 | 
				
			
			@ -444,8 +454,8 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
 | 
			
		|||
        register_txn_path(self, PATTERN, http_server)
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def on_POST(self, request, room_id, event_id):
 | 
			
		||||
        user = yield self.auth.get_user_by_req(request)
 | 
			
		||||
    def on_POST(self, request, room_id, event_id, txn_id=None):
 | 
			
		||||
        user, client = yield self.auth.get_user_by_req(request)
 | 
			
		||||
        content = _parse_json(request)
 | 
			
		||||
 | 
			
		||||
        msg_handler = self.handlers.message_handler
 | 
			
		||||
| 
						 | 
				
			
			@ -456,7 +466,9 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
 | 
			
		|||
                "room_id": room_id,
 | 
			
		||||
                "sender": user.to_string(),
 | 
			
		||||
                "redacts": event_id,
 | 
			
		||||
            }
 | 
			
		||||
            },
 | 
			
		||||
            client=client,
 | 
			
		||||
            txn_id=txn_id,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        defer.returnValue((200, {"event_id": event.event_id}))
 | 
			
		||||
| 
						 | 
				
			
			@ -470,7 +482,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
 | 
			
		|||
        except KeyError:
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
        response = yield self.on_POST(request, room_id, event_id)
 | 
			
		||||
        response = yield self.on_POST(request, room_id, event_id, txn_id)
 | 
			
		||||
 | 
			
		||||
        self.txns.store_client_transaction(request, txn_id, response)
 | 
			
		||||
        defer.returnValue(response)
 | 
			
		||||
| 
						 | 
				
			
			@ -483,7 +495,7 @@ class RoomTypingRestServlet(ClientV1RestServlet):
 | 
			
		|||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def on_PUT(self, request, room_id, user_id):
 | 
			
		||||
        auth_user = yield self.auth.get_user_by_req(request)
 | 
			
		||||
        auth_user, client = yield self.auth.get_user_by_req(request)
 | 
			
		||||
 | 
			
		||||
        room_id = urllib.unquote(room_id)
 | 
			
		||||
        target_user = UserID.from_string(urllib.unquote(user_id))
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -28,7 +28,7 @@ class VoipRestServlet(ClientV1RestServlet):
 | 
			
		|||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def on_GET(self, request):
 | 
			
		||||
        auth_user = yield self.auth.get_user_by_req(request)
 | 
			
		||||
        auth_user, client = yield self.auth.get_user_by_req(request)
 | 
			
		||||
 | 
			
		||||
        turnUris = self.hs.config.turn_uris
 | 
			
		||||
        turnSecret = self.hs.config.turn_shared_secret
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -66,7 +66,7 @@ class ContentRepoResource(resource.Resource):
 | 
			
		|||
    @defer.inlineCallbacks
 | 
			
		||||
    def map_request_to_name(self, request):
 | 
			
		||||
        # auth the user
 | 
			
		||||
        auth_user = yield self.auth.get_user_by_req(request)
 | 
			
		||||
        auth_user, client = yield self.auth.get_user_by_req(request)
 | 
			
		||||
 | 
			
		||||
        # namespace all file uploads on the user
 | 
			
		||||
        prefix = base64.urlsafe_b64encode(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -42,7 +42,7 @@ class UploadResource(BaseMediaResource):
 | 
			
		|||
    @defer.inlineCallbacks
 | 
			
		||||
    def _async_render_POST(self, request):
 | 
			
		||||
        try:
 | 
			
		||||
            auth_user = yield self.auth.get_user_by_req(request)
 | 
			
		||||
            auth_user, client = yield self.auth.get_user_by_req(request)
 | 
			
		||||
            # TODO: The checks here are a bit late. The content will have
 | 
			
		||||
            # already been uploaded to a tmp file at this point
 | 
			
		||||
            content_length = request.getHeader("Content-Length")
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -31,6 +31,7 @@ from synapse.util.lockutils import LockManager
 | 
			
		|||
from synapse.streams.events import EventSources
 | 
			
		||||
from synapse.api.ratelimiting import Ratelimiter
 | 
			
		||||
from synapse.crypto.keyring import Keyring
 | 
			
		||||
from synapse.push.pusherpool import PusherPool
 | 
			
		||||
from synapse.events.builder import EventBuilderFactory
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -78,6 +79,7 @@ class BaseHomeServer(object):
 | 
			
		|||
        'event_sources',
 | 
			
		||||
        'ratelimiter',
 | 
			
		||||
        'keyring',
 | 
			
		||||
        'pusherpool',
 | 
			
		||||
        'event_builder_factory',
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -197,3 +199,6 @@ class HomeServer(BaseHomeServer):
 | 
			
		|||
            clock=self.get_clock(),
 | 
			
		||||
            hostname=self.hostname,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def build_pusherpool(self):
 | 
			
		||||
        return PusherPool(self)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -29,6 +29,8 @@ from .stream import StreamStore
 | 
			
		|||
from .transactions import TransactionStore
 | 
			
		||||
from .keys import KeyStore
 | 
			
		||||
from .event_federation import EventFederationStore
 | 
			
		||||
from .pusher import PusherStore
 | 
			
		||||
from .push_rule import PushRuleStore
 | 
			
		||||
from .media_repository import MediaRepositoryStore
 | 
			
		||||
 | 
			
		||||
from .state import StateStore
 | 
			
		||||
| 
						 | 
				
			
			@ -60,13 +62,14 @@ SCHEMAS = [
 | 
			
		|||
    "state",
 | 
			
		||||
    "event_edges",
 | 
			
		||||
    "event_signatures",
 | 
			
		||||
    "pusher",
 | 
			
		||||
    "media_repository",
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Remember to update this number every time an incompatible change is made to
 | 
			
		||||
# database schema files, so the users will be informed on server restarts.
 | 
			
		||||
SCHEMA_VERSION = 11
 | 
			
		||||
SCHEMA_VERSION = 12
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class _RollbackButIsFineException(Exception):
 | 
			
		||||
| 
						 | 
				
			
			@ -82,6 +85,8 @@ class DataStore(RoomMemberStore, RoomStore,
 | 
			
		|||
                DirectoryStore, KeyStore, StateStore, SignatureStore,
 | 
			
		||||
                EventFederationStore,
 | 
			
		||||
                MediaRepositoryStore,
 | 
			
		||||
                PusherStore,
 | 
			
		||||
                PushRuleStore
 | 
			
		||||
                ):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, hs):
 | 
			
		||||
| 
						 | 
				
			
			@ -381,6 +386,41 @@ class DataStore(RoomMemberStore, RoomStore,
 | 
			
		|||
        events = yield self._parse_events(results)
 | 
			
		||||
        defer.returnValue(events)
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def get_room_name_and_aliases(self, room_id):
 | 
			
		||||
        del_sql = (
 | 
			
		||||
            "SELECT event_id FROM redactions WHERE redacts = e.event_id "
 | 
			
		||||
            "LIMIT 1"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        sql = (
 | 
			
		||||
            "SELECT e.*, (%(redacted)s) AS redacted FROM events as e "
 | 
			
		||||
            "INNER JOIN current_state_events as c ON e.event_id = c.event_id "
 | 
			
		||||
            "INNER JOIN state_events as s ON e.event_id = s.event_id "
 | 
			
		||||
            "WHERE c.room_id = ? "
 | 
			
		||||
        ) % {
 | 
			
		||||
            "redacted": del_sql,
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        sql += " AND ((s.type = 'm.room.name' AND s.state_key = '')"
 | 
			
		||||
        sql += " OR s.type = 'm.room.aliases')"
 | 
			
		||||
        args = (room_id,)
 | 
			
		||||
 | 
			
		||||
        results = yield self._execute_and_decode(sql, *args)
 | 
			
		||||
 | 
			
		||||
        events = yield self._parse_events(results)
 | 
			
		||||
 | 
			
		||||
        name = None
 | 
			
		||||
        aliases = []
 | 
			
		||||
 | 
			
		||||
        for e in events:
 | 
			
		||||
            if e.type == 'm.room.name':
 | 
			
		||||
                name = e.content['name']
 | 
			
		||||
            elif e.type == 'm.room.aliases':
 | 
			
		||||
                aliases.extend(e.content['aliases'])
 | 
			
		||||
 | 
			
		||||
        defer.returnValue((name, aliases))
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def _get_min_token(self):
 | 
			
		||||
        row = yield self._execute(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -193,6 +193,50 @@ class SQLBaseStore(object):
 | 
			
		|||
        txn.execute(sql, values.values())
 | 
			
		||||
        return txn.lastrowid
 | 
			
		||||
 | 
			
		||||
    def _simple_upsert(self, table, keyvalues, values):
 | 
			
		||||
        """
 | 
			
		||||
        Args:
 | 
			
		||||
            table (str): The table to upsert into
 | 
			
		||||
            keyvalues (dict): The unique key tables and their new values
 | 
			
		||||
            values (dict): The nonunique columns and their new values
 | 
			
		||||
        Returns: A deferred
 | 
			
		||||
        """
 | 
			
		||||
        return self.runInteraction(
 | 
			
		||||
            "_simple_upsert",
 | 
			
		||||
            self._simple_upsert_txn, table, keyvalues, values
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _simple_upsert_txn(self, txn, table, keyvalues, values):
 | 
			
		||||
        # Try to update
 | 
			
		||||
        sql = "UPDATE %s SET %s WHERE %s" % (
 | 
			
		||||
            table,
 | 
			
		||||
            ", ".join("%s = ?" % (k,) for k in values),
 | 
			
		||||
            " AND ".join("%s = ?" % (k,) for k in keyvalues)
 | 
			
		||||
        )
 | 
			
		||||
        sqlargs = values.values() + keyvalues.values()
 | 
			
		||||
        logger.debug(
 | 
			
		||||
            "[SQL] %s Args=%s",
 | 
			
		||||
            sql, sqlargs,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        txn.execute(sql, sqlargs)
 | 
			
		||||
        if txn.rowcount == 0:
 | 
			
		||||
            # We didn't update and rows so insert a new one
 | 
			
		||||
            allvalues = {}
 | 
			
		||||
            allvalues.update(keyvalues)
 | 
			
		||||
            allvalues.update(values)
 | 
			
		||||
 | 
			
		||||
            sql = "INSERT INTO %s (%s) VALUES (%s)" % (
 | 
			
		||||
                table,
 | 
			
		||||
                ", ".join(k for k in allvalues),
 | 
			
		||||
                ", ".join("?" for _ in allvalues)
 | 
			
		||||
            )
 | 
			
		||||
            logger.debug(
 | 
			
		||||
                "[SQL] %s Args=%s",
 | 
			
		||||
                sql, keyvalues.values(),
 | 
			
		||||
            )
 | 
			
		||||
            txn.execute(sql, allvalues.values())
 | 
			
		||||
 | 
			
		||||
    def _simple_select_one(self, table, keyvalues, retcols,
 | 
			
		||||
                           allow_none=False):
 | 
			
		||||
        """Executes a SELECT query on the named table, which is expected to
 | 
			
		||||
| 
						 | 
				
			
			@ -344,8 +388,8 @@ class SQLBaseStore(object):
 | 
			
		|||
        if updatevalues:
 | 
			
		||||
            update_sql = "UPDATE %s SET %s WHERE %s" % (
 | 
			
		||||
                table,
 | 
			
		||||
                ", ".join("%s = ?" % (k) for k in updatevalues),
 | 
			
		||||
                " AND ".join("%s = ?" % (k) for k in keyvalues)
 | 
			
		||||
                ", ".join("%s = ?" % (k,) for k in updatevalues),
 | 
			
		||||
                " AND ".join("%s = ?" % (k,) for k in keyvalues)
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        def func(txn):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,209 @@
 | 
			
		|||
# -*- coding: utf-8 -*-
 | 
			
		||||
# Copyright 2014 OpenMarket Ltd
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import collections
 | 
			
		||||
 | 
			
		||||
from ._base import SQLBaseStore, Table
 | 
			
		||||
from twisted.internet import defer
 | 
			
		||||
 | 
			
		||||
import logging
 | 
			
		||||
import copy
 | 
			
		||||
import json
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PushRuleStore(SQLBaseStore):
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def get_push_rules_for_user_name(self, user_name):
 | 
			
		||||
        sql = (
 | 
			
		||||
            "SELECT "+",".join(PushRuleTable.fields)+" "
 | 
			
		||||
            "FROM "+PushRuleTable.table_name+" "
 | 
			
		||||
            "WHERE user_name = ? "
 | 
			
		||||
            "ORDER BY priority_class DESC, priority DESC"
 | 
			
		||||
        )
 | 
			
		||||
        rows = yield self._execute(None, sql, user_name)
 | 
			
		||||
 | 
			
		||||
        dicts = []
 | 
			
		||||
        for r in rows:
 | 
			
		||||
            d = {}
 | 
			
		||||
            for i, f in enumerate(PushRuleTable.fields):
 | 
			
		||||
                d[f] = r[i]
 | 
			
		||||
            dicts.append(d)
 | 
			
		||||
 | 
			
		||||
        defer.returnValue(dicts)
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def add_push_rule(self, before, after, **kwargs):
 | 
			
		||||
        vals = copy.copy(kwargs)
 | 
			
		||||
        if 'conditions' in vals:
 | 
			
		||||
            vals['conditions'] = json.dumps(vals['conditions'])
 | 
			
		||||
        if 'actions' in vals:
 | 
			
		||||
            vals['actions'] = json.dumps(vals['actions'])
 | 
			
		||||
        # we could check the rest of the keys are valid column names
 | 
			
		||||
        # but sqlite will do that anyway so I think it's just pointless.
 | 
			
		||||
        if 'id' in vals:
 | 
			
		||||
            del vals['id']
 | 
			
		||||
 | 
			
		||||
        if before or after:
 | 
			
		||||
            ret = yield self.runInteraction(
 | 
			
		||||
                "_add_push_rule_relative_txn",
 | 
			
		||||
                self._add_push_rule_relative_txn,
 | 
			
		||||
                before=before,
 | 
			
		||||
                after=after,
 | 
			
		||||
                **vals
 | 
			
		||||
            )
 | 
			
		||||
            defer.returnValue(ret)
 | 
			
		||||
        else:
 | 
			
		||||
            ret = yield self.runInteraction(
 | 
			
		||||
                "_add_push_rule_highest_priority_txn",
 | 
			
		||||
                self._add_push_rule_highest_priority_txn,
 | 
			
		||||
                **vals
 | 
			
		||||
            )
 | 
			
		||||
            defer.returnValue(ret)
 | 
			
		||||
 | 
			
		||||
    def _add_push_rule_relative_txn(self, txn, user_name, **kwargs):
 | 
			
		||||
        after = None
 | 
			
		||||
        relative_to_rule = None
 | 
			
		||||
        if 'after' in kwargs and kwargs['after']:
 | 
			
		||||
            after = kwargs['after']
 | 
			
		||||
            relative_to_rule = after
 | 
			
		||||
        if 'before' in kwargs and kwargs['before']:
 | 
			
		||||
            relative_to_rule = kwargs['before']
 | 
			
		||||
 | 
			
		||||
        # get the priority of the rule we're inserting after/before
 | 
			
		||||
        sql = (
 | 
			
		||||
            "SELECT priority_class, priority FROM ? "
 | 
			
		||||
            "WHERE user_name = ? and rule_id = ?" % (PushRuleTable.table_name,)
 | 
			
		||||
        )
 | 
			
		||||
        txn.execute(sql, (user_name, relative_to_rule))
 | 
			
		||||
        res = txn.fetchall()
 | 
			
		||||
        if not res:
 | 
			
		||||
            raise RuleNotFoundException("before/after rule not found: %s" % (relative_to_rule))
 | 
			
		||||
        priority_class, base_rule_priority = res[0]
 | 
			
		||||
 | 
			
		||||
        if 'priority_class' in kwargs and kwargs['priority_class'] != priority_class:
 | 
			
		||||
            raise InconsistentRuleException(
 | 
			
		||||
                "Given priority class does not match class of relative rule"
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        new_rule = copy.copy(kwargs)
 | 
			
		||||
        if 'before' in new_rule:
 | 
			
		||||
            del new_rule['before']
 | 
			
		||||
        if 'after' in new_rule:
 | 
			
		||||
            del new_rule['after']
 | 
			
		||||
        new_rule['priority_class'] = priority_class
 | 
			
		||||
        new_rule['user_name'] = user_name
 | 
			
		||||
 | 
			
		||||
        # check if the priority before/after is free
 | 
			
		||||
        new_rule_priority = base_rule_priority
 | 
			
		||||
        if after:
 | 
			
		||||
            new_rule_priority -= 1
 | 
			
		||||
        else:
 | 
			
		||||
            new_rule_priority += 1
 | 
			
		||||
 | 
			
		||||
        new_rule['priority'] = new_rule_priority
 | 
			
		||||
 | 
			
		||||
        sql = (
 | 
			
		||||
            "SELECT COUNT(*) FROM "+PushRuleTable.table_name+
 | 
			
		||||
            " WHERE user_name = ? AND priority_class = ? AND priority = ?"
 | 
			
		||||
        )
 | 
			
		||||
        txn.execute(sql, (user_name, priority_class, new_rule_priority))
 | 
			
		||||
        res = txn.fetchall()
 | 
			
		||||
        num_conflicting = res[0][0]
 | 
			
		||||
 | 
			
		||||
        # if there are conflicting rules, bump everything
 | 
			
		||||
        if num_conflicting:
 | 
			
		||||
            sql = "UPDATE "+PushRuleTable.table_name+" SET priority = priority "
 | 
			
		||||
            if after:
 | 
			
		||||
                sql += "-1"
 | 
			
		||||
            else:
 | 
			
		||||
                sql += "+1"
 | 
			
		||||
            sql += " WHERE user_name = ? AND priority_class = ? AND priority "
 | 
			
		||||
            if after:
 | 
			
		||||
                sql += "<= ?"
 | 
			
		||||
            else:
 | 
			
		||||
                sql += ">= ?"
 | 
			
		||||
 | 
			
		||||
            txn.execute(sql, (user_name, priority_class, new_rule_priority))
 | 
			
		||||
 | 
			
		||||
        # now insert the new rule
 | 
			
		||||
        sql = "INSERT OR REPLACE INTO "+PushRuleTable.table_name+" ("
 | 
			
		||||
        sql += ",".join(new_rule.keys())+") VALUES ("
 | 
			
		||||
        sql += ", ".join(["?" for _ in new_rule.keys()])+")"
 | 
			
		||||
 | 
			
		||||
        txn.execute(sql, new_rule.values())
 | 
			
		||||
 | 
			
		||||
    def _add_push_rule_highest_priority_txn(self, txn, user_name, priority_class, **kwargs):
 | 
			
		||||
        # find the highest priority rule in that class
 | 
			
		||||
        sql = (
 | 
			
		||||
            "SELECT COUNT(*), MAX(priority) FROM "+PushRuleTable.table_name+
 | 
			
		||||
            " WHERE user_name = ? and priority_class = ?"
 | 
			
		||||
        )
 | 
			
		||||
        txn.execute(sql, (user_name, priority_class))
 | 
			
		||||
        res = txn.fetchall()
 | 
			
		||||
        (how_many, highest_prio) = res[0]
 | 
			
		||||
 | 
			
		||||
        new_prio = 0
 | 
			
		||||
        if how_many > 0:
 | 
			
		||||
            new_prio = highest_prio + 1
 | 
			
		||||
 | 
			
		||||
        # and insert the new rule
 | 
			
		||||
        new_rule = copy.copy(kwargs)
 | 
			
		||||
        if 'id' in new_rule:
 | 
			
		||||
            del new_rule['id']
 | 
			
		||||
        new_rule['user_name'] = user_name
 | 
			
		||||
        new_rule['priority_class'] = priority_class
 | 
			
		||||
        new_rule['priority'] = new_prio
 | 
			
		||||
 | 
			
		||||
        sql = "INSERT OR REPLACE INTO "+PushRuleTable.table_name+" ("
 | 
			
		||||
        sql += ",".join(new_rule.keys())+") VALUES ("
 | 
			
		||||
        sql += ", ".join(["?" for _ in new_rule.keys()])+")"
 | 
			
		||||
 | 
			
		||||
        txn.execute(sql, new_rule.values())
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def delete_push_rule(self, user_name, rule_id):
 | 
			
		||||
        yield self._simple_delete_one(
 | 
			
		||||
            PushRuleTable.table_name,
 | 
			
		||||
            {
 | 
			
		||||
                'user_name': user_name,
 | 
			
		||||
                'rule_id': rule_id
 | 
			
		||||
            }
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RuleNotFoundException(Exception):
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class InconsistentRuleException(Exception):
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PushRuleTable(Table):
 | 
			
		||||
    table_name = "push_rules"
 | 
			
		||||
 | 
			
		||||
    fields = [
 | 
			
		||||
        "id",
 | 
			
		||||
        "user_name",
 | 
			
		||||
        "rule_id",
 | 
			
		||||
        "priority_class",
 | 
			
		||||
        "priority",
 | 
			
		||||
        "conditions",
 | 
			
		||||
        "actions",
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    EntryType = collections.namedtuple("PushRuleEntry", fields)
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,173 @@
 | 
			
		|||
# -*- coding: utf-8 -*-
 | 
			
		||||
# Copyright 2014 OpenMarket Ltd
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import collections
 | 
			
		||||
 | 
			
		||||
from ._base import SQLBaseStore, Table
 | 
			
		||||
from twisted.internet import defer
 | 
			
		||||
 | 
			
		||||
from synapse.api.errors import StoreError
 | 
			
		||||
 | 
			
		||||
import logging
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PusherStore(SQLBaseStore):
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def get_pushers_by_app_id_and_pushkey(self, app_id_and_pushkey):
 | 
			
		||||
        sql = (
 | 
			
		||||
            "SELECT id, user_name, kind, instance_handle, app_id,"
 | 
			
		||||
            "app_display_name, device_display_name, pushkey, ts, data, "
 | 
			
		||||
            "last_token, last_success, failing_since "
 | 
			
		||||
            "FROM pushers "
 | 
			
		||||
            "WHERE app_id = ? AND pushkey = ?"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        rows = yield self._execute(
 | 
			
		||||
            None, sql, app_id_and_pushkey[0], app_id_and_pushkey[1]
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        ret = [
 | 
			
		||||
            {
 | 
			
		||||
                "id": r[0],
 | 
			
		||||
                "user_name": r[1],
 | 
			
		||||
                "kind": r[2],
 | 
			
		||||
                "instance_handle": r[3],
 | 
			
		||||
                "app_id": r[4],
 | 
			
		||||
                "app_display_name": r[5],
 | 
			
		||||
                "device_display_name": r[6],
 | 
			
		||||
                "pushkey": r[7],
 | 
			
		||||
                "pushkey_ts": r[8],
 | 
			
		||||
                "data": r[9],
 | 
			
		||||
                "last_token": r[10],
 | 
			
		||||
                "last_success": r[11],
 | 
			
		||||
                "failing_since": r[12]
 | 
			
		||||
            }
 | 
			
		||||
            for r in rows
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        defer.returnValue(ret[0])
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def get_all_pushers(self):
 | 
			
		||||
        sql = (
 | 
			
		||||
            "SELECT id, user_name, kind, instance_handle, app_id,"
 | 
			
		||||
            "app_display_name, device_display_name, pushkey, ts, data, "
 | 
			
		||||
            "last_token, last_success, failing_since "
 | 
			
		||||
            "FROM pushers"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        rows = yield self._execute(None, sql)
 | 
			
		||||
 | 
			
		||||
        ret = [
 | 
			
		||||
            {
 | 
			
		||||
                "id": r[0],
 | 
			
		||||
                "user_name": r[1],
 | 
			
		||||
                "kind": r[2],
 | 
			
		||||
                "instance_handle": r[3],
 | 
			
		||||
                "app_id": r[4],
 | 
			
		||||
                "app_display_name": r[5],
 | 
			
		||||
                "device_display_name": r[6],
 | 
			
		||||
                "pushkey": r[7],
 | 
			
		||||
                "pushkey_ts": r[8],
 | 
			
		||||
                "data": r[9],
 | 
			
		||||
                "last_token": r[10],
 | 
			
		||||
                "last_success": r[11],
 | 
			
		||||
                "failing_since": r[12]
 | 
			
		||||
            }
 | 
			
		||||
            for r in rows
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        defer.returnValue(ret)
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def add_pusher(self, user_name, instance_handle, kind, app_id,
 | 
			
		||||
                   app_display_name, device_display_name,
 | 
			
		||||
                   pushkey, pushkey_ts, lang, data):
 | 
			
		||||
        try:
 | 
			
		||||
            yield self._simple_upsert(
 | 
			
		||||
                PushersTable.table_name,
 | 
			
		||||
                dict(
 | 
			
		||||
                    app_id=app_id,
 | 
			
		||||
                    pushkey=pushkey,
 | 
			
		||||
                ),
 | 
			
		||||
                dict(
 | 
			
		||||
                    user_name=user_name,
 | 
			
		||||
                    kind=kind,
 | 
			
		||||
                    instance_handle=instance_handle,
 | 
			
		||||
                    app_display_name=app_display_name,
 | 
			
		||||
                    device_display_name=device_display_name,
 | 
			
		||||
                    ts=pushkey_ts,
 | 
			
		||||
                    lang=lang,
 | 
			
		||||
                    data=data
 | 
			
		||||
                ))
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            logger.error("create_pusher with failed: %s", e)
 | 
			
		||||
            raise StoreError(500, "Problem creating pusher.")
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def delete_pusher_by_app_id_pushkey(self, app_id, pushkey):
 | 
			
		||||
        yield self._simple_delete_one(
 | 
			
		||||
            PushersTable.table_name,
 | 
			
		||||
            dict(app_id=app_id, pushkey=pushkey)
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def update_pusher_last_token(self, user_name, pushkey, last_token):
 | 
			
		||||
        yield self._simple_update_one(
 | 
			
		||||
            PushersTable.table_name,
 | 
			
		||||
            {'user_name': user_name, 'pushkey': pushkey},
 | 
			
		||||
            {'last_token': last_token}
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def update_pusher_last_token_and_success(self, user_name, pushkey,
 | 
			
		||||
                                             last_token, last_success):
 | 
			
		||||
        yield self._simple_update_one(
 | 
			
		||||
            PushersTable.table_name,
 | 
			
		||||
            {'user_name': user_name, 'pushkey': pushkey},
 | 
			
		||||
            {'last_token': last_token, 'last_success': last_success}
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def update_pusher_failing_since(self, user_name, pushkey, failing_since):
 | 
			
		||||
        yield self._simple_update_one(
 | 
			
		||||
            PushersTable.table_name,
 | 
			
		||||
            {'user_name': user_name, 'pushkey': pushkey},
 | 
			
		||||
            {'failing_since': failing_since}
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PushersTable(Table):
 | 
			
		||||
    table_name = "pushers"
 | 
			
		||||
 | 
			
		||||
    fields = [
 | 
			
		||||
        "id",
 | 
			
		||||
        "user_name",
 | 
			
		||||
        "kind",
 | 
			
		||||
        "instance_handle",
 | 
			
		||||
        "app_id",
 | 
			
		||||
        "app_display_name",
 | 
			
		||||
        "device_display_name",
 | 
			
		||||
        "pushkey",
 | 
			
		||||
        "pushkey_ts",
 | 
			
		||||
        "data",
 | 
			
		||||
        "last_token",
 | 
			
		||||
        "last_success",
 | 
			
		||||
        "failing_since"
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    EntryType = collections.namedtuple("PusherEntry", fields)
 | 
			
		||||
| 
						 | 
				
			
			@ -122,7 +122,8 @@ class RegistrationStore(SQLBaseStore):
 | 
			
		|||
 | 
			
		||||
    def _query_for_auth(self, txn, token):
 | 
			
		||||
        sql = (
 | 
			
		||||
            "SELECT users.name, users.admin, access_tokens.device_id"
 | 
			
		||||
            "SELECT users.name, users.admin,"
 | 
			
		||||
            " access_tokens.device_id, access_tokens.id as token_id"
 | 
			
		||||
            " FROM users"
 | 
			
		||||
            " INNER JOIN access_tokens on users.id = access_tokens.user_id"
 | 
			
		||||
            " WHERE token = ?"
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,46 @@
 | 
			
		|||
/* Copyright 2014 OpenMarket Ltd
 | 
			
		||||
 *
 | 
			
		||||
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
 * you may not use this file except in compliance with the License.
 | 
			
		||||
 * You may obtain a copy of the License at
 | 
			
		||||
 *
 | 
			
		||||
 *    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 *
 | 
			
		||||
 * Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
 * See the License for the specific language governing permissions and
 | 
			
		||||
 * limitations under the License.
 | 
			
		||||
 */
 | 
			
		||||
-- Push notification endpoints that users have configured
 | 
			
		||||
CREATE TABLE IF NOT EXISTS pushers (
 | 
			
		||||
  id INTEGER PRIMARY KEY AUTOINCREMENT,
 | 
			
		||||
  user_name TEXT NOT NULL,
 | 
			
		||||
  instance_handle varchar(32) NOT NULL,
 | 
			
		||||
  kind varchar(8) NOT NULL,
 | 
			
		||||
  app_id varchar(64) NOT NULL,
 | 
			
		||||
  app_display_name varchar(64) NOT NULL,
 | 
			
		||||
  device_display_name varchar(128) NOT NULL,
 | 
			
		||||
  pushkey blob NOT NULL,
 | 
			
		||||
  ts BIGINT NOT NULL,
 | 
			
		||||
  lang varchar(8),
 | 
			
		||||
  data blob,
 | 
			
		||||
  last_token TEXT,
 | 
			
		||||
  last_success BIGINT,
 | 
			
		||||
  failing_since BIGINT,
 | 
			
		||||
  FOREIGN KEY(user_name) REFERENCES users(name),
 | 
			
		||||
  UNIQUE (app_id, pushkey)
 | 
			
		||||
);
 | 
			
		||||
 | 
			
		||||
CREATE TABLE IF NOT EXISTS push_rules (
 | 
			
		||||
  id INTEGER PRIMARY KEY AUTOINCREMENT,
 | 
			
		||||
  user_name TEXT NOT NULL,
 | 
			
		||||
  rule_id TEXT NOT NULL,
 | 
			
		||||
  priority_class TINYINT NOT NULL,
 | 
			
		||||
  priority INTEGER NOT NULL DEFAULT 0,
 | 
			
		||||
  conditions TEXT NOT NULL,
 | 
			
		||||
  actions TEXT NOT NULL,
 | 
			
		||||
  UNIQUE(user_name, rule_id)
 | 
			
		||||
);
 | 
			
		||||
 | 
			
		||||
CREATE INDEX IF NOT EXISTS push_rules_user_name on push_rules (user_name);
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,46 @@
 | 
			
		|||
/* Copyright 2014 OpenMarket Ltd
 | 
			
		||||
 *
 | 
			
		||||
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
 * you may not use this file except in compliance with the License.
 | 
			
		||||
 * You may obtain a copy of the License at
 | 
			
		||||
 *
 | 
			
		||||
 *    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 *
 | 
			
		||||
 * Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
 * See the License for the specific language governing permissions and
 | 
			
		||||
 * limitations under the License.
 | 
			
		||||
 */
 | 
			
		||||
-- Push notification endpoints that users have configured
 | 
			
		||||
CREATE TABLE IF NOT EXISTS pushers (
 | 
			
		||||
  id INTEGER PRIMARY KEY AUTOINCREMENT,
 | 
			
		||||
  user_name TEXT NOT NULL,
 | 
			
		||||
  instance_handle varchar(32) NOT NULL,
 | 
			
		||||
  kind varchar(8) NOT NULL,
 | 
			
		||||
  app_id varchar(64) NOT NULL,
 | 
			
		||||
  app_display_name varchar(64) NOT NULL,
 | 
			
		||||
  device_display_name varchar(128) NOT NULL,
 | 
			
		||||
  pushkey blob NOT NULL,
 | 
			
		||||
  ts BIGINT NOT NULL,
 | 
			
		||||
  lang varchar(8),
 | 
			
		||||
  data blob,
 | 
			
		||||
  last_token TEXT,
 | 
			
		||||
  last_success BIGINT,
 | 
			
		||||
  failing_since BIGINT,
 | 
			
		||||
  FOREIGN KEY(user_name) REFERENCES users(name),
 | 
			
		||||
  UNIQUE (app_id, pushkey)
 | 
			
		||||
);
 | 
			
		||||
 | 
			
		||||
CREATE TABLE IF NOT EXISTS push_rules (
 | 
			
		||||
  id INTEGER PRIMARY KEY AUTOINCREMENT,
 | 
			
		||||
  user_name TEXT NOT NULL,
 | 
			
		||||
  rule_id TEXT NOT NULL,
 | 
			
		||||
  priority_class TINYINT NOT NULL,
 | 
			
		||||
  priority INTEGER NOT NULL DEFAULT 0,
 | 
			
		||||
  conditions TEXT NOT NULL,
 | 
			
		||||
  actions TEXT NOT NULL,
 | 
			
		||||
  UNIQUE(user_name, rule_id)
 | 
			
		||||
);
 | 
			
		||||
 | 
			
		||||
CREATE INDEX IF NOT EXISTS push_rules_user_name on push_rules (user_name);
 | 
			
		||||
| 
						 | 
				
			
			@ -119,3 +119,6 @@ class StreamToken(
 | 
			
		|||
        d = self._asdict()
 | 
			
		||||
        d[key] = new_value
 | 
			
		||||
        return StreamToken(**d)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
ClientInfo = namedtuple("ClientInfo", ("device_id", "token_id"))
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -75,6 +75,7 @@ class PresenceStateTestCase(unittest.TestCase):
 | 
			
		|||
                "user": UserID.from_string(myid),
 | 
			
		||||
                "admin": False,
 | 
			
		||||
                "device_id": None,
 | 
			
		||||
                "token_id": 1,
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
        hs.get_auth().get_user_by_token = _get_user_by_token
 | 
			
		||||
| 
						 | 
				
			
			@ -165,6 +166,7 @@ class PresenceListTestCase(unittest.TestCase):
 | 
			
		|||
                "user": UserID.from_string(myid),
 | 
			
		||||
                "admin": False,
 | 
			
		||||
                "device_id": None,
 | 
			
		||||
                "token_id": 1,
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
        hs.handlers.room_member_handler = Mock(
 | 
			
		||||
| 
						 | 
				
			
			@ -282,7 +284,7 @@ class PresenceEventStreamTestCase(unittest.TestCase):
 | 
			
		|||
        hs.get_clock().time_msec.return_value = 1000000
 | 
			
		||||
 | 
			
		||||
        def _get_user_by_req(req=None):
 | 
			
		||||
            return UserID.from_string(myid)
 | 
			
		||||
            return (UserID.from_string(myid), "")
 | 
			
		||||
 | 
			
		||||
        hs.get_auth().get_user_by_req = _get_user_by_req
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -58,7 +58,7 @@ class ProfileTestCase(unittest.TestCase):
 | 
			
		|||
        )
 | 
			
		||||
 | 
			
		||||
        def _get_user_by_req(request=None):
 | 
			
		||||
            return UserID.from_string(myid)
 | 
			
		||||
            return (UserID.from_string(myid), "")
 | 
			
		||||
 | 
			
		||||
        hs.get_auth().get_user_by_req = _get_user_by_req
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -70,6 +70,7 @@ class RoomPermissionsTestCase(RestTestCase):
 | 
			
		|||
                "user": UserID.from_string(self.auth_user_id),
 | 
			
		||||
                "admin": False,
 | 
			
		||||
                "device_id": None,
 | 
			
		||||
                "token_id": 1,
 | 
			
		||||
            }
 | 
			
		||||
        hs.get_auth().get_user_by_token = _get_user_by_token
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -466,6 +467,7 @@ class RoomsMemberListTestCase(RestTestCase):
 | 
			
		|||
                "user": UserID.from_string(self.auth_user_id),
 | 
			
		||||
                "admin": False,
 | 
			
		||||
                "device_id": None,
 | 
			
		||||
                "token_id": 1,
 | 
			
		||||
            }
 | 
			
		||||
        hs.get_auth().get_user_by_token = _get_user_by_token
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -555,6 +557,7 @@ class RoomsCreateTestCase(RestTestCase):
 | 
			
		|||
                "user": UserID.from_string(self.auth_user_id),
 | 
			
		||||
                "admin": False,
 | 
			
		||||
                "device_id": None,
 | 
			
		||||
                "token_id": 1,
 | 
			
		||||
            }
 | 
			
		||||
        hs.get_auth().get_user_by_token = _get_user_by_token
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -657,6 +660,7 @@ class RoomTopicTestCase(RestTestCase):
 | 
			
		|||
                "user": UserID.from_string(self.auth_user_id),
 | 
			
		||||
                "admin": False,
 | 
			
		||||
                "device_id": None,
 | 
			
		||||
                "token_id": 1,
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
        hs.get_auth().get_user_by_token = _get_user_by_token
 | 
			
		||||
| 
						 | 
				
			
			@ -773,6 +777,7 @@ class RoomMemberStateTestCase(RestTestCase):
 | 
			
		|||
                "user": UserID.from_string(self.auth_user_id),
 | 
			
		||||
                "admin": False,
 | 
			
		||||
                "device_id": None,
 | 
			
		||||
                "token_id": 1,
 | 
			
		||||
            }
 | 
			
		||||
        hs.get_auth().get_user_by_token = _get_user_by_token
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -909,6 +914,7 @@ class RoomMessagesTestCase(RestTestCase):
 | 
			
		|||
                "user": UserID.from_string(self.auth_user_id),
 | 
			
		||||
                "admin": False,
 | 
			
		||||
                "device_id": None,
 | 
			
		||||
                "token_id": 1,
 | 
			
		||||
            }
 | 
			
		||||
        hs.get_auth().get_user_by_token = _get_user_by_token
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -1013,6 +1019,7 @@ class RoomInitialSyncTestCase(RestTestCase):
 | 
			
		|||
                "user": UserID.from_string(self.auth_user_id),
 | 
			
		||||
                "admin": False,
 | 
			
		||||
                "device_id": None,
 | 
			
		||||
                "token_id": 1,
 | 
			
		||||
            }
 | 
			
		||||
        hs.get_auth().get_user_by_token = _get_user_by_token
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -73,6 +73,7 @@ class RoomTypingTestCase(RestTestCase):
 | 
			
		|||
                "user": UserID.from_string(self.auth_user_id),
 | 
			
		||||
                "admin": False,
 | 
			
		||||
                "device_id": None,
 | 
			
		||||
                "token_id": 1,
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
        hs.get_auth().get_user_by_token = _get_user_by_token
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -13,3 +13,48 @@
 | 
			
		|||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
from tests import unittest
 | 
			
		||||
 | 
			
		||||
from mock import Mock
 | 
			
		||||
 | 
			
		||||
from ....utils import MockHttpResource, MockKey
 | 
			
		||||
 | 
			
		||||
from synapse.server import HomeServer
 | 
			
		||||
from synapse.types import UserID
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
PATH_PREFIX = "/_matrix/client/v2_alpha"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class V2AlphaRestTestCase(unittest.TestCase):
 | 
			
		||||
    # Consumer must define
 | 
			
		||||
    #   USER_ID = <some string>
 | 
			
		||||
    #   TO_REGISTER = [<list of REST servlets to register>]
 | 
			
		||||
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
 | 
			
		||||
 | 
			
		||||
        mock_config = Mock()
 | 
			
		||||
        mock_config.signing_key = [MockKey()]
 | 
			
		||||
 | 
			
		||||
        hs = HomeServer("test",
 | 
			
		||||
            db_pool=None,
 | 
			
		||||
            datastore=Mock(spec=[
 | 
			
		||||
                "insert_client_ip",
 | 
			
		||||
            ]),
 | 
			
		||||
            http_client=None,
 | 
			
		||||
            resource_for_client=self.mock_resource,
 | 
			
		||||
            resource_for_federation=self.mock_resource,
 | 
			
		||||
            config=mock_config,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        def _get_user_by_token(token=None):
 | 
			
		||||
            return {
 | 
			
		||||
                "user": UserID.from_string(self.USER_ID),
 | 
			
		||||
                "admin": False,
 | 
			
		||||
                "device_id": None,
 | 
			
		||||
            }
 | 
			
		||||
        hs.get_auth().get_user_by_token = _get_user_by_token
 | 
			
		||||
 | 
			
		||||
        for r in self.TO_REGISTER:
 | 
			
		||||
            r.register_servlets(hs, self.mock_resource)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -53,7 +53,10 @@ class RegistrationStoreTestCase(unittest.TestCase):
 | 
			
		|||
        )
 | 
			
		||||
 | 
			
		||||
        self.assertEquals(
 | 
			
		||||
            {"admin": 0, "device_id": None, "name": self.user_id},
 | 
			
		||||
            {"admin": 0,
 | 
			
		||||
             "device_id": None,
 | 
			
		||||
             "name": self.user_id,
 | 
			
		||||
             "token_id": 1},
 | 
			
		||||
            (yield self.store.get_user_by_token(self.tokens[0]))
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -63,7 +66,10 @@ class RegistrationStoreTestCase(unittest.TestCase):
 | 
			
		|||
        yield self.store.add_access_token_to_user(self.user_id, self.tokens[1])
 | 
			
		||||
 | 
			
		||||
        self.assertEquals(
 | 
			
		||||
            {"admin": 0, "device_id": None, "name": self.user_id},
 | 
			
		||||
            {"admin": 0,
 | 
			
		||||
             "device_id": None,
 | 
			
		||||
             "name": self.user_id,
 | 
			
		||||
             "token_id": 2},
 | 
			
		||||
            (yield self.store.get_user_by_token(self.tokens[1]))
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue