Merge branch 'develop' of github.com:matrix-org/synapse into matrix-org-hotfixes
commit
f86962cb6b
13
INSTALL.md
13
INSTALL.md
|
@ -124,12 +124,21 @@ sudo pacman -S base-devel python python-pip \
|
|||
|
||||
#### CentOS/Fedora
|
||||
|
||||
Installing prerequisites on CentOS 7 or Fedora 25:
|
||||
Installing prerequisites on CentOS 8 or Fedora>26:
|
||||
|
||||
```
|
||||
sudo dnf install libtiff-devel libjpeg-devel libzip-devel freetype-devel \
|
||||
libwebp-devel tk-devel redhat-rpm-config \
|
||||
python3-virtualenv libffi-devel openssl-devel
|
||||
sudo dnf groupinstall "Development Tools"
|
||||
```
|
||||
|
||||
Installing prerequisites on CentOS 7 or Fedora<=25:
|
||||
|
||||
```
|
||||
sudo yum install libtiff-devel libjpeg-devel libzip-devel freetype-devel \
|
||||
lcms2-devel libwebp-devel tcl-devel tk-devel redhat-rpm-config \
|
||||
python-virtualenv libffi-devel openssl-devel
|
||||
python3-virtualenv libffi-devel openssl-devel
|
||||
sudo yum groupinstall "Development Tools"
|
||||
```
|
||||
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
Updated CentOS8 install instructions. Contributed by Richard Kellner.
|
|
@ -0,0 +1 @@
|
|||
Remove the unused query_auth federation endpoint per MSC2451.
|
|
@ -0,0 +1 @@
|
|||
Remove special handling of aliases events from [MSC2260](https://github.com/matrix-org/matrix-doc/pull/2260) added in v1.10.0rc1.
|
|
@ -0,0 +1 @@
|
|||
Fix a bug that renders UTF-8 text files incorrectly when loaded from media. Contributed by @TheStranjer.
|
|
@ -0,0 +1 @@
|
|||
Render a configurable and comprehensible error page if something goes wrong during the SAML2 authentication process.
|
|
@ -0,0 +1 @@
|
|||
Add type annotations and comments to the auth handler.
|
|
@ -0,0 +1 @@
|
|||
Fix a bug that would cause Synapse to respond with an error about event visibility if a client tried to request the state of a room at a given token.
|
|
@ -0,0 +1 @@
|
|||
Render a configurable and comprehensible error page if something goes wrong during the SAML2 authentication process.
|
|
@ -0,0 +1 @@
|
|||
Repair a data-corruption issue which was introduced in Synapse 1.10, and fixed in Synapse 1.11, and which could cause `/sync` to return with 404 errors about missing events and unknown rooms.
|
|
@ -0,0 +1 @@
|
|||
Fix a bug causing account validity renewal emails to be sent even if the feature is turned off in some cases.
|
|
@ -0,0 +1 @@
|
|||
Add an optional parameter to control whether other sessions are logged out when a user's password is modified.
|
|
@ -0,0 +1 @@
|
|||
Improve performance when making HTTPS requests to sygnal, sydent, etc, by sharing the SSL context object between connections.
|
|
@ -0,0 +1 @@
|
|||
Attempt to improve performance of state res v2 algorithm.
|
|
@ -38,6 +38,7 @@ The parameter ``threepids`` is optional.
|
|||
The parameter ``avatar_url`` is optional.
|
||||
The parameter ``admin`` is optional and defaults to 'false'.
|
||||
The parameter ``deactivated`` is optional and defaults to 'false'.
|
||||
The parameter ``password`` is optional. If provided the user's password is updated and all devices are logged out.
|
||||
If the user already exists then optional parameters default to the current value.
|
||||
|
||||
List Accounts
|
||||
|
@ -168,11 +169,14 @@ with a body of:
|
|||
.. code:: json
|
||||
|
||||
{
|
||||
"new_password": "<secret>"
|
||||
"new_password": "<secret>",
|
||||
"logout_devices": true,
|
||||
}
|
||||
|
||||
including an ``access_token`` of a server admin.
|
||||
|
||||
The parameter ``new_password`` is required.
|
||||
The parameter ``logout_devices`` is optional and defaults to ``true``.
|
||||
|
||||
Get whether a user is a server administrator or not
|
||||
===================================================
|
||||
|
|
|
@ -1347,6 +1347,25 @@ saml2_config:
|
|||
#
|
||||
#grandfathered_mxid_source_attribute: upn
|
||||
|
||||
# Directory in which Synapse will try to find the template files below.
|
||||
# If not set, default templates from within the Synapse package will be used.
|
||||
#
|
||||
# DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates.
|
||||
# If you *do* uncomment it, you will need to make sure that all the templates
|
||||
# below are in the directory.
|
||||
#
|
||||
# Synapse will look for the following templates in this directory:
|
||||
#
|
||||
# * HTML page to display to users if something goes wrong during the
|
||||
# authentication process: 'saml_error.html'.
|
||||
#
|
||||
# This template doesn't currently need any variable to render.
|
||||
#
|
||||
# You can see the default templates at:
|
||||
# https://github.com/matrix-org/synapse/tree/master/synapse/res/templates
|
||||
#
|
||||
#template_dir: "res/templates"
|
||||
|
||||
|
||||
|
||||
# Enable CAS for registration and login.
|
||||
|
|
|
@ -15,6 +15,9 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import pkg_resources
|
||||
|
||||
from synapse.python_dependencies import DependencyException, check_requirements
|
||||
from synapse.util.module_loader import load_module, load_python_module
|
||||
|
@ -160,6 +163,14 @@ class SAML2Config(Config):
|
|||
saml2_config.get("saml_session_lifetime", "5m")
|
||||
)
|
||||
|
||||
template_dir = saml2_config.get("template_dir")
|
||||
if not template_dir:
|
||||
template_dir = pkg_resources.resource_filename("synapse", "res/templates",)
|
||||
|
||||
self.saml2_error_html_content = self.read_file(
|
||||
os.path.join(template_dir, "saml_error.html"), "saml2_config.saml_error",
|
||||
)
|
||||
|
||||
def _default_saml_config_dict(
|
||||
self, required_attributes: set, optional_attributes: set
|
||||
):
|
||||
|
@ -325,6 +336,25 @@ class SAML2Config(Config):
|
|||
# The default is 'uid'.
|
||||
#
|
||||
#grandfathered_mxid_source_attribute: upn
|
||||
|
||||
# Directory in which Synapse will try to find the template files below.
|
||||
# If not set, default templates from within the Synapse package will be used.
|
||||
#
|
||||
# DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates.
|
||||
# If you *do* uncomment it, you will need to make sure that all the templates
|
||||
# below are in the directory.
|
||||
#
|
||||
# Synapse will look for the following templates in this directory:
|
||||
#
|
||||
# * HTML page to display to users if something goes wrong during the
|
||||
# authentication process: 'saml_error.html'.
|
||||
#
|
||||
# This template doesn't currently need any variable to render.
|
||||
#
|
||||
# You can see the default templates at:
|
||||
# https://github.com/matrix-org/synapse/tree/master/synapse/res/templates
|
||||
#
|
||||
#template_dir: "res/templates"
|
||||
""" % {
|
||||
"config_dir_path": config_dir_path
|
||||
}
|
||||
|
|
|
@ -75,7 +75,7 @@ class ServerContextFactory(ContextFactory):
|
|||
|
||||
|
||||
@implementer(IPolicyForHTTPS)
|
||||
class ClientTLSOptionsFactory(object):
|
||||
class FederationPolicyForHTTPS(object):
|
||||
"""Factory for Twisted SSLClientConnectionCreators that are used to make connections
|
||||
to remote servers for federation.
|
||||
|
||||
|
@ -103,15 +103,15 @@ class ClientTLSOptionsFactory(object):
|
|||
# let us do).
|
||||
minTLS = _TLS_VERSION_MAP[config.federation_client_minimum_tls_version]
|
||||
|
||||
self._verify_ssl = CertificateOptions(
|
||||
_verify_ssl = CertificateOptions(
|
||||
trustRoot=trust_root, insecurelyLowerMinimumTo=minTLS
|
||||
)
|
||||
self._verify_ssl_context = self._verify_ssl.getContext()
|
||||
self._verify_ssl_context.set_info_callback(self._context_info_cb)
|
||||
self._verify_ssl_context = _verify_ssl.getContext()
|
||||
self._verify_ssl_context.set_info_callback(_context_info_cb)
|
||||
|
||||
self._no_verify_ssl = CertificateOptions(insecurelyLowerMinimumTo=minTLS)
|
||||
self._no_verify_ssl_context = self._no_verify_ssl.getContext()
|
||||
self._no_verify_ssl_context.set_info_callback(self._context_info_cb)
|
||||
_no_verify_ssl = CertificateOptions(insecurelyLowerMinimumTo=minTLS)
|
||||
self._no_verify_ssl_context = _no_verify_ssl.getContext()
|
||||
self._no_verify_ssl_context.set_info_callback(_context_info_cb)
|
||||
|
||||
def get_options(self, host: bytes):
|
||||
|
||||
|
@ -136,23 +136,6 @@ class ClientTLSOptionsFactory(object):
|
|||
|
||||
return SSLClientConnectionCreator(host, ssl_context, should_verify)
|
||||
|
||||
@staticmethod
|
||||
def _context_info_cb(ssl_connection, where, ret):
|
||||
"""The 'information callback' for our openssl context object."""
|
||||
# we assume that the app_data on the connection object has been set to
|
||||
# a TLSMemoryBIOProtocol object. (This is done by SSLClientConnectionCreator)
|
||||
tls_protocol = ssl_connection.get_app_data()
|
||||
try:
|
||||
# ... we further assume that SSLClientConnectionCreator has set the
|
||||
# '_synapse_tls_verifier' attribute to a ConnectionVerifier object.
|
||||
tls_protocol._synapse_tls_verifier.verify_context_info_cb(
|
||||
ssl_connection, where
|
||||
)
|
||||
except: # noqa: E722, taken from the twisted implementation
|
||||
logger.exception("Error during info_callback")
|
||||
f = Failure()
|
||||
tls_protocol.failVerification(f)
|
||||
|
||||
def creatorForNetloc(self, hostname, port):
|
||||
"""Implements the IPolicyForHTTPS interace so that this can be passed
|
||||
directly to agents.
|
||||
|
@ -160,6 +143,43 @@ class ClientTLSOptionsFactory(object):
|
|||
return self.get_options(hostname)
|
||||
|
||||
|
||||
@implementer(IPolicyForHTTPS)
|
||||
class RegularPolicyForHTTPS(object):
|
||||
"""Factory for Twisted SSLClientConnectionCreators that are used to make connections
|
||||
to remote servers, for other than federation.
|
||||
|
||||
Always uses the same OpenSSL context object, which uses the default OpenSSL CA
|
||||
trust root.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
trust_root = platformTrust()
|
||||
self._ssl_context = CertificateOptions(trustRoot=trust_root).getContext()
|
||||
self._ssl_context.set_info_callback(_context_info_cb)
|
||||
|
||||
def creatorForNetloc(self, hostname, port):
|
||||
return SSLClientConnectionCreator(hostname, self._ssl_context, True)
|
||||
|
||||
|
||||
def _context_info_cb(ssl_connection, where, ret):
|
||||
"""The 'information callback' for our openssl context objects.
|
||||
|
||||
Note: Once this is set as the info callback on a Context object, the Context should
|
||||
only be used with the SSLClientConnectionCreator.
|
||||
"""
|
||||
# we assume that the app_data on the connection object has been set to
|
||||
# a TLSMemoryBIOProtocol object. (This is done by SSLClientConnectionCreator)
|
||||
tls_protocol = ssl_connection.get_app_data()
|
||||
try:
|
||||
# ... we further assume that SSLClientConnectionCreator has set the
|
||||
# '_synapse_tls_verifier' attribute to a ConnectionVerifier object.
|
||||
tls_protocol._synapse_tls_verifier.verify_context_info_cb(ssl_connection, where)
|
||||
except: # noqa: E722, taken from the twisted implementation
|
||||
logger.exception("Error during info_callback")
|
||||
f = Failure()
|
||||
tls_protocol.failVerification(f)
|
||||
|
||||
|
||||
@implementer(IOpenSSLClientConnectionCreator)
|
||||
class SSLClientConnectionCreator(object):
|
||||
"""Creates openssl connection objects for client connections.
|
||||
|
|
|
@ -39,10 +39,8 @@ from synapse.logging.context import (
|
|||
LoggingContext,
|
||||
PreserveLoggingContext,
|
||||
make_deferred_yieldable,
|
||||
preserve_fn,
|
||||
)
|
||||
from synapse.types import JsonDict, get_domain_from_id
|
||||
from synapse.util import unwrapFirstError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -57,86 +55,6 @@ class FederationBase(object):
|
|||
self.store = hs.get_datastore()
|
||||
self._clock = hs.get_clock()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_sigs_and_hash_and_fetch(
|
||||
self,
|
||||
origin: str,
|
||||
pdus: List[EventBase],
|
||||
room_version: str,
|
||||
outlier: bool = False,
|
||||
include_none: bool = False,
|
||||
):
|
||||
"""Takes a list of PDUs and checks the signatures and hashs of each
|
||||
one. If a PDU fails its signature check then we check if we have it in
|
||||
the database and if not then request if from the originating server of
|
||||
that PDU.
|
||||
|
||||
If a PDU fails its content hash check then it is redacted.
|
||||
|
||||
The given list of PDUs are not modified, instead the function returns
|
||||
a new list.
|
||||
|
||||
Args:
|
||||
origin
|
||||
pdu
|
||||
room_version
|
||||
outlier: Whether the events are outliers or not
|
||||
include_none: Whether to include None in the returned list
|
||||
for events that have failed their checks
|
||||
|
||||
Returns:
|
||||
Deferred : A list of PDUs that have valid signatures and hashes.
|
||||
"""
|
||||
deferreds = self._check_sigs_and_hashes(room_version, pdus)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def handle_check_result(pdu: EventBase, deferred: Deferred):
|
||||
try:
|
||||
res = yield make_deferred_yieldable(deferred)
|
||||
except SynapseError:
|
||||
res = None
|
||||
|
||||
if not res:
|
||||
# Check local db.
|
||||
res = yield self.store.get_event(
|
||||
pdu.event_id, allow_rejected=True, allow_none=True
|
||||
)
|
||||
|
||||
if not res and pdu.origin != origin:
|
||||
try:
|
||||
# This should not exist in the base implementation, until
|
||||
# this is fixed, ignore it for typing. See issue #6997.
|
||||
res = yield defer.ensureDeferred(
|
||||
self.get_pdu( # type: ignore
|
||||
destinations=[pdu.origin],
|
||||
event_id=pdu.event_id,
|
||||
room_version=room_version,
|
||||
outlier=outlier,
|
||||
timeout=10000,
|
||||
)
|
||||
)
|
||||
except SynapseError:
|
||||
pass
|
||||
|
||||
if not res:
|
||||
logger.warning(
|
||||
"Failed to find copy of %s with valid signature", pdu.event_id
|
||||
)
|
||||
|
||||
return res
|
||||
|
||||
handle = preserve_fn(handle_check_result)
|
||||
deferreds2 = [handle(pdu, deferred) for pdu, deferred in zip(pdus, deferreds)]
|
||||
|
||||
valid_pdus = yield make_deferred_yieldable(
|
||||
defer.gatherResults(deferreds2, consumeErrors=True)
|
||||
).addErrback(unwrapFirstError)
|
||||
|
||||
if include_none:
|
||||
return valid_pdus
|
||||
else:
|
||||
return [p for p in valid_pdus if p]
|
||||
|
||||
def _check_sigs_and_hash(self, room_version: str, pdu: EventBase) -> Deferred:
|
||||
return make_deferred_yieldable(
|
||||
self._check_sigs_and_hashes(room_version, [pdu])[0]
|
||||
|
|
|
@ -33,6 +33,7 @@ from typing import (
|
|||
from prometheus_client import Counter
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.defer import Deferred
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.errors import (
|
||||
|
@ -51,7 +52,7 @@ from synapse.api.room_versions import (
|
|||
)
|
||||
from synapse.events import EventBase, builder
|
||||
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
from synapse.logging.context import make_deferred_yieldable, preserve_fn
|
||||
from synapse.logging.utils import log_function
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import unwrapFirstError
|
||||
|
@ -345,6 +346,83 @@ class FederationClient(FederationBase):
|
|||
|
||||
return state_event_ids, auth_event_ids
|
||||
|
||||
async def _check_sigs_and_hash_and_fetch(
|
||||
self,
|
||||
origin: str,
|
||||
pdus: List[EventBase],
|
||||
room_version: str,
|
||||
outlier: bool = False,
|
||||
include_none: bool = False,
|
||||
) -> List[EventBase]:
|
||||
"""Takes a list of PDUs and checks the signatures and hashs of each
|
||||
one. If a PDU fails its signature check then we check if we have it in
|
||||
the database and if not then request if from the originating server of
|
||||
that PDU.
|
||||
|
||||
If a PDU fails its content hash check then it is redacted.
|
||||
|
||||
The given list of PDUs are not modified, instead the function returns
|
||||
a new list.
|
||||
|
||||
Args:
|
||||
origin
|
||||
pdu
|
||||
room_version
|
||||
outlier: Whether the events are outliers or not
|
||||
include_none: Whether to include None in the returned list
|
||||
for events that have failed their checks
|
||||
|
||||
Returns:
|
||||
Deferred : A list of PDUs that have valid signatures and hashes.
|
||||
"""
|
||||
deferreds = self._check_sigs_and_hashes(room_version, pdus)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def handle_check_result(pdu: EventBase, deferred: Deferred):
|
||||
try:
|
||||
res = yield make_deferred_yieldable(deferred)
|
||||
except SynapseError:
|
||||
res = None
|
||||
|
||||
if not res:
|
||||
# Check local db.
|
||||
res = yield self.store.get_event(
|
||||
pdu.event_id, allow_rejected=True, allow_none=True
|
||||
)
|
||||
|
||||
if not res and pdu.origin != origin:
|
||||
try:
|
||||
res = yield defer.ensureDeferred(
|
||||
self.get_pdu(
|
||||
destinations=[pdu.origin],
|
||||
event_id=pdu.event_id,
|
||||
room_version=room_version, # type: ignore
|
||||
outlier=outlier,
|
||||
timeout=10000,
|
||||
)
|
||||
)
|
||||
except SynapseError:
|
||||
pass
|
||||
|
||||
if not res:
|
||||
logger.warning(
|
||||
"Failed to find copy of %s with valid signature", pdu.event_id
|
||||
)
|
||||
|
||||
return res
|
||||
|
||||
handle = preserve_fn(handle_check_result)
|
||||
deferreds2 = [handle(pdu, deferred) for pdu, deferred in zip(pdus, deferreds)]
|
||||
|
||||
valid_pdus = await make_deferred_yieldable(
|
||||
defer.gatherResults(deferreds2, consumeErrors=True)
|
||||
).addErrback(unwrapFirstError)
|
||||
|
||||
if include_none:
|
||||
return valid_pdus
|
||||
else:
|
||||
return [p for p in valid_pdus if p]
|
||||
|
||||
async def get_event_auth(self, destination, room_id, event_id):
|
||||
res = await self.transport_layer.get_event_auth(destination, room_id, event_id)
|
||||
|
||||
|
|
|
@ -470,57 +470,6 @@ class FederationServer(FederationBase):
|
|||
res = {"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus]}
|
||||
return 200, res
|
||||
|
||||
async def on_query_auth_request(self, origin, content, room_id, event_id):
|
||||
"""
|
||||
Content is a dict with keys::
|
||||
auth_chain (list): A list of events that give the auth chain.
|
||||
missing (list): A list of event_ids indicating what the other
|
||||
side (`origin`) think we're missing.
|
||||
rejects (dict): A mapping from event_id to a 2-tuple of reason
|
||||
string and a proof (or None) of why the event was rejected.
|
||||
The keys of this dict give the list of events the `origin` has
|
||||
rejected.
|
||||
|
||||
Args:
|
||||
origin (str)
|
||||
content (dict)
|
||||
event_id (str)
|
||||
|
||||
Returns:
|
||||
Deferred: Results in `dict` with the same format as `content`
|
||||
"""
|
||||
with (await self._server_linearizer.queue((origin, room_id))):
|
||||
origin_host, _ = parse_server_name(origin)
|
||||
await self.check_server_matches_acl(origin_host, room_id)
|
||||
|
||||
room_version = await self.store.get_room_version(room_id)
|
||||
|
||||
auth_chain = [
|
||||
event_from_pdu_json(e, room_version) for e in content["auth_chain"]
|
||||
]
|
||||
|
||||
signed_auth = await self._check_sigs_and_hash_and_fetch(
|
||||
origin, auth_chain, outlier=True, room_version=room_version.identifier
|
||||
)
|
||||
|
||||
ret = await self.handler.on_query_auth(
|
||||
origin,
|
||||
event_id,
|
||||
room_id,
|
||||
signed_auth,
|
||||
content.get("rejects", []),
|
||||
content.get("missing", []),
|
||||
)
|
||||
|
||||
time_now = self._clock.time_msec()
|
||||
send_content = {
|
||||
"auth_chain": [e.get_pdu_json(time_now) for e in ret["auth_chain"]],
|
||||
"rejects": ret.get("rejects", []),
|
||||
"missing": ret.get("missing", []),
|
||||
}
|
||||
|
||||
return 200, send_content
|
||||
|
||||
@log_function
|
||||
def on_query_client_keys(self, origin, content):
|
||||
return self.on_query_request("client_keys", content)
|
||||
|
|
|
@ -643,17 +643,6 @@ class FederationClientKeysClaimServlet(BaseFederationServlet):
|
|||
return 200, response
|
||||
|
||||
|
||||
class FederationQueryAuthServlet(BaseFederationServlet):
|
||||
PATH = "/query_auth/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
|
||||
|
||||
async def on_POST(self, origin, content, query, context, event_id):
|
||||
new_content = await self.handler.on_query_auth_request(
|
||||
origin, content, context, event_id
|
||||
)
|
||||
|
||||
return 200, new_content
|
||||
|
||||
|
||||
class FederationGetMissingEventsServlet(BaseFederationServlet):
|
||||
# TODO(paul): Why does this path alone end with "/?" optional?
|
||||
PATH = "/get_missing_events/(?P<room_id>[^/]*)/?"
|
||||
|
@ -1412,7 +1401,6 @@ FEDERATION_SERVLET_CLASSES = (
|
|||
FederationV2SendLeaveServlet,
|
||||
FederationV1InviteServlet,
|
||||
FederationV2InviteServlet,
|
||||
FederationQueryAuthServlet,
|
||||
FederationGetMissingEventsServlet,
|
||||
FederationEventAuthServlet,
|
||||
FederationClientKeysQueryServlet,
|
||||
|
|
|
@ -44,7 +44,11 @@ class AccountValidityHandler(object):
|
|||
|
||||
self._account_validity = self.hs.config.account_validity
|
||||
|
||||
if self._account_validity.renew_by_email_enabled and load_jinja2_templates:
|
||||
if (
|
||||
self._account_validity.enabled
|
||||
and self._account_validity.renew_by_email_enabled
|
||||
and load_jinja2_templates
|
||||
):
|
||||
# Don't do email-specific configuration if renewal by email is disabled.
|
||||
try:
|
||||
app_name = self.hs.config.email_app_name
|
||||
|
|
|
@ -18,10 +18,10 @@ import logging
|
|||
import time
|
||||
import unicodedata
|
||||
import urllib.parse
|
||||
from typing import Any
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
|
||||
import attr
|
||||
import bcrypt
|
||||
import bcrypt # type: ignore[import]
|
||||
import pymacaroons
|
||||
|
||||
from twisted.internet import defer
|
||||
|
@ -45,7 +45,7 @@ from synapse.http.site import SynapseRequest
|
|||
from synapse.logging.context import defer_to_thread
|
||||
from synapse.module_api import ModuleApi
|
||||
from synapse.push.mailer import load_jinja2_templates
|
||||
from synapse.types import UserID
|
||||
from synapse.types import Requester, UserID
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
|
||||
from ._base import BaseHandler
|
||||
|
@ -63,11 +63,11 @@ class AuthHandler(BaseHandler):
|
|||
"""
|
||||
super(AuthHandler, self).__init__(hs)
|
||||
|
||||
self.checkers = {} # type: dict[str, UserInteractiveAuthChecker]
|
||||
self.checkers = {} # type: Dict[str, UserInteractiveAuthChecker]
|
||||
for auth_checker_class in INTERACTIVE_AUTH_CHECKERS:
|
||||
inst = auth_checker_class(hs)
|
||||
if inst.is_enabled():
|
||||
self.checkers[inst.AUTH_TYPE] = inst
|
||||
self.checkers[inst.AUTH_TYPE] = inst # type: ignore
|
||||
|
||||
self.bcrypt_rounds = hs.config.bcrypt_rounds
|
||||
|
||||
|
@ -124,7 +124,9 @@ class AuthHandler(BaseHandler):
|
|||
self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def validate_user_via_ui_auth(self, requester, request_body, clientip):
|
||||
def validate_user_via_ui_auth(
|
||||
self, requester: Requester, request_body: Dict[str, Any], clientip: str
|
||||
):
|
||||
"""
|
||||
Checks that the user is who they claim to be, via a UI auth.
|
||||
|
||||
|
@ -133,11 +135,11 @@ class AuthHandler(BaseHandler):
|
|||
that it isn't stolen by re-authenticating them.
|
||||
|
||||
Args:
|
||||
requester (Requester): The user, as given by the access token
|
||||
requester: The user, as given by the access token
|
||||
|
||||
request_body (dict): The body of the request sent by the client
|
||||
request_body: The body of the request sent by the client
|
||||
|
||||
clientip (str): The IP address of the client.
|
||||
clientip: The IP address of the client.
|
||||
|
||||
Returns:
|
||||
defer.Deferred[dict]: the parameters for this request (which may
|
||||
|
@ -208,7 +210,9 @@ class AuthHandler(BaseHandler):
|
|||
return self.checkers.keys()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_auth(self, flows, clientdict, clientip):
|
||||
def check_auth(
|
||||
self, flows: List[List[str]], clientdict: Dict[str, Any], clientip: str
|
||||
):
|
||||
"""
|
||||
Takes a dictionary sent by the client in the login / registration
|
||||
protocol and handles the User-Interactive Auth flow.
|
||||
|
@ -223,14 +227,14 @@ class AuthHandler(BaseHandler):
|
|||
decorator.
|
||||
|
||||
Args:
|
||||
flows (list): A list of login flows. Each flow is an ordered list of
|
||||
strings representing auth-types. At least one full
|
||||
flow must be completed in order for auth to be successful.
|
||||
flows: A list of login flows. Each flow is an ordered list of
|
||||
strings representing auth-types. At least one full
|
||||
flow must be completed in order for auth to be successful.
|
||||
|
||||
clientdict: The dictionary from the client root level, not the
|
||||
'auth' key: this method prompts for auth if none is sent.
|
||||
|
||||
clientip (str): The IP address of the client.
|
||||
clientip: The IP address of the client.
|
||||
|
||||
Returns:
|
||||
defer.Deferred[dict, dict, str]: a deferred tuple of
|
||||
|
@ -250,7 +254,7 @@ class AuthHandler(BaseHandler):
|
|||
"""
|
||||
|
||||
authdict = None
|
||||
sid = None
|
||||
sid = None # type: Optional[str]
|
||||
if clientdict and "auth" in clientdict:
|
||||
authdict = clientdict["auth"]
|
||||
del clientdict["auth"]
|
||||
|
@ -283,9 +287,9 @@ class AuthHandler(BaseHandler):
|
|||
creds = session["creds"]
|
||||
|
||||
# check auth type currently being presented
|
||||
errordict = {}
|
||||
errordict = {} # type: Dict[str, Any]
|
||||
if "type" in authdict:
|
||||
login_type = authdict["type"]
|
||||
login_type = authdict["type"] # type: str
|
||||
try:
|
||||
result = yield self._check_auth_dict(authdict, clientip)
|
||||
if result:
|
||||
|
@ -326,7 +330,7 @@ class AuthHandler(BaseHandler):
|
|||
raise InteractiveAuthIncompleteError(ret)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def add_oob_auth(self, stagetype, authdict, clientip):
|
||||
def add_oob_auth(self, stagetype: str, authdict: Dict[str, Any], clientip: str):
|
||||
"""
|
||||
Adds the result of out-of-band authentication into an existing auth
|
||||
session. Currently used for adding the result of fallback auth.
|
||||
|
@ -348,7 +352,7 @@ class AuthHandler(BaseHandler):
|
|||
return True
|
||||
return False
|
||||
|
||||
def get_session_id(self, clientdict):
|
||||
def get_session_id(self, clientdict: Dict[str, Any]) -> Optional[str]:
|
||||
"""
|
||||
Gets the session ID for a client given the client dictionary
|
||||
|
||||
|
@ -356,7 +360,7 @@ class AuthHandler(BaseHandler):
|
|||
clientdict: The dictionary sent by the client in the request
|
||||
|
||||
Returns:
|
||||
str|None: The string session ID the client sent. If the client did
|
||||
The string session ID the client sent. If the client did
|
||||
not send a session ID, returns None.
|
||||
"""
|
||||
sid = None
|
||||
|
@ -366,40 +370,42 @@ class AuthHandler(BaseHandler):
|
|||
sid = authdict["session"]
|
||||
return sid
|
||||
|
||||
def set_session_data(self, session_id, key, value):
|
||||
def set_session_data(self, session_id: str, key: str, value: Any) -> None:
|
||||
"""
|
||||
Store a key-value pair into the sessions data associated with this
|
||||
request. This data is stored server-side and cannot be modified by
|
||||
the client.
|
||||
|
||||
Args:
|
||||
session_id (string): The ID of this session as returned from check_auth
|
||||
key (string): The key to store the data under
|
||||
value (any): The data to store
|
||||
session_id: The ID of this session as returned from check_auth
|
||||
key: The key to store the data under
|
||||
value: The data to store
|
||||
"""
|
||||
sess = self._get_session_info(session_id)
|
||||
sess.setdefault("serverdict", {})[key] = value
|
||||
self._save_session(sess)
|
||||
|
||||
def get_session_data(self, session_id, key, default=None):
|
||||
def get_session_data(
|
||||
self, session_id: str, key: str, default: Optional[Any] = None
|
||||
) -> Any:
|
||||
"""
|
||||
Retrieve data stored with set_session_data
|
||||
|
||||
Args:
|
||||
session_id (string): The ID of this session as returned from check_auth
|
||||
key (string): The key to store the data under
|
||||
default (any): Value to return if the key has not been set
|
||||
session_id: The ID of this session as returned from check_auth
|
||||
key: The key to store the data under
|
||||
default: Value to return if the key has not been set
|
||||
"""
|
||||
sess = self._get_session_info(session_id)
|
||||
return sess.setdefault("serverdict", {}).get(key, default)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_auth_dict(self, authdict, clientip):
|
||||
def _check_auth_dict(self, authdict: Dict[str, Any], clientip: str):
|
||||
"""Attempt to validate the auth dict provided by a client
|
||||
|
||||
Args:
|
||||
authdict (object): auth dict provided by the client
|
||||
clientip (str): IP address of the client
|
||||
authdict: auth dict provided by the client
|
||||
clientip: IP address of the client
|
||||
|
||||
Returns:
|
||||
Deferred: result of the stage verification.
|
||||
|
@ -425,10 +431,10 @@ class AuthHandler(BaseHandler):
|
|||
(canonical_id, callback) = yield self.validate_login(user_id, authdict)
|
||||
return canonical_id
|
||||
|
||||
def _get_params_recaptcha(self):
|
||||
def _get_params_recaptcha(self) -> dict:
|
||||
return {"public_key": self.hs.config.recaptcha_public_key}
|
||||
|
||||
def _get_params_terms(self):
|
||||
def _get_params_terms(self) -> dict:
|
||||
return {
|
||||
"policies": {
|
||||
"privacy_policy": {
|
||||
|
@ -445,7 +451,9 @@ class AuthHandler(BaseHandler):
|
|||
}
|
||||
}
|
||||
|
||||
def _auth_dict_for_flows(self, flows, session):
|
||||
def _auth_dict_for_flows(
|
||||
self, flows: List[List[str]], session: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
public_flows = []
|
||||
for f in flows:
|
||||
public_flows.append(f)
|
||||
|
@ -455,7 +463,7 @@ class AuthHandler(BaseHandler):
|
|||
LoginType.TERMS: self._get_params_terms,
|
||||
}
|
||||
|
||||
params = {}
|
||||
params = {} # type: Dict[str, Any]
|
||||
|
||||
for f in public_flows:
|
||||
for stage in f:
|
||||
|
@ -468,7 +476,13 @@ class AuthHandler(BaseHandler):
|
|||
"params": params,
|
||||
}
|
||||
|
||||
def _get_session_info(self, session_id):
|
||||
def _get_session_info(self, session_id: Optional[str]) -> dict:
|
||||
"""
|
||||
Gets or creates a session given a session ID.
|
||||
|
||||
The session can be used to track data across multiple requests, e.g. for
|
||||
interactive authentication.
|
||||
"""
|
||||
if session_id not in self.sessions:
|
||||
session_id = None
|
||||
|
||||
|
@ -481,7 +495,9 @@ class AuthHandler(BaseHandler):
|
|||
return self.sessions[session_id]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_access_token_for_user_id(self, user_id, device_id, valid_until_ms):
|
||||
def get_access_token_for_user_id(
|
||||
self, user_id: str, device_id: Optional[str], valid_until_ms: Optional[int]
|
||||
):
|
||||
"""
|
||||
Creates a new access token for the user with the given user ID.
|
||||
|
||||
|
@ -491,11 +507,11 @@ class AuthHandler(BaseHandler):
|
|||
The device will be recorded in the table if it is not there already.
|
||||
|
||||
Args:
|
||||
user_id (str): canonical User ID
|
||||
device_id (str|None): the device ID to associate with the tokens.
|
||||
user_id: canonical User ID
|
||||
device_id: the device ID to associate with the tokens.
|
||||
None to leave the tokens unassociated with a device (deprecated:
|
||||
we should always have a device ID)
|
||||
valid_until_ms (int|None): when the token is valid until. None for
|
||||
valid_until_ms: when the token is valid until. None for
|
||||
no expiry.
|
||||
Returns:
|
||||
The access token for the user's session.
|
||||
|
@ -530,13 +546,13 @@ class AuthHandler(BaseHandler):
|
|||
return access_token
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_user_exists(self, user_id):
|
||||
def check_user_exists(self, user_id: str):
|
||||
"""
|
||||
Checks to see if a user with the given id exists. Will check case
|
||||
insensitively, but return None if there are multiple inexact matches.
|
||||
|
||||
Args:
|
||||
(unicode|bytes) user_id: complete @user:id
|
||||
user_id: complete @user:id
|
||||
|
||||
Returns:
|
||||
defer.Deferred: (unicode) canonical_user_id, or None if zero or
|
||||
|
@ -551,7 +567,7 @@ class AuthHandler(BaseHandler):
|
|||
return None
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _find_user_id_and_pwd_hash(self, user_id):
|
||||
def _find_user_id_and_pwd_hash(self, user_id: str):
|
||||
"""Checks to see if a user with the given id exists. Will check case
|
||||
insensitively, but will return None if there are multiple inexact
|
||||
matches.
|
||||
|
@ -581,7 +597,7 @@ class AuthHandler(BaseHandler):
|
|||
)
|
||||
return result
|
||||
|
||||
def get_supported_login_types(self):
|
||||
def get_supported_login_types(self) -> Iterable[str]:
|
||||
"""Get a the login types supported for the /login API
|
||||
|
||||
By default this is just 'm.login.password' (unless password_enabled is
|
||||
|
@ -589,20 +605,20 @@ class AuthHandler(BaseHandler):
|
|||
other login types.
|
||||
|
||||
Returns:
|
||||
Iterable[str]: login types
|
||||
login types
|
||||
"""
|
||||
return self._supported_login_types
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def validate_login(self, username, login_submission):
|
||||
def validate_login(self, username: str, login_submission: Dict[str, Any]):
|
||||
"""Authenticates the user for the /login API
|
||||
|
||||
Also used by the user-interactive auth flow to validate
|
||||
m.login.password auth types.
|
||||
|
||||
Args:
|
||||
username (str): username supplied by the user
|
||||
login_submission (dict): the whole of the login submission
|
||||
username: username supplied by the user
|
||||
login_submission: the whole of the login submission
|
||||
(including 'type' and other relevant fields)
|
||||
Returns:
|
||||
Deferred[str, func]: canonical user id, and optional callback
|
||||
|
@ -690,13 +706,13 @@ class AuthHandler(BaseHandler):
|
|||
raise LoginError(403, "Invalid password", errcode=Codes.FORBIDDEN)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_password_provider_3pid(self, medium, address, password):
|
||||
def check_password_provider_3pid(self, medium: str, address: str, password: str):
|
||||
"""Check if a password provider is able to validate a thirdparty login
|
||||
|
||||
Args:
|
||||
medium (str): The medium of the 3pid (ex. email).
|
||||
address (str): The address of the 3pid (ex. jdoe@example.com).
|
||||
password (str): The password of the user.
|
||||
medium: The medium of the 3pid (ex. email).
|
||||
address: The address of the 3pid (ex. jdoe@example.com).
|
||||
password: The password of the user.
|
||||
|
||||
Returns:
|
||||
Deferred[(str|None, func|None)]: A tuple of `(user_id,
|
||||
|
@ -724,15 +740,15 @@ class AuthHandler(BaseHandler):
|
|||
return None, None
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_local_password(self, user_id, password):
|
||||
def _check_local_password(self, user_id: str, password: str):
|
||||
"""Authenticate a user against the local password database.
|
||||
|
||||
user_id is checked case insensitively, but will return None if there are
|
||||
multiple inexact matches.
|
||||
|
||||
Args:
|
||||
user_id (unicode): complete @user:id
|
||||
password (unicode): the provided password
|
||||
user_id: complete @user:id
|
||||
password: the provided password
|
||||
Returns:
|
||||
Deferred[unicode] the canonical_user_id, or Deferred[None] if
|
||||
unknown user/bad password
|
||||
|
@ -755,7 +771,7 @@ class AuthHandler(BaseHandler):
|
|||
return user_id
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def validate_short_term_login_token_and_get_user_id(self, login_token):
|
||||
def validate_short_term_login_token_and_get_user_id(self, login_token: str):
|
||||
auth_api = self.hs.get_auth()
|
||||
user_id = None
|
||||
try:
|
||||
|
@ -769,11 +785,11 @@ class AuthHandler(BaseHandler):
|
|||
return user_id
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_access_token(self, access_token):
|
||||
def delete_access_token(self, access_token: str):
|
||||
"""Invalidate a single access token
|
||||
|
||||
Args:
|
||||
access_token (str): access token to be deleted
|
||||
access_token: access token to be deleted
|
||||
|
||||
Returns:
|
||||
Deferred
|
||||
|
@ -798,15 +814,17 @@ class AuthHandler(BaseHandler):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def delete_access_tokens_for_user(
|
||||
self, user_id, except_token_id=None, device_id=None
|
||||
self,
|
||||
user_id: str,
|
||||
except_token_id: Optional[str] = None,
|
||||
device_id: Optional[str] = None,
|
||||
):
|
||||
"""Invalidate access tokens belonging to a user
|
||||
|
||||
Args:
|
||||
user_id (str): ID of user the tokens belong to
|
||||
except_token_id (str|None): access_token ID which should *not* be
|
||||
deleted
|
||||
device_id (str|None): ID of device the tokens are associated with.
|
||||
user_id: ID of user the tokens belong to
|
||||
except_token_id: access_token ID which should *not* be deleted
|
||||
device_id: ID of device the tokens are associated with.
|
||||
If None, tokens associated with any device (or no device) will
|
||||
be deleted
|
||||
Returns:
|
||||
|
@ -830,7 +848,7 @@ class AuthHandler(BaseHandler):
|
|||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def add_threepid(self, user_id, medium, address, validated_at):
|
||||
def add_threepid(self, user_id: str, medium: str, address: str, validated_at: int):
|
||||
# check if medium has a valid value
|
||||
if medium not in ["email", "msisdn"]:
|
||||
raise SynapseError(
|
||||
|
@ -856,19 +874,20 @@ class AuthHandler(BaseHandler):
|
|||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_threepid(self, user_id, medium, address, id_server=None):
|
||||
def delete_threepid(
|
||||
self, user_id: str, medium: str, address: str, id_server: Optional[str] = None
|
||||
):
|
||||
"""Attempts to unbind the 3pid on the identity servers and deletes it
|
||||
from the local database.
|
||||
|
||||
Args:
|
||||
user_id (str)
|
||||
medium (str)
|
||||
address (str)
|
||||
id_server (str|None): Use the given identity server when unbinding
|
||||
user_id: ID of user to remove the 3pid from.
|
||||
medium: The medium of the 3pid being removed: "email" or "msisdn".
|
||||
address: The 3pid address to remove.
|
||||
id_server: Use the given identity server when unbinding
|
||||
any threepids. If None then will attempt to unbind using the
|
||||
identity server specified when binding (if known).
|
||||
|
||||
|
||||
Returns:
|
||||
Deferred[bool]: Returns True if successfully unbound the 3pid on
|
||||
the identity server, False if identity server doesn't support the
|
||||
|
@ -887,17 +906,18 @@ class AuthHandler(BaseHandler):
|
|||
yield self.store.user_delete_threepid(user_id, medium, address)
|
||||
return result
|
||||
|
||||
def _save_session(self, session):
|
||||
def _save_session(self, session: Dict[str, Any]) -> None:
|
||||
"""Update the last used time on the session to now and add it back to the session store."""
|
||||
# TODO: Persistent storage
|
||||
logger.debug("Saving session %s", session)
|
||||
session["last_used"] = self.hs.get_clock().time_msec()
|
||||
self.sessions[session["id"]] = session
|
||||
|
||||
def hash(self, password):
|
||||
def hash(self, password: str):
|
||||
"""Computes a secure hash of password.
|
||||
|
||||
Args:
|
||||
password (unicode): Password to hash.
|
||||
password: Password to hash.
|
||||
|
||||
Returns:
|
||||
Deferred(unicode): Hashed password.
|
||||
|
@ -914,12 +934,12 @@ class AuthHandler(BaseHandler):
|
|||
|
||||
return defer_to_thread(self.hs.get_reactor(), _do_hash)
|
||||
|
||||
def validate_hash(self, password, stored_hash):
|
||||
def validate_hash(self, password: str, stored_hash: bytes):
|
||||
"""Validates that self.hash(password) == stored_hash.
|
||||
|
||||
Args:
|
||||
password (unicode): Password to hash.
|
||||
stored_hash (bytes): Expected hash value.
|
||||
password: Password to hash.
|
||||
stored_hash: Expected hash value.
|
||||
|
||||
Returns:
|
||||
Deferred(bool): Whether self.hash(password) == stored_hash.
|
||||
|
@ -1007,7 +1027,9 @@ class MacaroonGenerator(object):
|
|||
|
||||
hs = attr.ib()
|
||||
|
||||
def generate_access_token(self, user_id, extra_caveats=None):
|
||||
def generate_access_token(
|
||||
self, user_id: str, extra_caveats: Optional[List[str]] = None
|
||||
) -> str:
|
||||
extra_caveats = extra_caveats or []
|
||||
macaroon = self._generate_base_macaroon(user_id)
|
||||
macaroon.add_first_party_caveat("type = access")
|
||||
|
@ -1020,16 +1042,9 @@ class MacaroonGenerator(object):
|
|||
macaroon.add_first_party_caveat(caveat)
|
||||
return macaroon.serialize()
|
||||
|
||||
def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)):
|
||||
"""
|
||||
|
||||
Args:
|
||||
user_id (unicode):
|
||||
duration_in_ms (int):
|
||||
|
||||
Returns:
|
||||
unicode
|
||||
"""
|
||||
def generate_short_term_login_token(
|
||||
self, user_id: str, duration_in_ms: int = (2 * 60 * 1000)
|
||||
) -> str:
|
||||
macaroon = self._generate_base_macaroon(user_id)
|
||||
macaroon.add_first_party_caveat("type = login")
|
||||
now = self.hs.get_clock().time_msec()
|
||||
|
@ -1037,12 +1052,12 @@ class MacaroonGenerator(object):
|
|||
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
||||
return macaroon.serialize()
|
||||
|
||||
def generate_delete_pusher_token(self, user_id):
|
||||
def generate_delete_pusher_token(self, user_id: str) -> str:
|
||||
macaroon = self._generate_base_macaroon(user_id)
|
||||
macaroon.add_first_party_caveat("type = delete_pusher")
|
||||
return macaroon.serialize()
|
||||
|
||||
def _generate_base_macaroon(self, user_id):
|
||||
def _generate_base_macaroon(self, user_id: str) -> pymacaroons.Macaroon:
|
||||
macaroon = pymacaroons.Macaroon(
|
||||
location=self.hs.config.server_name,
|
||||
identifier="key",
|
||||
|
|
|
@ -292,16 +292,6 @@ class RoomCreationHandler(BaseHandler):
|
|||
except AuthError as e:
|
||||
logger.warning("Unable to update PLs in old room: %s", e)
|
||||
|
||||
new_pl_content = copy_power_levels_contents(old_room_pl_state.content)
|
||||
|
||||
# pre-msc2260 rooms may not have the right setting for aliases. If no other
|
||||
# value is set, set it now.
|
||||
events_default = new_pl_content.get("events_default", 0)
|
||||
new_pl_content.setdefault("events", {}).setdefault(
|
||||
EventTypes.Aliases, events_default
|
||||
)
|
||||
|
||||
logger.debug("Setting correct PLs in new room to %s", new_pl_content)
|
||||
yield self.event_creation_handler.create_and_send_nonmember_event(
|
||||
requester,
|
||||
{
|
||||
|
@ -309,7 +299,7 @@ class RoomCreationHandler(BaseHandler):
|
|||
"state_key": "",
|
||||
"room_id": new_room_id,
|
||||
"sender": requester.user.to_string(),
|
||||
"content": new_pl_content,
|
||||
"content": old_room_pl_state.content,
|
||||
},
|
||||
ratelimit=False,
|
||||
)
|
||||
|
@ -814,10 +804,6 @@ class RoomCreationHandler(BaseHandler):
|
|||
EventTypes.RoomHistoryVisibility: 100,
|
||||
EventTypes.CanonicalAlias: 50,
|
||||
EventTypes.RoomAvatar: 50,
|
||||
# MSC2260: Allow everybody to send alias events by default
|
||||
# This will be reudundant on pre-MSC2260 rooms, since the
|
||||
# aliases event is special-cased.
|
||||
EventTypes.Aliases: 0,
|
||||
EventTypes.Tombstone: 100,
|
||||
EventTypes.ServerACL: 100,
|
||||
},
|
||||
|
|
|
@ -23,6 +23,7 @@ from saml2.client import Saml2Client
|
|||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.config import ConfigError
|
||||
from synapse.http.server import finish_request
|
||||
from synapse.http.servlet import parse_string
|
||||
from synapse.module_api import ModuleApi
|
||||
from synapse.types import (
|
||||
|
@ -73,6 +74,8 @@ class SamlHandler:
|
|||
# a lock on the mappings
|
||||
self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock)
|
||||
|
||||
self._error_html_content = hs.config.saml2_error_html_content
|
||||
|
||||
def handle_redirect_request(self, client_redirect_url):
|
||||
"""Handle an incoming request to /login/sso/redirect
|
||||
|
||||
|
@ -114,7 +117,22 @@ class SamlHandler:
|
|||
# the dict.
|
||||
self.expire_sessions()
|
||||
|
||||
user_id = await self._map_saml_response_to_user(resp_bytes, relay_state)
|
||||
try:
|
||||
user_id = await self._map_saml_response_to_user(resp_bytes, relay_state)
|
||||
except Exception as e:
|
||||
# If decoding the response or mapping it to a user failed, then log the
|
||||
# error and tell the user that something went wrong.
|
||||
logger.error(e)
|
||||
|
||||
request.setResponseCode(400)
|
||||
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
|
||||
request.setHeader(
|
||||
b"Content-Length", b"%d" % (len(self._error_html_content),)
|
||||
)
|
||||
request.write(self._error_html_content.encode("utf8"))
|
||||
finish_request(request)
|
||||
return
|
||||
|
||||
self._auth_handler.complete_sso_login(user_id, request, relay_state)
|
||||
|
||||
async def _map_saml_response_to_user(self, resp_bytes, client_redirect_url):
|
||||
|
|
|
@ -13,10 +13,12 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import Codes, StoreError, SynapseError
|
||||
from synapse.types import Requester
|
||||
|
||||
from ._base import BaseHandler
|
||||
|
||||
|
@ -32,14 +34,17 @@ class SetPasswordHandler(BaseHandler):
|
|||
self._device_handler = hs.get_device_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def set_password(self, user_id, newpassword, requester=None):
|
||||
def set_password(
|
||||
self,
|
||||
user_id: str,
|
||||
new_password: str,
|
||||
logout_devices: bool,
|
||||
requester: Optional[Requester] = None,
|
||||
):
|
||||
if not self.hs.config.password_localdb_enabled:
|
||||
raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN)
|
||||
|
||||
password_hash = yield self._auth_handler.hash(newpassword)
|
||||
|
||||
except_device_id = requester.device_id if requester else None
|
||||
except_access_token_id = requester.access_token_id if requester else None
|
||||
password_hash = yield self._auth_handler.hash(new_password)
|
||||
|
||||
try:
|
||||
yield self.store.user_set_password_hash(user_id, password_hash)
|
||||
|
@ -48,14 +53,18 @@ class SetPasswordHandler(BaseHandler):
|
|||
raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
|
||||
raise e
|
||||
|
||||
# we want to log out all of the user's other sessions. First delete
|
||||
# all his other devices.
|
||||
yield self._device_handler.delete_all_devices_for_user(
|
||||
user_id, except_device_id=except_device_id
|
||||
)
|
||||
# Optionally, log out all of the user's other sessions.
|
||||
if logout_devices:
|
||||
except_device_id = requester.device_id if requester else None
|
||||
except_access_token_id = requester.access_token_id if requester else None
|
||||
|
||||
# and now delete any access tokens which weren't associated with
|
||||
# devices (or were associated with this device).
|
||||
yield self._auth_handler.delete_access_tokens_for_user(
|
||||
user_id, except_token_id=except_access_token_id
|
||||
)
|
||||
# First delete all of their other devices.
|
||||
yield self._device_handler.delete_all_devices_for_user(
|
||||
user_id, except_device_id=except_device_id
|
||||
)
|
||||
|
||||
# and now delete any access tokens which weren't associated with
|
||||
# devices (or were associated with this device).
|
||||
yield self._auth_handler.delete_access_tokens_for_user(
|
||||
user_id, except_token_id=except_access_token_id
|
||||
)
|
||||
|
|
|
@ -244,9 +244,6 @@ class SimpleHttpClient(object):
|
|||
pool.maxPersistentPerHost = max((100 * CACHE_SIZE_FACTOR, 5))
|
||||
pool.cachedConnectionTimeout = 2 * 60
|
||||
|
||||
# The default context factory in Twisted 14.0.0 (which we require) is
|
||||
# BrowserLikePolicyForHTTPS which will do regular cert validation
|
||||
# 'like a browser'
|
||||
self.agent = ProxyAgent(
|
||||
self.reactor,
|
||||
connectTimeout=15,
|
||||
|
|
|
@ -45,7 +45,7 @@ class MatrixFederationAgent(object):
|
|||
Args:
|
||||
reactor (IReactor): twisted reactor to use for underlying requests
|
||||
|
||||
tls_client_options_factory (ClientTLSOptionsFactory|None):
|
||||
tls_client_options_factory (FederationPolicyForHTTPS|None):
|
||||
factory to use for fetching client tls options, or none to disable TLS.
|
||||
|
||||
_srv_resolver (SrvResolver|None):
|
||||
|
|
|
@ -210,7 +210,7 @@ class LoggingContext(object):
|
|||
class Sentinel(object):
|
||||
"""Sentinel to represent the root context"""
|
||||
|
||||
__slots__ = ["previous_context", "alive", "request", "scope"]
|
||||
__slots__ = ["previous_context", "alive", "request", "scope", "tag"]
|
||||
|
||||
def __init__(self) -> None:
|
||||
# Minimal set for compatibility with LoggingContext
|
||||
|
@ -218,6 +218,7 @@ class LoggingContext(object):
|
|||
self.alive = None
|
||||
self.request = None
|
||||
self.scope = None
|
||||
self.tag = None
|
||||
|
||||
def __str__(self):
|
||||
return "sentinel"
|
||||
|
@ -511,7 +512,7 @@ class PreserveLoggingContext(object):
|
|||
|
||||
__slots__ = ["current_context", "new_context", "has_parent"]
|
||||
|
||||
def __init__(self, new_context: Optional[LoggingContext] = None) -> None:
|
||||
def __init__(self, new_context: Optional[LoggingContextOrSentinel] = None) -> None:
|
||||
if new_context is None:
|
||||
self.new_context = LoggingContext.sentinel # type: LoggingContextOrSentinel
|
||||
else:
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
<body>
|
||||
<p>Oops! Something went wrong during authentication<span id="errormsg"></span>.</p>
|
||||
<p>
|
||||
If you're seeing this page after clicking a link sent to you via email, make
|
||||
If you are seeing this page after clicking a link sent to you via email, make
|
||||
sure you only click the confirmation link once, and that you open the
|
||||
validation link in the same client you're logging in from.
|
||||
</p>
|
||||
|
@ -24,19 +24,22 @@
|
|||
// we just don't print anything specific.
|
||||
let searchStr = "";
|
||||
if (window.location.search) {
|
||||
// For some reason window.location.searchParams isn't always defined when
|
||||
// window.location.search is, so we can't just use it right away.
|
||||
// window.location.searchParams isn't always defined when
|
||||
// window.location.search is, so it's more reliable to parse the latter.
|
||||
searchStr = window.location.search;
|
||||
} else if (window.location.hash) {
|
||||
//
|
||||
// Replace the # with a ? so that URLSearchParams does the right thing and
|
||||
// doesn't parse the first parameter incorrectly.
|
||||
searchStr = window.location.hash.replace("#", "?");
|
||||
}
|
||||
|
||||
// We might end up with no error in the URL, so we need to check if we have one
|
||||
// to print one.
|
||||
let errorDesc = new URLSearchParams(searchStr).get("error_description")
|
||||
|
||||
if (errorDesc) {
|
||||
document.getElementById("errormsg").innerHTML = ` ("${errorDesc}")`;
|
||||
|
||||
document.getElementById("errormsg").innerText = ` ("${errorDesc}")`;
|
||||
}
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
</html>
|
||||
|
|
|
@ -221,8 +221,9 @@ class UserRestServletV2(RestServlet):
|
|||
raise SynapseError(400, "Invalid password")
|
||||
else:
|
||||
new_password = body["password"]
|
||||
logout_devices = True
|
||||
await self.set_password_handler.set_password(
|
||||
target_user.to_string(), new_password, requester
|
||||
target_user.to_string(), new_password, logout_devices, requester
|
||||
)
|
||||
|
||||
if "deactivated" in body:
|
||||
|
@ -536,9 +537,10 @@ class ResetPasswordRestServlet(RestServlet):
|
|||
params = parse_json_object_from_request(request)
|
||||
assert_params_in_dict(params, ["new_password"])
|
||||
new_password = params["new_password"]
|
||||
logout_devices = params.get("logout_devices", True)
|
||||
|
||||
await self._set_password_handler.set_password(
|
||||
target_user_id, new_password, requester
|
||||
target_user_id, new_password, logout_devices, requester
|
||||
)
|
||||
return 200, {}
|
||||
|
||||
|
|
|
@ -189,12 +189,6 @@ class RoomStateEventRestServlet(TransactionRestServlet):
|
|||
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
if event_type == EventTypes.Aliases:
|
||||
# MSC2260
|
||||
raise SynapseError(
|
||||
400, "Cannot send m.room.aliases events via /rooms/{room_id}/state"
|
||||
)
|
||||
|
||||
event_dict = {
|
||||
"type": event_type,
|
||||
"content": content,
|
||||
|
@ -242,12 +236,6 @@ class RoomSendEventRestServlet(TransactionRestServlet):
|
|||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
if event_type == EventTypes.Aliases:
|
||||
# MSC2260
|
||||
raise SynapseError(
|
||||
400, "Cannot send m.room.aliases events via /rooms/{room_id}/send"
|
||||
)
|
||||
|
||||
event_dict = {
|
||||
"type": event_type,
|
||||
"content": content,
|
||||
|
|
|
@ -265,8 +265,11 @@ class PasswordRestServlet(RestServlet):
|
|||
|
||||
assert_params_in_dict(params, ["new_password"])
|
||||
new_password = params["new_password"]
|
||||
logout_devices = params.get("logout_devices", True)
|
||||
|
||||
await self._set_password_handler.set_password(user_id, new_password, requester)
|
||||
await self._set_password_handler.set_password(
|
||||
user_id, new_password, logout_devices, requester
|
||||
)
|
||||
|
||||
return 200, {}
|
||||
|
||||
|
|
|
@ -30,6 +30,22 @@ from synapse.util.stringutils import is_ascii
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# list all text content types that will have the charset default to UTF-8 when
|
||||
# none is given
|
||||
TEXT_CONTENT_TYPES = [
|
||||
"text/css",
|
||||
"text/csv",
|
||||
"text/html",
|
||||
"text/calendar",
|
||||
"text/plain",
|
||||
"text/javascript",
|
||||
"application/json",
|
||||
"application/ld+json",
|
||||
"application/rtf",
|
||||
"image/svg+xml",
|
||||
"text/xml",
|
||||
]
|
||||
|
||||
|
||||
def parse_media_id(request):
|
||||
try:
|
||||
|
@ -96,7 +112,14 @@ def add_file_headers(request, media_type, file_size, upload_name):
|
|||
def _quote(x):
|
||||
return urllib.parse.quote(x.encode("utf-8"))
|
||||
|
||||
request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
|
||||
# Default to a UTF-8 charset for text content types.
|
||||
# ex, uses UTF-8 for 'text/css' but not 'text/css; charset=UTF-16'
|
||||
if media_type.lower() in TEXT_CONTENT_TYPES:
|
||||
content_type = media_type + "; charset=UTF-8"
|
||||
else:
|
||||
content_type = media_type
|
||||
|
||||
request.setHeader(b"Content-Type", content_type.encode("UTF-8"))
|
||||
if upload_name:
|
||||
# RFC6266 section 4.1 [1] defines both `filename` and `filename*`.
|
||||
#
|
||||
|
|
|
@ -14,7 +14,11 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.http.server import DirectServeResource, wrap_html_request_handler
|
||||
from synapse.http.server import (
|
||||
DirectServeResource,
|
||||
finish_request,
|
||||
wrap_html_request_handler,
|
||||
)
|
||||
|
||||
|
||||
class SAML2ResponseResource(DirectServeResource):
|
||||
|
@ -24,8 +28,20 @@ class SAML2ResponseResource(DirectServeResource):
|
|||
|
||||
def __init__(self, hs):
|
||||
super().__init__()
|
||||
self._error_html_content = hs.config.saml2_error_html_content
|
||||
self._saml_handler = hs.get_saml_handler()
|
||||
|
||||
async def _async_render_GET(self, request):
|
||||
# We're not expecting any GET request on that resource if everything goes right,
|
||||
# but some IdPs sometimes end up responding with a 302 redirect on this endpoint.
|
||||
# In this case, just tell the user that something went wrong and they should
|
||||
# try to authenticate again.
|
||||
request.setResponseCode(400)
|
||||
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
|
||||
request.setHeader(b"Content-Length", b"%d" % (len(self._error_html_content),))
|
||||
request.write(self._error_html_content.encode("utf8"))
|
||||
finish_request(request)
|
||||
|
||||
@wrap_html_request_handler
|
||||
async def _async_render_POST(self, request):
|
||||
return await self._saml_handler.handle_saml_response(request)
|
||||
|
|
|
@ -26,7 +26,6 @@ import logging
|
|||
import os
|
||||
|
||||
from twisted.mail.smtp import sendmail
|
||||
from twisted.web.client import BrowserLikePolicyForHTTPS
|
||||
|
||||
from synapse.api.auth import Auth
|
||||
from synapse.api.filtering import Filtering
|
||||
|
@ -35,6 +34,7 @@ from synapse.appservice.api import ApplicationServiceApi
|
|||
from synapse.appservice.scheduler import ApplicationServiceScheduler
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.crypto import context_factory
|
||||
from synapse.crypto.context_factory import RegularPolicyForHTTPS
|
||||
from synapse.crypto.keyring import Keyring
|
||||
from synapse.events.builder import EventBuilderFactory
|
||||
from synapse.events.spamcheck import SpamChecker
|
||||
|
@ -310,7 +310,7 @@ class HomeServer(object):
|
|||
return (
|
||||
InsecureInterceptableContextFactory()
|
||||
if self.config.use_insecure_ssl_client_just_for_testing_do_not_use
|
||||
else BrowserLikePolicyForHTTPS()
|
||||
else RegularPolicyForHTTPS()
|
||||
)
|
||||
|
||||
def build_simple_http_client(self):
|
||||
|
@ -420,7 +420,7 @@ class HomeServer(object):
|
|||
return PusherPool(self)
|
||||
|
||||
def build_http_client(self):
|
||||
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(
|
||||
tls_client_options_factory = context_factory.FederationPolicyForHTTPS(
|
||||
self.config
|
||||
)
|
||||
return MatrixFederationHttpClient(self, tls_client_options_factory)
|
||||
|
|
|
@ -662,28 +662,16 @@ class StateResolutionStore(object):
|
|||
allow_rejected=allow_rejected,
|
||||
)
|
||||
|
||||
def get_auth_chain(self, event_ids: List[str], ignore_events: Set[str]):
|
||||
"""Gets the full auth chain for a set of events (including rejected
|
||||
events).
|
||||
|
||||
Includes the given event IDs in the result.
|
||||
|
||||
Note that:
|
||||
1. All events must be state events.
|
||||
2. For v1 rooms this may not have the full auth chain in the
|
||||
presence of rejected events
|
||||
|
||||
Args:
|
||||
event_ids: The event IDs of the events to fetch the auth chain for.
|
||||
Must be state events.
|
||||
ignore_events: Set of events to exclude from the returned auth
|
||||
chain.
|
||||
def get_auth_chain_difference(self, state_sets: List[Set[str]]):
|
||||
"""Given sets of state events figure out the auth chain difference (as
|
||||
per state res v2 algorithm).
|
||||
|
||||
This equivalent to fetching the full auth chain for each set of state
|
||||
and returning the events that don't appear in each and every auth
|
||||
chain.
|
||||
|
||||
Returns:
|
||||
Deferred[list[str]]: List of event IDs of the auth chain.
|
||||
Deferred[Set[str]]: Set of event IDs.
|
||||
"""
|
||||
|
||||
return self.store.get_auth_chain_ids(
|
||||
event_ids, include_given=True, ignore_events=ignore_events,
|
||||
)
|
||||
return self.store.get_auth_chain_difference(state_sets)
|
||||
|
|
|
@ -227,36 +227,12 @@ def _get_auth_chain_difference(state_sets, event_map, state_res_store):
|
|||
Returns:
|
||||
Deferred[set[str]]: Set of event IDs
|
||||
"""
|
||||
common = set(itervalues(state_sets[0])).intersection(
|
||||
*(itervalues(s) for s in state_sets[1:])
|
||||
|
||||
difference = yield state_res_store.get_auth_chain_difference(
|
||||
[set(state_set.values()) for state_set in state_sets]
|
||||
)
|
||||
|
||||
auth_sets = []
|
||||
for state_set in state_sets:
|
||||
auth_ids = {
|
||||
eid
|
||||
for key, eid in iteritems(state_set)
|
||||
if (
|
||||
key[0] in (EventTypes.Member, EventTypes.ThirdPartyInvite)
|
||||
or key
|
||||
in (
|
||||
(EventTypes.PowerLevels, ""),
|
||||
(EventTypes.Create, ""),
|
||||
(EventTypes.JoinRules, ""),
|
||||
)
|
||||
)
|
||||
and eid not in common
|
||||
}
|
||||
|
||||
auth_chain = yield state_res_store.get_auth_chain(auth_ids, common)
|
||||
auth_ids.update(auth_chain)
|
||||
|
||||
auth_sets.append(auth_ids)
|
||||
|
||||
intersection = set(auth_sets[0]).intersection(*auth_sets[1:])
|
||||
union = set().union(*auth_sets)
|
||||
|
||||
return union - intersection
|
||||
return difference
|
||||
|
||||
|
||||
def _seperate(state_sets):
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# limitations under the License.
|
||||
import itertools
|
||||
import logging
|
||||
from typing import List, Optional, Set
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
from six.moves.queue import Empty, PriorityQueue
|
||||
|
||||
|
@ -103,6 +103,154 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
|||
|
||||
return list(results)
|
||||
|
||||
def get_auth_chain_difference(self, state_sets: List[Set[str]]):
|
||||
"""Given sets of state events figure out the auth chain difference (as
|
||||
per state res v2 algorithm).
|
||||
|
||||
This equivalent to fetching the full auth chain for each set of state
|
||||
and returning the events that don't appear in each and every auth
|
||||
chain.
|
||||
|
||||
Returns:
|
||||
Deferred[Set[str]]
|
||||
"""
|
||||
|
||||
return self.db.runInteraction(
|
||||
"get_auth_chain_difference",
|
||||
self._get_auth_chain_difference_txn,
|
||||
state_sets,
|
||||
)
|
||||
|
||||
def _get_auth_chain_difference_txn(
|
||||
self, txn, state_sets: List[Set[str]]
|
||||
) -> Set[str]:
|
||||
|
||||
# Algorithm Description
|
||||
# ~~~~~~~~~~~~~~~~~~~~~
|
||||
#
|
||||
# The idea here is to basically walk the auth graph of each state set in
|
||||
# tandem, keeping track of which auth events are reachable by each state
|
||||
# set. If we reach an auth event we've already visited (via a different
|
||||
# state set) then we mark that auth event and all ancestors as reachable
|
||||
# by the state set. This requires that we keep track of the auth chains
|
||||
# in memory.
|
||||
#
|
||||
# Doing it in a such a way means that we can stop early if all auth
|
||||
# events we're currently walking are reachable by all state sets.
|
||||
#
|
||||
# *Note*: We can't stop walking an event's auth chain if it is reachable
|
||||
# by all state sets. This is because other auth chains we're walking
|
||||
# might be reachable only via the original auth chain. For example,
|
||||
# given the following auth chain:
|
||||
#
|
||||
# A -> C -> D -> E
|
||||
# / /
|
||||
# B -´---------´
|
||||
#
|
||||
# and state sets {A} and {B} then walking the auth chains of A and B
|
||||
# would immediately show that C is reachable by both. However, if we
|
||||
# stopped at C then we'd only reach E via the auth chain of B and so E
|
||||
# would errornously get included in the returned difference.
|
||||
#
|
||||
# The other thing that we do is limit the number of auth chains we walk
|
||||
# at once, due to practical limits (i.e. we can only query the database
|
||||
# with a limited set of parameters). We pick the auth chains we walk
|
||||
# each iteration based on their depth, in the hope that events with a
|
||||
# lower depth are likely reachable by those with higher depths.
|
||||
#
|
||||
# We could use any ordering that we believe would give a rough
|
||||
# topological ordering, e.g. origin server timestamp. If the ordering
|
||||
# chosen is not topological then the algorithm still produces the right
|
||||
# result, but perhaps a bit more inefficiently. This is why it is safe
|
||||
# to use "depth" here.
|
||||
|
||||
initial_events = set(state_sets[0]).union(*state_sets[1:])
|
||||
|
||||
# Dict from events in auth chains to which sets *cannot* reach them.
|
||||
# I.e. if the set is empty then all sets can reach the event.
|
||||
event_to_missing_sets = {
|
||||
event_id: {i for i, a in enumerate(state_sets) if event_id not in a}
|
||||
for event_id in initial_events
|
||||
}
|
||||
|
||||
# We need to get the depth of the initial events for sorting purposes.
|
||||
sql = """
|
||||
SELECT depth, event_id FROM events
|
||||
WHERE %s
|
||||
ORDER BY depth ASC
|
||||
"""
|
||||
clause, args = make_in_list_sql_clause(
|
||||
txn.database_engine, "event_id", initial_events
|
||||
)
|
||||
txn.execute(sql % (clause,), args)
|
||||
|
||||
# The sorted list of events whose auth chains we should walk.
|
||||
search = txn.fetchall() # type: List[Tuple[int, str]]
|
||||
|
||||
# Map from event to its auth events
|
||||
event_to_auth_events = {} # type: Dict[str, Set[str]]
|
||||
|
||||
base_sql = """
|
||||
SELECT a.event_id, auth_id, depth
|
||||
FROM event_auth AS a
|
||||
INNER JOIN events AS e ON (e.event_id = a.auth_id)
|
||||
WHERE
|
||||
"""
|
||||
|
||||
while search:
|
||||
# Check whether all our current walks are reachable by all state
|
||||
# sets. If so we can bail.
|
||||
if all(not event_to_missing_sets[eid] for _, eid in search):
|
||||
break
|
||||
|
||||
# Fetch the auth events and their depths of the N last events we're
|
||||
# currently walking
|
||||
search, chunk = search[:-100], search[-100:]
|
||||
clause, args = make_in_list_sql_clause(
|
||||
txn.database_engine, "a.event_id", [e_id for _, e_id in chunk]
|
||||
)
|
||||
txn.execute(base_sql + clause, args)
|
||||
|
||||
for event_id, auth_event_id, auth_event_depth in txn:
|
||||
event_to_auth_events.setdefault(event_id, set()).add(auth_event_id)
|
||||
|
||||
sets = event_to_missing_sets.get(auth_event_id)
|
||||
if sets is None:
|
||||
# First time we're seeing this event, so we add it to the
|
||||
# queue of things to fetch.
|
||||
search.append((auth_event_depth, auth_event_id))
|
||||
|
||||
# Assume that this event is unreachable from any of the
|
||||
# state sets until proven otherwise
|
||||
sets = event_to_missing_sets[auth_event_id] = set(
|
||||
range(len(state_sets))
|
||||
)
|
||||
else:
|
||||
# We've previously seen this event, so look up its auth
|
||||
# events and recursively mark all ancestors as reachable
|
||||
# by the current event's state set.
|
||||
a_ids = event_to_auth_events.get(auth_event_id)
|
||||
while a_ids:
|
||||
new_aids = set()
|
||||
for a_id in a_ids:
|
||||
event_to_missing_sets[a_id].intersection_update(
|
||||
event_to_missing_sets[event_id]
|
||||
)
|
||||
|
||||
b = event_to_auth_events.get(a_id)
|
||||
if b:
|
||||
new_aids.update(b)
|
||||
|
||||
a_ids = new_aids
|
||||
|
||||
# Mark that the auth event is reachable by the approriate sets.
|
||||
sets.intersection_update(event_to_missing_sets[event_id])
|
||||
|
||||
search.sort()
|
||||
|
||||
# Return all events where not all sets can reach them.
|
||||
return {eid for eid, n in event_to_missing_sets.items() if n}
|
||||
|
||||
def get_oldest_events_in_room(self, room_id):
|
||||
return self.db.runInteraction(
|
||||
"get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id
|
||||
|
|
|
@ -29,7 +29,11 @@ from twisted.internet import defer
|
|||
|
||||
from synapse.api.errors import StoreError
|
||||
from synapse.config.database import DatabaseConnectionConfig
|
||||
from synapse.logging.context import LoggingContext, make_deferred_yieldable
|
||||
from synapse.logging.context import (
|
||||
LoggingContext,
|
||||
LoggingContextOrSentinel,
|
||||
make_deferred_yieldable,
|
||||
)
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.storage.background_updates import BackgroundUpdater
|
||||
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
|
||||
|
@ -543,7 +547,9 @@ class Database(object):
|
|||
Returns:
|
||||
Deferred: The result of func
|
||||
"""
|
||||
parent_context = LoggingContext.current_context()
|
||||
parent_context = (
|
||||
LoggingContext.current_context()
|
||||
) # type: Optional[LoggingContextOrSentinel]
|
||||
if parent_context == LoggingContext.sentinel:
|
||||
logger.warning(
|
||||
"Starting db connection from sentinel context: metrics will be lost"
|
||||
|
|
|
@ -118,30 +118,36 @@ def filter_events_for_client(
|
|||
|
||||
the original event if they can see it as normal.
|
||||
"""
|
||||
if event.type == "org.matrix.dummy_event" and filter_send_to_client:
|
||||
return None
|
||||
# Only run some checks if these events aren't about to be sent to clients. This is
|
||||
# because, if this is not the case, we're probably only checking if the users can
|
||||
# see events in the room at that point in the DAG, and that shouldn't be decided
|
||||
# on those checks.
|
||||
if filter_send_to_client:
|
||||
if event.type == "org.matrix.dummy_event":
|
||||
return None
|
||||
|
||||
if not event.is_state() and event.sender in ignore_list and filter_send_to_client:
|
||||
return None
|
||||
if not event.is_state() and event.sender in ignore_list:
|
||||
return None
|
||||
|
||||
# Until MSC2261 has landed we can't redact malicious alias events, so for
|
||||
# now we temporarily filter out m.room.aliases entirely to mitigate
|
||||
# abuse, while we spec a better solution to advertising aliases
|
||||
# on rooms.
|
||||
if event.type == EventTypes.Aliases:
|
||||
return None
|
||||
# Until MSC2261 has landed we can't redact malicious alias events, so for
|
||||
# now we temporarily filter out m.room.aliases entirely to mitigate
|
||||
# abuse, while we spec a better solution to advertising aliases
|
||||
# on rooms.
|
||||
if event.type == EventTypes.Aliases:
|
||||
return None
|
||||
|
||||
# Don't try to apply the room's retention policy if the event is a state event, as
|
||||
# MSC1763 states that retention is only considered for non-state events.
|
||||
if filter_send_to_client and not event.is_state():
|
||||
retention_policy = retention_policies[event.room_id]
|
||||
max_lifetime = retention_policy.get("max_lifetime")
|
||||
# Don't try to apply the room's retention policy if the event is a state
|
||||
# event, as MSC1763 states that retention is only considered for non-state
|
||||
# events.
|
||||
if not event.is_state():
|
||||
retention_policy = retention_policies[event.room_id]
|
||||
max_lifetime = retention_policy.get("max_lifetime")
|
||||
|
||||
if max_lifetime is not None:
|
||||
oldest_allowed_ts = storage.main.clock.time_msec() - max_lifetime
|
||||
if max_lifetime is not None:
|
||||
oldest_allowed_ts = storage.main.clock.time_msec() - max_lifetime
|
||||
|
||||
if event.origin_server_ts < oldest_allowed_ts:
|
||||
return None
|
||||
if event.origin_server_ts < oldest_allowed_ts:
|
||||
return None
|
||||
|
||||
if event.event_id in always_include_ids:
|
||||
return event
|
||||
|
|
|
@ -23,7 +23,7 @@ from OpenSSL import SSL
|
|||
|
||||
from synapse.config._base import Config, RootConfig
|
||||
from synapse.config.tls import ConfigError, TlsConfig
|
||||
from synapse.crypto.context_factory import ClientTLSOptionsFactory
|
||||
from synapse.crypto.context_factory import FederationPolicyForHTTPS
|
||||
|
||||
from tests.unittest import TestCase
|
||||
|
||||
|
@ -180,12 +180,13 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg=
|
|||
t = TestConfig()
|
||||
t.read_config(config, config_dir_path="", data_dir_path="")
|
||||
|
||||
cf = ClientTLSOptionsFactory(t)
|
||||
cf = FederationPolicyForHTTPS(t)
|
||||
options = _get_ssl_context_options(cf._verify_ssl_context)
|
||||
|
||||
# The context has had NO_TLSv1_1 and NO_TLSv1_0 set, but not NO_TLSv1_2
|
||||
self.assertNotEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1, 0)
|
||||
self.assertNotEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1_1, 0)
|
||||
self.assertEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1_2, 0)
|
||||
self.assertNotEqual(options & SSL.OP_NO_TLSv1, 0)
|
||||
self.assertNotEqual(options & SSL.OP_NO_TLSv1_1, 0)
|
||||
self.assertEqual(options & SSL.OP_NO_TLSv1_2, 0)
|
||||
|
||||
def test_tls_client_minimum_set_passed_through_1_0(self):
|
||||
"""
|
||||
|
@ -195,12 +196,13 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg=
|
|||
t = TestConfig()
|
||||
t.read_config(config, config_dir_path="", data_dir_path="")
|
||||
|
||||
cf = ClientTLSOptionsFactory(t)
|
||||
cf = FederationPolicyForHTTPS(t)
|
||||
options = _get_ssl_context_options(cf._verify_ssl_context)
|
||||
|
||||
# The context has not had any of the NO_TLS set.
|
||||
self.assertEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1, 0)
|
||||
self.assertEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1_1, 0)
|
||||
self.assertEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1_2, 0)
|
||||
self.assertEqual(options & SSL.OP_NO_TLSv1, 0)
|
||||
self.assertEqual(options & SSL.OP_NO_TLSv1_1, 0)
|
||||
self.assertEqual(options & SSL.OP_NO_TLSv1_2, 0)
|
||||
|
||||
def test_acme_disabled_in_generated_config_no_acme_domain_provied(self):
|
||||
"""
|
||||
|
@ -273,7 +275,7 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg=
|
|||
t = TestConfig()
|
||||
t.read_config(config, config_dir_path="", data_dir_path="")
|
||||
|
||||
cf = ClientTLSOptionsFactory(t)
|
||||
cf = FederationPolicyForHTTPS(t)
|
||||
|
||||
# Not in the whitelist
|
||||
opts = cf.get_options(b"notexample.com")
|
||||
|
@ -282,3 +284,10 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg=
|
|||
# Caught by the wildcard
|
||||
opts = cf.get_options(idna.encode("テスト.ドメイン.テスト"))
|
||||
self.assertFalse(opts._verifier._verify_certs)
|
||||
|
||||
|
||||
def _get_ssl_context_options(ssl_context: SSL.Context) -> int:
|
||||
"""get the options bits from an openssl context object"""
|
||||
# the OpenSSL.SSL.Context wrapper doesn't expose get_options, so we have to
|
||||
# use the low-level interface
|
||||
return SSL._lib.SSL_CTX_get_options(ssl_context._context)
|
||||
|
|
|
@ -31,7 +31,7 @@ from twisted.web.http_headers import Headers
|
|||
from twisted.web.iweb import IPolicyForHTTPS
|
||||
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.crypto.context_factory import ClientTLSOptionsFactory
|
||||
from synapse.crypto.context_factory import FederationPolicyForHTTPS
|
||||
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
|
||||
from synapse.http.federation.srv_resolver import Server
|
||||
from synapse.http.federation.well_known_resolver import (
|
||||
|
@ -79,7 +79,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
|||
self._config = config = HomeServerConfig()
|
||||
config.parse_config_dict(config_dict, "", "")
|
||||
|
||||
self.tls_factory = ClientTLSOptionsFactory(config)
|
||||
self.tls_factory = FederationPolicyForHTTPS(config)
|
||||
|
||||
self.well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds)
|
||||
self.had_well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds)
|
||||
|
@ -715,7 +715,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
|||
config = default_config("test", parse=True)
|
||||
|
||||
# Build a new agent and WellKnownResolver with a different tls factory
|
||||
tls_factory = ClientTLSOptionsFactory(config)
|
||||
tls_factory = FederationPolicyForHTTPS(config)
|
||||
agent = MatrixFederationAgent(
|
||||
reactor=self.reactor,
|
||||
tls_client_options_factory=tls_factory,
|
||||
|
|
|
@ -868,6 +868,13 @@ class RoomTestCase(unittest.HomeserverTestCase):
|
|||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
|
||||
# Set this new alias as the canonical alias for this room
|
||||
self.helper.send_state(
|
||||
room_id,
|
||||
"m.room.aliases",
|
||||
{"aliases": [test_alias]},
|
||||
tok=self.admin_user_tok,
|
||||
state_key="test",
|
||||
)
|
||||
self.helper.send_state(
|
||||
room_id,
|
||||
"m.room.canonical_alias",
|
||||
|
|
|
@ -51,30 +51,26 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
|
|||
self.user = self.register_user("user", "test")
|
||||
self.user_tok = self.login("user", "test")
|
||||
|
||||
def test_cannot_set_alias_via_state_event(self):
|
||||
self.ensure_user_joined_room()
|
||||
url = "/_matrix/client/r0/rooms/%s/state/m.room.aliases/%s" % (
|
||||
self.room_id,
|
||||
self.hs.hostname,
|
||||
)
|
||||
|
||||
data = {"aliases": [self.random_alias(5)]}
|
||||
request_data = json.dumps(data)
|
||||
|
||||
request, channel = self.make_request(
|
||||
"PUT", url, request_data, access_token=self.user_tok
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEqual(channel.code, 400, channel.result)
|
||||
def test_state_event_not_in_room(self):
|
||||
self.ensure_user_left_room()
|
||||
self.set_alias_via_state_event(403)
|
||||
|
||||
def test_directory_endpoint_not_in_room(self):
|
||||
self.ensure_user_left_room()
|
||||
self.set_alias_via_directory(403)
|
||||
|
||||
def test_state_event_in_room_too_long(self):
|
||||
self.ensure_user_joined_room()
|
||||
self.set_alias_via_state_event(400, alias_length=256)
|
||||
|
||||
def test_directory_in_room_too_long(self):
|
||||
self.ensure_user_joined_room()
|
||||
self.set_alias_via_directory(400, alias_length=256)
|
||||
|
||||
def test_state_event_in_room(self):
|
||||
self.ensure_user_joined_room()
|
||||
self.set_alias_via_state_event(200)
|
||||
|
||||
def test_directory_in_room(self):
|
||||
self.ensure_user_joined_room()
|
||||
self.set_alias_via_directory(200)
|
||||
|
@ -106,6 +102,21 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
|
|||
self.render(request)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
def set_alias_via_state_event(self, expected_code, alias_length=5):
|
||||
url = "/_matrix/client/r0/rooms/%s/state/m.room.aliases/%s" % (
|
||||
self.room_id,
|
||||
self.hs.hostname,
|
||||
)
|
||||
|
||||
data = {"aliases": [self.random_alias(alias_length)]}
|
||||
request_data = json.dumps(data)
|
||||
|
||||
request, channel = self.make_request(
|
||||
"PUT", url, request_data, access_token=self.user_tok
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEqual(channel.code, expected_code, channel.result)
|
||||
|
||||
def set_alias_via_directory(self, expected_code, alias_length=5):
|
||||
url = "/_matrix/client/r0/directory/room/%s" % self.random_alias(alias_length)
|
||||
data = {"room_id": self.room_id}
|
||||
|
|
|
@ -603,7 +603,7 @@ class TestStateResolutionStore(object):
|
|||
|
||||
return {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map}
|
||||
|
||||
def get_auth_chain(self, event_ids, ignore_events):
|
||||
def _get_auth_chain(self, event_ids):
|
||||
"""Gets the full auth chain for a set of events (including rejected
|
||||
events).
|
||||
|
||||
|
@ -617,9 +617,6 @@ class TestStateResolutionStore(object):
|
|||
Args:
|
||||
event_ids (list): The event IDs of the events to fetch the auth
|
||||
chain for. Must be state events.
|
||||
ignore_events: Set of events to exclude from the returned auth
|
||||
chain.
|
||||
|
||||
Returns:
|
||||
Deferred[list[str]]: List of event IDs of the auth chain.
|
||||
"""
|
||||
|
@ -629,7 +626,7 @@ class TestStateResolutionStore(object):
|
|||
stack = list(event_ids)
|
||||
while stack:
|
||||
event_id = stack.pop()
|
||||
if event_id in result or event_id in ignore_events:
|
||||
if event_id in result:
|
||||
continue
|
||||
|
||||
result.add(event_id)
|
||||
|
@ -639,3 +636,9 @@ class TestStateResolutionStore(object):
|
|||
stack.append(aid)
|
||||
|
||||
return list(result)
|
||||
|
||||
def get_auth_chain_difference(self, auth_sets):
|
||||
chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets]
|
||||
|
||||
common = set(chains[0]).intersection(*chains[1:])
|
||||
return set(chains[0]).union(*chains[1:]) - common
|
||||
|
|
|
@ -13,19 +13,14 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
import tests.unittest
|
||||
import tests.utils
|
||||
|
||||
|
||||
class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
|
||||
@defer.inlineCallbacks
|
||||
def setUp(self):
|
||||
hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
|
||||
class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||
def prepare(self, reactor, clock, hs):
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_prev_events_for_room(self):
|
||||
room_id = "@ROOM:local"
|
||||
|
||||
|
@ -61,15 +56,14 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
|
|||
)
|
||||
|
||||
for i in range(0, 20):
|
||||
yield self.store.db.runInteraction("insert", insert_event, i)
|
||||
self.get_success(self.store.db.runInteraction("insert", insert_event, i))
|
||||
|
||||
# this should get the last ten
|
||||
r = yield self.store.get_prev_events_for_room(room_id)
|
||||
r = self.get_success(self.store.get_prev_events_for_room(room_id))
|
||||
self.assertEqual(10, len(r))
|
||||
for i in range(0, 10):
|
||||
self.assertEqual("$event_%i:local" % (19 - i), r[i])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_rooms_with_many_extremities(self):
|
||||
room1 = "#room1"
|
||||
room2 = "#room2"
|
||||
|
@ -86,25 +80,154 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
|
|||
)
|
||||
|
||||
for i in range(0, 20):
|
||||
yield self.store.db.runInteraction("insert", insert_event, i, room1)
|
||||
yield self.store.db.runInteraction("insert", insert_event, i, room2)
|
||||
yield self.store.db.runInteraction("insert", insert_event, i, room3)
|
||||
self.get_success(
|
||||
self.store.db.runInteraction("insert", insert_event, i, room1)
|
||||
)
|
||||
self.get_success(
|
||||
self.store.db.runInteraction("insert", insert_event, i, room2)
|
||||
)
|
||||
self.get_success(
|
||||
self.store.db.runInteraction("insert", insert_event, i, room3)
|
||||
)
|
||||
|
||||
# Test simple case
|
||||
r = yield self.store.get_rooms_with_many_extremities(5, 5, [])
|
||||
r = self.get_success(self.store.get_rooms_with_many_extremities(5, 5, []))
|
||||
self.assertEqual(len(r), 3)
|
||||
|
||||
# Does filter work?
|
||||
|
||||
r = yield self.store.get_rooms_with_many_extremities(5, 5, [room1])
|
||||
r = self.get_success(self.store.get_rooms_with_many_extremities(5, 5, [room1]))
|
||||
self.assertTrue(room2 in r)
|
||||
self.assertTrue(room3 in r)
|
||||
self.assertEqual(len(r), 2)
|
||||
|
||||
r = yield self.store.get_rooms_with_many_extremities(5, 5, [room1, room2])
|
||||
r = self.get_success(
|
||||
self.store.get_rooms_with_many_extremities(5, 5, [room1, room2])
|
||||
)
|
||||
self.assertEqual(r, [room3])
|
||||
|
||||
# Does filter and limit work?
|
||||
|
||||
r = yield self.store.get_rooms_with_many_extremities(5, 1, [room1])
|
||||
r = self.get_success(self.store.get_rooms_with_many_extremities(5, 1, [room1]))
|
||||
self.assertTrue(r == [room2] or r == [room3])
|
||||
|
||||
def test_auth_difference(self):
|
||||
room_id = "@ROOM:local"
|
||||
|
||||
# The silly auth graph we use to test the auth difference algorithm,
|
||||
# where the top are the most recent events.
|
||||
#
|
||||
# A B
|
||||
# \ /
|
||||
# D E
|
||||
# \ |
|
||||
# ` F C
|
||||
# | /|
|
||||
# G ´ |
|
||||
# | \ |
|
||||
# H I
|
||||
# | |
|
||||
# K J
|
||||
|
||||
auth_graph = {
|
||||
"a": ["e"],
|
||||
"b": ["e"],
|
||||
"c": ["g", "i"],
|
||||
"d": ["f"],
|
||||
"e": ["f"],
|
||||
"f": ["g"],
|
||||
"g": ["h", "i"],
|
||||
"h": ["k"],
|
||||
"i": ["j"],
|
||||
"k": [],
|
||||
"j": [],
|
||||
}
|
||||
|
||||
depth_map = {
|
||||
"a": 7,
|
||||
"b": 7,
|
||||
"c": 4,
|
||||
"d": 6,
|
||||
"e": 6,
|
||||
"f": 5,
|
||||
"g": 3,
|
||||
"h": 2,
|
||||
"i": 2,
|
||||
"k": 1,
|
||||
"j": 1,
|
||||
}
|
||||
|
||||
# We rudely fiddle with the appropriate tables directly, as that's much
|
||||
# easier than constructing events properly.
|
||||
|
||||
def insert_event(txn, event_id, stream_ordering):
|
||||
|
||||
depth = depth_map[event_id]
|
||||
|
||||
self.store.db.simple_insert_txn(
|
||||
txn,
|
||||
table="events",
|
||||
values={
|
||||
"event_id": event_id,
|
||||
"room_id": room_id,
|
||||
"depth": depth,
|
||||
"topological_ordering": depth,
|
||||
"type": "m.test",
|
||||
"processed": True,
|
||||
"outlier": False,
|
||||
"stream_ordering": stream_ordering,
|
||||
},
|
||||
)
|
||||
|
||||
self.store.db.simple_insert_many_txn(
|
||||
txn,
|
||||
table="event_auth",
|
||||
values=[
|
||||
{"event_id": event_id, "room_id": room_id, "auth_id": a}
|
||||
for a in auth_graph[event_id]
|
||||
],
|
||||
)
|
||||
|
||||
next_stream_ordering = 0
|
||||
for event_id in auth_graph:
|
||||
next_stream_ordering += 1
|
||||
self.get_success(
|
||||
self.store.db.runInteraction(
|
||||
"insert", insert_event, event_id, next_stream_ordering
|
||||
)
|
||||
)
|
||||
|
||||
# Now actually test that various combinations give the right result:
|
||||
|
||||
difference = self.get_success(
|
||||
self.store.get_auth_chain_difference([{"a"}, {"b"}])
|
||||
)
|
||||
self.assertSetEqual(difference, {"a", "b"})
|
||||
|
||||
difference = self.get_success(
|
||||
self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}])
|
||||
)
|
||||
self.assertSetEqual(difference, {"a", "b", "c", "e", "f"})
|
||||
|
||||
difference = self.get_success(
|
||||
self.store.get_auth_chain_difference([{"a", "c"}, {"b"}])
|
||||
)
|
||||
self.assertSetEqual(difference, {"a", "b", "c"})
|
||||
|
||||
difference = self.get_success(
|
||||
self.store.get_auth_chain_difference([{"a"}, {"b"}, {"d"}])
|
||||
)
|
||||
self.assertSetEqual(difference, {"a", "b", "d", "e"})
|
||||
|
||||
difference = self.get_success(
|
||||
self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}, {"d"}])
|
||||
)
|
||||
self.assertSetEqual(difference, {"a", "b", "c", "d", "e", "f"})
|
||||
|
||||
difference = self.get_success(
|
||||
self.store.get_auth_chain_difference([{"a"}, {"b"}, {"e"}])
|
||||
)
|
||||
self.assertSetEqual(difference, {"a", "b"})
|
||||
|
||||
difference = self.get_success(self.store.get_auth_chain_difference([{"a"}]))
|
||||
self.assertSetEqual(difference, set())
|
||||
|
|
Loading…
Reference in New Issue