Merge branch 'develop' into markjh/split_pusher

markjh/split_pusher
Mark Haines 2016-04-14 17:00:40 +01:00
commit c214d3e36e
18 changed files with 240 additions and 130 deletions

View File

@ -36,7 +36,7 @@ then either responds with updates immediately if it already has updates or it
waits until the timeout for more updates. If the timeout expires and nothing
happened then the server returns an empty response.
However until the /sync API this replication API is returning synapse specific
However unlike the /sync API this replication API is returning synapse specific
data rather than trying to implement a matrix specification. The replication
results are returned as arrays of rows where the rows are mostly lifted
directly from the database. This avoids unnecessary JSON parsing on the server

View File

@ -25,6 +25,7 @@ from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
from synapse.types import Requester, RoomID, UserID, EventID
from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_context_over_fn
from synapse.util.metrics import Measure
from unpaddedbase64 import decode_base64
import logging
@ -44,6 +45,7 @@ class Auth(object):
def __init__(self, hs):
self.hs = hs
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
@ -66,9 +68,9 @@ class Auth(object):
Returns:
True if the auth checks pass.
"""
self.check_size_limits(event)
with Measure(self.clock, "auth.check"):
self.check_size_limits(event)
try:
if not hasattr(event, "room_id"):
raise AuthError(500, "Event has no room_id: %s" % event)
if auth_events is None:
@ -127,13 +129,6 @@ class Auth(object):
self.check_redaction(event, auth_events)
logger.debug("Allowing! %s", event)
except AuthError as e:
logger.info(
"Event auth check failed on event %s with msg: %s",
event, e.msg
)
logger.info("Denying! %s", event)
raise
def check_size_limits(self, event):
def too_big(field):

View File

@ -13,10 +13,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config
from ._base import Config, ConfigError
from collections import namedtuple
import sys
MISSING_NETADDR = (
"Missing netaddr library. This is required for URL preview API."
)
MISSING_LXML = (
"""Missing lxml library. This is required for URL preview API.
Install by running:
pip install lxml
Requires libxslt1-dev system package.
"""
)
ThumbnailRequirement = namedtuple(
"ThumbnailRequirement", ["width", "height", "method", "media_type"]
@ -62,18 +76,32 @@ class ContentRepositoryConfig(Config):
self.thumbnail_requirements = parse_thumbnail_requirements(
config["thumbnail_sizes"]
)
self.url_preview_enabled = config["url_preview_enabled"]
self.url_preview_enabled = config.get("url_preview_enabled", False)
if self.url_preview_enabled:
try:
from netaddr import IPSet
if "url_preview_ip_range_blacklist" in config:
self.url_preview_ip_range_blacklist = IPSet(
config["url_preview_ip_range_blacklist"]
)
if "url_preview_url_blacklist" in config:
self.url_preview_url_blacklist = config["url_preview_url_blacklist"]
import lxml
lxml # To stop unused lint.
except ImportError:
sys.stderr.write("\nmissing netaddr dep - disabling preview_url API\n")
raise ConfigError(MISSING_LXML)
try:
from netaddr import IPSet
except ImportError:
raise ConfigError(MISSING_NETADDR)
if "url_preview_ip_range_blacklist" in config:
self.url_preview_ip_range_blacklist = IPSet(
config["url_preview_ip_range_blacklist"]
)
else:
raise ConfigError(
"For security, you must specify an explicit target IP address "
"blacklist in url_preview_ip_range_blacklist for url previewing "
"to work"
)
if "url_preview_url_blacklist" in config:
self.url_preview_url_blacklist = config["url_preview_url_blacklist"]
def default_config(self, **kwargs):
media_store = self.default_path("media_store")

View File

@ -316,7 +316,11 @@ class BaseHandler(object):
if ratelimit:
self.ratelimit(requester)
self.auth.check(event, auth_events=context.current_state)
try:
self.auth.check(event, auth_events=context.current_state)
except AuthError as err:
logger.warn("Denying new event %r because %s", event, err)
raise err
yield self.maybe_kick_guest_users(event, context.current_state.values())

View File

@ -681,9 +681,13 @@ class FederationHandler(BaseHandler):
"state_key": user_id,
})
event, context = yield self._create_new_client_event(
builder=builder,
)
try:
event, context = yield self._create_new_client_event(
builder=builder,
)
except AuthError as e:
logger.warn("Failed to create join %r because %s", event, e)
raise e
self.auth.check(event, auth_events=context.current_state)
@ -915,7 +919,11 @@ class FederationHandler(BaseHandler):
builder=builder,
)
self.auth.check(event, auth_events=context.current_state)
try:
self.auth.check(event, auth_events=context.current_state)
except AuthError as e:
logger.warn("Failed to create new leave %r because %s", event, e)
raise e
defer.returnValue(event)
@ -1512,8 +1520,9 @@ class FederationHandler(BaseHandler):
try:
self.auth.check(event, auth_events=auth_events)
except AuthError:
raise
except AuthError as e:
logger.warn("Failed auth resolution for %r because %s", event, e)
raise e
@defer.inlineCallbacks
def construct_auth_difference(self, local_auth, remote_auth):
@ -1689,7 +1698,12 @@ class FederationHandler(BaseHandler):
event_dict, event, context
)
self.auth.check(event, context.current_state)
try:
self.auth.check(event, context.current_state)
except AuthError as e:
logger.warn("Denying new third party invite %r because %s", event, e)
raise e
yield self._check_signature(event, auth_events=context.current_state)
member_handler = self.hs.get_handlers().room_member_handler
yield member_handler.send_membership_event(None, event, context)
@ -1714,7 +1728,11 @@ class FederationHandler(BaseHandler):
event_dict, event, context
)
self.auth.check(event, auth_events=context.current_state)
try:
self.auth.check(event, auth_events=context.current_state)
except AuthError as e:
logger.warn("Denying third party invite %r because %s", event, e)
raise e
yield self._check_signature(event, auth_events=context.current_state)
returned_invite = yield self.send_invite(origin, event)

View File

@ -358,8 +358,6 @@ class RoomListHandler(BaseHandler):
@defer.inlineCallbacks
def handle_room(room_id):
aliases = yield self.store.get_aliases_for_room(room_id)
# We pull each bit of state out indvidually to avoid pulling the
# full state into memory. Due to how the caching works this should
# be fairly quick, even if not originally in the cache.
@ -374,6 +372,14 @@ class RoomListHandler(BaseHandler):
defer.returnValue(None)
result = {"room_id": room_id}
joined_users = yield self.store.get_users_in_room(room_id)
if len(joined_users) == 0:
return
result["num_joined_members"] = len(joined_users)
aliases = yield self.store.get_aliases_for_room(room_id)
if aliases:
result["aliases"] = aliases
@ -413,9 +419,6 @@ class RoomListHandler(BaseHandler):
if avatar_url:
result["avatar_url"] = avatar_url
joined_users = yield self.store.get_users_in_room(room_id)
result["num_joined_members"] = len(joined_users)
results.append(result)
yield concurrently_execute(handle_room, room_ids, 10)

View File

@ -17,6 +17,8 @@ from twisted.internet import defer
from .bulk_push_rule_evaluator import evaluator_for_event
from synapse.util.metrics import Measure
import logging
logger = logging.getLogger(__name__)
@ -25,6 +27,7 @@ logger = logging.getLogger(__name__)
class ActionGenerator:
def __init__(self, hs):
self.hs = hs
self.clock = hs.get_clock()
self.store = hs.get_datastore()
# really we want to get all user ids and all profile tags too,
# since we want the actions for each profile tag for every user and
@ -35,14 +38,15 @@ class ActionGenerator:
@defer.inlineCallbacks
def handle_push_actions_for_event(self, event, context, handler):
bulk_evaluator = yield evaluator_for_event(
event, self.hs, self.store
)
with Measure(self.clock, "handle_push_actions_for_event"):
bulk_evaluator = yield evaluator_for_event(
event, self.hs, self.store
)
actions_by_user = yield bulk_evaluator.action_for_event_by_user(
event, handler, context.current_state
)
actions_by_user = yield bulk_evaluator.action_for_event_by_user(
event, handler, context.current_state
)
context.push_actions = [
(uid, actions) for uid, actions in actions_by_user.items()
]
context.push_actions = [
(uid, actions) for uid, actions in actions_by_user.items()
]

View File

@ -71,13 +71,25 @@ def _get_rules(room_id, user_ids, store):
@defer.inlineCallbacks
def evaluator_for_event(event, hs, store):
room_id = event.room_id
# users in the room who have pushers need to get push rules run because
# that's how their pushers work
users_with_pushers = yield store.get_users_with_pushers_in_room(room_id)
# We also will want to generate notifs for other people in the room so
# their unread countss are correct in the event stream, but to avoid
# generating them for bot / AS users etc, we only do so for people who've
# sent a read receipt into the room.
all_in_room = yield store.get_users_in_room(room_id)
all_in_room = set(all_in_room)
receipts = yield store.get_receipts_for_room(room_id, "m.read")
# any users with pushers must be ours: they have pushers
user_ids = set(users_with_pushers)
for r in receipts:
if hs.is_mine_id(r['user_id']):
if hs.is_mine_id(r['user_id']) and r['user_id'] in all_in_room:
user_ids.add(r['user_id'])
# if this event is an invite event, we may need to run rules for the user

View File

@ -21,6 +21,7 @@ import logging
import push_rule_evaluator
import push_tools
from synapse.util.logcontext import LoggingContext
from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
@ -85,9 +86,8 @@ class HttpPusher(object):
@defer.inlineCallbacks
def on_new_notifications(self, min_stream_ordering, max_stream_ordering):
with Measure(self.clock, "push.on_new_notifications"):
self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering)
yield self._process()
self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering)
yield self._process()
@defer.inlineCallbacks
def on_new_receipts(self, min_stream_id, max_stream_id):
@ -95,16 +95,16 @@ class HttpPusher(object):
# We could check the receipts are actually m.read receipts here,
# but currently that's the only type of receipt anyway...
with Measure(self.clock, "push.on_new_receipts"):
badge = yield push_tools.get_badge_count(
self.hs.get_datastore(), self.user_id
)
yield self.send_badge(badge)
with LoggingContext("push.on_new_receipts"):
with Measure(self.clock, "push.on_new_receipts"):
badge = yield push_tools.get_badge_count(
self.hs.get_datastore(), self.user_id
)
yield self._send_badge(badge)
@defer.inlineCallbacks
def on_timer(self):
with Measure(self.clock, "push.on_timer"):
yield self._process()
yield self._process()
def on_stop(self):
if self.timed_call:
@ -114,20 +114,23 @@ class HttpPusher(object):
def _process(self):
if self.processing:
return
try:
self.processing = True
# if the max ordering changes while we're running _unsafe_process,
# call it again, and so on until we've caught up.
while True:
starting_max_ordering = self.max_stream_ordering
with LoggingContext("push._process"):
with Measure(self.clock, "push._process"):
try:
yield self._unsafe_process()
except:
logger.exception("Exception processing notifs")
if self.max_stream_ordering == starting_max_ordering:
break
finally:
self.processing = False
self.processing = True
# if the max ordering changes while we're running _unsafe_process,
# call it again, and so on until we've caught up.
while True:
starting_max_ordering = self.max_stream_ordering
try:
yield self._unsafe_process()
except:
logger.exception("Exception processing notifs")
if self.max_stream_ordering == starting_max_ordering:
break
finally:
self.processing = False
@defer.inlineCallbacks
def _unsafe_process(self):
@ -146,7 +149,7 @@ class HttpPusher(object):
if processed:
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
self.last_stream_ordering = push_action['stream_ordering']
self.store.update_pusher_last_stream_ordering_and_success(
yield self.store.update_pusher_last_stream_ordering_and_success(
self.app_id, self.pushkey, self.user_id,
self.last_stream_ordering,
self.clock.time_msec()
@ -291,7 +294,7 @@ class HttpPusher(object):
defer.returnValue(rejected)
@defer.inlineCallbacks
def send_badge(self, badge):
def _send_badge(self, badge):
logger.info("Sending updated badge count %d to %r", badge, self.user_id)
d = {
'notification': {

View File

@ -33,7 +33,6 @@ class PusherPool:
self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock()
self.pushers = {}
self.last_pusher_started = -1
@defer.inlineCallbacks
def start(self):

View File

@ -43,7 +43,6 @@ CONDITIONAL_REQUIREMENTS = {
"matrix_angular_sdk>=0.6.8": ["syweb>=0.6.8"],
},
"preview_url": {
"lxml>=3.6.0": ["lxml"],
"netaddr>=0.7.18": ["netaddr"],
},
}

View File

@ -100,6 +100,11 @@ class RegisterRestServlet(RestServlet):
# == Application Service Registration ==
if appservice:
# Set the desired user according to the AS API (which uses the
# 'user' key not 'username'). Since this is a new addition, we'll
# fallback to 'username' if they gave one.
if isinstance(body.get("user"), basestring):
desired_username = body["user"]
result = yield self._do_appservice_registration(
desired_username, request.args["access_token"][0]
)

View File

@ -80,8 +80,4 @@ class MediaRepositoryResource(Resource):
self.putChild("thumbnail", ThumbnailResource(hs, filepaths))
self.putChild("identicon", IdenticonResource())
if hs.config.url_preview_enabled:
try:
self.putChild("preview_url", PreviewUrlResource(hs, filepaths))
except Exception as e:
logger.warn("Failed to mount preview_url")
logger.exception(e)
self.putChild("preview_url", PreviewUrlResource(hs, filepaths))

View File

@ -17,7 +17,6 @@ from .base_resource import BaseMediaResource
from twisted.web.server import NOT_DONE_YET
from twisted.internet import defer
from urlparse import urlparse, urlsplit, urlunparse
from synapse.api.errors import (
SynapseError, Codes,
@ -36,37 +35,16 @@ import re
import fnmatch
import cgi
import ujson as json
import urlparse
import logging
logger = logging.getLogger(__name__)
try:
from lxml import html
except ImportError:
pass
class PreviewUrlResource(BaseMediaResource):
isLeaf = True
def __init__(self, hs, filepaths):
try:
if html:
pass
except:
raise RuntimeError("Disabling PreviewUrlResource as lxml not available")
if not hasattr(hs.config, "url_preview_ip_range_blacklist"):
logger.warn(
"For security, you must specify an explicit target IP address "
"blacklist in url_preview_ip_range_blacklist for url previewing "
"to work"
)
raise RuntimeError(
"Disabling PreviewUrlResource as "
"url_preview_ip_range_blacklist not specified"
)
BaseMediaResource.__init__(self, hs, filepaths)
self.client = SpiderHttpClient(hs)
if hasattr(hs.config, "url_preview_url_blacklist"):
@ -101,7 +79,7 @@ class PreviewUrlResource(BaseMediaResource):
# impose the URL pattern blacklist
if hasattr(self, "url_preview_url_blacklist"):
url_tuple = urlsplit(url)
url_tuple = urlparse.urlsplit(url)
for entry in self.url_preview_url_blacklist:
match = True
for attrib in entry:
@ -201,6 +179,8 @@ class PreviewUrlResource(BaseMediaResource):
elif self._is_html(media_info['media_type']):
# TODO: somehow stop a big HTML tree from exploding synapse's RAM
from lxml import html
try:
tree = html.parse(media_info['filename'])
og = yield self._calc_og(tree, media_info, requester)
@ -358,15 +338,15 @@ class PreviewUrlResource(BaseMediaResource):
defer.returnValue(og)
def _rebase_url(self, url, base):
base = list(urlparse(base))
url = list(urlparse(url))
base = list(urlparse.urlparse(base))
url = list(urlparse.urlparse(url))
if not url[0]: # fix up schema
url[0] = base[0] or "http"
if not url[1]: # fix up hostname
url[1] = base[1]
if not url[2].startswith('/'):
url[2] = re.sub(r'/[^/]+$', '/', base[2]) + url[2]
return urlunparse(url)
return urlparse.urlunparse(url)
@defer.inlineCallbacks
def _download_url(self, url, user):

View File

@ -72,10 +72,10 @@ class ThumbnailResource(BaseMediaResource):
self._respond_404(request)
return
if media_info["media_type"] == "image/svg+xml":
file_path = self.filepaths.local_media_filepath(media_id)
yield self._respond_with_file(request, media_info["media_type"], file_path)
return
# if media_info["media_type"] == "image/svg+xml":
# file_path = self.filepaths.local_media_filepath(media_id)
# yield self._respond_with_file(request, media_info["media_type"], file_path)
# return
thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
@ -108,10 +108,10 @@ class ThumbnailResource(BaseMediaResource):
self._respond_404(request)
return
if media_info["media_type"] == "image/svg+xml":
file_path = self.filepaths.local_media_filepath(media_id)
yield self._respond_with_file(request, media_info["media_type"], file_path)
return
# if media_info["media_type"] == "image/svg+xml":
# file_path = self.filepaths.local_media_filepath(media_id)
# yield self._respond_with_file(request, media_info["media_type"], file_path)
# return
thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
for info in thumbnail_infos:
@ -148,10 +148,10 @@ class ThumbnailResource(BaseMediaResource):
desired_method, desired_type):
media_info = yield self._get_remote_media(server_name, media_id)
if media_info["media_type"] == "image/svg+xml":
file_path = self.filepaths.remote_media_filepath(server_name, media_id)
yield self._respond_with_file(request, media_info["media_type"], file_path)
return
# if media_info["media_type"] == "image/svg+xml":
# file_path = self.filepaths.remote_media_filepath(server_name, media_id)
# yield self._respond_with_file(request, media_info["media_type"], file_path)
# return
thumbnail_infos = yield self.store.get_remote_media_thumbnails(
server_name, media_id,
@ -196,10 +196,10 @@ class ThumbnailResource(BaseMediaResource):
# We should proxy the thumbnail from the remote server instead.
media_info = yield self._get_remote_media(server_name, media_id)
if media_info["media_type"] == "image/svg+xml":
file_path = self.filepaths.remote_media_filepath(server_name, media_id)
yield self._respond_with_file(request, media_info["media_type"], file_path)
return
# if media_info["media_type"] == "image/svg+xml":
# file_path = self.filepaths.remote_media_filepath(server_name, media_id)
# yield self._respond_with_file(request, media_info["media_type"], file_path)
# return
thumbnail_infos = yield self.store.get_remote_media_thumbnails(
server_name, media_id,

View File

@ -116,26 +116,68 @@ class EventPushActionsStore(SQLBaseStore):
def get_unread_push_actions_for_user_in_range(self, user_id,
min_stream_ordering,
max_stream_ordering=None):
def f(txn):
def get_after_receipt(txn):
sql = (
"SELECT event_id, stream_ordering, actions"
" FROM event_push_actions"
" WHERE user_id = ? AND stream_ordering > ?"
"SELECT ep.event_id, ep.stream_ordering, ep.actions "
"FROM event_push_actions AS ep, ("
" SELECT room_id, user_id,"
" max(topological_ordering) as topological_ordering,"
" max(stream_ordering) as stream_ordering"
" FROM events"
" NATURAL JOIN receipts_linearized WHERE receipt_type = 'm.read'"
" GROUP BY room_id, user_id"
") AS rl "
"WHERE"
" ep.room_id = rl.room_id"
" AND ("
" ep.topological_ordering > rl.topological_ordering"
" OR ("
" ep.topological_ordering = rl.topological_ordering"
" AND ep.stream_ordering > rl.stream_ordering"
" )"
" )"
" AND ep.stream_ordering > ?"
" AND ep.user_id = ?"
" AND ep.user_id = rl.user_id"
)
args = [user_id, min_stream_ordering]
args = [min_stream_ordering, user_id]
if max_stream_ordering is not None:
sql += " AND stream_ordering <= ?"
sql += " AND ep.stream_ordering <= ?"
args.append(max_stream_ordering)
sql += " ORDER BY stream_ordering ASC"
sql += " ORDER BY ep.stream_ordering ASC"
txn.execute(sql, args)
return txn.fetchall()
ret = yield self.runInteraction("get_unread_push_actions_for_user_in_range", f)
after_read_receipt = yield self.runInteraction(
"get_unread_push_actions_for_user_in_range", get_after_receipt
)
def get_no_receipt(txn):
sql = (
"SELECT ep.event_id, ep.stream_ordering, ep.actions "
"FROM event_push_actions AS ep "
"WHERE ep.room_id not in ("
" SELECT room_id FROM events NATURAL JOIN receipts_linearized"
" WHERE receipt_type = 'm.read' AND user_id = ? "
" GROUP BY room_id"
") AND ep.user_id = ? AND ep.stream_ordering > ?"
)
args = [user_id, user_id, min_stream_ordering]
if max_stream_ordering is not None:
sql += " AND ep.stream_ordering <= ?"
args.append(max_stream_ordering)
sql += " ORDER BY ep.stream_ordering ASC"
txn.execute(sql, args)
return txn.fetchall()
no_read_receipt = yield self.runInteraction(
"get_unread_push_actions_for_user_in_range", get_no_receipt
)
defer.returnValue([
{
"event_id": row[0],
"stream_ordering": row[1],
"actions": json.loads(row[2]),
} for row in ret
} for row in after_read_receipt + no_read_receipt
])
@defer.inlineCallbacks

View File

@ -137,7 +137,11 @@ class PusherStore(SQLBaseStore):
users = yield self.get_users_in_room(room_id)
result = yield self._simple_select_many_batch(
'pushers', 'user_name', users, ['user_name']
table='pushers',
column='user_name',
iterable=users,
retcols=['user_name'],
desc='get_users_with_pushers_in_room'
)
defer.returnValue([r['user_name'] for r in result])

View File

@ -0,0 +1,18 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE INDEX event_push_actions_stream_ordering on event_push_actions(
stream_ordering, user_id
);