Merge branch 'develop' of github.com:matrix-org/synapse into matrix-org-hotfixes

matrix-org-hotfixes-identity
Neil Johnson 2019-03-27 09:18:35 +00:00
commit 1048e2ca6a
15 changed files with 171 additions and 45 deletions

1
changelog.d/4793.feature Normal file
View File

@ -0,0 +1 @@
Synapse is now permissive about trailing slashes on some of its federation endpoints, allowing zero or more to be present.

1
changelog.d/4931.feature Normal file
View File

@ -0,0 +1 @@
Add ability for password providers to login/register a user via 3PID (email, phone).

1
changelog.d/4944.feature Normal file
View File

@ -0,0 +1 @@
The user directory has been rewritten to make it faster, with less chance of falling behind on a large server.

View File

@ -75,6 +75,20 @@ Password auth provider classes may optionally provide the following methods.
result from the ``/login`` call (including ``access_token``, ``device_id``, result from the ``/login`` call (including ``access_token``, ``device_id``,
etc.) etc.)
``someprovider.check_3pid_auth``\(*medium*, *address*, *password*)
This method, if implemented, is called when a user attempts to register or
log in with a third party identifier, such as email. It is passed the
medium (ex. "email"), an address (ex. "jdoe@example.com") and the user's
password.
The method should return a Twisted ``Deferred`` object, which resolves to
a ``str`` containing the user's (canonical) User ID if authentication was
successful, and ``None`` if not.
As with ``check_auth``, the ``Deferred`` may alternatively resolve to a
``(user_id, callback)`` tuple.
``someprovider.check_password``\(*user_id*, *password*) ``someprovider.check_password``\(*user_id*, *password*)
This method provides a simpler interface than ``get_supported_login_types`` This method provides a simpler interface than ``get_supported_login_types``

View File

@ -173,7 +173,7 @@ class TransportLayerClient(object):
# generated by the json_data_callback. # generated by the json_data_callback.
json_data = transaction.get_dict() json_data = transaction.get_dict()
path = _create_v1_path("/send/%s/", transaction.transaction_id) path = _create_v1_path("/send/%s", transaction.transaction_id)
response = yield self.client.put_json( response = yield self.client.put_json(
transaction.destination, transaction.destination,

View File

@ -312,7 +312,7 @@ class BaseFederationServlet(object):
class FederationSendServlet(BaseFederationServlet): class FederationSendServlet(BaseFederationServlet):
PATH = "/send/(?P<transaction_id>[^/]*)/" PATH = "/send/(?P<transaction_id>[^/]*)/?"
def __init__(self, handler, server_name, **kwargs): def __init__(self, handler, server_name, **kwargs):
super(FederationSendServlet, self).__init__( super(FederationSendServlet, self).__init__(
@ -378,7 +378,7 @@ class FederationSendServlet(BaseFederationServlet):
class FederationEventServlet(BaseFederationServlet): class FederationEventServlet(BaseFederationServlet):
PATH = "/event/(?P<event_id>[^/]*)/" PATH = "/event/(?P<event_id>[^/]*)/?"
# This is when someone asks for a data item for a given server data_id pair. # This is when someone asks for a data item for a given server data_id pair.
def on_GET(self, origin, content, query, event_id): def on_GET(self, origin, content, query, event_id):
@ -386,7 +386,7 @@ class FederationEventServlet(BaseFederationServlet):
class FederationStateServlet(BaseFederationServlet): class FederationStateServlet(BaseFederationServlet):
PATH = "/state/(?P<context>[^/]*)/" PATH = "/state/(?P<context>[^/]*)/?"
# This is when someone asks for all data for a given context. # This is when someone asks for all data for a given context.
def on_GET(self, origin, content, query, context): def on_GET(self, origin, content, query, context):
@ -398,7 +398,7 @@ class FederationStateServlet(BaseFederationServlet):
class FederationStateIdsServlet(BaseFederationServlet): class FederationStateIdsServlet(BaseFederationServlet):
PATH = "/state_ids/(?P<room_id>[^/]*)/" PATH = "/state_ids/(?P<room_id>[^/]*)/?"
def on_GET(self, origin, content, query, room_id): def on_GET(self, origin, content, query, room_id):
return self.handler.on_state_ids_request( return self.handler.on_state_ids_request(
@ -409,7 +409,7 @@ class FederationStateIdsServlet(BaseFederationServlet):
class FederationBackfillServlet(BaseFederationServlet): class FederationBackfillServlet(BaseFederationServlet):
PATH = "/backfill/(?P<context>[^/]*)/" PATH = "/backfill/(?P<context>[^/]*)/?"
def on_GET(self, origin, content, query, context): def on_GET(self, origin, content, query, context):
versions = [x.decode('ascii') for x in query[b"v"]] versions = [x.decode('ascii') for x in query[b"v"]]
@ -1080,7 +1080,7 @@ class FederationGroupsCategoriesServlet(BaseFederationServlet):
"""Get all categories for a group """Get all categories for a group
""" """
PATH = ( PATH = (
"/groups/(?P<group_id>[^/]*)/categories/" "/groups/(?P<group_id>[^/]*)/categories/?"
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -1150,7 +1150,7 @@ class FederationGroupsRolesServlet(BaseFederationServlet):
"""Get roles in a group """Get roles in a group
""" """
PATH = ( PATH = (
"/groups/(?P<group_id>[^/]*)/roles/" "/groups/(?P<group_id>[^/]*)/roles/?"
) )
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -745,6 +745,42 @@ class AuthHandler(BaseHandler):
errcode=Codes.FORBIDDEN errcode=Codes.FORBIDDEN
) )
@defer.inlineCallbacks
def check_password_provider_3pid(self, medium, address, password):
"""Check if a password provider is able to validate a thirdparty login
Args:
medium (str): The medium of the 3pid (ex. email).
address (str): The address of the 3pid (ex. jdoe@example.com).
password (str): The password of the user.
Returns:
Deferred[(str|None, func|None)]: A tuple of `(user_id,
callback)`. If authentication is successful, `user_id` is a `str`
containing the authenticated, canonical user ID. `callback` is
then either a function to be later run after the server has
completed login/registration, or `None`. If authentication was
unsuccessful, `user_id` and `callback` are both `None`.
"""
for provider in self.password_providers:
if hasattr(provider, "check_3pid_auth"):
# This function is able to return a deferred that either
# resolves None, meaning authentication failure, or upon
# success, to a str (which is the user_id) or a tuple of
# (user_id, callback_func), where callback_func should be run
# after we've finished everything else
result = yield provider.check_3pid_auth(
medium, address, password,
)
if result:
# Check if the return value is a str or a tuple
if isinstance(result, str):
# If it's a str, set callback function to None
result = (result, None)
defer.returnValue(result)
defer.returnValue((None, None))
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_local_password(self, user_id, password): def _check_local_password(self, user_id, password):
"""Authenticate a user against the local password database. """Authenticate a user against the local password database.
@ -756,7 +792,8 @@ class AuthHandler(BaseHandler):
user_id (unicode): complete @user:id user_id (unicode): complete @user:id
password (unicode): the provided password password (unicode): the provided password
Returns: Returns:
(unicode) the canonical_user_id, or None if unknown user / bad password Deferred[unicode] the canonical_user_id, or Deferred[None] if
unknown user/bad password
Raises: Raises:
LimitExceededError if the ratelimiter's login requests count for this LimitExceededError if the ratelimiter's login requests count for this

View File

@ -147,8 +147,14 @@ class BaseProfileHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def set_displayname(self, target_user, requester, new_displayname, by_admin=False): def set_displayname(self, target_user, requester, new_displayname, by_admin=False):
"""target_user is the user whose displayname is to be changed; """Set the displayname of a user
auth_user is the user attempting to make this change."""
Args:
target_user (UserID): the user whose displayname is to be changed.
requester (Requester): The user attempting to make this change.
new_displayname (str): The displayname to give this user.
by_admin (bool): Whether this change was made by an administrator.
"""
if not self.hs.is_mine(target_user): if not self.hs.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this Home Server") raise SynapseError(400, "User is not hosted on this Home Server")

View File

@ -171,7 +171,7 @@ class RegistrationHandler(BaseHandler):
api.constants.UserTypes, or None for a normal user. api.constants.UserTypes, or None for a normal user.
default_display_name (unicode|None): if set, the new user's displayname default_display_name (unicode|None): if set, the new user's displayname
will be set to this. Defaults to 'localpart'. will be set to this. Defaults to 'localpart'.
address (str|None): the IP address used to perform the regitration. address (str|None): the IP address used to perform the registration.
Returns: Returns:
A tuple of (user_id, access_token). A tuple of (user_id, access_token).
Raises: Raises:
@ -623,7 +623,7 @@ class RegistrationHandler(BaseHandler):
admin (boolean): is an admin user? admin (boolean): is an admin user?
user_type (str|None): type of user. One of the values from user_type (str|None): type of user. One of the values from
api.constants.UserTypes, or None for a normal user. api.constants.UserTypes, or None for a normal user.
address (str|None): the IP address used to perform the regitration. address (str|None): the IP address used to perform the registration.
Returns: Returns:
Deferred Deferred
@ -721,9 +721,9 @@ class RegistrationHandler(BaseHandler):
access_token (str|None): The access token of the newly logged in access_token (str|None): The access token of the newly logged in
device, or None if `inhibit_login` enabled. device, or None if `inhibit_login` enabled.
bind_email (bool): Whether to bind the email with the identity bind_email (bool): Whether to bind the email with the identity
server server.
bind_msisdn (bool): Whether to bind the msisdn with the identity bind_msisdn (bool): Whether to bind the msisdn with the identity
server server.
""" """
if self.hs.config.worker_app: if self.hs.config.worker_app:
yield self._post_registration_client( yield self._post_registration_client(
@ -765,7 +765,7 @@ class RegistrationHandler(BaseHandler):
"""A user consented to the terms on registration """A user consented to the terms on registration
Args: Args:
user_id (str): The user ID that consented user_id (str): The user ID that consented.
consent_version (str): version of the policy the user has consent_version (str): version of the policy the user has
consented to. consented to.
""" """

View File

@ -231,6 +231,7 @@ class MatrixFederationHttpClient(object):
# Retry with a trailing slash if we received a 400 with # Retry with a trailing slash if we received a 400 with
# 'M_UNRECOGNIZED' which some endpoints can return when omitting a # 'M_UNRECOGNIZED' which some endpoints can return when omitting a
# trailing slash on Synapse <= v0.99.3. # trailing slash on Synapse <= v0.99.3.
logger.info("Retrying request with trailing slash")
request.path += "/" request.path += "/"
response = yield self._send_request( response = yield self._send_request(

View File

@ -73,14 +73,26 @@ class ModuleApi(object):
""" """
return self._auth_handler.check_user_exists(user_id) return self._auth_handler.check_user_exists(user_id)
def register(self, localpart): @defer.inlineCallbacks
"""Registers a new user with given localpart def register(self, localpart, displayname=None):
"""Registers a new user with given localpart and optional
displayname.
Args:
localpart (str): The localpart of the new user.
displayname (str|None): The displayname of the new user. If None,
the user's displayname will default to `localpart`.
Returns: Returns:
Deferred: a 2-tuple of (user_id, access_token) Deferred: a 2-tuple of (user_id, access_token)
""" """
# Register the user
reg = self.hs.get_registration_handler() reg = self.hs.get_registration_handler()
return reg.register(localpart=localpart) user_id, access_token = yield reg.register(
localpart=localpart, default_display_name=displayname,
)
defer.returnValue((user_id, access_token))
@defer.inlineCallbacks @defer.inlineCallbacks
def invalidate_access_token(self, access_token): def invalidate_access_token(self, access_token):

View File

@ -201,6 +201,24 @@ class LoginRestServlet(ClientV1RestServlet):
# We store all email addreses as lowercase in the DB. # We store all email addreses as lowercase in the DB.
# (See add_threepid in synapse/handlers/auth.py) # (See add_threepid in synapse/handlers/auth.py)
address = address.lower() address = address.lower()
# Check for login providers that support 3pid login types
canonical_user_id, callback_3pid = (
yield self.auth_handler.check_password_provider_3pid(
medium,
address,
login_submission["password"],
)
)
if canonical_user_id:
# Authentication through password provider and 3pid succeeded
result = yield self._register_device_with_callback(
canonical_user_id, login_submission, callback_3pid,
)
defer.returnValue(result)
# No password providers were able to handle this 3pid
# Check local store
user_id = yield self.hs.get_datastore().get_user_id_by_threepid( user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
medium, address, medium, address,
) )
@ -223,20 +241,43 @@ class LoginRestServlet(ClientV1RestServlet):
if "user" not in identifier: if "user" not in identifier:
raise SynapseError(400, "User identifier is missing 'user' key") raise SynapseError(400, "User identifier is missing 'user' key")
auth_handler = self.auth_handler canonical_user_id, callback = yield self.auth_handler.validate_login(
canonical_user_id, callback = yield auth_handler.validate_login(
identifier["user"], identifier["user"],
login_submission, login_submission,
) )
result = yield self._register_device_with_callback(
canonical_user_id, login_submission, callback,
)
defer.returnValue(result)
@defer.inlineCallbacks
def _register_device_with_callback(
self,
user_id,
login_submission,
callback=None,
):
""" Registers a device with a given user_id. Optionally run a callback
function after registration has completed.
Args:
user_id (str): ID of the user to register.
login_submission (dict): Dictionary of login information.
callback (func|None): Callback function to run after registration.
Returns:
result (Dict[str,str]): Dictionary of account information after
successful registration.
"""
device_id = login_submission.get("device_id") device_id = login_submission.get("device_id")
initial_display_name = login_submission.get("initial_device_display_name") initial_display_name = login_submission.get("initial_device_display_name")
device_id, access_token = yield self.registration_handler.register_device( device_id, access_token = yield self.registration_handler.register_device(
canonical_user_id, device_id, initial_display_name, user_id, device_id, initial_display_name,
) )
result = { result = {
"user_id": canonical_user_id, "user_id": user_id,
"access_token": access_token, "access_token": access_token,
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
"device_id": device_id, "device_id": device_id,

View File

@ -135,7 +135,12 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def _populate_user_directory_process_rooms(self, progress, batch_size): def _populate_user_directory_process_rooms(self, progress, batch_size):
"""
Args:
progress (dict)
batch_size (int): Maximum number of state events to process
per cycle.
"""
state = self.hs.get_state_handler() state = self.hs.get_state_handler()
# If we don't have progress filed, delete everything. # If we don't have progress filed, delete everything.
@ -143,13 +148,14 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
yield self.delete_all_from_user_dir() yield self.delete_all_from_user_dir()
def _get_next_batch(txn): def _get_next_batch(txn):
# Only fetch 250 rooms, so we don't fetch too many at once, even
# if those 250 rooms have less than batch_size state events.
sql = """ sql = """
SELECT room_id FROM %s SELECT room_id, events FROM %s
ORDER BY events DESC ORDER BY events DESC
LIMIT %s LIMIT 250
""" % ( """ % (
TEMP_TABLE + "_rooms", TEMP_TABLE + "_rooms",
str(batch_size),
) )
txn.execute(sql) txn.execute(sql)
rooms_to_work_on = txn.fetchall() rooms_to_work_on = txn.fetchall()
@ -157,8 +163,6 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
if not rooms_to_work_on: if not rooms_to_work_on:
return None return None
rooms_to_work_on = [x[0] for x in rooms_to_work_on]
# Get how many are left to process, so we can give status on how # Get how many are left to process, so we can give status on how
# far we are in processing # far we are in processing
txn.execute("SELECT COUNT(*) FROM " + TEMP_TABLE + "_rooms") txn.execute("SELECT COUNT(*) FROM " + TEMP_TABLE + "_rooms")
@ -180,7 +184,9 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
% (len(rooms_to_work_on), progress["remaining"]) % (len(rooms_to_work_on), progress["remaining"])
) )
for room_id in rooms_to_work_on: processed_event_count = 0
for room_id, event_count in rooms_to_work_on:
is_in_room = yield self.is_host_joined(room_id, self.server_name) is_in_room = yield self.is_host_joined(room_id, self.server_name)
if is_in_room: if is_in_room:
@ -247,7 +253,13 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
progress, progress,
) )
defer.returnValue(len(rooms_to_work_on)) processed_event_count += event_count
if processed_event_count > batch_size:
# Don't process any more rooms, we've hit our batch size.
defer.returnValue(processed_event_count)
defer.returnValue(processed_event_count)
@defer.inlineCallbacks @defer.inlineCallbacks
def _populate_user_directory_process_users(self, progress, batch_size): def _populate_user_directory_process_users(self, progress, batch_size):

View File

@ -180,7 +180,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
put_json = self.hs.get_http_client().put_json put_json = self.hs.get_http_client().put_json
put_json.assert_called_once_with( put_json.assert_called_once_with(
"farm", "farm",
path="/_matrix/federation/v1/send/1000000/", path="/_matrix/federation/v1/send/1000000",
data=_expect_edu_transaction( data=_expect_edu_transaction(
"m.typing", "m.typing",
content={ content={
@ -202,7 +202,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
(request, channel) = self.make_request( (request, channel) = self.make_request(
"PUT", "PUT",
"/_matrix/federation/v1/send/1000000/", "/_matrix/federation/v1/send/1000000",
_make_edu_transaction_json( _make_edu_transaction_json(
"m.typing", "m.typing",
content={ content={
@ -258,7 +258,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
put_json = self.hs.get_http_client().put_json put_json = self.hs.get_http_client().put_json
put_json.assert_called_once_with( put_json.assert_called_once_with(
"farm", "farm",
path="/_matrix/federation/v1/send/1000000/", path="/_matrix/federation/v1/send/1000000",
data=_expect_edu_transaction( data=_expect_edu_transaction(
"m.typing", "m.typing",
content={ content={