Merge remote-tracking branch 'origin/develop' into dbkr/password_reset_case_insensitive
						commit
						bcb1245a2d
					
				|  | @ -24,10 +24,10 @@ homeserver*.yaml | |||
| .coverage | ||||
| htmlcov | ||||
| 
 | ||||
| demo/*.db | ||||
| demo/*.log | ||||
| demo/*.log.* | ||||
| demo/*.pid | ||||
| demo/*/*.db | ||||
| demo/*/*.log | ||||
| demo/*/*.log.* | ||||
| demo/*/*.pid | ||||
| demo/media_store.* | ||||
| demo/etc | ||||
| 
 | ||||
|  |  | |||
|  | @ -197,7 +197,7 @@ class PusherServer(HomeServer): | |||
|                     yield start_pusher(user_id, app_id, pushkey) | ||||
| 
 | ||||
|             stream = results.get("events") | ||||
|             if stream: | ||||
|             if stream and stream["rows"]: | ||||
|                 min_stream_id = stream["rows"][0][0] | ||||
|                 max_stream_id = stream["position"] | ||||
|                 preserve_fn(pusher_pool.on_new_notifications)( | ||||
|  | @ -205,7 +205,7 @@ class PusherServer(HomeServer): | |||
|                 ) | ||||
| 
 | ||||
|             stream = results.get("receipts") | ||||
|             if stream: | ||||
|             if stream and stream["rows"]: | ||||
|                 rows = stream["rows"] | ||||
|                 affected_room_ids = set(row[1] for row in rows) | ||||
|                 min_stream_id = rows[0][0] | ||||
|  |  | |||
|  | @ -30,7 +30,7 @@ from .saml2 import SAML2Config | |||
| from .cas import CasConfig | ||||
| from .password import PasswordConfig | ||||
| from .jwt import JWTConfig | ||||
| from .ldap import LDAPConfig | ||||
| from .password_auth_providers import PasswordAuthProviderConfig | ||||
| from .emailconfig import EmailConfig | ||||
| from .workers import WorkerConfig | ||||
| 
 | ||||
|  | @ -39,8 +39,8 @@ class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, | |||
|                        RatelimitConfig, ContentRepositoryConfig, CaptchaConfig, | ||||
|                        VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig, | ||||
|                        AppServiceConfig, KeyConfig, SAML2Config, CasConfig, | ||||
|                        JWTConfig, LDAPConfig, PasswordConfig, EmailConfig, | ||||
|                        WorkerConfig,): | ||||
|                        JWTConfig, PasswordConfig, EmailConfig, | ||||
|                        WorkerConfig, PasswordAuthProviderConfig,): | ||||
|     pass | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
|  | @ -1,100 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | ||||
| # Copyright 2015 Niklas Riekenbrauck | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| from ._base import Config, ConfigError | ||||
| 
 | ||||
| 
 | ||||
| MISSING_LDAP3 = ( | ||||
|     "Missing ldap3 library. This is required for LDAP Authentication." | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| class LDAPMode(object): | ||||
|     SIMPLE = "simple", | ||||
|     SEARCH = "search", | ||||
| 
 | ||||
|     LIST = (SIMPLE, SEARCH) | ||||
| 
 | ||||
| 
 | ||||
| class LDAPConfig(Config): | ||||
|     def read_config(self, config): | ||||
|         ldap_config = config.get("ldap_config", {}) | ||||
| 
 | ||||
|         self.ldap_enabled = ldap_config.get("enabled", False) | ||||
| 
 | ||||
|         if self.ldap_enabled: | ||||
|             # verify dependencies are available | ||||
|             try: | ||||
|                 import ldap3 | ||||
|                 ldap3  # to stop unused lint | ||||
|             except ImportError: | ||||
|                 raise ConfigError(MISSING_LDAP3) | ||||
| 
 | ||||
|             self.ldap_mode = LDAPMode.SIMPLE | ||||
| 
 | ||||
|             # verify config sanity | ||||
|             self.require_keys(ldap_config, [ | ||||
|                 "uri", | ||||
|                 "base", | ||||
|                 "attributes", | ||||
|             ]) | ||||
| 
 | ||||
|             self.ldap_uri = ldap_config["uri"] | ||||
|             self.ldap_start_tls = ldap_config.get("start_tls", False) | ||||
|             self.ldap_base = ldap_config["base"] | ||||
|             self.ldap_attributes = ldap_config["attributes"] | ||||
| 
 | ||||
|             if "bind_dn" in ldap_config: | ||||
|                 self.ldap_mode = LDAPMode.SEARCH | ||||
|                 self.require_keys(ldap_config, [ | ||||
|                     "bind_dn", | ||||
|                     "bind_password", | ||||
|                 ]) | ||||
| 
 | ||||
|                 self.ldap_bind_dn = ldap_config["bind_dn"] | ||||
|                 self.ldap_bind_password = ldap_config["bind_password"] | ||||
|                 self.ldap_filter = ldap_config.get("filter", None) | ||||
| 
 | ||||
|             # verify attribute lookup | ||||
|             self.require_keys(ldap_config['attributes'], [ | ||||
|                 "uid", | ||||
|                 "name", | ||||
|                 "mail", | ||||
|             ]) | ||||
| 
 | ||||
|     def require_keys(self, config, required): | ||||
|         missing = [key for key in required if key not in config] | ||||
|         if missing: | ||||
|             raise ConfigError( | ||||
|                 "LDAP enabled but missing required config values: {}".format( | ||||
|                     ", ".join(missing) | ||||
|                 ) | ||||
|             ) | ||||
| 
 | ||||
|     def default_config(self, **kwargs): | ||||
|         return """\ | ||||
|         # ldap_config: | ||||
|         #   enabled: true | ||||
|         #   uri: "ldap://ldap.example.com:389" | ||||
|         #   start_tls: true | ||||
|         #   base: "ou=users,dc=example,dc=com" | ||||
|         #   attributes: | ||||
|         #      uid: "cn" | ||||
|         #      mail: "email" | ||||
|         #      name: "givenName" | ||||
|         #   #bind_dn: | ||||
|         #   #bind_password: | ||||
|         #   #filter: "(objectClass=posixAccount)" | ||||
|         """ | ||||
|  | @ -0,0 +1,61 @@ | |||
| # -*- coding: utf-8 -*- | ||||
| # Copyright 2016 Openmarket | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| from ._base import Config | ||||
| 
 | ||||
| import importlib | ||||
| 
 | ||||
| 
 | ||||
| class PasswordAuthProviderConfig(Config): | ||||
|     def read_config(self, config): | ||||
|         self.password_providers = [] | ||||
| 
 | ||||
|         # We want to be backwards compatible with the old `ldap_config` | ||||
|         # param. | ||||
|         ldap_config = config.get("ldap_config", {}) | ||||
|         self.ldap_enabled = ldap_config.get("enabled", False) | ||||
|         if self.ldap_enabled: | ||||
|             from synapse.util.ldap_auth_provider import LdapAuthProvider | ||||
|             parsed_config = LdapAuthProvider.parse_config(ldap_config) | ||||
|             self.password_providers.append((LdapAuthProvider, parsed_config)) | ||||
| 
 | ||||
|         providers = config.get("password_providers", []) | ||||
|         for provider in providers: | ||||
|             # We need to import the module, and then pick the class out of | ||||
|             # that, so we split based on the last dot. | ||||
|             module, clz = provider['module'].rsplit(".", 1) | ||||
|             module = importlib.import_module(module) | ||||
|             provider_class = getattr(module, clz) | ||||
| 
 | ||||
|             provider_config = provider_class.parse_config(provider["config"]) | ||||
|             self.password_providers.append((provider_class, provider_config)) | ||||
| 
 | ||||
|     def default_config(self, **kwargs): | ||||
|         return """\ | ||||
|         # password_providers: | ||||
|         #     - module: "synapse.util.ldap_auth_provider.LdapAuthProvider" | ||||
|         #       config: | ||||
|         #         enabled: true | ||||
|         #         uri: "ldap://ldap.example.com:389" | ||||
|         #         start_tls: true | ||||
|         #         base: "ou=users,dc=example,dc=com" | ||||
|         #         attributes: | ||||
|         #            uid: "cn" | ||||
|         #            mail: "email" | ||||
|         #            name: "givenName" | ||||
|         #         #bind_dn: | ||||
|         #         #bind_password: | ||||
|         #         #filter: "(objectClass=posixAccount)" | ||||
|         """ | ||||
|  | @ -19,6 +19,9 @@ from OpenSSL import crypto | |||
| import subprocess | ||||
| import os | ||||
| 
 | ||||
| from hashlib import sha256 | ||||
| from unpaddedbase64 import encode_base64 | ||||
| 
 | ||||
| GENERATE_DH_PARAMS = False | ||||
| 
 | ||||
| 
 | ||||
|  | @ -42,6 +45,19 @@ class TlsConfig(Config): | |||
|             config.get("tls_dh_params_path"), "tls_dh_params" | ||||
|         ) | ||||
| 
 | ||||
|         self.tls_fingerprints = config["tls_fingerprints"] | ||||
| 
 | ||||
|         # Check that our own certificate is included in the list of fingerprints | ||||
|         # and include it if it is not. | ||||
|         x509_certificate_bytes = crypto.dump_certificate( | ||||
|             crypto.FILETYPE_ASN1, | ||||
|             self.tls_certificate | ||||
|         ) | ||||
|         sha256_fingerprint = encode_base64(sha256(x509_certificate_bytes).digest()) | ||||
|         sha256_fingerprints = set(f["sha256"] for f in self.tls_fingerprints) | ||||
|         if sha256_fingerprint not in sha256_fingerprints: | ||||
|             self.tls_fingerprints.append({u"sha256": sha256_fingerprint}) | ||||
| 
 | ||||
|         # This config option applies to non-federation HTTP clients | ||||
|         # (e.g. for talking to recaptcha, identity servers, and such) | ||||
|         # It should never be used in production, and is intended for | ||||
|  | @ -73,6 +89,28 @@ class TlsConfig(Config): | |||
| 
 | ||||
|         # Don't bind to the https port | ||||
|         no_tls: False | ||||
| 
 | ||||
|         # List of allowed TLS fingerprints for this server to publish along | ||||
|         # with the signing keys for this server. Other matrix servers that | ||||
|         # make HTTPS requests to this server will check that the TLS | ||||
|         # certificates returned by this server match one of the fingerprints. | ||||
|         # | ||||
|         # Synapse automatically adds its the fingerprint of its own certificate | ||||
|         # to the list. So if federation traffic is handle directly by synapse | ||||
|         # then no modification to the list is required. | ||||
|         # | ||||
|         # If synapse is run behind a load balancer that handles the TLS then it | ||||
|         # will be necessary to add the fingerprints of the certificates used by | ||||
|         # the loadbalancers to this list if they are different to the one | ||||
|         # synapse is using. | ||||
|         # | ||||
|         # Homeservers are permitted to cache the list of TLS fingerprints | ||||
|         # returned in the key responses up to the "valid_until_ts" returned in | ||||
|         # key. It may be necessary to publish the fingerprints of a new | ||||
|         # certificate and wait until the "valid_until_ts" of the previous key | ||||
|         # responses have passed before deploying it. | ||||
|         tls_fingerprints: [] | ||||
|         # tls_fingerprints: [{"sha256": "<base64_encoded_sha256_fingerprint>"}] | ||||
|         """ % locals() | ||||
| 
 | ||||
|     def read_tls_certificate(self, cert_path): | ||||
|  |  | |||
|  | @ -20,7 +20,6 @@ from synapse.api.constants import LoginType | |||
| from synapse.types import UserID | ||||
| from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError | ||||
| from synapse.util.async import run_on_reactor | ||||
| from synapse.config.ldap import LDAPMode | ||||
| 
 | ||||
| from twisted.web.client import PartialDownloadError | ||||
| 
 | ||||
|  | @ -29,13 +28,6 @@ import bcrypt | |||
| import pymacaroons | ||||
| import simplejson | ||||
| 
 | ||||
| try: | ||||
|     import ldap3 | ||||
|     import ldap3.core.exceptions | ||||
| except ImportError: | ||||
|     ldap3 = None | ||||
|     pass | ||||
| 
 | ||||
| import synapse.util.stringutils as stringutils | ||||
| 
 | ||||
| 
 | ||||
|  | @ -60,21 +52,14 @@ class AuthHandler(BaseHandler): | |||
|         self.bcrypt_rounds = hs.config.bcrypt_rounds | ||||
|         self.sessions = {} | ||||
| 
 | ||||
|         self.ldap_enabled = hs.config.ldap_enabled | ||||
|         if self.ldap_enabled: | ||||
|             if not ldap3: | ||||
|                 raise RuntimeError( | ||||
|                     'Missing ldap3 library. This is required for LDAP Authentication.' | ||||
|                 ) | ||||
|             self.ldap_mode = hs.config.ldap_mode | ||||
|             self.ldap_uri = hs.config.ldap_uri | ||||
|             self.ldap_start_tls = hs.config.ldap_start_tls | ||||
|             self.ldap_base = hs.config.ldap_base | ||||
|             self.ldap_attributes = hs.config.ldap_attributes | ||||
|             if self.ldap_mode == LDAPMode.SEARCH: | ||||
|                 self.ldap_bind_dn = hs.config.ldap_bind_dn | ||||
|                 self.ldap_bind_password = hs.config.ldap_bind_password | ||||
|                 self.ldap_filter = hs.config.ldap_filter | ||||
|         account_handler = _AccountHandler( | ||||
|             hs, check_user_exists=self.check_user_exists | ||||
|         ) | ||||
| 
 | ||||
|         self.password_providers = [ | ||||
|             module(config=config, account_handler=account_handler) | ||||
|             for module, config in hs.config.password_providers | ||||
|         ] | ||||
| 
 | ||||
|         self.hs = hs  # FIXME better possibility to access registrationHandler later? | ||||
|         self.device_handler = hs.get_device_handler() | ||||
|  | @ -497,9 +482,10 @@ class AuthHandler(BaseHandler): | |||
|         Raises: | ||||
|             LoginError if login fails | ||||
|         """ | ||||
|         valid_ldap = yield self._check_ldap_password(user_id, password) | ||||
|         if valid_ldap: | ||||
|             defer.returnValue(user_id) | ||||
|         for provider in self.password_providers: | ||||
|             is_valid = yield provider.check_password(user_id, password) | ||||
|             if is_valid: | ||||
|                 defer.returnValue(user_id) | ||||
| 
 | ||||
|         canonical_user_id = yield self._check_local_password(user_id, password) | ||||
| 
 | ||||
|  | @ -536,275 +522,6 @@ class AuthHandler(BaseHandler): | |||
|             defer.returnValue(None) | ||||
|         defer.returnValue(user_id) | ||||
| 
 | ||||
|     def _ldap_simple_bind(self, server, localpart, password): | ||||
|         """ Attempt a simple bind with the credentials | ||||
|             given by the user against the LDAP server. | ||||
| 
 | ||||
|             Returns True, LDAP3Connection | ||||
|                 if the bind was successful | ||||
|             Returns False, None | ||||
|                 if an error occured | ||||
|         """ | ||||
| 
 | ||||
|         try: | ||||
|             # bind with the the local users ldap credentials | ||||
|             bind_dn = "{prop}={value},{base}".format( | ||||
|                 prop=self.ldap_attributes['uid'], | ||||
|                 value=localpart, | ||||
|                 base=self.ldap_base | ||||
|             ) | ||||
|             conn = ldap3.Connection(server, bind_dn, password) | ||||
|             logger.debug( | ||||
|                 "Established LDAP connection in simple bind mode: %s", | ||||
|                 conn | ||||
|             ) | ||||
| 
 | ||||
|             if self.ldap_start_tls: | ||||
|                 conn.start_tls() | ||||
|                 logger.debug( | ||||
|                     "Upgraded LDAP connection in simple bind mode through StartTLS: %s", | ||||
|                     conn | ||||
|                 ) | ||||
| 
 | ||||
|             if conn.bind(): | ||||
|                 # GOOD: bind okay | ||||
|                 logger.debug("LDAP Bind successful in simple bind mode.") | ||||
|                 return True, conn | ||||
| 
 | ||||
|             # BAD: bind failed | ||||
|             logger.info( | ||||
|                 "Binding against LDAP failed for '%s' failed: %s", | ||||
|                 localpart, conn.result['description'] | ||||
|             ) | ||||
|             conn.unbind() | ||||
|             return False, None | ||||
| 
 | ||||
|         except ldap3.core.exceptions.LDAPException as e: | ||||
|             logger.warn("Error during LDAP authentication: %s", e) | ||||
|             return False, None | ||||
| 
 | ||||
|     def _ldap_authenticated_search(self, server, localpart, password): | ||||
|         """ Attempt to login with the preconfigured bind_dn | ||||
|             and then continue searching and filtering within | ||||
|             the base_dn | ||||
| 
 | ||||
|             Returns (True, LDAP3Connection) | ||||
|                 if a single matching DN within the base was found | ||||
|                 that matched the filter expression, and with which | ||||
|                 a successful bind was achieved | ||||
| 
 | ||||
|                 The LDAP3Connection returned is the instance that was used to | ||||
|                 verify the password not the one using the configured bind_dn. | ||||
|             Returns (False, None) | ||||
|                 if an error occured | ||||
|         """ | ||||
| 
 | ||||
|         try: | ||||
|             conn = ldap3.Connection( | ||||
|                 server, | ||||
|                 self.ldap_bind_dn, | ||||
|                 self.ldap_bind_password | ||||
|             ) | ||||
|             logger.debug( | ||||
|                 "Established LDAP connection in search mode: %s", | ||||
|                 conn | ||||
|             ) | ||||
| 
 | ||||
|             if self.ldap_start_tls: | ||||
|                 conn.start_tls() | ||||
|                 logger.debug( | ||||
|                     "Upgraded LDAP connection in search mode through StartTLS: %s", | ||||
|                     conn | ||||
|                 ) | ||||
| 
 | ||||
|             if not conn.bind(): | ||||
|                 logger.warn( | ||||
|                     "Binding against LDAP with `bind_dn` failed: %s", | ||||
|                     conn.result['description'] | ||||
|                 ) | ||||
|                 conn.unbind() | ||||
|                 return False, None | ||||
| 
 | ||||
|             # construct search_filter like (uid=localpart) | ||||
|             query = "({prop}={value})".format( | ||||
|                 prop=self.ldap_attributes['uid'], | ||||
|                 value=localpart | ||||
|             ) | ||||
|             if self.ldap_filter: | ||||
|                 # combine with the AND expression | ||||
|                 query = "(&{query}{filter})".format( | ||||
|                     query=query, | ||||
|                     filter=self.ldap_filter | ||||
|                 ) | ||||
|             logger.debug( | ||||
|                 "LDAP search filter: %s", | ||||
|                 query | ||||
|             ) | ||||
|             conn.search( | ||||
|                 search_base=self.ldap_base, | ||||
|                 search_filter=query | ||||
|             ) | ||||
| 
 | ||||
|             if len(conn.response) == 1: | ||||
|                 # GOOD: found exactly one result | ||||
|                 user_dn = conn.response[0]['dn'] | ||||
|                 logger.debug('LDAP search found dn: %s', user_dn) | ||||
| 
 | ||||
|                 # unbind and simple bind with user_dn to verify the password | ||||
|                 # Note: do not use rebind(), for some reason it did not verify | ||||
|                 #       the password for me! | ||||
|                 conn.unbind() | ||||
|                 return self._ldap_simple_bind(server, localpart, password) | ||||
|             else: | ||||
|                 # BAD: found 0 or > 1 results, abort! | ||||
|                 if len(conn.response) == 0: | ||||
|                     logger.info( | ||||
|                         "LDAP search returned no results for '%s'", | ||||
|                         localpart | ||||
|                     ) | ||||
|                 else: | ||||
|                     logger.info( | ||||
|                         "LDAP search returned too many (%s) results for '%s'", | ||||
|                         len(conn.response), localpart | ||||
|                     ) | ||||
|                 conn.unbind() | ||||
|                 return False, None | ||||
| 
 | ||||
|         except ldap3.core.exceptions.LDAPException as e: | ||||
|             logger.warn("Error during LDAP authentication: %s", e) | ||||
|             return False, None | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def _check_ldap_password(self, user_id, password): | ||||
|         """ Attempt to authenticate a user against an LDAP Server | ||||
|             and register an account if none exists. | ||||
| 
 | ||||
|             Returns: | ||||
|                 True if authentication against LDAP was successful | ||||
|         """ | ||||
| 
 | ||||
|         if not ldap3 or not self.ldap_enabled: | ||||
|             defer.returnValue(False) | ||||
| 
 | ||||
|         localpart = UserID.from_string(user_id).localpart | ||||
| 
 | ||||
|         try: | ||||
|             server = ldap3.Server(self.ldap_uri) | ||||
|             logger.debug( | ||||
|                 "Attempting LDAP connection with %s", | ||||
|                 self.ldap_uri | ||||
|             ) | ||||
| 
 | ||||
|             if self.ldap_mode == LDAPMode.SIMPLE: | ||||
|                 result, conn = self._ldap_simple_bind( | ||||
|                     server=server, localpart=localpart, password=password | ||||
|                 ) | ||||
|                 logger.debug( | ||||
|                     'LDAP authentication method simple bind returned: %s (conn: %s)', | ||||
|                     result, | ||||
|                     conn | ||||
|                 ) | ||||
|                 if not result: | ||||
|                     defer.returnValue(False) | ||||
|             elif self.ldap_mode == LDAPMode.SEARCH: | ||||
|                 result, conn = self._ldap_authenticated_search( | ||||
|                     server=server, localpart=localpart, password=password | ||||
|                 ) | ||||
|                 logger.debug( | ||||
|                     'LDAP auth method authenticated search returned: %s (conn: %s)', | ||||
|                     result, | ||||
|                     conn | ||||
|                 ) | ||||
|                 if not result: | ||||
|                     defer.returnValue(False) | ||||
|             else: | ||||
|                 raise RuntimeError( | ||||
|                     'Invalid LDAP mode specified: {mode}'.format( | ||||
|                         mode=self.ldap_mode | ||||
|                     ) | ||||
|                 ) | ||||
| 
 | ||||
|             try: | ||||
|                 logger.info( | ||||
|                     "User authenticated against LDAP server: %s", | ||||
|                     conn | ||||
|                 ) | ||||
|             except NameError: | ||||
|                 logger.warn("Authentication method yielded no LDAP connection, aborting!") | ||||
|                 defer.returnValue(False) | ||||
| 
 | ||||
|             # check if user with user_id exists | ||||
|             if (yield self.check_user_exists(user_id)): | ||||
|                 # exists, authentication complete | ||||
|                 conn.unbind() | ||||
|                 defer.returnValue(True) | ||||
| 
 | ||||
|             else: | ||||
|                 # does not exist, fetch metadata for account creation from | ||||
|                 # existing ldap connection | ||||
|                 query = "({prop}={value})".format( | ||||
|                     prop=self.ldap_attributes['uid'], | ||||
|                     value=localpart | ||||
|                 ) | ||||
| 
 | ||||
|                 if self.ldap_mode == LDAPMode.SEARCH and self.ldap_filter: | ||||
|                     query = "(&{filter}{user_filter})".format( | ||||
|                         filter=query, | ||||
|                         user_filter=self.ldap_filter | ||||
|                     ) | ||||
|                 logger.debug( | ||||
|                     "ldap registration filter: %s", | ||||
|                     query | ||||
|                 ) | ||||
| 
 | ||||
|                 conn.search( | ||||
|                     search_base=self.ldap_base, | ||||
|                     search_filter=query, | ||||
|                     attributes=[ | ||||
|                         self.ldap_attributes['name'], | ||||
|                         self.ldap_attributes['mail'] | ||||
|                     ] | ||||
|                 ) | ||||
| 
 | ||||
|                 if len(conn.response) == 1: | ||||
|                     attrs = conn.response[0]['attributes'] | ||||
|                     mail = attrs[self.ldap_attributes['mail']][0] | ||||
|                     name = attrs[self.ldap_attributes['name']][0] | ||||
| 
 | ||||
|                     # create account | ||||
|                     registration_handler = self.hs.get_handlers().registration_handler | ||||
|                     user_id, access_token = ( | ||||
|                         yield registration_handler.register(localpart=localpart) | ||||
|                     ) | ||||
| 
 | ||||
|                     # TODO: bind email, set displayname with data from ldap directory | ||||
| 
 | ||||
|                     logger.info( | ||||
|                         "Registration based on LDAP data was successful: %d: %s (%s, %)", | ||||
|                         user_id, | ||||
|                         localpart, | ||||
|                         name, | ||||
|                         mail | ||||
|                     ) | ||||
| 
 | ||||
|                     defer.returnValue(True) | ||||
|                 else: | ||||
|                     if len(conn.response) == 0: | ||||
|                         logger.warn("LDAP registration failed, no result.") | ||||
|                     else: | ||||
|                         logger.warn( | ||||
|                             "LDAP registration failed, too many results (%s)", | ||||
|                             len(conn.response) | ||||
|                         ) | ||||
| 
 | ||||
|                     defer.returnValue(False) | ||||
| 
 | ||||
|             defer.returnValue(False) | ||||
| 
 | ||||
|         except ldap3.core.exceptions.LDAPException as e: | ||||
|             logger.warn("Error during ldap authentication: %s", e) | ||||
|             defer.returnValue(False) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def issue_access_token(self, user_id, device_id=None): | ||||
|         access_token = self.generate_access_token(user_id) | ||||
|  | @ -942,3 +659,30 @@ class AuthHandler(BaseHandler): | |||
|                                  stored_hash.encode('utf-8')) == stored_hash | ||||
|         else: | ||||
|             return False | ||||
| 
 | ||||
| 
 | ||||
| class _AccountHandler(object): | ||||
|     """A proxy object that gets passed to password auth providers so they | ||||
|     can register new users etc if necessary. | ||||
|     """ | ||||
|     def __init__(self, hs, check_user_exists): | ||||
|         self.hs = hs | ||||
| 
 | ||||
|         self._check_user_exists = check_user_exists | ||||
| 
 | ||||
|     def check_user_exists(self, user_id): | ||||
|         """Check if user exissts. | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred(bool) | ||||
|         """ | ||||
|         return self._check_user_exists(user_id) | ||||
| 
 | ||||
|     def register(self, localpart): | ||||
|         """Registers a new user with given localpart | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred: a 2-tuple of (user_id, access_token) | ||||
|         """ | ||||
|         reg = self.hs.get_handlers().registration_handler | ||||
|         return reg.register(localpart=localpart) | ||||
|  |  | |||
|  | @ -328,7 +328,7 @@ class Mailer(object): | |||
|         return messagevars | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def make_summary_text(self, notifs_by_room, state_by_room, | ||||
|     def make_summary_text(self, notifs_by_room, room_state_ids, | ||||
|                           notif_events, user_id, reason): | ||||
|         if len(notifs_by_room) == 1: | ||||
|             # Only one room has new stuff | ||||
|  | @ -338,14 +338,18 @@ class Mailer(object): | |||
|             # want the generated-from-names one here otherwise we'll | ||||
|             # end up with, "new message from Bob in the Bob room" | ||||
|             room_name = yield calculate_room_name( | ||||
|                 self.store, state_by_room[room_id], user_id, fallback_to_members=False | ||||
|                 self.store, room_state_ids[room_id], user_id, fallback_to_members=False | ||||
|             ) | ||||
| 
 | ||||
|             my_member_event = state_by_room[room_id][("m.room.member", user_id)] | ||||
|             my_member_event_id = room_state_ids[room_id][("m.room.member", user_id)] | ||||
|             my_member_event = yield self.store.get_event(my_member_event_id) | ||||
|             if my_member_event.content["membership"] == "invite": | ||||
|                 inviter_member_event = state_by_room[room_id][ | ||||
|                 inviter_member_event_id = room_state_ids[room_id][ | ||||
|                     ("m.room.member", my_member_event.sender) | ||||
|                 ] | ||||
|                 inviter_member_event = yield self.store.get_event( | ||||
|                     inviter_member_event_id | ||||
|                 ) | ||||
|                 inviter_name = name_from_member_event(inviter_member_event) | ||||
| 
 | ||||
|                 if room_name is None: | ||||
|  | @ -364,8 +368,11 @@ class Mailer(object): | |||
|             if len(notifs_by_room[room_id]) == 1: | ||||
|                 # There is just the one notification, so give some detail | ||||
|                 event = notif_events[notifs_by_room[room_id][0]["event_id"]] | ||||
|                 if ("m.room.member", event.sender) in state_by_room[room_id]: | ||||
|                     state_event = state_by_room[room_id][("m.room.member", event.sender)] | ||||
|                 if ("m.room.member", event.sender) in room_state_ids[room_id]: | ||||
|                     state_event_id = room_state_ids[room_id][ | ||||
|                         ("m.room.member", event.sender) | ||||
|                     ] | ||||
|                     state_event = yield self.get_event(state_event_id) | ||||
|                     sender_name = name_from_member_event(state_event) | ||||
| 
 | ||||
|                 if sender_name is not None and room_name is not None: | ||||
|  | @ -395,11 +402,13 @@ class Mailer(object): | |||
|                         for n in notifs_by_room[room_id] | ||||
|                     ])) | ||||
| 
 | ||||
|                     member_events = yield self.store.get_events([ | ||||
|                         room_state_ids[room_id][("m.room.member", s)] | ||||
|                         for s in sender_ids | ||||
|                     ]) | ||||
| 
 | ||||
|                     defer.returnValue(MESSAGES_FROM_PERSON % { | ||||
|                         "person": descriptor_from_member_events([ | ||||
|                             state_by_room[room_id][("m.room.member", s)] | ||||
|                             for s in sender_ids | ||||
|                         ]), | ||||
|                         "person": descriptor_from_member_events(member_events.values()), | ||||
|                         "app": self.app_name, | ||||
|                     }) | ||||
|         else: | ||||
|  | @ -419,11 +428,13 @@ class Mailer(object): | |||
|                     for n in notifs_by_room[reason['room_id']] | ||||
|                 ])) | ||||
| 
 | ||||
|                 member_events = yield self.store.get_events([ | ||||
|                     room_state_ids[room_id][("m.room.member", s)] | ||||
|                     for s in sender_ids | ||||
|                 ]) | ||||
| 
 | ||||
|                 defer.returnValue(MESSAGES_FROM_PERSON_AND_OTHERS % { | ||||
|                     "person": descriptor_from_member_events([ | ||||
|                         state_by_room[reason['room_id']][("m.room.member", s)] | ||||
|                         for s in sender_ids | ||||
|                     ]), | ||||
|                     "person": descriptor_from_member_events(member_events.values()), | ||||
|                     "app": self.app_name, | ||||
|                 }) | ||||
| 
 | ||||
|  |  | |||
|  | @ -17,6 +17,7 @@ from synapse.http.servlet import parse_integer, parse_string | |||
| from synapse.http.server import request_handler, finish_request | ||||
| from synapse.replication.pusher_resource import PusherResource | ||||
| from synapse.replication.presence_resource import PresenceResource | ||||
| from synapse.api.errors import SynapseError | ||||
| 
 | ||||
| from twisted.web.resource import Resource | ||||
| from twisted.web.server import NOT_DONE_YET | ||||
|  | @ -166,7 +167,8 @@ class ReplicationResource(Resource): | |||
|         def replicate(): | ||||
|             return self.replicate(request_streams, limit) | ||||
| 
 | ||||
|         result = yield self.notifier.wait_for_replication(replicate, timeout) | ||||
|         writer = yield self.notifier.wait_for_replication(replicate, timeout) | ||||
|         result = writer.finish() | ||||
| 
 | ||||
|         for stream_name, stream_content in result.items(): | ||||
|             logger.info( | ||||
|  | @ -186,6 +188,9 @@ class ReplicationResource(Resource): | |||
|         current_token = yield self.current_replication_token() | ||||
|         logger.debug("Replicating up to %r", current_token) | ||||
| 
 | ||||
|         if limit == 0: | ||||
|             raise SynapseError(400, "Limit cannot be 0") | ||||
| 
 | ||||
|         yield self.account_data(writer, current_token, limit, request_streams) | ||||
|         yield self.events(writer, current_token, limit, request_streams) | ||||
|         # TODO: implement limit | ||||
|  | @ -200,7 +205,7 @@ class ReplicationResource(Resource): | |||
|         self.streams(writer, current_token, request_streams) | ||||
| 
 | ||||
|         logger.debug("Replicated %d rows", writer.total) | ||||
|         defer.returnValue(writer.finish()) | ||||
|         defer.returnValue(writer) | ||||
| 
 | ||||
|     def streams(self, writer, current_token, request_streams): | ||||
|         request_token = request_streams.get("streams") | ||||
|  | @ -237,27 +242,48 @@ class ReplicationResource(Resource): | |||
|                 request_events = current_token.events | ||||
|             if request_backfill is None: | ||||
|                 request_backfill = current_token.backfill | ||||
| 
 | ||||
|             no_new_tokens = ( | ||||
|                 request_events == current_token.events | ||||
|                 and request_backfill == current_token.backfill | ||||
|             ) | ||||
|             if no_new_tokens: | ||||
|                 return | ||||
| 
 | ||||
|             res = yield self.store.get_all_new_events( | ||||
|                 request_backfill, request_events, | ||||
|                 current_token.backfill, current_token.events, | ||||
|                 limit | ||||
|             ) | ||||
|             writer.write_header_and_rows("events", res.new_forward_events, ( | ||||
|                 "position", "internal", "json", "state_group" | ||||
|             )) | ||||
|             writer.write_header_and_rows("backfill", res.new_backfill_events, ( | ||||
|                 "position", "internal", "json", "state_group" | ||||
|             )) | ||||
| 
 | ||||
|             upto_events_token = _position_from_rows( | ||||
|                 res.new_forward_events, current_token.events | ||||
|             ) | ||||
| 
 | ||||
|             upto_backfill_token = _position_from_rows( | ||||
|                 res.new_backfill_events, current_token.backfill | ||||
|             ) | ||||
| 
 | ||||
|             if request_events != upto_events_token: | ||||
|                 writer.write_header_and_rows("events", res.new_forward_events, ( | ||||
|                     "position", "internal", "json", "state_group" | ||||
|                 ), position=upto_events_token) | ||||
| 
 | ||||
|             if request_backfill != upto_backfill_token: | ||||
|                 writer.write_header_and_rows("backfill", res.new_backfill_events, ( | ||||
|                     "position", "internal", "json", "state_group", | ||||
|                 ), position=upto_backfill_token) | ||||
| 
 | ||||
|             writer.write_header_and_rows( | ||||
|                 "forward_ex_outliers", res.forward_ex_outliers, | ||||
|                 ("position", "event_id", "state_group") | ||||
|                 ("position", "event_id", "state_group"), | ||||
|             ) | ||||
|             writer.write_header_and_rows( | ||||
|                 "backward_ex_outliers", res.backward_ex_outliers, | ||||
|                 ("position", "event_id", "state_group") | ||||
|                 ("position", "event_id", "state_group"), | ||||
|             ) | ||||
|             writer.write_header_and_rows( | ||||
|                 "state_resets", res.state_resets, ("position",) | ||||
|                 "state_resets", res.state_resets, ("position",), | ||||
|             ) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|  | @ -266,15 +292,16 @@ class ReplicationResource(Resource): | |||
| 
 | ||||
|         request_presence = request_streams.get("presence") | ||||
| 
 | ||||
|         if request_presence is not None: | ||||
|         if request_presence is not None and request_presence != current_position: | ||||
|             presence_rows = yield self.presence_handler.get_all_presence_updates( | ||||
|                 request_presence, current_position | ||||
|             ) | ||||
|             upto_token = _position_from_rows(presence_rows, current_position) | ||||
|             writer.write_header_and_rows("presence", presence_rows, ( | ||||
|                 "position", "user_id", "state", "last_active_ts", | ||||
|                 "last_federation_update_ts", "last_user_sync_ts", | ||||
|                 "status_msg", "currently_active", | ||||
|             )) | ||||
|             ), position=upto_token) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def typing(self, writer, current_token, request_streams): | ||||
|  | @ -282,7 +309,7 @@ class ReplicationResource(Resource): | |||
| 
 | ||||
|         request_typing = request_streams.get("typing") | ||||
| 
 | ||||
|         if request_typing is not None: | ||||
|         if request_typing is not None and request_typing != current_position: | ||||
|             # If they have a higher token than current max, we can assume that | ||||
|             # they had been talking to a previous instance of the master. Since | ||||
|             # we reset the token on restart, the best (but hacky) thing we can | ||||
|  | @ -293,9 +320,10 @@ class ReplicationResource(Resource): | |||
|             typing_rows = yield self.typing_handler.get_all_typing_updates( | ||||
|                 request_typing, current_position | ||||
|             ) | ||||
|             upto_token = _position_from_rows(typing_rows, current_position) | ||||
|             writer.write_header_and_rows("typing", typing_rows, ( | ||||
|                 "position", "room_id", "typing" | ||||
|             )) | ||||
|             ), position=upto_token) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def receipts(self, writer, current_token, limit, request_streams): | ||||
|  | @ -303,13 +331,14 @@ class ReplicationResource(Resource): | |||
| 
 | ||||
|         request_receipts = request_streams.get("receipts") | ||||
| 
 | ||||
|         if request_receipts is not None: | ||||
|         if request_receipts is not None and request_receipts != current_position: | ||||
|             receipts_rows = yield self.store.get_all_updated_receipts( | ||||
|                 request_receipts, current_position, limit | ||||
|             ) | ||||
|             upto_token = _position_from_rows(receipts_rows, current_position) | ||||
|             writer.write_header_and_rows("receipts", receipts_rows, ( | ||||
|                 "position", "room_id", "receipt_type", "user_id", "event_id", "data" | ||||
|             )) | ||||
|             ), position=upto_token) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def account_data(self, writer, current_token, limit, request_streams): | ||||
|  | @ -324,23 +353,36 @@ class ReplicationResource(Resource): | |||
|                 user_account_data = current_position | ||||
|             if room_account_data is None: | ||||
|                 room_account_data = current_position | ||||
| 
 | ||||
|             no_new_tokens = ( | ||||
|                 user_account_data == current_position | ||||
|                 and room_account_data == current_position | ||||
|             ) | ||||
|             if no_new_tokens: | ||||
|                 return | ||||
| 
 | ||||
|             user_rows, room_rows = yield self.store.get_all_updated_account_data( | ||||
|                 user_account_data, room_account_data, current_position, limit | ||||
|             ) | ||||
| 
 | ||||
|             upto_users_token = _position_from_rows(user_rows, current_position) | ||||
|             upto_rooms_token = _position_from_rows(room_rows, current_position) | ||||
| 
 | ||||
|             writer.write_header_and_rows("user_account_data", user_rows, ( | ||||
|                 "position", "user_id", "type", "content" | ||||
|             )) | ||||
|             ), position=upto_users_token) | ||||
|             writer.write_header_and_rows("room_account_data", room_rows, ( | ||||
|                 "position", "user_id", "room_id", "type", "content" | ||||
|             )) | ||||
|             ), position=upto_rooms_token) | ||||
| 
 | ||||
|         if tag_account_data is not None: | ||||
|             tag_rows = yield self.store.get_all_updated_tags( | ||||
|                 tag_account_data, current_position, limit | ||||
|             ) | ||||
|             upto_tag_token = _position_from_rows(tag_rows, current_position) | ||||
|             writer.write_header_and_rows("tag_account_data", tag_rows, ( | ||||
|                 "position", "user_id", "room_id", "tags" | ||||
|             )) | ||||
|             ), position=upto_tag_token) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def push_rules(self, writer, current_token, limit, request_streams): | ||||
|  | @ -348,14 +390,15 @@ class ReplicationResource(Resource): | |||
| 
 | ||||
|         push_rules = request_streams.get("push_rules") | ||||
| 
 | ||||
|         if push_rules is not None: | ||||
|         if push_rules is not None and push_rules != current_position: | ||||
|             rows = yield self.store.get_all_push_rule_updates( | ||||
|                 push_rules, current_position, limit | ||||
|             ) | ||||
|             upto_token = _position_from_rows(rows, current_position) | ||||
|             writer.write_header_and_rows("push_rules", rows, ( | ||||
|                 "position", "event_stream_ordering", "user_id", "rule_id", "op", | ||||
|                 "priority_class", "priority", "conditions", "actions" | ||||
|             )) | ||||
|             ), position=upto_token) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def pushers(self, writer, current_token, limit, request_streams): | ||||
|  | @ -363,18 +406,19 @@ class ReplicationResource(Resource): | |||
| 
 | ||||
|         pushers = request_streams.get("pushers") | ||||
| 
 | ||||
|         if pushers is not None: | ||||
|         if pushers is not None and pushers != current_position: | ||||
|             updated, deleted = yield self.store.get_all_updated_pushers( | ||||
|                 pushers, current_position, limit | ||||
|             ) | ||||
|             upto_token = _position_from_rows(updated, current_position) | ||||
|             writer.write_header_and_rows("pushers", updated, ( | ||||
|                 "position", "user_id", "access_token", "profile_tag", "kind", | ||||
|                 "app_id", "app_display_name", "device_display_name", "pushkey", | ||||
|                 "ts", "lang", "data" | ||||
|             )) | ||||
|             ), position=upto_token) | ||||
|             writer.write_header_and_rows("deleted_pushers", deleted, ( | ||||
|                 "position", "user_id", "app_id", "pushkey" | ||||
|             )) | ||||
|             ), position=upto_token) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def caches(self, writer, current_token, limit, request_streams): | ||||
|  | @ -382,13 +426,14 @@ class ReplicationResource(Resource): | |||
| 
 | ||||
|         caches = request_streams.get("caches") | ||||
| 
 | ||||
|         if caches is not None: | ||||
|         if caches is not None and caches != current_position: | ||||
|             updated_caches = yield self.store.get_all_updated_caches( | ||||
|                 caches, current_position, limit | ||||
|             ) | ||||
|             upto_token = _position_from_rows(updated_caches, current_position) | ||||
|             writer.write_header_and_rows("caches", updated_caches, ( | ||||
|                 "position", "cache_func", "keys", "invalidation_ts" | ||||
|             )) | ||||
|             ), position=upto_token) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def to_device(self, writer, current_token, limit, request_streams): | ||||
|  | @ -396,13 +441,14 @@ class ReplicationResource(Resource): | |||
| 
 | ||||
|         to_device = request_streams.get("to_device") | ||||
| 
 | ||||
|         if to_device is not None: | ||||
|         if to_device is not None and to_device != current_position: | ||||
|             to_device_rows = yield self.store.get_all_new_device_messages( | ||||
|                 to_device, current_position, limit | ||||
|             ) | ||||
|             upto_token = _position_from_rows(to_device_rows, current_position) | ||||
|             writer.write_header_and_rows("to_device", to_device_rows, ( | ||||
|                 "position", "user_id", "device_id", "message_json" | ||||
|             )) | ||||
|             ), position=upto_token) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def public_rooms(self, writer, current_token, limit, request_streams): | ||||
|  | @ -410,13 +456,14 @@ class ReplicationResource(Resource): | |||
| 
 | ||||
|         public_rooms = request_streams.get("public_rooms") | ||||
| 
 | ||||
|         if public_rooms is not None: | ||||
|         if public_rooms is not None and public_rooms != current_position: | ||||
|             public_rooms_rows = yield self.store.get_all_new_public_rooms( | ||||
|                 public_rooms, current_position, limit | ||||
|             ) | ||||
|             upto_token = _position_from_rows(public_rooms_rows, current_position) | ||||
|             writer.write_header_and_rows("public_rooms", public_rooms_rows, ( | ||||
|                 "position", "room_id", "visibility" | ||||
|             )) | ||||
|             ), position=upto_token) | ||||
| 
 | ||||
| 
 | ||||
| class _Writer(object): | ||||
|  | @ -426,11 +473,11 @@ class _Writer(object): | |||
|         self.total = 0 | ||||
| 
 | ||||
|     def write_header_and_rows(self, name, rows, fields, position=None): | ||||
|         if not rows: | ||||
|             return | ||||
| 
 | ||||
|         if position is None: | ||||
|             position = rows[-1][0] | ||||
|             if rows: | ||||
|                 position = rows[-1][0] | ||||
|             else: | ||||
|                 return | ||||
| 
 | ||||
|         self.streams[name] = { | ||||
|             "position": position if type(position) is int else str(position), | ||||
|  | @ -440,6 +487,9 @@ class _Writer(object): | |||
| 
 | ||||
|         self.total += len(rows) | ||||
| 
 | ||||
|     def __nonzero__(self): | ||||
|         return bool(self.total) | ||||
| 
 | ||||
|     def finish(self): | ||||
|         return self.streams | ||||
| 
 | ||||
|  | @ -461,3 +511,20 @@ class _ReplicationToken(collections.namedtuple("_ReplicationToken", ( | |||
| 
 | ||||
|     def __str__(self): | ||||
|         return "_".join(str(value) for value in self) | ||||
| 
 | ||||
| 
 | ||||
| def _position_from_rows(rows, current_position): | ||||
|     """Calculates a position to return for a stream. Ideally we want to return the | ||||
|     position of the last row, as that will be the most correct. However, if there | ||||
|     are no rows we fall back to using the current position to stop us from | ||||
|     repeatedly hitting the storage layer unncessarily thinking there are updates. | ||||
|     (Not all advances of the token correspond to an actual update) | ||||
| 
 | ||||
|     We can't just always return the current position, as we often limit the | ||||
|     number of rows we replicate, and so the stream may lag. The assumption is | ||||
|     that if the storage layer returns no new rows then we are not lagging and | ||||
|     we are at the `current_position`. | ||||
|     """ | ||||
|     if rows: | ||||
|         return rows[-1][0] | ||||
|     return current_position | ||||
|  |  | |||
|  | @ -17,6 +17,7 @@ import logging | |||
| 
 | ||||
| from twisted.internet import defer | ||||
| 
 | ||||
| from synapse.api import constants, errors | ||||
| from synapse.http import servlet | ||||
| from ._base import client_v2_patterns | ||||
| 
 | ||||
|  | @ -58,6 +59,7 @@ class DeviceRestServlet(servlet.RestServlet): | |||
|         self.hs = hs | ||||
|         self.auth = hs.get_auth() | ||||
|         self.device_handler = hs.get_device_handler() | ||||
|         self.auth_handler = hs.get_auth_handler() | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def on_GET(self, request, device_id): | ||||
|  | @ -70,11 +72,24 @@ class DeviceRestServlet(servlet.RestServlet): | |||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def on_DELETE(self, request, device_id): | ||||
|         # XXX: it's not completely obvious we want to expose this endpoint. | ||||
|         # It allows the client to delete access tokens, which feels like a | ||||
|         # thing which merits extra auth. But if we want to do the interactive- | ||||
|         # auth dance, we should really make it possible to delete more than one | ||||
|         # device at a time. | ||||
|         try: | ||||
|             body = servlet.parse_json_object_from_request(request) | ||||
| 
 | ||||
|         except errors.SynapseError as e: | ||||
|             if e.errcode == errors.Codes.NOT_JSON: | ||||
|                 # deal with older clients which didn't pass a JSON dict | ||||
|                 # the same as those that pass an empty dict | ||||
|                 body = {} | ||||
|             else: | ||||
|                 raise | ||||
| 
 | ||||
|         authed, result, params, _ = yield self.auth_handler.check_auth([ | ||||
|             [constants.LoginType.PASSWORD], | ||||
|         ], body, self.hs.get_ip_from_request(request)) | ||||
| 
 | ||||
|         if not authed: | ||||
|             defer.returnValue((401, result)) | ||||
| 
 | ||||
|         requester = yield self.auth.get_user_by_req(request) | ||||
|         yield self.device_handler.delete_device( | ||||
|             requester.user.to_string(), | ||||
|  |  | |||
|  | @ -19,8 +19,6 @@ from synapse.http.server import respond_with_json_bytes | |||
| from signedjson.sign import sign_json | ||||
| from unpaddedbase64 import encode_base64 | ||||
| from canonicaljson import encode_canonical_json | ||||
| from hashlib import sha256 | ||||
| from OpenSSL import crypto | ||||
| import logging | ||||
| 
 | ||||
| 
 | ||||
|  | @ -48,8 +46,12 @@ class LocalKey(Resource): | |||
|                     "expired_ts": # integer posix timestamp when the key expired. | ||||
|                     "key": # base64 encoded NACL verification key. | ||||
|                 } | ||||
|             } | ||||
|             "tls_certificate": # base64 ASN.1 DER encoded X.509 tls cert. | ||||
|             }, | ||||
|             "tls_fingerprints": [ # Fingerprints of the TLS certs this server uses. | ||||
|                 { | ||||
|                     "sha256": # base64 encoded sha256 fingerprint of the X509 cert | ||||
|                 }, | ||||
|             ], | ||||
|             "signatures": { | ||||
|                 "this.server.example.com": { | ||||
|                    "algorithm:version": # NACL signature for this server | ||||
|  | @ -90,21 +92,14 @@ class LocalKey(Resource): | |||
|                 u"expired_ts": key.expired, | ||||
|             } | ||||
| 
 | ||||
|         x509_certificate_bytes = crypto.dump_certificate( | ||||
|             crypto.FILETYPE_ASN1, | ||||
|             self.config.tls_certificate | ||||
|         ) | ||||
| 
 | ||||
|         sha256_fingerprint = sha256(x509_certificate_bytes).digest() | ||||
|         tls_fingerprints = self.config.tls_fingerprints | ||||
| 
 | ||||
|         json_object = { | ||||
|             u"valid_until_ts": self.valid_until_ts, | ||||
|             u"server_name": self.config.server_name, | ||||
|             u"verify_keys": verify_keys, | ||||
|             u"old_verify_keys": old_verify_keys, | ||||
|             u"tls_fingerprints": [{ | ||||
|                 u"sha256": encode_base64(sha256_fingerprint), | ||||
|             }] | ||||
|             u"tls_fingerprints": tls_fingerprints, | ||||
|         } | ||||
|         for key in self.config.signing_key: | ||||
|             json_object = sign_json( | ||||
|  |  | |||
|  | @ -320,6 +320,9 @@ class RoomStore(SQLBaseStore): | |||
|             txn.execute(sql, (prev_id, current_id, limit,)) | ||||
|             return txn.fetchall() | ||||
| 
 | ||||
|         if prev_id == current_id: | ||||
|             return defer.succeed([]) | ||||
| 
 | ||||
|         return self.runInteraction( | ||||
|             "get_all_new_public_rooms", get_all_new_public_rooms | ||||
|         ) | ||||
|  |  | |||
|  | @ -0,0 +1,368 @@ | |||
| 
 | ||||
| from twisted.internet import defer | ||||
| 
 | ||||
| from synapse.config._base import ConfigError | ||||
| from synapse.types import UserID | ||||
| 
 | ||||
| import ldap3 | ||||
| import ldap3.core.exceptions | ||||
| 
 | ||||
| import logging | ||||
| 
 | ||||
| try: | ||||
|     import ldap3 | ||||
|     import ldap3.core.exceptions | ||||
| except ImportError: | ||||
|     ldap3 = None | ||||
|     pass | ||||
| 
 | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| 
 | ||||
| class LDAPMode(object): | ||||
|     SIMPLE = "simple", | ||||
|     SEARCH = "search", | ||||
| 
 | ||||
|     LIST = (SIMPLE, SEARCH) | ||||
| 
 | ||||
| 
 | ||||
| class LdapAuthProvider(object): | ||||
|     __version__ = "0.1" | ||||
| 
 | ||||
|     def __init__(self, config, account_handler): | ||||
|         self.account_handler = account_handler | ||||
| 
 | ||||
|         if not ldap3: | ||||
|             raise RuntimeError( | ||||
|                 'Missing ldap3 library. This is required for LDAP Authentication.' | ||||
|             ) | ||||
| 
 | ||||
|         self.ldap_mode = config.mode | ||||
|         self.ldap_uri = config.uri | ||||
|         self.ldap_start_tls = config.start_tls | ||||
|         self.ldap_base = config.base | ||||
|         self.ldap_attributes = config.attributes | ||||
|         if self.ldap_mode == LDAPMode.SEARCH: | ||||
|             self.ldap_bind_dn = config.bind_dn | ||||
|             self.ldap_bind_password = config.bind_password | ||||
|             self.ldap_filter = config.filter | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def check_password(self, user_id, password): | ||||
|         """ Attempt to authenticate a user against an LDAP Server | ||||
|             and register an account if none exists. | ||||
| 
 | ||||
|             Returns: | ||||
|                 True if authentication against LDAP was successful | ||||
|         """ | ||||
|         localpart = UserID.from_string(user_id).localpart | ||||
| 
 | ||||
|         try: | ||||
|             server = ldap3.Server(self.ldap_uri) | ||||
|             logger.debug( | ||||
|                 "Attempting LDAP connection with %s", | ||||
|                 self.ldap_uri | ||||
|             ) | ||||
| 
 | ||||
|             if self.ldap_mode == LDAPMode.SIMPLE: | ||||
|                 result, conn = self._ldap_simple_bind( | ||||
|                     server=server, localpart=localpart, password=password | ||||
|                 ) | ||||
|                 logger.debug( | ||||
|                     'LDAP authentication method simple bind returned: %s (conn: %s)', | ||||
|                     result, | ||||
|                     conn | ||||
|                 ) | ||||
|                 if not result: | ||||
|                     defer.returnValue(False) | ||||
|             elif self.ldap_mode == LDAPMode.SEARCH: | ||||
|                 result, conn = self._ldap_authenticated_search( | ||||
|                     server=server, localpart=localpart, password=password | ||||
|                 ) | ||||
|                 logger.debug( | ||||
|                     'LDAP auth method authenticated search returned: %s (conn: %s)', | ||||
|                     result, | ||||
|                     conn | ||||
|                 ) | ||||
|                 if not result: | ||||
|                     defer.returnValue(False) | ||||
|             else: | ||||
|                 raise RuntimeError( | ||||
|                     'Invalid LDAP mode specified: {mode}'.format( | ||||
|                         mode=self.ldap_mode | ||||
|                     ) | ||||
|                 ) | ||||
| 
 | ||||
|             try: | ||||
|                 logger.info( | ||||
|                     "User authenticated against LDAP server: %s", | ||||
|                     conn | ||||
|                 ) | ||||
|             except NameError: | ||||
|                 logger.warn( | ||||
|                     "Authentication method yielded no LDAP connection, aborting!" | ||||
|                 ) | ||||
|                 defer.returnValue(False) | ||||
| 
 | ||||
|             # check if user with user_id exists | ||||
|             if (yield self.account_handler.check_user_exists(user_id)): | ||||
|                 # exists, authentication complete | ||||
|                 conn.unbind() | ||||
|                 defer.returnValue(True) | ||||
| 
 | ||||
|             else: | ||||
|                 # does not exist, fetch metadata for account creation from | ||||
|                 # existing ldap connection | ||||
|                 query = "({prop}={value})".format( | ||||
|                     prop=self.ldap_attributes['uid'], | ||||
|                     value=localpart | ||||
|                 ) | ||||
| 
 | ||||
|                 if self.ldap_mode == LDAPMode.SEARCH and self.ldap_filter: | ||||
|                     query = "(&{filter}{user_filter})".format( | ||||
|                         filter=query, | ||||
|                         user_filter=self.ldap_filter | ||||
|                     ) | ||||
|                 logger.debug( | ||||
|                     "ldap registration filter: %s", | ||||
|                     query | ||||
|                 ) | ||||
| 
 | ||||
|                 conn.search( | ||||
|                     search_base=self.ldap_base, | ||||
|                     search_filter=query, | ||||
|                     attributes=[ | ||||
|                         self.ldap_attributes['name'], | ||||
|                         self.ldap_attributes['mail'] | ||||
|                     ] | ||||
|                 ) | ||||
| 
 | ||||
|                 if len(conn.response) == 1: | ||||
|                     attrs = conn.response[0]['attributes'] | ||||
|                     mail = attrs[self.ldap_attributes['mail']][0] | ||||
|                     name = attrs[self.ldap_attributes['name']][0] | ||||
| 
 | ||||
|                     # create account | ||||
|                     user_id, access_token = ( | ||||
|                         yield self.account_handler.register(localpart=localpart) | ||||
|                     ) | ||||
| 
 | ||||
|                     # TODO: bind email, set displayname with data from ldap directory | ||||
| 
 | ||||
|                     logger.info( | ||||
|                         "Registration based on LDAP data was successful: %d: %s (%s, %)", | ||||
|                         user_id, | ||||
|                         localpart, | ||||
|                         name, | ||||
|                         mail | ||||
|                     ) | ||||
| 
 | ||||
|                     defer.returnValue(True) | ||||
|                 else: | ||||
|                     if len(conn.response) == 0: | ||||
|                         logger.warn("LDAP registration failed, no result.") | ||||
|                     else: | ||||
|                         logger.warn( | ||||
|                             "LDAP registration failed, too many results (%s)", | ||||
|                             len(conn.response) | ||||
|                         ) | ||||
| 
 | ||||
|                     defer.returnValue(False) | ||||
| 
 | ||||
|             defer.returnValue(False) | ||||
| 
 | ||||
|         except ldap3.core.exceptions.LDAPException as e: | ||||
|             logger.warn("Error during ldap authentication: %s", e) | ||||
|             defer.returnValue(False) | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def parse_config(config): | ||||
|         class _LdapConfig(object): | ||||
|             pass | ||||
| 
 | ||||
|         ldap_config = _LdapConfig() | ||||
| 
 | ||||
|         ldap_config.enabled = config.get("enabled", False) | ||||
| 
 | ||||
|         ldap_config.mode = LDAPMode.SIMPLE | ||||
| 
 | ||||
|         # verify config sanity | ||||
|         _require_keys(config, [ | ||||
|             "uri", | ||||
|             "base", | ||||
|             "attributes", | ||||
|         ]) | ||||
| 
 | ||||
|         ldap_config.uri = config["uri"] | ||||
|         ldap_config.start_tls = config.get("start_tls", False) | ||||
|         ldap_config.base = config["base"] | ||||
|         ldap_config.attributes = config["attributes"] | ||||
| 
 | ||||
|         if "bind_dn" in config: | ||||
|             ldap_config.mode = LDAPMode.SEARCH | ||||
|             _require_keys(config, [ | ||||
|                 "bind_dn", | ||||
|                 "bind_password", | ||||
|             ]) | ||||
| 
 | ||||
|             ldap_config.bind_dn = config["bind_dn"] | ||||
|             ldap_config.bind_password = config["bind_password"] | ||||
|             ldap_config.filter = config.get("filter", None) | ||||
| 
 | ||||
|         # verify attribute lookup | ||||
|         _require_keys(config['attributes'], [ | ||||
|             "uid", | ||||
|             "name", | ||||
|             "mail", | ||||
|         ]) | ||||
| 
 | ||||
|         return ldap_config | ||||
| 
 | ||||
|     def _ldap_simple_bind(self, server, localpart, password): | ||||
|         """ Attempt a simple bind with the credentials | ||||
|             given by the user against the LDAP server. | ||||
| 
 | ||||
|             Returns True, LDAP3Connection | ||||
|                 if the bind was successful | ||||
|             Returns False, None | ||||
|                 if an error occured | ||||
|         """ | ||||
| 
 | ||||
|         try: | ||||
|             # bind with the the local users ldap credentials | ||||
|             bind_dn = "{prop}={value},{base}".format( | ||||
|                 prop=self.ldap_attributes['uid'], | ||||
|                 value=localpart, | ||||
|                 base=self.ldap_base | ||||
|             ) | ||||
|             conn = ldap3.Connection(server, bind_dn, password) | ||||
|             logger.debug( | ||||
|                 "Established LDAP connection in simple bind mode: %s", | ||||
|                 conn | ||||
|             ) | ||||
| 
 | ||||
|             if self.ldap_start_tls: | ||||
|                 conn.start_tls() | ||||
|                 logger.debug( | ||||
|                     "Upgraded LDAP connection in simple bind mode through StartTLS: %s", | ||||
|                     conn | ||||
|                 ) | ||||
| 
 | ||||
|             if conn.bind(): | ||||
|                 # GOOD: bind okay | ||||
|                 logger.debug("LDAP Bind successful in simple bind mode.") | ||||
|                 return True, conn | ||||
| 
 | ||||
|             # BAD: bind failed | ||||
|             logger.info( | ||||
|                 "Binding against LDAP failed for '%s' failed: %s", | ||||
|                 localpart, conn.result['description'] | ||||
|             ) | ||||
|             conn.unbind() | ||||
|             return False, None | ||||
| 
 | ||||
|         except ldap3.core.exceptions.LDAPException as e: | ||||
|             logger.warn("Error during LDAP authentication: %s", e) | ||||
|             return False, None | ||||
| 
 | ||||
|     def _ldap_authenticated_search(self, server, localpart, password): | ||||
|         """ Attempt to login with the preconfigured bind_dn | ||||
|             and then continue searching and filtering within | ||||
|             the base_dn | ||||
| 
 | ||||
|             Returns (True, LDAP3Connection) | ||||
|                 if a single matching DN within the base was found | ||||
|                 that matched the filter expression, and with which | ||||
|                 a successful bind was achieved | ||||
| 
 | ||||
|                 The LDAP3Connection returned is the instance that was used to | ||||
|                 verify the password not the one using the configured bind_dn. | ||||
|             Returns (False, None) | ||||
|                 if an error occured | ||||
|         """ | ||||
| 
 | ||||
|         try: | ||||
|             conn = ldap3.Connection( | ||||
|                 server, | ||||
|                 self.ldap_bind_dn, | ||||
|                 self.ldap_bind_password | ||||
|             ) | ||||
|             logger.debug( | ||||
|                 "Established LDAP connection in search mode: %s", | ||||
|                 conn | ||||
|             ) | ||||
| 
 | ||||
|             if self.ldap_start_tls: | ||||
|                 conn.start_tls() | ||||
|                 logger.debug( | ||||
|                     "Upgraded LDAP connection in search mode through StartTLS: %s", | ||||
|                     conn | ||||
|                 ) | ||||
| 
 | ||||
|             if not conn.bind(): | ||||
|                 logger.warn( | ||||
|                     "Binding against LDAP with `bind_dn` failed: %s", | ||||
|                     conn.result['description'] | ||||
|                 ) | ||||
|                 conn.unbind() | ||||
|                 return False, None | ||||
| 
 | ||||
|             # construct search_filter like (uid=localpart) | ||||
|             query = "({prop}={value})".format( | ||||
|                 prop=self.ldap_attributes['uid'], | ||||
|                 value=localpart | ||||
|             ) | ||||
|             if self.ldap_filter: | ||||
|                 # combine with the AND expression | ||||
|                 query = "(&{query}{filter})".format( | ||||
|                     query=query, | ||||
|                     filter=self.ldap_filter | ||||
|                 ) | ||||
|             logger.debug( | ||||
|                 "LDAP search filter: %s", | ||||
|                 query | ||||
|             ) | ||||
|             conn.search( | ||||
|                 search_base=self.ldap_base, | ||||
|                 search_filter=query | ||||
|             ) | ||||
| 
 | ||||
|             if len(conn.response) == 1: | ||||
|                 # GOOD: found exactly one result | ||||
|                 user_dn = conn.response[0]['dn'] | ||||
|                 logger.debug('LDAP search found dn: %s', user_dn) | ||||
| 
 | ||||
|                 # unbind and simple bind with user_dn to verify the password | ||||
|                 # Note: do not use rebind(), for some reason it did not verify | ||||
|                 #       the password for me! | ||||
|                 conn.unbind() | ||||
|                 return self._ldap_simple_bind(server, localpart, password) | ||||
|             else: | ||||
|                 # BAD: found 0 or > 1 results, abort! | ||||
|                 if len(conn.response) == 0: | ||||
|                     logger.info( | ||||
|                         "LDAP search returned no results for '%s'", | ||||
|                         localpart | ||||
|                     ) | ||||
|                 else: | ||||
|                     logger.info( | ||||
|                         "LDAP search returned too many (%s) results for '%s'", | ||||
|                         len(conn.response), localpart | ||||
|                     ) | ||||
|                 conn.unbind() | ||||
|                 return False, None | ||||
| 
 | ||||
|         except ldap3.core.exceptions.LDAPException as e: | ||||
|             logger.warn("Error during LDAP authentication: %s", e) | ||||
|             return False, None | ||||
| 
 | ||||
| 
 | ||||
| def _require_keys(config, required): | ||||
|     missing = [key for key in required if key not in config] | ||||
|     if missing: | ||||
|         raise ConfigError( | ||||
|             "LDAP enabled but missing required config values: {}".format( | ||||
|                 ", ".join(missing) | ||||
|             ) | ||||
|         ) | ||||
|  | @ -42,7 +42,8 @@ class BaseSlavedStoreTestCase(unittest.TestCase): | |||
|     @defer.inlineCallbacks | ||||
|     def replicate(self): | ||||
|         streams = self.slaved_store.stream_positions() | ||||
|         result = yield self.replication.replicate(streams, 100) | ||||
|         writer = yield self.replication.replicate(streams, 100) | ||||
|         result = writer.finish() | ||||
|         yield self.slaved_store.process_replication(result) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|  |  | |||
|  | @ -120,7 +120,7 @@ class ReplicationResourceCase(unittest.TestCase): | |||
|             self.hs.clock.advance_time_msec(1) | ||||
|             code, body = yield get | ||||
|             self.assertEquals(code, 200) | ||||
|             self.assertEquals(body, {}) | ||||
|             self.assertEquals(body.get("rows", []), []) | ||||
|         test_timeout.__name__ = "test_timeout_%s" % (stream) | ||||
|         return test_timeout | ||||
| 
 | ||||
|  | @ -195,7 +195,6 @@ class ReplicationResourceCase(unittest.TestCase): | |||
|             self.assertIn("field_names", stream) | ||||
|             field_names = stream["field_names"] | ||||
|             self.assertIn("rows", stream) | ||||
|             self.assertTrue(stream["rows"]) | ||||
|             for row in stream["rows"]: | ||||
|                 self.assertEquals( | ||||
|                     len(row), len(field_names), | ||||
|  |  | |||
|  | @ -37,6 +37,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): | |||
|         config = Mock( | ||||
|             app_service_config_files=self.as_yaml_files, | ||||
|             event_cache_size=1, | ||||
|             password_providers=[], | ||||
|         ) | ||||
|         hs = yield setup_test_homeserver(config=config) | ||||
| 
 | ||||
|  | @ -109,6 +110,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): | |||
|         config = Mock( | ||||
|             app_service_config_files=self.as_yaml_files, | ||||
|             event_cache_size=1, | ||||
|             password_providers=[], | ||||
|         ) | ||||
|         hs = yield setup_test_homeserver(config=config) | ||||
|         self.db_pool = hs.get_db_pool() | ||||
|  | @ -437,7 +439,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): | |||
|         f1 = self._write_config(suffix="1") | ||||
|         f2 = self._write_config(suffix="2") | ||||
| 
 | ||||
|         config = Mock(app_service_config_files=[f1, f2], event_cache_size=1) | ||||
|         config = Mock( | ||||
|             app_service_config_files=[f1, f2], event_cache_size=1, | ||||
|             password_providers=[] | ||||
|         ) | ||||
|         hs = yield setup_test_homeserver(config=config, datastore=Mock()) | ||||
| 
 | ||||
|         ApplicationServiceStore(hs) | ||||
|  | @ -447,7 +452,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): | |||
|         f1 = self._write_config(id="id", suffix="1") | ||||
|         f2 = self._write_config(id="id", suffix="2") | ||||
| 
 | ||||
|         config = Mock(app_service_config_files=[f1, f2], event_cache_size=1) | ||||
|         config = Mock( | ||||
|             app_service_config_files=[f1, f2], event_cache_size=1, | ||||
|             password_providers=[] | ||||
|         ) | ||||
|         hs = yield setup_test_homeserver(config=config, datastore=Mock()) | ||||
| 
 | ||||
|         with self.assertRaises(ConfigError) as cm: | ||||
|  | @ -463,7 +471,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): | |||
|         f1 = self._write_config(as_token="as_token", suffix="1") | ||||
|         f2 = self._write_config(as_token="as_token", suffix="2") | ||||
| 
 | ||||
|         config = Mock(app_service_config_files=[f1, f2], event_cache_size=1) | ||||
|         config = Mock( | ||||
|             app_service_config_files=[f1, f2], event_cache_size=1, | ||||
|             password_providers=[] | ||||
|         ) | ||||
|         hs = yield setup_test_homeserver(config=config, datastore=Mock()) | ||||
| 
 | ||||
|         with self.assertRaises(ConfigError) as cm: | ||||
|  |  | |||
|  | @ -52,6 +52,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): | |||
|         config.server_name = name | ||||
|         config.trusted_third_party_id_servers = [] | ||||
|         config.room_invite_state_types = [] | ||||
|         config.password_providers = [] | ||||
| 
 | ||||
|     config.use_frozen_dicts = True | ||||
|     config.database_config = {"name": "sqlite3"} | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 David Baker
						David Baker