Merge branch 'develop' into markjh/split_pusher

markjh/split_pusher
Mark Haines 2016-04-14 17:00:40 +01:00
commit c214d3e36e
18 changed files with 240 additions and 130 deletions

View File

@ -36,7 +36,7 @@ then either responds with updates immediately if it already has updates or it
waits until the timeout for more updates. If the timeout expires and nothing waits until the timeout for more updates. If the timeout expires and nothing
happened then the server returns an empty response. happened then the server returns an empty response.
However until the /sync API this replication API is returning synapse specific However unlike the /sync API this replication API is returning synapse specific
data rather than trying to implement a matrix specification. The replication data rather than trying to implement a matrix specification. The replication
results are returned as arrays of rows where the rows are mostly lifted results are returned as arrays of rows where the rows are mostly lifted
directly from the database. This avoids unnecessary JSON parsing on the server directly from the database. This avoids unnecessary JSON parsing on the server

View File

@ -25,6 +25,7 @@ from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
from synapse.types import Requester, RoomID, UserID, EventID from synapse.types import Requester, RoomID, UserID, EventID
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_context_over_fn from synapse.util.logcontext import preserve_context_over_fn
from synapse.util.metrics import Measure
from unpaddedbase64 import decode_base64 from unpaddedbase64 import decode_base64
import logging import logging
@ -44,6 +45,7 @@ class Auth(object):
def __init__(self, hs): def __init__(self, hs):
self.hs = hs self.hs = hs
self.clock = hs.get_clock()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.TOKEN_NOT_FOUND_HTTP_STATUS = 401 self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
@ -66,9 +68,9 @@ class Auth(object):
Returns: Returns:
True if the auth checks pass. True if the auth checks pass.
""" """
self.check_size_limits(event) with Measure(self.clock, "auth.check"):
self.check_size_limits(event)
try:
if not hasattr(event, "room_id"): if not hasattr(event, "room_id"):
raise AuthError(500, "Event has no room_id: %s" % event) raise AuthError(500, "Event has no room_id: %s" % event)
if auth_events is None: if auth_events is None:
@ -127,13 +129,6 @@ class Auth(object):
self.check_redaction(event, auth_events) self.check_redaction(event, auth_events)
logger.debug("Allowing! %s", event) logger.debug("Allowing! %s", event)
except AuthError as e:
logger.info(
"Event auth check failed on event %s with msg: %s",
event, e.msg
)
logger.info("Denying! %s", event)
raise
def check_size_limits(self, event): def check_size_limits(self, event):
def too_big(field): def too_big(field):

View File

@ -13,10 +13,24 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import Config from ._base import Config, ConfigError
from collections import namedtuple from collections import namedtuple
import sys
MISSING_NETADDR = (
"Missing netaddr library. This is required for URL preview API."
)
MISSING_LXML = (
"""Missing lxml library. This is required for URL preview API.
Install by running:
pip install lxml
Requires libxslt1-dev system package.
"""
)
ThumbnailRequirement = namedtuple( ThumbnailRequirement = namedtuple(
"ThumbnailRequirement", ["width", "height", "method", "media_type"] "ThumbnailRequirement", ["width", "height", "method", "media_type"]
@ -62,18 +76,32 @@ class ContentRepositoryConfig(Config):
self.thumbnail_requirements = parse_thumbnail_requirements( self.thumbnail_requirements = parse_thumbnail_requirements(
config["thumbnail_sizes"] config["thumbnail_sizes"]
) )
self.url_preview_enabled = config["url_preview_enabled"] self.url_preview_enabled = config.get("url_preview_enabled", False)
if self.url_preview_enabled: if self.url_preview_enabled:
try: try:
from netaddr import IPSet import lxml
if "url_preview_ip_range_blacklist" in config: lxml # To stop unused lint.
self.url_preview_ip_range_blacklist = IPSet(
config["url_preview_ip_range_blacklist"]
)
if "url_preview_url_blacklist" in config:
self.url_preview_url_blacklist = config["url_preview_url_blacklist"]
except ImportError: except ImportError:
sys.stderr.write("\nmissing netaddr dep - disabling preview_url API\n") raise ConfigError(MISSING_LXML)
try:
from netaddr import IPSet
except ImportError:
raise ConfigError(MISSING_NETADDR)
if "url_preview_ip_range_blacklist" in config:
self.url_preview_ip_range_blacklist = IPSet(
config["url_preview_ip_range_blacklist"]
)
else:
raise ConfigError(
"For security, you must specify an explicit target IP address "
"blacklist in url_preview_ip_range_blacklist for url previewing "
"to work"
)
if "url_preview_url_blacklist" in config:
self.url_preview_url_blacklist = config["url_preview_url_blacklist"]
def default_config(self, **kwargs): def default_config(self, **kwargs):
media_store = self.default_path("media_store") media_store = self.default_path("media_store")

View File

@ -316,7 +316,11 @@ class BaseHandler(object):
if ratelimit: if ratelimit:
self.ratelimit(requester) self.ratelimit(requester)
self.auth.check(event, auth_events=context.current_state) try:
self.auth.check(event, auth_events=context.current_state)
except AuthError as err:
logger.warn("Denying new event %r because %s", event, err)
raise err
yield self.maybe_kick_guest_users(event, context.current_state.values()) yield self.maybe_kick_guest_users(event, context.current_state.values())

View File

@ -681,9 +681,13 @@ class FederationHandler(BaseHandler):
"state_key": user_id, "state_key": user_id,
}) })
event, context = yield self._create_new_client_event( try:
builder=builder, event, context = yield self._create_new_client_event(
) builder=builder,
)
except AuthError as e:
logger.warn("Failed to create join %r because %s", event, e)
raise e
self.auth.check(event, auth_events=context.current_state) self.auth.check(event, auth_events=context.current_state)
@ -915,7 +919,11 @@ class FederationHandler(BaseHandler):
builder=builder, builder=builder,
) )
self.auth.check(event, auth_events=context.current_state) try:
self.auth.check(event, auth_events=context.current_state)
except AuthError as e:
logger.warn("Failed to create new leave %r because %s", event, e)
raise e
defer.returnValue(event) defer.returnValue(event)
@ -1512,8 +1520,9 @@ class FederationHandler(BaseHandler):
try: try:
self.auth.check(event, auth_events=auth_events) self.auth.check(event, auth_events=auth_events)
except AuthError: except AuthError as e:
raise logger.warn("Failed auth resolution for %r because %s", event, e)
raise e
@defer.inlineCallbacks @defer.inlineCallbacks
def construct_auth_difference(self, local_auth, remote_auth): def construct_auth_difference(self, local_auth, remote_auth):
@ -1689,7 +1698,12 @@ class FederationHandler(BaseHandler):
event_dict, event, context event_dict, event, context
) )
self.auth.check(event, context.current_state) try:
self.auth.check(event, context.current_state)
except AuthError as e:
logger.warn("Denying new third party invite %r because %s", event, e)
raise e
yield self._check_signature(event, auth_events=context.current_state) yield self._check_signature(event, auth_events=context.current_state)
member_handler = self.hs.get_handlers().room_member_handler member_handler = self.hs.get_handlers().room_member_handler
yield member_handler.send_membership_event(None, event, context) yield member_handler.send_membership_event(None, event, context)
@ -1714,7 +1728,11 @@ class FederationHandler(BaseHandler):
event_dict, event, context event_dict, event, context
) )
self.auth.check(event, auth_events=context.current_state) try:
self.auth.check(event, auth_events=context.current_state)
except AuthError as e:
logger.warn("Denying third party invite %r because %s", event, e)
raise e
yield self._check_signature(event, auth_events=context.current_state) yield self._check_signature(event, auth_events=context.current_state)
returned_invite = yield self.send_invite(origin, event) returned_invite = yield self.send_invite(origin, event)

View File

@ -358,8 +358,6 @@ class RoomListHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_room(room_id): def handle_room(room_id):
aliases = yield self.store.get_aliases_for_room(room_id)
# We pull each bit of state out indvidually to avoid pulling the # We pull each bit of state out indvidually to avoid pulling the
# full state into memory. Due to how the caching works this should # full state into memory. Due to how the caching works this should
# be fairly quick, even if not originally in the cache. # be fairly quick, even if not originally in the cache.
@ -374,6 +372,14 @@ class RoomListHandler(BaseHandler):
defer.returnValue(None) defer.returnValue(None)
result = {"room_id": room_id} result = {"room_id": room_id}
joined_users = yield self.store.get_users_in_room(room_id)
if len(joined_users) == 0:
return
result["num_joined_members"] = len(joined_users)
aliases = yield self.store.get_aliases_for_room(room_id)
if aliases: if aliases:
result["aliases"] = aliases result["aliases"] = aliases
@ -413,9 +419,6 @@ class RoomListHandler(BaseHandler):
if avatar_url: if avatar_url:
result["avatar_url"] = avatar_url result["avatar_url"] = avatar_url
joined_users = yield self.store.get_users_in_room(room_id)
result["num_joined_members"] = len(joined_users)
results.append(result) results.append(result)
yield concurrently_execute(handle_room, room_ids, 10) yield concurrently_execute(handle_room, room_ids, 10)

View File

@ -17,6 +17,8 @@ from twisted.internet import defer
from .bulk_push_rule_evaluator import evaluator_for_event from .bulk_push_rule_evaluator import evaluator_for_event
from synapse.util.metrics import Measure
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -25,6 +27,7 @@ logger = logging.getLogger(__name__)
class ActionGenerator: class ActionGenerator:
def __init__(self, hs): def __init__(self, hs):
self.hs = hs self.hs = hs
self.clock = hs.get_clock()
self.store = hs.get_datastore() self.store = hs.get_datastore()
# really we want to get all user ids and all profile tags too, # really we want to get all user ids and all profile tags too,
# since we want the actions for each profile tag for every user and # since we want the actions for each profile tag for every user and
@ -35,14 +38,15 @@ class ActionGenerator:
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_push_actions_for_event(self, event, context, handler): def handle_push_actions_for_event(self, event, context, handler):
bulk_evaluator = yield evaluator_for_event( with Measure(self.clock, "handle_push_actions_for_event"):
event, self.hs, self.store bulk_evaluator = yield evaluator_for_event(
) event, self.hs, self.store
)
actions_by_user = yield bulk_evaluator.action_for_event_by_user( actions_by_user = yield bulk_evaluator.action_for_event_by_user(
event, handler, context.current_state event, handler, context.current_state
) )
context.push_actions = [ context.push_actions = [
(uid, actions) for uid, actions in actions_by_user.items() (uid, actions) for uid, actions in actions_by_user.items()
] ]

View File

@ -71,13 +71,25 @@ def _get_rules(room_id, user_ids, store):
@defer.inlineCallbacks @defer.inlineCallbacks
def evaluator_for_event(event, hs, store): def evaluator_for_event(event, hs, store):
room_id = event.room_id room_id = event.room_id
# users in the room who have pushers need to get push rules run because
# that's how their pushers work
users_with_pushers = yield store.get_users_with_pushers_in_room(room_id) users_with_pushers = yield store.get_users_with_pushers_in_room(room_id)
# We also will want to generate notifs for other people in the room so
# their unread countss are correct in the event stream, but to avoid
# generating them for bot / AS users etc, we only do so for people who've
# sent a read receipt into the room.
all_in_room = yield store.get_users_in_room(room_id)
all_in_room = set(all_in_room)
receipts = yield store.get_receipts_for_room(room_id, "m.read") receipts = yield store.get_receipts_for_room(room_id, "m.read")
# any users with pushers must be ours: they have pushers # any users with pushers must be ours: they have pushers
user_ids = set(users_with_pushers) user_ids = set(users_with_pushers)
for r in receipts: for r in receipts:
if hs.is_mine_id(r['user_id']): if hs.is_mine_id(r['user_id']) and r['user_id'] in all_in_room:
user_ids.add(r['user_id']) user_ids.add(r['user_id'])
# if this event is an invite event, we may need to run rules for the user # if this event is an invite event, we may need to run rules for the user

View File

@ -21,6 +21,7 @@ import logging
import push_rule_evaluator import push_rule_evaluator
import push_tools import push_tools
from synapse.util.logcontext import LoggingContext
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -85,9 +86,8 @@ class HttpPusher(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_new_notifications(self, min_stream_ordering, max_stream_ordering): def on_new_notifications(self, min_stream_ordering, max_stream_ordering):
with Measure(self.clock, "push.on_new_notifications"): self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering)
self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering) yield self._process()
yield self._process()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_new_receipts(self, min_stream_id, max_stream_id): def on_new_receipts(self, min_stream_id, max_stream_id):
@ -95,16 +95,16 @@ class HttpPusher(object):
# We could check the receipts are actually m.read receipts here, # We could check the receipts are actually m.read receipts here,
# but currently that's the only type of receipt anyway... # but currently that's the only type of receipt anyway...
with Measure(self.clock, "push.on_new_receipts"): with LoggingContext("push.on_new_receipts"):
badge = yield push_tools.get_badge_count( with Measure(self.clock, "push.on_new_receipts"):
self.hs.get_datastore(), self.user_id badge = yield push_tools.get_badge_count(
) self.hs.get_datastore(), self.user_id
yield self.send_badge(badge) )
yield self._send_badge(badge)
@defer.inlineCallbacks @defer.inlineCallbacks
def on_timer(self): def on_timer(self):
with Measure(self.clock, "push.on_timer"): yield self._process()
yield self._process()
def on_stop(self): def on_stop(self):
if self.timed_call: if self.timed_call:
@ -114,20 +114,23 @@ class HttpPusher(object):
def _process(self): def _process(self):
if self.processing: if self.processing:
return return
try:
self.processing = True with LoggingContext("push._process"):
# if the max ordering changes while we're running _unsafe_process, with Measure(self.clock, "push._process"):
# call it again, and so on until we've caught up.
while True:
starting_max_ordering = self.max_stream_ordering
try: try:
yield self._unsafe_process() self.processing = True
except: # if the max ordering changes while we're running _unsafe_process,
logger.exception("Exception processing notifs") # call it again, and so on until we've caught up.
if self.max_stream_ordering == starting_max_ordering: while True:
break starting_max_ordering = self.max_stream_ordering
finally: try:
self.processing = False yield self._unsafe_process()
except:
logger.exception("Exception processing notifs")
if self.max_stream_ordering == starting_max_ordering:
break
finally:
self.processing = False
@defer.inlineCallbacks @defer.inlineCallbacks
def _unsafe_process(self): def _unsafe_process(self):
@ -146,7 +149,7 @@ class HttpPusher(object):
if processed: if processed:
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
self.last_stream_ordering = push_action['stream_ordering'] self.last_stream_ordering = push_action['stream_ordering']
self.store.update_pusher_last_stream_ordering_and_success( yield self.store.update_pusher_last_stream_ordering_and_success(
self.app_id, self.pushkey, self.user_id, self.app_id, self.pushkey, self.user_id,
self.last_stream_ordering, self.last_stream_ordering,
self.clock.time_msec() self.clock.time_msec()
@ -291,7 +294,7 @@ class HttpPusher(object):
defer.returnValue(rejected) defer.returnValue(rejected)
@defer.inlineCallbacks @defer.inlineCallbacks
def send_badge(self, badge): def _send_badge(self, badge):
logger.info("Sending updated badge count %d to %r", badge, self.user_id) logger.info("Sending updated badge count %d to %r", badge, self.user_id)
d = { d = {
'notification': { 'notification': {

View File

@ -33,7 +33,6 @@ class PusherPool:
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock() self.clock = self.hs.get_clock()
self.pushers = {} self.pushers = {}
self.last_pusher_started = -1
@defer.inlineCallbacks @defer.inlineCallbacks
def start(self): def start(self):

View File

@ -43,7 +43,6 @@ CONDITIONAL_REQUIREMENTS = {
"matrix_angular_sdk>=0.6.8": ["syweb>=0.6.8"], "matrix_angular_sdk>=0.6.8": ["syweb>=0.6.8"],
}, },
"preview_url": { "preview_url": {
"lxml>=3.6.0": ["lxml"],
"netaddr>=0.7.18": ["netaddr"], "netaddr>=0.7.18": ["netaddr"],
}, },
} }

View File

@ -100,6 +100,11 @@ class RegisterRestServlet(RestServlet):
# == Application Service Registration == # == Application Service Registration ==
if appservice: if appservice:
# Set the desired user according to the AS API (which uses the
# 'user' key not 'username'). Since this is a new addition, we'll
# fallback to 'username' if they gave one.
if isinstance(body.get("user"), basestring):
desired_username = body["user"]
result = yield self._do_appservice_registration( result = yield self._do_appservice_registration(
desired_username, request.args["access_token"][0] desired_username, request.args["access_token"][0]
) )

View File

@ -80,8 +80,4 @@ class MediaRepositoryResource(Resource):
self.putChild("thumbnail", ThumbnailResource(hs, filepaths)) self.putChild("thumbnail", ThumbnailResource(hs, filepaths))
self.putChild("identicon", IdenticonResource()) self.putChild("identicon", IdenticonResource())
if hs.config.url_preview_enabled: if hs.config.url_preview_enabled:
try: self.putChild("preview_url", PreviewUrlResource(hs, filepaths))
self.putChild("preview_url", PreviewUrlResource(hs, filepaths))
except Exception as e:
logger.warn("Failed to mount preview_url")
logger.exception(e)

View File

@ -17,7 +17,6 @@ from .base_resource import BaseMediaResource
from twisted.web.server import NOT_DONE_YET from twisted.web.server import NOT_DONE_YET
from twisted.internet import defer from twisted.internet import defer
from urlparse import urlparse, urlsplit, urlunparse
from synapse.api.errors import ( from synapse.api.errors import (
SynapseError, Codes, SynapseError, Codes,
@ -36,37 +35,16 @@ import re
import fnmatch import fnmatch
import cgi import cgi
import ujson as json import ujson as json
import urlparse
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
try:
from lxml import html
except ImportError:
pass
class PreviewUrlResource(BaseMediaResource): class PreviewUrlResource(BaseMediaResource):
isLeaf = True isLeaf = True
def __init__(self, hs, filepaths): def __init__(self, hs, filepaths):
try:
if html:
pass
except:
raise RuntimeError("Disabling PreviewUrlResource as lxml not available")
if not hasattr(hs.config, "url_preview_ip_range_blacklist"):
logger.warn(
"For security, you must specify an explicit target IP address "
"blacklist in url_preview_ip_range_blacklist for url previewing "
"to work"
)
raise RuntimeError(
"Disabling PreviewUrlResource as "
"url_preview_ip_range_blacklist not specified"
)
BaseMediaResource.__init__(self, hs, filepaths) BaseMediaResource.__init__(self, hs, filepaths)
self.client = SpiderHttpClient(hs) self.client = SpiderHttpClient(hs)
if hasattr(hs.config, "url_preview_url_blacklist"): if hasattr(hs.config, "url_preview_url_blacklist"):
@ -101,7 +79,7 @@ class PreviewUrlResource(BaseMediaResource):
# impose the URL pattern blacklist # impose the URL pattern blacklist
if hasattr(self, "url_preview_url_blacklist"): if hasattr(self, "url_preview_url_blacklist"):
url_tuple = urlsplit(url) url_tuple = urlparse.urlsplit(url)
for entry in self.url_preview_url_blacklist: for entry in self.url_preview_url_blacklist:
match = True match = True
for attrib in entry: for attrib in entry:
@ -201,6 +179,8 @@ class PreviewUrlResource(BaseMediaResource):
elif self._is_html(media_info['media_type']): elif self._is_html(media_info['media_type']):
# TODO: somehow stop a big HTML tree from exploding synapse's RAM # TODO: somehow stop a big HTML tree from exploding synapse's RAM
from lxml import html
try: try:
tree = html.parse(media_info['filename']) tree = html.parse(media_info['filename'])
og = yield self._calc_og(tree, media_info, requester) og = yield self._calc_og(tree, media_info, requester)
@ -358,15 +338,15 @@ class PreviewUrlResource(BaseMediaResource):
defer.returnValue(og) defer.returnValue(og)
def _rebase_url(self, url, base): def _rebase_url(self, url, base):
base = list(urlparse(base)) base = list(urlparse.urlparse(base))
url = list(urlparse(url)) url = list(urlparse.urlparse(url))
if not url[0]: # fix up schema if not url[0]: # fix up schema
url[0] = base[0] or "http" url[0] = base[0] or "http"
if not url[1]: # fix up hostname if not url[1]: # fix up hostname
url[1] = base[1] url[1] = base[1]
if not url[2].startswith('/'): if not url[2].startswith('/'):
url[2] = re.sub(r'/[^/]+$', '/', base[2]) + url[2] url[2] = re.sub(r'/[^/]+$', '/', base[2]) + url[2]
return urlunparse(url) return urlparse.urlunparse(url)
@defer.inlineCallbacks @defer.inlineCallbacks
def _download_url(self, url, user): def _download_url(self, url, user):

View File

@ -72,10 +72,10 @@ class ThumbnailResource(BaseMediaResource):
self._respond_404(request) self._respond_404(request)
return return
if media_info["media_type"] == "image/svg+xml": # if media_info["media_type"] == "image/svg+xml":
file_path = self.filepaths.local_media_filepath(media_id) # file_path = self.filepaths.local_media_filepath(media_id)
yield self._respond_with_file(request, media_info["media_type"], file_path) # yield self._respond_with_file(request, media_info["media_type"], file_path)
return # return
thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id) thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
@ -108,10 +108,10 @@ class ThumbnailResource(BaseMediaResource):
self._respond_404(request) self._respond_404(request)
return return
if media_info["media_type"] == "image/svg+xml": # if media_info["media_type"] == "image/svg+xml":
file_path = self.filepaths.local_media_filepath(media_id) # file_path = self.filepaths.local_media_filepath(media_id)
yield self._respond_with_file(request, media_info["media_type"], file_path) # yield self._respond_with_file(request, media_info["media_type"], file_path)
return # return
thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id) thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
for info in thumbnail_infos: for info in thumbnail_infos:
@ -148,10 +148,10 @@ class ThumbnailResource(BaseMediaResource):
desired_method, desired_type): desired_method, desired_type):
media_info = yield self._get_remote_media(server_name, media_id) media_info = yield self._get_remote_media(server_name, media_id)
if media_info["media_type"] == "image/svg+xml": # if media_info["media_type"] == "image/svg+xml":
file_path = self.filepaths.remote_media_filepath(server_name, media_id) # file_path = self.filepaths.remote_media_filepath(server_name, media_id)
yield self._respond_with_file(request, media_info["media_type"], file_path) # yield self._respond_with_file(request, media_info["media_type"], file_path)
return # return
thumbnail_infos = yield self.store.get_remote_media_thumbnails( thumbnail_infos = yield self.store.get_remote_media_thumbnails(
server_name, media_id, server_name, media_id,
@ -196,10 +196,10 @@ class ThumbnailResource(BaseMediaResource):
# We should proxy the thumbnail from the remote server instead. # We should proxy the thumbnail from the remote server instead.
media_info = yield self._get_remote_media(server_name, media_id) media_info = yield self._get_remote_media(server_name, media_id)
if media_info["media_type"] == "image/svg+xml": # if media_info["media_type"] == "image/svg+xml":
file_path = self.filepaths.remote_media_filepath(server_name, media_id) # file_path = self.filepaths.remote_media_filepath(server_name, media_id)
yield self._respond_with_file(request, media_info["media_type"], file_path) # yield self._respond_with_file(request, media_info["media_type"], file_path)
return # return
thumbnail_infos = yield self.store.get_remote_media_thumbnails( thumbnail_infos = yield self.store.get_remote_media_thumbnails(
server_name, media_id, server_name, media_id,

View File

@ -116,26 +116,68 @@ class EventPushActionsStore(SQLBaseStore):
def get_unread_push_actions_for_user_in_range(self, user_id, def get_unread_push_actions_for_user_in_range(self, user_id,
min_stream_ordering, min_stream_ordering,
max_stream_ordering=None): max_stream_ordering=None):
def f(txn): def get_after_receipt(txn):
sql = ( sql = (
"SELECT event_id, stream_ordering, actions" "SELECT ep.event_id, ep.stream_ordering, ep.actions "
" FROM event_push_actions" "FROM event_push_actions AS ep, ("
" WHERE user_id = ? AND stream_ordering > ?" " SELECT room_id, user_id,"
" max(topological_ordering) as topological_ordering,"
" max(stream_ordering) as stream_ordering"
" FROM events"
" NATURAL JOIN receipts_linearized WHERE receipt_type = 'm.read'"
" GROUP BY room_id, user_id"
") AS rl "
"WHERE"
" ep.room_id = rl.room_id"
" AND ("
" ep.topological_ordering > rl.topological_ordering"
" OR ("
" ep.topological_ordering = rl.topological_ordering"
" AND ep.stream_ordering > rl.stream_ordering"
" )"
" )"
" AND ep.stream_ordering > ?"
" AND ep.user_id = ?"
" AND ep.user_id = rl.user_id"
) )
args = [user_id, min_stream_ordering] args = [min_stream_ordering, user_id]
if max_stream_ordering is not None: if max_stream_ordering is not None:
sql += " AND stream_ordering <= ?" sql += " AND ep.stream_ordering <= ?"
args.append(max_stream_ordering) args.append(max_stream_ordering)
sql += " ORDER BY stream_ordering ASC" sql += " ORDER BY ep.stream_ordering ASC"
txn.execute(sql, args) txn.execute(sql, args)
return txn.fetchall() return txn.fetchall()
ret = yield self.runInteraction("get_unread_push_actions_for_user_in_range", f) after_read_receipt = yield self.runInteraction(
"get_unread_push_actions_for_user_in_range", get_after_receipt
)
def get_no_receipt(txn):
sql = (
"SELECT ep.event_id, ep.stream_ordering, ep.actions "
"FROM event_push_actions AS ep "
"WHERE ep.room_id not in ("
" SELECT room_id FROM events NATURAL JOIN receipts_linearized"
" WHERE receipt_type = 'm.read' AND user_id = ? "
" GROUP BY room_id"
") AND ep.user_id = ? AND ep.stream_ordering > ?"
)
args = [user_id, user_id, min_stream_ordering]
if max_stream_ordering is not None:
sql += " AND ep.stream_ordering <= ?"
args.append(max_stream_ordering)
sql += " ORDER BY ep.stream_ordering ASC"
txn.execute(sql, args)
return txn.fetchall()
no_read_receipt = yield self.runInteraction(
"get_unread_push_actions_for_user_in_range", get_no_receipt
)
defer.returnValue([ defer.returnValue([
{ {
"event_id": row[0], "event_id": row[0],
"stream_ordering": row[1], "stream_ordering": row[1],
"actions": json.loads(row[2]), "actions": json.loads(row[2]),
} for row in ret } for row in after_read_receipt + no_read_receipt
]) ])
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -137,7 +137,11 @@ class PusherStore(SQLBaseStore):
users = yield self.get_users_in_room(room_id) users = yield self.get_users_in_room(room_id)
result = yield self._simple_select_many_batch( result = yield self._simple_select_many_batch(
'pushers', 'user_name', users, ['user_name'] table='pushers',
column='user_name',
iterable=users,
retcols=['user_name'],
desc='get_users_with_pushers_in_room'
) )
defer.returnValue([r['user_name'] for r in result]) defer.returnValue([r['user_name'] for r in result])

View File

@ -0,0 +1,18 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE INDEX event_push_actions_stream_ordering on event_push_actions(
stream_ordering, user_id
);