Merge branch 'develop' of github.com:matrix-org/synapse into erikj/timings

erikj/timings
Erik Johnston 2016-06-03 13:25:36 +01:00
commit d6a9d4e562
36 changed files with 677 additions and 119 deletions

View File

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains classes for authenticating the user."""
from canonicaljson import encode_canonical_json
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json, SignatureVerifyException
@ -42,13 +41,20 @@ AuthEventTypes = (
class Auth(object):
"""
FIXME: This class contains a mix of functions for authenticating users
of our client-server API and authenticating events added to room graphs.
"""
def __init__(self, hs):
self.hs = hs
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
# Docs for these currently lives at
# https://github.com/matrix-org/matrix-doc/blob/master/drafts/macaroons_caveats.rst
# In addition, we have type == delete_pusher which grants access only to
# delete pushers.
self._KNOWN_CAVEAT_PREFIXES = set([
"gen = ",
"guest = ",
@ -525,7 +531,7 @@ class Auth(object):
return default
@defer.inlineCallbacks
def get_user_by_req(self, request, allow_guest=False):
def get_user_by_req(self, request, allow_guest=False, rights="access"):
""" Get a registered user's ID.
Args:
@ -547,7 +553,7 @@ class Auth(object):
)
access_token = request.args["access_token"][0]
user_info = yield self.get_user_by_access_token(access_token)
user_info = yield self.get_user_by_access_token(access_token, rights)
user = user_info["user"]
token_id = user_info["token_id"]
is_guest = user_info["is_guest"]
@ -608,7 +614,7 @@ class Auth(object):
defer.returnValue(user_id)
@defer.inlineCallbacks
def get_user_by_access_token(self, token):
def get_user_by_access_token(self, token, rights="access"):
""" Get a registered user's ID.
Args:
@ -619,7 +625,7 @@ class Auth(object):
AuthError if no user by that token exists or the token is invalid.
"""
try:
ret = yield self.get_user_from_macaroon(token)
ret = yield self.get_user_from_macaroon(token, rights)
except AuthError:
# TODO(daniel): Remove this fallback when all existing access tokens
# have been re-issued as macaroons.
@ -627,11 +633,11 @@ class Auth(object):
defer.returnValue(ret)
@defer.inlineCallbacks
def get_user_from_macaroon(self, macaroon_str):
def get_user_from_macaroon(self, macaroon_str, rights="access"):
try:
macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
self.validate_macaroon(macaroon, "access", self.hs.config.expire_access_token)
self.validate_macaroon(macaroon, rights, self.hs.config.expire_access_token)
user_prefix = "user_id = "
user = None
@ -654,6 +660,13 @@ class Auth(object):
"is_guest": True,
"token_id": None,
}
elif rights == "delete_pusher":
# We don't store these tokens in the database
ret = {
"user": user,
"is_guest": False,
"token_id": None,
}
else:
# This codepath exists so that we can actually return a
# token ID, because we use token IDs in place of device
@ -685,7 +698,8 @@ class Auth(object):
Args:
macaroon(pymacaroons.Macaroon): The macaroon to validate
type_string(str): The kind of token this is (e.g. "access", "refresh")
type_string(str): The kind of token required (e.g. "access", "refresh",
"delete_pusher")
verify_expiry(bool): Whether to verify whether the macaroon has expired.
This should really always be True, but no clients currently implement
token refresh, so we can't enforce expiry yet.

View File

@ -21,6 +21,7 @@ from synapse.config._base import ConfigError
from synapse.config.database import DatabaseConfig
from synapse.config.logger import LoggingConfig
from synapse.config.emailconfig import EmailConfig
from synapse.config.key import KeyConfig
from synapse.http.site import SynapseSite
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.storage.roommember import RoomMemberStore
@ -63,6 +64,26 @@ class SlaveConfig(DatabaseConfig):
self.pid_file = self.abspath(config.get("pid_file"))
self.public_baseurl = config["public_baseurl"]
# some things used by the auth handler but not actually used in the
# pusher codebase
self.bcrypt_rounds = None
self.ldap_enabled = None
self.ldap_server = None
self.ldap_port = None
self.ldap_tls = None
self.ldap_search_base = None
self.ldap_search_property = None
self.ldap_email_property = None
self.ldap_full_name_property = None
# We would otherwise try to use the registration shared secret as the
# macaroon shared secret if there was no macaroon_shared_secret, but
# that means pulling in RegistrationConfig too. We don't need to be
# backwards compaitible in the pusher codebase so just make people set
# macaroon_shared_secret. We set this to None to prevent it referencing
# an undefined key.
self.registration_shared_secret = None
def default_config(self, server_name, **kwargs):
pid_file = self.abspath("pusher.pid")
return """\
@ -95,7 +116,7 @@ class SlaveConfig(DatabaseConfig):
""" % locals()
class PusherSlaveConfig(SlaveConfig, LoggingConfig, EmailConfig):
class PusherSlaveConfig(SlaveConfig, LoggingConfig, EmailConfig, KeyConfig):
pass

View File

@ -89,7 +89,7 @@ class EmailConfig(Config):
# enable_notifs: false
# smtp_host: "localhost"
# smtp_port: 25
# notif_from: Your Friendly Matrix Home Server <noreply@example.com>
# notif_from: "Your Friendly %(app)s Home Server <noreply@example.com>"
# app_name: Matrix
# template_dir: res/templates
# notif_template_html: notif_mail.html

View File

@ -24,7 +24,6 @@ from .federation import FederationHandler
from .profile import ProfileHandler
from .directory import DirectoryHandler
from .admin import AdminHandler
from .auth import AuthHandler
from .identity import IdentityHandler
from .receipts import ReceiptsHandler
from .search import SearchHandler
@ -50,7 +49,6 @@ class Handlers(object):
self.directory_handler = DirectoryHandler(hs)
self.admin_handler = AdminHandler(hs)
self.receipts_handler = ReceiptsHandler(hs)
self.auth_handler = AuthHandler(hs)
self.identity_handler = IdentityHandler(hs)
self.search_handler = SearchHandler(hs)
self.room_context_handler = RoomContextHandler(hs)

View File

@ -529,6 +529,11 @@ class AuthHandler(BaseHandler):
macaroon.add_first_party_caveat("time < %d" % (expiry,))
return macaroon.serialize()
def generate_delete_pusher_token(self, user_id):
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = delete_pusher")
return macaroon.serialize()
def validate_short_term_login_token_and_get_user_id(self, login_token):
try:
macaroon = pymacaroons.Macaroon.deserialize(login_token)

View File

@ -68,6 +68,10 @@ FEDERATION_TIMEOUT = 30 * 60 * 1000
# How often to resend presence to remote servers
FEDERATION_PING_INTERVAL = 25 * 60 * 1000
# How long we will wait before assuming that the syncs from an external process
# are dead.
EXTERNAL_PROCESS_EXPIRY = 5 * 60 * 1000
assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER
@ -158,10 +162,21 @@ class PresenceHandler(object):
self.serial_to_user = {}
self._next_serial = 1
# Keeps track of the number of *ongoing* syncs. While this is non zero
# a user will never go offline.
# Keeps track of the number of *ongoing* syncs on this process. While
# this is non zero a user will never go offline.
self.user_to_num_current_syncs = {}
# Keeps track of the number of *ongoing* syncs on other processes.
# While any sync is ongoing on another process the user will never
# go offline.
# Each process has a unique identifier and an update frequency. If
# no update is received from that process within the update period then
# we assume that all the sync requests on that process have stopped.
# Stored as a dict from process_id to set of user_id, and a dict of
# process_id to millisecond timestamp last updated.
self.external_process_to_current_syncs = {}
self.external_process_last_updated_ms = {}
# Start a LoopingCall in 30s that fires every 5s.
# The initial delay is to allow disconnected clients a chance to
# reconnect before we treat them as offline.
@ -272,13 +287,26 @@ class PresenceHandler(object):
# Fetch the list of users that *may* have timed out. Things may have
# changed since the timeout was set, so we won't necessarily have to
# take any action.
users_to_check = self.wheel_timer.fetch(now)
users_to_check = set(self.wheel_timer.fetch(now))
# Check whether the lists of syncing processes from an external
# process have expired.
expired_process_ids = [
process_id for process_id, last_update
in self.external_process_last_update.items()
if now - last_update > EXTERNAL_PROCESS_EXPIRY
]
for process_id in expired_process_ids:
users_to_check.update(
self.external_process_to_current_syncs.pop(process_id, ())
)
self.external_process_last_update.pop(process_id)
states = [
self.user_to_current_state.get(
user_id, UserPresenceState.default(user_id)
)
for user_id in set(users_to_check)
for user_id in users_to_check
]
timers_fired_counter.inc_by(len(states))
@ -286,7 +314,7 @@ class PresenceHandler(object):
changes = handle_timeouts(
states,
is_mine_fn=self.is_mine_id,
user_to_num_current_syncs=self.user_to_num_current_syncs,
syncing_users=self.get_syncing_users(),
now=now,
)
@ -363,6 +391,73 @@ class PresenceHandler(object):
defer.returnValue(_user_syncing())
def get_currently_syncing_users(self):
"""Get the set of user ids that are currently syncing on this HS.
Returns:
set(str): A set of user_id strings.
"""
syncing_user_ids = {
user_id for user_id, count in self.user_to_num_current_syncs.items()
if count
}
syncing_user_ids.update(self.external_process_to_current_syncs.values())
return syncing_user_ids
@defer.inlineCallbacks
def update_external_syncs(self, process_id, syncing_user_ids):
"""Update the syncing users for an external process
Args:
process_id(str): An identifier for the process the users are
syncing against. This allows synapse to process updates
as user start and stop syncing against a given process.
syncing_user_ids(set(str)): The set of user_ids that are
currently syncing on that server.
"""
# Grab the previous list of user_ids that were syncing on that process
prev_syncing_user_ids = (
self.external_process_to_current_syncs.get(process_id, set())
)
# Grab the current presence state for both the users that are syncing
# now and the users that were syncing before this update.
prev_states = yield self.current_state_for_users(
syncing_user_ids | prev_syncing_user_ids
)
updates = []
time_now_ms = self.clock.time_msec()
# For each new user that is syncing check if we need to mark them as
# being online.
for new_user_id in syncing_user_ids - prev_syncing_user_ids:
prev_state = prev_states[new_user_id]
if prev_state.state == PresenceState.OFFLINE:
updates.append(prev_state.copy_and_replace(
state=PresenceState.ONLINE,
last_active_ts=time_now_ms,
last_user_sync_ts=time_now_ms,
))
else:
updates.append(prev_state.copy_and_replace(
last_user_sync_ts=time_now_ms,
))
# For each user that is still syncing or stopped syncing update the
# last sync time so that we will correctly apply the grace period when
# they stop syncing.
for old_user_id in prev_syncing_user_ids:
prev_state = prev_states[old_user_id]
updates.append(prev_state.copy_and_replace(
last_user_sync_ts=time_now_ms,
))
yield self._update_states(updates)
# Update the last updated time for the process. We expire the entries
# if we don't receive an update in the given timeframe.
self.external_process_last_updated_ms[process_id] = self.clock.time_msec()
self.external_process_to_current_syncs[process_id] = syncing_user_ids
@defer.inlineCallbacks
def current_state_for_user(self, user_id):
"""Get the current presence state for a user.
@ -935,15 +1030,14 @@ class PresenceEventSource(object):
return self.get_new_events(user, from_key=None, include_offline=False)
def handle_timeouts(user_states, is_mine_fn, user_to_num_current_syncs, now):
def handle_timeouts(user_states, is_mine_fn, syncing_user_ids, now):
"""Checks the presence of users that have timed out and updates as
appropriate.
Args:
user_states(list): List of UserPresenceState's to check.
is_mine_fn (fn): Function that returns if a user_id is ours
user_to_num_current_syncs (dict): Mapping of user_id to number of currently
active syncs.
syncing_user_ids (set): Set of user_ids with active syncs.
now (int): Current time in ms.
Returns:
@ -954,21 +1048,20 @@ def handle_timeouts(user_states, is_mine_fn, user_to_num_current_syncs, now):
for state in user_states:
is_mine = is_mine_fn(state.user_id)
new_state = handle_timeout(state, is_mine, user_to_num_current_syncs, now)
new_state = handle_timeout(state, is_mine, syncing_user_ids, now)
if new_state:
changes[state.user_id] = new_state
return changes.values()
def handle_timeout(state, is_mine, user_to_num_current_syncs, now):
def handle_timeout(state, is_mine, syncing_user_ids, now):
"""Checks the presence of the user to see if any of the timers have elapsed
Args:
state (UserPresenceState)
is_mine (bool): Whether the user is ours
user_to_num_current_syncs (dict): Mapping of user_id to number of currently
active syncs.
syncing_user_ids (set): Set of user_ids with active syncs.
now (int): Current time in ms.
Returns:
@ -1002,7 +1095,7 @@ def handle_timeout(state, is_mine, user_to_num_current_syncs, now):
# If there are have been no sync for a while (and none ongoing),
# set presence to offline
if not user_to_num_current_syncs.get(user_id, 0):
if user_id not in syncing_user_ids:
if now - state.last_user_sync_ts > SYNC_ONLINE_TIMEOUT:
state = state.copy_and_replace(
state=PresenceState.OFFLINE,

View File

@ -413,7 +413,7 @@ class RegistrationHandler(BaseHandler):
defer.returnValue((user_id, token))
def auth_handler(self):
return self.hs.get_handlers().auth_handler
return self.hs.get_auth_handler()
@defer.inlineCallbacks
def guest_access_token_for(self, medium, address, inviter_user_id):

View File

@ -636,6 +636,9 @@ class SyncHandler(object):
)
presence.extend(states)
# Deduplicate the presence entries so that there's at most one per user
presence = {p["content"]["user_id"]: p for p in presence}.values()
presence = sync_config.filter_collection.filter_presence(
presence
)

View File

@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
# A tiny object useful for storing a user's membership in a room, as a mapping
# key
RoomMember = namedtuple("RoomMember", ("room_id", "user"))
RoomMember = namedtuple("RoomMember", ("room_id", "user_id"))
class TypingHandler(object):
@ -38,7 +38,7 @@ class TypingHandler(object):
self.store = hs.get_datastore()
self.server_name = hs.config.server_name
self.auth = hs.get_auth()
self.is_mine = hs.is_mine
self.is_mine_id = hs.is_mine_id
self.notifier = hs.get_notifier()
self.clock = hs.get_clock()
@ -67,20 +67,23 @@ class TypingHandler(object):
@defer.inlineCallbacks
def started_typing(self, target_user, auth_user, room_id, timeout):
if not self.is_mine(target_user):
target_user_id = target_user.to_string()
auth_user_id = auth_user.to_string()
if not self.is_mine_id(target_user_id):
raise SynapseError(400, "User is not hosted on this Home Server")
if target_user != auth_user:
if target_user_id != auth_user_id:
raise AuthError(400, "Cannot set another user's typing state")
yield self.auth.check_joined_room(room_id, target_user.to_string())
yield self.auth.check_joined_room(room_id, target_user_id)
logger.debug(
"%s has started typing in %s", target_user.to_string(), room_id
"%s has started typing in %s", target_user_id, room_id
)
until = self.clock.time_msec() + timeout
member = RoomMember(room_id=room_id, user=target_user)
member = RoomMember(room_id=room_id, user_id=target_user_id)
was_present = member in self._member_typing_until
@ -104,25 +107,28 @@ class TypingHandler(object):
yield self._push_update(
room_id=room_id,
user=target_user,
user_id=target_user_id,
typing=True,
)
@defer.inlineCallbacks
def stopped_typing(self, target_user, auth_user, room_id):
if not self.is_mine(target_user):
target_user_id = target_user.to_string()
auth_user_id = auth_user.to_string()
if not self.is_mine_id(target_user_id):
raise SynapseError(400, "User is not hosted on this Home Server")
if target_user != auth_user:
if target_user_id != auth_user_id:
raise AuthError(400, "Cannot set another user's typing state")
yield self.auth.check_joined_room(room_id, target_user.to_string())
yield self.auth.check_joined_room(room_id, target_user_id)
logger.debug(
"%s has stopped typing in %s", target_user.to_string(), room_id
"%s has stopped typing in %s", target_user_id, room_id
)
member = RoomMember(room_id=room_id, user=target_user)
member = RoomMember(room_id=room_id, user_id=target_user_id)
if member in self._member_typing_timer:
self.clock.cancel_call_later(self._member_typing_timer[member])
@ -132,8 +138,9 @@ class TypingHandler(object):
@defer.inlineCallbacks
def user_left_room(self, user, room_id):
if self.is_mine(user):
member = RoomMember(room_id=room_id, user=user)
user_id = user.to_string()
if self.is_mine_id(user_id):
member = RoomMember(room_id=room_id, user=user_id)
yield self._stopped_typing(member)
@defer.inlineCallbacks
@ -144,7 +151,7 @@ class TypingHandler(object):
yield self._push_update(
room_id=member.room_id,
user=member.user,
user_id=member.user_id,
typing=False,
)
@ -156,7 +163,7 @@ class TypingHandler(object):
del self._member_typing_timer[member]
@defer.inlineCallbacks
def _push_update(self, room_id, user, typing):
def _push_update(self, room_id, user_id, typing):
domains = yield self.store.get_joined_hosts_for_room(room_id)
deferreds = []
@ -164,7 +171,7 @@ class TypingHandler(object):
if domain == self.server_name:
self._push_update_local(
room_id=room_id,
user=user,
user_id=user_id,
typing=typing
)
else:
@ -173,7 +180,7 @@ class TypingHandler(object):
edu_type="m.typing",
content={
"room_id": room_id,
"user_id": user.to_string(),
"user_id": user_id,
"typing": typing,
},
))
@ -183,23 +190,26 @@ class TypingHandler(object):
@defer.inlineCallbacks
def _recv_edu(self, origin, content):
room_id = content["room_id"]
user = UserID.from_string(content["user_id"])
user_id = content["user_id"]
# Check that the string is a valid user id
UserID.from_string(user_id)
domains = yield self.store.get_joined_hosts_for_room(room_id)
if self.server_name in domains:
self._push_update_local(
room_id=room_id,
user=user,
user_id=user_id,
typing=content["typing"]
)
def _push_update_local(self, room_id, user, typing):
def _push_update_local(self, room_id, user_id, typing):
room_set = self._room_typing.setdefault(room_id, set())
if typing:
room_set.add(user)
room_set.add(user_id)
else:
room_set.discard(user)
room_set.discard(user_id)
self._latest_room_serial += 1
self._room_serials[room_id] = self._latest_room_serial
@ -215,9 +225,7 @@ class TypingHandler(object):
for room_id, serial in self._room_serials.items():
if last_id < serial and serial <= current_id:
typing = self._room_typing[room_id]
typing_bytes = json.dumps([
u.to_string() for u in typing
], ensure_ascii=False)
typing_bytes = json.dumps(list(typing), ensure_ascii=False)
rows.append((serial, room_id, typing_bytes))
rows.sort()
return rows
@ -239,7 +247,7 @@ class TypingNotificationEventSource(object):
"type": "m.typing",
"room_id": room_id,
"content": {
"user_ids": [u.to_string() for u in typing],
"user_ids": list(typing),
},
}

View File

@ -126,12 +126,6 @@ class DistributionMetric(object):
class CacheMetric(object):
"""A combination of two CounterMetrics, one to count cache hits and one to
count a total, and a callback metric to yield the current size.
This metric generates standard metric name pairs, so that monitoring rules
can easily be applied to measure hit ratio."""
__slots__ = ("name", "cache_name", "hits", "misses", "size_callback")
def __init__(self, name, size_callback, cache_name):

View File

@ -279,5 +279,5 @@ class EmailPusher(object):
logger.info("Sending notif email for user %r", self.user_id)
yield self.mailer.send_notification_mail(
self.user_id, self.email, push_actions, reason
self.app_id, self.user_id, self.email, push_actions, reason
)

View File

@ -41,7 +41,7 @@ logger = logging.getLogger(__name__)
MESSAGE_FROM_PERSON_IN_ROOM = "You have a message on %(app)s from %(person)s " \
"in the %s room..."
"in the %(room)s room..."
MESSAGE_FROM_PERSON = "You have a message on %(app)s from %(person)s..."
MESSAGES_FROM_PERSON = "You have messages on %(app)s from %(person)s..."
MESSAGES_IN_ROOM = "You have messages on %(app)s in the %(room)s room..."
@ -81,9 +81,11 @@ class Mailer(object):
def __init__(self, hs, app_name):
self.hs = hs
self.store = self.hs.get_datastore()
self.auth_handler = self.hs.get_auth_handler()
self.state_handler = self.hs.get_state_handler()
loader = jinja2.FileSystemLoader(self.hs.config.email_template_dir)
self.app_name = app_name
logger.info("Created Mailer for app_name %s" % app_name)
env = jinja2.Environment(loader=loader)
env.filters["format_ts"] = format_ts_filter
env.filters["mxc_to_http"] = self.mxc_to_http_filter
@ -95,8 +97,16 @@ class Mailer(object):
)
@defer.inlineCallbacks
def send_notification_mail(self, user_id, email_address, push_actions, reason):
raw_from = email.utils.parseaddr(self.hs.config.email_notif_from)[1]
def send_notification_mail(self, app_id, user_id, email_address,
push_actions, reason):
try:
from_string = self.hs.config.email_notif_from % {
"app": self.app_name
}
except TypeError:
from_string = self.hs.config.email_notif_from
raw_from = email.utils.parseaddr(from_string)[1]
raw_to = email.utils.parseaddr(email_address)[1]
if raw_to == '':
@ -159,7 +169,9 @@ class Mailer(object):
template_vars = {
"user_display_name": user_display_name,
"unsubscribe_link": self.make_unsubscribe_link(),
"unsubscribe_link": self.make_unsubscribe_link(
user_id, app_id, email_address
),
"summary_text": summary_text,
"app_name": self.app_name,
"rooms": rooms,
@ -425,9 +437,18 @@ class Mailer(object):
notif['room_id'], notif['event_id']
)
def make_unsubscribe_link(self):
# XXX: matrix.to
return "https://vector.im/#/settings"
def make_unsubscribe_link(self, user_id, app_id, email_address):
params = {
"access_token": self.auth_handler.generate_delete_pusher_token(user_id),
"app_id": app_id,
"pushkey": email_address,
}
# XXX: make r0 once API is stable
return "%s_matrix/client/unstable/pushers/remove?%s" % (
self.hs.config.public_baseurl,
urllib.urlencode(params),
)
def mxc_to_http_filter(self, value, width, height, resize_method="crop"):
if value[0:6] != "mxc://":

View File

@ -0,0 +1,59 @@
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.http.server import respond_with_json_bytes, request_handler
from synapse.http.servlet import parse_json_object_from_request
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
from twisted.internet import defer
class PresenceResource(Resource):
"""
HTTP endpoint for marking users as syncing.
POST /_synapse/replication/presence HTTP/1.1
Content-Type: application/json
{
"process_id": "<process_id>",
"syncing_users": ["<user_id>"]
}
"""
def __init__(self, hs):
Resource.__init__(self) # Resource is old-style, so no super()
self.version_string = hs.version_string
self.clock = hs.get_clock()
self.presence_handler = hs.get_presence_handler()
def render_POST(self, request):
self._async_render_POST(request)
return NOT_DONE_YET
@request_handler()
@defer.inlineCallbacks
def _async_render_POST(self, request):
content = parse_json_object_from_request(request)
process_id = content["process_id"]
syncing_user_ids = content["syncing_users"]
yield self.presence_handler.update_external_syncs(
process_id, set(syncing_user_ids)
)
respond_with_json_bytes(request, 200, "{}")

View File

@ -16,6 +16,7 @@
from synapse.http.servlet import parse_integer, parse_string
from synapse.http.server import request_handler, finish_request
from synapse.replication.pusher_resource import PusherResource
from synapse.replication.presence_resource import PresenceResource
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
@ -115,6 +116,7 @@ class ReplicationResource(Resource):
self.clock = hs.get_clock()
self.putChild("remove_pushers", PusherResource(hs))
self.putChild("syncing_users", PresenceResource(hs))
def render_GET(self, request):
self._async_render_GET(request)

View File

@ -15,7 +15,10 @@
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore
from synapse.storage.account_data import AccountDataStore
from synapse.storage.tags import TagsStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedAccountDataStore(BaseSlavedStore):
@ -25,6 +28,14 @@ class SlavedAccountDataStore(BaseSlavedStore):
self._account_data_id_gen = SlavedIdTracker(
db_conn, "account_data_max_stream_id", "stream_id",
)
self._account_data_stream_cache = StreamChangeCache(
"AccountDataAndTagsChangeCache",
self._account_data_id_gen.get_current_token(),
)
get_account_data_for_user = (
AccountDataStore.__dict__["get_account_data_for_user"]
)
get_global_account_data_by_type_for_users = (
AccountDataStore.__dict__["get_global_account_data_by_type_for_users"]
@ -34,6 +45,16 @@ class SlavedAccountDataStore(BaseSlavedStore):
AccountDataStore.__dict__["get_global_account_data_by_type_for_user"]
)
get_tags_for_user = TagsStore.__dict__["get_tags_for_user"]
get_updated_tags = DataStore.get_updated_tags.__func__
get_updated_account_data_for_user = (
DataStore.get_updated_account_data_for_user.__func__
)
def get_max_account_data_stream_id(self):
return self._account_data_id_gen.get_current_token()
def stream_positions(self):
result = super(SlavedAccountDataStore, self).stream_positions()
position = self._account_data_id_gen.get_current_token()
@ -47,15 +68,33 @@ class SlavedAccountDataStore(BaseSlavedStore):
if stream:
self._account_data_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
user_id, data_type = row[1:3]
position, user_id, data_type = row[:3]
self.get_global_account_data_by_type_for_user.invalidate(
(data_type, user_id,)
)
self.get_account_data_for_user.invalidate((user_id,))
self._account_data_stream_cache.entity_has_changed(
user_id, position
)
stream = result.get("room_account_data")
if stream:
self._account_data_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
position, user_id = row[:2]
self.get_account_data_for_user.invalidate((user_id,))
self._account_data_stream_cache.entity_has_changed(
user_id, position
)
stream = result.get("tag_account_data")
if stream:
self._account_data_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
position, user_id = row[:2]
self.get_tags_for_user.invalidate((user_id,))
self._account_data_stream_cache.entity_has_changed(
user_id, position
)
return super(SlavedAccountDataStore, self).process_replication(result)

View File

@ -0,0 +1,30 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from synapse.storage import DataStore
from synapse.config.appservice import load_appservices
class SlavedApplicationServiceStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedApplicationServiceStore, self).__init__(db_conn, hs)
self.services_cache = load_appservices(
hs.config.server_name,
hs.config.app_service_config_files
)
get_app_service_by_token = DataStore.get_app_service_by_token.__func__
get_app_service_by_user_id = DataStore.get_app_service_by_user_id.__func__

View File

@ -23,6 +23,7 @@ from synapse.storage.roommember import RoomMemberStore
from synapse.storage.event_federation import EventFederationStore
from synapse.storage.event_push_actions import EventPushActionsStore
from synapse.storage.state import StateStore
from synapse.storage.stream import StreamStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
import ujson as json
@ -57,6 +58,9 @@ class SlavedEventStore(BaseSlavedStore):
"EventsRoomStreamChangeCache", min_event_val,
prefilled_cache=event_cache_prefill,
)
self._membership_stream_cache = StreamChangeCache(
"MembershipStreamChangeCache", events_max,
)
# Cached functions can't be accessed through a class instance so we need
# to reach inside the __dict__ to extract them.
@ -87,6 +91,9 @@ class SlavedEventStore(BaseSlavedStore):
_get_state_group_from_group = (
StateStore.__dict__["_get_state_group_from_group"]
)
get_recent_event_ids_for_room = (
StreamStore.__dict__["get_recent_event_ids_for_room"]
)
get_unread_push_actions_for_user_in_range = (
DataStore.get_unread_push_actions_for_user_in_range.__func__
@ -109,10 +116,16 @@ class SlavedEventStore(BaseSlavedStore):
DataStore.get_room_events_stream_for_room.__func__
)
get_events_around = DataStore.get_events_around.__func__
get_state_for_event = DataStore.get_state_for_event.__func__
get_state_for_events = DataStore.get_state_for_events.__func__
get_state_groups = DataStore.get_state_groups.__func__
get_recent_events_for_room = DataStore.get_recent_events_for_room.__func__
get_room_events_stream_for_rooms = (
DataStore.get_room_events_stream_for_rooms.__func__
)
get_stream_token_for_event = DataStore.get_stream_token_for_event.__func__
_set_before_and_after = DataStore._set_before_and_after
_set_before_and_after = staticmethod(DataStore._set_before_and_after)
_get_events = DataStore._get_events.__func__
_get_events_from_cache = DataStore._get_events_from_cache.__func__
@ -220,9 +233,9 @@ class SlavedEventStore(BaseSlavedStore):
self.get_rooms_for_user.invalidate((event.state_key,))
# self.get_joined_hosts_for_room.invalidate((event.room_id,))
self.get_users_in_room.invalidate((event.room_id,))
# self._membership_stream_cache.entity_has_changed(
# event.state_key, event.internal_metadata.stream_ordering
# )
self._membership_stream_cache.entity_has_changed(
event.state_key, event.internal_metadata.stream_ordering
)
self.get_invited_rooms_for_user.invalidate((event.state_key,))
if not event.is_state():

View File

@ -0,0 +1,25 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from synapse.storage.filtering import FilteringStore
class SlavedFilteringStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedFilteringStore, self).__init__(db_conn, hs)
# Filters are immutable so this cache doesn't need to be expired
get_user_filter = FilteringStore.__dict__["get_user_filter"]

View File

@ -0,0 +1,59 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.storage import DataStore
class SlavedPresenceStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedPresenceStore, self).__init__(db_conn, hs)
self._presence_id_gen = SlavedIdTracker(
db_conn, "presence_stream", "stream_id",
)
self._presence_on_startup = self._get_active_presence(db_conn)
self.presence_stream_cache = self.presence_stream_cache = StreamChangeCache(
"PresenceStreamChangeCache", self._presence_id_gen.get_current_token()
)
_get_active_presence = DataStore._get_active_presence.__func__
take_presence_startup_info = DataStore.take_presence_startup_info.__func__
get_presence_for_users = DataStore.get_presence_for_users.__func__
def get_current_presence_token(self):
return self._presence_id_gen.get_current_token()
def stream_positions(self):
result = super(SlavedPresenceStore, self).stream_positions()
position = self._presence_id_gen.get_current_token()
result["presence"] = position
return result
def process_replication(self, result):
stream = result.get("presence")
if stream:
self._presence_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
position, user_id = row[:2]
self.presence_stream_cache.entity_has_changed(
user_id, position
)
return super(SlavedPresenceStore, self).process_replication(result)

View File

@ -0,0 +1,67 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .events import SlavedEventStore
from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore
from synapse.storage.push_rule import PushRuleStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedPushRuleStore(SlavedEventStore):
def __init__(self, db_conn, hs):
super(SlavedPushRuleStore, self).__init__(db_conn, hs)
self._push_rules_stream_id_gen = SlavedIdTracker(
db_conn, "push_rules_stream", "stream_id",
)
self.push_rules_stream_cache = StreamChangeCache(
"PushRulesStreamChangeCache",
self._push_rules_stream_id_gen.get_current_token(),
)
get_push_rules_for_user = PushRuleStore.__dict__["get_push_rules_for_user"]
get_push_rules_enabled_for_user = (
PushRuleStore.__dict__["get_push_rules_enabled_for_user"]
)
have_push_rules_changed_for_user = (
DataStore.have_push_rules_changed_for_user.__func__
)
def get_push_rules_stream_token(self):
return (
self._push_rules_stream_id_gen.get_current_token(),
self._stream_id_gen.get_current_token(),
)
def stream_positions(self):
result = super(SlavedPushRuleStore, self).stream_positions()
result["push_rules"] = self._push_rules_stream_id_gen.get_current_token()
return result
def process_replication(self, result):
stream = result.get("push_rules")
if stream:
for row in stream["rows"]:
position = row[0]
user_id = row[2]
self.get_push_rules_for_user.invalidate((user_id,))
self.get_push_rules_enabled_for_user.invalidate((user_id,))
self.push_rules_stream_cache.entity_has_changed(
user_id, position
)
self._push_rules_stream_id_gen.advance(int(stream["position"]))
return super(SlavedPushRuleStore, self).process_replication(result)

View File

@ -18,6 +18,7 @@ from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore
from synapse.storage.receipts import ReceiptsStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
# So, um, we want to borrow a load of functions intended for reading from
# a DataStore, but we don't want to take functions that either write to the
@ -37,11 +38,28 @@ class SlavedReceiptsStore(BaseSlavedStore):
db_conn, "receipts_linearized", "stream_id"
)
self._receipts_stream_cache = StreamChangeCache(
"ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
)
get_receipts_for_user = ReceiptsStore.__dict__["get_receipts_for_user"]
get_linearized_receipts_for_room = (
ReceiptsStore.__dict__["get_linearized_receipts_for_room"]
)
_get_linearized_receipts_for_rooms = (
ReceiptsStore.__dict__["_get_linearized_receipts_for_rooms"]
)
get_last_receipt_event_id_for_user = (
ReceiptsStore.__dict__["get_last_receipt_event_id_for_user"]
)
get_max_receipt_stream_id = DataStore.get_max_receipt_stream_id.__func__
get_all_updated_receipts = DataStore.get_all_updated_receipts.__func__
get_linearized_receipts_for_rooms = (
DataStore.get_linearized_receipts_for_rooms.__func__
)
def stream_positions(self):
result = super(SlavedReceiptsStore, self).stream_positions()
result["receipts"] = self._receipts_id_gen.get_current_token()
@ -52,10 +70,15 @@ class SlavedReceiptsStore(BaseSlavedStore):
if stream:
self._receipts_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
room_id, receipt_type, user_id = row[1:4]
position, room_id, receipt_type, user_id = row[:4]
self.invalidate_caches_for_receipt(room_id, receipt_type, user_id)
self._receipts_stream_cache.entity_has_changed(room_id, position)
return super(SlavedReceiptsStore, self).process_replication(result)
def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
self.get_receipts_for_user.invalidate((user_id, receipt_type))
self.get_linearized_receipts_for_room.invalidate_many((room_id,))
self.get_last_receipt_event_id_for_user.invalidate(
(user_id, room_id, receipt_type)
)

View File

@ -0,0 +1,30 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from synapse.storage import DataStore
from synapse.storage.registration import RegistrationStore
class SlavedRegistrationStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedRegistrationStore, self).__init__(db_conn, hs)
# TODO: use the cached version and invalidate deleted tokens
get_user_by_access_token = RegistrationStore.__dict__[
"get_user_by_access_token"
].orig
_query_for_auth = DataStore._query_for_auth.__func__

View File

@ -58,6 +58,7 @@ class LoginRestServlet(ClientV1RestServlet):
self.cas_required_attributes = hs.config.cas_required_attributes
self.servername = hs.config.server_name
self.http_client = hs.get_simple_http_client()
self.auth_handler = self.hs.get_auth_handler()
def on_GET(self, request):
flows = []
@ -143,7 +144,7 @@ class LoginRestServlet(ClientV1RestServlet):
user_id, self.hs.hostname
).to_string()
auth_handler = self.handlers.auth_handler
auth_handler = self.auth_handler
user_id, access_token, refresh_token = yield auth_handler.login_with_password(
user_id=user_id,
password=login_submission["password"])
@ -160,7 +161,7 @@ class LoginRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def do_token_login(self, login_submission):
token = login_submission['token']
auth_handler = self.handlers.auth_handler
auth_handler = self.auth_handler
user_id = (
yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
)
@ -194,7 +195,7 @@ class LoginRestServlet(ClientV1RestServlet):
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.handlers.auth_handler
auth_handler = self.auth_handler
user_exists = yield auth_handler.does_user_exist(user_id)
if user_exists:
user_id, access_token, refresh_token = (
@ -243,7 +244,7 @@ class LoginRestServlet(ClientV1RestServlet):
raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.handlers.auth_handler
auth_handler = self.auth_handler
user_exists = yield auth_handler.does_user_exist(user_id)
if user_exists:
user_id, access_token, refresh_token = (
@ -412,7 +413,7 @@ class CasTicketServlet(ClientV1RestServlet):
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.handlers.auth_handler
auth_handler = self.auth_handler
user_exists = yield auth_handler.does_user_exist(user_id)
if not user_exists:
user_id, _ = (

View File

@ -17,7 +17,11 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, Codes
from synapse.push import PusherConfigException
from synapse.http.servlet import parse_json_object_from_request
from synapse.http.servlet import (
parse_json_object_from_request, parse_string, RestServlet
)
from synapse.http.server import finish_request
from synapse.api.errors import StoreError
from .base import ClientV1RestServlet, client_path_patterns
@ -136,6 +140,57 @@ class PushersSetRestServlet(ClientV1RestServlet):
return 200, {}
class PushersRemoveRestServlet(RestServlet):
"""
To allow pusher to be delete by clicking a link (ie. GET request)
"""
PATTERNS = client_path_patterns("/pushers/remove$")
SUCCESS_HTML = "<html><body>You have been unsubscribed</body><html>"
def __init__(self, hs):
super(RestServlet, self).__init__()
self.hs = hs
self.notifier = hs.get_notifier()
self.auth = hs.get_v1auth()
@defer.inlineCallbacks
def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request, rights="delete_pusher")
user = requester.user
app_id = parse_string(request, "app_id", required=True)
pushkey = parse_string(request, "pushkey", required=True)
pusher_pool = self.hs.get_pusherpool()
try:
yield pusher_pool.remove_pusher(
app_id=app_id,
pushkey=pushkey,
user_id=user.to_string(),
)
except StoreError as se:
if se.code != 404:
# This is fine: they're already unsubscribed
raise
self.notifier.on_new_replication_data()
request.setResponseCode(200)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(b"Server", self.hs.version_string)
request.setHeader(b"Content-Length", b"%d" % (
len(PushersRemoveRestServlet.SUCCESS_HTML),
))
request.write(PushersRemoveRestServlet.SUCCESS_HTML)
finish_request(request)
defer.returnValue(None)
def on_OPTIONS(self, _):
return 200, {}
def register_servlets(hs, http_server):
PushersRestServlet(hs).register(http_server)
PushersSetRestServlet(hs).register(http_server)
PushersRemoveRestServlet(hs).register(http_server)

View File

@ -35,7 +35,7 @@ class PasswordRestServlet(RestServlet):
super(PasswordRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_handlers().auth_handler
self.auth_handler = hs.get_auth_handler()
@defer.inlineCallbacks
def on_POST(self, request):
@ -97,7 +97,7 @@ class ThreepidRestServlet(RestServlet):
self.hs = hs
self.identity_handler = hs.get_handlers().identity_handler
self.auth = hs.get_auth()
self.auth_handler = hs.get_handlers().auth_handler
self.auth_handler = hs.get_auth_handler()
@defer.inlineCallbacks
def on_GET(self, request):

View File

@ -104,7 +104,7 @@ class AuthRestServlet(RestServlet):
super(AuthRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_handlers().auth_handler
self.auth_handler = hs.get_auth_handler()
self.registration_handler = hs.get_handlers().registration_handler
@defer.inlineCallbacks

View File

@ -49,7 +49,7 @@ class RegisterRestServlet(RestServlet):
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.auth_handler = hs.get_handlers().auth_handler
self.auth_handler = hs.get_auth_handler()
self.registration_handler = hs.get_handlers().registration_handler
self.identity_handler = hs.get_handlers().identity_handler

View File

@ -38,7 +38,7 @@ class TokenRefreshRestServlet(RestServlet):
body = parse_json_object_from_request(request)
try:
old_refresh_token = body["refresh_token"]
auth_handler = self.hs.get_handlers().auth_handler
auth_handler = self.hs.get_auth_handler()
(user_id, new_refresh_token) = yield self.store.exchange_refresh_token(
old_refresh_token, auth_handler.generate_refresh_token)
new_access_token = yield auth_handler.issue_access_token(user_id)

View File

@ -33,6 +33,7 @@ from synapse.handlers.presence import PresenceHandler
from synapse.handlers.sync import SyncHandler
from synapse.handlers.typing import TypingHandler
from synapse.handlers.room import RoomListHandler
from synapse.handlers.auth import AuthHandler
from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.state import StateHandler
from synapse.storage import DataStore
@ -89,6 +90,7 @@ class HomeServer(object):
'sync_handler',
'typing_handler',
'room_list_handler',
'auth_handler',
'application_service_api',
'application_service_scheduler',
'application_service_handler',
@ -190,6 +192,9 @@ class HomeServer(object):
def build_room_list_handler(self):
return RoomListHandler(self)
def build_auth_handler(self):
return AuthHandler(self)
def build_application_service_api(self):
return ApplicationServiceApi(self)

View File

@ -149,7 +149,7 @@ class DataStore(RoomMemberStore, RoomStore,
"AccountDataAndTagsChangeCache", account_max,
)
self.__presence_on_startup = self._get_active_presence(db_conn)
self._presence_on_startup = self._get_active_presence(db_conn)
presence_cache_prefill, min_presence_val = self._get_cache_dict(
db_conn, "presence_stream",
@ -190,8 +190,8 @@ class DataStore(RoomMemberStore, RoomStore,
super(DataStore, self).__init__(hs)
def take_presence_startup_info(self):
active_on_startup = self.__presence_on_startup
self.__presence_on_startup = None
active_on_startup = self._presence_on_startup
self._presence_on_startup = None
return active_on_startup
def _get_active_presence(self, db_conn):

View File

@ -152,6 +152,7 @@ class SQLBaseStore(object):
def __init__(self, hs):
self.hs = hs
self._clock = hs.get_clock()
self._db_pool = hs.get_db_pool()
self._previous_txn_total_time = 0

View File

@ -32,7 +32,7 @@ import os
import functools
import inspect
import threading
import itertools
logger = logging.getLogger(__name__)
@ -367,17 +367,16 @@ class CacheListDescriptor(object):
cached_defers[arg] = res
if cached_defers:
def update_results_dict(res):
results.update(res)
return results
return preserve_context_over_deferred(defer.gatherResults(
cached_defers.values(),
consumeErrors=True,
).addCallback(
lambda res: {
k: v
for k, v in itertools.chain(results.items(), res)
}
)).addErrback(
).addCallback(update_results_dict).addErrback(
unwrapFirstError
)
))
else:
return results

View File

@ -264,7 +264,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
)
new_state = handle_timeout(
state, is_mine=True, user_to_num_current_syncs={}, now=now
state, is_mine=True, syncing_user_ids=set(), now=now
)
self.assertIsNotNone(new_state)
@ -282,7 +282,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
)
new_state = handle_timeout(
state, is_mine=True, user_to_num_current_syncs={}, now=now
state, is_mine=True, syncing_user_ids=set(), now=now
)
self.assertIsNotNone(new_state)
@ -300,9 +300,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
)
new_state = handle_timeout(
state, is_mine=True, user_to_num_current_syncs={
user_id: 1,
}, now=now
state, is_mine=True, syncing_user_ids=set([user_id]), now=now
)
self.assertIsNotNone(new_state)
@ -321,7 +319,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
)
new_state = handle_timeout(
state, is_mine=True, user_to_num_current_syncs={}, now=now
state, is_mine=True, syncing_user_ids=set(), now=now
)
self.assertIsNotNone(new_state)
@ -340,7 +338,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
)
new_state = handle_timeout(
state, is_mine=True, user_to_num_current_syncs={}, now=now
state, is_mine=True, syncing_user_ids=set(), now=now
)
self.assertIsNone(new_state)
@ -358,7 +356,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
)
new_state = handle_timeout(
state, is_mine=False, user_to_num_current_syncs={}, now=now
state, is_mine=False, syncing_user_ids=set(), now=now
)
self.assertIsNotNone(new_state)
@ -377,7 +375,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
)
new_state = handle_timeout(
state, is_mine=True, user_to_num_current_syncs={}, now=now
state, is_mine=True, syncing_user_ids=set(), now=now
)
self.assertIsNotNone(new_state)

View File

@ -251,12 +251,12 @@ class TypingNotificationsTestCase(unittest.TestCase):
# Gut-wrenching
from synapse.handlers.typing import RoomMember
member = RoomMember(self.room_id, self.u_apple)
member = RoomMember(self.room_id, self.u_apple.to_string())
self.handler._member_typing_until[member] = 1002000
self.handler._member_typing_timer[member] = (
self.clock.call_later(1002, lambda: 0)
)
self.handler._room_typing[self.room_id] = set((self.u_apple,))
self.handler._room_typing[self.room_id] = set((self.u_apple.to_string(),))
self.assertEquals(self.event_source.get_current_key(), 0)

View File

@ -33,7 +33,6 @@ class RegisterRestServletTestCase(unittest.TestCase):
# do the dance to hook it up to the hs global
self.handlers = Mock(
auth_handler=self.auth_handler,
registration_handler=self.registration_handler,
identity_handler=self.identity_handler,
login_handler=self.login_handler
@ -42,6 +41,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.hs.hostname = "superbig~testing~thing.com"
self.hs.get_auth = Mock(return_value=self.auth)
self.hs.get_handlers = Mock(return_value=self.handlers)
self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
self.hs.config.enable_registration = True
# init the thing we're testing

View File

@ -81,16 +81,11 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
)
# bcrypt is far too slow to be doing in unit tests
def swap_out_hash_for_testing(old_build_handlers):
def build_handlers():
handlers = old_build_handlers()
auth_handler = handlers.auth_handler
auth_handler.hash = lambda p: hashlib.md5(p).hexdigest()
auth_handler.validate_hash = lambda p, h: hashlib.md5(p).hexdigest() == h
return handlers
return build_handlers
hs.build_handlers = swap_out_hash_for_testing(hs.build_handlers)
# Need to let the HS build an auth handler and then mess with it
# because AuthHandler's constructor requires the HS, so we can't make one
# beforehand and pass it in to the HS's constructor (chicken / egg)
hs.get_auth_handler().hash = lambda p: hashlib.md5(p).hexdigest()
hs.get_auth_handler().validate_hash = lambda p, h: hashlib.md5(p).hexdigest() == h
fed = kargs.get("resource_for_federation", None)
if fed: