Synchronise account metadata onto another server. (#4145)

* implement shadow registration via AS (untested)
* shadow support for 3pid binding/unbinding (untested)
pull/4385/head dinsic_2019-01-11
Matthew Hodgson 2019-01-11 15:50:28 +00:00 committed by Michael Kaye
parent 9cc95fd0a5
commit cf68593544
11 changed files with 254 additions and 62 deletions

View File

@ -189,6 +189,7 @@ class Auth(object):
# Can optionally look elsewhere in the request (e.g. headers) # Can optionally look elsewhere in the request (e.g. headers)
try: try:
user_id, app_service = yield self._get_appservice_user_id(request) user_id, app_service = yield self._get_appservice_user_id(request)
if user_id: if user_id:
request.authenticated_entity = user_id request.authenticated_entity = user_id
defer.returnValue( defer.returnValue(
@ -238,39 +239,40 @@ class Auth(object):
errcode=Codes.MISSING_TOKEN errcode=Codes.MISSING_TOKEN
) )
@defer.inlineCallbacks
def _get_appservice_user_id(self, request): def _get_appservice_user_id(self, request):
app_service = self.store.get_app_service_by_token( app_service = self.store.get_app_service_by_token(
self.get_access_token_from_request( self.get_access_token_from_request(
request, self.TOKEN_NOT_FOUND_HTTP_STATUS request, self.TOKEN_NOT_FOUND_HTTP_STATUS
) )
) )
if app_service is None: if app_service is None:
defer.returnValue((None, None)) return(None, None)
if app_service.ip_range_whitelist: if app_service.ip_range_whitelist:
ip_address = IPAddress(self.hs.get_ip_from_request(request)) ip_address = IPAddress(self.hs.get_ip_from_request(request))
if ip_address not in app_service.ip_range_whitelist: if ip_address not in app_service.ip_range_whitelist:
defer.returnValue((None, None)) return(None, None)
if b"user_id" not in request.args: if b"user_id" not in request.args:
defer.returnValue((app_service.sender, app_service)) return(app_service.sender, app_service)
user_id = request.args[b"user_id"][0].decode('utf8') user_id = request.args[b"user_id"][0].decode('utf8')
if app_service.sender == user_id: if app_service.sender == user_id:
defer.returnValue((app_service.sender, app_service)) return(app_service.sender, app_service)
if not app_service.is_interested_in_user(user_id): if not app_service.is_interested_in_user(user_id):
raise AuthError( raise AuthError(
403, 403,
"Application service cannot masquerade as this user." "Application service cannot masquerade as this user."
) )
if not (yield self.store.get_user_by_id(user_id)): # Let ASes manipulate nonexistent users (e.g. to shadow-register them)
raise AuthError( # if not (yield self.store.get_user_by_id(user_id)):
403, # raise AuthError(
"Application service has not registered this user" # 403,
) # "Application service has not registered this user"
defer.returnValue((user_id, app_service)) # )
return(user_id, app_service)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_user_by_access_token(self, token, rights="access"): def get_user_by_access_token(self, token, rights="access"):
@ -514,24 +516,9 @@ class Auth(object):
defer.returnValue(user_info) defer.returnValue(user_info)
def get_appservice_by_req(self, request): def get_appservice_by_req(self, request):
try: (user_id, app_service) = self._get_appservice_user_id(request)
token = self.get_access_token_from_request( request.authenticated_entity = app_service.sender
request, self.TOKEN_NOT_FOUND_HTTP_STATUS return app_service
)
service = self.store.get_app_service_by_token(token)
if not service:
logger.warn("Unrecognised appservice access token.")
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS,
"Unrecognised access token.",
errcode=Codes.UNKNOWN_TOKEN
)
request.authenticated_entity = service.sender
return defer.succeed(service)
except KeyError:
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token."
)
def is_server_admin(self, user): def is_server_admin(self, user):
""" Check if the given user is a local server admin. """ Check if the given user is a local server admin.

View File

@ -265,7 +265,7 @@ class ApplicationService(object):
def is_exclusive_room(self, room_id): def is_exclusive_room(self, room_id):
return self._is_exclusive(ApplicationService.NS_ROOMS, room_id) return self._is_exclusive(ApplicationService.NS_ROOMS, room_id)
def get_exlusive_user_regexes(self): def get_exclusive_user_regexes(self):
"""Get the list of regexes used to determine if a user is exclusively """Get the list of regexes used to determine if a user is exclusively
registered by the AS registered by the AS
""" """

View File

@ -64,6 +64,8 @@ class RegistrationConfig(Config):
if not isinstance(self.replicate_user_profiles_to, list): if not isinstance(self.replicate_user_profiles_to, list):
self.replicate_user_profiles_to = [self.replicate_user_profiles_to, ] self.replicate_user_profiles_to = [self.replicate_user_profiles_to, ]
self.shadow_server = config.get("shadow_server", None)
def default_config(self, **kwargs): def default_config(self, **kwargs):
registration_shared_secret = random_string_with_symbols(50) registration_shared_secret = random_string_with_symbols(50)
@ -141,6 +143,14 @@ class RegistrationConfig(Config):
# cross-homeserver user directories. # cross-homeserver user directories.
# replicate_user_profiles_to: example.com # replicate_user_profiles_to: example.com
# If specified, attempt to replay registrations, profile changes & 3pid
# bindings on the given target homeserver via the AS API. The HS is authed
# via a given AS token.
# shadow_server:
# hs_url: https://shadow.example.com
# hs: shadow.example.com
# as_token: 12u394refgbdhivsia
# If enabled, don't let users set their own display names/avatars # If enabled, don't let users set their own display names/avatars
# other than for the very first time (unless they are a server admin). # other than for the very first time (unless they are a server admin).
# Useful when provisioning users based on the contents of a 3rd party # Useful when provisioning users based on the contents of a 3rd party

View File

@ -52,6 +52,7 @@ class RegistrationHandler(BaseHandler):
self.user_directory_handler = hs.get_user_directory_handler() self.user_directory_handler = hs.get_user_directory_handler()
self.room_creation_handler = self.hs.get_room_creation_handler() self.room_creation_handler = self.hs.get_room_creation_handler()
self.captcha_client = CaptchaServerHttpClient(hs) self.captcha_client = CaptchaServerHttpClient(hs)
self.http_client = hs.get_simple_http_client()
self._next_generated_user_id = None self._next_generated_user_id = None
@ -273,7 +274,9 @@ class RegistrationHandler(BaseHandler):
defer.returnValue((user_id, token)) defer.returnValue((user_id, token))
@defer.inlineCallbacks @defer.inlineCallbacks
def appservice_register(self, user_localpart, as_token): def appservice_register(self, user_localpart, as_token, password, display_name):
# FIXME: this should be factored out and merged with normal register()
user = UserID(user_localpart, self.hs.hostname) user = UserID(user_localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
service = self.store.get_app_service_by_token(as_token) service = self.store.get_app_service_by_token(as_token)
@ -291,16 +294,26 @@ class RegistrationHandler(BaseHandler):
user_id, allowed_appservice=service user_id, allowed_appservice=service
) )
password_hash = ""
if password:
password_hash = yield self.auth_handler().hash(password)
yield self.store.register( yield self.store.register(
user_id=user_id, user_id=user_id,
password_hash="", password_hash=password_hash,
appservice_id=service_id, appservice_id=service_id,
) )
yield self.profile_handler.set_displayname( yield self.profile_handler.set_displayname(
user, None, user.localpart, by_admin=True, user, None, display_name or user.localpart, by_admin=True,
) )
if self.hs.config.user_directory_search_all_users:
profile = yield self.store.get_profileinfo(user_localpart)
yield self.user_directory_handler.handle_local_profile_change(
user_id, profile
)
defer.returnValue(user_id) defer.returnValue(user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -425,6 +438,39 @@ class RegistrationHandler(BaseHandler):
errcode=Codes.EXCLUSIVE errcode=Codes.EXCLUSIVE
) )
@defer.inlineCallbacks
def shadow_register(self, localpart, display_name, auth_result, params):
"""Invokes the current registration on another server, using
shared secret registration, passing in any auth_results from
other registration UI auth flows (e.g. validated 3pids)
Useful for setting up shadow/backup accounts on a parallel deployment.
"""
# TODO: retries
shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
as_token = self.hs.config.shadow_server.get("as_token")
yield self.http_client.post_json_get_json(
"%s/_matrix/client/r0/register?access_token=%s" % (
shadow_hs_url, as_token,
),
{
# XXX: auth_result is an unspecified extension for shadow registration
'auth_result': auth_result,
# XXX: another unspecified extension for shadow registration to ensure
# that the displayname is correctly set by the masters erver
'display_name': display_name,
'username': localpart,
'password': params.get("password"),
'bind_email': params.get("bind_email"),
'bind_msisdn': params.get("bind_msisdn"),
'device_id': params.get("device_id"),
'initial_device_display_name': params.get("initial_device_display_name"),
'inhibit_login': True,
'access_token': as_token,
}
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _generate_user_id(self, reseed=False): def _generate_user_id(self, reseed=False):
if reseed or self._next_generated_user_id is None: if reseed or self._next_generated_user_id is None:

View File

@ -157,8 +157,9 @@ class SimpleHttpClient(object):
data=query_bytes data=query_bytes
) )
body = yield make_deferred_yieldable(treq.json_content(response))
if 200 <= response.code < 300: if 200 <= response.code < 300:
body = yield make_deferred_yieldable(treq.json_content(response))
defer.returnValue(body) defer.returnValue(body)
else: else:
raise HttpResponseException(response.code, response.phrase, body) raise HttpResponseException(response.code, response.phrase, body)

View File

@ -14,6 +14,8 @@
# limitations under the License. # limitations under the License.
""" This module contains REST servlets to do with profile: /profile/<paths> """ """ This module contains REST servlets to do with profile: /profile/<paths> """
import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request
@ -21,6 +23,8 @@ from synapse.types import UserID
from .base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
logger = logging.getLogger(__name__)
class ProfileDisplaynameRestServlet(ClientV1RestServlet): class ProfileDisplaynameRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/displayname") PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/displayname")
@ -28,6 +32,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(ProfileDisplaynameRestServlet, self).__init__(hs) super(ProfileDisplaynameRestServlet, self).__init__(hs)
self.profile_handler = hs.get_profile_handler() self.profile_handler = hs.get_profile_handler()
self.http_client = hs.get_simple_http_client()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
@ -59,11 +64,30 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
yield self.profile_handler.set_displayname( yield self.profile_handler.set_displayname(
user, requester, new_name, is_admin) user, requester, new_name, is_admin)
if self.hs.config.shadow_server:
shadow_user = UserID(
user.localpart, self.hs.config.shadow_server.get("hs")
)
self.shadow_displayname(shadow_user.to_string(), content)
defer.returnValue((200, {})) defer.returnValue((200, {}))
def on_OPTIONS(self, request, user_id): def on_OPTIONS(self, request, user_id):
return (200, {}) return (200, {})
@defer.inlineCallbacks
def shadow_displayname(self, user_id, body):
# TODO: retries
shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
as_token = self.hs.config.shadow_server.get("as_token")
yield self.http_client.put_json(
"%s/_matrix/client/r0/profile/%s/displayname?access_token=%s&user_id=%s" % (
shadow_hs_url, user_id, as_token, user_id
),
body
)
class ProfileAvatarURLRestServlet(ClientV1RestServlet): class ProfileAvatarURLRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/avatar_url") PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/avatar_url")
@ -71,6 +95,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(ProfileAvatarURLRestServlet, self).__init__(hs) super(ProfileAvatarURLRestServlet, self).__init__(hs)
self.profile_handler = hs.get_profile_handler() self.profile_handler = hs.get_profile_handler()
self.http_client = hs.get_simple_http_client()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
@ -101,11 +126,30 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
yield self.profile_handler.set_avatar_url( yield self.profile_handler.set_avatar_url(
user, requester, new_name, is_admin) user, requester, new_name, is_admin)
if self.hs.config.shadow_server:
shadow_user = UserID(
user.localpart, self.hs.config.shadow_server.get("hs")
)
self.shadow_avatar_url(shadow_user.to_string(), content)
defer.returnValue((200, {})) defer.returnValue((200, {}))
def on_OPTIONS(self, request, user_id): def on_OPTIONS(self, request, user_id):
return (200, {}) return (200, {})
@defer.inlineCallbacks
def shadow_avatar_url(self, user_id, body):
# TODO: retries
shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
as_token = self.hs.config.shadow_server.get("as_token")
yield self.http_client.put_json(
"%s/_matrix/client/r0/profile/%s/avatar_url?access_token=%s&user_id=%s" % (
shadow_hs_url, user_id, as_token, user_id
),
body
)
class ProfileRestServlet(ClientV1RestServlet): class ProfileRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)") PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)")

View File

@ -29,6 +29,7 @@ from synapse.http.servlet import (
) )
from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.threepids import check_3pid_allowed from synapse.util.threepids import check_3pid_allowed
from synapse.types import UserID
from ._base import client_v2_patterns, interactive_auth_handler from ._base import client_v2_patterns, interactive_auth_handler
@ -117,6 +118,7 @@ class PasswordRestServlet(RestServlet):
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
self.datastore = self.hs.get_datastore() self.datastore = self.hs.get_datastore()
self._set_password_handler = hs.get_set_password_handler() self._set_password_handler = hs.get_set_password_handler()
self.http_client = hs.get_simple_http_client()
@interactive_auth_handler @interactive_auth_handler
@defer.inlineCallbacks @defer.inlineCallbacks
@ -135,9 +137,13 @@ class PasswordRestServlet(RestServlet):
if self.auth.has_access_token(request): if self.auth.has_access_token(request):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
params = yield self.auth_handler.validate_user_via_ui_auth( # blindly trust ASes without UI-authing them
requester, body, self.hs.get_ip_from_request(request), if requester.app_service:
) params = body
else:
params = yield self.auth_handler.validate_user_via_ui_auth(
requester, body, self.hs.get_ip_from_request(request),
)
user_id = requester.user.to_string() user_id = requester.user.to_string()
else: else:
requester = None requester = None
@ -173,11 +179,30 @@ class PasswordRestServlet(RestServlet):
user_id, new_password, requester user_id, new_password, requester
) )
if self.hs.config.shadow_server:
shadow_user = UserID(
requester.user.localpart, self.hs.config.shadow_server.get("hs")
)
self.shadow_password(params, shadow_user.to_string())
defer.returnValue((200, {})) defer.returnValue((200, {}))
def on_OPTIONS(self, _): def on_OPTIONS(self, _):
return 200, {} return 200, {}
@defer.inlineCallbacks
def shadow_password(self, body, user_id):
# TODO: retries
shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
as_token = self.hs.config.shadow_server.get("as_token")
yield self.http_client.post_json_get_json(
"%s/_matrix/client/r0/account/password?access_token=%s&user_id=%s" % (
shadow_hs_url, as_token, user_id,
),
body
)
class DeactivateAccountRestServlet(RestServlet): class DeactivateAccountRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/deactivate$") PATTERNS = client_v2_patterns("/account/deactivate$")
@ -307,7 +332,8 @@ class ThreepidRestServlet(RestServlet):
self.identity_handler = hs.get_handlers().identity_handler self.identity_handler = hs.get_handlers().identity_handler
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
self.datastore = self.hs.get_datastore() self.datastore = hs.get_datastore()
self.http_client = hs.get_simple_http_client()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
@ -326,25 +352,33 @@ class ThreepidRestServlet(RestServlet):
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
threePidCreds = body.get('threePidCreds')
threePidCreds = body.get('three_pid_creds', threePidCreds)
if threePidCreds is None:
raise SynapseError(400, "Missing param", Codes.MISSING_PARAM)
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string() user_id = requester.user.to_string()
threepid = yield self.identity_handler.threepid_from_creds(threePidCreds) # skip validation if this is a shadow 3PID from an AS
if not requester.app_service:
threePidCreds = body.get('threePidCreds')
threePidCreds = body.get('three_pid_creds', threePidCreds)
if threePidCreds is None:
raise SynapseError(400, "Missing param", Codes.MISSING_PARAM)
if not threepid: threepid = yield self.identity_handler.threepid_from_creds(threePidCreds)
raise SynapseError(
400, "Failed to auth 3pid", Codes.THREEPID_AUTH_FAILED
)
for reqd in ['medium', 'address', 'validated_at']: if not threepid:
if reqd not in threepid: raise SynapseError(
logger.warn("Couldn't add 3pid: invalid response from ID server") 400, "Failed to auth 3pid", Codes.THREEPID_AUTH_FAILED
raise SynapseError(500, "Invalid response from ID Server") )
for reqd in ['medium', 'address', 'validated_at']:
if reqd not in threepid:
logger.warn("Couldn't add 3pid: invalid response from ID server")
raise SynapseError(500, "Invalid response from ID Server")
else:
# XXX: ASes pass in a validated threepid directly to bypass the IS.
# This makes the API entirely change shape when we have an AS token;
# it really should be an entirely separate API - perhaps
# /account/3pid/replicate or something.
threepid = body.get('threepid')
yield self.auth_handler.add_threepid( yield self.auth_handler.add_threepid(
user_id, user_id,
@ -353,7 +387,7 @@ class ThreepidRestServlet(RestServlet):
threepid['validated_at'], threepid['validated_at'],
) )
if 'bind' in body and body['bind']: if not requester.app_service and ('bind' in body and body['bind']):
logger.debug( logger.debug(
"Binding threepid %s to %s", "Binding threepid %s to %s",
threepid, user_id threepid, user_id
@ -362,8 +396,27 @@ class ThreepidRestServlet(RestServlet):
threePidCreds, user_id threePidCreds, user_id
) )
if self.hs.config.shadow_server:
shadow_user = UserID(
requester.user.localpart, self.hs.config.shadow_server.get("hs")
)
self.shadow_3pid({'threepid': threepid}, shadow_user.to_string())
defer.returnValue((200, {})) defer.returnValue((200, {}))
@defer.inlineCallbacks
def shadow_3pid(self, body, user_id):
# TODO: retries
shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
as_token = self.hs.config.shadow_server.get("as_token")
yield self.http_client.post_json_get_json(
"%s/_matrix/client/r0/account/3pid?access_token=%s&user_id=%s" % (
shadow_hs_url, as_token, user_id,
),
body
)
class ThreepidDeleteRestServlet(RestServlet): class ThreepidDeleteRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/3pid/delete$", releases=()) PATTERNS = client_v2_patterns("/account/3pid/delete$", releases=())
@ -373,6 +426,7 @@ class ThreepidDeleteRestServlet(RestServlet):
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
self.http_client = hs.get_simple_http_client()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
@ -396,6 +450,12 @@ class ThreepidDeleteRestServlet(RestServlet):
logger.exception("Failed to remove threepid") logger.exception("Failed to remove threepid")
raise SynapseError(500, "Failed to remove threepid") raise SynapseError(500, "Failed to remove threepid")
if self.hs.config.shadow_server:
shadow_user = UserID(
requester.user.localpart, self.hs.config.shadow_server.get("hs")
)
self.shadow_3pid_delete(body, shadow_user.to_string())
if ret: if ret:
id_server_unbind_result = "success" id_server_unbind_result = "success"
else: else:
@ -405,6 +465,19 @@ class ThreepidDeleteRestServlet(RestServlet):
"id_server_unbind_result": id_server_unbind_result, "id_server_unbind_result": id_server_unbind_result,
})) }))
@defer.inlineCallbacks
def shadow_3pid_delete(self, body, user_id):
# TODO: retries
shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
as_token = self.hs.config.shadow_server.get("as_token")
yield self.http_client.post_json_get_json(
"%s/_matrix/client/r0/account/3pid/delete?access_token=%s&user_id=%s" % (
shadow_hs_url, as_token, user_id
),
body
)
class WhoamiRestServlet(RestServlet): class WhoamiRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/whoami$") PATTERNS = client_v2_patterns("/account/whoami$")

View File

@ -229,7 +229,7 @@ class RegisterRestServlet(RestServlet):
raise SynapseError(400, "Invalid username") raise SynapseError(400, "Invalid username")
desired_username = body['username'] desired_username = body['username']
desired_display_name = None desired_display_name = body.get('display_name')
appservice = None appservice = None
if self.auth.has_access_token(request): if self.auth.has_access_token(request):
@ -254,7 +254,8 @@ class RegisterRestServlet(RestServlet):
if isinstance(desired_username, string_types): if isinstance(desired_username, string_types):
result = yield self._do_appservice_registration( result = yield self._do_appservice_registration(
desired_username, access_token, body desired_username, desired_password, desired_display_name,
access_token, body
) )
defer.returnValue((200, result)) # we throw for non 200 responses defer.returnValue((200, result)) # we throw for non 200 responses
return return
@ -474,7 +475,6 @@ class RegisterRestServlet(RestServlet):
pass pass
guest_access_token = params.get("guest_access_token", None) guest_access_token = params.get("guest_access_token", None)
new_password = params.get("password", None)
# XXX: don't we need to validate these for length etc like we did on # XXX: don't we need to validate these for length etc like we did on
# the ones from the JSON body earlier on in the method? # the ones from the JSON body earlier on in the method?
@ -488,7 +488,7 @@ class RegisterRestServlet(RestServlet):
(registered_user_id, _) = yield self.registration_handler.register( (registered_user_id, _) = yield self.registration_handler.register(
localpart=desired_username, localpart=desired_username,
password=new_password, password=params.get("password", None),
guest_access_token=guest_access_token, guest_access_token=guest_access_token,
generate_token=False, generate_token=False,
display_name=desired_display_name, display_name=desired_display_name,
@ -499,6 +499,14 @@ class RegisterRestServlet(RestServlet):
if is_threepid_reserved(self.hs.config, threepid): if is_threepid_reserved(self.hs.config, threepid):
yield self.store.upsert_monthly_active_user(registered_user_id) yield self.store.upsert_monthly_active_user(registered_user_id)
if self.hs.config.shadow_server:
yield self.registration_handler.shadow_register(
localpart=desired_username,
display_name=desired_display_name,
auth_result=auth_result,
params=params,
)
# remember that we've now registered that user account, and with # remember that we've now registered that user account, and with
# what user ID (since the user may not have specified) # what user ID (since the user may not have specified)
self.auth_handler.set_session_data( self.auth_handler.set_session_data(
@ -532,11 +540,33 @@ class RegisterRestServlet(RestServlet):
return 200, {} return 200, {}
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_appservice_registration(self, username, as_token, body): def _do_appservice_registration(
self, username, password, display_name, as_token, body
):
# FIXME: appservice_register() is horribly duplicated with register()
# and they should probably just be combined together with a config flag.
user_id = yield self.registration_handler.appservice_register( user_id = yield self.registration_handler.appservice_register(
username, as_token username, as_token, password, display_name
) )
defer.returnValue((yield self._create_registration_details(user_id, body))) result = yield self._create_registration_details(user_id, body)
auth_result = body.get('auth_result')
if auth_result and LoginType.EMAIL_IDENTITY in auth_result:
threepid = auth_result[LoginType.EMAIL_IDENTITY]
yield self._register_email_threepid(
user_id, threepid, result["access_token"],
body.get("bind_email")
)
if auth_result and LoginType.MSISDN in auth_result:
threepid = auth_result[LoginType.MSISDN]
yield self._register_msisdn_threepid(
user_id, threepid, result["access_token"],
body.get("bind_msisdn")
)
defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_shared_secret_registration(self, username, password, body): def _do_shared_secret_registration(self, username, password, body):

View File

@ -35,7 +35,7 @@ def _make_exclusive_regex(services_cache):
exclusive_user_regexes = [ exclusive_user_regexes = [
regex.pattern regex.pattern
for service in services_cache for service in services_cache
for regex in service.get_exlusive_user_regexes() for regex in service.get_exclusive_user_regexes()
] ]
if exclusive_user_regexes: if exclusive_user_regexes:
exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes) exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes)

View File

@ -63,7 +63,6 @@ class TestMauLimit(unittest.TestCase):
self.hs.config.server_notices_mxid_display_name = None self.hs.config.server_notices_mxid_display_name = None
self.hs.config.server_notices_mxid_avatar_url = None self.hs.config.server_notices_mxid_avatar_url = None
self.hs.config.server_notices_room_name = "Test Server Notice Room" self.hs.config.server_notices_room_name = "Test Server Notice Room"
self.hs.config.register_mxid_from_3pid = None
self.resource = JsonResource(self.hs) self.resource = JsonResource(self.hs)
register.register_servlets(self.hs, self.resource) register.register_servlets(self.hs, self.resource)

View File

@ -137,6 +137,8 @@ def default_config(name):
config.admin_contact = None config.admin_contact = None
config.rc_messages_per_second = 10000 config.rc_messages_per_second = 10000
config.rc_message_burst_count = 10000 config.rc_message_burst_count = 10000
config.register_mxid_from_3pid = None
config.shadow_server = None
config.use_frozen_dicts = False config.use_frozen_dicts = False