Merge branch 'develop' of github.com:matrix-org/synapse into matrix-org-hotfixes

matrix-org-hotfixes-identity
Neil Johnson 2019-03-27 09:18:35 +00:00
commit 1048e2ca6a
15 changed files with 171 additions and 45 deletions

1
changelog.d/4793.feature Normal file
View File

@ -0,0 +1 @@
Synapse is now permissive about trailing slashes on some of its federation endpoints, allowing zero or more to be present.

1
changelog.d/4931.feature Normal file
View File

@ -0,0 +1 @@
Add ability for password providers to login/register a user via 3PID (email, phone).

1
changelog.d/4944.feature Normal file
View File

@ -0,0 +1 @@
The user directory has been rewritten to make it faster, with less chance of falling behind on a large server.

View File

@ -75,6 +75,20 @@ Password auth provider classes may optionally provide the following methods.
result from the ``/login`` call (including ``access_token``, ``device_id``, result from the ``/login`` call (including ``access_token``, ``device_id``,
etc.) etc.)
``someprovider.check_3pid_auth``\(*medium*, *address*, *password*)
This method, if implemented, is called when a user attempts to register or
log in with a third party identifier, such as email. It is passed the
medium (ex. "email"), an address (ex. "jdoe@example.com") and the user's
password.
The method should return a Twisted ``Deferred`` object, which resolves to
a ``str`` containing the user's (canonical) User ID if authentication was
successful, and ``None`` if not.
As with ``check_auth``, the ``Deferred`` may alternatively resolve to a
``(user_id, callback)`` tuple.
``someprovider.check_password``\(*user_id*, *password*) ``someprovider.check_password``\(*user_id*, *password*)
This method provides a simpler interface than ``get_supported_login_types`` This method provides a simpler interface than ``get_supported_login_types``

View File

@ -621,13 +621,13 @@ class Auth(object):
Returns: Returns:
True if the the sender is allowed to redact the target event if the True if the the sender is allowed to redact the target event if the
target event was created by them. target event was created by them.
False if the sender is allowed to redact the target event with no False if the sender is allowed to redact the target event with no
further checks. further checks.
Raises: Raises:
AuthError if the event sender is definitely not allowed to redact AuthError if the event sender is definitely not allowed to redact
the target event. the target event.
""" """
return event_auth.check_redaction(room_version, event, auth_events) return event_auth.check_redaction(room_version, event, auth_events)
@ -743,9 +743,9 @@ class Auth(object):
Returns: Returns:
Deferred[tuple[str, str|None]]: Resolves to the current membership of Deferred[tuple[str, str|None]]: Resolves to the current membership of
the user in the room and the membership event ID of the user. If the user in the room and the membership event ID of the user. If
the user is not in the room and never has been, then the user is not in the room and never has been, then
`(Membership.JOIN, None)` is returned. `(Membership.JOIN, None)` is returned.
""" """
try: try:
@ -777,13 +777,13 @@ class Auth(object):
Args: Args:
user_id(str|None): If present, checks for presence against existing user_id(str|None): If present, checks for presence against existing
MAU cohort MAU cohort
threepid(dict|None): If present, checks for presence against configured threepid(dict|None): If present, checks for presence against configured
reserved threepid. Used in cases where the user is trying register reserved threepid. Used in cases where the user is trying register
with a MAU blocked server, normally they would be rejected but their with a MAU blocked server, normally they would be rejected but their
threepid is on the reserved list. user_id and threepid is on the reserved list. user_id and
threepid should never be set at the same time. threepid should never be set at the same time.
""" """
# Never fail an auth check for the server notices users or support user # Never fail an auth check for the server notices users or support user

View File

@ -173,7 +173,7 @@ class TransportLayerClient(object):
# generated by the json_data_callback. # generated by the json_data_callback.
json_data = transaction.get_dict() json_data = transaction.get_dict()
path = _create_v1_path("/send/%s/", transaction.transaction_id) path = _create_v1_path("/send/%s", transaction.transaction_id)
response = yield self.client.put_json( response = yield self.client.put_json(
transaction.destination, transaction.destination,

View File

@ -312,7 +312,7 @@ class BaseFederationServlet(object):
class FederationSendServlet(BaseFederationServlet): class FederationSendServlet(BaseFederationServlet):
PATH = "/send/(?P<transaction_id>[^/]*)/" PATH = "/send/(?P<transaction_id>[^/]*)/?"
def __init__(self, handler, server_name, **kwargs): def __init__(self, handler, server_name, **kwargs):
super(FederationSendServlet, self).__init__( super(FederationSendServlet, self).__init__(
@ -378,7 +378,7 @@ class FederationSendServlet(BaseFederationServlet):
class FederationEventServlet(BaseFederationServlet): class FederationEventServlet(BaseFederationServlet):
PATH = "/event/(?P<event_id>[^/]*)/" PATH = "/event/(?P<event_id>[^/]*)/?"
# This is when someone asks for a data item for a given server data_id pair. # This is when someone asks for a data item for a given server data_id pair.
def on_GET(self, origin, content, query, event_id): def on_GET(self, origin, content, query, event_id):
@ -386,7 +386,7 @@ class FederationEventServlet(BaseFederationServlet):
class FederationStateServlet(BaseFederationServlet): class FederationStateServlet(BaseFederationServlet):
PATH = "/state/(?P<context>[^/]*)/" PATH = "/state/(?P<context>[^/]*)/?"
# This is when someone asks for all data for a given context. # This is when someone asks for all data for a given context.
def on_GET(self, origin, content, query, context): def on_GET(self, origin, content, query, context):
@ -398,7 +398,7 @@ class FederationStateServlet(BaseFederationServlet):
class FederationStateIdsServlet(BaseFederationServlet): class FederationStateIdsServlet(BaseFederationServlet):
PATH = "/state_ids/(?P<room_id>[^/]*)/" PATH = "/state_ids/(?P<room_id>[^/]*)/?"
def on_GET(self, origin, content, query, room_id): def on_GET(self, origin, content, query, room_id):
return self.handler.on_state_ids_request( return self.handler.on_state_ids_request(
@ -409,7 +409,7 @@ class FederationStateIdsServlet(BaseFederationServlet):
class FederationBackfillServlet(BaseFederationServlet): class FederationBackfillServlet(BaseFederationServlet):
PATH = "/backfill/(?P<context>[^/]*)/" PATH = "/backfill/(?P<context>[^/]*)/?"
def on_GET(self, origin, content, query, context): def on_GET(self, origin, content, query, context):
versions = [x.decode('ascii') for x in query[b"v"]] versions = [x.decode('ascii') for x in query[b"v"]]
@ -1080,7 +1080,7 @@ class FederationGroupsCategoriesServlet(BaseFederationServlet):
"""Get all categories for a group """Get all categories for a group
""" """
PATH = ( PATH = (
"/groups/(?P<group_id>[^/]*)/categories/" "/groups/(?P<group_id>[^/]*)/categories/?"
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -1150,7 +1150,7 @@ class FederationGroupsRolesServlet(BaseFederationServlet):
"""Get roles in a group """Get roles in a group
""" """
PATH = ( PATH = (
"/groups/(?P<group_id>[^/]*)/roles/" "/groups/(?P<group_id>[^/]*)/roles/?"
) )
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -745,6 +745,42 @@ class AuthHandler(BaseHandler):
errcode=Codes.FORBIDDEN errcode=Codes.FORBIDDEN
) )
@defer.inlineCallbacks
def check_password_provider_3pid(self, medium, address, password):
"""Check if a password provider is able to validate a thirdparty login
Args:
medium (str): The medium of the 3pid (ex. email).
address (str): The address of the 3pid (ex. jdoe@example.com).
password (str): The password of the user.
Returns:
Deferred[(str|None, func|None)]: A tuple of `(user_id,
callback)`. If authentication is successful, `user_id` is a `str`
containing the authenticated, canonical user ID. `callback` is
then either a function to be later run after the server has
completed login/registration, or `None`. If authentication was
unsuccessful, `user_id` and `callback` are both `None`.
"""
for provider in self.password_providers:
if hasattr(provider, "check_3pid_auth"):
# This function is able to return a deferred that either
# resolves None, meaning authentication failure, or upon
# success, to a str (which is the user_id) or a tuple of
# (user_id, callback_func), where callback_func should be run
# after we've finished everything else
result = yield provider.check_3pid_auth(
medium, address, password,
)
if result:
# Check if the return value is a str or a tuple
if isinstance(result, str):
# If it's a str, set callback function to None
result = (result, None)
defer.returnValue(result)
defer.returnValue((None, None))
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_local_password(self, user_id, password): def _check_local_password(self, user_id, password):
"""Authenticate a user against the local password database. """Authenticate a user against the local password database.
@ -756,7 +792,8 @@ class AuthHandler(BaseHandler):
user_id (unicode): complete @user:id user_id (unicode): complete @user:id
password (unicode): the provided password password (unicode): the provided password
Returns: Returns:
(unicode) the canonical_user_id, or None if unknown user / bad password Deferred[unicode] the canonical_user_id, or Deferred[None] if
unknown user/bad password
Raises: Raises:
LimitExceededError if the ratelimiter's login requests count for this LimitExceededError if the ratelimiter's login requests count for this

View File

@ -147,8 +147,14 @@ class BaseProfileHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def set_displayname(self, target_user, requester, new_displayname, by_admin=False): def set_displayname(self, target_user, requester, new_displayname, by_admin=False):
"""target_user is the user whose displayname is to be changed; """Set the displayname of a user
auth_user is the user attempting to make this change."""
Args:
target_user (UserID): the user whose displayname is to be changed.
requester (Requester): The user attempting to make this change.
new_displayname (str): The displayname to give this user.
by_admin (bool): Whether this change was made by an administrator.
"""
if not self.hs.is_mine(target_user): if not self.hs.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this Home Server") raise SynapseError(400, "User is not hosted on this Home Server")

View File

@ -171,7 +171,7 @@ class RegistrationHandler(BaseHandler):
api.constants.UserTypes, or None for a normal user. api.constants.UserTypes, or None for a normal user.
default_display_name (unicode|None): if set, the new user's displayname default_display_name (unicode|None): if set, the new user's displayname
will be set to this. Defaults to 'localpart'. will be set to this. Defaults to 'localpart'.
address (str|None): the IP address used to perform the regitration. address (str|None): the IP address used to perform the registration.
Returns: Returns:
A tuple of (user_id, access_token). A tuple of (user_id, access_token).
Raises: Raises:
@ -623,7 +623,7 @@ class RegistrationHandler(BaseHandler):
admin (boolean): is an admin user? admin (boolean): is an admin user?
user_type (str|None): type of user. One of the values from user_type (str|None): type of user. One of the values from
api.constants.UserTypes, or None for a normal user. api.constants.UserTypes, or None for a normal user.
address (str|None): the IP address used to perform the regitration. address (str|None): the IP address used to perform the registration.
Returns: Returns:
Deferred Deferred
@ -721,9 +721,9 @@ class RegistrationHandler(BaseHandler):
access_token (str|None): The access token of the newly logged in access_token (str|None): The access token of the newly logged in
device, or None if `inhibit_login` enabled. device, or None if `inhibit_login` enabled.
bind_email (bool): Whether to bind the email with the identity bind_email (bool): Whether to bind the email with the identity
server server.
bind_msisdn (bool): Whether to bind the msisdn with the identity bind_msisdn (bool): Whether to bind the msisdn with the identity
server server.
""" """
if self.hs.config.worker_app: if self.hs.config.worker_app:
yield self._post_registration_client( yield self._post_registration_client(
@ -765,7 +765,7 @@ class RegistrationHandler(BaseHandler):
"""A user consented to the terms on registration """A user consented to the terms on registration
Args: Args:
user_id (str): The user ID that consented user_id (str): The user ID that consented.
consent_version (str): version of the policy the user has consent_version (str): version of the policy the user has
consented to. consented to.
""" """

View File

@ -231,6 +231,7 @@ class MatrixFederationHttpClient(object):
# Retry with a trailing slash if we received a 400 with # Retry with a trailing slash if we received a 400 with
# 'M_UNRECOGNIZED' which some endpoints can return when omitting a # 'M_UNRECOGNIZED' which some endpoints can return when omitting a
# trailing slash on Synapse <= v0.99.3. # trailing slash on Synapse <= v0.99.3.
logger.info("Retrying request with trailing slash")
request.path += "/" request.path += "/"
response = yield self._send_request( response = yield self._send_request(

View File

@ -73,14 +73,26 @@ class ModuleApi(object):
""" """
return self._auth_handler.check_user_exists(user_id) return self._auth_handler.check_user_exists(user_id)
def register(self, localpart): @defer.inlineCallbacks
"""Registers a new user with given localpart def register(self, localpart, displayname=None):
"""Registers a new user with given localpart and optional
displayname.
Args:
localpart (str): The localpart of the new user.
displayname (str|None): The displayname of the new user. If None,
the user's displayname will default to `localpart`.
Returns: Returns:
Deferred: a 2-tuple of (user_id, access_token) Deferred: a 2-tuple of (user_id, access_token)
""" """
# Register the user
reg = self.hs.get_registration_handler() reg = self.hs.get_registration_handler()
return reg.register(localpart=localpart) user_id, access_token = yield reg.register(
localpart=localpart, default_display_name=displayname,
)
defer.returnValue((user_id, access_token))
@defer.inlineCallbacks @defer.inlineCallbacks
def invalidate_access_token(self, access_token): def invalidate_access_token(self, access_token):

View File

@ -201,6 +201,24 @@ class LoginRestServlet(ClientV1RestServlet):
# We store all email addreses as lowercase in the DB. # We store all email addreses as lowercase in the DB.
# (See add_threepid in synapse/handlers/auth.py) # (See add_threepid in synapse/handlers/auth.py)
address = address.lower() address = address.lower()
# Check for login providers that support 3pid login types
canonical_user_id, callback_3pid = (
yield self.auth_handler.check_password_provider_3pid(
medium,
address,
login_submission["password"],
)
)
if canonical_user_id:
# Authentication through password provider and 3pid succeeded
result = yield self._register_device_with_callback(
canonical_user_id, login_submission, callback_3pid,
)
defer.returnValue(result)
# No password providers were able to handle this 3pid
# Check local store
user_id = yield self.hs.get_datastore().get_user_id_by_threepid( user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
medium, address, medium, address,
) )
@ -223,20 +241,43 @@ class LoginRestServlet(ClientV1RestServlet):
if "user" not in identifier: if "user" not in identifier:
raise SynapseError(400, "User identifier is missing 'user' key") raise SynapseError(400, "User identifier is missing 'user' key")
auth_handler = self.auth_handler canonical_user_id, callback = yield self.auth_handler.validate_login(
canonical_user_id, callback = yield auth_handler.validate_login(
identifier["user"], identifier["user"],
login_submission, login_submission,
) )
result = yield self._register_device_with_callback(
canonical_user_id, login_submission, callback,
)
defer.returnValue(result)
@defer.inlineCallbacks
def _register_device_with_callback(
self,
user_id,
login_submission,
callback=None,
):
""" Registers a device with a given user_id. Optionally run a callback
function after registration has completed.
Args:
user_id (str): ID of the user to register.
login_submission (dict): Dictionary of login information.
callback (func|None): Callback function to run after registration.
Returns:
result (Dict[str,str]): Dictionary of account information after
successful registration.
"""
device_id = login_submission.get("device_id") device_id = login_submission.get("device_id")
initial_display_name = login_submission.get("initial_device_display_name") initial_display_name = login_submission.get("initial_device_display_name")
device_id, access_token = yield self.registration_handler.register_device( device_id, access_token = yield self.registration_handler.register_device(
canonical_user_id, device_id, initial_display_name, user_id, device_id, initial_display_name,
) )
result = { result = {
"user_id": canonical_user_id, "user_id": user_id,
"access_token": access_token, "access_token": access_token,
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
"device_id": device_id, "device_id": device_id,

View File

@ -135,7 +135,12 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def _populate_user_directory_process_rooms(self, progress, batch_size): def _populate_user_directory_process_rooms(self, progress, batch_size):
"""
Args:
progress (dict)
batch_size (int): Maximum number of state events to process
per cycle.
"""
state = self.hs.get_state_handler() state = self.hs.get_state_handler()
# If we don't have progress filed, delete everything. # If we don't have progress filed, delete everything.
@ -143,13 +148,14 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
yield self.delete_all_from_user_dir() yield self.delete_all_from_user_dir()
def _get_next_batch(txn): def _get_next_batch(txn):
# Only fetch 250 rooms, so we don't fetch too many at once, even
# if those 250 rooms have less than batch_size state events.
sql = """ sql = """
SELECT room_id FROM %s SELECT room_id, events FROM %s
ORDER BY events DESC ORDER BY events DESC
LIMIT %s LIMIT 250
""" % ( """ % (
TEMP_TABLE + "_rooms", TEMP_TABLE + "_rooms",
str(batch_size),
) )
txn.execute(sql) txn.execute(sql)
rooms_to_work_on = txn.fetchall() rooms_to_work_on = txn.fetchall()
@ -157,8 +163,6 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
if not rooms_to_work_on: if not rooms_to_work_on:
return None return None
rooms_to_work_on = [x[0] for x in rooms_to_work_on]
# Get how many are left to process, so we can give status on how # Get how many are left to process, so we can give status on how
# far we are in processing # far we are in processing
txn.execute("SELECT COUNT(*) FROM " + TEMP_TABLE + "_rooms") txn.execute("SELECT COUNT(*) FROM " + TEMP_TABLE + "_rooms")
@ -180,7 +184,9 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
% (len(rooms_to_work_on), progress["remaining"]) % (len(rooms_to_work_on), progress["remaining"])
) )
for room_id in rooms_to_work_on: processed_event_count = 0
for room_id, event_count in rooms_to_work_on:
is_in_room = yield self.is_host_joined(room_id, self.server_name) is_in_room = yield self.is_host_joined(room_id, self.server_name)
if is_in_room: if is_in_room:
@ -247,7 +253,13 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
progress, progress,
) )
defer.returnValue(len(rooms_to_work_on)) processed_event_count += event_count
if processed_event_count > batch_size:
# Don't process any more rooms, we've hit our batch size.
defer.returnValue(processed_event_count)
defer.returnValue(processed_event_count)
@defer.inlineCallbacks @defer.inlineCallbacks
def _populate_user_directory_process_users(self, progress, batch_size): def _populate_user_directory_process_users(self, progress, batch_size):

View File

@ -180,7 +180,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
put_json = self.hs.get_http_client().put_json put_json = self.hs.get_http_client().put_json
put_json.assert_called_once_with( put_json.assert_called_once_with(
"farm", "farm",
path="/_matrix/federation/v1/send/1000000/", path="/_matrix/federation/v1/send/1000000",
data=_expect_edu_transaction( data=_expect_edu_transaction(
"m.typing", "m.typing",
content={ content={
@ -202,7 +202,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
(request, channel) = self.make_request( (request, channel) = self.make_request(
"PUT", "PUT",
"/_matrix/federation/v1/send/1000000/", "/_matrix/federation/v1/send/1000000",
_make_edu_transaction_json( _make_edu_transaction_json(
"m.typing", "m.typing",
content={ content={
@ -258,7 +258,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
put_json = self.hs.get_http_client().put_json put_json = self.hs.get_http_client().put_json
put_json.assert_called_once_with( put_json.assert_called_once_with(
"farm", "farm",
path="/_matrix/federation/v1/send/1000000/", path="/_matrix/federation/v1/send/1000000",
data=_expect_edu_transaction( data=_expect_edu_transaction(
"m.typing", "m.typing",
content={ content={