Merge branch 'develop' of github.com:matrix-org/synapse into erikj/createroom_content

pull/2717/head
Erik Johnston 2017-12-07 14:24:01 +00:00
commit d8a6c734fa
73 changed files with 1778 additions and 822 deletions

View File

@ -5,7 +5,8 @@ To use it, first install prometheus by following the instructions at
http://prometheus.io/
Then add a new job to the main prometheus.conf file:
### for Prometheus v1
Add a new job to the main prometheus.conf file:
job: {
name: "synapse"
@ -15,6 +16,22 @@ Then add a new job to the main prometheus.conf file:
}
}
### for Prometheus v2
Add a new job to the main prometheus.yml file:
- job_name: "synapse"
metrics_path: "/_synapse/metrics"
# when endpoint uses https:
scheme: "https"
static_configs:
- targets: ['SERVER.LOCATION:PORT']
To use `synapse.rules` add
rule_files:
- "/PATH/TO/synapse-v2.rules"
Metrics are disabled by default when running synapse; they must be enabled
with the 'enable-metrics' option, either in the synapse config file or as a
command-line option.

View File

@ -0,0 +1,60 @@
groups:
- name: synapse
rules:
- record: "synapse_federation_transaction_queue_pendingEdus:total"
expr: "sum(synapse_federation_transaction_queue_pendingEdus or absent(synapse_federation_transaction_queue_pendingEdus)*0)"
- record: "synapse_federation_transaction_queue_pendingPdus:total"
expr: "sum(synapse_federation_transaction_queue_pendingPdus or absent(synapse_federation_transaction_queue_pendingPdus)*0)"
- record: 'synapse_http_server_requests:method'
labels:
servlet: ""
expr: "sum(synapse_http_server_requests) by (method)"
- record: 'synapse_http_server_requests:servlet'
labels:
method: ""
expr: 'sum(synapse_http_server_requests) by (servlet)'
- record: 'synapse_http_server_requests:total'
labels:
servlet: ""
expr: 'sum(synapse_http_server_requests:by_method) by (servlet)'
- record: 'synapse_cache:hit_ratio_5m'
expr: 'rate(synapse_util_caches_cache:hits[5m]) / rate(synapse_util_caches_cache:total[5m])'
- record: 'synapse_cache:hit_ratio_30s'
expr: 'rate(synapse_util_caches_cache:hits[30s]) / rate(synapse_util_caches_cache:total[30s])'
- record: 'synapse_federation_client_sent'
labels:
type: "EDU"
expr: 'synapse_federation_client_sent_edus + 0'
- record: 'synapse_federation_client_sent'
labels:
type: "PDU"
expr: 'synapse_federation_client_sent_pdu_destinations:count + 0'
- record: 'synapse_federation_client_sent'
labels:
type: "Query"
expr: 'sum(synapse_federation_client_sent_queries) by (job)'
- record: 'synapse_federation_server_received'
labels:
type: "EDU"
expr: 'synapse_federation_server_received_edus + 0'
- record: 'synapse_federation_server_received'
labels:
type: "PDU"
expr: 'synapse_federation_server_received_pdus + 0'
- record: 'synapse_federation_server_received'
labels:
type: "Query"
expr: 'sum(synapse_federation_server_received_queries) by (job)'
- record: 'synapse_federation_transaction_queue_pending'
labels:
type: "EDU"
expr: 'synapse_federation_transaction_queue_pending_edus + 0'
- record: 'synapse_federation_transaction_queue_pending'
labels:
type: "PDU"
expr: 'synapse_federation_transaction_queue_pending_pdus + 0'

View File

@ -298,10 +298,6 @@ It can be used like this:
# this will now be logged against the request context
logger.debug("Request handling complete")
XXX: I think ``preserve_context_over_fn`` is supposed to do the first option,
but the fact that it does ``preserve_context_over_deferred`` on its results
means that its use is fraught with difficulty.
Passing synapse deferreds into third-party functions
----------------------------------------------------

View File

@ -1,11 +1,15 @@
Scaling synapse via workers
---------------------------
===========================
Synapse has experimental support for splitting out functionality into
multiple separate python processes, helping greatly with scalability. These
processes are called 'workers', and are (eventually) intended to scale
horizontally independently.
All of the below is highly experimental and subject to change as Synapse evolves,
but documenting it here to help folks needing highly scalable Synapses similar
to the one running matrix.org!
All processes continue to share the same database instance, and as such, workers
only work with postgres based synapse deployments (sharing a single sqlite
across multiple processes is a recipe for disaster, plus you should be using
@ -16,6 +20,16 @@ TCP protocol called 'replication' - analogous to MySQL or Postgres style
database replication; feeding a stream of relevant data to the workers so they
can be kept in sync with the main synapse process and database state.
Configuration
-------------
To make effective use of the workers, you will need to configure an HTTP
reverse-proxy such as nginx or haproxy, which will direct incoming requests to
the correct worker, or to the main synapse instance. Note that this includes
requests made to the federation port. The caveats regarding running a
reverse-proxy on the federation port still apply (see
https://github.com/matrix-org/synapse/blob/master/README.rst#reverse-proxying-the-federation-port).
To enable workers, you need to add a replication listener to the master synapse, e.g.::
listeners:
@ -27,26 +41,19 @@ Under **no circumstances** should this replication API listener be exposed to th
public internet; it currently implements no authentication whatsoever and is
unencrypted.
You then create a set of configs for the various worker processes. These should be
worker configuration files should be stored in a dedicated subdirectory, to allow
synctl to manipulate them.
The current available worker applications are:
* synapse.app.pusher - handles sending push notifications to sygnal and email
* synapse.app.synchrotron - handles /sync endpoints. can scales horizontally through multiple instances.
* synapse.app.appservice - handles output traffic to Application Services
* synapse.app.federation_reader - handles receiving federation traffic (including public_rooms API)
* synapse.app.media_repository - handles the media repository.
* synapse.app.client_reader - handles client API endpoints like /publicRooms
You then create a set of configs for the various worker processes. These
should be worker configuration files, and should be stored in a dedicated
subdirectory, to allow synctl to manipulate them.
Each worker configuration file inherits the configuration of the main homeserver
configuration file. You can then override configuration specific to that worker,
e.g. the HTTP listener that it provides (if any); logging configuration; etc.
You should minimise the number of overrides though to maintain a usable config.
You must specify the type of worker application (worker_app) and the replication
endpoint that it's talking to on the main synapse process (worker_replication_host
and worker_replication_port).
You must specify the type of worker application (``worker_app``). The currently
available worker applications are listed below. You must also specify the
replication endpoint that it's talking to on the main synapse process
(``worker_replication_host`` and ``worker_replication_port``).
For instance::
@ -68,11 +75,11 @@ For instance::
worker_log_config: /home/matrix/synapse/config/synchrotron_log_config.yaml
...is a full configuration for a synchrotron worker instance, which will expose a
plain HTTP /sync endpoint on port 8083 separately from the /sync endpoint provided
plain HTTP ``/sync`` endpoint on port 8083 separately from the ``/sync`` endpoint provided
by the main synapse.
Obviously you should configure your loadbalancer to route the /sync endpoint to
the synchrotron instance(s) in this instance.
Obviously you should configure your reverse-proxy to route the relevant
endpoints to the worker (``localhost:8083`` in the above example).
Finally, to actually run your worker-based synapse, you must pass synctl the -a
commandline option to tell it to operate on all the worker configurations found
@ -89,6 +96,114 @@ To manipulate a specific worker, you pass the -w option to synctl::
synctl -w $CONFIG/workers/synchrotron.yaml restart
All of the above is highly experimental and subject to change as Synapse evolves,
but documenting it here to help folks needing highly scalable Synapses similar
to the one running matrix.org!
Available worker applications
-----------------------------
``synapse.app.pusher``
~~~~~~~~~~~~~~~~~~~~~~
Handles sending push notifications to sygnal and email. Doesn't handle any
REST endpoints itself, but you should set ``start_pushers: False`` in the
shared configuration file to stop the main synapse sending these notifications.
Note this worker cannot be load-balanced: only one instance should be active.
``synapse.app.synchrotron``
~~~~~~~~~~~~~~~~~~~~~~~~~~~
The synchrotron handles ``sync`` requests from clients. In particular, it can
handle REST endpoints matching the following regular expressions::
^/_matrix/client/(v2_alpha|r0)/sync$
^/_matrix/client/(api/v1|v2_alpha|r0)/events$
^/_matrix/client/(api/v1|r0)/initialSync$
^/_matrix/client/(api/v1|r0)/rooms/[^/]+/initialSync$
The above endpoints should all be routed to the synchrotron worker by the
reverse-proxy configuration.
It is possible to run multiple instances of the synchrotron to scale
horizontally. In this case the reverse-proxy should be configured to
load-balance across the instances, though it will be more efficient if all
requests from a particular user are routed to a single instance. Extracting
a userid from the access token is currently left as an exercise for the reader.
``synapse.app.appservice``
~~~~~~~~~~~~~~~~~~~~~~~~~~
Handles sending output traffic to Application Services. Doesn't handle any
REST endpoints itself, but you should set ``notify_appservices: False`` in the
shared configuration file to stop the main synapse sending these notifications.
Note this worker cannot be load-balanced: only one instance should be active.
``synapse.app.federation_reader``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Handles a subset of federation endpoints. In particular, it can handle REST
endpoints matching the following regular expressions::
^/_matrix/federation/v1/event/
^/_matrix/federation/v1/state/
^/_matrix/federation/v1/state_ids/
^/_matrix/federation/v1/backfill/
^/_matrix/federation/v1/get_missing_events/
^/_matrix/federation/v1/publicRooms
The above endpoints should all be routed to the federation_reader worker by the
reverse-proxy configuration.
``synapse.app.federation_sender``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Handles sending federation traffic to other servers. Doesn't handle any
REST endpoints itself, but you should set ``send_federation: False`` in the
shared configuration file to stop the main synapse sending this traffic.
Note this worker cannot be load-balanced: only one instance should be active.
``synapse.app.media_repository``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Handles the media repository. It can handle all endpoints starting with::
/_matrix/media/
You should also set ``enable_media_repo: False`` in the shared configuration
file to stop the main synapse running background jobs related to managing the
media repository.
Note this worker cannot be load-balanced: only one instance should be active.
``synapse.app.client_reader``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Handles client API endpoints. It can handle REST endpoints matching the
following regular expressions::
^/_matrix/client/(api/v1|r0|unstable)/publicRooms$
``synapse.app.user_dir``
~~~~~~~~~~~~~~~~~~~~~~~~
Handles searches in the user directory. It can handle REST endpoints matching
the following regular expressions::
^/_matrix/client/(api/v1|r0|unstable)/user_directory/search$
``synapse.app.frontend_proxy``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Proxies some frequently-requested client endpoints to add caching and remove
load from the main synapse. It can handle REST endpoints matching the following
regular expressions::
^/_matrix/client/(api/v1|r0|unstable)/keys/upload
It will proxy any requests it cannot handle to the main synapse instance. It
must therefore be configured with the location of the main instance, via
the ``worker_main_http_uri`` setting in the frontend_proxy worker configuration
file. For example::
worker_main_http_uri: http://127.0.0.1:8008

View File

@ -123,15 +123,25 @@ def lookup(destination, path):
except:
return "https://%s:%d%s" % (destination, 8448, path)
def get_json(origin_name, origin_key, destination, path):
request_json = {
"method": "GET",
def request_json(method, origin_name, origin_key, destination, path, content):
if method is None:
if content is None:
method = "GET"
else:
method = "POST"
json_to_sign = {
"method": method,
"uri": path,
"origin": origin_name,
"destination": destination,
}
signed_json = sign_json(request_json, origin_key, origin_name)
if content is not None:
json_to_sign["content"] = json.loads(content)
signed_json = sign_json(json_to_sign, origin_key, origin_name)
authorization_headers = []
@ -145,10 +155,12 @@ def get_json(origin_name, origin_key, destination, path):
dest = lookup(destination, path)
print ("Requesting %s" % dest, file=sys.stderr)
result = requests.get(
dest,
result = requests.request(
method=method,
url=dest,
headers={"Authorization": authorization_headers[0]},
verify=False,
data=content,
)
sys.stderr.write("Status Code: %d\n" % (result.status_code,))
return result.json()
@ -186,6 +198,17 @@ def main():
"connect appropriately.",
)
parser.add_argument(
"-X", "--method",
help="HTTP method to use for the request. Defaults to GET if --data is"
"unspecified, POST if it is."
)
parser.add_argument(
"--body",
help="Data to send as the body of the HTTP request"
)
parser.add_argument(
"path",
help="request path. We will add '/_matrix/federation/v1/' to this."
@ -199,8 +222,11 @@ def main():
with open(args.signing_key_path) as f:
key = read_signing_keys(f)[0]
result = get_json(
args.server_name, key, args.destination, "/_matrix/federation/v1/" + args.path
result = request_json(
args.method,
args.server_name, key, args.destination,
"/_matrix/federation/v1/" + args.path,
content=args.body,
)
json.dump(result, sys.stdout)

45
scripts/sync_room_to_group.pl Executable file
View File

@ -0,0 +1,45 @@
#!/usr/bin/env perl
use strict;
use warnings;
use JSON::XS;
use LWP::UserAgent;
use URI::Escape;
if (@ARGV < 4) {
die "usage: $0 <homeserver url> <access_token> <room_id|room_alias> <group_id>\n";
}
my ($hs, $access_token, $room_id, $group_id) = @ARGV;
my $ua = LWP::UserAgent->new();
$ua->timeout(10);
if ($room_id =~ /^#/) {
$room_id = uri_escape($room_id);
$room_id = decode_json($ua->get("${hs}/_matrix/client/r0/directory/room/${room_id}?access_token=${access_token}")->decoded_content)->{room_id};
}
my $room_users = [ keys %{decode_json($ua->get("${hs}/_matrix/client/r0/rooms/${room_id}/joined_members?access_token=${access_token}")->decoded_content)->{joined}} ];
my $group_users = [
(map { $_->{user_id} } @{decode_json($ua->get("${hs}/_matrix/client/unstable/groups/${group_id}/users?access_token=${access_token}" )->decoded_content)->{chunk}}),
(map { $_->{user_id} } @{decode_json($ua->get("${hs}/_matrix/client/unstable/groups/${group_id}/invited_users?access_token=${access_token}" )->decoded_content)->{chunk}}),
];
die "refusing to sync from empty room" unless (@$room_users);
die "refusing to sync to empty group" unless (@$group_users);
my $diff = {};
foreach my $user (@$room_users) { $diff->{$user}++ }
foreach my $user (@$group_users) { $diff->{$user}-- }
foreach my $user (keys %$diff) {
if ($diff->{$user} == 1) {
warn "inviting $user";
print STDERR $ua->put("${hs}/_matrix/client/unstable/groups/${group_id}/admin/users/invite/${user}?access_token=${access_token}", Content=>'{}')->status_line."\n";
}
elsif ($diff->{$user} == -1) {
warn "removing $user";
print STDERR $ua->put("${hs}/_matrix/client/unstable/groups/${group_id}/admin/users/remove/${user}?access_token=${access_token}", Content=>'{}')->status_line."\n";
}
}

View File

@ -270,7 +270,11 @@ class Auth(object):
rights (str): The operation being performed; the access token must
allow this.
Returns:
dict : dict that includes the user and the ID of their access token.
Deferred[dict]: dict that includes:
`user` (UserID)
`is_guest` (bool)
`token_id` (int|None): access token id. May be None if guest
`device_id` (str|None): device corresponding to access token
Raises:
AuthError if no user by that token exists or the token is invalid.
"""

View File

@ -140,6 +140,22 @@ class RegistrationError(SynapseError):
pass
class InteractiveAuthIncompleteError(Exception):
"""An error raised when UI auth is not yet complete
(This indicates we should return a 401 with 'result' as the body)
Attributes:
result (dict): the server response to the request, which should be
passed back to the client
"""
def __init__(self, result):
super(InteractiveAuthIncompleteError, self).__init__(
"Interactive auth not yet complete",
)
self.result = result
class UnrecognizedRequestError(SynapseError):
"""An error indicating we don't understand the request you're trying to make"""
def __init__(self, *args, **kwargs):

View File

@ -43,7 +43,6 @@ from synapse.rest import ClientRestResource
from synapse.rest.key.v1.server_key_resource import LocalKey
from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.rest.media.v1.media_repository import MediaRepositoryResource
from synapse.server import HomeServer
from synapse.storage import are_all_users_on_domain
from synapse.storage.engines import IncorrectDatabaseSetup, create_engine
@ -195,14 +194,19 @@ class SynapseHomeServer(HomeServer):
})
if name in ["media", "federation", "client"]:
media_repo = MediaRepositoryResource(self)
resources.update({
MEDIA_PREFIX: media_repo,
LEGACY_MEDIA_PREFIX: media_repo,
CONTENT_REPO_PREFIX: ContentRepoResource(
self, self.config.uploads_path
),
})
if self.get_config().enable_media_repo:
media_repo = self.get_media_repository_resource()
resources.update({
MEDIA_PREFIX: media_repo,
LEGACY_MEDIA_PREFIX: media_repo,
CONTENT_REPO_PREFIX: ContentRepoResource(
self, self.config.uploads_path
),
})
elif name == "media":
raise ConfigError(
"'media' resource conflicts with enable_media_repo=False",
)
if name in ["keys", "federation"]:
resources.update({

View File

@ -35,7 +35,6 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto
from synapse.replication.slave.storage.transactions import TransactionStore
from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.rest.media.v1.media_repository import MediaRepositoryResource
from synapse.server import HomeServer
from synapse.storage.engines import create_engine
from synapse.storage.media_repository import MediaRepositoryStore
@ -89,7 +88,7 @@ class MediaRepositoryServer(HomeServer):
if name == "metrics":
resources[METRICS_PREFIX] = MetricsResource(self)
elif name == "media":
media_repo = MediaRepositoryResource(self)
media_repo = self.get_media_repository_resource()
resources.update({
MEDIA_PREFIX: media_repo,
LEGACY_MEDIA_PREFIX: media_repo,
@ -151,6 +150,13 @@ def start(config_options):
assert config.worker_app == "synapse.app.media_repository"
if config.enable_media_repo:
_base.quit_with_error(
"enable_media_repo must be disabled in the main synapse process\n"
"before the media repo can be run in a separate worker.\n"
"Please add ``enable_media_repo: false`` to the main config\n"
)
setup_logging(config, use_worker_options=True)
events.USE_FROZEN_DICTS = config.use_frozen_dicts

View File

@ -340,11 +340,10 @@ class SyncReplicationHandler(ReplicationClientHandler):
self.store = hs.get_datastore()
self.typing_handler = hs.get_typing_handler()
# NB this is a SynchrotronPresence, not a normal PresenceHandler
self.presence_handler = hs.get_presence_handler()
self.notifier = hs.get_notifier()
self.presence_handler.sync_callback = self.send_user_sync
def on_rdata(self, stream_name, token, rows):
super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows)

View File

@ -14,6 +14,7 @@
# limitations under the License.
from synapse.api.constants import EventTypes
from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.types import GroupID, get_domain_from_id
from twisted.internet import defer
@ -81,12 +82,13 @@ class ApplicationService(object):
# values.
NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS]
def __init__(self, token, url=None, namespaces=None, hs_token=None,
def __init__(self, token, hostname, url=None, namespaces=None, hs_token=None,
sender=None, id=None, protocols=None, rate_limited=True):
self.token = token
self.url = url
self.hs_token = hs_token
self.sender = sender
self.server_name = hostname
self.namespaces = self._check_namespaces(namespaces)
self.id = id
@ -125,6 +127,24 @@ class ApplicationService(object):
raise ValueError(
"Expected bool for 'exclusive' in ns '%s'" % ns
)
group_id = regex_obj.get("group_id")
if group_id:
if not isinstance(group_id, str):
raise ValueError(
"Expected string for 'group_id' in ns '%s'" % ns
)
try:
GroupID.from_string(group_id)
except Exception:
raise ValueError(
"Expected valid group ID for 'group_id' in ns '%s'" % ns
)
if get_domain_from_id(group_id) != self.server_name:
raise ValueError(
"Expected 'group_id' to be this host in ns '%s'" % ns
)
regex = regex_obj.get("regex")
if isinstance(regex, basestring):
regex_obj["regex"] = re.compile(regex) # Pre-compile regex
@ -251,6 +271,21 @@ class ApplicationService(object):
if regex_obj["exclusive"]
]
def get_groups_for_user(self, user_id):
"""Get the groups that this user is associated with by this AS
Args:
user_id (str): The ID of the user.
Returns:
iterable[str]: an iterable that yields group_id strings.
"""
return (
regex_obj["group_id"]
for regex_obj in self.namespaces[ApplicationService.NS_USERS]
if "group_id" in regex_obj and regex_obj["regex"].match(user_id)
)
def is_rate_limited(self):
return self.rate_limited

View File

@ -154,6 +154,7 @@ def _load_appservice(hostname, as_info, config_filename):
)
return ApplicationService(
token=as_info["as_token"],
hostname=hostname,
url=as_info["url"],
namespaces=as_info["namespaces"],
hs_token=as_info["hs_token"],

View File

@ -36,6 +36,7 @@ from .workers import WorkerConfig
from .push import PushConfig
from .spam_checker import SpamCheckerConfig
from .groups import GroupsConfig
from .user_directory import UserDirectoryConfig
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
@ -44,7 +45,7 @@ class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
AppServiceConfig, KeyConfig, SAML2Config, CasConfig,
JWTConfig, PasswordConfig, EmailConfig,
WorkerConfig, PasswordAuthProviderConfig, PushConfig,
SpamCheckerConfig, GroupsConfig,):
SpamCheckerConfig, GroupsConfig, UserDirectoryConfig,):
pass

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2017 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -18,28 +19,43 @@ from ._base import Config
class PushConfig(Config):
def read_config(self, config):
self.push_redact_content = False
push_config = config.get("push", {})
self.push_include_content = push_config.get("include_content", True)
# There was a a 'redact_content' setting but mistakenly read from the
# 'email'section'. Check for the flag in the 'push' section, and log,
# but do not honour it to avoid nasty surprises when people upgrade.
if push_config.get("redact_content") is not None:
print(
"The push.redact_content content option has never worked. "
"Please set push.include_content if you want this behaviour"
)
# Now check for the one in the 'email' section and honour it,
# with a warning.
push_config = config.get("email", {})
self.push_redact_content = push_config.get("redact_content", False)
redact_content = push_config.get("redact_content")
if redact_content is not None:
print(
"The 'email.redact_content' option is deprecated: "
"please set push.include_content instead"
)
self.push_include_content = not redact_content
def default_config(self, config_dir_path, server_name, **kwargs):
return """
# Control how push messages are sent to google/apple to notifications.
# Normally every message said in a room with one or more people using
# mobile devices will be posted to a push server hosted by matrix.org
# which is registered with google and apple in order to allow push
# notifications to be sent to these mobile devices.
#
# Setting redact_content to true will make the push messages contain no
# message content which will provide increased privacy. This is a
# temporary solution pending improvements to Android and iPhone apps
# to get content from the app rather than the notification.
#
# Clients requesting push notifications can either have the body of
# the message sent in the notification poke along with other details
# like the sender, or just the event ID and room ID (`event_id_only`).
# If clients choose the former, this option controls whether the
# notification request includes the content of the event (other details
# like the sender are still included). For `event_id_only` push, it
# has no effect.
# For modern android devices the notification content will still appear
# because it is loaded by the app. iPhone, however will send a
# notification saying only that a message arrived and who it came from.
#
#push:
# redact_content: false
# include_content: true
"""

View File

@ -41,6 +41,12 @@ class ServerConfig(Config):
# false only if we are updating the user directory in a worker
self.update_user_directory = config.get("update_user_directory", True)
# whether to enable the media repository endpoints. This should be set
# to false if the media repository is running as a separate endpoint;
# doing so ensures that we will not run cache cleanup jobs on the
# master, potentially causing inconsistency.
self.enable_media_repo = config.get("enable_media_repo", True)
self.filter_timeline_limit = config.get("filter_timeline_limit", -1)
# Whether we should block invites sent to users on this server

View File

@ -0,0 +1,44 @@
# -*- coding: utf-8 -*-
# Copyright 2017 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config
class UserDirectoryConfig(Config):
"""User Directory Configuration
Configuration for the behaviour of the /user_directory API
"""
def read_config(self, config):
self.user_directory_search_all_users = False
user_directory_config = config.get("user_directory", None)
if user_directory_config:
self.user_directory_search_all_users = (
user_directory_config.get("search_all_users", False)
)
def default_config(self, config_dir_path, server_name, **kwargs):
return """
# User Directory configuration
#
# 'search_all_users' defines whether to search all users visible to your HS
# when searching the user directory, rather than limiting to users visible
# in public rooms. Defaults to false. If you set it True, you'll have to run
# UPDATE user_directory_stream_pos SET stream_id = NULL;
# on your database to tell it to rebuild the user_directory search indexes.
#
#user_directory:
# search_all_users: false
"""

View File

@ -32,15 +32,22 @@ def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
"""Check whether the hash for this PDU matches the contents"""
name, expected_hash = compute_content_hash(event, hash_algorithm)
logger.debug("Expecting hash: %s", encode_base64(expected_hash))
if name not in event.hashes:
# some malformed events lack a 'hashes'. Protect against it being missing
# or a weird type by basically treating it the same as an unhashed event.
hashes = event.get("hashes")
if not isinstance(hashes, dict):
raise SynapseError(400, "Malformed 'hashes'", Codes.UNAUTHORIZED)
if name not in hashes:
raise SynapseError(
400,
"Algorithm %s not in hashes %s" % (
name, list(event.hashes),
name, list(hashes),
),
Codes.UNAUTHORIZED,
)
message_hash_base64 = event.hashes[name]
message_hash_base64 = hashes[name]
try:
message_hash_bytes = decode_base64(message_hash_base64)
except Exception:

View File

@ -25,7 +25,7 @@ from synapse.api.errors import (
from synapse.util import unwrapFirstError, logcontext
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
from synapse.events import FrozenEvent, builder
import synapse.metrics
@ -420,7 +420,7 @@ class FederationClient(FederationBase):
for e_id in batch
]
res = yield preserve_context_over_deferred(
res = yield make_deferred_yieldable(
defer.DeferredList(deferreds, consumeErrors=True)
)
for success, result in res:

View File

@ -20,7 +20,7 @@ from .persistence import TransactionActions
from .units import Transaction, Edu
from synapse.api.errors import HttpResponseException
from synapse.util import logcontext
from synapse.util import logcontext, PreserveLoggingContext
from synapse.util.async import run_on_reactor
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
from synapse.util.metrics import measure_func
@ -146,7 +146,6 @@ class TransactionQueue(object):
else:
return not destination.startswith("localhost")
@defer.inlineCallbacks
def notify_new_events(self, current_id):
"""This gets called when we have some new events we might want to
send out to other servers.
@ -156,6 +155,13 @@ class TransactionQueue(object):
if self._is_processing:
return
# fire off a processing loop in the background. It's likely it will
# outlast the current request, so run it in the sentinel logcontext.
with PreserveLoggingContext():
self._process_event_queue_loop()
@defer.inlineCallbacks
def _process_event_queue_loop(self):
try:
self._is_processing = True
while True:

View File

@ -17,7 +17,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.util.metrics import Measure
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
import logging
@ -159,7 +159,7 @@ class ApplicationServicesHandler(object):
def query_3pe(self, kind, protocol, fields):
services = yield self._get_services_for_3pn(protocol)
results = yield preserve_context_over_deferred(defer.DeferredList([
results = yield make_deferred_yieldable(defer.DeferredList([
preserve_fn(self.appservice_api.query_3pe)(service, kind, protocol, fields)
for service in services
], consumeErrors=True))

View File

@ -17,7 +17,10 @@ from twisted.internet import defer
from ._base import BaseHandler
from synapse.api.constants import LoginType
from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError
from synapse.api.errors import (
AuthError, Codes, InteractiveAuthIncompleteError, LoginError, StoreError,
SynapseError,
)
from synapse.module_api import ModuleApi
from synapse.types import UserID
from synapse.util.async import run_on_reactor
@ -46,7 +49,6 @@ class AuthHandler(BaseHandler):
"""
super(AuthHandler, self).__init__(hs)
self.checkers = {
LoginType.PASSWORD: self._check_password_auth,
LoginType.RECAPTCHA: self._check_recaptcha,
LoginType.EMAIL_IDENTITY: self._check_email_identity,
LoginType.MSISDN: self._check_msisdn,
@ -75,15 +77,76 @@ class AuthHandler(BaseHandler):
self.macaroon_gen = hs.get_macaroon_generator()
self._password_enabled = hs.config.password_enabled
login_types = set()
# we keep this as a list despite the O(N^2) implication so that we can
# keep PASSWORD first and avoid confusing clients which pick the first
# type in the list. (NB that the spec doesn't require us to do so and
# clients which favour types that they don't understand over those that
# they do are technically broken)
login_types = []
if self._password_enabled:
login_types.add(LoginType.PASSWORD)
login_types.append(LoginType.PASSWORD)
for provider in self.password_providers:
if hasattr(provider, "get_supported_login_types"):
login_types.update(
provider.get_supported_login_types().keys()
)
self._supported_login_types = frozenset(login_types)
for t in provider.get_supported_login_types().keys():
if t not in login_types:
login_types.append(t)
self._supported_login_types = login_types
@defer.inlineCallbacks
def validate_user_via_ui_auth(self, requester, request_body, clientip):
"""
Checks that the user is who they claim to be, via a UI auth.
This is used for things like device deletion and password reset where
the user already has a valid access token, but we want to double-check
that it isn't stolen by re-authenticating them.
Args:
requester (Requester): The user, as given by the access token
request_body (dict): The body of the request sent by the client
clientip (str): The IP address of the client.
Returns:
defer.Deferred[dict]: the parameters for this request (which may
have been given only in a previous call).
Raises:
InteractiveAuthIncompleteError if the client has not yet completed
any of the permitted login flows
AuthError if the client has completed a login flow, and it gives
a different user to `requester`
"""
# build a list of supported flows
flows = [
[login_type] for login_type in self._supported_login_types
]
result, params, _ = yield self.check_auth(
flows, request_body, clientip,
)
# find the completed login type
for login_type in self._supported_login_types:
if login_type not in result:
continue
user_id = result[login_type]
break
else:
# this can't happen
raise Exception(
"check_auth returned True but no successful login type",
)
# check that the UI auth matched the access token
if user_id != requester.user.to_string():
raise AuthError(403, "Invalid auth")
defer.returnValue(params)
@defer.inlineCallbacks
def check_auth(self, flows, clientdict, clientip):
@ -95,26 +158,36 @@ class AuthHandler(BaseHandler):
session with a map, which maps each auth-type (str) to the relevant
identity authenticated by that auth-type (mostly str, but for captcha, bool).
If no auth flows have been completed successfully, raises an
InteractiveAuthIncompleteError. To handle this, you can use
synapse.rest.client.v2_alpha._base.interactive_auth_handler as a
decorator.
Args:
flows (list): A list of login flows. Each flow is an ordered list of
strings representing auth-types. At least one full
flow must be completed in order for auth to be successful.
clientdict: The dictionary from the client root level, not the
'auth' key: this method prompts for auth if none is sent.
clientip (str): The IP address of the client.
Returns:
A tuple of (authed, dict, dict, session_id) where authed is true if
the client has successfully completed an auth flow. If it is true
the first dict contains the authenticated credentials of each stage.
defer.Deferred[dict, dict, str]: a deferred tuple of
(creds, params, session_id).
If authed is false, the first dictionary is the server response to
the login request and should be passed back to the client.
'creds' contains the authenticated credentials of each stage.
In either case, the second dict contains the parameters for this
request (which may have been given only in a previous call).
'params' contains the parameters for this request (which may
have been given only in a previous call).
session_id is the ID of this session, either passed in by the client
or assigned by the call to check_auth
'session_id' is the ID of this session, either passed in by the
client or assigned by this call
Raises:
InteractiveAuthIncompleteError if the client has not yet completed
all the stages in any of the permitted flows.
"""
authdict = None
@ -142,11 +215,8 @@ class AuthHandler(BaseHandler):
clientdict = session['clientdict']
if not authdict:
defer.returnValue(
(
False, self._auth_dict_for_flows(flows, session),
clientdict, session['id']
)
raise InteractiveAuthIncompleteError(
self._auth_dict_for_flows(flows, session),
)
if 'creds' not in session:
@ -157,14 +227,12 @@ class AuthHandler(BaseHandler):
errordict = {}
if 'type' in authdict:
login_type = authdict['type']
if login_type not in self.checkers:
raise LoginError(400, "", Codes.UNRECOGNIZED)
try:
result = yield self.checkers[login_type](authdict, clientip)
result = yield self._check_auth_dict(authdict, clientip)
if result:
creds[login_type] = result
self._save_session(session)
except LoginError, e:
except LoginError as e:
if login_type == LoginType.EMAIL_IDENTITY:
# riot used to have a bug where it would request a new
# validation token (thus sending a new email) each time it
@ -173,7 +241,7 @@ class AuthHandler(BaseHandler):
#
# Grandfather in the old behaviour for now to avoid
# breaking old riot deployments.
raise e
raise
# this step failed. Merge the error dict into the response
# so that the client can have another go.
@ -190,12 +258,14 @@ class AuthHandler(BaseHandler):
"Auth completed with creds: %r. Client dict has keys: %r",
creds, clientdict.keys()
)
defer.returnValue((True, creds, clientdict, session['id']))
defer.returnValue((creds, clientdict, session['id']))
ret = self._auth_dict_for_flows(flows, session)
ret['completed'] = creds.keys()
ret.update(errordict)
defer.returnValue((False, ret, clientdict, session['id']))
raise InteractiveAuthIncompleteError(
ret,
)
@defer.inlineCallbacks
def add_oob_auth(self, stagetype, authdict, clientip):
@ -268,17 +338,35 @@ class AuthHandler(BaseHandler):
return sess.setdefault('serverdict', {}).get(key, default)
@defer.inlineCallbacks
def _check_password_auth(self, authdict, _):
if "user" not in authdict or "password" not in authdict:
raise LoginError(400, "", Codes.MISSING_PARAM)
def _check_auth_dict(self, authdict, clientip):
"""Attempt to validate the auth dict provided by a client
user_id = authdict["user"]
password = authdict["password"]
Args:
authdict (object): auth dict provided by the client
clientip (str): IP address of the client
(canonical_id, callback) = yield self.validate_login(user_id, {
"type": LoginType.PASSWORD,
"password": password,
})
Returns:
Deferred: result of the stage verification.
Raises:
StoreError if there was a problem accessing the database
SynapseError if there was a problem with the request
LoginError if there was an authentication problem.
"""
login_type = authdict['type']
checker = self.checkers.get(login_type)
if checker is not None:
res = yield checker(authdict, clientip)
defer.returnValue(res)
# build a v1-login-style dict out of the authdict and fall back to the
# v1 code
user_id = authdict.get("user")
if user_id is None:
raise SynapseError(400, "", Codes.MISSING_PARAM)
(canonical_id, callback) = yield self.validate_login(user_id, authdict)
defer.returnValue(canonical_id)
@defer.inlineCallbacks
@ -649,41 +737,6 @@ class AuthHandler(BaseHandler):
except Exception:
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
@defer.inlineCallbacks
def set_password(self, user_id, newpassword, requester=None):
password_hash = self.hash(newpassword)
except_access_token_id = requester.access_token_id if requester else None
try:
yield self.store.user_set_password_hash(user_id, password_hash)
except StoreError as e:
if e.code == 404:
raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
raise e
yield self.delete_access_tokens_for_user(
user_id, except_token_id=except_access_token_id,
)
yield self.hs.get_pusherpool().remove_pushers_by_user(
user_id, except_access_token_id
)
@defer.inlineCallbacks
def deactivate_account(self, user_id):
"""Deactivate a user's account
Args:
user_id (str): ID of user to be deactivated
Returns:
Deferred
"""
# FIXME: Theoretically there is a race here wherein user resets
# password using threepid.
yield self.delete_access_tokens_for_user(user_id)
yield self.store.user_delete_threepids(user_id)
yield self.store.user_set_password_hash(user_id, None)
@defer.inlineCallbacks
def delete_access_token(self, access_token):
"""Invalidate a single access token
@ -706,6 +759,12 @@ class AuthHandler(BaseHandler):
access_token=access_token,
)
# delete pushers associated with this access token
if user_info["token_id"] is not None:
yield self.hs.get_pusherpool().remove_pushers_by_access_token(
str(user_info["user"]), (user_info["token_id"], )
)
@defer.inlineCallbacks
def delete_access_tokens_for_user(self, user_id, except_token_id=None,
device_id=None):
@ -728,13 +787,18 @@ class AuthHandler(BaseHandler):
# see if any of our auth providers want to know about this
for provider in self.password_providers:
if hasattr(provider, "on_logged_out"):
for token, device_id in tokens_and_devices:
for token, token_id, device_id in tokens_and_devices:
yield provider.on_logged_out(
user_id=user_id,
device_id=device_id,
access_token=token,
)
# delete pushers associated with the access tokens
yield self.hs.get_pusherpool().remove_pushers_by_access_token(
user_id, (token_id for _, token_id, _ in tokens_and_devices),
)
@defer.inlineCallbacks
def add_threepid(self, user_id, medium, address, validated_at):
# 'Canonicalise' email addresses down to lower case.

View File

@ -0,0 +1,52 @@
# -*- coding: utf-8 -*-
# Copyright 2017 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from ._base import BaseHandler
import logging
logger = logging.getLogger(__name__)
class DeactivateAccountHandler(BaseHandler):
"""Handler which deals with deactivating user accounts."""
def __init__(self, hs):
super(DeactivateAccountHandler, self).__init__(hs)
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
@defer.inlineCallbacks
def deactivate_account(self, user_id):
"""Deactivate a user's account
Args:
user_id (str): ID of user to be deactivated
Returns:
Deferred
"""
# FIXME: Theoretically there is a race here wherein user resets
# password using threepid.
# first delete any devices belonging to the user, which will also
# delete corresponding access tokens.
yield self._device_handler.delete_all_devices_for_user(user_id)
# then delete any remaining access tokens which weren't associated with
# a device.
yield self._auth_handler.delete_access_tokens_for_user(user_id)
yield self.store.user_delete_threepids(user_id)
yield self.store.user_set_password_hash(user_id, None)

View File

@ -170,13 +170,31 @@ class DeviceHandler(BaseHandler):
yield self.notify_device_update(user_id, [device_id])
@defer.inlineCallbacks
def delete_all_devices_for_user(self, user_id, except_device_id=None):
"""Delete all of the user's devices
Args:
user_id (str):
except_device_id (str|None): optional device id which should not
be deleted
Returns:
defer.Deferred:
"""
device_map = yield self.store.get_devices_by_user(user_id)
device_ids = device_map.keys()
if except_device_id is not None:
device_ids = [d for d in device_ids if d != except_device_id]
yield self.delete_devices(user_id, device_ids)
@defer.inlineCallbacks
def delete_devices(self, user_id, device_ids):
""" Delete several devices
Args:
user_id (str):
device_ids (str): The list of device IDs to delete
device_ids (List[str]): The list of device IDs to delete
Returns:
defer.Deferred:

View File

@ -375,6 +375,12 @@ class GroupsLocalHandler(object):
def get_publicised_groups_for_user(self, user_id):
if self.hs.is_mine_id(user_id):
result = yield self.store.get_publicised_groups_for_user(user_id)
# Check AS associated groups for this user - this depends on the
# RegExps in the AS registration file (under `users`)
for app_service in self.store.get_app_services():
result.extend(app_service.get_groups_for_user(user_id))
defer.returnValue({"groups": result})
else:
result = yield self.transport_client.get_publicised_groups_for_user(
@ -415,4 +421,9 @@ class GroupsLocalHandler(object):
uid
)
# Check AS associated groups for this user - this depends on the
# RegExps in the AS registration file (under `users`)
for app_service in self.store.get_app_services():
results[uid].extend(app_service.get_groups_for_user(uid))
defer.returnValue({"users": results})

View File

@ -27,7 +27,7 @@ from synapse.types import (
from synapse.util import unwrapFirstError
from synapse.util.async import concurrently_execute
from synapse.util.caches.snapshot_cache import SnapshotCache
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
@ -163,7 +163,7 @@ class InitialSyncHandler(BaseHandler):
lambda states: states[event.event_id]
)
(messages, token), current_state = yield preserve_context_over_deferred(
(messages, token), current_state = yield make_deferred_yieldable(
defer.gatherResults(
[
preserve_fn(self.store.get_recent_events_for_room)(

View File

@ -1199,7 +1199,7 @@ def handle_timeout(state, is_mine, syncing_user_ids, now):
)
changed = True
else:
# We expect to be poked occaisonally by the other side.
# We expect to be poked occasionally by the other side.
# This is to protect against forgetful/buggy servers, so that
# no one gets stuck online forever.
if now - state.last_federation_update_ts > FEDERATION_TIMEOUT:

View File

@ -36,6 +36,8 @@ class ProfileHandler(BaseHandler):
"profile", self.on_profile_query
)
self.user_directory_handler = hs.get_user_directory_handler()
self.clock.looping_call(self._update_remote_profile_cache, self.PROFILE_UPDATE_MS)
@defer.inlineCallbacks
@ -139,6 +141,12 @@ class ProfileHandler(BaseHandler):
target_user.localpart, new_displayname
)
if self.hs.config.user_directory_search_all_users:
profile = yield self.store.get_profileinfo(target_user.localpart)
yield self.user_directory_handler.handle_local_profile_change(
target_user.to_string(), profile
)
yield self._update_join_states(requester, target_user)
@defer.inlineCallbacks
@ -183,6 +191,12 @@ class ProfileHandler(BaseHandler):
target_user.localpart, new_avatar_url
)
if self.hs.config.user_directory_search_all_users:
profile = yield self.store.get_profileinfo(target_user.localpart)
yield self.user_directory_handler.handle_local_profile_change(
target_user.to_string(), profile
)
yield self._update_join_states(requester, target_user)
@defer.inlineCallbacks

View File

@ -38,6 +38,7 @@ class RegistrationHandler(BaseHandler):
self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
self.profile_handler = hs.get_profile_handler()
self.user_directory_handler = hs.get_user_directory_handler()
self.captcha_client = CaptchaServerHttpClient(hs)
self._next_generated_user_id = None
@ -165,6 +166,13 @@ class RegistrationHandler(BaseHandler):
),
admin=admin,
)
if self.hs.config.user_directory_search_all_users:
profile = yield self.store.get_profileinfo(localpart)
yield self.user_directory_handler.handle_local_profile_change(
user_id, profile
)
else:
# autogen a sequential user ID
attempts = 0

View File

@ -154,6 +154,8 @@ class RoomListHandler(BaseHandler):
# We want larger rooms to be first, hence negating num_joined_users
rooms_to_order_value[room_id] = (-num_joined_users, room_id)
logger.info("Getting ordering for %i rooms since %s",
len(room_ids), stream_token)
yield concurrently_execute(get_order_for_room, room_ids, 10)
sorted_entries = sorted(rooms_to_order_value.items(), key=lambda e: e[1])
@ -181,34 +183,42 @@ class RoomListHandler(BaseHandler):
rooms_to_scan = rooms_to_scan[:since_token.current_limit]
rooms_to_scan.reverse()
# Actually generate the entries. _append_room_entry_to_chunk will append to
# chunk but will stop if len(chunk) > limit
chunk = []
if limit and not search_filter:
logger.info("After sorting and filtering, %i rooms remain",
len(rooms_to_scan))
# _append_room_entry_to_chunk will append to chunk but will stop if
# len(chunk) > limit
#
# Normally we will generate enough results on the first iteration here,
# but if there is a search filter, _append_room_entry_to_chunk may
# filter some results out, in which case we loop again.
#
# We don't want to scan over the entire range either as that
# would potentially waste a lot of work.
#
# XXX if there is no limit, we may end up DoSing the server with
# calls to get_current_state_ids for every single room on the
# server. Surely we should cap this somehow?
#
if limit:
step = limit + 1
for i in xrange(0, len(rooms_to_scan), step):
# We iterate here because the vast majority of cases we'll stop
# at first iteration, but occaisonally _append_room_entry_to_chunk
# won't append to the chunk and so we need to loop again.
# We don't want to scan over the entire range either as that
# would potentially waste a lot of work.
yield concurrently_execute(
lambda r: self._append_room_entry_to_chunk(
r, rooms_to_num_joined[r],
chunk, limit, search_filter
),
rooms_to_scan[i:i + step], 10
)
if len(chunk) >= limit + 1:
break
else:
step = len(rooms_to_scan)
chunk = []
for i in xrange(0, len(rooms_to_scan), step):
batch = rooms_to_scan[i:i + step]
logger.info("Processing %i rooms for result", len(batch))
yield concurrently_execute(
lambda r: self._append_room_entry_to_chunk(
r, rooms_to_num_joined[r],
chunk, limit, search_filter
),
rooms_to_scan, 5
batch, 5,
)
logger.info("Now %i rooms in result", len(chunk))
if len(chunk) >= limit + 1:
break
chunk.sort(key=lambda e: (-e["num_joined_members"], e["room_id"]))

View File

@ -0,0 +1,56 @@
# -*- coding: utf-8 -*-
# Copyright 2017 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.api.errors import Codes, StoreError, SynapseError
from ._base import BaseHandler
logger = logging.getLogger(__name__)
class SetPasswordHandler(BaseHandler):
"""Handler which deals with changing user account passwords"""
def __init__(self, hs):
super(SetPasswordHandler, self).__init__(hs)
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
@defer.inlineCallbacks
def set_password(self, user_id, newpassword, requester=None):
password_hash = self._auth_handler.hash(newpassword)
except_device_id = requester.device_id if requester else None
except_access_token_id = requester.access_token_id if requester else None
try:
yield self.store.user_set_password_hash(user_id, password_hash)
except StoreError as e:
if e.code == 404:
raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
raise e
# we want to log out all of the user's other sessions. First delete
# all his other devices.
yield self._device_handler.delete_all_devices_for_user(
user_id, except_device_id=except_device_id,
)
# and now delete any access tokens which weren't associated with
# devices (or were associated with this device).
yield self._auth_handler.delete_access_tokens_for_user(
user_id, except_token_id=except_access_token_id,
)

View File

@ -20,12 +20,13 @@ from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.storage.roommember import ProfileInfo
from synapse.util.metrics import Measure
from synapse.util.async import sleep
from synapse.types import get_localpart_from_id
logger = logging.getLogger(__name__)
class UserDirectoyHandler(object):
class UserDirectoryHandler(object):
"""Handles querying of and keeping updated the user_directory.
N.B.: ASSUMES IT IS THE ONLY THING THAT MODIFIES THE USER DIRECTORY
@ -41,9 +42,10 @@ class UserDirectoyHandler(object):
one public room.
"""
INITIAL_SLEEP_MS = 50
INITIAL_SLEEP_COUNT = 100
INITIAL_BATCH_SIZE = 100
INITIAL_ROOM_SLEEP_MS = 50
INITIAL_ROOM_SLEEP_COUNT = 100
INITIAL_ROOM_BATCH_SIZE = 100
INITIAL_USER_SLEEP_MS = 10
def __init__(self, hs):
self.store = hs.get_datastore()
@ -53,6 +55,7 @@ class UserDirectoyHandler(object):
self.notifier = hs.get_notifier()
self.is_mine_id = hs.is_mine_id
self.update_user_directory = hs.config.update_user_directory
self.search_all_users = hs.config.user_directory_search_all_users
# When start up for the first time we need to populate the user_directory.
# This is a set of user_id's we've inserted already
@ -110,6 +113,15 @@ class UserDirectoyHandler(object):
finally:
self._is_processing = False
@defer.inlineCallbacks
def handle_local_profile_change(self, user_id, profile):
"""Called to update index of our local user profiles when they change
irrespective of any rooms the user may be in.
"""
yield self.store.update_profile_in_user_dir(
user_id, profile.display_name, profile.avatar_url, None,
)
@defer.inlineCallbacks
def _unsafe_process(self):
# If self.pos is None then means we haven't fetched it from DB
@ -148,16 +160,30 @@ class UserDirectoyHandler(object):
room_ids = yield self.store.get_all_rooms()
logger.info("Doing initial update of user directory. %d rooms", len(room_ids))
num_processed_rooms = 1
num_processed_rooms = 0
for room_id in room_ids:
logger.info("Handling room %d/%d", num_processed_rooms, len(room_ids))
logger.info("Handling room %d/%d", num_processed_rooms + 1, len(room_ids))
yield self._handle_initial_room(room_id)
num_processed_rooms += 1
yield sleep(self.INITIAL_SLEEP_MS / 1000.)
yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
logger.info("Processed all rooms.")
if self.search_all_users:
num_processed_users = 0
user_ids = yield self.store.get_all_local_users()
logger.info("Doing initial update of user directory. %d users", len(user_ids))
for user_id in user_ids:
# We add profiles for all users even if they don't match the
# include pattern, just in case we want to change it in future
logger.info("Handling user %d/%d", num_processed_users + 1, len(user_ids))
yield self._handle_local_user(user_id)
num_processed_users += 1
yield sleep(self.INITIAL_USER_SLEEP_MS / 1000.)
logger.info("Processed all users")
self.initially_handled_users = None
self.initially_handled_users_in_public = None
self.initially_handled_users_share = None
@ -201,8 +227,8 @@ class UserDirectoyHandler(object):
to_update = set()
count = 0
for user_id in user_ids:
if count % self.INITIAL_SLEEP_COUNT == 0:
yield sleep(self.INITIAL_SLEEP_MS / 1000.)
if count % self.INITIAL_ROOM_SLEEP_COUNT == 0:
yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
if not self.is_mine_id(user_id):
count += 1
@ -216,8 +242,8 @@ class UserDirectoyHandler(object):
if user_id == other_user_id:
continue
if count % self.INITIAL_SLEEP_COUNT == 0:
yield sleep(self.INITIAL_SLEEP_MS / 1000.)
if count % self.INITIAL_ROOM_SLEEP_COUNT == 0:
yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
count += 1
user_set = (user_id, other_user_id)
@ -237,13 +263,13 @@ class UserDirectoyHandler(object):
else:
self.initially_handled_users_share_private_room.add(user_set)
if len(to_insert) > self.INITIAL_BATCH_SIZE:
if len(to_insert) > self.INITIAL_ROOM_BATCH_SIZE:
yield self.store.add_users_who_share_room(
room_id, not is_public, to_insert,
)
to_insert.clear()
if len(to_update) > self.INITIAL_BATCH_SIZE:
if len(to_update) > self.INITIAL_ROOM_BATCH_SIZE:
yield self.store.update_users_who_share_room(
room_id, not is_public, to_update,
)
@ -384,15 +410,29 @@ class UserDirectoyHandler(object):
for user_id in users:
yield self._handle_remove_user(room_id, user_id)
@defer.inlineCallbacks
def _handle_local_user(self, user_id):
"""Adds a new local roomless user into the user_directory_search table.
Used to populate up the user index when we have an
user_directory_search_all_users specified.
"""
logger.debug("Adding new local user to dir, %r", user_id)
profile = yield self.store.get_profileinfo(get_localpart_from_id(user_id))
row = yield self.store.get_user_in_directory(user_id)
if not row:
yield self.store.add_profiles_to_user_dir(None, {user_id: profile})
@defer.inlineCallbacks
def _handle_new_user(self, room_id, user_id, profile):
"""Called when we might need to add user to directory
Args:
room_id (str): room_id that user joined or started being public that
room_id (str): room_id that user joined or started being public
user_id (str)
"""
logger.debug("Adding user to dir, %r", user_id)
logger.debug("Adding new user to dir, %r", user_id)
row = yield self.store.get_user_in_directory(user_id)
if not row:
@ -407,7 +447,7 @@ class UserDirectoyHandler(object):
if not row:
yield self.store.add_users_to_public_room(room_id, [user_id])
else:
logger.debug("Not adding user to public dir, %r", user_id)
logger.debug("Not adding new user to public dir, %r", user_id)
# Now we update users who share rooms with users. We do this by getting
# all the current users in the room and seeing which aren't already

View File

@ -362,8 +362,10 @@ def _get_hosts_for_srv_record(dns_client, host):
return res
# no logcontexts here, so we can safely fire these off and gatherResults
d1 = dns_client.lookupAddress(host).addCallbacks(cb, eb)
d2 = dns_client.lookupIPV6Address(host).addCallbacks(cb, eb)
d1 = dns_client.lookupAddress(host).addCallbacks(
cb, eb, errbackArgs=("A", ))
d2 = dns_client.lookupIPV6Address(host).addCallbacks(
cb, eb, errbackArgs=("AAAA", ))
results = yield defer.DeferredList(
[d1, d2], consumeErrors=True)

View File

@ -28,6 +28,7 @@ from canonicaljson import (
)
from twisted.internet import defer
from twisted.python import failure
from twisted.web import server, resource
from twisted.web.server import NOT_DONE_YET
from twisted.web.util import redirectTo
@ -131,12 +132,17 @@ def wrap_request_handler(request_handler, include_metrics=False):
version_string=self.version_string,
)
except Exception:
logger.exception(
"Failed handle request %s.%s on %r: %r",
# failure.Failure() fishes the original Failure out
# of our stack, and thus gives us a sensible stack
# trace.
f = failure.Failure()
logger.error(
"Failed handle request %s.%s on %r: %r: %s",
request_handler.__module__,
request_handler.__name__,
self,
request
request,
f.getTraceback().rstrip(),
)
respond_with_json(
request,

View File

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.types import UserID
@ -81,6 +82,7 @@ class ModuleApi(object):
reg = self.hs.get_handlers().registration_handler
return reg.register(localpart=localpart)
@defer.inlineCallbacks
def invalidate_access_token(self, access_token):
"""Invalidate an access token for a user
@ -94,8 +96,16 @@ class ModuleApi(object):
Raises:
synapse.api.errors.AuthError: the access token is invalid
"""
return self._auth_handler.delete_access_token(access_token)
# see if the access token corresponds to a device
user_info = yield self._auth.get_user_by_access_token(access_token)
device_id = user_info.get("device_id")
user_id = user_info["user"].to_string()
if device_id:
# delete the device, which will also delete its access tokens
yield self.hs.get_device_handler().delete_device(user_id, device_id)
else:
# no associated device. Just delete the access token.
yield self._auth_handler.delete_access_token(access_token)
def run_db_interaction(self, desc, func, *args, **kwargs):
"""Run a function with a database connection

View File

@ -255,9 +255,7 @@ class Notifier(object):
)
if self.federation_sender:
preserve_fn(self.federation_sender.notify_new_events)(
room_stream_id
)
self.federation_sender.notify_new_events(room_stream_id)
if event.type == EventTypes.Member and event.membership == Membership.JOIN:
self._user_joined_room(event.state_key, event.room_id)
@ -297,8 +295,7 @@ class Notifier(object):
def on_new_replication_data(self):
"""Used to inform replication listeners that something has happend
without waking up any of the normal user event streams"""
with PreserveLoggingContext():
self.notify_replication()
self.notify_replication()
@defer.inlineCallbacks
def wait_for_events(self, user_id, timeout, callback, room_ids=None,
@ -516,8 +513,14 @@ class Notifier(object):
self.replication_deferred = ObservableDeferred(defer.Deferred())
deferred.callback(None)
for cb in self.replication_callbacks:
preserve_fn(cb)()
# the callbacks may well outlast the current request, so we run
# them in the sentinel logcontext.
#
# (ideally it would be up to the callbacks to know if they were
# starting off background processes and drop the logcontext
# accordingly, but that requires more changes)
for cb in self.replication_callbacks:
cb()
@defer.inlineCallbacks
def wait_for_replication(self, callback, timeout):

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2017 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -295,7 +296,7 @@ class HttpPusher(object):
if event.type == 'm.room.member':
d['notification']['membership'] = event.content['membership']
d['notification']['user_is_target'] = event.state_key == self.user_id
if not self.hs.config.push_redact_content and 'content' in event:
if self.hs.config.push_include_content and 'content' in event:
d['notification']['content'] = event.content
# We no longer send aliases separately, instead, we send the human

View File

@ -17,7 +17,7 @@
from twisted.internet import defer
from .pusher import PusherFactory
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
from synapse.util.async import run_on_reactor
import logging
@ -103,19 +103,25 @@ class PusherPool:
yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
@defer.inlineCallbacks
def remove_pushers_by_user(self, user_id, except_access_token_id=None):
all = yield self.store.get_all_pushers()
logger.info(
"Removing all pushers for user %s except access tokens id %r",
user_id, except_access_token_id
)
for p in all:
if p['user_name'] == user_id and p['access_token'] != except_access_token_id:
def remove_pushers_by_access_token(self, user_id, access_tokens):
"""Remove the pushers for a given user corresponding to a set of
access_tokens.
Args:
user_id (str): user to remove pushers for
access_tokens (Iterable[int]): access token *ids* to remove pushers
for
"""
tokens = set(access_tokens)
for p in (yield self.store.get_pushers_by_user_id(user_id)):
if p['access_token'] in tokens:
logger.info(
"Removing pusher for app id %s, pushkey %s, user %s",
p['app_id'], p['pushkey'], p['user_name']
)
yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
yield self.remove_pusher(
p['app_id'], p['pushkey'], p['user_name'],
)
@defer.inlineCallbacks
def on_new_notifications(self, min_stream_id, max_stream_id):
@ -136,7 +142,7 @@ class PusherPool:
)
)
yield preserve_context_over_deferred(defer.gatherResults(deferreds))
yield make_deferred_yieldable(defer.gatherResults(deferreds))
except Exception:
logger.exception("Exception in pusher on_new_notifications")
@ -161,7 +167,7 @@ class PusherPool:
preserve_fn(p.on_new_receipts)(min_stream_id, max_stream_id)
)
yield preserve_context_over_deferred(defer.gatherResults(deferreds))
yield make_deferred_yieldable(defer.gatherResults(deferreds))
except Exception:
logger.exception("Exception in pusher on_new_receipts")

View File

@ -12,20 +12,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
import logging
from synapse.api.constants import EventTypes
from synapse.storage import DataStore
from synapse.storage.roommember import RoomMemberStore
from synapse.storage.event_federation import EventFederationStore
from synapse.storage.event_push_actions import EventPushActionsStore
from synapse.storage.state import StateStore
from synapse.storage.roommember import RoomMemberStore
from synapse.storage.state import StateGroupReadStore
from synapse.storage.stream import StreamStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
import logging
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
logger = logging.getLogger(__name__)
@ -39,7 +37,7 @@ logger = logging.getLogger(__name__)
# the method descriptor on the DataStore and chuck them into our class.
class SlavedEventStore(BaseSlavedStore):
class SlavedEventStore(StateGroupReadStore, BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedEventStore, self).__init__(db_conn, hs)
@ -90,25 +88,9 @@ class SlavedEventStore(BaseSlavedStore):
_get_unread_counts_by_pos_txn = (
DataStore._get_unread_counts_by_pos_txn.__func__
)
_get_state_group_for_events = (
StateStore.__dict__["_get_state_group_for_events"]
)
_get_state_group_for_event = (
StateStore.__dict__["_get_state_group_for_event"]
)
_get_state_groups_from_groups = (
StateStore.__dict__["_get_state_groups_from_groups"]
)
_get_state_groups_from_groups_txn = (
DataStore._get_state_groups_from_groups_txn.__func__
)
get_recent_event_ids_for_room = (
StreamStore.__dict__["get_recent_event_ids_for_room"]
)
get_current_state_ids = (
StateStore.__dict__["get_current_state_ids"]
)
get_state_group_delta = StateStore.__dict__["get_state_group_delta"]
_get_joined_hosts_cache = RoomMemberStore.__dict__["_get_joined_hosts_cache"]
has_room_changed_since = DataStore.has_room_changed_since.__func__
@ -134,12 +116,6 @@ class SlavedEventStore(BaseSlavedStore):
DataStore.get_room_events_stream_for_room.__func__
)
get_events_around = DataStore.get_events_around.__func__
get_state_for_event = DataStore.get_state_for_event.__func__
get_state_for_events = DataStore.get_state_for_events.__func__
get_state_groups = DataStore.get_state_groups.__func__
get_state_groups_ids = DataStore.get_state_groups_ids.__func__
get_state_ids_for_event = DataStore.get_state_ids_for_event.__func__
get_state_ids_for_events = DataStore.get_state_ids_for_events.__func__
get_joined_users_from_state = DataStore.get_joined_users_from_state.__func__
get_joined_users_from_context = DataStore.get_joined_users_from_context.__func__
_get_joined_users_from_context = (
@ -169,10 +145,7 @@ class SlavedEventStore(BaseSlavedStore):
_get_rooms_for_user_where_membership_is_txn = (
DataStore._get_rooms_for_user_where_membership_is_txn.__func__
)
_get_state_for_groups = DataStore._get_state_for_groups.__func__
_get_all_state_from_cache = DataStore._get_all_state_from_cache.__func__
_get_events_around_txn = DataStore._get_events_around_txn.__func__
_get_some_state_from_cache = DataStore._get_some_state_from_cache.__func__
get_backfill_events = DataStore.get_backfill_events.__func__
_get_backfill_events = DataStore._get_backfill_events.__func__

View File

@ -216,11 +216,12 @@ class ReplicationStreamer(object):
self.federation_sender.federation_ack(token)
@measure_func("repl.on_user_sync")
@defer.inlineCallbacks
def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms):
"""A client has started/stopped syncing on a worker.
"""
user_sync_counter.inc()
self.presence_handler.update_external_syncs_row(
yield self.presence_handler.update_external_syncs_row(
conn_id, user_id, is_syncing, last_sync_ms,
)
@ -244,11 +245,12 @@ class ReplicationStreamer(object):
getattr(self.store, cache_func).invalidate(tuple(keys))
@measure_func("repl.on_user_ip")
@defer.inlineCallbacks
def on_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen):
"""The client saw a user request
"""
user_ip_cache_counter.inc()
self.store.insert_client_ip(
yield self.store.insert_client_ip(
user_id, access_token, ip, user_agent, device_id, last_seen,
)

View File

@ -137,8 +137,8 @@ class DeactivateAccountRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/admin/deactivate/(?P<target_user_id>[^/]*)")
def __init__(self, hs):
self._auth_handler = hs.get_auth_handler()
super(DeactivateAccountRestServlet, self).__init__(hs)
self._deactivate_account_handler = hs.get_deactivate_account_handler()
@defer.inlineCallbacks
def on_POST(self, request, target_user_id):
@ -149,7 +149,7 @@ class DeactivateAccountRestServlet(ClientV1RestServlet):
if not is_admin:
raise AuthError(403, "You are not a server admin")
yield self._auth_handler.deactivate_account(target_user_id)
yield self._deactivate_account_handler.deactivate_account(target_user_id)
defer.returnValue((200, {}))
@ -309,7 +309,7 @@ class ResetPasswordRestServlet(ClientV1RestServlet):
super(ResetPasswordRestServlet, self).__init__(hs)
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
self._set_password_handler = hs.get_set_password_handler()
@defer.inlineCallbacks
def on_POST(self, request, target_user_id):
@ -330,7 +330,7 @@ class ResetPasswordRestServlet(ClientV1RestServlet):
logger.info("new_password: %r", new_password)
yield self.auth_handler.set_password(
yield self._set_password_handler.set_password(
target_user_id, new_password, requester
)
defer.returnValue((200, {}))

View File

@ -16,6 +16,7 @@
from twisted.internet import defer
from synapse.api.auth import get_access_token_from_request
from synapse.api.errors import AuthError
from .base import ClientV1RestServlet, client_path_patterns
@ -30,15 +31,30 @@ class LogoutRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(LogoutRestServlet, self).__init__(hs)
self._auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
def on_OPTIONS(self, request):
return (200, {})
@defer.inlineCallbacks
def on_POST(self, request):
access_token = get_access_token_from_request(request)
yield self._auth_handler.delete_access_token(access_token)
try:
requester = yield self.auth.get_user_by_req(request)
except AuthError:
# this implies the access token has already been deleted.
pass
else:
if requester.device_id is None:
# the acccess token wasn't associated with a device.
# Just delete the access token
access_token = get_access_token_from_request(request)
yield self._auth_handler.delete_access_token(access_token)
else:
yield self._device_handler.delete_device(
requester.user.to_string(), requester.device_id)
defer.returnValue((200, {}))
@ -49,6 +65,7 @@ class LogoutAllRestServlet(ClientV1RestServlet):
super(LogoutAllRestServlet, self).__init__(hs)
self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
def on_OPTIONS(self, request):
return (200, {})
@ -57,6 +74,12 @@ class LogoutAllRestServlet(ClientV1RestServlet):
def on_POST(self, request):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
# first delete all of the user's devices
yield self._device_handler.delete_all_devices_for_user(user_id)
# .. and then delete any access tokens which weren't associated with
# devices.
yield self._auth_handler.delete_access_tokens_for_user(user_id)
defer.returnValue((200, {}))

View File

@ -15,12 +15,13 @@
"""This module contains base REST classes for constructing client v1 servlets.
"""
from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
import logging
import re
import logging
from twisted.internet import defer
from synapse.api.errors import InteractiveAuthIncompleteError
from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
logger = logging.getLogger(__name__)
@ -57,3 +58,37 @@ def set_timeline_upper_limit(filter_json, filter_timeline_limit):
filter_json['room']['timeline']["limit"] = min(
filter_json['room']['timeline']['limit'],
filter_timeline_limit)
def interactive_auth_handler(orig):
"""Wraps an on_POST method to handle InteractiveAuthIncompleteErrors
Takes a on_POST method which returns a deferred (errcode, body) response
and adds exception handling to turn a InteractiveAuthIncompleteError into
a 401 response.
Normal usage is:
@interactive_auth_handler
@defer.inlineCallbacks
def on_POST(self, request):
# ...
yield self.auth_handler.check_auth
"""
def wrapped(*args, **kwargs):
res = defer.maybeDeferred(orig, *args, **kwargs)
res.addErrback(_catch_incomplete_interactive_auth)
return res
return wrapped
def _catch_incomplete_interactive_auth(f):
"""helper for interactive_auth_handler
Catches InteractiveAuthIncompleteErrors and turns them into 401 responses
Args:
f (failure.Failure):
"""
f.trap(InteractiveAuthIncompleteError)
return 401, f.value.result

View File

@ -19,14 +19,14 @@ from twisted.internet import defer
from synapse.api.auth import has_access_token
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import (
RestServlet, assert_params_in_request,
parse_json_object_from_request,
)
from synapse.util.async import run_on_reactor
from synapse.util.msisdn import phone_number_to_msisdn
from ._base import client_v2_patterns
from ._base import client_v2_patterns, interactive_auth_handler
logger = logging.getLogger(__name__)
@ -98,56 +98,61 @@ class PasswordRestServlet(RestServlet):
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
self.datastore = self.hs.get_datastore()
self._set_password_handler = hs.get_set_password_handler()
@interactive_auth_handler
@defer.inlineCallbacks
def on_POST(self, request):
yield run_on_reactor()
body = parse_json_object_from_request(request)
authed, result, params, _ = yield self.auth_handler.check_auth([
[LoginType.PASSWORD],
[LoginType.EMAIL_IDENTITY],
[LoginType.MSISDN],
], body, self.hs.get_ip_from_request(request))
# there are two possibilities here. Either the user does not have an
# access token, and needs to do a password reset; or they have one and
# need to validate their identity.
#
# In the first case, we offer a couple of means of identifying
# themselves (email and msisdn, though it's unclear if msisdn actually
# works).
#
# In the second case, we require a password to confirm their identity.
if not authed:
defer.returnValue((401, result))
user_id = None
requester = None
if LoginType.PASSWORD in result:
# if using password, they should also be logged in
if has_access_token(request):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
if user_id != result[LoginType.PASSWORD]:
raise LoginError(400, "", Codes.UNKNOWN)
elif LoginType.EMAIL_IDENTITY in result:
threepid = result[LoginType.EMAIL_IDENTITY]
if 'medium' not in threepid or 'address' not in threepid:
raise SynapseError(500, "Malformed threepid")
if threepid['medium'] == 'email':
# For emails, transform the address to lowercase.
# We store all email addreses as lowercase in the DB.
# (See add_threepid in synapse/handlers/auth.py)
threepid['address'] = threepid['address'].lower()
# if using email, we must know about the email they're authing with!
threepid_user_id = yield self.datastore.get_user_id_by_threepid(
threepid['medium'], threepid['address']
params = yield self.auth_handler.validate_user_via_ui_auth(
requester, body, self.hs.get_ip_from_request(request),
)
if not threepid_user_id:
raise SynapseError(404, "Email address not found", Codes.NOT_FOUND)
user_id = threepid_user_id
user_id = requester.user.to_string()
else:
logger.error("Auth succeeded but no known type!", result.keys())
raise SynapseError(500, "", Codes.UNKNOWN)
requester = None
result, params, _ = yield self.auth_handler.check_auth(
[[LoginType.EMAIL_IDENTITY], [LoginType.MSISDN]],
body, self.hs.get_ip_from_request(request),
)
if LoginType.EMAIL_IDENTITY in result:
threepid = result[LoginType.EMAIL_IDENTITY]
if 'medium' not in threepid or 'address' not in threepid:
raise SynapseError(500, "Malformed threepid")
if threepid['medium'] == 'email':
# For emails, transform the address to lowercase.
# We store all email addreses as lowercase in the DB.
# (See add_threepid in synapse/handlers/auth.py)
threepid['address'] = threepid['address'].lower()
# if using email, we must know about the email they're authing with!
threepid_user_id = yield self.datastore.get_user_id_by_threepid(
threepid['medium'], threepid['address']
)
if not threepid_user_id:
raise SynapseError(404, "Email address not found", Codes.NOT_FOUND)
user_id = threepid_user_id
else:
logger.error("Auth succeeded but no known type!", result.keys())
raise SynapseError(500, "", Codes.UNKNOWN)
if 'new_password' not in params:
raise SynapseError(400, "", Codes.MISSING_PARAM)
new_password = params['new_password']
yield self.auth_handler.set_password(
yield self._set_password_handler.set_password(
user_id, new_password, requester
)
@ -161,52 +166,32 @@ class DeactivateAccountRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/deactivate$")
def __init__(self, hs):
super(DeactivateAccountRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
super(DeactivateAccountRestServlet, self).__init__()
self._deactivate_account_handler = hs.get_deactivate_account_handler()
@interactive_auth_handler
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
# if the caller provides an access token, it ought to be valid.
requester = None
if has_access_token(request):
requester = yield self.auth.get_user_by_req(
request,
) # type: synapse.types.Requester
requester = yield self.auth.get_user_by_req(request)
# allow ASes to dectivate their own users
if requester and requester.app_service:
yield self.auth_handler.deactivate_account(
if requester.app_service:
yield self._deactivate_account_handler.deactivate_account(
requester.user.to_string()
)
defer.returnValue((200, {}))
authed, result, params, _ = yield self.auth_handler.check_auth([
[LoginType.PASSWORD],
], body, self.hs.get_ip_from_request(request))
if not authed:
defer.returnValue((401, result))
if LoginType.PASSWORD in result:
user_id = result[LoginType.PASSWORD]
# if using password, they should also be logged in
if requester is None:
raise SynapseError(
400,
"Deactivate account requires an access_token",
errcode=Codes.MISSING_TOKEN
)
if requester.user.to_string() != user_id:
raise LoginError(400, "", Codes.UNKNOWN)
else:
logger.error("Auth succeeded but no known type!", result.keys())
raise SynapseError(500, "", Codes.UNKNOWN)
yield self.auth_handler.deactivate_account(user_id)
yield self.auth_handler.validate_user_via_ui_auth(
requester, body, self.hs.get_ip_from_request(request),
)
yield self._deactivate_account_handler.deactivate_account(
requester.user.to_string(),
)
defer.returnValue((200, {}))

View File

@ -17,9 +17,9 @@ import logging
from twisted.internet import defer
from synapse.api import constants, errors
from synapse.api import errors
from synapse.http import servlet
from ._base import client_v2_patterns
from ._base import client_v2_patterns, interactive_auth_handler
logger = logging.getLogger(__name__)
@ -60,8 +60,11 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
self.device_handler = hs.get_device_handler()
self.auth_handler = hs.get_auth_handler()
@interactive_auth_handler
@defer.inlineCallbacks
def on_POST(self, request):
requester = yield self.auth.get_user_by_req(request)
try:
body = servlet.parse_json_object_from_request(request)
except errors.SynapseError as e:
@ -77,14 +80,10 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
400, "No devices supplied", errcode=errors.Codes.MISSING_PARAM
)
authed, result, params, _ = yield self.auth_handler.check_auth([
[constants.LoginType.PASSWORD],
], body, self.hs.get_ip_from_request(request))
yield self.auth_handler.validate_user_via_ui_auth(
requester, body, self.hs.get_ip_from_request(request),
)
if not authed:
defer.returnValue((401, result))
requester = yield self.auth.get_user_by_req(request)
yield self.device_handler.delete_devices(
requester.user.to_string(),
body['devices'],
@ -115,6 +114,7 @@ class DeviceRestServlet(servlet.RestServlet):
)
defer.returnValue((200, device))
@interactive_auth_handler
@defer.inlineCallbacks
def on_DELETE(self, request, device_id):
requester = yield self.auth.get_user_by_req(request)
@ -130,19 +130,13 @@ class DeviceRestServlet(servlet.RestServlet):
else:
raise
authed, result, params, _ = yield self.auth_handler.check_auth([
[constants.LoginType.PASSWORD],
], body, self.hs.get_ip_from_request(request))
yield self.auth_handler.validate_user_via_ui_auth(
requester, body, self.hs.get_ip_from_request(request),
)
if not authed:
defer.returnValue((401, result))
# check that the UI auth matched the access token
user_id = result[constants.LoginType.PASSWORD]
if user_id != requester.user.to_string():
raise errors.AuthError(403, "Invalid auth")
yield self.device_handler.delete_device(user_id, device_id)
yield self.device_handler.delete_device(
requester.user.to_string(), device_id,
)
defer.returnValue((200, {}))
@defer.inlineCallbacks

View File

@ -38,7 +38,7 @@ class GroupServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, group_id):
requester = yield self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
group_description = yield self.groups_handler.get_group_profile(
@ -74,7 +74,7 @@ class GroupSummaryServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, group_id):
requester = yield self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
get_group_summary = yield self.groups_handler.get_group_summary(
@ -148,7 +148,7 @@ class GroupCategoryServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, group_id, category_id):
requester = yield self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
category = yield self.groups_handler.get_group_category(
@ -200,7 +200,7 @@ class GroupCategoriesServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, group_id):
requester = yield self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
category = yield self.groups_handler.get_group_categories(
@ -225,7 +225,7 @@ class GroupRoleServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, group_id, role_id):
requester = yield self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
category = yield self.groups_handler.get_group_role(
@ -277,7 +277,7 @@ class GroupRolesServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, group_id):
requester = yield self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
category = yield self.groups_handler.get_group_roles(
@ -348,7 +348,7 @@ class GroupRoomServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, group_id):
requester = yield self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
result = yield self.groups_handler.get_rooms_in_group(group_id, requester_user_id)
@ -369,7 +369,7 @@ class GroupUsersServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, group_id):
requester = yield self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
result = yield self.groups_handler.get_users_in_group(group_id, requester_user_id)
@ -672,7 +672,7 @@ class PublicisedGroupsForUserServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, user_id):
yield self.auth.get_user_by_req(request)
yield self.auth.get_user_by_req(request, allow_guest=True)
result = yield self.groups_handler.get_publicised_groups_for_user(
user_id
@ -697,7 +697,7 @@ class PublicisedGroupsForUsersServlet(RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
yield self.auth.get_user_by_req(request)
yield self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request)
user_ids = content["user_ids"]
@ -724,7 +724,7 @@ class GroupsForUserServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
result = yield self.groups_handler.get_joined_groups(requester_user_id)

View File

@ -27,7 +27,7 @@ from synapse.http.servlet import (
)
from synapse.util.msisdn import phone_number_to_msisdn
from ._base import client_v2_patterns
from ._base import client_v2_patterns, interactive_auth_handler
import logging
import hmac
@ -176,6 +176,7 @@ class RegisterRestServlet(RestServlet):
self.device_handler = hs.get_device_handler()
self.macaroon_gen = hs.get_macaroon_generator()
@interactive_auth_handler
@defer.inlineCallbacks
def on_POST(self, request):
yield run_on_reactor()
@ -325,14 +326,10 @@ class RegisterRestServlet(RestServlet):
[LoginType.MSISDN, LoginType.EMAIL_IDENTITY],
])
authed, auth_result, params, session_id = yield self.auth_handler.check_auth(
auth_result, params, session_id = yield self.auth_handler.check_auth(
flows, body, self.hs.get_ip_from_request(request)
)
if not authed:
defer.returnValue((401, auth_result))
return
if registered_user_id is not None:
logger.info(
"Already registered user ID %r for this session",

View File

@ -30,6 +30,7 @@ class VersionsRestServlet(RestServlet):
"r0.0.1",
"r0.1.0",
"r0.2.0",
"r0.3.0",
]
})

View File

@ -25,7 +25,8 @@ from synapse.util.stringutils import random_string
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.http.client import SpiderHttpClient
from synapse.http.server import (
request_handler, respond_with_json_bytes
request_handler, respond_with_json_bytes,
respond_with_json,
)
from synapse.util.async import ObservableDeferred
from synapse.util.stringutils import is_ascii
@ -78,6 +79,9 @@ class PreviewUrlResource(Resource):
self._expire_url_cache_data, 10 * 1000
)
def render_OPTIONS(self, request):
return respond_with_json(request, 200, {}, send_cors=True)
def render_GET(self, request):
self._async_render_GET(request)
return NOT_DONE_YET
@ -348,11 +352,16 @@ class PreviewUrlResource(Resource):
def _expire_url_cache_data(self):
"""Clean up expired url cache content, media and thumbnails.
"""
# TODO: Delete from backup media store
now = self.clock.time_msec()
logger.info("Running url preview cache expiry")
if not (yield self.store.has_completed_background_updates()):
logger.info("Still running DB updates; skipping expiry")
return
# First we delete expired url cache entries
media_ids = yield self.store.get_expired_url_cache(now)
@ -426,8 +435,7 @@ class PreviewUrlResource(Resource):
yield self.store.delete_url_cache_media(removed_media)
if removed_media:
logger.info("Deleted %d media from url cache", len(removed_media))
logger.info("Deleted %d media from url cache", len(removed_media))
def decode_and_calc_og(body, media_uri, request_encoding=None):

View File

@ -39,18 +39,20 @@ from synapse.federation.transaction_queue import TransactionQueue
from synapse.handlers import Handlers
from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.handlers.auth import AuthHandler, MacaroonGeneartor
from synapse.handlers.deactivate_account import DeactivateAccountHandler
from synapse.handlers.devicemessage import DeviceMessageHandler
from synapse.handlers.device import DeviceHandler
from synapse.handlers.e2e_keys import E2eKeysHandler
from synapse.handlers.presence import PresenceHandler
from synapse.handlers.room_list import RoomListHandler
from synapse.handlers.set_password import SetPasswordHandler
from synapse.handlers.sync import SyncHandler
from synapse.handlers.typing import TypingHandler
from synapse.handlers.events import EventHandler, EventStreamHandler
from synapse.handlers.initial_sync import InitialSyncHandler
from synapse.handlers.receipts import ReceiptsHandler
from synapse.handlers.read_marker import ReadMarkerHandler
from synapse.handlers.user_directory import UserDirectoyHandler
from synapse.handlers.user_directory import UserDirectoryHandler
from synapse.handlers.groups_local import GroupsLocalHandler
from synapse.handlers.profile import ProfileHandler
from synapse.groups.groups_server import GroupsServerHandler
@ -60,7 +62,10 @@ from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.notifier import Notifier
from synapse.push.action_generator import ActionGenerator
from synapse.push.pusherpool import PusherPool
from synapse.rest.media.v1.media_repository import MediaRepository
from synapse.rest.media.v1.media_repository import (
MediaRepository,
MediaRepositoryResource,
)
from synapse.state import StateHandler
from synapse.storage import DataStore
from synapse.streams.events import EventSources
@ -90,17 +95,12 @@ class HomeServer(object):
"""
DEPENDENCIES = [
'config',
'clock',
'http_client',
'db_pool',
'persistence_service',
'replication_layer',
'datastore',
'handlers',
'v1auth',
'auth',
'rest_servlet_factory',
'state_handler',
'presence_handler',
'sync_handler',
@ -117,19 +117,10 @@ class HomeServer(object):
'application_service_handler',
'device_message_handler',
'profile_handler',
'deactivate_account_handler',
'set_password_handler',
'notifier',
'distributor',
'client_resource',
'resource_for_federation',
'resource_for_static_content',
'resource_for_web_client',
'resource_for_content_repo',
'resource_for_server_key',
'resource_for_server_key_v2',
'resource_for_media_repository',
'resource_for_metrics',
'event_sources',
'ratelimiter',
'keyring',
'pusherpool',
'event_builder_factory',
@ -137,6 +128,7 @@ class HomeServer(object):
'http_client_context_factory',
'simple_http_client',
'media_repository',
'media_repository_resource',
'federation_transport_client',
'federation_sender',
'receipts_handler',
@ -183,6 +175,21 @@ class HomeServer(object):
def is_mine_id(self, string):
return string.split(":", 1)[1] == self.hostname
def get_clock(self):
return self.clock
def get_datastore(self):
return self.datastore
def get_config(self):
return self.config
def get_distributor(self):
return self.distributor
def get_ratelimiter(self):
return self.ratelimiter
def build_replication_layer(self):
return initialize_http_replication(self)
@ -265,6 +272,12 @@ class HomeServer(object):
def build_profile_handler(self):
return ProfileHandler(self)
def build_deactivate_account_handler(self):
return DeactivateAccountHandler(self)
def build_set_password_handler(self):
return SetPasswordHandler(self)
def build_event_sources(self):
return EventSources(self)
@ -294,6 +307,11 @@ class HomeServer(object):
**self.db_config.get("args", {})
)
def build_media_repository_resource(self):
# build the media repo resource. This indirects through the HomeServer
# to ensure that we only have a single instance of
return MediaRepositoryResource(self)
def build_media_repository(self):
return MediaRepository(self)
@ -321,7 +339,7 @@ class HomeServer(object):
return ActionGenerator(self)
def build_user_directory_handler(self):
return UserDirectoyHandler(self)
return UserDirectoryHandler(self)
def build_groups_local_handler(self):
return GroupsLocalHandler(self)

View File

@ -3,10 +3,14 @@ import synapse.federation.transaction_queue
import synapse.federation.transport.client
import synapse.handlers
import synapse.handlers.auth
import synapse.handlers.deactivate_account
import synapse.handlers.device
import synapse.handlers.e2e_keys
import synapse.storage
import synapse.handlers.set_password
import synapse.rest.media.v1.media_repository
import synapse.state
import synapse.storage
class HomeServer(object):
def get_auth(self) -> synapse.api.auth.Auth:
@ -30,8 +34,20 @@ class HomeServer(object):
def get_state_handler(self) -> synapse.state.StateHandler:
pass
def get_deactivate_account_handler(self) -> synapse.handlers.deactivate_account.DeactivateAccountHandler:
pass
def get_set_password_handler(self) -> synapse.handlers.set_password.SetPasswordHandler:
pass
def get_federation_sender(self) -> synapse.federation.transaction_queue.TransactionQueue:
pass
def get_federation_transport_client(self) -> synapse.federation.transport.client.TransportLayerClient:
pass
def get_media_repository_resource(self) -> synapse.rest.media.v1.media_repository.MediaRepositoryResource:
pass
def get_media_repository(self) -> synapse.rest.media.v1.media_repository.MediaRepository:
pass

View File

@ -16,8 +16,6 @@ import logging
from synapse.api.errors import StoreError
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.caches.dictionary_cache import DictionaryCache
from synapse.util.caches.descriptors import Cache
from synapse.storage.engines import PostgresEngine
import synapse.metrics
@ -180,10 +178,6 @@ class SQLBaseStore(object):
self._get_event_cache = Cache("*getEvent*", keylen=3,
max_entries=hs.config.event_cache_size)
self._state_group_cache = DictionaryCache(
"*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR
)
self._event_fetch_lock = threading.Condition()
self._event_fetch_list = []
self._event_fetch_ongoing = 0
@ -475,23 +469,53 @@ class SQLBaseStore(object):
txn.executemany(sql, vals)
@defer.inlineCallbacks
def _simple_upsert(self, table, keyvalues, values,
insertion_values={}, desc="_simple_upsert", lock=True):
"""
`lock` should generally be set to True (the default), but can be set
to False if either of the following are true:
* there is a UNIQUE INDEX on the key columns. In this case a conflict
will cause an IntegrityError in which case this function will retry
the update.
* we somehow know that we are the only thread which will be updating
this table.
Args:
table (str): The table to upsert into
keyvalues (dict): The unique key tables and their new values
values (dict): The nonunique columns and their new values
insertion_values (dict): key/values to use when inserting
insertion_values (dict): additional key/values to use only when
inserting
lock (bool): True to lock the table when doing the upsert.
Returns:
Deferred(bool): True if a new entry was created, False if an
existing one was updated.
"""
return self.runInteraction(
desc,
self._simple_upsert_txn, table, keyvalues, values, insertion_values,
lock
)
attempts = 0
while True:
try:
result = yield self.runInteraction(
desc,
self._simple_upsert_txn, table, keyvalues, values, insertion_values,
lock=lock
)
defer.returnValue(result)
except self.database_engine.module.IntegrityError as e:
attempts += 1
if attempts >= 5:
# don't retry forever, because things other than races
# can cause IntegrityErrors
raise
# presumably we raced with another transaction: let's retry.
logger.warn(
"IntegrityError when upserting into %s; retrying: %s",
table, e
)
def _simple_upsert_txn(self, txn, table, keyvalues, values, insertion_values={},
lock=True):
@ -499,7 +523,7 @@ class SQLBaseStore(object):
if lock:
self.database_engine.lock_table(txn, table)
# Try to update
# First try to update.
sql = "UPDATE %s SET %s WHERE %s" % (
table,
", ".join("%s = ?" % (k,) for k in values),
@ -508,28 +532,29 @@ class SQLBaseStore(object):
sqlargs = values.values() + keyvalues.values()
txn.execute(sql, sqlargs)
if txn.rowcount == 0:
# We didn't update and rows so insert a new one
allvalues = {}
allvalues.update(keyvalues)
allvalues.update(values)
allvalues.update(insertion_values)
sql = "INSERT INTO %s (%s) VALUES (%s)" % (
table,
", ".join(k for k in allvalues),
", ".join("?" for _ in allvalues)
)
txn.execute(sql, allvalues.values())
return True
else:
if txn.rowcount > 0:
# successfully updated at least one row.
return False
# We didn't update any rows so insert a new one
allvalues = {}
allvalues.update(keyvalues)
allvalues.update(values)
allvalues.update(insertion_values)
sql = "INSERT INTO %s (%s) VALUES (%s)" % (
table,
", ".join(k for k in allvalues),
", ".join("?" for _ in allvalues)
)
txn.execute(sql, allvalues.values())
# successfully inserted
return True
def _simple_select_one(self, table, keyvalues, retcols,
allow_none=False, desc="_simple_select_one"):
"""Executes a SELECT query on the named table, which is expected to
return a single row, returning a single column from it.
return a single row, returning multiple columns from it.
Args:
table : string giving the table name
@ -582,20 +607,18 @@ class SQLBaseStore(object):
@staticmethod
def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
if keyvalues:
where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
else:
where = ""
sql = (
"SELECT %(retcol)s FROM %(table)s %(where)s"
"SELECT %(retcol)s FROM %(table)s"
) % {
"retcol": retcol,
"table": table,
"where": where,
}
txn.execute(sql, keyvalues.values())
if keyvalues:
sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
txn.execute(sql, keyvalues.values())
else:
txn.execute(sql)
return [r[0] for r in txn]
@ -606,7 +629,7 @@ class SQLBaseStore(object):
Args:
table (str): table name
keyvalues (dict): column names and values to select the rows with
keyvalues (dict|None): column names and values to select the rows with
retcol (str): column whos value we wish to retrieve.
Returns:

View File

@ -222,9 +222,12 @@ class AccountDataStore(SQLBaseStore):
"""
content_json = json.dumps(content)
def add_account_data_txn(txn, next_id):
self._simple_upsert_txn(
txn,
with self._account_data_id_gen.get_next() as next_id:
# no need to lock here as room_account_data has a unique constraint
# on (user_id, room_id, account_data_type) so _simple_upsert will
# retry if there is a conflict.
yield self._simple_upsert(
desc="add_room_account_data",
table="room_account_data",
keyvalues={
"user_id": user_id,
@ -234,19 +237,20 @@ class AccountDataStore(SQLBaseStore):
values={
"stream_id": next_id,
"content": content_json,
}
},
lock=False,
)
txn.call_after(
self._account_data_stream_cache.entity_has_changed,
user_id, next_id,
)
txn.call_after(self.get_account_data_for_user.invalidate, (user_id,))
self._update_max_stream_id(txn, next_id)
with self._account_data_id_gen.get_next() as next_id:
yield self.runInteraction(
"add_room_account_data", add_account_data_txn, next_id
)
# it's theoretically possible for the above to succeed and the
# below to fail - in which case we might reuse a stream id on
# restart, and the above update might not get propagated. That
# doesn't sound any worse than the whole update getting lost,
# which is what would happen if we combined the two into one
# transaction.
yield self._update_max_stream_id(next_id)
self._account_data_stream_cache.entity_has_changed(user_id, next_id)
self.get_account_data_for_user.invalidate((user_id,))
result = self._account_data_id_gen.get_current_token()
defer.returnValue(result)
@ -263,9 +267,12 @@ class AccountDataStore(SQLBaseStore):
"""
content_json = json.dumps(content)
def add_account_data_txn(txn, next_id):
self._simple_upsert_txn(
txn,
with self._account_data_id_gen.get_next() as next_id:
# no need to lock here as account_data has a unique constraint on
# (user_id, account_data_type) so _simple_upsert will retry if
# there is a conflict.
yield self._simple_upsert(
desc="add_user_account_data",
table="account_data",
keyvalues={
"user_id": user_id,
@ -274,40 +281,46 @@ class AccountDataStore(SQLBaseStore):
values={
"stream_id": next_id,
"content": content_json,
}
},
lock=False,
)
txn.call_after(
self._account_data_stream_cache.entity_has_changed,
# it's theoretically possible for the above to succeed and the
# below to fail - in which case we might reuse a stream id on
# restart, and the above update might not get propagated. That
# doesn't sound any worse than the whole update getting lost,
# which is what would happen if we combined the two into one
# transaction.
yield self._update_max_stream_id(next_id)
self._account_data_stream_cache.entity_has_changed(
user_id, next_id,
)
txn.call_after(self.get_account_data_for_user.invalidate, (user_id,))
txn.call_after(
self.get_global_account_data_by_type_for_user.invalidate,
self.get_account_data_for_user.invalidate((user_id,))
self.get_global_account_data_by_type_for_user.invalidate(
(account_data_type, user_id,)
)
self._update_max_stream_id(txn, next_id)
with self._account_data_id_gen.get_next() as next_id:
yield self.runInteraction(
"add_user_account_data", add_account_data_txn, next_id
)
result = self._account_data_id_gen.get_current_token()
defer.returnValue(result)
def _update_max_stream_id(self, txn, next_id):
def _update_max_stream_id(self, next_id):
"""Update the max stream_id
Args:
txn: The database cursor
next_id(int): The the revision to advance to.
"""
update_max_id_sql = (
"UPDATE account_data_max_stream_id"
" SET stream_id = ?"
" WHERE stream_id < ?"
def _update(txn):
update_max_id_sql = (
"UPDATE account_data_max_stream_id"
" SET stream_id = ?"
" WHERE stream_id < ?"
)
txn.execute(update_max_id_sql, (next_id, next_id))
return self.runInteraction(
"update_account_data_max_stream_id",
_update,
)
txn.execute(update_max_id_sql, (next_id, next_id))
@cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000)
def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context):

View File

@ -85,6 +85,7 @@ class BackgroundUpdateStore(SQLBaseStore):
self._background_update_performance = {}
self._background_update_queue = []
self._background_update_handlers = {}
self._all_done = False
@defer.inlineCallbacks
def start_doing_background_updates(self):
@ -106,8 +107,40 @@ class BackgroundUpdateStore(SQLBaseStore):
"No more background updates to do."
" Unscheduling background update task."
)
self._all_done = True
defer.returnValue(None)
@defer.inlineCallbacks
def has_completed_background_updates(self):
"""Check if all the background updates have completed
Returns:
Deferred[bool]: True if all background updates have completed
"""
# if we've previously determined that there is nothing left to do, that
# is easy
if self._all_done:
defer.returnValue(True)
# obviously, if we have things in our queue, we're not done.
if self._background_update_queue:
defer.returnValue(False)
# otherwise, check if there are updates to be run. This is important,
# as we may be running on a worker which doesn't perform the bg updates
# itself, but still wants to wait for them to happen.
updates = yield self._simple_select_onecol(
"background_updates",
keyvalues=None,
retcol="1",
desc="check_background_updates",
)
if not updates:
self._all_done = True
defer.returnValue(True)
defer.returnValue(False)
@defer.inlineCallbacks
def do_next_background_update(self, desired_duration_ms):
"""Does some amount of work on the next queued background update
@ -269,7 +302,7 @@ class BackgroundUpdateStore(SQLBaseStore):
# Sqlite doesn't support concurrent creation of indexes.
#
# We don't use partial indices on SQLite as it wasn't introduced
# until 3.8, and wheezy has 3.7
# until 3.8, and wheezy and CentOS 7 have 3.7
#
# We assume that sqlite doesn't give us invalid indices; however
# we may still end up with the index existing but the

View File

@ -12,13 +12,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import SQLBaseStore
from synapse.storage.background_updates import BackgroundUpdateStore
class MediaRepositoryStore(SQLBaseStore):
class MediaRepositoryStore(BackgroundUpdateStore):
"""Persistence for attachments and avatars"""
def __init__(self, db_conn, hs):
super(MediaRepositoryStore, self).__init__(db_conn, hs)
self.register_background_index_update(
update_name='local_media_repository_url_idx',
index_name='local_media_repository_url_idx',
table='local_media_repository',
columns=['created_ts'],
where_clause='url_cache IS NOT NULL',
)
def get_default_thumbnails(self, top_level_type, sub_type):
return []

View File

@ -15,6 +15,9 @@
from twisted.internet import defer
from synapse.storage.roommember import ProfileInfo
from synapse.api.errors import StoreError
from ._base import SQLBaseStore
@ -26,6 +29,30 @@ class ProfileStore(SQLBaseStore):
desc="create_profile",
)
@defer.inlineCallbacks
def get_profileinfo(self, user_localpart):
try:
profile = yield self._simple_select_one(
table="profiles",
keyvalues={"user_id": user_localpart},
retcols=("displayname", "avatar_url"),
desc="get_profileinfo",
)
except StoreError as e:
if e.code == 404:
# no match
defer.returnValue(ProfileInfo(None, None))
return
else:
raise
defer.returnValue(
ProfileInfo(
avatar_url=profile['avatar_url'],
display_name=profile['displayname'],
)
)
def get_profile_displayname(self, user_localpart):
return self._simple_select_one_onecol(
table="profiles",

View File

@ -204,34 +204,35 @@ class PusherStore(SQLBaseStore):
pushkey, pushkey_ts, lang, data, last_stream_ordering,
profile_tag=""):
with self._pushers_id_gen.get_next() as stream_id:
def f(txn):
newly_inserted = self._simple_upsert_txn(
txn,
"pushers",
{
"app_id": app_id,
"pushkey": pushkey,
"user_name": user_id,
},
{
"access_token": access_token,
"kind": kind,
"app_display_name": app_display_name,
"device_display_name": device_display_name,
"ts": pushkey_ts,
"lang": lang,
"data": encode_canonical_json(data),
"last_stream_ordering": last_stream_ordering,
"profile_tag": profile_tag,
"id": stream_id,
},
)
if newly_inserted:
# get_if_user_has_pusher only cares if the user has
# at least *one* pusher.
txn.call_after(self.get_if_user_has_pusher.invalidate, (user_id,))
# no need to lock because `pushers` has a unique key on
# (app_id, pushkey, user_name) so _simple_upsert will retry
newly_inserted = yield self._simple_upsert(
table="pushers",
keyvalues={
"app_id": app_id,
"pushkey": pushkey,
"user_name": user_id,
},
values={
"access_token": access_token,
"kind": kind,
"app_display_name": app_display_name,
"device_display_name": device_display_name,
"ts": pushkey_ts,
"lang": lang,
"data": encode_canonical_json(data),
"last_stream_ordering": last_stream_ordering,
"profile_tag": profile_tag,
"id": stream_id,
},
desc="add_pusher",
lock=False,
)
yield self.runInteraction("add_pusher", f)
if newly_inserted:
# get_if_user_has_pusher only cares if the user has
# at least *one* pusher.
self.get_if_user_has_pusher.invalidate(user_id,)
@defer.inlineCallbacks
def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id):
@ -243,11 +244,19 @@ class PusherStore(SQLBaseStore):
"pushers",
{"app_id": app_id, "pushkey": pushkey, "user_name": user_id}
)
self._simple_upsert_txn(
# it's possible for us to end up with duplicate rows for
# (app_id, pushkey, user_id) at different stream_ids, but that
# doesn't really matter.
self._simple_insert_txn(
txn,
"deleted_pushers",
{"app_id": app_id, "pushkey": pushkey, "user_id": user_id},
{"stream_id": stream_id},
table="deleted_pushers",
values={
"stream_id": stream_id,
"app_id": app_id,
"pushkey": pushkey,
"user_id": user_id,
},
)
with self._pushers_id_gen.get_next() as stream_id:
@ -310,9 +319,12 @@ class PusherStore(SQLBaseStore):
@defer.inlineCallbacks
def set_throttle_params(self, pusher_id, room_id, params):
# no need to lock because `pusher_throttle` has a primary key on
# (pusher, room_id) so _simple_upsert will retry
yield self._simple_upsert(
"pusher_throttle",
{"pusher": pusher_id, "room_id": room_id},
params,
desc="set_throttle_params"
desc="set_throttle_params",
lock=False,
)

View File

@ -254,8 +254,8 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
If None, tokens associated with any device (or no device) will
be deleted
Returns:
defer.Deferred[list[str, str|None]]: a list of the deleted tokens
and device IDs
defer.Deferred[list[str, int, str|None, int]]: a list of
(token, token id, device id) for each of the deleted tokens
"""
def f(txn):
keyvalues = {
@ -272,12 +272,12 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
values.append(except_token_id)
txn.execute(
"SELECT token, device_id FROM access_tokens WHERE %s" % where_clause,
"SELECT token, id, device_id FROM access_tokens WHERE %s" % where_clause,
values
)
tokens_and_devices = [(r[0], r[1]) for r in txn]
tokens_and_devices = [(r[0], r[1], r[2]) for r in txn]
for token, _ in tokens_and_devices:
for token, _, _ in tokens_and_devices:
self._invalidate_cache_and_stream(
txn, self.get_user_by_access_token, (token,)
)

View File

@ -13,7 +13,10 @@
* limitations under the License.
*/
CREATE INDEX local_media_repository_url_idx ON local_media_repository(created_ts) WHERE url_cache IS NOT NULL;
-- this didn't work on SQLite 3.7 (because of lack of partial indexes), so was
-- removed and replaced with 46/local_media_repository_url_idx.sql.
--
-- CREATE INDEX local_media_repository_url_idx ON local_media_repository(created_ts) WHERE url_cache IS NOT NULL;
-- we need to change `expires` to `expires_ts` so that we can index on it. SQLite doesn't support
-- indices on expressions until 3.9.

View File

@ -0,0 +1,35 @@
/* Copyright 2017 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- drop the unique constraint on deleted_pushers so that we can just insert
-- into it rather than upserting.
CREATE TABLE deleted_pushers2 (
stream_id BIGINT NOT NULL,
app_id TEXT NOT NULL,
pushkey TEXT NOT NULL,
user_id TEXT NOT NULL
);
INSERT INTO deleted_pushers2 (stream_id, app_id, pushkey, user_id)
SELECT stream_id, app_id, pushkey, user_id from deleted_pushers;
DROP TABLE deleted_pushers;
ALTER TABLE deleted_pushers2 RENAME TO deleted_pushers;
-- create the index after doing the inserts because that's more efficient.
-- it also means we can give it the same name as the old one without renaming.
CREATE INDEX deleted_pushers_stream_id ON deleted_pushers (stream_id);

View File

@ -0,0 +1,24 @@
/* Copyright 2017 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- register a background update which will recreate the
-- local_media_repository_url_idx index.
--
-- We do this as a bg update not because it is a particularly onerous
-- operation, but because we'd like it to be a partial index if possible, and
-- the background_index_update code will understand whether we are on
-- postgres or sqlite and behave accordingly.
INSERT INTO background_updates (update_name, progress_json) VALUES
('local_media_repository_url_idx', '{}');

View File

@ -0,0 +1,35 @@
/* Copyright 2017 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- change the user_directory table to also cover global local user profiles
-- rather than just profiles within specific rooms.
CREATE TABLE user_directory2 (
user_id TEXT NOT NULL,
room_id TEXT,
display_name TEXT,
avatar_url TEXT
);
INSERT INTO user_directory2(user_id, room_id, display_name, avatar_url)
SELECT user_id, room_id, display_name, avatar_url from user_directory;
DROP TABLE user_directory;
ALTER TABLE user_directory2 RENAME TO user_directory;
-- create indexes after doing the inserts because that's more efficient.
-- it also means we can give it the same name as the old one without renaming.
CREATE INDEX user_directory_room_idx ON user_directory(room_id);
CREATE UNIQUE INDEX user_directory_user_idx ON user_directory(user_id);

View File

@ -13,16 +13,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches import intern_string
from synapse.util.stringutils import to_ascii
from synapse.storage.engines import PostgresEngine
from collections import namedtuple
import logging
from twisted.internet import defer
from collections import namedtuple
import logging
from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.storage.engines import PostgresEngine
from synapse.util.caches import intern_string, CACHE_SIZE_FACTOR
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.dictionary_cache import DictionaryCache
from synapse.util.stringutils import to_ascii
from ._base import SQLBaseStore
logger = logging.getLogger(__name__)
@ -40,23 +42,11 @@ class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delt
return len(self.delta_ids) if self.delta_ids else 0
class StateStore(SQLBaseStore):
""" Keeps track of the state at a given event.
class StateGroupReadStore(SQLBaseStore):
"""The read-only parts of StateGroupStore
This is done by the concept of `state groups`. Every event is a assigned
a state group (identified by an arbitrary string), which references a
collection of state events. The current state of an event is then the
collection of state events referenced by the event's state group.
Hence, every change in the current state causes a new state group to be
generated. However, if no change happens (e.g., if we get a message event
with only one parent it inherits the state group from its parent.)
There are three tables:
* `state_groups`: Stores group name, first event with in the group and
room id.
* `event_to_state_groups`: Maps events to state groups.
* `state_groups_state`: Maps state group to state events.
None of these functions write to the state tables, so are suitable for
including in the SlavedStores.
"""
STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
@ -64,21 +54,10 @@ class StateStore(SQLBaseStore):
CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
def __init__(self, db_conn, hs):
super(StateStore, self).__init__(db_conn, hs)
self.register_background_update_handler(
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
self._background_deduplicate_state,
)
self.register_background_update_handler(
self.STATE_GROUP_INDEX_UPDATE_NAME,
self._background_index_state,
)
self.register_background_index_update(
self.CURRENT_STATE_INDEX_UPDATE_NAME,
index_name="current_state_events_member_index",
table="current_state_events",
columns=["state_key"],
where_clause="type='m.room.member'",
super(StateGroupReadStore, self).__init__(db_conn, hs)
self._state_group_cache = DictionaryCache(
"*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR
)
@cached(max_entries=100000, iterable=True)
@ -190,178 +169,6 @@ class StateStore(SQLBaseStore):
for group, event_id_map in group_to_ids.iteritems()
})
def _have_persisted_state_group_txn(self, txn, state_group):
txn.execute(
"SELECT count(*) FROM state_groups WHERE id = ?",
(state_group,)
)
row = txn.fetchone()
return row and row[0]
def _store_mult_state_groups_txn(self, txn, events_and_contexts):
state_groups = {}
for event, context in events_and_contexts:
if event.internal_metadata.is_outlier():
continue
if context.current_state_ids is None:
# AFAIK, this can never happen
logger.error(
"Non-outlier event %s had current_state_ids==None",
event.event_id)
continue
# if the event was rejected, just give it the same state as its
# predecessor.
if context.rejected:
state_groups[event.event_id] = context.prev_group
continue
state_groups[event.event_id] = context.state_group
if self._have_persisted_state_group_txn(txn, context.state_group):
continue
self._simple_insert_txn(
txn,
table="state_groups",
values={
"id": context.state_group,
"room_id": event.room_id,
"event_id": event.event_id,
},
)
# We persist as a delta if we can, while also ensuring the chain
# of deltas isn't tooo long, as otherwise read performance degrades.
if context.prev_group:
is_in_db = self._simple_select_one_onecol_txn(
txn,
table="state_groups",
keyvalues={"id": context.prev_group},
retcol="id",
allow_none=True,
)
if not is_in_db:
raise Exception(
"Trying to persist state with unpersisted prev_group: %r"
% (context.prev_group,)
)
potential_hops = self._count_state_group_hops_txn(
txn, context.prev_group
)
if context.prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
self._simple_insert_txn(
txn,
table="state_group_edges",
values={
"state_group": context.state_group,
"prev_state_group": context.prev_group,
},
)
self._simple_insert_many_txn(
txn,
table="state_groups_state",
values=[
{
"state_group": context.state_group,
"room_id": event.room_id,
"type": key[0],
"state_key": key[1],
"event_id": state_id,
}
for key, state_id in context.delta_ids.iteritems()
],
)
else:
self._simple_insert_many_txn(
txn,
table="state_groups_state",
values=[
{
"state_group": context.state_group,
"room_id": event.room_id,
"type": key[0],
"state_key": key[1],
"event_id": state_id,
}
for key, state_id in context.current_state_ids.iteritems()
],
)
# Prefill the state group cache with this group.
# It's fine to use the sequence like this as the state group map
# is immutable. (If the map wasn't immutable then this prefill could
# race with another update)
txn.call_after(
self._state_group_cache.update,
self._state_group_cache.sequence,
key=context.state_group,
value=dict(context.current_state_ids),
full=True,
)
self._simple_insert_many_txn(
txn,
table="event_to_state_groups",
values=[
{
"state_group": state_group_id,
"event_id": event_id,
}
for event_id, state_group_id in state_groups.iteritems()
],
)
for event_id, state_group_id in state_groups.iteritems():
txn.call_after(
self._get_state_group_for_event.prefill,
(event_id,), state_group_id
)
def _count_state_group_hops_txn(self, txn, state_group):
"""Given a state group, count how many hops there are in the tree.
This is used to ensure the delta chains don't get too long.
"""
if isinstance(self.database_engine, PostgresEngine):
sql = ("""
WITH RECURSIVE state(state_group) AS (
VALUES(?::bigint)
UNION ALL
SELECT prev_state_group FROM state_group_edges e, state s
WHERE s.state_group = e.state_group
)
SELECT count(*) FROM state;
""")
txn.execute(sql, (state_group,))
row = txn.fetchone()
if row and row[0]:
return row[0]
else:
return 0
else:
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
next_group = state_group
count = 0
while next_group:
next_group = self._simple_select_one_onecol_txn(
txn,
table="state_group_edges",
keyvalues={"state_group": next_group},
retcol="prev_state_group",
allow_none=True,
)
if next_group:
count += 1
return count
@defer.inlineCallbacks
def _get_state_groups_from_groups(self, groups, types):
"""Returns dictionary state_group -> (dict of (type, state_key) -> event id)
@ -742,6 +549,220 @@ class StateStore(SQLBaseStore):
defer.returnValue(results)
class StateStore(StateGroupReadStore, BackgroundUpdateStore):
""" Keeps track of the state at a given event.
This is done by the concept of `state groups`. Every event is a assigned
a state group (identified by an arbitrary string), which references a
collection of state events. The current state of an event is then the
collection of state events referenced by the event's state group.
Hence, every change in the current state causes a new state group to be
generated. However, if no change happens (e.g., if we get a message event
with only one parent it inherits the state group from its parent.)
There are three tables:
* `state_groups`: Stores group name, first event with in the group and
room id.
* `event_to_state_groups`: Maps events to state groups.
* `state_groups_state`: Maps state group to state events.
"""
STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
def __init__(self, db_conn, hs):
super(StateStore, self).__init__(db_conn, hs)
self.register_background_update_handler(
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
self._background_deduplicate_state,
)
self.register_background_update_handler(
self.STATE_GROUP_INDEX_UPDATE_NAME,
self._background_index_state,
)
self.register_background_index_update(
self.CURRENT_STATE_INDEX_UPDATE_NAME,
index_name="current_state_events_member_index",
table="current_state_events",
columns=["state_key"],
where_clause="type='m.room.member'",
)
def _have_persisted_state_group_txn(self, txn, state_group):
txn.execute(
"SELECT count(*) FROM state_groups WHERE id = ?",
(state_group,)
)
row = txn.fetchone()
return row and row[0]
def _store_mult_state_groups_txn(self, txn, events_and_contexts):
state_groups = {}
for event, context in events_and_contexts:
if event.internal_metadata.is_outlier():
continue
if context.current_state_ids is None:
# AFAIK, this can never happen
logger.error(
"Non-outlier event %s had current_state_ids==None",
event.event_id)
continue
# if the event was rejected, just give it the same state as its
# predecessor.
if context.rejected:
state_groups[event.event_id] = context.prev_group
continue
state_groups[event.event_id] = context.state_group
if self._have_persisted_state_group_txn(txn, context.state_group):
continue
self._simple_insert_txn(
txn,
table="state_groups",
values={
"id": context.state_group,
"room_id": event.room_id,
"event_id": event.event_id,
},
)
# We persist as a delta if we can, while also ensuring the chain
# of deltas isn't tooo long, as otherwise read performance degrades.
if context.prev_group:
is_in_db = self._simple_select_one_onecol_txn(
txn,
table="state_groups",
keyvalues={"id": context.prev_group},
retcol="id",
allow_none=True,
)
if not is_in_db:
raise Exception(
"Trying to persist state with unpersisted prev_group: %r"
% (context.prev_group,)
)
potential_hops = self._count_state_group_hops_txn(
txn, context.prev_group
)
if context.prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
self._simple_insert_txn(
txn,
table="state_group_edges",
values={
"state_group": context.state_group,
"prev_state_group": context.prev_group,
},
)
self._simple_insert_many_txn(
txn,
table="state_groups_state",
values=[
{
"state_group": context.state_group,
"room_id": event.room_id,
"type": key[0],
"state_key": key[1],
"event_id": state_id,
}
for key, state_id in context.delta_ids.iteritems()
],
)
else:
self._simple_insert_many_txn(
txn,
table="state_groups_state",
values=[
{
"state_group": context.state_group,
"room_id": event.room_id,
"type": key[0],
"state_key": key[1],
"event_id": state_id,
}
for key, state_id in context.current_state_ids.iteritems()
],
)
# Prefill the state group cache with this group.
# It's fine to use the sequence like this as the state group map
# is immutable. (If the map wasn't immutable then this prefill could
# race with another update)
txn.call_after(
self._state_group_cache.update,
self._state_group_cache.sequence,
key=context.state_group,
value=dict(context.current_state_ids),
full=True,
)
self._simple_insert_many_txn(
txn,
table="event_to_state_groups",
values=[
{
"state_group": state_group_id,
"event_id": event_id,
}
for event_id, state_group_id in state_groups.iteritems()
],
)
for event_id, state_group_id in state_groups.iteritems():
txn.call_after(
self._get_state_group_for_event.prefill,
(event_id,), state_group_id
)
def _count_state_group_hops_txn(self, txn, state_group):
"""Given a state group, count how many hops there are in the tree.
This is used to ensure the delta chains don't get too long.
"""
if isinstance(self.database_engine, PostgresEngine):
sql = ("""
WITH RECURSIVE state(state_group) AS (
VALUES(?::bigint)
UNION ALL
SELECT prev_state_group FROM state_group_edges e, state s
WHERE s.state_group = e.state_group
)
SELECT count(*) FROM state;
""")
txn.execute(sql, (state_group,))
row = txn.fetchone()
if row and row[0]:
return row[0]
else:
return 0
else:
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
next_group = state_group
count = 0
while next_group:
next_group = self._simple_select_one_onecol_txn(
txn,
table="state_group_edges",
keyvalues={"state_group": next_group},
retcol="prev_state_group",
allow_none=True,
)
if next_group:
count += 1
return count
def get_next_state_group(self):
return self._state_groups_id_gen.get_next()

View File

@ -39,7 +39,7 @@ from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached
from synapse.api.constants import EventTypes
from synapse.types import RoomStreamToken
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
import logging
@ -234,7 +234,7 @@ class StreamStore(SQLBaseStore):
results = {}
room_ids = list(room_ids)
for rm_ids in (room_ids[i:i + 20] for i in xrange(0, len(room_ids), 20)):
res = yield preserve_context_over_deferred(defer.gatherResults([
res = yield make_deferred_yieldable(defer.gatherResults([
preserve_fn(self.get_room_events_stream_for_room)(
room_id, from_key, to_key, limit, order=order,
)

View File

@ -164,7 +164,7 @@ class UserDirectoryStore(SQLBaseStore):
)
if isinstance(self.database_engine, PostgresEngine):
# We weight the loclpart most highly, then display name and finally
# We weight the localpart most highly, then display name and finally
# server name
if new_entry:
sql = """
@ -317,6 +317,16 @@ class UserDirectoryStore(SQLBaseStore):
rows = yield self._execute("get_all_rooms", None, sql)
defer.returnValue([room_id for room_id, in rows])
@defer.inlineCallbacks
def get_all_local_users(self):
"""Get all local users
"""
sql = """
SELECT name FROM users
"""
rows = yield self._execute("get_all_local_users", None, sql)
defer.returnValue([name for name, in rows])
def add_users_who_share_room(self, room_id, share_private, user_id_tuples):
"""Insert entries into the users_who_share_rooms table. The first
user should be a local user.
@ -629,6 +639,20 @@ class UserDirectoryStore(SQLBaseStore):
]
}
"""
if self.hs.config.user_directory_search_all_users:
join_clause = ""
where_clause = "?<>''" # naughty hack to keep the same number of binds
else:
join_clause = """
LEFT JOIN users_in_public_rooms AS p USING (user_id)
LEFT JOIN (
SELECT other_user_id AS user_id FROM users_who_share_rooms
WHERE user_id = ? AND share_private
) AS s USING (user_id)
"""
where_clause = "(s.user_id IS NOT NULL OR p.user_id IS NOT NULL)"
if isinstance(self.database_engine, PostgresEngine):
full_query, exact_query, prefix_query = _parse_query_postgres(search_term)
@ -641,13 +665,9 @@ class UserDirectoryStore(SQLBaseStore):
SELECT d.user_id, display_name, avatar_url
FROM user_directory_search
INNER JOIN user_directory AS d USING (user_id)
LEFT JOIN users_in_public_rooms AS p USING (user_id)
LEFT JOIN (
SELECT other_user_id AS user_id FROM users_who_share_rooms
WHERE user_id = ? AND share_private
) AS s USING (user_id)
%s
WHERE
(s.user_id IS NOT NULL OR p.user_id IS NOT NULL)
%s
AND vector @@ to_tsquery('english', ?)
ORDER BY
(CASE WHEN s.user_id IS NOT NULL THEN 4.0 ELSE 1.0 END)
@ -671,7 +691,7 @@ class UserDirectoryStore(SQLBaseStore):
display_name IS NULL,
avatar_url IS NULL
LIMIT ?
"""
""" % (join_clause, where_clause)
args = (user_id, full_query, exact_query, prefix_query, limit + 1,)
elif isinstance(self.database_engine, Sqlite3Engine):
search_query = _parse_query_sqlite(search_term)
@ -680,20 +700,16 @@ class UserDirectoryStore(SQLBaseStore):
SELECT d.user_id, display_name, avatar_url
FROM user_directory_search
INNER JOIN user_directory AS d USING (user_id)
LEFT JOIN users_in_public_rooms AS p USING (user_id)
LEFT JOIN (
SELECT other_user_id AS user_id FROM users_who_share_rooms
WHERE user_id = ? AND share_private
) AS s USING (user_id)
%s
WHERE
(s.user_id IS NOT NULL OR p.user_id IS NOT NULL)
%s
AND value MATCH ?
ORDER BY
rank(matchinfo(user_directory_search)) DESC,
display_name IS NULL,
avatar_url IS NULL
LIMIT ?
"""
""" % (join_clause, where_clause)
args = (user_id, search_query, limit + 1)
else:
# This should be unreachable.
@ -723,7 +739,7 @@ def _parse_query_sqlite(search_term):
# Pull out the individual words, discarding any non-word characters.
results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)
return " & ".join("(%s* | %s)" % (result, result,) for result in results)
return " & ".join("(%s* OR %s)" % (result, result,) for result in results)
def _parse_query_postgres(search_term):

View File

@ -17,7 +17,7 @@
from twisted.internet import defer, reactor
from .logcontext import (
PreserveLoggingContext, preserve_fn, preserve_context_over_deferred,
PreserveLoggingContext, make_deferred_yieldable, preserve_fn
)
from synapse.util import logcontext, unwrapFirstError
@ -351,7 +351,7 @@ class ReadWriteLock(object):
# We wait for the latest writer to finish writing. We can safely ignore
# any existing readers... as they're readers.
yield curr_writer
yield make_deferred_yieldable(curr_writer)
@contextmanager
def _ctx_manager():
@ -380,7 +380,7 @@ class ReadWriteLock(object):
curr_readers.clear()
self.key_to_current_writer[key] = new_defer
yield preserve_context_over_deferred(defer.gatherResults(to_wait_on))
yield make_deferred_yieldable(defer.gatherResults(to_wait_on))
@contextmanager
def _ctx_manager():

View File

@ -13,32 +13,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.util.logcontext import (
PreserveLoggingContext, preserve_context_over_fn
)
from synapse.util import unwrapFirstError
import logging
from twisted.internet import defer
from synapse.util import unwrapFirstError
from synapse.util.logcontext import PreserveLoggingContext
logger = logging.getLogger(__name__)
def user_left_room(distributor, user, room_id):
return preserve_context_over_fn(
distributor.fire,
"user_left_room", user=user, room_id=room_id
)
with PreserveLoggingContext():
distributor.fire("user_left_room", user=user, room_id=room_id)
def user_joined_room(distributor, user, room_id):
return preserve_context_over_fn(
distributor.fire,
"user_joined_room", user=user, room_id=room_id
)
with PreserveLoggingContext():
distributor.fire("user_joined_room", user=user, room_id=room_id)
class Distributor(object):

View File

@ -261,67 +261,6 @@ class PreserveLoggingContext(object):
)
class _PreservingContextDeferred(defer.Deferred):
"""A deferred that ensures that all callbacks and errbacks are called with
the given logging context.
"""
def __init__(self, context):
self._log_context = context
defer.Deferred.__init__(self)
def addCallbacks(self, callback, errback=None,
callbackArgs=None, callbackKeywords=None,
errbackArgs=None, errbackKeywords=None):
callback = self._wrap_callback(callback)
errback = self._wrap_callback(errback)
return defer.Deferred.addCallbacks(
self, callback,
errback=errback,
callbackArgs=callbackArgs,
callbackKeywords=callbackKeywords,
errbackArgs=errbackArgs,
errbackKeywords=errbackKeywords,
)
def _wrap_callback(self, f):
def g(res, *args, **kwargs):
with PreserveLoggingContext(self._log_context):
res = f(res, *args, **kwargs)
return res
return g
def preserve_context_over_fn(fn, *args, **kwargs):
"""Takes a function and invokes it with the given arguments, but removes
and restores the current logging context while doing so.
If the result is a deferred, call preserve_context_over_deferred before
returning it.
"""
with PreserveLoggingContext():
res = fn(*args, **kwargs)
if isinstance(res, defer.Deferred):
return preserve_context_over_deferred(res)
else:
return res
def preserve_context_over_deferred(deferred, context=None):
"""Given a deferred wrap it such that any callbacks added later to it will
be invoked with the current context.
Deprecated: this almost certainly doesn't do want you want, ie make
the deferred follow the synapse logcontext rules: try
``make_deferred_yieldable`` instead.
"""
if context is None:
context = LoggingContext.current_context()
d = _PreservingContextDeferred(context)
deferred.chainDeferred(d)
return d
def preserve_fn(f):
"""Wraps a function, to ensure that the current context is restored after
return from the function, and that the sentinel context is set once the

View File

@ -17,7 +17,7 @@ from twisted.internet import defer
from synapse.api.constants import Membership, EventTypes
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
import logging
@ -58,7 +58,7 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state,
always_include_ids (set(event_id)): set of event ids to specifically
include (unless sender is ignored)
"""
forgotten = yield preserve_context_over_deferred(defer.gatherResults([
forgotten = yield make_deferred_yieldable(defer.gatherResults([
defer.maybeDeferred(
preserve_fn(store.who_forgot_in_room),
room_id,

View File

@ -36,6 +36,7 @@ class ApplicationServiceTestCase(unittest.TestCase):
id="unique_identifier",
url="some_url",
token="some_token",
hostname="matrix.org", # only used by get_groups_for_user
namespaces={
ApplicationService.NS_USERS: [],
ApplicationService.NS_ROOMS: [],

View File

@ -58,7 +58,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
self.mock_federation_resource = MockHttpResource()
mock_notifier = Mock(spec=["on_new_event"])
mock_notifier = Mock()
self.on_new_event = mock_notifier.on_new_event
self.auth = Mock(spec=[])
@ -76,6 +76,9 @@ class TypingNotificationsTestCase(unittest.TestCase):
"set_received_txn_response",
"get_destination_retry_timings",
"get_devices_by_remote",
# Bits that user_directory needs
"get_user_directory_stream_pos",
"get_current_state_deltas",
]),
state_handler=self.state_handler,
handlers=None,
@ -122,6 +125,15 @@ class TypingNotificationsTestCase(unittest.TestCase):
return set(str(u) for u in self.room_members)
self.state_handler.get_current_user_in_room = get_current_user_in_room
self.datastore.get_user_directory_stream_pos.return_value = (
# we deliberately return a non-None stream pos to avoid doing an initial_spam
defer.succeed(1)
)
self.datastore.get_current_state_deltas.return_value = (
None
)
self.auth.check_joined_room = check_joined_room
self.datastore.get_to_device_stream_token = lambda: 0

View File

@ -1,5 +1,7 @@
from twisted.python import failure
from synapse.rest.client.v2_alpha.register import RegisterRestServlet
from synapse.api.errors import SynapseError
from synapse.api.errors import SynapseError, InteractiveAuthIncompleteError
from twisted.internet import defer
from mock import Mock
from tests import unittest
@ -24,7 +26,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
side_effect=lambda x: self.appservice)
)
self.auth_result = (False, None, None, None)
self.auth_result = failure.Failure(InteractiveAuthIncompleteError(None))
self.auth_handler = Mock(
check_auth=Mock(side_effect=lambda x, y, z: self.auth_result),
get_session_data=Mock(return_value=None)
@ -86,6 +88,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.request.args = {
"access_token": "i_am_an_app_service"
}
self.request_data = json.dumps({
"username": "kermit"
})
@ -120,7 +123,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
"device_id": device_id,
})
self.registration_handler.check_username = Mock(return_value=True)
self.auth_result = (True, None, {
self.auth_result = (None, {
"username": "kermit",
"password": "monkey"
}, None)
@ -150,7 +153,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
"password": "monkey"
})
self.registration_handler.check_username = Mock(return_value=True)
self.auth_result = (True, None, {
self.auth_result = (None, {
"username": "kermit",
"password": "monkey"
}, None)