From c588b9b9e421855a1025e1d8c355818a40508c44 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 7 Dec 2018 13:10:07 +0100 Subject: [PATCH] Factor SSO success handling out of CAS login (#4264) This is mostly factoring out the post-CAS-login code to somewhere we can reuse it for other SSO flows, but it also fixes the userid mapping while we're at it. --- changelog.d/4264.bugfix | 1 + synapse/handlers/auth.py | 13 +++- synapse/rest/client/v1/login.py | 105 +++++++++++++++++++++++--------- synapse/types.py | 66 ++++++++++++++++++++ tests/test_types.py | 31 +++++++++- 5 files changed, 184 insertions(+), 32 deletions(-) create mode 100644 changelog.d/4264.bugfix diff --git a/changelog.d/4264.bugfix b/changelog.d/4264.bugfix new file mode 100644 index 0000000000..b914026932 --- /dev/null +++ b/changelog.d/4264.bugfix @@ -0,0 +1 @@ +Fix CAS login when username is not valid in an MXID diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index c6e89db4bc..2abd9af94f 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -563,10 +563,10 @@ class AuthHandler(BaseHandler): insensitively, but return None if there are multiple inexact matches. Args: - (str) user_id: complete @user:id + (unicode|bytes) user_id: complete @user:id Returns: - defer.Deferred: (str) canonical_user_id, or None if zero or + defer.Deferred: (unicode) canonical_user_id, or None if zero or multiple matches """ res = yield self._find_user_id_and_pwd_hash(user_id) @@ -954,6 +954,15 @@ class MacaroonGenerator(object): return macaroon.serialize() def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)): + """ + + Args: + user_id (unicode): + duration_in_ms (int): + + Returns: + unicode + """ macaroon = self._generate_base_macaroon(user_id) macaroon.add_first_party_caveat("type = login") now = self.hs.get_clock().time_msec() diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 011e84e8b1..b7c5b58b01 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -23,8 +23,12 @@ from twisted.web.client import PartialDownloadError from synapse.api.errors import Codes, LoginError, SynapseError from synapse.http.server import finish_request -from synapse.http.servlet import RestServlet, parse_json_object_from_request -from synapse.types import UserID +from synapse.http.servlet import ( + RestServlet, + parse_json_object_from_request, + parse_string, +) +from synapse.types import UserID, map_username_to_mxid_localpart from synapse.util.msisdn import phone_number_to_msisdn from .base import ClientV1RestServlet, client_path_patterns @@ -358,17 +362,15 @@ class CasTicketServlet(ClientV1RestServlet): self.cas_server_url = hs.config.cas_server_url self.cas_service_url = hs.config.cas_service_url self.cas_required_attributes = hs.config.cas_required_attributes - self.auth_handler = hs.get_auth_handler() - self.handlers = hs.get_handlers() - self.macaroon_gen = hs.get_macaroon_generator() + self._sso_auth_handler = SSOAuthHandler(hs) @defer.inlineCallbacks def on_GET(self, request): - client_redirect_url = request.args[b"redirectUrl"][0] + client_redirect_url = parse_string(request, "redirectUrl", required=True) http_client = self.hs.get_simple_http_client() uri = self.cas_server_url + "/proxyValidate" args = { - "ticket": request.args[b"ticket"][0].decode('ascii'), + "ticket": parse_string(request, "ticket", required=True), "service": self.cas_service_url } try: @@ -380,7 +382,6 @@ class CasTicketServlet(ClientV1RestServlet): result = yield self.handle_cas_response(request, body, client_redirect_url) defer.returnValue(result) - @defer.inlineCallbacks def handle_cas_response(self, request, cas_response_body, client_redirect_url): user, attributes = self.parse_cas_response(cas_response_body) @@ -396,28 +397,9 @@ class CasTicketServlet(ClientV1RestServlet): if required_value != actual_value: raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) - user_id = UserID(user, self.hs.hostname).to_string() - auth_handler = self.auth_handler - registered_user_id = yield auth_handler.check_user_exists(user_id) - if not registered_user_id: - registered_user_id, _ = ( - yield self.handlers.registration_handler.register(localpart=user) - ) - - login_token = self.macaroon_gen.generate_short_term_login_token( - registered_user_id + return self._sso_auth_handler.on_successful_auth( + user, request, client_redirect_url, ) - redirect_url = self.add_login_token_to_redirect_url(client_redirect_url, - login_token) - request.redirect(redirect_url) - finish_request(request) - - def add_login_token_to_redirect_url(self, url, token): - url_parts = list(urllib.parse.urlparse(url)) - query = dict(urllib.parse.parse_qsl(url_parts[4])) - query.update({"loginToken": token}) - url_parts[4] = urllib.parse.urlencode(query).encode('ascii') - return urllib.parse.urlunparse(url_parts) def parse_cas_response(self, cas_response_body): user = None @@ -452,6 +434,71 @@ class CasTicketServlet(ClientV1RestServlet): return user, attributes +class SSOAuthHandler(object): + """ + Utility class for Resources and Servlets which handle the response from a SSO + service + + Args: + hs (synapse.server.HomeServer) + """ + def __init__(self, hs): + self._hostname = hs.hostname + self._auth_handler = hs.get_auth_handler() + self._registration_handler = hs.get_handlers().registration_handler + self._macaroon_gen = hs.get_macaroon_generator() + + @defer.inlineCallbacks + def on_successful_auth( + self, username, request, client_redirect_url, + ): + """Called once the user has successfully authenticated with the SSO. + + Registers the user if necessary, and then returns a redirect (with + a login token) to the client. + + Args: + username (unicode|bytes): the remote user id. We'll map this onto + something sane for a MXID localpath. + + request (SynapseRequest): the incoming request from the browser. We'll + respond to it with a redirect. + + client_redirect_url (unicode): the redirect_url the client gave us when + it first started the process. + + Returns: + Deferred[none]: Completes once we have handled the request. + """ + localpart = map_username_to_mxid_localpart(username) + user_id = UserID(localpart, self._hostname).to_string() + registered_user_id = yield self._auth_handler.check_user_exists(user_id) + if not registered_user_id: + registered_user_id, _ = ( + yield self._registration_handler.register( + localpart=localpart, + generate_token=False, + ) + ) + + login_token = self._macaroon_gen.generate_short_term_login_token( + registered_user_id + ) + redirect_url = self._add_login_token_to_redirect_url( + client_redirect_url, login_token + ) + request.redirect(redirect_url) + finish_request(request) + + @staticmethod + def _add_login_token_to_redirect_url(url, token): + url_parts = list(urllib.parse.urlparse(url)) + query = dict(urllib.parse.parse_qsl(url_parts[4])) + query.update({"loginToken": token}) + url_parts[4] = urllib.parse.urlencode(query) + return urllib.parse.urlunparse(url_parts) + + def register_servlets(hs, http_server): LoginRestServlet(hs).register(http_server) if hs.config.cas_enabled: diff --git a/synapse/types.py b/synapse/types.py index 41afb27a74..d8cb64addb 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import re import string from collections import namedtuple @@ -228,6 +229,71 @@ def contains_invalid_mxid_characters(localpart): return any(c not in mxid_localpart_allowed_characters for c in localpart) +UPPER_CASE_PATTERN = re.compile(b"[A-Z_]") + +# the following is a pattern which matches '=', and bytes which are not allowed in a mxid +# localpart. +# +# It works by: +# * building a string containing the allowed characters (excluding '=') +# * escaping every special character with a backslash (to stop '-' being interpreted as a +# range operator) +# * wrapping it in a '[^...]' regex +# * converting the whole lot to a 'bytes' sequence, so that we can use it to match +# bytes rather than strings +# +NON_MXID_CHARACTER_PATTERN = re.compile( + ("[^%s]" % ( + re.escape("".join(mxid_localpart_allowed_characters - {"="}),), + )).encode("ascii"), +) + + +def map_username_to_mxid_localpart(username, case_sensitive=False): + """Map a username onto a string suitable for a MXID + + This follows the algorithm laid out at + https://matrix.org/docs/spec/appendices.html#mapping-from-other-character-sets. + + Args: + username (unicode|bytes): username to be mapped + case_sensitive (bool): true if TEST and test should be mapped + onto different mxids + + Returns: + unicode: string suitable for a mxid localpart + """ + if not isinstance(username, bytes): + username = username.encode('utf-8') + + # first we sort out upper-case characters + if case_sensitive: + def f1(m): + return b"_" + m.group().lower() + + username = UPPER_CASE_PATTERN.sub(f1, username) + else: + username = username.lower() + + # then we sort out non-ascii characters + def f2(m): + g = m.group()[0] + if isinstance(g, str): + # on python 2, we need to do a ord(). On python 3, the + # byte itself will do. + g = ord(g) + return b"=%02x" % (g,) + + username = NON_MXID_CHARACTER_PATTERN.sub(f2, username) + + # we also do the =-escaping to mxids starting with an underscore. + username = re.sub(b'^_', b'=5f', username) + + # we should now only have ascii bytes left, so can decode back to a + # unicode. + return username.decode('ascii') + + class StreamToken( namedtuple("Token", ( "room_key", diff --git a/tests/test_types.py b/tests/test_types.py index 0f5c8bfaf9..d314a7ff58 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -14,7 +14,7 @@ # limitations under the License. from synapse.api.errors import SynapseError -from synapse.types import GroupID, RoomAlias, UserID +from synapse.types import GroupID, RoomAlias, UserID, map_username_to_mxid_localpart from tests import unittest from tests.utils import TestHomeServer @@ -79,3 +79,32 @@ class GroupIDTestCase(unittest.TestCase): except SynapseError as exc: self.assertEqual(400, exc.code) self.assertEqual("M_UNKNOWN", exc.errcode) + + +class MapUsernameTestCase(unittest.TestCase): + def testPassThrough(self): + self.assertEqual(map_username_to_mxid_localpart("test1234"), "test1234") + + def testUpperCase(self): + self.assertEqual(map_username_to_mxid_localpart("tEST_1234"), "test_1234") + self.assertEqual( + map_username_to_mxid_localpart("tEST_1234", case_sensitive=True), + "t_e_s_t__1234", + ) + + def testSymbols(self): + self.assertEqual( + map_username_to_mxid_localpart("test=$?_1234"), + "test=3d=24=3f_1234", + ) + + def testLeadingUnderscore(self): + self.assertEqual(map_username_to_mxid_localpart("_test_1234"), "=5ftest_1234") + + def testNonAscii(self): + # this should work with either a unicode or a bytes + self.assertEqual(map_username_to_mxid_localpart(u'têst'), "t=c3=aast") + self.assertEqual( + map_username_to_mxid_localpart(u'têst'.encode('utf-8')), + "t=c3=aast", + )