Add support for handling avatar with SSO login (#13917)

This commit adds support for handling a provided avatar picture URL
when logging in via SSO.

Signed-off-by: Ashish Kumar <ashfame@users.noreply.github.com>

Fixes #9357.
pull/14574/head
Ashish Kumar 2022-11-25 19:16:50 +04:00 committed by GitHub
parent 39cde585bf
commit 09de2aecb0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 275 additions and 2 deletions

View File

@ -0,0 +1 @@
Adds support for handling avatar in SSO login. Contributed by @ashfame.

View File

@ -2968,10 +2968,17 @@ Options for each entry include:
For the default provider, the following settings are available: For the default provider, the following settings are available:
* subject_claim: name of the claim containing a unique identifier * `subject_claim`: name of the claim containing a unique identifier
for the user. Defaults to 'sub', which OpenID Connect for the user. Defaults to 'sub', which OpenID Connect
compliant providers should provide. compliant providers should provide.
* `picture_claim`: name of the claim containing an url for the user's profile picture.
Defaults to 'picture', which OpenID Connect compliant providers should provide
and has to refer to a direct image file such as PNG, JPEG, or GIF image file.
Currently only supported in monolithic (single-process) server configurations
where the media repository runs within the Synapse process.
* `localpart_template`: Jinja2 template for the localpart of the MXID. * `localpart_template`: Jinja2 template for the localpart of the MXID.
If this is not set, the user will be prompted to choose their If this is not set, the user will be prompted to choose their
own username (see the documentation for the `sso_auth_account_details.html` own username (see the documentation for the `sso_auth_account_details.html`

View File

@ -119,6 +119,9 @@ disallow_untyped_defs = True
[mypy-tests.storage.test_profile] [mypy-tests.storage.test_profile]
disallow_untyped_defs = True disallow_untyped_defs = True
[mypy-tests.handlers.test_sso]
disallow_untyped_defs = True
[mypy-tests.storage.test_user_directory] [mypy-tests.storage.test_user_directory]
disallow_untyped_defs = True disallow_untyped_defs = True
@ -137,7 +140,6 @@ disallow_untyped_defs = False
[mypy-tests.utils] [mypy-tests.utils]
disallow_untyped_defs = True disallow_untyped_defs = True
;; Dependencies without annotations ;; Dependencies without annotations
;; Before ignoring a module, check to see if type stubs are available. ;; Before ignoring a module, check to see if type stubs are available.
;; The `typeshed` project maintains stubs here: ;; The `typeshed` project maintains stubs here:

View File

@ -1435,6 +1435,7 @@ class UserAttributeDict(TypedDict):
localpart: Optional[str] localpart: Optional[str]
confirm_localpart: bool confirm_localpart: bool
display_name: Optional[str] display_name: Optional[str]
picture: Optional[str] # may be omitted by older `OidcMappingProviders`
emails: List[str] emails: List[str]
@ -1520,6 +1521,7 @@ env.filters.update(
@attr.s(slots=True, frozen=True, auto_attribs=True) @attr.s(slots=True, frozen=True, auto_attribs=True)
class JinjaOidcMappingConfig: class JinjaOidcMappingConfig:
subject_claim: str subject_claim: str
picture_claim: str
localpart_template: Optional[Template] localpart_template: Optional[Template]
display_name_template: Optional[Template] display_name_template: Optional[Template]
email_template: Optional[Template] email_template: Optional[Template]
@ -1539,6 +1541,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
@staticmethod @staticmethod
def parse_config(config: dict) -> JinjaOidcMappingConfig: def parse_config(config: dict) -> JinjaOidcMappingConfig:
subject_claim = config.get("subject_claim", "sub") subject_claim = config.get("subject_claim", "sub")
picture_claim = config.get("picture_claim", "picture")
def parse_template_config(option_name: str) -> Optional[Template]: def parse_template_config(option_name: str) -> Optional[Template]:
if option_name not in config: if option_name not in config:
@ -1572,6 +1575,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
return JinjaOidcMappingConfig( return JinjaOidcMappingConfig(
subject_claim=subject_claim, subject_claim=subject_claim,
picture_claim=picture_claim,
localpart_template=localpart_template, localpart_template=localpart_template,
display_name_template=display_name_template, display_name_template=display_name_template,
email_template=email_template, email_template=email_template,
@ -1611,10 +1615,13 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
if email: if email:
emails.append(email) emails.append(email)
picture = userinfo.get("picture")
return UserAttributeDict( return UserAttributeDict(
localpart=localpart, localpart=localpart,
display_name=display_name, display_name=display_name,
emails=emails, emails=emails,
picture=picture,
confirm_localpart=self._config.confirm_localpart, confirm_localpart=self._config.confirm_localpart,
) )

View File

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import abc import abc
import hashlib
import io
import logging import logging
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
@ -138,6 +140,7 @@ class UserAttributes:
localpart: Optional[str] localpart: Optional[str]
confirm_localpart: bool = False confirm_localpart: bool = False
display_name: Optional[str] = None display_name: Optional[str] = None
picture: Optional[str] = None
emails: Collection[str] = attr.Factory(list) emails: Collection[str] = attr.Factory(list)
@ -196,6 +199,10 @@ class SsoHandler:
self._error_template = hs.config.sso.sso_error_template self._error_template = hs.config.sso.sso_error_template
self._bad_user_template = hs.config.sso.sso_auth_bad_user_template self._bad_user_template = hs.config.sso.sso_auth_bad_user_template
self._profile_handler = hs.get_profile_handler() self._profile_handler = hs.get_profile_handler()
self._media_repo = (
hs.get_media_repository() if hs.config.media.can_load_media_repo else None
)
self._http_client = hs.get_proxied_blacklisted_http_client()
# The following template is shown after a successful user interactive # The following template is shown after a successful user interactive
# authentication session. It tells the user they can close the window. # authentication session. It tells the user they can close the window.
@ -495,6 +502,8 @@ class SsoHandler:
await self._profile_handler.set_displayname( await self._profile_handler.set_displayname(
user_id_obj, requester, attributes.display_name, True user_id_obj, requester, attributes.display_name, True
) )
if attributes.picture:
await self.set_avatar(user_id, attributes.picture)
await self._auth_handler.complete_sso_login( await self._auth_handler.complete_sso_login(
user_id, user_id,
@ -703,8 +712,110 @@ class SsoHandler:
await self._store.record_user_external_id( await self._store.record_user_external_id(
auth_provider_id, remote_user_id, registered_user_id auth_provider_id, remote_user_id, registered_user_id
) )
# Set avatar, if available
if attributes.picture:
await self.set_avatar(registered_user_id, attributes.picture)
return registered_user_id return registered_user_id
async def set_avatar(self, user_id: str, picture_https_url: str) -> bool:
"""Set avatar of the user.
This downloads the image file from the URL provided, stores that in
the media repository and then sets the avatar on the user's profile.
It can detect if the same image is being saved again and bails early by storing
the hash of the file in the `upload_name` of the avatar image.
Currently, it only supports server configurations which run the media repository
within the same process.
It silently fails and logs a warning by raising an exception and catching it
internally if:
* it is unable to fetch the image itself (non 200 status code) or
* the image supplied is bigger than max allowed size or
* the image type is not one of the allowed image types.
Args:
user_id: matrix user ID in the form @localpart:domain as a string.
picture_https_url: HTTPS url for the picture image file.
Returns: `True` if the user's avatar has been successfully set to the image at
`picture_https_url`.
"""
if self._media_repo is None:
logger.info(
"failed to set user avatar because out-of-process media repositories "
"are not supported yet "
)
return False
try:
uid = UserID.from_string(user_id)
def is_allowed_mime_type(content_type: str) -> bool:
if (
self._profile_handler.allowed_avatar_mimetypes
and content_type
not in self._profile_handler.allowed_avatar_mimetypes
):
return False
return True
# download picture, enforcing size limit & mime type check
picture = io.BytesIO()
content_length, headers, uri, code = await self._http_client.get_file(
url=picture_https_url,
output_stream=picture,
max_size=self._profile_handler.max_avatar_size,
is_allowed_content_type=is_allowed_mime_type,
)
if code != 200:
raise Exception(
"GET request to download sso avatar image returned {}".format(code)
)
# upload name includes hash of the image file's content so that we can
# easily check if it requires an update or not, the next time user logs in
upload_name = "sso_avatar_" + hashlib.sha256(picture.read()).hexdigest()
# bail if user already has the same avatar
profile = await self._profile_handler.get_profile(user_id)
if profile["avatar_url"] is not None:
server_name = profile["avatar_url"].split("/")[-2]
media_id = profile["avatar_url"].split("/")[-1]
if server_name == self._server_name:
media = await self._media_repo.store.get_local_media(media_id)
if media is not None and upload_name == media["upload_name"]:
logger.info("skipping saving the user avatar")
return True
# store it in media repository
avatar_mxc_url = await self._media_repo.create_content(
media_type=headers[b"Content-Type"][0].decode("utf-8"),
upload_name=upload_name,
content=picture,
content_length=content_length,
auth_user=uid,
)
# save it as user avatar
await self._profile_handler.set_avatar_url(
uid,
create_requester(uid),
str(avatar_mxc_url),
)
logger.info("successfully saved the user avatar")
return True
except Exception:
logger.warning("failed to save the user avatar")
return False
async def complete_sso_ui_auth_request( async def complete_sso_ui_auth_request(
self, self,
auth_provider_id: str, auth_provider_id: str,

145
tests/handlers/test_sso.py Normal file
View File

@ -0,0 +1,145 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from http import HTTPStatus
from typing import BinaryIO, Callable, Dict, List, Optional, Tuple
from unittest.mock import Mock
from twisted.test.proto_helpers import MemoryReactor
from twisted.web.http_headers import Headers
from synapse.api.errors import Codes, SynapseError
from synapse.http.client import RawHeaders
from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest
from tests.test_utils import SMALL_PNG, FakeResponse
class TestSSOHandler(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.http_client = Mock(spec=["get_file"])
self.http_client.get_file.side_effect = mock_get_file
self.http_client.user_agent = b"Synapse Test"
hs = self.setup_test_homeserver(
proxied_blacklisted_http_client=self.http_client
)
return hs
async def test_set_avatar(self) -> None:
"""Tests successfully setting the avatar of a newly created user"""
handler = self.hs.get_sso_handler()
# Create a new user to set avatar for
reg_handler = self.hs.get_registration_handler()
user_id = self.get_success(reg_handler.register_user(approved=True))
self.assertTrue(
self.get_success(handler.set_avatar(user_id, "http://my.server/me.png"))
)
# Ensure avatar is set on this newly created user,
# so no need to compare for the exact image
profile_handler = self.hs.get_profile_handler()
profile = self.get_success(profile_handler.get_profile(user_id))
self.assertIsNot(profile["avatar_url"], None)
@unittest.override_config({"max_avatar_size": 1})
async def test_set_avatar_too_big_image(self) -> None:
"""Tests that saving an avatar fails when it is too big"""
handler = self.hs.get_sso_handler()
# any random user works since image check is supposed to fail
user_id = "@sso-user:test"
self.assertFalse(
self.get_success(handler.set_avatar(user_id, "http://my.server/me.png"))
)
@unittest.override_config({"allowed_avatar_mimetypes": ["image/jpeg"]})
async def test_set_avatar_incorrect_mime_type(self) -> None:
"""Tests that saving an avatar fails when its mime type is not allowed"""
handler = self.hs.get_sso_handler()
# any random user works since image check is supposed to fail
user_id = "@sso-user:test"
self.assertFalse(
self.get_success(handler.set_avatar(user_id, "http://my.server/me.png"))
)
async def test_skip_saving_avatar_when_not_changed(self) -> None:
"""Tests whether saving of avatar correctly skips if the avatar hasn't
changed"""
handler = self.hs.get_sso_handler()
# Create a new user to set avatar for
reg_handler = self.hs.get_registration_handler()
user_id = self.get_success(reg_handler.register_user(approved=True))
# set avatar for the first time, should be a success
self.assertTrue(
self.get_success(handler.set_avatar(user_id, "http://my.server/me.png"))
)
# get avatar picture for comparison after another attempt
profile_handler = self.hs.get_profile_handler()
profile = self.get_success(profile_handler.get_profile(user_id))
url_to_match = profile["avatar_url"]
# set same avatar for the second time, should be a success
self.assertTrue(
self.get_success(handler.set_avatar(user_id, "http://my.server/me.png"))
)
# compare avatar picture's url from previous step
profile = self.get_success(profile_handler.get_profile(user_id))
self.assertEqual(profile["avatar_url"], url_to_match)
async def mock_get_file(
url: str,
output_stream: BinaryIO,
max_size: Optional[int] = None,
headers: Optional[RawHeaders] = None,
is_allowed_content_type: Optional[Callable[[str], bool]] = None,
) -> Tuple[int, Dict[bytes, List[bytes]], str, int]:
fake_response = FakeResponse(code=404)
if url == "http://my.server/me.png":
fake_response = FakeResponse(
code=200,
headers=Headers(
{"Content-Type": ["image/png"], "Content-Length": [str(len(SMALL_PNG))]}
),
body=SMALL_PNG,
)
if max_size is not None and max_size < len(SMALL_PNG):
raise SynapseError(
HTTPStatus.BAD_GATEWAY,
"Requested file is too large > %r bytes" % (max_size,),
Codes.TOO_LARGE,
)
if is_allowed_content_type and not is_allowed_content_type("image/png"):
raise SynapseError(
HTTPStatus.BAD_GATEWAY,
(
"Requested file's content type not allowed for this operation: %s"
% "image/png"
),
)
output_stream.write(fake_response.body)
return len(SMALL_PNG), {b"Content-Type": [b"image/png"]}, "", 200