Merge branch 'markjh/external_presence' into markjh/synchrotron

markjh/synchrotron
Mark Haines 2016-06-02 11:28:22 +01:00
commit b161fae864
17 changed files with 365 additions and 42 deletions

View File

@ -17,11 +17,15 @@
</td> </td>
<td class="message_contents"> <td class="message_contents">
{% if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %} {% if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %}
<div class="sender_name">{{ message.sender_name }}</div> <div class="sender_name">{% if message.msgtype == "m.emote" %}*{% endif %} {{ message.sender_name }}</div>
{% endif %} {% endif %}
<div class="message_body"> <div class="message_body">
{% if message.msgtype == "m.text" %} {% if message.msgtype == "m.text" %}
{{ message.body_text_html }} {{ message.body_text_html }}
{% elif message.msgtype == "m.emote" %}
{{ message.body_text_html }}
{% elif message.msgtype == "m.notice" %}
{{ message.body_text_html }}
{% elif message.msgtype == "m.image" %} {% elif message.msgtype == "m.image" %}
<img src="{{ message.image_url|mxc_to_http(640, 480, scale) }}" /> <img src="{{ message.image_url|mxc_to_http(640, 480, scale) }}" />
{% elif message.msgtype == "m.file" %} {% elif message.msgtype == "m.file" %}

View File

@ -1,7 +1,11 @@
{% for message in notif.messages %} {% for message in notif.messages %}
{{ message.sender_name }} ({{ message.ts|format_ts("%H:%M") }}) {% if message.msgtype == "m.emote" %}* {% endif %}{{ message.sender_name }} ({{ message.ts|format_ts("%H:%M") }})
{% if message.msgtype == "m.text" %} {% if message.msgtype == "m.text" %}
{{ message.body_text_plain }} {{ message.body_text_plain }}
{% elif message.msgtype == "m.emote" %}
{{ message.body_text_plain }}
{% elif message.msgtype == "m.notice" %}
{{ message.body_text_plain }}
{% elif message.msgtype == "m.image" %} {% elif message.msgtype == "m.image" %}
{{ message.body_text_plain }} {{ message.body_text_plain }}
{% elif message.msgtype == "m.file" %} {% elif message.msgtype == "m.file" %}

View File

@ -29,6 +29,7 @@ class ServerConfig(Config):
self.user_agent_suffix = config.get("user_agent_suffix") self.user_agent_suffix = config.get("user_agent_suffix")
self.use_frozen_dicts = config.get("use_frozen_dicts", True) self.use_frozen_dicts = config.get("use_frozen_dicts", True)
self.public_baseurl = config.get("public_baseurl") self.public_baseurl = config.get("public_baseurl")
self.secondary_directory_servers = config.get("secondary_directory_servers", [])
if self.public_baseurl is not None: if self.public_baseurl is not None:
if self.public_baseurl[-1] != '/': if self.public_baseurl[-1] != '/':
@ -156,6 +157,15 @@ class ServerConfig(Config):
# hard limit. # hard limit.
soft_file_limit: 0 soft_file_limit: 0
# A list of other Home Servers to fetch the public room directory from
# and include in the public room directory of this home server
# This is a temporary stopgap solution to populate new server with a
# list of rooms until there exists a good solution of a decentralized
# room directory.
# secondary_directory_servers:
# - matrix.org
# - vector.im
# List of ports that Synapse should listen on, their purpose and their # List of ports that Synapse should listen on, their purpose and their
# configuration. # configuration.
listeners: listeners:

View File

@ -24,6 +24,7 @@ from synapse.api.errors import (
CodeMessageException, HttpResponseException, SynapseError, CodeMessageException, HttpResponseException, SynapseError,
) )
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.async import concurrently_execute
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
@ -550,6 +551,25 @@ class FederationClient(FederationBase):
raise RuntimeError("Failed to send to any server.") raise RuntimeError("Failed to send to any server.")
@defer.inlineCallbacks
def get_public_rooms(self, destinations):
results_by_server = {}
@defer.inlineCallbacks
def _get_result(s):
if s == self.server_name:
defer.returnValue()
try:
result = yield self.transport_layer.get_public_rooms(s)
results_by_server[s] = result
except:
logger.exception("Error getting room list from server %r", s)
yield concurrently_execute(_get_result, destinations, 3)
defer.returnValue(results_by_server)
@defer.inlineCallbacks @defer.inlineCallbacks
def query_auth(self, destination, room_id, event_id, local_auth): def query_auth(self, destination, room_id, event_id, local_auth):
""" """

View File

@ -224,6 +224,18 @@ class TransportLayerClient(object):
defer.returnValue(response) defer.returnValue(response)
@defer.inlineCallbacks
@log_function
def get_public_rooms(self, remote_server):
path = PREFIX + "/publicRooms"
response = yield self.client.get_json(
destination=remote_server,
path=path,
)
defer.returnValue(response)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def exchange_third_party_invite(self, destination, room_id, event_dict): def exchange_third_party_invite(self, destination, room_id, event_dict):

View File

@ -134,10 +134,12 @@ class Authenticator(object):
class BaseFederationServlet(object): class BaseFederationServlet(object):
def __init__(self, handler, authenticator, ratelimiter, server_name): def __init__(self, handler, authenticator, ratelimiter, server_name,
room_list_handler):
self.handler = handler self.handler = handler
self.authenticator = authenticator self.authenticator = authenticator
self.ratelimiter = ratelimiter self.ratelimiter = ratelimiter
self.room_list_handler = room_list_handler
def _wrap(self, code): def _wrap(self, code):
authenticator = self.authenticator authenticator = self.authenticator
@ -492,6 +494,50 @@ class OpenIdUserInfo(BaseFederationServlet):
return code return code
class PublicRoomList(BaseFederationServlet):
"""
Fetch the public room list for this server.
This API returns information in the same format as /publicRooms on the
client API, but will only ever include local public rooms and hence is
intended for consumption by other home servers.
GET /publicRooms HTTP/1.1
HTTP/1.1 200 OK
Content-Type: application/json
{
"chunk": [
{
"aliases": [
"#test:localhost"
],
"guest_can_join": false,
"name": "test room",
"num_joined_members": 3,
"room_id": "!whkydVegtvatLfXmPN:localhost",
"world_readable": false
}
],
"end": "END",
"start": "START"
}
"""
PATH = "/publicRooms"
@defer.inlineCallbacks
def on_GET(self, request):
data = yield self.room_list_handler.get_local_public_room_list()
defer.returnValue((200, data))
# Avoid doing remote HS authorization checks which are done by default by
# BaseFederationServlet.
def _wrap(self, code):
return code
SERVLET_CLASSES = ( SERVLET_CLASSES = (
FederationSendServlet, FederationSendServlet,
FederationPullServlet, FederationPullServlet,
@ -513,6 +559,7 @@ SERVLET_CLASSES = (
FederationThirdPartyInviteExchangeServlet, FederationThirdPartyInviteExchangeServlet,
On3pidBindServlet, On3pidBindServlet,
OpenIdUserInfo, OpenIdUserInfo,
PublicRoomList,
) )
@ -523,4 +570,5 @@ def register_servlets(hs, resource, authenticator, ratelimiter):
authenticator=authenticator, authenticator=authenticator,
ratelimiter=ratelimiter, ratelimiter=ratelimiter,
server_name=hs.hostname, server_name=hs.hostname,
room_list_handler=hs.get_room_list_handler(),
).register(resource) ).register(resource)

View File

@ -68,6 +68,10 @@ FEDERATION_TIMEOUT = 30 * 60 * 1000
# How often to resend presence to remote servers # How often to resend presence to remote servers
FEDERATION_PING_INTERVAL = 25 * 60 * 1000 FEDERATION_PING_INTERVAL = 25 * 60 * 1000
# How long we will wait before assuming that the syncs from an external process
# are dead.
EXTERNAL_PROCESS_EXPIRY = 5 * 60 * 1000
assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER
@ -158,10 +162,21 @@ class PresenceHandler(object):
self.serial_to_user = {} self.serial_to_user = {}
self._next_serial = 1 self._next_serial = 1
# Keeps track of the number of *ongoing* syncs. While this is non zero # Keeps track of the number of *ongoing* syncs on this process. While
# a user will never go offline. # this is non zero a user will never go offline.
self.user_to_num_current_syncs = {} self.user_to_num_current_syncs = {}
# Keeps track of the number of *ongoing* syncs on other processes.
# While any sync is ongoing on another process the user will never
# go offline.
# Each process has a unique identifier and an update frequency. If
# no update is received from that process within the update period then
# we assume that all the sync requests on that process have stopped.
# Stored as a dict from process_id to set of user_id, and a dict of
# process_id to millisecond timestamp last updated.
self.external_process_to_current_syncs = {}
self.external_process_last_updated_ms = []
# Start a LoopingCall in 30s that fires every 5s. # Start a LoopingCall in 30s that fires every 5s.
# The initial delay is to allow disconnected clients a chance to # The initial delay is to allow disconnected clients a chance to
# reconnect before we treat them as offline. # reconnect before we treat them as offline.
@ -272,13 +287,26 @@ class PresenceHandler(object):
# Fetch the list of users that *may* have timed out. Things may have # Fetch the list of users that *may* have timed out. Things may have
# changed since the timeout was set, so we won't necessarily have to # changed since the timeout was set, so we won't necessarily have to
# take any action. # take any action.
users_to_check = self.wheel_timer.fetch(now) users_to_check = set(self.wheel_timer.fetch(now))
# Check whether the lists of syncing processes from an external
# process have expired.
expired_process_ids = [
process_id for process_id, last_update
in self.external_process_last_update.items()
if now - last_update > EXTERNAL_PROCESS_EXPIRY
]
for process_id in expired_process_ids:
users_to_check.update(
self.external_process_to_current_syncs.pop(process_id, ())
)
self.external_process_last_update.pop(process_id)
states = [ states = [
self.user_to_current_state.get( self.user_to_current_state.get(
user_id, UserPresenceState.default(user_id) user_id, UserPresenceState.default(user_id)
) )
for user_id in set(users_to_check) for user_id in users_to_check
] ]
timers_fired_counter.inc_by(len(states)) timers_fired_counter.inc_by(len(states))
@ -286,7 +314,7 @@ class PresenceHandler(object):
changes = handle_timeouts( changes = handle_timeouts(
states, states,
is_mine_fn=self.is_mine_id, is_mine_fn=self.is_mine_id,
user_to_num_current_syncs=self.user_to_num_current_syncs, syncing_users=self.get_syncing_users(),
now=now, now=now,
) )
@ -363,6 +391,73 @@ class PresenceHandler(object):
defer.returnValue(_user_syncing()) defer.returnValue(_user_syncing())
def get_currently_syncing_users(self):
"""Get the set of user ids that are currently syncing on this HS.
Returns:
set(str): A set of user_id strings.
"""
syncing_user_ids = {
user_id for user_id, count in self.user_to_num_current_syncs.items()
if count
}
syncing_user_ids.update(self.external_process_to_current_syncs.values())
return syncing_user_ids
@defer.inlineCallbacks
def update_external_syncs(self, process_id, syncing_user_ids):
"""Update the syncing users for an external process
Args:
process_id(str): An identifier for the process the users are
syncing against. This allows synapse to process updates
as user start and stop syncing against a given process.
syncing_user_ids(set(str)): The set of user_ids that are
currently syncing on that server.
"""
# Grab the previous list of user_ids that were syncing on that process
prev_syncing_user_ids = (
self.external_process_to_current_syncs.get(process_id, set())
)
# Grab the current presence state for both the users that are syncing
# now and the users that were syncing before this update.
prev_states = yield self.current_state_for_users(
syncing_user_ids + prev_syncing_user_ids
)
updates = []
time_now_ms = self.clock.time_msec()
# For each new user that is syncing check if we need to mark them as
# being online.
for new_user_id in syncing_user_ids - prev_syncing_user_ids:
prev_state = prev_states[new_user_id]
if prev_state.state == PresenceState.OFFLINE:
updates.append(prev_state.copy_and_replace(
state=PresenceState.ONLINE,
last_active_ts=time_now_ms,
last_user_sync_ts=time_now_ms,
))
else:
updates.append(prev_state.copy_and_replace(
last_user_sync_ts=time_now_ms,
))
# For each user that is still syncing or stopped syncing update the
# last sync time so that we will correctly apply the grace period when
# they stop syncing.
for old_user_id in prev_syncing_user_ids:
prev_state = prev_states[old_user_id]
updates.append(prev_state.copy_and_replace(
last_user_sync_ts=time_now_ms,
))
yield self._update_states(updates)
# Update the last updated time for the process. We expire the entries
# if we don't receive an update in the given timeframe.
self.external_process_last_updated_ms[process_id] = self.clock.time_msec()
self.external_process_to_current_syncs[process_id] = syncing_user_ids
@defer.inlineCallbacks @defer.inlineCallbacks
def current_state_for_user(self, user_id): def current_state_for_user(self, user_id):
"""Get the current presence state for a user. """Get the current presence state for a user.
@ -935,15 +1030,14 @@ class PresenceEventSource(object):
return self.get_new_events(user, from_key=None, include_offline=False) return self.get_new_events(user, from_key=None, include_offline=False)
def handle_timeouts(user_states, is_mine_fn, user_to_num_current_syncs, now): def handle_timeouts(user_states, is_mine_fn, syncing_user_ids, now):
"""Checks the presence of users that have timed out and updates as """Checks the presence of users that have timed out and updates as
appropriate. appropriate.
Args: Args:
user_states(list): List of UserPresenceState's to check. user_states(list): List of UserPresenceState's to check.
is_mine_fn (fn): Function that returns if a user_id is ours is_mine_fn (fn): Function that returns if a user_id is ours
user_to_num_current_syncs (dict): Mapping of user_id to number of currently syncing_user_ids (set): Set of user_ids with active syncs.
active syncs.
now (int): Current time in ms. now (int): Current time in ms.
Returns: Returns:
@ -954,21 +1048,20 @@ def handle_timeouts(user_states, is_mine_fn, user_to_num_current_syncs, now):
for state in user_states: for state in user_states:
is_mine = is_mine_fn(state.user_id) is_mine = is_mine_fn(state.user_id)
new_state = handle_timeout(state, is_mine, user_to_num_current_syncs, now) new_state = handle_timeout(state, is_mine, syncing_user_ids, now)
if new_state: if new_state:
changes[state.user_id] = new_state changes[state.user_id] = new_state
return changes.values() return changes.values()
def handle_timeout(state, is_mine, user_to_num_current_syncs, now): def handle_timeout(state, is_mine, syncing_user_ids, now):
"""Checks the presence of the user to see if any of the timers have elapsed """Checks the presence of the user to see if any of the timers have elapsed
Args: Args:
state (UserPresenceState) state (UserPresenceState)
is_mine (bool): Whether the user is ours is_mine (bool): Whether the user is ours
user_to_num_current_syncs (dict): Mapping of user_id to number of currently syncing_user_ids (set): Set of user_ids with active syncs.
active syncs.
now (int): Current time in ms. now (int): Current time in ms.
Returns: Returns:
@ -1002,7 +1095,7 @@ def handle_timeout(state, is_mine, user_to_num_current_syncs, now):
# If there are have been no sync for a while (and none ongoing), # If there are have been no sync for a while (and none ongoing),
# set presence to offline # set presence to offline
if not user_to_num_current_syncs.get(user_id, 0): if user_id not in syncing_user_ids:
if now - state.last_user_sync_ts > SYNC_ONLINE_TIMEOUT: if now - state.last_user_sync_ts > SYNC_ONLINE_TIMEOUT:
state = state.copy_and_replace( state = state.copy_and_replace(
state=PresenceState.OFFLINE, state=PresenceState.OFFLINE,

View File

@ -36,6 +36,8 @@ import string
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000
id_server_scheme = "https://" id_server_scheme = "https://"
@ -344,8 +346,14 @@ class RoomListHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
super(RoomListHandler, self).__init__(hs) super(RoomListHandler, self).__init__(hs)
self.response_cache = ResponseCache() self.response_cache = ResponseCache()
self.remote_list_request_cache = ResponseCache()
self.remote_list_cache = {}
self.fetch_looping_call = hs.get_clock().looping_call(
self.fetch_all_remote_lists, REMOTE_ROOM_LIST_POLL_INTERVAL
)
self.fetch_all_remote_lists()
def get_public_room_list(self): def get_local_public_room_list(self):
result = self.response_cache.get(()) result = self.response_cache.get(())
if not result: if not result:
result = self.response_cache.set((), self._get_public_room_list()) result = self.response_cache.set((), self._get_public_room_list())
@ -427,6 +435,55 @@ class RoomListHandler(BaseHandler):
# FIXME (erikj): START is no longer a valid value # FIXME (erikj): START is no longer a valid value
defer.returnValue({"start": "START", "end": "END", "chunk": results}) defer.returnValue({"start": "START", "end": "END", "chunk": results})
@defer.inlineCallbacks
def fetch_all_remote_lists(self):
deferred = self.hs.get_replication_layer().get_public_rooms(
self.hs.config.secondary_directory_servers
)
self.remote_list_request_cache.set((), deferred)
yield deferred
@defer.inlineCallbacks
def get_aggregated_public_room_list(self):
"""
Get the public room list from this server and the servers
specified in the secondary_directory_servers config option.
XXX: Pagination...
"""
# We return the results from out cache which is updated by a looping call,
# unless we're missing a cache entry, in which case wait for the result
# of the fetch if there's one in progress. If not, omit that server.
wait = False
for s in self.hs.config.secondary_directory_servers:
if s not in self.remote_list_cache:
logger.warn("No cached room list from %s: waiting for fetch", s)
wait = True
break
if wait and self.remote_list_request_cache.get(()):
yield self.remote_list_request_cache.get(())
public_rooms = yield self.get_local_public_room_list()
# keep track of which room IDs we've seen so we can de-dup
room_ids = set()
# tag all the ones in our list with our server name.
# Also add the them to the de-deping set
for room in public_rooms['chunk']:
room["server_name"] = self.hs.hostname
room_ids.add(room["room_id"])
# Now add the results from federation
for server_name, server_result in self.remote_list_cache.items():
for room in server_result["chunk"]:
if room["room_id"] not in room_ids:
room["server_name"] = server_name
public_rooms["chunk"].append(room)
room_ids.add(room["room_id"])
defer.returnValue(public_rooms)
class RoomContextHandler(BaseHandler): class RoomContextHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -29,6 +29,7 @@ logger = logging.getLogger(__name__)
def decode_rule_json(rule): def decode_rule_json(rule):
rule = dict(rule)
rule['conditions'] = json.loads(rule['conditions']) rule['conditions'] = json.loads(rule['conditions'])
rule['actions'] = json.loads(rule['actions']) rule['actions'] = json.loads(rule['actions'])
return rule return rule
@ -39,6 +40,8 @@ def _get_rules(room_id, user_ids, store):
rules_by_user = yield store.bulk_get_push_rules(user_ids) rules_by_user = yield store.bulk_get_push_rules(user_ids)
rules_enabled_by_user = yield store.bulk_get_push_rules_enabled(user_ids) rules_enabled_by_user = yield store.bulk_get_push_rules_enabled(user_ids)
rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None}
rules_by_user = { rules_by_user = {
uid: list_with_base_rules([ uid: list_with_base_rules([
decode_rule_json(rule_list) decode_rule_json(rule_list)
@ -51,11 +54,10 @@ def _get_rules(room_id, user_ids, store):
# fetch disabled rules, but this won't account for any server default # fetch disabled rules, but this won't account for any server default
# rules the user has disabled, so we need to do this too. # rules the user has disabled, so we need to do this too.
for uid in user_ids: for uid in user_ids:
if uid not in rules_enabled_by_user: user_enabled_map = rules_enabled_by_user.get(uid)
if not user_enabled_map:
continue continue
user_enabled_map = rules_enabled_by_user[uid]
for i, rule in enumerate(rules_by_user[uid]): for i, rule in enumerate(rules_by_user[uid]):
rule_id = rule['rule_id'] rule_id = rule['rule_id']

View File

@ -0,0 +1,58 @@
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.http.server import respond_with_json_bytes, request_handler
from synapse.http.servlet import parse_json_object_from_request
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
from twisted.internet import defer
class PresenceResource(Resource):
"""
HTTP endpoint for marking users as syncing.
POST /_synapse/replication/presence HTTP/1.1
Content-Type: application/json
{
"process_id": "<process_id>",
"syncing_users": ["<user_id>"]
}
"""
def __init__(self, hs):
Resource.__init__(self) # Resource is old-style, so no super()
self.version_string = hs.version_string
self.presence_handler = hs.get_presence_handler()
def render_POST(self, request):
self._async_render_POST(request)
return NOT_DONE_YET
@request_handler()
@defer.inlineCallbacks
def _async_render_POST(self, request):
content = parse_json_object_from_request(request)
process_id = content["process_id"]
syncing_user_ids = content["syncing_users"]
yield self.presence_handler.update_external_syncs(
process_id, set(syncing_user_ids)
)
respond_with_json_bytes(request, 200, "{}")

View File

@ -16,6 +16,7 @@
from synapse.http.servlet import parse_integer, parse_string from synapse.http.servlet import parse_integer, parse_string
from synapse.http.server import request_handler, finish_request from synapse.http.server import request_handler, finish_request
from synapse.replication.pusher_resource import PusherResource from synapse.replication.pusher_resource import PusherResource
from synapse.replication.presence_resource import PresenceResource
from twisted.web.resource import Resource from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET from twisted.web.server import NOT_DONE_YET
@ -115,6 +116,7 @@ class ReplicationResource(Resource):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.putChild("remove_pushers", PusherResource(hs)) self.putChild("remove_pushers", PusherResource(hs))
self.putChild("syncing_users", PresenceResource(hs))
def render_GET(self, request): def render_GET(self, request):
self._async_render_GET(request) self._async_render_GET(request)

View File

@ -280,7 +280,8 @@ class PublicRoomListRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
handler = self.hs.get_room_list_handler() handler = self.hs.get_room_list_handler()
data = yield handler.get_public_room_list() data = yield handler.get_aggregated_public_room_list()
defer.returnValue((200, data)) defer.returnValue((200, data))

View File

@ -119,7 +119,8 @@ class EventPushActionsStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_unread_push_actions_for_user_in_range(self, user_id, def get_unread_push_actions_for_user_in_range(self, user_id,
min_stream_ordering, min_stream_ordering,
max_stream_ordering=None): max_stream_ordering=None,
limit=20):
def get_after_receipt(txn): def get_after_receipt(txn):
sql = ( sql = (
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, " "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, "
@ -151,7 +152,8 @@ class EventPushActionsStore(SQLBaseStore):
if max_stream_ordering is not None: if max_stream_ordering is not None:
sql += " AND ep.stream_ordering <= ?" sql += " AND ep.stream_ordering <= ?"
args.append(max_stream_ordering) args.append(max_stream_ordering)
sql += " ORDER BY ep.stream_ordering ASC" sql += " ORDER BY ep.stream_ordering ASC LIMIT ?"
args.append(limit)
txn.execute(sql, args) txn.execute(sql, args)
return txn.fetchall() return txn.fetchall()
after_read_receipt = yield self.runInteraction( after_read_receipt = yield self.runInteraction(

View File

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
from twisted.internet import defer from twisted.internet import defer
import logging import logging
@ -24,7 +24,7 @@ logger = logging.getLogger(__name__)
class PushRuleStore(SQLBaseStore): class PushRuleStore(SQLBaseStore):
@cachedInlineCallbacks() @cachedInlineCallbacks(lru=True)
def get_push_rules_for_user(self, user_id): def get_push_rules_for_user(self, user_id):
rows = yield self._simple_select_list( rows = yield self._simple_select_list(
table="push_rules", table="push_rules",
@ -44,7 +44,7 @@ class PushRuleStore(SQLBaseStore):
defer.returnValue(rows) defer.returnValue(rows)
@cachedInlineCallbacks() @cachedInlineCallbacks(lru=True)
def get_push_rules_enabled_for_user(self, user_id): def get_push_rules_enabled_for_user(self, user_id):
results = yield self._simple_select_list( results = yield self._simple_select_list(
table="push_rules_enable", table="push_rules_enable",
@ -60,12 +60,16 @@ class PushRuleStore(SQLBaseStore):
r['rule_id']: False if r['enabled'] == 0 else True for r in results r['rule_id']: False if r['enabled'] == 0 else True for r in results
}) })
@defer.inlineCallbacks @cachedList(cached_method_name="get_push_rules_for_user",
list_name="user_ids", num_args=1, inlineCallbacks=True)
def bulk_get_push_rules(self, user_ids): def bulk_get_push_rules(self, user_ids):
if not user_ids: if not user_ids:
defer.returnValue({}) defer.returnValue({})
results = {} results = {
user_id: []
for user_id in user_ids
}
rows = yield self._simple_select_many_batch( rows = yield self._simple_select_many_batch(
table="push_rules", table="push_rules",
@ -75,18 +79,24 @@ class PushRuleStore(SQLBaseStore):
desc="bulk_get_push_rules", desc="bulk_get_push_rules",
) )
rows.sort(key=lambda e: (-e["priority_class"], -e["priority"])) rows.sort(
key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))
)
for row in rows: for row in rows:
results.setdefault(row['user_name'], []).append(row) results.setdefault(row['user_name'], []).append(row)
defer.returnValue(results) defer.returnValue(results)
@defer.inlineCallbacks @cachedList(cached_method_name="get_push_rules_enabled_for_user",
list_name="user_ids", num_args=1, inlineCallbacks=True)
def bulk_get_push_rules_enabled(self, user_ids): def bulk_get_push_rules_enabled(self, user_ids):
if not user_ids: if not user_ids:
defer.returnValue({}) defer.returnValue({})
results = {} results = {
user_id: {}
for user_id in user_ids
}
rows = yield self._simple_select_many_batch( rows = yield self._simple_select_many_batch(
table="push_rules_enable", table="push_rules_enable",

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from synapse.storage.appservice import ApplicationServiceStore from synapse.config.appservice import load_appservices
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -38,7 +38,7 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs):
logger.warning("Could not get app_service_config_files from config") logger.warning("Could not get app_service_config_files from config")
pass pass
appservices = ApplicationServiceStore.load_appservices( appservices = load_appservices(
config.server_name, config_files config.server_name, config_files
) )

View File

@ -264,7 +264,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
) )
new_state = handle_timeout( new_state = handle_timeout(
state, is_mine=True, user_to_num_current_syncs={}, now=now state, is_mine=True, syncing_user_ids=set(), now=now
) )
self.assertIsNotNone(new_state) self.assertIsNotNone(new_state)
@ -282,7 +282,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
) )
new_state = handle_timeout( new_state = handle_timeout(
state, is_mine=True, user_to_num_current_syncs={}, now=now state, is_mine=True, syncing_user_ids=set(), now=now
) )
self.assertIsNotNone(new_state) self.assertIsNotNone(new_state)
@ -300,9 +300,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
) )
new_state = handle_timeout( new_state = handle_timeout(
state, is_mine=True, user_to_num_current_syncs={ state, is_mine=True, syncing_user_ids=set([user_id]), now=now
user_id: 1,
}, now=now
) )
self.assertIsNotNone(new_state) self.assertIsNotNone(new_state)
@ -321,7 +319,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
) )
new_state = handle_timeout( new_state = handle_timeout(
state, is_mine=True, user_to_num_current_syncs={}, now=now state, is_mine=True, syncing_user_ids=set(), now=now
) )
self.assertIsNotNone(new_state) self.assertIsNotNone(new_state)
@ -340,7 +338,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
) )
new_state = handle_timeout( new_state = handle_timeout(
state, is_mine=True, user_to_num_current_syncs={}, now=now state, is_mine=True, syncing_user_ids=set(), now=now
) )
self.assertIsNone(new_state) self.assertIsNone(new_state)
@ -358,7 +356,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
) )
new_state = handle_timeout( new_state = handle_timeout(
state, is_mine=False, user_to_num_current_syncs={}, now=now state, is_mine=False, syncing_user_ids=set(), now=now
) )
self.assertIsNotNone(new_state) self.assertIsNotNone(new_state)
@ -377,7 +375,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
) )
new_state = handle_timeout( new_state = handle_timeout(
state, is_mine=True, user_to_num_current_syncs={}, now=now state, is_mine=True, syncing_user_ids=set(), now=now
) )
self.assertIsNotNone(new_state) self.assertIsNotNone(new_state)

View File

@ -67,6 +67,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
version_string="Synapse/tests", version_string="Synapse/tests",
database_engine=create_engine(config.database_config), database_engine=create_engine(config.database_config),
get_db_conn=db_pool.get_db_conn, get_db_conn=db_pool.get_db_conn,
room_list_handler=object(),
**kargs **kargs
) )
hs.setup() hs.setup()
@ -75,6 +76,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
name, db_pool=None, datastore=datastore, config=config, name, db_pool=None, datastore=datastore, config=config,
version_string="Synapse/tests", version_string="Synapse/tests",
database_engine=create_engine(config.database_config), database_engine=create_engine(config.database_config),
room_list_handler=object(),
**kargs **kargs
) )