Fix some types in the CAS code.

pull/8784/head
Patrick Cloke 2020-10-22 14:17:50 -04:00
parent 79bfe966e0
commit 37fb198e5f
4 changed files with 18 additions and 7 deletions

1
changelog.d/8784.misc Normal file
View File

@ -0,0 +1 @@
Fix type hints in CAS handler.

View File

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
import urllib import urllib
from typing import Dict, Optional, Tuple from typing import TYPE_CHECKING, Dict, Optional, Tuple
from xml.etree import ElementTree as ET from xml.etree import ElementTree as ET
from twisted.web.client import PartialDownloadError from twisted.web.client import PartialDownloadError
@ -23,6 +23,9 @@ from synapse.api.errors import Codes, LoginError
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.types import UserID, map_username_to_mxid_localpart from synapse.types import UserID, map_username_to_mxid_localpart
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -31,10 +34,10 @@ class CasHandler:
Utility class for to handle the response from a CAS SSO service. Utility class for to handle the response from a CAS SSO service.
Args: Args:
hs (synapse.server.HomeServer) hs
""" """
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.hs = hs self.hs = hs
self._hostname = hs.hostname self._hostname = hs.hostname
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
@ -205,11 +208,16 @@ class CasHandler:
registered_user_id = await self._auth_handler.check_user_exists(user_id) registered_user_id = await self._auth_handler.check_user_exists(user_id)
if session: if session:
# If there's a session then the user must already exist.
assert registered_user_id
await self._auth_handler.complete_sso_ui_auth( await self._auth_handler.complete_sso_ui_auth(
registered_user_id, session, request, registered_user_id, session, request,
) )
else: else:
# If this not a UI auth request than there must be a redirect URL.
assert client_redirect_url
if not registered_user_id: if not registered_user_id:
# Pull out the user-agent and IP from the request. # Pull out the user-agent and IP from the request.
user_agent = request.get_user_agent("") user_agent = request.get_user_agent("")
@ -221,6 +229,8 @@ class CasHandler:
user_agent_ips=(user_agent, ip_address), user_agent_ips=(user_agent, ip_address),
) )
assert registered_user_id
await self._auth_handler.complete_sso_login( await self._auth_handler.complete_sso_login(
registered_user_id, request, client_redirect_url registered_user_id, request, client_redirect_url
) )

View File

@ -152,7 +152,7 @@ class RegistrationHandler(BaseHandler):
bind_emails=[], bind_emails=[],
by_admin=False, by_admin=False,
user_agent_ips=None, user_agent_ips=None,
): ) -> str:
"""Registers a new client on the server. """Registers a new client on the server.
Args: Args:

View File

@ -39,7 +39,7 @@ from synapse.util.async_helpers import Linearizer
from synapse.util.iterutils import chunk_seq from synapse.util.iterutils import chunk_seq
if TYPE_CHECKING: if TYPE_CHECKING:
import synapse.server from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -56,7 +56,7 @@ class Saml2SessionData:
class SamlHandler(BaseHandler): class SamlHandler(BaseHandler):
def __init__(self, hs: "synapse.server.HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self._saml_client = Saml2Client(hs.config.saml2_sp_config) self._saml_client = Saml2Client(hs.config.saml2_sp_config)
self._saml_idp_entityid = hs.config.saml2_idp_entityid self._saml_idp_entityid = hs.config.saml2_idp_entityid