Compare commits
3 Commits
1c1242acba
...
fa4f12102d
| Author | SHA1 | Date |
|---|---|---|
|
|
fa4f12102d | |
|
|
825fb5d0a5 | |
|
|
e8e2ddb60a |
|
|
@ -0,0 +1 @@
|
||||||
|
Don't attempt to use an invalid sqlite config if no database configuration is provided. Contributed by @nekatak.
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
Allow server admins to define and enforce a password policy (MSC2000).
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
Refactored the CAS authentication logic to a separate class.
|
||||||
|
|
@ -1482,6 +1482,41 @@ password_config:
|
||||||
#
|
#
|
||||||
#pepper: "EVEN_MORE_SECRET"
|
#pepper: "EVEN_MORE_SECRET"
|
||||||
|
|
||||||
|
# Define and enforce a password policy. Each parameter is optional.
|
||||||
|
# This is an implementation of MSC2000.
|
||||||
|
#
|
||||||
|
policy:
|
||||||
|
# Whether to enforce the password policy.
|
||||||
|
# Defaults to 'false'.
|
||||||
|
#
|
||||||
|
#enabled: true
|
||||||
|
|
||||||
|
# Minimum accepted length for a password.
|
||||||
|
# Defaults to 0.
|
||||||
|
#
|
||||||
|
#minimum_length: 15
|
||||||
|
|
||||||
|
# Whether a password must contain at least one digit.
|
||||||
|
# Defaults to 'false'.
|
||||||
|
#
|
||||||
|
#require_digit: true
|
||||||
|
|
||||||
|
# Whether a password must contain at least one symbol.
|
||||||
|
# A symbol is any character that's not a number or a letter.
|
||||||
|
# Defaults to 'false'.
|
||||||
|
#
|
||||||
|
#require_symbol: true
|
||||||
|
|
||||||
|
# Whether a password must contain at least one lowercase letter.
|
||||||
|
# Defaults to 'false'.
|
||||||
|
#
|
||||||
|
#require_lowercase: true
|
||||||
|
|
||||||
|
# Whether a password must contain at least one lowercase letter.
|
||||||
|
# Defaults to 'false'.
|
||||||
|
#
|
||||||
|
#require_uppercase: true
|
||||||
|
|
||||||
|
|
||||||
# Configuration for sending emails from Synapse.
|
# Configuration for sending emails from Synapse.
|
||||||
#
|
#
|
||||||
|
|
|
||||||
|
|
@ -64,6 +64,13 @@ class Codes(object):
|
||||||
INCOMPATIBLE_ROOM_VERSION = "M_INCOMPATIBLE_ROOM_VERSION"
|
INCOMPATIBLE_ROOM_VERSION = "M_INCOMPATIBLE_ROOM_VERSION"
|
||||||
WRONG_ROOM_KEYS_VERSION = "M_WRONG_ROOM_KEYS_VERSION"
|
WRONG_ROOM_KEYS_VERSION = "M_WRONG_ROOM_KEYS_VERSION"
|
||||||
EXPIRED_ACCOUNT = "ORG_MATRIX_EXPIRED_ACCOUNT"
|
EXPIRED_ACCOUNT = "ORG_MATRIX_EXPIRED_ACCOUNT"
|
||||||
|
PASSWORD_TOO_SHORT = "M_PASSWORD_TOO_SHORT"
|
||||||
|
PASSWORD_NO_DIGIT = "M_PASSWORD_NO_DIGIT"
|
||||||
|
PASSWORD_NO_UPPERCASE = "M_PASSWORD_NO_UPPERCASE"
|
||||||
|
PASSWORD_NO_LOWERCASE = "M_PASSWORD_NO_LOWERCASE"
|
||||||
|
PASSWORD_NO_SYMBOL = "M_PASSWORD_NO_SYMBOL"
|
||||||
|
PASSWORD_IN_DICTIONARY = "M_PASSWORD_IN_DICTIONARY"
|
||||||
|
WEAK_PASSWORD = "M_WEAK_PASSWORD"
|
||||||
INVALID_SIGNATURE = "M_INVALID_SIGNATURE"
|
INVALID_SIGNATURE = "M_INVALID_SIGNATURE"
|
||||||
USER_DEACTIVATED = "M_USER_DEACTIVATED"
|
USER_DEACTIVATED = "M_USER_DEACTIVATED"
|
||||||
BAD_ALIAS = "M_BAD_ALIAS"
|
BAD_ALIAS = "M_BAD_ALIAS"
|
||||||
|
|
@ -439,6 +446,20 @@ class IncompatibleRoomVersionError(SynapseError):
|
||||||
return cs_error(self.msg, self.errcode, room_version=self._room_version)
|
return cs_error(self.msg, self.errcode, room_version=self._room_version)
|
||||||
|
|
||||||
|
|
||||||
|
class PasswordRefusedError(SynapseError):
|
||||||
|
"""A password has been refused, either during password reset/change or registration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
msg="This password doesn't comply with the server's policy",
|
||||||
|
errcode=Codes.WEAK_PASSWORD,
|
||||||
|
):
|
||||||
|
super(PasswordRefusedError, self).__init__(
|
||||||
|
code=400, msg=msg, errcode=errcode,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RequestSendFailed(RuntimeError):
|
class RequestSendFailed(RuntimeError):
|
||||||
"""Sending a HTTP request over federation failed due to not being able to
|
"""Sending a HTTP request over federation failed due to not being able to
|
||||||
talk to the remote server for some reason.
|
talk to the remote server for some reason.
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,11 @@ from synapse.config._base import Config, ConfigError
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
NON_SQLITE_DATABASE_PATH_WARNING = """\
|
||||||
|
Ignoring 'database_path' setting: not using a sqlite3 database.
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
"""
|
||||||
|
|
||||||
DEFAULT_CONFIG = """\
|
DEFAULT_CONFIG = """\
|
||||||
## Database ##
|
## Database ##
|
||||||
|
|
||||||
|
|
@ -105,6 +110,11 @@ class DatabaseConnectionConfig:
|
||||||
class DatabaseConfig(Config):
|
class DatabaseConfig(Config):
|
||||||
section = "database"
|
section = "database"
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
self.databases = []
|
||||||
|
|
||||||
def read_config(self, config, **kwargs):
|
def read_config(self, config, **kwargs):
|
||||||
self.event_cache_size = self.parse_size(config.get("event_cache_size", "10K"))
|
self.event_cache_size = self.parse_size(config.get("event_cache_size", "10K"))
|
||||||
|
|
||||||
|
|
@ -125,12 +135,13 @@ class DatabaseConfig(Config):
|
||||||
|
|
||||||
multi_database_config = config.get("databases")
|
multi_database_config = config.get("databases")
|
||||||
database_config = config.get("database")
|
database_config = config.get("database")
|
||||||
|
database_path = config.get("database_path")
|
||||||
|
|
||||||
if multi_database_config and database_config:
|
if multi_database_config and database_config:
|
||||||
raise ConfigError("Can't specify both 'database' and 'datbases' in config")
|
raise ConfigError("Can't specify both 'database' and 'datbases' in config")
|
||||||
|
|
||||||
if multi_database_config:
|
if multi_database_config:
|
||||||
if config.get("database_path"):
|
if database_path:
|
||||||
raise ConfigError("Can't specify 'database_path' with 'databases'")
|
raise ConfigError("Can't specify 'database_path' with 'databases'")
|
||||||
|
|
||||||
self.databases = [
|
self.databases = [
|
||||||
|
|
@ -138,13 +149,17 @@ class DatabaseConfig(Config):
|
||||||
for name, db_conf in multi_database_config.items()
|
for name, db_conf in multi_database_config.items()
|
||||||
]
|
]
|
||||||
|
|
||||||
else:
|
if database_config:
|
||||||
if database_config is None:
|
|
||||||
database_config = {"name": "sqlite3", "args": {}}
|
|
||||||
|
|
||||||
self.databases = [DatabaseConnectionConfig("master", database_config)]
|
self.databases = [DatabaseConnectionConfig("master", database_config)]
|
||||||
|
|
||||||
self.set_databasepath(config.get("database_path"))
|
if database_path:
|
||||||
|
if self.databases and self.databases[0].name != "sqlite3":
|
||||||
|
logger.warning(NON_SQLITE_DATABASE_PATH_WARNING)
|
||||||
|
return
|
||||||
|
|
||||||
|
database_config = {"name": "sqlite3", "args": {}}
|
||||||
|
self.databases = [DatabaseConnectionConfig("master", database_config)]
|
||||||
|
self.set_databasepath(database_path)
|
||||||
|
|
||||||
def generate_config_section(self, data_dir_path, **kwargs):
|
def generate_config_section(self, data_dir_path, **kwargs):
|
||||||
return DEFAULT_CONFIG % {
|
return DEFAULT_CONFIG % {
|
||||||
|
|
@ -152,27 +167,37 @@ class DatabaseConfig(Config):
|
||||||
}
|
}
|
||||||
|
|
||||||
def read_arguments(self, args):
|
def read_arguments(self, args):
|
||||||
|
"""
|
||||||
|
Cases for the cli input:
|
||||||
|
- If no databases are configured and no database_path is set, raise.
|
||||||
|
- No databases and only database_path available ==> sqlite3 db.
|
||||||
|
- If there are multiple databases and a database_path raise an error.
|
||||||
|
- If the database set in the config file is sqlite then
|
||||||
|
overwrite with the command line argument.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if args.database_path is None:
|
||||||
|
if not self.databases:
|
||||||
|
raise ConfigError("No database config provided")
|
||||||
|
return
|
||||||
|
|
||||||
|
if len(self.databases) == 0:
|
||||||
|
database_config = {"name": "sqlite3", "args": {}}
|
||||||
|
self.databases = [DatabaseConnectionConfig("master", database_config)]
|
||||||
self.set_databasepath(args.database_path)
|
self.set_databasepath(args.database_path)
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.get_single_database().name == "sqlite3":
|
||||||
|
self.set_databasepath(args.database_path)
|
||||||
|
else:
|
||||||
|
logger.warning(NON_SQLITE_DATABASE_PATH_WARNING)
|
||||||
|
|
||||||
def set_databasepath(self, database_path):
|
def set_databasepath(self, database_path):
|
||||||
if database_path is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
if database_path != ":memory:":
|
if database_path != ":memory:":
|
||||||
database_path = self.abspath(database_path)
|
database_path = self.abspath(database_path)
|
||||||
|
|
||||||
# We only support setting a database path if we have a single sqlite3
|
self.databases[0].config["args"]["database"] = database_path
|
||||||
# database.
|
|
||||||
if len(self.databases) != 1:
|
|
||||||
raise ConfigError("Cannot specify 'database_path' with multiple databases")
|
|
||||||
|
|
||||||
database = self.get_single_database()
|
|
||||||
if database.config["name"] != "sqlite3":
|
|
||||||
# We don't raise here as we haven't done so before for this case.
|
|
||||||
logger.warn("Ignoring 'database_path' for non-sqlite3 database")
|
|
||||||
return
|
|
||||||
|
|
||||||
database.config["args"]["database"] = database_path
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_arguments(parser):
|
def add_arguments(parser):
|
||||||
|
|
@ -187,7 +212,7 @@ class DatabaseConfig(Config):
|
||||||
def get_single_database(self) -> DatabaseConnectionConfig:
|
def get_single_database(self) -> DatabaseConnectionConfig:
|
||||||
"""Returns the database if there is only one, useful for e.g. tests
|
"""Returns the database if there is only one, useful for e.g. tests
|
||||||
"""
|
"""
|
||||||
if len(self.databases) != 1:
|
if not self.databases:
|
||||||
raise Exception("More than one database exists")
|
raise Exception("More than one database exists")
|
||||||
|
|
||||||
return self.databases[0]
|
return self.databases[0]
|
||||||
|
|
|
||||||
|
|
@ -31,6 +31,10 @@ class PasswordConfig(Config):
|
||||||
self.password_localdb_enabled = password_config.get("localdb_enabled", True)
|
self.password_localdb_enabled = password_config.get("localdb_enabled", True)
|
||||||
self.password_pepper = password_config.get("pepper", "")
|
self.password_pepper = password_config.get("pepper", "")
|
||||||
|
|
||||||
|
# Password policy
|
||||||
|
self.password_policy = password_config.get("policy") or {}
|
||||||
|
self.password_policy_enabled = self.password_policy.get("enabled", False)
|
||||||
|
|
||||||
def generate_config_section(self, config_dir_path, server_name, **kwargs):
|
def generate_config_section(self, config_dir_path, server_name, **kwargs):
|
||||||
return """\
|
return """\
|
||||||
password_config:
|
password_config:
|
||||||
|
|
@ -48,4 +52,39 @@ class PasswordConfig(Config):
|
||||||
# DO NOT CHANGE THIS AFTER INITIAL SETUP!
|
# DO NOT CHANGE THIS AFTER INITIAL SETUP!
|
||||||
#
|
#
|
||||||
#pepper: "EVEN_MORE_SECRET"
|
#pepper: "EVEN_MORE_SECRET"
|
||||||
|
|
||||||
|
# Define and enforce a password policy. Each parameter is optional.
|
||||||
|
# This is an implementation of MSC2000.
|
||||||
|
#
|
||||||
|
policy:
|
||||||
|
# Whether to enforce the password policy.
|
||||||
|
# Defaults to 'false'.
|
||||||
|
#
|
||||||
|
#enabled: true
|
||||||
|
|
||||||
|
# Minimum accepted length for a password.
|
||||||
|
# Defaults to 0.
|
||||||
|
#
|
||||||
|
#minimum_length: 15
|
||||||
|
|
||||||
|
# Whether a password must contain at least one digit.
|
||||||
|
# Defaults to 'false'.
|
||||||
|
#
|
||||||
|
#require_digit: true
|
||||||
|
|
||||||
|
# Whether a password must contain at least one symbol.
|
||||||
|
# A symbol is any character that's not a number or a letter.
|
||||||
|
# Defaults to 'false'.
|
||||||
|
#
|
||||||
|
#require_symbol: true
|
||||||
|
|
||||||
|
# Whether a password must contain at least one lowercase letter.
|
||||||
|
# Defaults to 'false'.
|
||||||
|
#
|
||||||
|
#require_lowercase: true
|
||||||
|
|
||||||
|
# Whether a password must contain at least one lowercase letter.
|
||||||
|
# Defaults to 'false'.
|
||||||
|
#
|
||||||
|
#require_uppercase: true
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,204 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import xml.etree.ElementTree as ET
|
||||||
|
from typing import AnyStr, Dict, Optional, Tuple
|
||||||
|
|
||||||
|
from six.moves import urllib
|
||||||
|
|
||||||
|
from twisted.web.client import PartialDownloadError
|
||||||
|
|
||||||
|
from synapse.api.errors import Codes, LoginError
|
||||||
|
from synapse.http.site import SynapseRequest
|
||||||
|
from synapse.types import UserID, map_username_to_mxid_localpart
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CasHandler:
|
||||||
|
"""
|
||||||
|
Utility class for to handle the response from a CAS SSO service.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hs (synapse.server.HomeServer)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
self._hostname = hs.hostname
|
||||||
|
self._auth_handler = hs.get_auth_handler()
|
||||||
|
self._registration_handler = hs.get_registration_handler()
|
||||||
|
|
||||||
|
self._cas_server_url = hs.config.cas_server_url
|
||||||
|
self._cas_service_url = hs.config.cas_service_url
|
||||||
|
self._cas_displayname_attribute = hs.config.cas_displayname_attribute
|
||||||
|
self._cas_required_attributes = hs.config.cas_required_attributes
|
||||||
|
|
||||||
|
self._http_client = hs.get_proxied_http_client()
|
||||||
|
|
||||||
|
def _build_service_param(self, client_redirect_url: AnyStr) -> str:
|
||||||
|
return "%s%s?%s" % (
|
||||||
|
self._cas_service_url,
|
||||||
|
"/_matrix/client/r0/login/cas/ticket",
|
||||||
|
urllib.parse.urlencode({"redirectUrl": client_redirect_url}),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _handle_cas_response(
|
||||||
|
self, request: SynapseRequest, cas_response_body: str, client_redirect_url: str
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Retrieves the user and display name from the CAS response and continues with the authentication.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: The original client request.
|
||||||
|
cas_response_body: The response from the CAS server.
|
||||||
|
client_redirect_url: The URl to redirect the client to when
|
||||||
|
everything is done.
|
||||||
|
"""
|
||||||
|
user, attributes = self._parse_cas_response(cas_response_body)
|
||||||
|
displayname = attributes.pop(self._cas_displayname_attribute, None)
|
||||||
|
|
||||||
|
for required_attribute, required_value in self._cas_required_attributes.items():
|
||||||
|
# If required attribute was not in CAS Response - Forbidden
|
||||||
|
if required_attribute not in attributes:
|
||||||
|
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
|
||||||
|
|
||||||
|
# Also need to check value
|
||||||
|
if required_value is not None:
|
||||||
|
actual_value = attributes[required_attribute]
|
||||||
|
# If required attribute value does not match expected - Forbidden
|
||||||
|
if required_value != actual_value:
|
||||||
|
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
|
||||||
|
|
||||||
|
await self._on_successful_auth(user, request, client_redirect_url, displayname)
|
||||||
|
|
||||||
|
def _parse_cas_response(
|
||||||
|
self, cas_response_body: str
|
||||||
|
) -> Tuple[str, Dict[str, Optional[str]]]:
|
||||||
|
"""
|
||||||
|
Retrieve the user and other parameters from the CAS response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cas_response_body: The response from the CAS query.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of the user and a mapping of other attributes.
|
||||||
|
"""
|
||||||
|
user = None
|
||||||
|
attributes = {}
|
||||||
|
try:
|
||||||
|
root = ET.fromstring(cas_response_body)
|
||||||
|
if not root.tag.endswith("serviceResponse"):
|
||||||
|
raise Exception("root of CAS response is not serviceResponse")
|
||||||
|
success = root[0].tag.endswith("authenticationSuccess")
|
||||||
|
for child in root[0]:
|
||||||
|
if child.tag.endswith("user"):
|
||||||
|
user = child.text
|
||||||
|
if child.tag.endswith("attributes"):
|
||||||
|
for attribute in child:
|
||||||
|
# ElementTree library expands the namespace in
|
||||||
|
# attribute tags to the full URL of the namespace.
|
||||||
|
# We don't care about namespace here and it will always
|
||||||
|
# be encased in curly braces, so we remove them.
|
||||||
|
tag = attribute.tag
|
||||||
|
if "}" in tag:
|
||||||
|
tag = tag.split("}")[1]
|
||||||
|
attributes[tag] = attribute.text
|
||||||
|
if user is None:
|
||||||
|
raise Exception("CAS response does not contain user")
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error parsing CAS response")
|
||||||
|
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
|
||||||
|
if not success:
|
||||||
|
raise LoginError(
|
||||||
|
401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED
|
||||||
|
)
|
||||||
|
return user, attributes
|
||||||
|
|
||||||
|
async def _on_successful_auth(
|
||||||
|
self,
|
||||||
|
username: str,
|
||||||
|
request: SynapseRequest,
|
||||||
|
client_redirect_url: str,
|
||||||
|
user_display_name: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Called once the user has successfully authenticated with the SSO.
|
||||||
|
|
||||||
|
Registers the user if necessary, and then returns a redirect (with
|
||||||
|
a login token) to the client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
username: the remote user id. We'll map this onto
|
||||||
|
something sane for a MXID localpath.
|
||||||
|
|
||||||
|
request: the incoming request from the browser. We'll
|
||||||
|
respond to it with a redirect.
|
||||||
|
|
||||||
|
client_redirect_url: the redirect_url the client gave us when
|
||||||
|
it first started the process.
|
||||||
|
|
||||||
|
user_display_name: if set, and we have to register a new user,
|
||||||
|
we will set their displayname to this.
|
||||||
|
"""
|
||||||
|
localpart = map_username_to_mxid_localpart(username)
|
||||||
|
user_id = UserID(localpart, self._hostname).to_string()
|
||||||
|
registered_user_id = await self._auth_handler.check_user_exists(user_id)
|
||||||
|
if not registered_user_id:
|
||||||
|
registered_user_id = await self._registration_handler.register_user(
|
||||||
|
localpart=localpart, default_display_name=user_display_name
|
||||||
|
)
|
||||||
|
|
||||||
|
self._auth_handler.complete_sso_login(
|
||||||
|
registered_user_id, request, client_redirect_url
|
||||||
|
)
|
||||||
|
|
||||||
|
def handle_redirect_request(self, client_redirect_url: bytes) -> bytes:
|
||||||
|
"""
|
||||||
|
Generates a URL to the CAS server where the client should be redirected.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client_redirect_url: The final URL the client should go to after the
|
||||||
|
user has negotiated SSO.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The URL to redirect to.
|
||||||
|
"""
|
||||||
|
args = urllib.parse.urlencode(
|
||||||
|
{"service": self._build_service_param(client_redirect_url)}
|
||||||
|
)
|
||||||
|
|
||||||
|
return ("%s/login?%s" % (self._cas_server_url, args)).encode("ascii")
|
||||||
|
|
||||||
|
async def handle_ticket_request(
|
||||||
|
self, request: SynapseRequest, client_redirect_url: str, ticket: str
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Validates a CAS ticket sent by the client for login/registration.
|
||||||
|
|
||||||
|
On a successful request, writes a redirect to the request.
|
||||||
|
"""
|
||||||
|
uri = self._cas_server_url + "/proxyValidate"
|
||||||
|
args = {
|
||||||
|
"ticket": ticket,
|
||||||
|
"service": self._build_service_param(client_redirect_url),
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
body = await self._http_client.get_raw(uri, args)
|
||||||
|
except PartialDownloadError as pde:
|
||||||
|
# Twisted raises this error if the connection is closed,
|
||||||
|
# even if that's being used old-http style to signal end-of-data
|
||||||
|
body = pde.response
|
||||||
|
|
||||||
|
await self._handle_cas_response(request, body, client_redirect_url)
|
||||||
|
|
@ -0,0 +1,93 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2019 New Vector Ltd
|
||||||
|
# Copyright 2019 The Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
|
||||||
|
from synapse.api.errors import Codes, PasswordRefusedError
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PasswordPolicyHandler(object):
|
||||||
|
def __init__(self, hs):
|
||||||
|
self.policy = hs.config.password_policy
|
||||||
|
self.enabled = hs.config.password_policy_enabled
|
||||||
|
|
||||||
|
# Regexps for the spec'd policy parameters.
|
||||||
|
self.regexp_digit = re.compile("[0-9]")
|
||||||
|
self.regexp_symbol = re.compile("[^a-zA-Z0-9]")
|
||||||
|
self.regexp_uppercase = re.compile("[A-Z]")
|
||||||
|
self.regexp_lowercase = re.compile("[a-z]")
|
||||||
|
|
||||||
|
def validate_password(self, password):
|
||||||
|
"""Checks whether a given password complies with the server's policy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
password (str): The password to check against the server's policy.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
PasswordRefusedError: The password doesn't comply with the server's policy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not self.enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
minimum_accepted_length = self.policy.get("minimum_length", 0)
|
||||||
|
if len(password) < minimum_accepted_length:
|
||||||
|
raise PasswordRefusedError(
|
||||||
|
msg=(
|
||||||
|
"The password must be at least %d characters long"
|
||||||
|
% minimum_accepted_length
|
||||||
|
),
|
||||||
|
errcode=Codes.PASSWORD_TOO_SHORT,
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.policy.get("require_digit", False)
|
||||||
|
and self.regexp_digit.search(password) is None
|
||||||
|
):
|
||||||
|
raise PasswordRefusedError(
|
||||||
|
msg="The password must include at least one digit",
|
||||||
|
errcode=Codes.PASSWORD_NO_DIGIT,
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.policy.get("require_symbol", False)
|
||||||
|
and self.regexp_symbol.search(password) is None
|
||||||
|
):
|
||||||
|
raise PasswordRefusedError(
|
||||||
|
msg="The password must include at least one symbol",
|
||||||
|
errcode=Codes.PASSWORD_NO_SYMBOL,
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.policy.get("require_uppercase", False)
|
||||||
|
and self.regexp_uppercase.search(password) is None
|
||||||
|
):
|
||||||
|
raise PasswordRefusedError(
|
||||||
|
msg="The password must include at least one uppercase letter",
|
||||||
|
errcode=Codes.PASSWORD_NO_UPPERCASE,
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.policy.get("require_lowercase", False)
|
||||||
|
and self.regexp_lowercase.search(password) is None
|
||||||
|
):
|
||||||
|
raise PasswordRefusedError(
|
||||||
|
msg="The password must include at least one lowercase letter",
|
||||||
|
errcode=Codes.PASSWORD_NO_LOWERCASE,
|
||||||
|
)
|
||||||
|
|
@ -32,6 +32,7 @@ class SetPasswordHandler(BaseHandler):
|
||||||
super(SetPasswordHandler, self).__init__(hs)
|
super(SetPasswordHandler, self).__init__(hs)
|
||||||
self._auth_handler = hs.get_auth_handler()
|
self._auth_handler = hs.get_auth_handler()
|
||||||
self._device_handler = hs.get_device_handler()
|
self._device_handler = hs.get_device_handler()
|
||||||
|
self._password_policy_handler = hs.get_password_policy_handler()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def set_password(
|
def set_password(
|
||||||
|
|
@ -44,6 +45,7 @@ class SetPasswordHandler(BaseHandler):
|
||||||
if not self.hs.config.password_localdb_enabled:
|
if not self.hs.config.password_localdb_enabled:
|
||||||
raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN)
|
raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN)
|
||||||
|
|
||||||
|
self._password_policy_handler.validate_password(new_password)
|
||||||
password_hash = yield self._auth_handler.hash(new_password)
|
password_hash = yield self._auth_handler.hash(new_password)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -41,6 +41,7 @@ from synapse.rest.client.v2_alpha import (
|
||||||
keys,
|
keys,
|
||||||
notifications,
|
notifications,
|
||||||
openid,
|
openid,
|
||||||
|
password_policy,
|
||||||
read_marker,
|
read_marker,
|
||||||
receipts,
|
receipts,
|
||||||
register,
|
register,
|
||||||
|
|
@ -118,6 +119,7 @@ class ClientRestResource(JsonResource):
|
||||||
capabilities.register_servlets(hs, client_resource)
|
capabilities.register_servlets(hs, client_resource)
|
||||||
account_validity.register_servlets(hs, client_resource)
|
account_validity.register_servlets(hs, client_resource)
|
||||||
relations.register_servlets(hs, client_resource)
|
relations.register_servlets(hs, client_resource)
|
||||||
|
password_policy.register_servlets(hs, client_resource)
|
||||||
|
|
||||||
# moving to /_synapse/admin
|
# moving to /_synapse/admin
|
||||||
synapse.rest.admin.register_servlets_for_client_rest_resource(
|
synapse.rest.admin.register_servlets_for_client_rest_resource(
|
||||||
|
|
|
||||||
|
|
@ -14,11 +14,6 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import xml.etree.ElementTree as ET
|
|
||||||
|
|
||||||
from six.moves import urllib
|
|
||||||
|
|
||||||
from twisted.web.client import PartialDownloadError
|
|
||||||
|
|
||||||
from synapse.api.errors import Codes, LoginError, SynapseError
|
from synapse.api.errors import Codes, LoginError, SynapseError
|
||||||
from synapse.api.ratelimiting import Ratelimiter
|
from synapse.api.ratelimiting import Ratelimiter
|
||||||
|
|
@ -28,9 +23,10 @@ from synapse.http.servlet import (
|
||||||
parse_json_object_from_request,
|
parse_json_object_from_request,
|
||||||
parse_string,
|
parse_string,
|
||||||
)
|
)
|
||||||
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.rest.client.v2_alpha._base import client_patterns
|
from synapse.rest.client.v2_alpha._base import client_patterns
|
||||||
from synapse.rest.well_known import WellKnownBuilder
|
from synapse.rest.well_known import WellKnownBuilder
|
||||||
from synapse.types import UserID, map_username_to_mxid_localpart
|
from synapse.types import UserID
|
||||||
from synapse.util.msisdn import phone_number_to_msisdn
|
from synapse.util.msisdn import phone_number_to_msisdn
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -72,14 +68,6 @@ def login_id_thirdparty_from_phone(identifier):
|
||||||
return {"type": "m.id.thirdparty", "medium": "msisdn", "address": msisdn}
|
return {"type": "m.id.thirdparty", "medium": "msisdn", "address": msisdn}
|
||||||
|
|
||||||
|
|
||||||
def build_service_param(cas_service_url, client_redirect_url):
|
|
||||||
return "%s%s?redirectUrl=%s" % (
|
|
||||||
cas_service_url,
|
|
||||||
"/_matrix/client/r0/login/cas/ticket",
|
|
||||||
urllib.parse.quote(client_redirect_url, safe=""),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LoginRestServlet(RestServlet):
|
class LoginRestServlet(RestServlet):
|
||||||
PATTERNS = client_patterns("/login$", v1=True)
|
PATTERNS = client_patterns("/login$", v1=True)
|
||||||
CAS_TYPE = "m.login.cas"
|
CAS_TYPE = "m.login.cas"
|
||||||
|
|
@ -409,7 +397,7 @@ class BaseSSORedirectServlet(RestServlet):
|
||||||
|
|
||||||
PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
|
PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
|
||||||
|
|
||||||
def on_GET(self, request):
|
def on_GET(self, request: SynapseRequest):
|
||||||
args = request.args
|
args = request.args
|
||||||
if b"redirectUrl" not in args:
|
if b"redirectUrl" not in args:
|
||||||
return 400, "Redirect URL not specified for SSO auth"
|
return 400, "Redirect URL not specified for SSO auth"
|
||||||
|
|
@ -418,15 +406,15 @@ class BaseSSORedirectServlet(RestServlet):
|
||||||
request.redirect(sso_url)
|
request.redirect(sso_url)
|
||||||
finish_request(request)
|
finish_request(request)
|
||||||
|
|
||||||
def get_sso_url(self, client_redirect_url):
|
def get_sso_url(self, client_redirect_url: bytes) -> bytes:
|
||||||
"""Get the URL to redirect to, to perform SSO auth
|
"""Get the URL to redirect to, to perform SSO auth
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
client_redirect_url (bytes): the URL that we should redirect the
|
client_redirect_url: the URL that we should redirect the
|
||||||
client to when everything is done
|
client to when everything is done
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bytes: URL to redirect to
|
URL to redirect to
|
||||||
"""
|
"""
|
||||||
# to be implemented by subclasses
|
# to be implemented by subclasses
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
@ -434,16 +422,10 @@ class BaseSSORedirectServlet(RestServlet):
|
||||||
|
|
||||||
class CasRedirectServlet(BaseSSORedirectServlet):
|
class CasRedirectServlet(BaseSSORedirectServlet):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(CasRedirectServlet, self).__init__()
|
self._cas_handler = hs.get_cas_handler()
|
||||||
self.cas_server_url = hs.config.cas_server_url
|
|
||||||
self.cas_service_url = hs.config.cas_service_url
|
|
||||||
|
|
||||||
def get_sso_url(self, client_redirect_url):
|
def get_sso_url(self, client_redirect_url: bytes) -> bytes:
|
||||||
args = urllib.parse.urlencode(
|
return self._cas_handler.handle_redirect_request(client_redirect_url)
|
||||||
{"service": build_service_param(self.cas_service_url, client_redirect_url)}
|
|
||||||
)
|
|
||||||
|
|
||||||
return "%s/login?%s" % (self.cas_server_url, args)
|
|
||||||
|
|
||||||
|
|
||||||
class CasTicketServlet(RestServlet):
|
class CasTicketServlet(RestServlet):
|
||||||
|
|
@ -451,81 +433,15 @@ class CasTicketServlet(RestServlet):
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(CasTicketServlet, self).__init__()
|
super(CasTicketServlet, self).__init__()
|
||||||
self.cas_server_url = hs.config.cas_server_url
|
self._cas_handler = hs.get_cas_handler()
|
||||||
self.cas_service_url = hs.config.cas_service_url
|
|
||||||
self.cas_displayname_attribute = hs.config.cas_displayname_attribute
|
|
||||||
self.cas_required_attributes = hs.config.cas_required_attributes
|
|
||||||
self._sso_auth_handler = SSOAuthHandler(hs)
|
|
||||||
self._http_client = hs.get_proxied_http_client()
|
|
||||||
|
|
||||||
async def on_GET(self, request):
|
async def on_GET(self, request: SynapseRequest) -> None:
|
||||||
client_redirect_url = parse_string(request, "redirectUrl", required=True)
|
client_redirect_url = parse_string(request, "redirectUrl", required=True)
|
||||||
uri = self.cas_server_url + "/proxyValidate"
|
ticket = parse_string(request, "ticket", required=True)
|
||||||
args = {
|
await self._cas_handler.handle_ticket_request(
|
||||||
"ticket": parse_string(request, "ticket", required=True),
|
request, client_redirect_url, ticket
|
||||||
"service": build_service_param(self.cas_service_url, client_redirect_url),
|
|
||||||
}
|
|
||||||
try:
|
|
||||||
body = await self._http_client.get_raw(uri, args)
|
|
||||||
except PartialDownloadError as pde:
|
|
||||||
# Twisted raises this error if the connection is closed,
|
|
||||||
# even if that's being used old-http style to signal end-of-data
|
|
||||||
body = pde.response
|
|
||||||
result = await self.handle_cas_response(request, body, client_redirect_url)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def handle_cas_response(self, request, cas_response_body, client_redirect_url):
|
|
||||||
user, attributes = self.parse_cas_response(cas_response_body)
|
|
||||||
displayname = attributes.pop(self.cas_displayname_attribute, None)
|
|
||||||
|
|
||||||
for required_attribute, required_value in self.cas_required_attributes.items():
|
|
||||||
# If required attribute was not in CAS Response - Forbidden
|
|
||||||
if required_attribute not in attributes:
|
|
||||||
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
|
|
||||||
|
|
||||||
# Also need to check value
|
|
||||||
if required_value is not None:
|
|
||||||
actual_value = attributes[required_attribute]
|
|
||||||
# If required attribute value does not match expected - Forbidden
|
|
||||||
if required_value != actual_value:
|
|
||||||
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
|
|
||||||
|
|
||||||
return self._sso_auth_handler.on_successful_auth(
|
|
||||||
user, request, client_redirect_url, displayname
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def parse_cas_response(self, cas_response_body):
|
|
||||||
user = None
|
|
||||||
attributes = {}
|
|
||||||
try:
|
|
||||||
root = ET.fromstring(cas_response_body)
|
|
||||||
if not root.tag.endswith("serviceResponse"):
|
|
||||||
raise Exception("root of CAS response is not serviceResponse")
|
|
||||||
success = root[0].tag.endswith("authenticationSuccess")
|
|
||||||
for child in root[0]:
|
|
||||||
if child.tag.endswith("user"):
|
|
||||||
user = child.text
|
|
||||||
if child.tag.endswith("attributes"):
|
|
||||||
for attribute in child:
|
|
||||||
# ElementTree library expands the namespace in
|
|
||||||
# attribute tags to the full URL of the namespace.
|
|
||||||
# We don't care about namespace here and it will always
|
|
||||||
# be encased in curly braces, so we remove them.
|
|
||||||
tag = attribute.tag
|
|
||||||
if "}" in tag:
|
|
||||||
tag = tag.split("}")[1]
|
|
||||||
attributes[tag] = attribute.text
|
|
||||||
if user is None:
|
|
||||||
raise Exception("CAS response does not contain user")
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Error parsing CAS response")
|
|
||||||
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
|
|
||||||
if not success:
|
|
||||||
raise LoginError(
|
|
||||||
401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED
|
|
||||||
)
|
|
||||||
return user, attributes
|
|
||||||
|
|
||||||
|
|
||||||
class SAMLRedirectServlet(BaseSSORedirectServlet):
|
class SAMLRedirectServlet(BaseSSORedirectServlet):
|
||||||
PATTERNS = client_patterns("/login/sso/redirect", v1=True)
|
PATTERNS = client_patterns("/login/sso/redirect", v1=True)
|
||||||
|
|
@ -533,65 +449,10 @@ class SAMLRedirectServlet(BaseSSORedirectServlet):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
self._saml_handler = hs.get_saml_handler()
|
self._saml_handler = hs.get_saml_handler()
|
||||||
|
|
||||||
def get_sso_url(self, client_redirect_url):
|
def get_sso_url(self, client_redirect_url: bytes) -> bytes:
|
||||||
return self._saml_handler.handle_redirect_request(client_redirect_url)
|
return self._saml_handler.handle_redirect_request(client_redirect_url)
|
||||||
|
|
||||||
|
|
||||||
class SSOAuthHandler(object):
|
|
||||||
"""
|
|
||||||
Utility class for Resources and Servlets which handle the response from a SSO
|
|
||||||
service
|
|
||||||
|
|
||||||
Args:
|
|
||||||
hs (synapse.server.HomeServer)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, hs):
|
|
||||||
self._hostname = hs.hostname
|
|
||||||
self._auth_handler = hs.get_auth_handler()
|
|
||||||
self._registration_handler = hs.get_registration_handler()
|
|
||||||
self._macaroon_gen = hs.get_macaroon_generator()
|
|
||||||
|
|
||||||
# cast to tuple for use with str.startswith
|
|
||||||
self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist)
|
|
||||||
|
|
||||||
async def on_successful_auth(
|
|
||||||
self, username, request, client_redirect_url, user_display_name=None
|
|
||||||
):
|
|
||||||
"""Called once the user has successfully authenticated with the SSO.
|
|
||||||
|
|
||||||
Registers the user if necessary, and then returns a redirect (with
|
|
||||||
a login token) to the client.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
username (unicode|bytes): the remote user id. We'll map this onto
|
|
||||||
something sane for a MXID localpath.
|
|
||||||
|
|
||||||
request (SynapseRequest): the incoming request from the browser. We'll
|
|
||||||
respond to it with a redirect.
|
|
||||||
|
|
||||||
client_redirect_url (unicode): the redirect_url the client gave us when
|
|
||||||
it first started the process.
|
|
||||||
|
|
||||||
user_display_name (unicode|None): if set, and we have to register a new user,
|
|
||||||
we will set their displayname to this.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Deferred[none]: Completes once we have handled the request.
|
|
||||||
"""
|
|
||||||
localpart = map_username_to_mxid_localpart(username)
|
|
||||||
user_id = UserID(localpart, self._hostname).to_string()
|
|
||||||
registered_user_id = await self._auth_handler.check_user_exists(user_id)
|
|
||||||
if not registered_user_id:
|
|
||||||
registered_user_id = await self._registration_handler.register_user(
|
|
||||||
localpart=localpart, default_display_name=user_display_name
|
|
||||||
)
|
|
||||||
|
|
||||||
self._auth_handler.complete_sso_login(
|
|
||||||
registered_user_id, request, client_redirect_url
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def register_servlets(hs, http_server):
|
def register_servlets(hs, http_server):
|
||||||
LoginRestServlet(hs).register(http_server)
|
LoginRestServlet(hs).register(http_server)
|
||||||
if hs.config.cas_enabled:
|
if hs.config.cas_enabled:
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,58 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2019 The Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from synapse.http.servlet import RestServlet
|
||||||
|
|
||||||
|
from ._base import client_patterns
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PasswordPolicyServlet(RestServlet):
|
||||||
|
PATTERNS = client_patterns("/password_policy$")
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hs (synapse.server.HomeServer): server
|
||||||
|
"""
|
||||||
|
super(PasswordPolicyServlet, self).__init__()
|
||||||
|
|
||||||
|
self.policy = hs.config.password_policy
|
||||||
|
self.enabled = hs.config.password_policy_enabled
|
||||||
|
|
||||||
|
def on_GET(self, request):
|
||||||
|
if not self.enabled or not self.policy:
|
||||||
|
return (200, {})
|
||||||
|
|
||||||
|
policy = {}
|
||||||
|
|
||||||
|
for param in [
|
||||||
|
"minimum_length",
|
||||||
|
"require_digit",
|
||||||
|
"require_symbol",
|
||||||
|
"require_lowercase",
|
||||||
|
"require_uppercase",
|
||||||
|
]:
|
||||||
|
if param in self.policy:
|
||||||
|
policy["m.%s" % param] = self.policy[param]
|
||||||
|
|
||||||
|
return (200, policy)
|
||||||
|
|
||||||
|
|
||||||
|
def register_servlets(hs, http_server):
|
||||||
|
PasswordPolicyServlet(hs).register(http_server)
|
||||||
|
|
@ -373,6 +373,7 @@ class RegisterRestServlet(RestServlet):
|
||||||
self.room_member_handler = hs.get_room_member_handler()
|
self.room_member_handler = hs.get_room_member_handler()
|
||||||
self.macaroon_gen = hs.get_macaroon_generator()
|
self.macaroon_gen = hs.get_macaroon_generator()
|
||||||
self.ratelimiter = hs.get_registration_ratelimiter()
|
self.ratelimiter = hs.get_registration_ratelimiter()
|
||||||
|
self.password_policy_handler = hs.get_password_policy_handler()
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
self._registration_flows = _calculate_registration_flows(
|
self._registration_flows = _calculate_registration_flows(
|
||||||
|
|
@ -420,6 +421,7 @@ class RegisterRestServlet(RestServlet):
|
||||||
or len(body["password"]) > 512
|
or len(body["password"]) > 512
|
||||||
):
|
):
|
||||||
raise SynapseError(400, "Invalid password")
|
raise SynapseError(400, "Invalid password")
|
||||||
|
self.password_policy_handler.validate_password(body["password"])
|
||||||
|
|
||||||
desired_username = None
|
desired_username = None
|
||||||
if "username" in body:
|
if "username" in body:
|
||||||
|
|
|
||||||
|
|
@ -56,6 +56,7 @@ from synapse.handlers.account_validity import AccountValidityHandler
|
||||||
from synapse.handlers.acme import AcmeHandler
|
from synapse.handlers.acme import AcmeHandler
|
||||||
from synapse.handlers.appservice import ApplicationServicesHandler
|
from synapse.handlers.appservice import ApplicationServicesHandler
|
||||||
from synapse.handlers.auth import AuthHandler, MacaroonGenerator
|
from synapse.handlers.auth import AuthHandler, MacaroonGenerator
|
||||||
|
from synapse.handlers.cas_handler import CasHandler
|
||||||
from synapse.handlers.deactivate_account import DeactivateAccountHandler
|
from synapse.handlers.deactivate_account import DeactivateAccountHandler
|
||||||
from synapse.handlers.device import DeviceHandler, DeviceWorkerHandler
|
from synapse.handlers.device import DeviceHandler, DeviceWorkerHandler
|
||||||
from synapse.handlers.devicemessage import DeviceMessageHandler
|
from synapse.handlers.devicemessage import DeviceMessageHandler
|
||||||
|
|
@ -66,6 +67,7 @@ from synapse.handlers.groups_local import GroupsLocalHandler, GroupsLocalWorkerH
|
||||||
from synapse.handlers.initial_sync import InitialSyncHandler
|
from synapse.handlers.initial_sync import InitialSyncHandler
|
||||||
from synapse.handlers.message import EventCreationHandler, MessageHandler
|
from synapse.handlers.message import EventCreationHandler, MessageHandler
|
||||||
from synapse.handlers.pagination import PaginationHandler
|
from synapse.handlers.pagination import PaginationHandler
|
||||||
|
from synapse.handlers.password_policy import PasswordPolicyHandler
|
||||||
from synapse.handlers.presence import PresenceHandler
|
from synapse.handlers.presence import PresenceHandler
|
||||||
from synapse.handlers.profile import BaseProfileHandler, MasterProfileHandler
|
from synapse.handlers.profile import BaseProfileHandler, MasterProfileHandler
|
||||||
from synapse.handlers.read_marker import ReadMarkerHandler
|
from synapse.handlers.read_marker import ReadMarkerHandler
|
||||||
|
|
@ -197,8 +199,10 @@ class HomeServer(object):
|
||||||
"sendmail",
|
"sendmail",
|
||||||
"registration_handler",
|
"registration_handler",
|
||||||
"account_validity_handler",
|
"account_validity_handler",
|
||||||
|
"cas_handler",
|
||||||
"saml_handler",
|
"saml_handler",
|
||||||
"event_client_serializer",
|
"event_client_serializer",
|
||||||
|
"password_policy_handler",
|
||||||
"storage",
|
"storage",
|
||||||
"replication_streamer",
|
"replication_streamer",
|
||||||
]
|
]
|
||||||
|
|
@ -527,6 +531,9 @@ class HomeServer(object):
|
||||||
def build_account_validity_handler(self):
|
def build_account_validity_handler(self):
|
||||||
return AccountValidityHandler(self)
|
return AccountValidityHandler(self)
|
||||||
|
|
||||||
|
def build_cas_handler(self):
|
||||||
|
return CasHandler(self)
|
||||||
|
|
||||||
def build_saml_handler(self):
|
def build_saml_handler(self):
|
||||||
from synapse.handlers.saml_handler import SamlHandler
|
from synapse.handlers.saml_handler import SamlHandler
|
||||||
|
|
||||||
|
|
@ -535,6 +542,9 @@ class HomeServer(object):
|
||||||
def build_event_client_serializer(self):
|
def build_event_client_serializer(self):
|
||||||
return EventClientSerializer(self)
|
return EventClientSerializer(self)
|
||||||
|
|
||||||
|
def build_password_policy_handler(self):
|
||||||
|
return PasswordPolicyHandler(self)
|
||||||
|
|
||||||
def build_storage(self) -> Storage:
|
def build_storage(self) -> Storage:
|
||||||
return Storage(self, self.datastores)
|
return Storage(self, self.datastores)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,179 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2019 The Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
from synapse.api.constants import LoginType
|
||||||
|
from synapse.api.errors import Codes
|
||||||
|
from synapse.rest import admin
|
||||||
|
from synapse.rest.client.v1 import login
|
||||||
|
from synapse.rest.client.v2_alpha import account, password_policy, register
|
||||||
|
|
||||||
|
from tests import unittest
|
||||||
|
|
||||||
|
|
||||||
|
class PasswordPolicyTestCase(unittest.HomeserverTestCase):
|
||||||
|
"""Tests the password policy feature and its compliance with MSC2000.
|
||||||
|
|
||||||
|
When validating a password, Synapse does the necessary checks in this order:
|
||||||
|
|
||||||
|
1. Password is long enough
|
||||||
|
2. Password contains digit(s)
|
||||||
|
3. Password contains symbol(s)
|
||||||
|
4. Password contains uppercase letter(s)
|
||||||
|
5. Password contains lowercase letter(s)
|
||||||
|
|
||||||
|
For each test below that checks whether a password triggers the right error code,
|
||||||
|
that test provides a password good enough to pass the previous tests, but not the
|
||||||
|
one it is currently testing (nor any test that comes afterward).
|
||||||
|
"""
|
||||||
|
|
||||||
|
servlets = [
|
||||||
|
admin.register_servlets_for_client_rest_resource,
|
||||||
|
login.register_servlets,
|
||||||
|
register.register_servlets,
|
||||||
|
password_policy.register_servlets,
|
||||||
|
account.register_servlets,
|
||||||
|
]
|
||||||
|
|
||||||
|
def make_homeserver(self, reactor, clock):
|
||||||
|
self.register_url = "/_matrix/client/r0/register"
|
||||||
|
self.policy = {
|
||||||
|
"enabled": True,
|
||||||
|
"minimum_length": 10,
|
||||||
|
"require_digit": True,
|
||||||
|
"require_symbol": True,
|
||||||
|
"require_lowercase": True,
|
||||||
|
"require_uppercase": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
config = self.default_config()
|
||||||
|
config["password_config"] = {
|
||||||
|
"policy": self.policy,
|
||||||
|
}
|
||||||
|
|
||||||
|
hs = self.setup_test_homeserver(config=config)
|
||||||
|
return hs
|
||||||
|
|
||||||
|
def test_get_policy(self):
|
||||||
|
"""Tests if the /password_policy endpoint returns the configured policy."""
|
||||||
|
|
||||||
|
request, channel = self.make_request(
|
||||||
|
"GET", "/_matrix/client/r0/password_policy"
|
||||||
|
)
|
||||||
|
self.render(request)
|
||||||
|
|
||||||
|
self.assertEqual(channel.code, 200, channel.result)
|
||||||
|
self.assertEqual(
|
||||||
|
channel.json_body,
|
||||||
|
{
|
||||||
|
"m.minimum_length": 10,
|
||||||
|
"m.require_digit": True,
|
||||||
|
"m.require_symbol": True,
|
||||||
|
"m.require_lowercase": True,
|
||||||
|
"m.require_uppercase": True,
|
||||||
|
},
|
||||||
|
channel.result,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_password_too_short(self):
|
||||||
|
request_data = json.dumps({"username": "kermit", "password": "shorty"})
|
||||||
|
request, channel = self.make_request("POST", self.register_url, request_data)
|
||||||
|
self.render(request)
|
||||||
|
|
||||||
|
self.assertEqual(channel.code, 400, channel.result)
|
||||||
|
self.assertEqual(
|
||||||
|
channel.json_body["errcode"], Codes.PASSWORD_TOO_SHORT, channel.result,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_password_no_digit(self):
|
||||||
|
request_data = json.dumps({"username": "kermit", "password": "longerpassword"})
|
||||||
|
request, channel = self.make_request("POST", self.register_url, request_data)
|
||||||
|
self.render(request)
|
||||||
|
|
||||||
|
self.assertEqual(channel.code, 400, channel.result)
|
||||||
|
self.assertEqual(
|
||||||
|
channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT, channel.result,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_password_no_symbol(self):
|
||||||
|
request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword"})
|
||||||
|
request, channel = self.make_request("POST", self.register_url, request_data)
|
||||||
|
self.render(request)
|
||||||
|
|
||||||
|
self.assertEqual(channel.code, 400, channel.result)
|
||||||
|
self.assertEqual(
|
||||||
|
channel.json_body["errcode"], Codes.PASSWORD_NO_SYMBOL, channel.result,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_password_no_uppercase(self):
|
||||||
|
request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword!"})
|
||||||
|
request, channel = self.make_request("POST", self.register_url, request_data)
|
||||||
|
self.render(request)
|
||||||
|
|
||||||
|
self.assertEqual(channel.code, 400, channel.result)
|
||||||
|
self.assertEqual(
|
||||||
|
channel.json_body["errcode"], Codes.PASSWORD_NO_UPPERCASE, channel.result,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_password_no_lowercase(self):
|
||||||
|
request_data = json.dumps({"username": "kermit", "password": "L0NGERPASSWORD!"})
|
||||||
|
request, channel = self.make_request("POST", self.register_url, request_data)
|
||||||
|
self.render(request)
|
||||||
|
|
||||||
|
self.assertEqual(channel.code, 400, channel.result)
|
||||||
|
self.assertEqual(
|
||||||
|
channel.json_body["errcode"], Codes.PASSWORD_NO_LOWERCASE, channel.result,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_password_compliant(self):
|
||||||
|
request_data = json.dumps({"username": "kermit", "password": "L0ngerpassword!"})
|
||||||
|
request, channel = self.make_request("POST", self.register_url, request_data)
|
||||||
|
self.render(request)
|
||||||
|
|
||||||
|
# Getting a 401 here means the password has passed validation and the server has
|
||||||
|
# responded with a list of registration flows.
|
||||||
|
self.assertEqual(channel.code, 401, channel.result)
|
||||||
|
|
||||||
|
def test_password_change(self):
|
||||||
|
"""This doesn't test every possible use case, only that hitting /account/password
|
||||||
|
triggers the password validation code.
|
||||||
|
"""
|
||||||
|
compliant_password = "C0mpl!antpassword"
|
||||||
|
not_compliant_password = "notcompliantpassword"
|
||||||
|
|
||||||
|
user_id = self.register_user("kermit", compliant_password)
|
||||||
|
tok = self.login("kermit", compliant_password)
|
||||||
|
|
||||||
|
request_data = json.dumps(
|
||||||
|
{
|
||||||
|
"new_password": not_compliant_password,
|
||||||
|
"auth": {
|
||||||
|
"password": compliant_password,
|
||||||
|
"type": LoginType.PASSWORD,
|
||||||
|
"user": user_id,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
request, channel = self.make_request(
|
||||||
|
"POST",
|
||||||
|
"/_matrix/client/r0/account/password",
|
||||||
|
request_data,
|
||||||
|
access_token=tok,
|
||||||
|
)
|
||||||
|
self.render(request)
|
||||||
|
|
||||||
|
self.assertEqual(channel.code, 400, channel.result)
|
||||||
|
self.assertEqual(channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT)
|
||||||
1
tox.ini
1
tox.ini
|
|
@ -186,6 +186,7 @@ commands = mypy \
|
||||||
synapse/federation/sender \
|
synapse/federation/sender \
|
||||||
synapse/federation/transport \
|
synapse/federation/transport \
|
||||||
synapse/handlers/auth.py \
|
synapse/handlers/auth.py \
|
||||||
|
synapse/handlers/cas_handler.py \
|
||||||
synapse/handlers/directory.py \
|
synapse/handlers/directory.py \
|
||||||
synapse/handlers/presence.py \
|
synapse/handlers/presence.py \
|
||||||
synapse/handlers/sync.py \
|
synapse/handlers/sync.py \
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue