Merge branch 'develop' of github.com:matrix-org/synapse into erikj/timings
commit
d6a9d4e562
|
@ -13,7 +13,6 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""This module contains classes for authenticating the user."""
|
||||
from canonicaljson import encode_canonical_json
|
||||
from signedjson.key import decode_verify_key_bytes
|
||||
from signedjson.sign import verify_signed_json, SignatureVerifyException
|
||||
|
@ -42,13 +41,20 @@ AuthEventTypes = (
|
|||
|
||||
|
||||
class Auth(object):
|
||||
|
||||
"""
|
||||
FIXME: This class contains a mix of functions for authenticating users
|
||||
of our client-server API and authenticating events added to room graphs.
|
||||
"""
|
||||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
self.clock = hs.get_clock()
|
||||
self.store = hs.get_datastore()
|
||||
self.state = hs.get_state_handler()
|
||||
self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
|
||||
# Docs for these currently lives at
|
||||
# https://github.com/matrix-org/matrix-doc/blob/master/drafts/macaroons_caveats.rst
|
||||
# In addition, we have type == delete_pusher which grants access only to
|
||||
# delete pushers.
|
||||
self._KNOWN_CAVEAT_PREFIXES = set([
|
||||
"gen = ",
|
||||
"guest = ",
|
||||
|
@ -525,7 +531,7 @@ class Auth(object):
|
|||
return default
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_user_by_req(self, request, allow_guest=False):
|
||||
def get_user_by_req(self, request, allow_guest=False, rights="access"):
|
||||
""" Get a registered user's ID.
|
||||
|
||||
Args:
|
||||
|
@ -547,7 +553,7 @@ class Auth(object):
|
|||
)
|
||||
|
||||
access_token = request.args["access_token"][0]
|
||||
user_info = yield self.get_user_by_access_token(access_token)
|
||||
user_info = yield self.get_user_by_access_token(access_token, rights)
|
||||
user = user_info["user"]
|
||||
token_id = user_info["token_id"]
|
||||
is_guest = user_info["is_guest"]
|
||||
|
@ -608,7 +614,7 @@ class Auth(object):
|
|||
defer.returnValue(user_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_user_by_access_token(self, token):
|
||||
def get_user_by_access_token(self, token, rights="access"):
|
||||
""" Get a registered user's ID.
|
||||
|
||||
Args:
|
||||
|
@ -619,7 +625,7 @@ class Auth(object):
|
|||
AuthError if no user by that token exists or the token is invalid.
|
||||
"""
|
||||
try:
|
||||
ret = yield self.get_user_from_macaroon(token)
|
||||
ret = yield self.get_user_from_macaroon(token, rights)
|
||||
except AuthError:
|
||||
# TODO(daniel): Remove this fallback when all existing access tokens
|
||||
# have been re-issued as macaroons.
|
||||
|
@ -627,11 +633,11 @@ class Auth(object):
|
|||
defer.returnValue(ret)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_user_from_macaroon(self, macaroon_str):
|
||||
def get_user_from_macaroon(self, macaroon_str, rights="access"):
|
||||
try:
|
||||
macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
|
||||
|
||||
self.validate_macaroon(macaroon, "access", self.hs.config.expire_access_token)
|
||||
self.validate_macaroon(macaroon, rights, self.hs.config.expire_access_token)
|
||||
|
||||
user_prefix = "user_id = "
|
||||
user = None
|
||||
|
@ -654,6 +660,13 @@ class Auth(object):
|
|||
"is_guest": True,
|
||||
"token_id": None,
|
||||
}
|
||||
elif rights == "delete_pusher":
|
||||
# We don't store these tokens in the database
|
||||
ret = {
|
||||
"user": user,
|
||||
"is_guest": False,
|
||||
"token_id": None,
|
||||
}
|
||||
else:
|
||||
# This codepath exists so that we can actually return a
|
||||
# token ID, because we use token IDs in place of device
|
||||
|
@ -685,7 +698,8 @@ class Auth(object):
|
|||
|
||||
Args:
|
||||
macaroon(pymacaroons.Macaroon): The macaroon to validate
|
||||
type_string(str): The kind of token this is (e.g. "access", "refresh")
|
||||
type_string(str): The kind of token required (e.g. "access", "refresh",
|
||||
"delete_pusher")
|
||||
verify_expiry(bool): Whether to verify whether the macaroon has expired.
|
||||
This should really always be True, but no clients currently implement
|
||||
token refresh, so we can't enforce expiry yet.
|
||||
|
|
|
@ -21,6 +21,7 @@ from synapse.config._base import ConfigError
|
|||
from synapse.config.database import DatabaseConfig
|
||||
from synapse.config.logger import LoggingConfig
|
||||
from synapse.config.emailconfig import EmailConfig
|
||||
from synapse.config.key import KeyConfig
|
||||
from synapse.http.site import SynapseSite
|
||||
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
|
||||
from synapse.storage.roommember import RoomMemberStore
|
||||
|
@ -63,6 +64,26 @@ class SlaveConfig(DatabaseConfig):
|
|||
self.pid_file = self.abspath(config.get("pid_file"))
|
||||
self.public_baseurl = config["public_baseurl"]
|
||||
|
||||
# some things used by the auth handler but not actually used in the
|
||||
# pusher codebase
|
||||
self.bcrypt_rounds = None
|
||||
self.ldap_enabled = None
|
||||
self.ldap_server = None
|
||||
self.ldap_port = None
|
||||
self.ldap_tls = None
|
||||
self.ldap_search_base = None
|
||||
self.ldap_search_property = None
|
||||
self.ldap_email_property = None
|
||||
self.ldap_full_name_property = None
|
||||
|
||||
# We would otherwise try to use the registration shared secret as the
|
||||
# macaroon shared secret if there was no macaroon_shared_secret, but
|
||||
# that means pulling in RegistrationConfig too. We don't need to be
|
||||
# backwards compaitible in the pusher codebase so just make people set
|
||||
# macaroon_shared_secret. We set this to None to prevent it referencing
|
||||
# an undefined key.
|
||||
self.registration_shared_secret = None
|
||||
|
||||
def default_config(self, server_name, **kwargs):
|
||||
pid_file = self.abspath("pusher.pid")
|
||||
return """\
|
||||
|
@ -95,7 +116,7 @@ class SlaveConfig(DatabaseConfig):
|
|||
""" % locals()
|
||||
|
||||
|
||||
class PusherSlaveConfig(SlaveConfig, LoggingConfig, EmailConfig):
|
||||
class PusherSlaveConfig(SlaveConfig, LoggingConfig, EmailConfig, KeyConfig):
|
||||
pass
|
||||
|
||||
|
||||
|
|
|
@ -89,7 +89,7 @@ class EmailConfig(Config):
|
|||
# enable_notifs: false
|
||||
# smtp_host: "localhost"
|
||||
# smtp_port: 25
|
||||
# notif_from: Your Friendly Matrix Home Server <noreply@example.com>
|
||||
# notif_from: "Your Friendly %(app)s Home Server <noreply@example.com>"
|
||||
# app_name: Matrix
|
||||
# template_dir: res/templates
|
||||
# notif_template_html: notif_mail.html
|
||||
|
|
|
@ -24,7 +24,6 @@ from .federation import FederationHandler
|
|||
from .profile import ProfileHandler
|
||||
from .directory import DirectoryHandler
|
||||
from .admin import AdminHandler
|
||||
from .auth import AuthHandler
|
||||
from .identity import IdentityHandler
|
||||
from .receipts import ReceiptsHandler
|
||||
from .search import SearchHandler
|
||||
|
@ -50,7 +49,6 @@ class Handlers(object):
|
|||
self.directory_handler = DirectoryHandler(hs)
|
||||
self.admin_handler = AdminHandler(hs)
|
||||
self.receipts_handler = ReceiptsHandler(hs)
|
||||
self.auth_handler = AuthHandler(hs)
|
||||
self.identity_handler = IdentityHandler(hs)
|
||||
self.search_handler = SearchHandler(hs)
|
||||
self.room_context_handler = RoomContextHandler(hs)
|
||||
|
|
|
@ -529,6 +529,11 @@ class AuthHandler(BaseHandler):
|
|||
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
||||
return macaroon.serialize()
|
||||
|
||||
def generate_delete_pusher_token(self, user_id):
|
||||
macaroon = self._generate_base_macaroon(user_id)
|
||||
macaroon.add_first_party_caveat("type = delete_pusher")
|
||||
return macaroon.serialize()
|
||||
|
||||
def validate_short_term_login_token_and_get_user_id(self, login_token):
|
||||
try:
|
||||
macaroon = pymacaroons.Macaroon.deserialize(login_token)
|
||||
|
|
|
@ -68,6 +68,10 @@ FEDERATION_TIMEOUT = 30 * 60 * 1000
|
|||
# How often to resend presence to remote servers
|
||||
FEDERATION_PING_INTERVAL = 25 * 60 * 1000
|
||||
|
||||
# How long we will wait before assuming that the syncs from an external process
|
||||
# are dead.
|
||||
EXTERNAL_PROCESS_EXPIRY = 5 * 60 * 1000
|
||||
|
||||
assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER
|
||||
|
||||
|
||||
|
@ -158,10 +162,21 @@ class PresenceHandler(object):
|
|||
self.serial_to_user = {}
|
||||
self._next_serial = 1
|
||||
|
||||
# Keeps track of the number of *ongoing* syncs. While this is non zero
|
||||
# a user will never go offline.
|
||||
# Keeps track of the number of *ongoing* syncs on this process. While
|
||||
# this is non zero a user will never go offline.
|
||||
self.user_to_num_current_syncs = {}
|
||||
|
||||
# Keeps track of the number of *ongoing* syncs on other processes.
|
||||
# While any sync is ongoing on another process the user will never
|
||||
# go offline.
|
||||
# Each process has a unique identifier and an update frequency. If
|
||||
# no update is received from that process within the update period then
|
||||
# we assume that all the sync requests on that process have stopped.
|
||||
# Stored as a dict from process_id to set of user_id, and a dict of
|
||||
# process_id to millisecond timestamp last updated.
|
||||
self.external_process_to_current_syncs = {}
|
||||
self.external_process_last_updated_ms = {}
|
||||
|
||||
# Start a LoopingCall in 30s that fires every 5s.
|
||||
# The initial delay is to allow disconnected clients a chance to
|
||||
# reconnect before we treat them as offline.
|
||||
|
@ -272,13 +287,26 @@ class PresenceHandler(object):
|
|||
# Fetch the list of users that *may* have timed out. Things may have
|
||||
# changed since the timeout was set, so we won't necessarily have to
|
||||
# take any action.
|
||||
users_to_check = self.wheel_timer.fetch(now)
|
||||
users_to_check = set(self.wheel_timer.fetch(now))
|
||||
|
||||
# Check whether the lists of syncing processes from an external
|
||||
# process have expired.
|
||||
expired_process_ids = [
|
||||
process_id for process_id, last_update
|
||||
in self.external_process_last_update.items()
|
||||
if now - last_update > EXTERNAL_PROCESS_EXPIRY
|
||||
]
|
||||
for process_id in expired_process_ids:
|
||||
users_to_check.update(
|
||||
self.external_process_to_current_syncs.pop(process_id, ())
|
||||
)
|
||||
self.external_process_last_update.pop(process_id)
|
||||
|
||||
states = [
|
||||
self.user_to_current_state.get(
|
||||
user_id, UserPresenceState.default(user_id)
|
||||
)
|
||||
for user_id in set(users_to_check)
|
||||
for user_id in users_to_check
|
||||
]
|
||||
|
||||
timers_fired_counter.inc_by(len(states))
|
||||
|
@ -286,7 +314,7 @@ class PresenceHandler(object):
|
|||
changes = handle_timeouts(
|
||||
states,
|
||||
is_mine_fn=self.is_mine_id,
|
||||
user_to_num_current_syncs=self.user_to_num_current_syncs,
|
||||
syncing_users=self.get_syncing_users(),
|
||||
now=now,
|
||||
)
|
||||
|
||||
|
@ -363,6 +391,73 @@ class PresenceHandler(object):
|
|||
|
||||
defer.returnValue(_user_syncing())
|
||||
|
||||
def get_currently_syncing_users(self):
|
||||
"""Get the set of user ids that are currently syncing on this HS.
|
||||
Returns:
|
||||
set(str): A set of user_id strings.
|
||||
"""
|
||||
syncing_user_ids = {
|
||||
user_id for user_id, count in self.user_to_num_current_syncs.items()
|
||||
if count
|
||||
}
|
||||
syncing_user_ids.update(self.external_process_to_current_syncs.values())
|
||||
return syncing_user_ids
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def update_external_syncs(self, process_id, syncing_user_ids):
|
||||
"""Update the syncing users for an external process
|
||||
|
||||
Args:
|
||||
process_id(str): An identifier for the process the users are
|
||||
syncing against. This allows synapse to process updates
|
||||
as user start and stop syncing against a given process.
|
||||
syncing_user_ids(set(str)): The set of user_ids that are
|
||||
currently syncing on that server.
|
||||
"""
|
||||
|
||||
# Grab the previous list of user_ids that were syncing on that process
|
||||
prev_syncing_user_ids = (
|
||||
self.external_process_to_current_syncs.get(process_id, set())
|
||||
)
|
||||
# Grab the current presence state for both the users that are syncing
|
||||
# now and the users that were syncing before this update.
|
||||
prev_states = yield self.current_state_for_users(
|
||||
syncing_user_ids | prev_syncing_user_ids
|
||||
)
|
||||
updates = []
|
||||
time_now_ms = self.clock.time_msec()
|
||||
|
||||
# For each new user that is syncing check if we need to mark them as
|
||||
# being online.
|
||||
for new_user_id in syncing_user_ids - prev_syncing_user_ids:
|
||||
prev_state = prev_states[new_user_id]
|
||||
if prev_state.state == PresenceState.OFFLINE:
|
||||
updates.append(prev_state.copy_and_replace(
|
||||
state=PresenceState.ONLINE,
|
||||
last_active_ts=time_now_ms,
|
||||
last_user_sync_ts=time_now_ms,
|
||||
))
|
||||
else:
|
||||
updates.append(prev_state.copy_and_replace(
|
||||
last_user_sync_ts=time_now_ms,
|
||||
))
|
||||
|
||||
# For each user that is still syncing or stopped syncing update the
|
||||
# last sync time so that we will correctly apply the grace period when
|
||||
# they stop syncing.
|
||||
for old_user_id in prev_syncing_user_ids:
|
||||
prev_state = prev_states[old_user_id]
|
||||
updates.append(prev_state.copy_and_replace(
|
||||
last_user_sync_ts=time_now_ms,
|
||||
))
|
||||
|
||||
yield self._update_states(updates)
|
||||
|
||||
# Update the last updated time for the process. We expire the entries
|
||||
# if we don't receive an update in the given timeframe.
|
||||
self.external_process_last_updated_ms[process_id] = self.clock.time_msec()
|
||||
self.external_process_to_current_syncs[process_id] = syncing_user_ids
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def current_state_for_user(self, user_id):
|
||||
"""Get the current presence state for a user.
|
||||
|
@ -935,15 +1030,14 @@ class PresenceEventSource(object):
|
|||
return self.get_new_events(user, from_key=None, include_offline=False)
|
||||
|
||||
|
||||
def handle_timeouts(user_states, is_mine_fn, user_to_num_current_syncs, now):
|
||||
def handle_timeouts(user_states, is_mine_fn, syncing_user_ids, now):
|
||||
"""Checks the presence of users that have timed out and updates as
|
||||
appropriate.
|
||||
|
||||
Args:
|
||||
user_states(list): List of UserPresenceState's to check.
|
||||
is_mine_fn (fn): Function that returns if a user_id is ours
|
||||
user_to_num_current_syncs (dict): Mapping of user_id to number of currently
|
||||
active syncs.
|
||||
syncing_user_ids (set): Set of user_ids with active syncs.
|
||||
now (int): Current time in ms.
|
||||
|
||||
Returns:
|
||||
|
@ -954,21 +1048,20 @@ def handle_timeouts(user_states, is_mine_fn, user_to_num_current_syncs, now):
|
|||
for state in user_states:
|
||||
is_mine = is_mine_fn(state.user_id)
|
||||
|
||||
new_state = handle_timeout(state, is_mine, user_to_num_current_syncs, now)
|
||||
new_state = handle_timeout(state, is_mine, syncing_user_ids, now)
|
||||
if new_state:
|
||||
changes[state.user_id] = new_state
|
||||
|
||||
return changes.values()
|
||||
|
||||
|
||||
def handle_timeout(state, is_mine, user_to_num_current_syncs, now):
|
||||
def handle_timeout(state, is_mine, syncing_user_ids, now):
|
||||
"""Checks the presence of the user to see if any of the timers have elapsed
|
||||
|
||||
Args:
|
||||
state (UserPresenceState)
|
||||
is_mine (bool): Whether the user is ours
|
||||
user_to_num_current_syncs (dict): Mapping of user_id to number of currently
|
||||
active syncs.
|
||||
syncing_user_ids (set): Set of user_ids with active syncs.
|
||||
now (int): Current time in ms.
|
||||
|
||||
Returns:
|
||||
|
@ -1002,7 +1095,7 @@ def handle_timeout(state, is_mine, user_to_num_current_syncs, now):
|
|||
|
||||
# If there are have been no sync for a while (and none ongoing),
|
||||
# set presence to offline
|
||||
if not user_to_num_current_syncs.get(user_id, 0):
|
||||
if user_id not in syncing_user_ids:
|
||||
if now - state.last_user_sync_ts > SYNC_ONLINE_TIMEOUT:
|
||||
state = state.copy_and_replace(
|
||||
state=PresenceState.OFFLINE,
|
||||
|
|
|
@ -413,7 +413,7 @@ class RegistrationHandler(BaseHandler):
|
|||
defer.returnValue((user_id, token))
|
||||
|
||||
def auth_handler(self):
|
||||
return self.hs.get_handlers().auth_handler
|
||||
return self.hs.get_auth_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def guest_access_token_for(self, medium, address, inviter_user_id):
|
||||
|
|
|
@ -636,6 +636,9 @@ class SyncHandler(object):
|
|||
)
|
||||
presence.extend(states)
|
||||
|
||||
# Deduplicate the presence entries so that there's at most one per user
|
||||
presence = {p["content"]["user_id"]: p for p in presence}.values()
|
||||
|
||||
presence = sync_config.filter_collection.filter_presence(
|
||||
presence
|
||||
)
|
||||
|
|
|
@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
# A tiny object useful for storing a user's membership in a room, as a mapping
|
||||
# key
|
||||
RoomMember = namedtuple("RoomMember", ("room_id", "user"))
|
||||
RoomMember = namedtuple("RoomMember", ("room_id", "user_id"))
|
||||
|
||||
|
||||
class TypingHandler(object):
|
||||
|
@ -38,7 +38,7 @@ class TypingHandler(object):
|
|||
self.store = hs.get_datastore()
|
||||
self.server_name = hs.config.server_name
|
||||
self.auth = hs.get_auth()
|
||||
self.is_mine = hs.is_mine
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
self.notifier = hs.get_notifier()
|
||||
|
||||
self.clock = hs.get_clock()
|
||||
|
@ -67,20 +67,23 @@ class TypingHandler(object):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def started_typing(self, target_user, auth_user, room_id, timeout):
|
||||
if not self.is_mine(target_user):
|
||||
target_user_id = target_user.to_string()
|
||||
auth_user_id = auth_user.to_string()
|
||||
|
||||
if not self.is_mine_id(target_user_id):
|
||||
raise SynapseError(400, "User is not hosted on this Home Server")
|
||||
|
||||
if target_user != auth_user:
|
||||
if target_user_id != auth_user_id:
|
||||
raise AuthError(400, "Cannot set another user's typing state")
|
||||
|
||||
yield self.auth.check_joined_room(room_id, target_user.to_string())
|
||||
yield self.auth.check_joined_room(room_id, target_user_id)
|
||||
|
||||
logger.debug(
|
||||
"%s has started typing in %s", target_user.to_string(), room_id
|
||||
"%s has started typing in %s", target_user_id, room_id
|
||||
)
|
||||
|
||||
until = self.clock.time_msec() + timeout
|
||||
member = RoomMember(room_id=room_id, user=target_user)
|
||||
member = RoomMember(room_id=room_id, user_id=target_user_id)
|
||||
|
||||
was_present = member in self._member_typing_until
|
||||
|
||||
|
@ -104,25 +107,28 @@ class TypingHandler(object):
|
|||
|
||||
yield self._push_update(
|
||||
room_id=room_id,
|
||||
user=target_user,
|
||||
user_id=target_user_id,
|
||||
typing=True,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def stopped_typing(self, target_user, auth_user, room_id):
|
||||
if not self.is_mine(target_user):
|
||||
target_user_id = target_user.to_string()
|
||||
auth_user_id = auth_user.to_string()
|
||||
|
||||
if not self.is_mine_id(target_user_id):
|
||||
raise SynapseError(400, "User is not hosted on this Home Server")
|
||||
|
||||
if target_user != auth_user:
|
||||
if target_user_id != auth_user_id:
|
||||
raise AuthError(400, "Cannot set another user's typing state")
|
||||
|
||||
yield self.auth.check_joined_room(room_id, target_user.to_string())
|
||||
yield self.auth.check_joined_room(room_id, target_user_id)
|
||||
|
||||
logger.debug(
|
||||
"%s has stopped typing in %s", target_user.to_string(), room_id
|
||||
"%s has stopped typing in %s", target_user_id, room_id
|
||||
)
|
||||
|
||||
member = RoomMember(room_id=room_id, user=target_user)
|
||||
member = RoomMember(room_id=room_id, user_id=target_user_id)
|
||||
|
||||
if member in self._member_typing_timer:
|
||||
self.clock.cancel_call_later(self._member_typing_timer[member])
|
||||
|
@ -132,8 +138,9 @@ class TypingHandler(object):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def user_left_room(self, user, room_id):
|
||||
if self.is_mine(user):
|
||||
member = RoomMember(room_id=room_id, user=user)
|
||||
user_id = user.to_string()
|
||||
if self.is_mine_id(user_id):
|
||||
member = RoomMember(room_id=room_id, user=user_id)
|
||||
yield self._stopped_typing(member)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -144,7 +151,7 @@ class TypingHandler(object):
|
|||
|
||||
yield self._push_update(
|
||||
room_id=member.room_id,
|
||||
user=member.user,
|
||||
user_id=member.user_id,
|
||||
typing=False,
|
||||
)
|
||||
|
||||
|
@ -156,7 +163,7 @@ class TypingHandler(object):
|
|||
del self._member_typing_timer[member]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _push_update(self, room_id, user, typing):
|
||||
def _push_update(self, room_id, user_id, typing):
|
||||
domains = yield self.store.get_joined_hosts_for_room(room_id)
|
||||
|
||||
deferreds = []
|
||||
|
@ -164,7 +171,7 @@ class TypingHandler(object):
|
|||
if domain == self.server_name:
|
||||
self._push_update_local(
|
||||
room_id=room_id,
|
||||
user=user,
|
||||
user_id=user_id,
|
||||
typing=typing
|
||||
)
|
||||
else:
|
||||
|
@ -173,7 +180,7 @@ class TypingHandler(object):
|
|||
edu_type="m.typing",
|
||||
content={
|
||||
"room_id": room_id,
|
||||
"user_id": user.to_string(),
|
||||
"user_id": user_id,
|
||||
"typing": typing,
|
||||
},
|
||||
))
|
||||
|
@ -183,23 +190,26 @@ class TypingHandler(object):
|
|||
@defer.inlineCallbacks
|
||||
def _recv_edu(self, origin, content):
|
||||
room_id = content["room_id"]
|
||||
user = UserID.from_string(content["user_id"])
|
||||
user_id = content["user_id"]
|
||||
|
||||
# Check that the string is a valid user id
|
||||
UserID.from_string(user_id)
|
||||
|
||||
domains = yield self.store.get_joined_hosts_for_room(room_id)
|
||||
|
||||
if self.server_name in domains:
|
||||
self._push_update_local(
|
||||
room_id=room_id,
|
||||
user=user,
|
||||
user_id=user_id,
|
||||
typing=content["typing"]
|
||||
)
|
||||
|
||||
def _push_update_local(self, room_id, user, typing):
|
||||
def _push_update_local(self, room_id, user_id, typing):
|
||||
room_set = self._room_typing.setdefault(room_id, set())
|
||||
if typing:
|
||||
room_set.add(user)
|
||||
room_set.add(user_id)
|
||||
else:
|
||||
room_set.discard(user)
|
||||
room_set.discard(user_id)
|
||||
|
||||
self._latest_room_serial += 1
|
||||
self._room_serials[room_id] = self._latest_room_serial
|
||||
|
@ -215,9 +225,7 @@ class TypingHandler(object):
|
|||
for room_id, serial in self._room_serials.items():
|
||||
if last_id < serial and serial <= current_id:
|
||||
typing = self._room_typing[room_id]
|
||||
typing_bytes = json.dumps([
|
||||
u.to_string() for u in typing
|
||||
], ensure_ascii=False)
|
||||
typing_bytes = json.dumps(list(typing), ensure_ascii=False)
|
||||
rows.append((serial, room_id, typing_bytes))
|
||||
rows.sort()
|
||||
return rows
|
||||
|
@ -239,7 +247,7 @@ class TypingNotificationEventSource(object):
|
|||
"type": "m.typing",
|
||||
"room_id": room_id,
|
||||
"content": {
|
||||
"user_ids": [u.to_string() for u in typing],
|
||||
"user_ids": list(typing),
|
||||
},
|
||||
}
|
||||
|
||||
|
|
|
@ -126,12 +126,6 @@ class DistributionMetric(object):
|
|||
|
||||
|
||||
class CacheMetric(object):
|
||||
"""A combination of two CounterMetrics, one to count cache hits and one to
|
||||
count a total, and a callback metric to yield the current size.
|
||||
|
||||
This metric generates standard metric name pairs, so that monitoring rules
|
||||
can easily be applied to measure hit ratio."""
|
||||
|
||||
__slots__ = ("name", "cache_name", "hits", "misses", "size_callback")
|
||||
|
||||
def __init__(self, name, size_callback, cache_name):
|
||||
|
|
|
@ -279,5 +279,5 @@ class EmailPusher(object):
|
|||
logger.info("Sending notif email for user %r", self.user_id)
|
||||
|
||||
yield self.mailer.send_notification_mail(
|
||||
self.user_id, self.email, push_actions, reason
|
||||
self.app_id, self.user_id, self.email, push_actions, reason
|
||||
)
|
||||
|
|
|
@ -41,7 +41,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
MESSAGE_FROM_PERSON_IN_ROOM = "You have a message on %(app)s from %(person)s " \
|
||||
"in the %s room..."
|
||||
"in the %(room)s room..."
|
||||
MESSAGE_FROM_PERSON = "You have a message on %(app)s from %(person)s..."
|
||||
MESSAGES_FROM_PERSON = "You have messages on %(app)s from %(person)s..."
|
||||
MESSAGES_IN_ROOM = "You have messages on %(app)s in the %(room)s room..."
|
||||
|
@ -81,9 +81,11 @@ class Mailer(object):
|
|||
def __init__(self, hs, app_name):
|
||||
self.hs = hs
|
||||
self.store = self.hs.get_datastore()
|
||||
self.auth_handler = self.hs.get_auth_handler()
|
||||
self.state_handler = self.hs.get_state_handler()
|
||||
loader = jinja2.FileSystemLoader(self.hs.config.email_template_dir)
|
||||
self.app_name = app_name
|
||||
logger.info("Created Mailer for app_name %s" % app_name)
|
||||
env = jinja2.Environment(loader=loader)
|
||||
env.filters["format_ts"] = format_ts_filter
|
||||
env.filters["mxc_to_http"] = self.mxc_to_http_filter
|
||||
|
@ -95,8 +97,16 @@ class Mailer(object):
|
|||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send_notification_mail(self, user_id, email_address, push_actions, reason):
|
||||
raw_from = email.utils.parseaddr(self.hs.config.email_notif_from)[1]
|
||||
def send_notification_mail(self, app_id, user_id, email_address,
|
||||
push_actions, reason):
|
||||
try:
|
||||
from_string = self.hs.config.email_notif_from % {
|
||||
"app": self.app_name
|
||||
}
|
||||
except TypeError:
|
||||
from_string = self.hs.config.email_notif_from
|
||||
|
||||
raw_from = email.utils.parseaddr(from_string)[1]
|
||||
raw_to = email.utils.parseaddr(email_address)[1]
|
||||
|
||||
if raw_to == '':
|
||||
|
@ -159,7 +169,9 @@ class Mailer(object):
|
|||
|
||||
template_vars = {
|
||||
"user_display_name": user_display_name,
|
||||
"unsubscribe_link": self.make_unsubscribe_link(),
|
||||
"unsubscribe_link": self.make_unsubscribe_link(
|
||||
user_id, app_id, email_address
|
||||
),
|
||||
"summary_text": summary_text,
|
||||
"app_name": self.app_name,
|
||||
"rooms": rooms,
|
||||
|
@ -425,9 +437,18 @@ class Mailer(object):
|
|||
notif['room_id'], notif['event_id']
|
||||
)
|
||||
|
||||
def make_unsubscribe_link(self):
|
||||
# XXX: matrix.to
|
||||
return "https://vector.im/#/settings"
|
||||
def make_unsubscribe_link(self, user_id, app_id, email_address):
|
||||
params = {
|
||||
"access_token": self.auth_handler.generate_delete_pusher_token(user_id),
|
||||
"app_id": app_id,
|
||||
"pushkey": email_address,
|
||||
}
|
||||
|
||||
# XXX: make r0 once API is stable
|
||||
return "%s_matrix/client/unstable/pushers/remove?%s" % (
|
||||
self.hs.config.public_baseurl,
|
||||
urllib.urlencode(params),
|
||||
)
|
||||
|
||||
def mxc_to_http_filter(self, value, width, height, resize_method="crop"):
|
||||
if value[0:6] != "mxc://":
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
# Copyright 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.http.server import respond_with_json_bytes, request_handler
|
||||
from synapse.http.servlet import parse_json_object_from_request
|
||||
|
||||
from twisted.web.resource import Resource
|
||||
from twisted.web.server import NOT_DONE_YET
|
||||
from twisted.internet import defer
|
||||
|
||||
|
||||
class PresenceResource(Resource):
|
||||
"""
|
||||
HTTP endpoint for marking users as syncing.
|
||||
|
||||
POST /_synapse/replication/presence HTTP/1.1
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"process_id": "<process_id>",
|
||||
"syncing_users": ["<user_id>"]
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
Resource.__init__(self) # Resource is old-style, so no super()
|
||||
|
||||
self.version_string = hs.version_string
|
||||
self.clock = hs.get_clock()
|
||||
self.presence_handler = hs.get_presence_handler()
|
||||
|
||||
def render_POST(self, request):
|
||||
self._async_render_POST(request)
|
||||
return NOT_DONE_YET
|
||||
|
||||
@request_handler()
|
||||
@defer.inlineCallbacks
|
||||
def _async_render_POST(self, request):
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
process_id = content["process_id"]
|
||||
syncing_user_ids = content["syncing_users"]
|
||||
|
||||
yield self.presence_handler.update_external_syncs(
|
||||
process_id, set(syncing_user_ids)
|
||||
)
|
||||
|
||||
respond_with_json_bytes(request, 200, "{}")
|
|
@ -16,6 +16,7 @@
|
|||
from synapse.http.servlet import parse_integer, parse_string
|
||||
from synapse.http.server import request_handler, finish_request
|
||||
from synapse.replication.pusher_resource import PusherResource
|
||||
from synapse.replication.presence_resource import PresenceResource
|
||||
|
||||
from twisted.web.resource import Resource
|
||||
from twisted.web.server import NOT_DONE_YET
|
||||
|
@ -115,6 +116,7 @@ class ReplicationResource(Resource):
|
|||
self.clock = hs.get_clock()
|
||||
|
||||
self.putChild("remove_pushers", PusherResource(hs))
|
||||
self.putChild("syncing_users", PresenceResource(hs))
|
||||
|
||||
def render_GET(self, request):
|
||||
self._async_render_GET(request)
|
||||
|
|
|
@ -15,7 +15,10 @@
|
|||
|
||||
from ._base import BaseSlavedStore
|
||||
from ._slaved_id_tracker import SlavedIdTracker
|
||||
from synapse.storage import DataStore
|
||||
from synapse.storage.account_data import AccountDataStore
|
||||
from synapse.storage.tags import TagsStore
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
|
||||
|
||||
class SlavedAccountDataStore(BaseSlavedStore):
|
||||
|
@ -25,6 +28,14 @@ class SlavedAccountDataStore(BaseSlavedStore):
|
|||
self._account_data_id_gen = SlavedIdTracker(
|
||||
db_conn, "account_data_max_stream_id", "stream_id",
|
||||
)
|
||||
self._account_data_stream_cache = StreamChangeCache(
|
||||
"AccountDataAndTagsChangeCache",
|
||||
self._account_data_id_gen.get_current_token(),
|
||||
)
|
||||
|
||||
get_account_data_for_user = (
|
||||
AccountDataStore.__dict__["get_account_data_for_user"]
|
||||
)
|
||||
|
||||
get_global_account_data_by_type_for_users = (
|
||||
AccountDataStore.__dict__["get_global_account_data_by_type_for_users"]
|
||||
|
@ -34,6 +45,16 @@ class SlavedAccountDataStore(BaseSlavedStore):
|
|||
AccountDataStore.__dict__["get_global_account_data_by_type_for_user"]
|
||||
)
|
||||
|
||||
get_tags_for_user = TagsStore.__dict__["get_tags_for_user"]
|
||||
|
||||
get_updated_tags = DataStore.get_updated_tags.__func__
|
||||
get_updated_account_data_for_user = (
|
||||
DataStore.get_updated_account_data_for_user.__func__
|
||||
)
|
||||
|
||||
def get_max_account_data_stream_id(self):
|
||||
return self._account_data_id_gen.get_current_token()
|
||||
|
||||
def stream_positions(self):
|
||||
result = super(SlavedAccountDataStore, self).stream_positions()
|
||||
position = self._account_data_id_gen.get_current_token()
|
||||
|
@ -47,15 +68,33 @@ class SlavedAccountDataStore(BaseSlavedStore):
|
|||
if stream:
|
||||
self._account_data_id_gen.advance(int(stream["position"]))
|
||||
for row in stream["rows"]:
|
||||
user_id, data_type = row[1:3]
|
||||
position, user_id, data_type = row[:3]
|
||||
self.get_global_account_data_by_type_for_user.invalidate(
|
||||
(data_type, user_id,)
|
||||
)
|
||||
self.get_account_data_for_user.invalidate((user_id,))
|
||||
self._account_data_stream_cache.entity_has_changed(
|
||||
user_id, position
|
||||
)
|
||||
|
||||
stream = result.get("room_account_data")
|
||||
if stream:
|
||||
self._account_data_id_gen.advance(int(stream["position"]))
|
||||
for row in stream["rows"]:
|
||||
position, user_id = row[:2]
|
||||
self.get_account_data_for_user.invalidate((user_id,))
|
||||
self._account_data_stream_cache.entity_has_changed(
|
||||
user_id, position
|
||||
)
|
||||
|
||||
stream = result.get("tag_account_data")
|
||||
if stream:
|
||||
self._account_data_id_gen.advance(int(stream["position"]))
|
||||
for row in stream["rows"]:
|
||||
position, user_id = row[:2]
|
||||
self.get_tags_for_user.invalidate((user_id,))
|
||||
self._account_data_stream_cache.entity_has_changed(
|
||||
user_id, position
|
||||
)
|
||||
|
||||
return super(SlavedAccountDataStore, self).process_replication(result)
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import BaseSlavedStore
|
||||
from synapse.storage import DataStore
|
||||
from synapse.config.appservice import load_appservices
|
||||
|
||||
|
||||
class SlavedApplicationServiceStore(BaseSlavedStore):
|
||||
def __init__(self, db_conn, hs):
|
||||
super(SlavedApplicationServiceStore, self).__init__(db_conn, hs)
|
||||
self.services_cache = load_appservices(
|
||||
hs.config.server_name,
|
||||
hs.config.app_service_config_files
|
||||
)
|
||||
|
||||
get_app_service_by_token = DataStore.get_app_service_by_token.__func__
|
||||
get_app_service_by_user_id = DataStore.get_app_service_by_user_id.__func__
|
|
@ -23,6 +23,7 @@ from synapse.storage.roommember import RoomMemberStore
|
|||
from synapse.storage.event_federation import EventFederationStore
|
||||
from synapse.storage.event_push_actions import EventPushActionsStore
|
||||
from synapse.storage.state import StateStore
|
||||
from synapse.storage.stream import StreamStore
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
|
||||
import ujson as json
|
||||
|
@ -57,6 +58,9 @@ class SlavedEventStore(BaseSlavedStore):
|
|||
"EventsRoomStreamChangeCache", min_event_val,
|
||||
prefilled_cache=event_cache_prefill,
|
||||
)
|
||||
self._membership_stream_cache = StreamChangeCache(
|
||||
"MembershipStreamChangeCache", events_max,
|
||||
)
|
||||
|
||||
# Cached functions can't be accessed through a class instance so we need
|
||||
# to reach inside the __dict__ to extract them.
|
||||
|
@ -87,6 +91,9 @@ class SlavedEventStore(BaseSlavedStore):
|
|||
_get_state_group_from_group = (
|
||||
StateStore.__dict__["_get_state_group_from_group"]
|
||||
)
|
||||
get_recent_event_ids_for_room = (
|
||||
StreamStore.__dict__["get_recent_event_ids_for_room"]
|
||||
)
|
||||
|
||||
get_unread_push_actions_for_user_in_range = (
|
||||
DataStore.get_unread_push_actions_for_user_in_range.__func__
|
||||
|
@ -109,10 +116,16 @@ class SlavedEventStore(BaseSlavedStore):
|
|||
DataStore.get_room_events_stream_for_room.__func__
|
||||
)
|
||||
get_events_around = DataStore.get_events_around.__func__
|
||||
get_state_for_event = DataStore.get_state_for_event.__func__
|
||||
get_state_for_events = DataStore.get_state_for_events.__func__
|
||||
get_state_groups = DataStore.get_state_groups.__func__
|
||||
get_recent_events_for_room = DataStore.get_recent_events_for_room.__func__
|
||||
get_room_events_stream_for_rooms = (
|
||||
DataStore.get_room_events_stream_for_rooms.__func__
|
||||
)
|
||||
get_stream_token_for_event = DataStore.get_stream_token_for_event.__func__
|
||||
|
||||
_set_before_and_after = DataStore._set_before_and_after
|
||||
_set_before_and_after = staticmethod(DataStore._set_before_and_after)
|
||||
|
||||
_get_events = DataStore._get_events.__func__
|
||||
_get_events_from_cache = DataStore._get_events_from_cache.__func__
|
||||
|
@ -220,9 +233,9 @@ class SlavedEventStore(BaseSlavedStore):
|
|||
self.get_rooms_for_user.invalidate((event.state_key,))
|
||||
# self.get_joined_hosts_for_room.invalidate((event.room_id,))
|
||||
self.get_users_in_room.invalidate((event.room_id,))
|
||||
# self._membership_stream_cache.entity_has_changed(
|
||||
# event.state_key, event.internal_metadata.stream_ordering
|
||||
# )
|
||||
self._membership_stream_cache.entity_has_changed(
|
||||
event.state_key, event.internal_metadata.stream_ordering
|
||||
)
|
||||
self.get_invited_rooms_for_user.invalidate((event.state_key,))
|
||||
|
||||
if not event.is_state():
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import BaseSlavedStore
|
||||
from synapse.storage.filtering import FilteringStore
|
||||
|
||||
|
||||
class SlavedFilteringStore(BaseSlavedStore):
|
||||
def __init__(self, db_conn, hs):
|
||||
super(SlavedFilteringStore, self).__init__(db_conn, hs)
|
||||
|
||||
# Filters are immutable so this cache doesn't need to be expired
|
||||
get_user_filter = FilteringStore.__dict__["get_user_filter"]
|
|
@ -0,0 +1,59 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import BaseSlavedStore
|
||||
from ._slaved_id_tracker import SlavedIdTracker
|
||||
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
from synapse.storage import DataStore
|
||||
|
||||
|
||||
class SlavedPresenceStore(BaseSlavedStore):
|
||||
def __init__(self, db_conn, hs):
|
||||
super(SlavedPresenceStore, self).__init__(db_conn, hs)
|
||||
self._presence_id_gen = SlavedIdTracker(
|
||||
db_conn, "presence_stream", "stream_id",
|
||||
)
|
||||
|
||||
self._presence_on_startup = self._get_active_presence(db_conn)
|
||||
|
||||
self.presence_stream_cache = self.presence_stream_cache = StreamChangeCache(
|
||||
"PresenceStreamChangeCache", self._presence_id_gen.get_current_token()
|
||||
)
|
||||
|
||||
_get_active_presence = DataStore._get_active_presence.__func__
|
||||
take_presence_startup_info = DataStore.take_presence_startup_info.__func__
|
||||
get_presence_for_users = DataStore.get_presence_for_users.__func__
|
||||
|
||||
def get_current_presence_token(self):
|
||||
return self._presence_id_gen.get_current_token()
|
||||
|
||||
def stream_positions(self):
|
||||
result = super(SlavedPresenceStore, self).stream_positions()
|
||||
position = self._presence_id_gen.get_current_token()
|
||||
result["presence"] = position
|
||||
return result
|
||||
|
||||
def process_replication(self, result):
|
||||
stream = result.get("presence")
|
||||
if stream:
|
||||
self._presence_id_gen.advance(int(stream["position"]))
|
||||
for row in stream["rows"]:
|
||||
position, user_id = row[:2]
|
||||
self.presence_stream_cache.entity_has_changed(
|
||||
user_id, position
|
||||
)
|
||||
|
||||
return super(SlavedPresenceStore, self).process_replication(result)
|
|
@ -0,0 +1,67 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .events import SlavedEventStore
|
||||
from ._slaved_id_tracker import SlavedIdTracker
|
||||
from synapse.storage import DataStore
|
||||
from synapse.storage.push_rule import PushRuleStore
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
|
||||
|
||||
class SlavedPushRuleStore(SlavedEventStore):
|
||||
def __init__(self, db_conn, hs):
|
||||
super(SlavedPushRuleStore, self).__init__(db_conn, hs)
|
||||
self._push_rules_stream_id_gen = SlavedIdTracker(
|
||||
db_conn, "push_rules_stream", "stream_id",
|
||||
)
|
||||
self.push_rules_stream_cache = StreamChangeCache(
|
||||
"PushRulesStreamChangeCache",
|
||||
self._push_rules_stream_id_gen.get_current_token(),
|
||||
)
|
||||
|
||||
get_push_rules_for_user = PushRuleStore.__dict__["get_push_rules_for_user"]
|
||||
get_push_rules_enabled_for_user = (
|
||||
PushRuleStore.__dict__["get_push_rules_enabled_for_user"]
|
||||
)
|
||||
have_push_rules_changed_for_user = (
|
||||
DataStore.have_push_rules_changed_for_user.__func__
|
||||
)
|
||||
|
||||
def get_push_rules_stream_token(self):
|
||||
return (
|
||||
self._push_rules_stream_id_gen.get_current_token(),
|
||||
self._stream_id_gen.get_current_token(),
|
||||
)
|
||||
|
||||
def stream_positions(self):
|
||||
result = super(SlavedPushRuleStore, self).stream_positions()
|
||||
result["push_rules"] = self._push_rules_stream_id_gen.get_current_token()
|
||||
return result
|
||||
|
||||
def process_replication(self, result):
|
||||
stream = result.get("push_rules")
|
||||
if stream:
|
||||
for row in stream["rows"]:
|
||||
position = row[0]
|
||||
user_id = row[2]
|
||||
self.get_push_rules_for_user.invalidate((user_id,))
|
||||
self.get_push_rules_enabled_for_user.invalidate((user_id,))
|
||||
self.push_rules_stream_cache.entity_has_changed(
|
||||
user_id, position
|
||||
)
|
||||
|
||||
self._push_rules_stream_id_gen.advance(int(stream["position"]))
|
||||
|
||||
return super(SlavedPushRuleStore, self).process_replication(result)
|
|
@ -18,6 +18,7 @@ from ._slaved_id_tracker import SlavedIdTracker
|
|||
|
||||
from synapse.storage import DataStore
|
||||
from synapse.storage.receipts import ReceiptsStore
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
|
||||
# So, um, we want to borrow a load of functions intended for reading from
|
||||
# a DataStore, but we don't want to take functions that either write to the
|
||||
|
@ -37,11 +38,28 @@ class SlavedReceiptsStore(BaseSlavedStore):
|
|||
db_conn, "receipts_linearized", "stream_id"
|
||||
)
|
||||
|
||||
self._receipts_stream_cache = StreamChangeCache(
|
||||
"ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
|
||||
)
|
||||
|
||||
get_receipts_for_user = ReceiptsStore.__dict__["get_receipts_for_user"]
|
||||
get_linearized_receipts_for_room = (
|
||||
ReceiptsStore.__dict__["get_linearized_receipts_for_room"]
|
||||
)
|
||||
_get_linearized_receipts_for_rooms = (
|
||||
ReceiptsStore.__dict__["_get_linearized_receipts_for_rooms"]
|
||||
)
|
||||
get_last_receipt_event_id_for_user = (
|
||||
ReceiptsStore.__dict__["get_last_receipt_event_id_for_user"]
|
||||
)
|
||||
|
||||
get_max_receipt_stream_id = DataStore.get_max_receipt_stream_id.__func__
|
||||
get_all_updated_receipts = DataStore.get_all_updated_receipts.__func__
|
||||
|
||||
get_linearized_receipts_for_rooms = (
|
||||
DataStore.get_linearized_receipts_for_rooms.__func__
|
||||
)
|
||||
|
||||
def stream_positions(self):
|
||||
result = super(SlavedReceiptsStore, self).stream_positions()
|
||||
result["receipts"] = self._receipts_id_gen.get_current_token()
|
||||
|
@ -52,10 +70,15 @@ class SlavedReceiptsStore(BaseSlavedStore):
|
|||
if stream:
|
||||
self._receipts_id_gen.advance(int(stream["position"]))
|
||||
for row in stream["rows"]:
|
||||
room_id, receipt_type, user_id = row[1:4]
|
||||
position, room_id, receipt_type, user_id = row[:4]
|
||||
self.invalidate_caches_for_receipt(room_id, receipt_type, user_id)
|
||||
self._receipts_stream_cache.entity_has_changed(room_id, position)
|
||||
|
||||
return super(SlavedReceiptsStore, self).process_replication(result)
|
||||
|
||||
def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
|
||||
self.get_receipts_for_user.invalidate((user_id, receipt_type))
|
||||
self.get_linearized_receipts_for_room.invalidate_many((room_id,))
|
||||
self.get_last_receipt_event_id_for_user.invalidate(
|
||||
(user_id, room_id, receipt_type)
|
||||
)
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import BaseSlavedStore
|
||||
from synapse.storage import DataStore
|
||||
from synapse.storage.registration import RegistrationStore
|
||||
|
||||
|
||||
class SlavedRegistrationStore(BaseSlavedStore):
|
||||
def __init__(self, db_conn, hs):
|
||||
super(SlavedRegistrationStore, self).__init__(db_conn, hs)
|
||||
|
||||
# TODO: use the cached version and invalidate deleted tokens
|
||||
get_user_by_access_token = RegistrationStore.__dict__[
|
||||
"get_user_by_access_token"
|
||||
].orig
|
||||
|
||||
_query_for_auth = DataStore._query_for_auth.__func__
|
|
@ -58,6 +58,7 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
self.cas_required_attributes = hs.config.cas_required_attributes
|
||||
self.servername = hs.config.server_name
|
||||
self.http_client = hs.get_simple_http_client()
|
||||
self.auth_handler = self.hs.get_auth_handler()
|
||||
|
||||
def on_GET(self, request):
|
||||
flows = []
|
||||
|
@ -143,7 +144,7 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
user_id, self.hs.hostname
|
||||
).to_string()
|
||||
|
||||
auth_handler = self.handlers.auth_handler
|
||||
auth_handler = self.auth_handler
|
||||
user_id, access_token, refresh_token = yield auth_handler.login_with_password(
|
||||
user_id=user_id,
|
||||
password=login_submission["password"])
|
||||
|
@ -160,7 +161,7 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
@defer.inlineCallbacks
|
||||
def do_token_login(self, login_submission):
|
||||
token = login_submission['token']
|
||||
auth_handler = self.handlers.auth_handler
|
||||
auth_handler = self.auth_handler
|
||||
user_id = (
|
||||
yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
|
||||
)
|
||||
|
@ -194,7 +195,7 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
|
||||
|
||||
user_id = UserID.create(user, self.hs.hostname).to_string()
|
||||
auth_handler = self.handlers.auth_handler
|
||||
auth_handler = self.auth_handler
|
||||
user_exists = yield auth_handler.does_user_exist(user_id)
|
||||
if user_exists:
|
||||
user_id, access_token, refresh_token = (
|
||||
|
@ -243,7 +244,7 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
|
||||
|
||||
user_id = UserID.create(user, self.hs.hostname).to_string()
|
||||
auth_handler = self.handlers.auth_handler
|
||||
auth_handler = self.auth_handler
|
||||
user_exists = yield auth_handler.does_user_exist(user_id)
|
||||
if user_exists:
|
||||
user_id, access_token, refresh_token = (
|
||||
|
@ -412,7 +413,7 @@ class CasTicketServlet(ClientV1RestServlet):
|
|||
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
|
||||
|
||||
user_id = UserID.create(user, self.hs.hostname).to_string()
|
||||
auth_handler = self.handlers.auth_handler
|
||||
auth_handler = self.auth_handler
|
||||
user_exists = yield auth_handler.does_user_exist(user_id)
|
||||
if not user_exists:
|
||||
user_id, _ = (
|
||||
|
|
|
@ -17,7 +17,11 @@ from twisted.internet import defer
|
|||
|
||||
from synapse.api.errors import SynapseError, Codes
|
||||
from synapse.push import PusherConfigException
|
||||
from synapse.http.servlet import parse_json_object_from_request
|
||||
from synapse.http.servlet import (
|
||||
parse_json_object_from_request, parse_string, RestServlet
|
||||
)
|
||||
from synapse.http.server import finish_request
|
||||
from synapse.api.errors import StoreError
|
||||
|
||||
from .base import ClientV1RestServlet, client_path_patterns
|
||||
|
||||
|
@ -136,6 +140,57 @@ class PushersSetRestServlet(ClientV1RestServlet):
|
|||
return 200, {}
|
||||
|
||||
|
||||
class PushersRemoveRestServlet(RestServlet):
|
||||
"""
|
||||
To allow pusher to be delete by clicking a link (ie. GET request)
|
||||
"""
|
||||
PATTERNS = client_path_patterns("/pushers/remove$")
|
||||
SUCCESS_HTML = "<html><body>You have been unsubscribed</body><html>"
|
||||
|
||||
def __init__(self, hs):
|
||||
super(RestServlet, self).__init__()
|
||||
self.hs = hs
|
||||
self.notifier = hs.get_notifier()
|
||||
self.auth = hs.get_v1auth()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
requester = yield self.auth.get_user_by_req(request, rights="delete_pusher")
|
||||
user = requester.user
|
||||
|
||||
app_id = parse_string(request, "app_id", required=True)
|
||||
pushkey = parse_string(request, "pushkey", required=True)
|
||||
|
||||
pusher_pool = self.hs.get_pusherpool()
|
||||
|
||||
try:
|
||||
yield pusher_pool.remove_pusher(
|
||||
app_id=app_id,
|
||||
pushkey=pushkey,
|
||||
user_id=user.to_string(),
|
||||
)
|
||||
except StoreError as se:
|
||||
if se.code != 404:
|
||||
# This is fine: they're already unsubscribed
|
||||
raise
|
||||
|
||||
self.notifier.on_new_replication_data()
|
||||
|
||||
request.setResponseCode(200)
|
||||
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
|
||||
request.setHeader(b"Server", self.hs.version_string)
|
||||
request.setHeader(b"Content-Length", b"%d" % (
|
||||
len(PushersRemoveRestServlet.SUCCESS_HTML),
|
||||
))
|
||||
request.write(PushersRemoveRestServlet.SUCCESS_HTML)
|
||||
finish_request(request)
|
||||
defer.returnValue(None)
|
||||
|
||||
def on_OPTIONS(self, _):
|
||||
return 200, {}
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
PushersRestServlet(hs).register(http_server)
|
||||
PushersSetRestServlet(hs).register(http_server)
|
||||
PushersRemoveRestServlet(hs).register(http_server)
|
||||
|
|
|
@ -35,7 +35,7 @@ class PasswordRestServlet(RestServlet):
|
|||
super(PasswordRestServlet, self).__init__()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.auth_handler = hs.get_handlers().auth_handler
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
|
@ -97,7 +97,7 @@ class ThreepidRestServlet(RestServlet):
|
|||
self.hs = hs
|
||||
self.identity_handler = hs.get_handlers().identity_handler
|
||||
self.auth = hs.get_auth()
|
||||
self.auth_handler = hs.get_handlers().auth_handler
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
|
|
|
@ -104,7 +104,7 @@ class AuthRestServlet(RestServlet):
|
|||
super(AuthRestServlet, self).__init__()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.auth_handler = hs.get_handlers().auth_handler
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
self.registration_handler = hs.get_handlers().registration_handler
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
|
|
@ -49,7 +49,7 @@ class RegisterRestServlet(RestServlet):
|
|||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
self.auth_handler = hs.get_handlers().auth_handler
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
self.registration_handler = hs.get_handlers().registration_handler
|
||||
self.identity_handler = hs.get_handlers().identity_handler
|
||||
|
||||
|
|
|
@ -38,7 +38,7 @@ class TokenRefreshRestServlet(RestServlet):
|
|||
body = parse_json_object_from_request(request)
|
||||
try:
|
||||
old_refresh_token = body["refresh_token"]
|
||||
auth_handler = self.hs.get_handlers().auth_handler
|
||||
auth_handler = self.hs.get_auth_handler()
|
||||
(user_id, new_refresh_token) = yield self.store.exchange_refresh_token(
|
||||
old_refresh_token, auth_handler.generate_refresh_token)
|
||||
new_access_token = yield auth_handler.issue_access_token(user_id)
|
||||
|
|
|
@ -33,6 +33,7 @@ from synapse.handlers.presence import PresenceHandler
|
|||
from synapse.handlers.sync import SyncHandler
|
||||
from synapse.handlers.typing import TypingHandler
|
||||
from synapse.handlers.room import RoomListHandler
|
||||
from synapse.handlers.auth import AuthHandler
|
||||
from synapse.handlers.appservice import ApplicationServicesHandler
|
||||
from synapse.state import StateHandler
|
||||
from synapse.storage import DataStore
|
||||
|
@ -89,6 +90,7 @@ class HomeServer(object):
|
|||
'sync_handler',
|
||||
'typing_handler',
|
||||
'room_list_handler',
|
||||
'auth_handler',
|
||||
'application_service_api',
|
||||
'application_service_scheduler',
|
||||
'application_service_handler',
|
||||
|
@ -190,6 +192,9 @@ class HomeServer(object):
|
|||
def build_room_list_handler(self):
|
||||
return RoomListHandler(self)
|
||||
|
||||
def build_auth_handler(self):
|
||||
return AuthHandler(self)
|
||||
|
||||
def build_application_service_api(self):
|
||||
return ApplicationServiceApi(self)
|
||||
|
||||
|
|
|
@ -149,7 +149,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
|||
"AccountDataAndTagsChangeCache", account_max,
|
||||
)
|
||||
|
||||
self.__presence_on_startup = self._get_active_presence(db_conn)
|
||||
self._presence_on_startup = self._get_active_presence(db_conn)
|
||||
|
||||
presence_cache_prefill, min_presence_val = self._get_cache_dict(
|
||||
db_conn, "presence_stream",
|
||||
|
@ -190,8 +190,8 @@ class DataStore(RoomMemberStore, RoomStore,
|
|||
super(DataStore, self).__init__(hs)
|
||||
|
||||
def take_presence_startup_info(self):
|
||||
active_on_startup = self.__presence_on_startup
|
||||
self.__presence_on_startup = None
|
||||
active_on_startup = self._presence_on_startup
|
||||
self._presence_on_startup = None
|
||||
return active_on_startup
|
||||
|
||||
def _get_active_presence(self, db_conn):
|
||||
|
|
|
@ -152,6 +152,7 @@ class SQLBaseStore(object):
|
|||
|
||||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
self._clock = hs.get_clock()
|
||||
self._db_pool = hs.get_db_pool()
|
||||
|
||||
self._previous_txn_total_time = 0
|
||||
|
|
|
@ -32,7 +32,7 @@ import os
|
|||
import functools
|
||||
import inspect
|
||||
import threading
|
||||
import itertools
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -367,17 +367,16 @@ class CacheListDescriptor(object):
|
|||
cached_defers[arg] = res
|
||||
|
||||
if cached_defers:
|
||||
def update_results_dict(res):
|
||||
results.update(res)
|
||||
return results
|
||||
|
||||
return preserve_context_over_deferred(defer.gatherResults(
|
||||
cached_defers.values(),
|
||||
consumeErrors=True,
|
||||
).addCallback(
|
||||
lambda res: {
|
||||
k: v
|
||||
for k, v in itertools.chain(results.items(), res)
|
||||
}
|
||||
)).addErrback(
|
||||
).addCallback(update_results_dict).addErrback(
|
||||
unwrapFirstError
|
||||
)
|
||||
))
|
||||
else:
|
||||
return results
|
||||
|
||||
|
|
|
@ -264,7 +264,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
|
|||
)
|
||||
|
||||
new_state = handle_timeout(
|
||||
state, is_mine=True, user_to_num_current_syncs={}, now=now
|
||||
state, is_mine=True, syncing_user_ids=set(), now=now
|
||||
)
|
||||
|
||||
self.assertIsNotNone(new_state)
|
||||
|
@ -282,7 +282,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
|
|||
)
|
||||
|
||||
new_state = handle_timeout(
|
||||
state, is_mine=True, user_to_num_current_syncs={}, now=now
|
||||
state, is_mine=True, syncing_user_ids=set(), now=now
|
||||
)
|
||||
|
||||
self.assertIsNotNone(new_state)
|
||||
|
@ -300,9 +300,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
|
|||
)
|
||||
|
||||
new_state = handle_timeout(
|
||||
state, is_mine=True, user_to_num_current_syncs={
|
||||
user_id: 1,
|
||||
}, now=now
|
||||
state, is_mine=True, syncing_user_ids=set([user_id]), now=now
|
||||
)
|
||||
|
||||
self.assertIsNotNone(new_state)
|
||||
|
@ -321,7 +319,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
|
|||
)
|
||||
|
||||
new_state = handle_timeout(
|
||||
state, is_mine=True, user_to_num_current_syncs={}, now=now
|
||||
state, is_mine=True, syncing_user_ids=set(), now=now
|
||||
)
|
||||
|
||||
self.assertIsNotNone(new_state)
|
||||
|
@ -340,7 +338,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
|
|||
)
|
||||
|
||||
new_state = handle_timeout(
|
||||
state, is_mine=True, user_to_num_current_syncs={}, now=now
|
||||
state, is_mine=True, syncing_user_ids=set(), now=now
|
||||
)
|
||||
|
||||
self.assertIsNone(new_state)
|
||||
|
@ -358,7 +356,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
|
|||
)
|
||||
|
||||
new_state = handle_timeout(
|
||||
state, is_mine=False, user_to_num_current_syncs={}, now=now
|
||||
state, is_mine=False, syncing_user_ids=set(), now=now
|
||||
)
|
||||
|
||||
self.assertIsNotNone(new_state)
|
||||
|
@ -377,7 +375,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
|
|||
)
|
||||
|
||||
new_state = handle_timeout(
|
||||
state, is_mine=True, user_to_num_current_syncs={}, now=now
|
||||
state, is_mine=True, syncing_user_ids=set(), now=now
|
||||
)
|
||||
|
||||
self.assertIsNotNone(new_state)
|
||||
|
|
|
@ -251,12 +251,12 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
|||
|
||||
# Gut-wrenching
|
||||
from synapse.handlers.typing import RoomMember
|
||||
member = RoomMember(self.room_id, self.u_apple)
|
||||
member = RoomMember(self.room_id, self.u_apple.to_string())
|
||||
self.handler._member_typing_until[member] = 1002000
|
||||
self.handler._member_typing_timer[member] = (
|
||||
self.clock.call_later(1002, lambda: 0)
|
||||
)
|
||||
self.handler._room_typing[self.room_id] = set((self.u_apple,))
|
||||
self.handler._room_typing[self.room_id] = set((self.u_apple.to_string(),))
|
||||
|
||||
self.assertEquals(self.event_source.get_current_key(), 0)
|
||||
|
||||
|
|
|
@ -33,7 +33,6 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
|||
|
||||
# do the dance to hook it up to the hs global
|
||||
self.handlers = Mock(
|
||||
auth_handler=self.auth_handler,
|
||||
registration_handler=self.registration_handler,
|
||||
identity_handler=self.identity_handler,
|
||||
login_handler=self.login_handler
|
||||
|
@ -42,6 +41,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
|||
self.hs.hostname = "superbig~testing~thing.com"
|
||||
self.hs.get_auth = Mock(return_value=self.auth)
|
||||
self.hs.get_handlers = Mock(return_value=self.handlers)
|
||||
self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
|
||||
self.hs.config.enable_registration = True
|
||||
|
||||
# init the thing we're testing
|
||||
|
|
|
@ -81,16 +81,11 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
|
|||
)
|
||||
|
||||
# bcrypt is far too slow to be doing in unit tests
|
||||
def swap_out_hash_for_testing(old_build_handlers):
|
||||
def build_handlers():
|
||||
handlers = old_build_handlers()
|
||||
auth_handler = handlers.auth_handler
|
||||
auth_handler.hash = lambda p: hashlib.md5(p).hexdigest()
|
||||
auth_handler.validate_hash = lambda p, h: hashlib.md5(p).hexdigest() == h
|
||||
return handlers
|
||||
return build_handlers
|
||||
|
||||
hs.build_handlers = swap_out_hash_for_testing(hs.build_handlers)
|
||||
# Need to let the HS build an auth handler and then mess with it
|
||||
# because AuthHandler's constructor requires the HS, so we can't make one
|
||||
# beforehand and pass it in to the HS's constructor (chicken / egg)
|
||||
hs.get_auth_handler().hash = lambda p: hashlib.md5(p).hexdigest()
|
||||
hs.get_auth_handler().validate_hash = lambda p, h: hashlib.md5(p).hexdigest() == h
|
||||
|
||||
fed = kargs.get("resource_for_federation", None)
|
||||
if fed:
|
||||
|
|
Loading…
Reference in New Issue