Merge branch 'develop' of github.com:matrix-org/synapse into matrix-org-hotfixes

pull/4692/head
Erik Johnston 2019-02-19 13:24:37 +00:00
commit 0e07d2c7d5
44 changed files with 775 additions and 373 deletions

View File

@ -12,9 +12,6 @@ cache:
#
- $HOME/.cache/pip/wheels
addons:
postgresql: "9.4"
# don't clone the whole repo history, one commit will do
git:
depth: 1
@ -25,6 +22,7 @@ branches:
- master
- develop
- /^release-v/
- rav/pg95
# When running the tox environments that call Twisted Trial, we can pass the -j
# flag to run the tests concurrently. We set this to 2 for CPU bound tests
@ -32,36 +30,53 @@ branches:
matrix:
fast_finish: true
include:
- python: 2.7
env: TOX_ENV=packaging
- name: "pep8"
python: 3.6
env: TOX_ENV="pep8,check_isort,packaging"
- python: 3.6
env: TOX_ENV="pep8,check_isort"
- python: 2.7
- name: "py2.7 / sqlite"
python: 2.7
env: TOX_ENV=py27,codecov TRIAL_FLAGS="-j 2"
- python: 2.7
- name: "py2.7 / sqlite / olddeps"
python: 2.7
env: TOX_ENV=py27-old TRIAL_FLAGS="-j 2"
- python: 2.7
- name: "py2.7 / postgres9.5"
python: 2.7
addons:
postgresql: "9.5"
env: TOX_ENV=py27-postgres,codecov TRIAL_FLAGS="-j 4"
services:
- postgresql
- python: 3.5
- name: "py3.5 / sqlite"
python: 3.5
env: TOX_ENV=py35,codecov TRIAL_FLAGS="-j 2"
- python: 3.6
- name: "py3.6 / sqlite"
python: 3.6
env: TOX_ENV=py36,codecov TRIAL_FLAGS="-j 2"
- python: 3.6
- name: "py3.6 / postgres9.4"
python: 3.6
addons:
postgresql: "9.4"
env: TOX_ENV=py36-postgres TRIAL_FLAGS="-j 4"
services:
- postgresql
- name: "py3.6 / postgres9.5"
python: 3.6
addons:
postgresql: "9.5"
env: TOX_ENV=py36-postgres,codecov TRIAL_FLAGS="-j 4"
services:
- postgresql
- # we only need to check for the newsfragment if it's a PR build
if: type = pull_request
name: "check-newsfragment"
python: 3.6
env: TOX_ENV=check-newsfragment
script:
@ -70,6 +85,9 @@ matrix:
- tox -e $TOX_ENV
install:
# this just logs the postgres version we will be testing against (if any)
- psql -At -U postgres -c 'select version();'
- pip install tox
# if we don't have python3.6 in this environment, travis unhelpfully gives us

View File

@ -39,7 +39,7 @@ instructions that may be required are listed later in this document.
./synctl restart
To check whether your update was sucessful, you can check the Server header
To check whether your update was successful, you can check the Server header
returned by the Client-Server API:
.. code:: bash

1
changelog.d/4632.feature Normal file
View File

@ -0,0 +1 @@
Add basic optional sentry integration

1
changelog.d/4642.feature Normal file
View File

@ -0,0 +1 @@
Transfer bans on room upgrade.

1
changelog.d/4643.misc Normal file
View File

@ -0,0 +1 @@
Reduce number of exceptions we log

1
changelog.d/4652.feature Normal file
View File

@ -0,0 +1 @@
Support .well-known delegation when issuing certificates through ACME.

1
changelog.d/4657.misc Normal file
View File

@ -0,0 +1 @@
Fix various spelling mistakes.

1
changelog.d/4666.feature Normal file
View File

@ -0,0 +1 @@
Allow registration and login to be handled by a worker instance.

1
changelog.d/4667.bugfix Normal file
View File

@ -0,0 +1 @@
Fix kicking guest users on guest access revocation in worker mode.

1
changelog.d/4668.misc Normal file
View File

@ -0,0 +1 @@
Reduce number of exceptions we log

1
changelog.d/4669.misc Normal file
View File

@ -0,0 +1 @@
Cleanup request exception logging

1
changelog.d/4670.feature Normal file
View File

@ -0,0 +1 @@
Allow registration and login to be handled by a worker instance.

1
changelog.d/4671.misc Normal file
View File

@ -0,0 +1 @@
Improve replication performance by reducing cache invalidation traffic.

1
changelog.d/4674.feature Normal file
View File

@ -0,0 +1 @@
Reduce the overhead of creating outbound federation connections over TLS by caching the TLS client options.

1
changelog.d/4676.misc Normal file
View File

@ -0,0 +1 @@
Test against Postgres 9.5 as well as 9.4

View File

@ -137,7 +137,6 @@ for each stream so that on reconneciton it can start streaming from the correct
place. Note: not all RDATA have valid tokens due to batching. See
``RdataCommand`` for more details.
Example
~~~~~~~
@ -221,3 +220,28 @@ SYNC (S, C)
See ``synapse/replication/tcp/commands.py`` for a detailed description and the
format of each command.
Cache Invalidation Stream
~~~~~~~~~~~~~~~~~~~~~~~~~
The cache invalidation stream is used to inform workers when they need to
invalidate any of their caches in the data store. This is done by streaming all
cache invalidations done on master down to the workers, assuming that any caches
on the workers also exist on the master.
Each individual cache invalidation results in a row being sent down replication,
which includes the cache name (the name of the function) and they key to
invalidate. For example::
> RDATA caches 550953771 ["get_user_by_id", ["@bob:example.com"], 1550574873251]
However, there are times when a number of caches need to be invalidated at the
same time with the same key. To reduce traffic we batch those invalidations into
a single poke by defining a special cache name that workers understand to mean
to expand to invalidate the correct caches.
Currently the special cache names are declared in ``synapse/storage/_base.py``
and are:
1. ``cs_cache_fake`` ─ invalidates caches that depend on the current state

View File

@ -222,6 +222,13 @@ following regular expressions::
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/context/.*$
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/members$
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/state$
^/_matrix/client/(api/v1|r0|unstable)/login$
Additionally, the following REST endpoints can be handled, but all requests must
be routed to the same instance::
^/_matrix/client/(r0|unstable)/register$
``synapse.app.user_dir``
~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -25,10 +25,12 @@ from daemonize import Daemonize
from twisted.internet import error, reactor
from twisted.protocols.tls import TLSMemoryBIOFactory
import synapse
from synapse.app import check_bind_error
from synapse.crypto import context_factory
from synapse.util import PreserveLoggingContext
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
logger = logging.getLogger(__name__)
@ -270,9 +272,37 @@ def start(hs, listeners=None):
# It is now safe to start your Synapse.
hs.start_listening(listeners)
hs.get_datastore().start_profiling()
setup_sentry(hs)
except Exception:
traceback.print_exc(file=sys.stderr)
reactor = hs.get_reactor()
if reactor.running:
reactor.stop()
sys.exit(1)
def setup_sentry(hs):
"""Enable sentry integration, if enabled in configuration
Args:
hs (synapse.server.HomeServer)
"""
if not hs.config.sentry_enabled:
return
import sentry_sdk
sentry_sdk.init(
dsn=hs.config.sentry_dsn,
release=get_version_string(synapse),
)
# We set some default tags that give some context to this instance
with sentry_sdk.configure_scope() as scope:
scope.set_tag("matrix_server_name", hs.config.server_name)
app = hs.config.worker_app if hs.config.worker_app else "synapse.app.homeserver"
name = hs.config.worker_name if hs.config.worker_name else "master"
scope.set_tag("worker_app", app)
scope.set_tag("worker_name", name)

View File

@ -40,6 +40,7 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto
from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.slave.storage.transactions import SlavedTransactionStore
from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.rest.client.v1.login import LoginRestServlet
from synapse.rest.client.v1.room import (
JoinedRoomMemberListRestServlet,
PublicRoomListRestServlet,
@ -47,6 +48,7 @@ from synapse.rest.client.v1.room import (
RoomMemberListRestServlet,
RoomStateRestServlet,
)
from synapse.rest.client.v2_alpha.register import RegisterRestServlet
from synapse.server import HomeServer
from synapse.storage.engines import create_engine
from synapse.util.httpresourcetree import create_resource_tree
@ -92,6 +94,8 @@ class ClientReaderServer(HomeServer):
JoinedRoomMemberListRestServlet(self).register(resource)
RoomStateRestServlet(self).register(resource)
RoomEventContextServlet(self).register(resource)
RegisterRestServlet(self).register(resource)
LoginRestServlet(self).register(resource)
resources.update({
"/_matrix/client/r0": resource,

View File

@ -40,6 +40,7 @@ from synapse.replication.slave.storage.profile import SlavedProfileStore
from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
from synapse.replication.slave.storage.pushers import SlavedPusherStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.slave.storage.transactions import SlavedTransactionStore
from synapse.replication.tcp.client import ReplicationClientHandler
@ -63,6 +64,7 @@ class FederationReaderSlavedStore(
SlavedReceiptsStore,
SlavedEventStore,
SlavedKeyStore,
SlavedRegistrationStore,
RoomStore,
DirectoryStore,
SlavedTransactionStore,

View File

@ -56,7 +56,7 @@ class KeyConfig(Config):
if not self.macaroon_secret_key:
# Unfortunately, there are people out there that don't have this
# set. Lets just be "nice" and derive one from their secret key.
logger.warn("Config is missing missing macaroon_secret_key")
logger.warn("Config is missing macaroon_secret_key")
seed = bytes(self.signing_key[0])
self.macaroon_secret_key = hashlib.sha256(seed).digest()

View File

@ -13,7 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config
from ._base import Config, ConfigError
MISSING_SENTRY = (
"""Missing sentry-sdk library. This is required to enable sentry
integration.
"""
)
class MetricsConfig(Config):
@ -23,12 +29,34 @@ class MetricsConfig(Config):
self.metrics_port = config.get("metrics_port")
self.metrics_bind_host = config.get("metrics_bind_host", "127.0.0.1")
self.sentry_enabled = "sentry" in config
if self.sentry_enabled:
try:
import sentry_sdk # noqa F401
except ImportError:
raise ConfigError(MISSING_SENTRY)
self.sentry_dsn = config["sentry"].get("dsn")
if not self.sentry_dsn:
raise ConfigError(
"sentry.dsn field is required when sentry integration is enabled",
)
def default_config(self, report_stats=None, **kwargs):
res = """\
## Metrics ###
# Enable collection and rendering of performance metrics
enable_metrics: False
# Enable sentry integration
# NOTE: While attempts are made to ensure that the logs don't contain
# any sensitive information, this cannot be guaranteed. By enabling
# this option the sentry server may therefore receive sensitive
# information, and it in turn may then diseminate sensitive information
# through insecure notification channels if so configured.
#sentry:
# dsn: "..."
"""
if report_stats is None:

View File

@ -42,6 +42,7 @@ class TlsConfig(Config):
self.acme_port = acme_config.get("port", 80)
self.acme_bind_addresses = acme_config.get("bind_addresses", ['::', '0.0.0.0'])
self.acme_reprovision_threshold = acme_config.get("reprovision_threshold", 30)
self.acme_domain = acme_config.get("domain", config.get("server_name"))
self.tls_certificate_file = self.abspath(config.get("tls_certificate_path"))
self.tls_private_key_file = self.abspath(config.get("tls_private_key_path"))
@ -229,6 +230,20 @@ class TlsConfig(Config):
#
# reprovision_threshold: 30
# The domain that the certificate should be for. Normally this
# should be the same as your Matrix domain (i.e., 'server_name'), but,
# by putting a file at 'https://<server_name>/.well-known/matrix/server',
# you can delegate incoming traffic to another server. If you do that,
# you should give the target of the delegation here.
#
# For example: if your 'server_name' is 'example.com', but
# 'https://example.com/.well-known/matrix/server' delegates to
# 'matrix.example.com', you should put 'matrix.example.com' here.
#
# If not set, defaults to your 'server_name'.
#
# domain: matrix.example.com
# List of allowed TLS fingerprints for this server to publish along
# with the signing keys for this server. Other matrix servers that
# make HTTPS requests to this server will check that the TLS

View File

@ -1,4 +1,5 @@
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -11,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from zope.interface import implementer
@ -105,9 +107,7 @@ class ClientTLSOptions(object):
self._hostnameBytes = _idnaBytes(hostname)
self._sendSNI = True
ctx.set_info_callback(
_tolerateErrors(self._identityVerifyingInfoCallback)
)
ctx.set_info_callback(_tolerateErrors(self._identityVerifyingInfoCallback))
def clientConnectionForTLS(self, tlsProtocol):
context = self._ctx
@ -128,10 +128,8 @@ class ClientTLSOptionsFactory(object):
def __init__(self, config):
# We don't use config options yet
pass
self._options = CertificateOptions(verify=False)
def get_options(self, host):
return ClientTLSOptions(
host,
CertificateOptions(verify=False).getContext()
)
# Use _makeContext so that we get a fresh OpenSSL CTX each time.
return ClientTLSOptions(host, self._options._makeContext())

View File

@ -35,7 +35,7 @@ from unpaddedbase64 import decode_base64
from twisted.internet import defer
from synapse.api.errors import Codes, SynapseError
from synapse.api.errors import Codes, RequestSendFailed, SynapseError
from synapse.util import logcontext, unwrapFirstError
from synapse.util.logcontext import (
LoggingContext,
@ -656,7 +656,7 @@ def _handle_key_deferred(verify_request):
try:
with PreserveLoggingContext():
_, key_id, verify_key = yield verify_request.deferred
except IOError as e:
except (IOError, RequestSendFailed) as e:
logger.warn(
"Got IOError when downloading keys for %s: %s %s",
server_name, type(e).__name__, str(e),

View File

@ -42,7 +42,7 @@ from signedjson.sign import sign_json
from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.api.errors import RequestSendFailed, SynapseError
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import get_domain_from_id
from synapse.util.logcontext import run_in_background
@ -191,6 +191,11 @@ class GroupAttestionRenewer(object):
yield self.store.update_attestation_renewal(
group_id, user_id, attestation
)
except RequestSendFailed as e:
logger.warning(
"Failed to renew attestation of %r in %r: %s",
user_id, group_id, e,
)
except Exception:
logger.exception("Error renewing attestation of %r in %r",
user_id, group_id)

View File

@ -167,4 +167,4 @@ class BaseHandler(object):
ratelimit=False,
)
except Exception as e:
logger.warn("Error kicking guest user: %s" % (e,))
logger.exception("Error kicking guest user: %s" % (e,))

View File

@ -56,6 +56,7 @@ class AcmeHandler(object):
def __init__(self, hs):
self.hs = hs
self.reactor = hs.get_reactor()
self._acme_domain = hs.config.acme_domain
@defer.inlineCallbacks
def start_listening(self):
@ -123,15 +124,15 @@ class AcmeHandler(object):
@defer.inlineCallbacks
def provision_certificate(self):
logger.warning("Reprovisioning %s", self.hs.hostname)
logger.warning("Reprovisioning %s", self._acme_domain)
try:
yield self._issuer.issue_cert(self.hs.hostname)
yield self._issuer.issue_cert(self._acme_domain)
except Exception:
logger.exception("Fail!")
raise
logger.warning("Reprovisioned %s, saving.", self.hs.hostname)
cert_chain = self._store.certs[self.hs.hostname]
logger.warning("Reprovisioned %s, saving.", self._acme_domain)
cert_chain = self._store.certs[self._acme_domain]
try:
with open(self.hs.config.tls_private_key_file, "wb") as private_key_file:

View File

@ -20,7 +20,11 @@ from twisted.internet import defer
from synapse.api import errors
from synapse.api.constants import EventTypes
from synapse.api.errors import FederationDeniedError
from synapse.api.errors import (
FederationDeniedError,
HttpResponseException,
RequestSendFailed,
)
from synapse.types import RoomStreamToken, get_domain_from_id
from synapse.util import stringutils
from synapse.util.async_helpers import Linearizer
@ -504,13 +508,13 @@ class DeviceListEduUpdater(object):
origin = get_domain_from_id(user_id)
try:
result = yield self.federation.query_user_devices(origin, user_id)
except NotRetryingDestination:
except (
NotRetryingDestination, RequestSendFailed, HttpResponseException,
):
# TODO: Remember that we are now out of sync and try again
# later
logger.warn(
"Failed to handle device list update for %s,"
" we're not retrying the remote",
user_id,
"Failed to handle device list update for %s", user_id,
)
# We abort on exceptions rather than accepting the update
# as otherwise synapse will 'forget' that its device list

View File

@ -20,7 +20,7 @@ from six import iteritems
from twisted.internet import defer
from synapse.api.errors import HttpResponseException, SynapseError
from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
from synapse.types import get_domain_from_id
logger = logging.getLogger(__name__)
@ -46,13 +46,19 @@ def _create_rerouter(func_name):
# when the remote end responds with things like 403 Not
# In Group, we can communicate that to the client instead
# of a 500.
def h(failure):
def http_response_errback(failure):
failure.trap(HttpResponseException)
e = failure.value
if e.code == 403:
raise e.to_synapse_error()
return failure
d.addErrback(h)
def request_failed_errback(failure):
failure.trap(RequestSendFailed)
raise SynapseError(502, "Failed to contact group server")
d.addErrback(http_response_errback)
d.addErrback(request_failed_errback)
return d
return f

View File

@ -27,6 +27,8 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.http.client import CaptchaServerHttpClient
from synapse.replication.http.login import RegisterDeviceReplicationServlet
from synapse.replication.http.register import ReplicationRegisterServlet
from synapse.types import RoomAlias, RoomID, UserID, create_requester
from synapse.util.async_helpers import Linearizer
from synapse.util.threepids import check_3pid_allowed
@ -61,6 +63,14 @@ class RegistrationHandler(BaseHandler):
)
self._server_notices_mxid = hs.config.server_notices_mxid
if hs.config.worker_app:
self._register_client = ReplicationRegisterServlet.make_client(hs)
self._register_device_client = (
RegisterDeviceReplicationServlet.make_client(hs)
)
else:
self.device_handler = hs.get_device_handler()
@defer.inlineCallbacks
def check_username(self, localpart, guest_access_token=None,
assigned_user_id=None):
@ -155,7 +165,7 @@ class RegistrationHandler(BaseHandler):
yield self.auth.check_auth_blocking(threepid=threepid)
password_hash = None
if password:
password_hash = yield self.auth_handler().hash(password)
password_hash = yield self._auth_handler.hash(password)
if localpart:
yield self.check_username(localpart, guest_access_token=guest_access_token)
@ -185,7 +195,7 @@ class RegistrationHandler(BaseHandler):
token = None
if generate_token:
token = self.macaroon_gen.generate_access_token(user_id)
yield self.store.register(
yield self._register_with_store(
user_id=user_id,
token=token,
password_hash=password_hash,
@ -217,7 +227,7 @@ class RegistrationHandler(BaseHandler):
if default_display_name is None:
default_display_name = localpart
try:
yield self.store.register(
yield self._register_with_store(
user_id=user_id,
token=token,
password_hash=password_hash,
@ -316,7 +326,7 @@ class RegistrationHandler(BaseHandler):
user_id, allowed_appservice=service
)
yield self.store.register(
yield self._register_with_store(
user_id=user_id,
password_hash="",
appservice_id=service_id,
@ -494,7 +504,7 @@ class RegistrationHandler(BaseHandler):
token = self.macaroon_gen.generate_access_token(user_id)
if need_register:
yield self.store.register(
yield self._register_with_store(
user_id=user_id,
token=token,
password_hash=password_hash,
@ -512,9 +522,6 @@ class RegistrationHandler(BaseHandler):
defer.returnValue((user_id, token))
def auth_handler(self):
return self.hs.get_auth_handler()
@defer.inlineCallbacks
def get_or_register_3pid_guest(self, medium, address, inviter_user_id):
"""Get a guest access token for a 3PID, creating a guest account if
@ -573,3 +580,94 @@ class RegistrationHandler(BaseHandler):
action="join",
ratelimit=False,
)
def _register_with_store(self, user_id, token=None, password_hash=None,
was_guest=False, make_guest=False, appservice_id=None,
create_profile_with_displayname=None, admin=False,
user_type=None):
"""Register user in the datastore.
Args:
user_id (str): The desired user ID to register.
token (str): The desired access token to use for this user. If this
is not None, the given access token is associated with the user
id.
password_hash (str|None): Optional. The password hash for this user.
was_guest (bool): Optional. Whether this is a guest account being
upgraded to a non-guest account.
make_guest (boolean): True if the the new user should be guest,
false to add a regular user account.
appservice_id (str|None): The ID of the appservice registering the user.
create_profile_with_displayname (unicode|None): Optionally create a
profile for the user, setting their displayname to the given value
admin (boolean): is an admin user?
user_type (str|None): type of user. One of the values from
api.constants.UserTypes, or None for a normal user.
Returns:
Deferred
"""
if self.hs.config.worker_app:
return self._register_client(
user_id=user_id,
token=token,
password_hash=password_hash,
was_guest=was_guest,
make_guest=make_guest,
appservice_id=appservice_id,
create_profile_with_displayname=create_profile_with_displayname,
admin=admin,
user_type=user_type,
)
else:
return self.store.register(
user_id=user_id,
token=token,
password_hash=password_hash,
was_guest=was_guest,
make_guest=make_guest,
appservice_id=appservice_id,
create_profile_with_displayname=create_profile_with_displayname,
admin=admin,
user_type=user_type,
)
@defer.inlineCallbacks
def register_device(self, user_id, device_id, initial_display_name,
is_guest=False):
"""Register a device for a user and generate an access token.
Args:
user_id (str): full canonical @user:id
device_id (str|None): The device ID to check, or None to generate
a new one.
initial_display_name (str|None): An optional display name for the
device.
is_guest (bool): Whether this is a guest account
Returns:
defer.Deferred[tuple[str, str]]: Tuple of device ID and access token
"""
if self.hs.config.worker_app:
r = yield self._register_device_client(
user_id=user_id,
device_id=device_id,
initial_display_name=initial_display_name,
is_guest=is_guest,
)
defer.returnValue((r["device_id"], r["access_token"]))
else:
device_id = yield self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
)
if is_guest:
access_token = self.macaroon_gen.generate_access_token(
user_id, ["guest = true"]
)
else:
access_token = yield self._auth_handler.get_access_token_for_user_id(
user_id, device_id=device_id,
)
defer.returnValue((device_id, access_token))

View File

@ -311,6 +311,28 @@ class RoomCreationHandler(BaseHandler):
creation_content=creation_content,
)
# Transfer membership events
old_room_member_state_ids = yield self.store.get_filtered_current_state_ids(
old_room_id, StateFilter.from_types([(EventTypes.Member, None)]),
)
# map from event_id to BaseEvent
old_room_member_state_events = yield self.store.get_events(
old_room_member_state_ids.values(),
)
for k, old_event in iteritems(old_room_member_state_events):
# Only transfer ban events
if ("membership" in old_event.content and
old_event.content["membership"] == "ban"):
yield self.room_member_handler.update_membership(
requester,
UserID.from_string(old_event['state_key']),
new_room_id,
"ban",
ratelimit=False,
content=old_event.content,
)
# XXX invites/joins
# XXX 3pid invites

View File

@ -106,10 +106,10 @@ def wrap_json_request_handler(h):
# trace.
f = failure.Failure()
logger.error(
"Failed handle request via %r: %r: %s",
h,
"Failed handle request via %r: %r",
request.request_metrics.name,
request,
f.getTraceback().rstrip(),
exc_info=(f.type, f.value, f.getTracebackObject()),
)
# Only respond with an error response if we haven't already started
# writing, otherwise lets just kill the connection

View File

@ -86,6 +86,7 @@ CONDITIONAL_REQUIREMENTS = {
"saml2": ["pysaml2>=4.5.0"],
"url_preview": ["lxml>=3.5.0"],
"test": ["mock>=2.0", "parameterized"],
"sentry": ["sentry-sdk>=0.7.2"],
}

View File

@ -14,7 +14,7 @@
# limitations under the License.
from synapse.http.server import JsonResource
from synapse.replication.http import federation, membership, send_event
from synapse.replication.http import federation, login, membership, register, send_event
REPLICATION_PREFIX = "/_synapse/replication"
@ -28,3 +28,5 @@ class ReplicationRestResource(JsonResource):
send_event.register_servlets(hs, self)
membership.register_servlets(hs, self)
federation.register_servlets(hs, self)
login.register_servlets(hs, self)
register.register_servlets(hs, self)

View File

@ -0,0 +1,74 @@
# -*- coding: utf-8 -*-
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
logger = logging.getLogger(__name__)
class RegisterDeviceReplicationServlet(ReplicationEndpoint):
"""Ensure a device is registered, generating a new access token for the
device.
Used during registration and login.
"""
NAME = "device_check_registered"
PATH_ARGS = ("user_id",)
def __init__(self, hs):
super(RegisterDeviceReplicationServlet, self).__init__(hs)
self.registration_handler = hs.get_handlers().registration_handler
@staticmethod
def _serialize_payload(user_id, device_id, initial_display_name, is_guest):
"""
Args:
device_id (str|None): Device ID to use, if None a new one is
generated.
initial_display_name (str|None)
is_guest (bool)
"""
return {
"device_id": device_id,
"initial_display_name": initial_display_name,
"is_guest": is_guest,
}
@defer.inlineCallbacks
def _handle_request(self, request, user_id):
content = parse_json_object_from_request(request)
device_id = content["device_id"]
initial_display_name = content["initial_display_name"]
is_guest = content["is_guest"]
device_id, access_token = yield self.registration_handler.register_device(
user_id, device_id, initial_display_name, is_guest,
)
defer.returnValue((200, {
"device_id": device_id,
"access_token": access_token,
}))
def register_servlets(hs, http_server):
RegisterDeviceReplicationServlet(hs).register(http_server)

View File

@ -0,0 +1,91 @@
# -*- coding: utf-8 -*-
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
logger = logging.getLogger(__name__)
class ReplicationRegisterServlet(ReplicationEndpoint):
"""Register a new user
"""
NAME = "register_user"
PATH_ARGS = ("user_id",)
def __init__(self, hs):
super(ReplicationRegisterServlet, self).__init__(hs)
self.store = hs.get_datastore()
@staticmethod
def _serialize_payload(
user_id, token, password_hash, was_guest, make_guest, appservice_id,
create_profile_with_displayname, admin, user_type,
):
"""
Args:
user_id (str): The desired user ID to register.
token (str): The desired access token to use for this user. If this
is not None, the given access token is associated with the user
id.
password_hash (str|None): Optional. The password hash for this user.
was_guest (bool): Optional. Whether this is a guest account being
upgraded to a non-guest account.
make_guest (boolean): True if the the new user should be guest,
false to add a regular user account.
appservice_id (str|None): The ID of the appservice registering the user.
create_profile_with_displayname (unicode|None): Optionally create a
profile for the user, setting their displayname to the given value
admin (boolean): is an admin user?
user_type (str|None): type of user. One of the values from
api.constants.UserTypes, or None for a normal user.
"""
return {
"token": token,
"password_hash": password_hash,
"was_guest": was_guest,
"make_guest": make_guest,
"appservice_id": appservice_id,
"create_profile_with_displayname": create_profile_with_displayname,
"admin": admin,
"user_type": user_type,
}
@defer.inlineCallbacks
def _handle_request(self, request, user_id):
content = parse_json_object_from_request(request)
yield self.store.register(
user_id=user_id,
token=content["token"],
password_hash=content["password_hash"],
was_guest=content["was_guest"],
make_guest=content["make_guest"],
appservice_id=content["appservice_id"],
create_profile_with_displayname=content["create_profile_with_displayname"],
admin=content["admin"],
user_type=content["user_type"],
)
defer.returnValue((200, {}))
def register_servlets(hs, http_server):
ReplicationRegisterServlet(hs).register(http_server)

View File

@ -17,7 +17,7 @@ import logging
import six
from synapse.storage._base import SQLBaseStore
from synapse.storage._base import _CURRENT_STATE_CACHE_NAME, SQLBaseStore
from synapse.storage.engines import PostgresEngine
from ._slaved_id_tracker import SlavedIdTracker
@ -54,12 +54,17 @@ class BaseSlavedStore(SQLBaseStore):
if stream_name == "caches":
self._cache_id_gen.advance(token)
for row in rows:
try:
getattr(self, row.cache_func).invalidate(tuple(row.keys))
except AttributeError:
# We probably haven't pulled in the cache in this worker,
# which is fine.
pass
if row.cache_func == _CURRENT_STATE_CACHE_NAME:
room_id = row.keys[0]
members_changed = set(row.keys[1:])
self._invalidate_state_caches(room_id, members_changed)
else:
try:
getattr(self, row.cache_func).invalidate(tuple(row.keys))
except AttributeError:
# We probably haven't pulled in the cache in this worker,
# which is fine.
pass
def _invalidate_cache_and_stream(self, txn, cache_func, keys):
txn.call_after(cache_func.invalidate, keys)

View File

@ -94,7 +94,7 @@ class LoginRestServlet(ClientV1RestServlet):
self.jwt_algorithm = hs.config.jwt_algorithm
self.cas_enabled = hs.config.cas_enabled
self.auth_handler = self.hs.get_auth_handler()
self.device_handler = self.hs.get_device_handler()
self.registration_handler = hs.get_handlers().registration_handler
self.handlers = hs.get_handlers()
self._well_known_builder = WellKnownBuilder(hs)
@ -220,11 +220,10 @@ class LoginRestServlet(ClientV1RestServlet):
login_submission,
)
device_id = yield self._register_device(
canonical_user_id, login_submission,
)
access_token = yield auth_handler.get_access_token_for_user_id(
canonical_user_id, device_id,
device_id = login_submission.get("device_id")
initial_display_name = login_submission.get("initial_device_display_name")
device_id, access_token = yield self.registration_handler.register_device(
canonical_user_id, device_id, initial_display_name,
)
result = {
@ -246,10 +245,13 @@ class LoginRestServlet(ClientV1RestServlet):
user_id = (
yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
)
device_id = yield self._register_device(user_id, login_submission)
access_token = yield auth_handler.get_access_token_for_user_id(
user_id, device_id,
device_id = login_submission.get("device_id")
initial_display_name = login_submission.get("initial_device_display_name")
device_id, access_token = yield self.registration_handler.register_device(
user_id, device_id, initial_display_name,
)
result = {
"user_id": user_id, # may have changed
"access_token": access_token,
@ -286,11 +288,10 @@ class LoginRestServlet(ClientV1RestServlet):
auth_handler = self.auth_handler
registered_user_id = yield auth_handler.check_user_exists(user_id)
if registered_user_id:
device_id = yield self._register_device(
registered_user_id, login_submission
)
access_token = yield auth_handler.get_access_token_for_user_id(
registered_user_id, device_id,
device_id = login_submission.get("device_id")
initial_display_name = login_submission.get("initial_device_display_name")
device_id, access_token = yield self.registration_handler.register_device(
registered_user_id, device_id, initial_display_name,
)
result = {
@ -299,12 +300,16 @@ class LoginRestServlet(ClientV1RestServlet):
"home_server": self.hs.hostname,
}
else:
# TODO: we should probably check that the register isn't going
# to fonx/change our user_id before registering the device
device_id = yield self._register_device(user_id, login_submission)
user_id, access_token = (
yield self.handlers.registration_handler.register(localpart=user)
)
device_id = login_submission.get("device_id")
initial_display_name = login_submission.get("initial_device_display_name")
device_id, access_token = yield self.registration_handler.register_device(
registered_user_id, device_id, initial_display_name,
)
result = {
"user_id": user_id, # may have changed
"access_token": access_token,
@ -313,26 +318,6 @@ class LoginRestServlet(ClientV1RestServlet):
defer.returnValue(result)
def _register_device(self, user_id, login_submission):
"""Register a device for a user.
This is called after the user's credentials have been validated, but
before the access token has been issued.
Args:
(str) user_id: full canonical @user:id
(object) login_submission: dictionary supplied to /login call, from
which we pull device_id and initial_device_name
Returns:
defer.Deferred: (str) device_id
"""
device_id = login_submission.get("device_id")
initial_display_name = login_submission.get(
"initial_device_display_name")
return self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
)
class CasRedirectServlet(RestServlet):
PATTERNS = client_path_patterns("/login/(cas|sso)/redirect")

View File

@ -190,7 +190,6 @@ class RegisterRestServlet(RestServlet):
self.registration_handler = hs.get_handlers().registration_handler
self.identity_handler = hs.get_handlers().identity_handler
self.room_member_handler = hs.get_room_member_handler()
self.device_handler = hs.get_device_handler()
self.macaroon_gen = hs.get_macaroon_generator()
@interactive_auth_handler
@ -633,12 +632,10 @@ class RegisterRestServlet(RestServlet):
"home_server": self.hs.hostname,
}
if not params.get("inhibit_login", False):
device_id = yield self._register_device(user_id, params)
access_token = (
yield self.auth_handler.get_access_token_for_user_id(
user_id, device_id=device_id,
)
device_id = params.get("device_id")
initial_display_name = params.get("initial_device_display_name")
device_id, access_token = yield self.registration_handler.register_device(
user_id, device_id, initial_display_name, is_guest=False,
)
result.update({
@ -647,26 +644,6 @@ class RegisterRestServlet(RestServlet):
})
defer.returnValue(result)
def _register_device(self, user_id, params):
"""Register a device for a user.
This is called after the user's credentials have been validated, but
before the access token has been issued.
Args:
(str) user_id: full canonical @user:id
(object) params: registration parameters, from which we pull
device_id and initial_device_name
Returns:
defer.Deferred: (str) device_id
"""
# register the user's device
device_id = params.get("device_id")
initial_display_name = params.get("initial_device_display_name")
return self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
)
@defer.inlineCallbacks
def _do_guest_registration(self, params):
if not self.hs.config.allow_guest_access:
@ -680,13 +657,10 @@ class RegisterRestServlet(RestServlet):
# we have nowhere to store it.
device_id = synapse.api.auth.GUEST_DEVICE_ID
initial_display_name = params.get("initial_device_display_name")
yield self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
device_id, access_token = yield self.registration_handler.register_device(
user_id, device_id, initial_display_name, is_guest=True,
)
access_token = self.macaroon_gen.generate_access_token(
user_id, ["guest = true"]
)
defer.returnValue((200, {
"user_id": user_id,
"device_id": device_id,

View File

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import logging
import sys
import threading
@ -28,6 +29,7 @@ from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.types import get_domain_from_id
from synapse.util.caches.descriptors import Cache
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.stringutils import exception_to_unicode
@ -64,6 +66,10 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
"event_search": "event_search_event_id_idx",
}
# This is a special cache name we use to batch multiple invalidations of caches
# based on the current state when notifying workers over replication.
_CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object
@ -1184,6 +1190,56 @@ class SQLBaseStore(object):
be invalidated.
"""
txn.call_after(cache_func.invalidate, keys)
self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
def _invalidate_state_caches_and_stream(self, txn, room_id, members_changed):
"""Special case invalidation of caches based on current state.
We special case this so that we can batch the cache invalidations into a
single replication poke.
Args:
txn
room_id (str): Room where state changed
members_changed (iterable[str]): The user_ids of members that have changed
"""
txn.call_after(self._invalidate_state_caches, room_id, members_changed)
keys = itertools.chain([room_id], members_changed)
self._send_invalidation_to_replication(
txn, _CURRENT_STATE_CACHE_NAME, keys,
)
def _invalidate_state_caches(self, room_id, members_changed):
"""Invalidates caches that are based on the current state, but does
not stream invalidations down replication.
Args:
room_id (str): Room where state changed
members_changed (iterable[str]): The user_ids of members that have
changed
"""
for member in members_changed:
self.get_rooms_for_user_with_stream_ordering.invalidate((member,))
for host in set(get_domain_from_id(u) for u in members_changed):
self.is_host_joined.invalidate((room_id, host))
self.was_host_joined.invalidate((room_id, host))
self.get_users_in_room.invalidate((room_id,))
self.get_room_summary.invalidate((room_id,))
self.get_current_state_ids.invalidate((room_id,))
def _send_invalidation_to_replication(self, txn, cache_name, keys):
"""Notifies replication that given cache has been invalidated.
Note that this does *not* invalidate the cache locally.
Args:
txn
cache_name (str)
keys (iterable[str])
"""
if isinstance(self.database_engine, PostgresEngine):
# get_next() returns a context manager which is designed to wrap
@ -1201,7 +1257,7 @@ class SQLBaseStore(object):
table="cache_invalidation_stream",
values={
"stream_id": stream_id,
"cache_func": cache_func.__name__,
"cache_func": cache_name,
"keys": list(keys),
"invalidation_ts": self.clock.time_msec(),
}

View File

@ -979,30 +979,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
if ev_type == EventTypes.Member
)
for member in members_changed:
self._invalidate_cache_and_stream(
txn, self.get_rooms_for_user_with_stream_ordering, (member,)
)
for host in set(get_domain_from_id(u) for u in members_changed):
self._invalidate_cache_and_stream(
txn, self.is_host_joined, (room_id, host)
)
self._invalidate_cache_and_stream(
txn, self.was_host_joined, (room_id, host)
)
self._invalidate_cache_and_stream(
txn, self.get_users_in_room, (room_id,)
)
self._invalidate_cache_and_stream(
txn, self.get_room_summary, (room_id,)
)
self._invalidate_cache_and_stream(
txn, self.get_current_state_ids, (room_id,)
)
self._invalidate_state_caches_and_stream(txn, room_id, members_changed)
def _update_forward_extremities_txn(self, txn, new_forward_extremities,
max_stream_order):

View File

@ -139,6 +139,162 @@ class RegistrationWorkerStore(SQLBaseStore):
)
return True if res == UserTypes.SUPPORT else False
def get_users_by_id_case_insensitive(self, user_id):
"""Gets users that match user_id case insensitively.
Returns a mapping of user_id -> password_hash.
"""
def f(txn):
sql = (
"SELECT name, password_hash FROM users"
" WHERE lower(name) = lower(?)"
)
txn.execute(sql, (user_id,))
return dict(txn)
return self.runInteraction("get_users_by_id_case_insensitive", f)
@defer.inlineCallbacks
def count_all_users(self):
"""Counts all users registered on the homeserver."""
def _count_users(txn):
txn.execute("SELECT COUNT(*) AS users FROM users")
rows = self.cursor_to_dict(txn)
if rows:
return rows[0]["users"]
return 0
ret = yield self.runInteraction("count_users", _count_users)
defer.returnValue(ret)
def count_daily_user_type(self):
"""
Counts 1) native non guest users
2) native guests users
3) bridged users
who registered on the homeserver in the past 24 hours
"""
def _count_daily_user_type(txn):
yesterday = int(self._clock.time()) - (60 * 60 * 24)
sql = """
SELECT user_type, COALESCE(count(*), 0) AS count FROM (
SELECT
CASE
WHEN is_guest=0 AND appservice_id IS NULL THEN 'native'
WHEN is_guest=1 AND appservice_id IS NULL THEN 'guest'
WHEN is_guest=0 AND appservice_id IS NOT NULL THEN 'bridged'
END AS user_type
FROM users
WHERE creation_ts > ?
) AS t GROUP BY user_type
"""
results = {'native': 0, 'guest': 0, 'bridged': 0}
txn.execute(sql, (yesterday,))
for row in txn:
results[row[0]] = row[1]
return results
return self.runInteraction("count_daily_user_type", _count_daily_user_type)
@defer.inlineCallbacks
def count_nonbridged_users(self):
def _count_users(txn):
txn.execute("""
SELECT COALESCE(COUNT(*), 0) FROM users
WHERE appservice_id IS NULL
""")
count, = txn.fetchone()
return count
ret = yield self.runInteraction("count_users", _count_users)
defer.returnValue(ret)
@defer.inlineCallbacks
def find_next_generated_user_id_localpart(self):
"""
Gets the localpart of the next generated user ID.
Generated user IDs are integers, and we aim for them to be as small as
we can. Unfortunately, it's possible some of them are already taken by
existing users, and there may be gaps in the already taken range. This
function returns the start of the first allocatable gap. This is to
avoid the case of ID 10000000 being pre-allocated, so us wasting the
first (and shortest) many generated user IDs.
"""
def _find_next_generated_user_id(txn):
txn.execute("SELECT name FROM users")
regex = re.compile(r"^@(\d+):")
found = set()
for user_id, in txn:
match = regex.search(user_id)
if match:
found.add(int(match.group(1)))
for i in range(len(found) + 1):
if i not in found:
return i
defer.returnValue((yield self.runInteraction(
"find_next_generated_user_id",
_find_next_generated_user_id
)))
@defer.inlineCallbacks
def get_3pid_guest_access_token(self, medium, address):
ret = yield self._simple_select_one(
"threepid_guest_access_tokens",
{
"medium": medium,
"address": address
},
["guest_access_token"], True, 'get_3pid_guest_access_token'
)
if ret:
defer.returnValue(ret["guest_access_token"])
defer.returnValue(None)
@defer.inlineCallbacks
def get_user_id_by_threepid(self, medium, address):
"""Returns user id from threepid
Args:
medium (str): threepid medium e.g. email
address (str): threepid address e.g. me@example.com
Returns:
Deferred[str|None]: user id or None if no user id/threepid mapping exists
"""
user_id = yield self.runInteraction(
"get_user_id_by_threepid", self.get_user_id_by_threepid_txn,
medium, address
)
defer.returnValue(user_id)
def get_user_id_by_threepid_txn(self, txn, medium, address):
"""Returns user id from threepid
Args:
txn (cursor):
medium (str): threepid medium e.g. email
address (str): threepid address e.g. me@example.com
Returns:
str|None: user id or None if no user id/threepid mapping exists
"""
ret = self._simple_select_one_txn(
txn,
"user_threepids",
{
"medium": medium,
"address": address
},
['user_id'], True
)
if ret:
return ret['user_id']
return None
class RegistrationStore(RegistrationWorkerStore,
background_updates.BackgroundUpdateStore):
@ -326,20 +482,6 @@ class RegistrationStore(RegistrationWorkerStore,
)
txn.call_after(self.is_guest.invalidate, (user_id,))
def get_users_by_id_case_insensitive(self, user_id):
"""Gets users that match user_id case insensitively.
Returns a mapping of user_id -> password_hash.
"""
def f(txn):
sql = (
"SELECT name, password_hash FROM users"
" WHERE lower(name) = lower(?)"
)
txn.execute(sql, (user_id,))
return dict(txn)
return self.runInteraction("get_users_by_id_case_insensitive", f)
def user_set_password_hash(self, user_id, password_hash):
"""
NB. This does *not* evict any cache because the one use for this
@ -512,47 +654,6 @@ class RegistrationStore(RegistrationWorkerStore,
)
defer.returnValue(ret)
@defer.inlineCallbacks
def get_user_id_by_threepid(self, medium, address):
"""Returns user id from threepid
Args:
medium (str): threepid medium e.g. email
address (str): threepid address e.g. me@example.com
Returns:
Deferred[str|None]: user id or None if no user id/threepid mapping exists
"""
user_id = yield self.runInteraction(
"get_user_id_by_threepid", self.get_user_id_by_threepid_txn,
medium, address
)
defer.returnValue(user_id)
def get_user_id_by_threepid_txn(self, txn, medium, address):
"""Returns user id from threepid
Args:
txn (cursor):
medium (str): threepid medium e.g. email
address (str): threepid address e.g. me@example.com
Returns:
str|None: user id or None if no user id/threepid mapping exists
"""
ret = self._simple_select_one_txn(
txn,
"user_threepids",
{
"medium": medium,
"address": address
},
['user_id'], True
)
if ret:
return ret['user_id']
return None
def user_delete_threepid(self, user_id, medium, address):
return self._simple_delete(
"user_threepids",
@ -564,107 +665,6 @@ class RegistrationStore(RegistrationWorkerStore,
desc="user_delete_threepids",
)
@defer.inlineCallbacks
def count_all_users(self):
"""Counts all users registered on the homeserver."""
def _count_users(txn):
txn.execute("SELECT COUNT(*) AS users FROM users")
rows = self.cursor_to_dict(txn)
if rows:
return rows[0]["users"]
return 0
ret = yield self.runInteraction("count_users", _count_users)
defer.returnValue(ret)
def count_daily_user_type(self):
"""
Counts 1) native non guest users
2) native guests users
3) bridged users
who registered on the homeserver in the past 24 hours
"""
def _count_daily_user_type(txn):
yesterday = int(self._clock.time()) - (60 * 60 * 24)
sql = """
SELECT user_type, COALESCE(count(*), 0) AS count FROM (
SELECT
CASE
WHEN is_guest=0 AND appservice_id IS NULL THEN 'native'
WHEN is_guest=1 AND appservice_id IS NULL THEN 'guest'
WHEN is_guest=0 AND appservice_id IS NOT NULL THEN 'bridged'
END AS user_type
FROM users
WHERE creation_ts > ?
) AS t GROUP BY user_type
"""
results = {'native': 0, 'guest': 0, 'bridged': 0}
txn.execute(sql, (yesterday,))
for row in txn:
results[row[0]] = row[1]
return results
return self.runInteraction("count_daily_user_type", _count_daily_user_type)
@defer.inlineCallbacks
def count_nonbridged_users(self):
def _count_users(txn):
txn.execute("""
SELECT COALESCE(COUNT(*), 0) FROM users
WHERE appservice_id IS NULL
""")
count, = txn.fetchone()
return count
ret = yield self.runInteraction("count_users", _count_users)
defer.returnValue(ret)
@defer.inlineCallbacks
def find_next_generated_user_id_localpart(self):
"""
Gets the localpart of the next generated user ID.
Generated user IDs are integers, and we aim for them to be as small as
we can. Unfortunately, it's possible some of them are already taken by
existing users, and there may be gaps in the already taken range. This
function returns the start of the first allocatable gap. This is to
avoid the case of ID 10000000 being pre-allocated, so us wasting the
first (and shortest) many generated user IDs.
"""
def _find_next_generated_user_id(txn):
txn.execute("SELECT name FROM users")
regex = re.compile(r"^@(\d+):")
found = set()
for user_id, in txn:
match = regex.search(user_id)
if match:
found.add(int(match.group(1)))
for i in range(len(found) + 1):
if i not in found:
return i
defer.returnValue((yield self.runInteraction(
"find_next_generated_user_id",
_find_next_generated_user_id
)))
@defer.inlineCallbacks
def get_3pid_guest_access_token(self, medium, address):
ret = yield self._simple_select_one(
"threepid_guest_access_tokens",
{
"medium": medium,
"address": address
},
["guest_access_token"], True, 'get_3pid_guest_access_token'
)
if ret:
defer.returnValue(ret["guest_access_token"])
defer.returnValue(None)
@defer.inlineCallbacks
def save_or_get_3pid_guest_access_token(
self, medium, address, access_token, inviter_user_id

View File

@ -1,10 +1,7 @@
import json
from mock import Mock
from twisted.python import failure
from synapse.api.errors import InteractiveAuthIncompleteError
from synapse.api.constants import LoginType
from synapse.appservice import ApplicationService
from synapse.rest.client.v2_alpha.register import register_servlets
from tests import unittest
@ -18,50 +15,28 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.url = b"/_matrix/client/r0/register"
self.appservice = None
self.auth = Mock(
get_appservice_by_req=Mock(side_effect=lambda x: self.appservice)
)
self.auth_result = failure.Failure(InteractiveAuthIncompleteError(None))
self.auth_handler = Mock(
check_auth=Mock(side_effect=lambda x, y, z: self.auth_result),
get_session_data=Mock(return_value=None),
)
self.registration_handler = Mock()
self.identity_handler = Mock()
self.login_handler = Mock()
self.device_handler = Mock()
self.device_handler.check_device_registered = Mock(return_value="FAKE")
self.datastore = Mock(return_value=Mock())
self.datastore.get_current_state_deltas = Mock(return_value=[])
# do the dance to hook it up to the hs global
self.handlers = Mock(
registration_handler=self.registration_handler,
identity_handler=self.identity_handler,
login_handler=self.login_handler,
)
self.hs = self.setup_test_homeserver()
self.hs.get_auth = Mock(return_value=self.auth)
self.hs.get_handlers = Mock(return_value=self.handlers)
self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
self.hs.get_device_handler = Mock(return_value=self.device_handler)
self.hs.get_datastore = Mock(return_value=self.datastore)
self.hs.config.enable_registration = True
self.hs.config.registrations_require_3pid = []
self.hs.config.auto_join_rooms = []
self.hs.config.enable_registration_captcha = False
return self.hs
def test_POST_appservice_registration_valid(self):
user_id = "@kermit:muppet"
token = "kermits_access_token"
self.appservice = {"id": "1234"}
self.registration_handler.appservice_register = Mock(return_value=user_id)
self.auth_handler.get_access_token_for_user_id = Mock(return_value=token)
request_data = json.dumps({"username": "kermit"})
user_id = "@as_user_kermit:test"
as_token = "i_am_an_app_service"
appservice = ApplicationService(
as_token, self.hs.config.hostname,
id="1234",
namespaces={
"users": [{"regex": r"@as_user.*", "exclusive": True}],
},
)
self.hs.get_datastore().services_cache.append(appservice)
request_data = json.dumps({"username": "as_user_kermit"})
request, channel = self.make_request(
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
@ -71,7 +46,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result)
det_data = {
"user_id": user_id,
"access_token": token,
"home_server": self.hs.hostname,
}
self.assertDictContainsSubset(det_data, channel.json_body)
@ -103,39 +77,30 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.json_body["error"], "Invalid username")
def test_POST_user_valid(self):
user_id = "@kermit:muppet"
token = "kermits_access_token"
user_id = "@kermit:test"
device_id = "frogfone"
request_data = json.dumps(
{"username": "kermit", "password": "monkey", "device_id": device_id}
)
self.registration_handler.check_username = Mock(return_value=True)
self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
self.registration_handler.register = Mock(return_value=(user_id, None))
self.auth_handler.get_access_token_for_user_id = Mock(return_value=token)
self.device_handler.check_device_registered = Mock(return_value=device_id)
params = {
"username": "kermit",
"password": "monkey",
"device_id": device_id,
"auth": {"type": LoginType.DUMMY},
}
request_data = json.dumps(params)
request, channel = self.make_request(b"POST", self.url, request_data)
self.render(request)
det_data = {
"user_id": user_id,
"access_token": token,
"home_server": self.hs.hostname,
"device_id": device_id,
}
self.assertEquals(channel.result["code"], b"200", channel.result)
self.assertDictContainsSubset(det_data, channel.json_body)
self.auth_handler.get_login_tuple_for_user_id(
user_id, device_id=device_id, initial_device_display_name=None
)
def test_POST_disabled_registration(self):
self.hs.config.enable_registration = False
request_data = json.dumps({"username": "kermit", "password": "monkey"})
self.registration_handler.check_username = Mock(return_value=True)
self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
self.registration_handler.register = Mock(return_value=("@user:id", "t"))
request, channel = self.make_request(b"POST", self.url, request_data)
self.render(request)
@ -144,16 +109,13 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.json_body["error"], "Registration has been disabled")
def test_POST_guest_registration(self):
user_id = "a@b"
self.hs.config.macaroon_secret_key = "test"
self.hs.config.allow_guest_access = True
self.registration_handler.register = Mock(return_value=(user_id, None))
request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
self.render(request)
det_data = {
"user_id": user_id,
"home_server": self.hs.hostname,
"device_id": "guest_device",
}