Merge branch 'develop' of github.com:matrix-org/synapse into matrix-org-hotfixes
commit
93a0751302
|
@ -0,0 +1 @@
|
|||
Refactor _EventInternalMetadata object to improve type safety.
|
|
@ -0,0 +1 @@
|
|||
Update Synapse's documentation to warn about the deprecation of ACME v1.
|
|
@ -0,0 +1 @@
|
|||
Increase DB/CPU perf of `_is_server_still_joined` check.
|
|
@ -0,0 +1 @@
|
|||
Implement `GET /_matrix/client/r0/rooms/{roomId}/aliases` endpoint as per [MSC2432](https://github.com/matrix-org/matrix-doc/pull/2432).
|
|
@ -0,0 +1 @@
|
|||
Fix errors from logging in the purge jobs related to the message retention policies support.
|
|
@ -0,0 +1 @@
|
|||
Increase perf of `get_auth_chain_ids` used in state res v2.
|
|
@ -476,6 +476,11 @@ retention:
|
|||
# ACME support: This will configure Synapse to request a valid TLS certificate
|
||||
# for your configured `server_name` via Let's Encrypt.
|
||||
#
|
||||
# Note that ACME v1 is now deprecated, and Synapse currently doesn't support
|
||||
# ACME v2. This means that this feature currently won't work with installs set
|
||||
# up after November 2019. For more info, and alternative solutions, see
|
||||
# https://github.com/matrix-org/synapse/blob/master/docs/ACME.md#deprecation-of-acme-v1
|
||||
#
|
||||
# Note that provisioning a certificate in this way requires port 80 to be
|
||||
# routed to Synapse so that it can complete the http-01 ACME challenge.
|
||||
# By default, if you enable ACME support, Synapse will attempt to listen on
|
||||
|
|
|
@ -32,6 +32,17 @@ from synapse.util import glob_to_regex
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ACME_SUPPORT_ENABLED_WARN = """\
|
||||
This server uses Synapse's built-in ACME support. Note that ACME v1 has been
|
||||
deprecated by Let's Encrypt, and that Synapse doesn't currently support ACME v2,
|
||||
which means that this feature will not work with Synapse installs set up after
|
||||
November 2019, and that it may stop working on June 2020 for installs set up
|
||||
before that date.
|
||||
|
||||
For more info and alternative solutions, see
|
||||
https://github.com/matrix-org/synapse/blob/master/docs/ACME.md#deprecation-of-acme-v1
|
||||
--------------------------------------------------------------------------------"""
|
||||
|
||||
|
||||
class TlsConfig(Config):
|
||||
section = "tls"
|
||||
|
@ -44,6 +55,9 @@ class TlsConfig(Config):
|
|||
|
||||
self.acme_enabled = acme_config.get("enabled", False)
|
||||
|
||||
if self.acme_enabled:
|
||||
logger.warning(ACME_SUPPORT_ENABLED_WARN)
|
||||
|
||||
# hyperlink complains on py2 if this is not a Unicode
|
||||
self.acme_url = six.text_type(
|
||||
acme_config.get("url", "https://acme-v01.api.letsencrypt.org/directory")
|
||||
|
@ -362,6 +376,11 @@ class TlsConfig(Config):
|
|||
# ACME support: This will configure Synapse to request a valid TLS certificate
|
||||
# for your configured `server_name` via Let's Encrypt.
|
||||
#
|
||||
# Note that ACME v1 is now deprecated, and Synapse currently doesn't support
|
||||
# ACME v2. This means that this feature currently won't work with installs set
|
||||
# up after November 2019. For more info, and alternative solutions, see
|
||||
# https://github.com/matrix-org/synapse/blob/master/docs/ACME.md#deprecation-of-acme-v1
|
||||
#
|
||||
# Note that provisioning a certificate in this way requires port 80 to be
|
||||
# routed to Synapse so that it can complete the http-01 ACME challenge.
|
||||
# By default, if you enable ACME support, Synapse will attempt to listen on
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright 2019 New Vector Ltd
|
||||
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -37,34 +38,115 @@ from synapse.util.frozenutils import freeze
|
|||
USE_FROZEN_DICTS = strtobool(os.environ.get("SYNAPSE_USE_FROZEN_DICTS", "0"))
|
||||
|
||||
|
||||
class DictProperty:
|
||||
"""An object property which delegates to the `_dict` within its parent object."""
|
||||
|
||||
__slots__ = ["key"]
|
||||
|
||||
def __init__(self, key: str):
|
||||
self.key = key
|
||||
|
||||
def __get__(self, instance, owner=None):
|
||||
# if the property is accessed as a class property rather than an instance
|
||||
# property, return the property itself rather than the value
|
||||
if instance is None:
|
||||
return self
|
||||
try:
|
||||
return instance._dict[self.key]
|
||||
except KeyError as e1:
|
||||
# We want this to look like a regular attribute error (mostly so that
|
||||
# hasattr() works correctly), so we convert the KeyError into an
|
||||
# AttributeError.
|
||||
#
|
||||
# To exclude the KeyError from the traceback, we explicitly
|
||||
# 'raise from e1.__context__' (which is better than 'raise from None',
|
||||
# becuase that would omit any *earlier* exceptions).
|
||||
#
|
||||
raise AttributeError(
|
||||
"'%s' has no '%s' property" % (type(instance), self.key)
|
||||
) from e1.__context__
|
||||
|
||||
def __set__(self, instance, v):
|
||||
instance._dict[self.key] = v
|
||||
|
||||
def __delete__(self, instance):
|
||||
try:
|
||||
del instance._dict[self.key]
|
||||
except KeyError as e1:
|
||||
raise AttributeError(
|
||||
"'%s' has no '%s' property" % (type(instance), self.key)
|
||||
) from e1.__context__
|
||||
|
||||
|
||||
class DefaultDictProperty(DictProperty):
|
||||
"""An extension of DictProperty which provides a default if the property is
|
||||
not present in the parent's _dict.
|
||||
|
||||
Note that this means that hasattr() on the property always returns True.
|
||||
"""
|
||||
|
||||
__slots__ = ["default"]
|
||||
|
||||
def __init__(self, key, default):
|
||||
super().__init__(key)
|
||||
self.default = default
|
||||
|
||||
def __get__(self, instance, owner=None):
|
||||
if instance is None:
|
||||
return self
|
||||
return instance._dict.get(self.key, self.default)
|
||||
|
||||
|
||||
class _EventInternalMetadata(object):
|
||||
def __init__(self, internal_metadata_dict):
|
||||
self.__dict__ = dict(internal_metadata_dict)
|
||||
__slots__ = ["_dict"]
|
||||
|
||||
def get_dict(self):
|
||||
return dict(self.__dict__)
|
||||
def __init__(self, internal_metadata_dict: JsonDict):
|
||||
# we have to copy the dict, because it turns out that the same dict is
|
||||
# reused. TODO: fix that
|
||||
self._dict = dict(internal_metadata_dict)
|
||||
|
||||
def is_outlier(self):
|
||||
return getattr(self, "outlier", False)
|
||||
outlier = DictProperty("outlier") # type: bool
|
||||
out_of_band_membership = DictProperty("out_of_band_membership") # type: bool
|
||||
send_on_behalf_of = DictProperty("send_on_behalf_of") # type: str
|
||||
recheck_redaction = DictProperty("recheck_redaction") # type: bool
|
||||
soft_failed = DictProperty("soft_failed") # type: bool
|
||||
proactively_send = DictProperty("proactively_send") # type: bool
|
||||
redacted = DictProperty("redacted") # type: bool
|
||||
txn_id = DictProperty("txn_id") # type: str
|
||||
token_id = DictProperty("token_id") # type: str
|
||||
stream_ordering = DictProperty("stream_ordering") # type: int
|
||||
|
||||
def is_out_of_band_membership(self):
|
||||
# XXX: These are set by StreamWorkerStore._set_before_and_after.
|
||||
# I'm pretty sure that these are never persisted to the database, so shouldn't
|
||||
# be here
|
||||
before = DictProperty("before") # type: str
|
||||
after = DictProperty("after") # type: str
|
||||
order = DictProperty("order") # type: int
|
||||
|
||||
def get_dict(self) -> JsonDict:
|
||||
return dict(self._dict)
|
||||
|
||||
def is_outlier(self) -> bool:
|
||||
return self._dict.get("outlier", False)
|
||||
|
||||
def is_out_of_band_membership(self) -> bool:
|
||||
"""Whether this is an out of band membership, like an invite or an invite
|
||||
rejection. This is needed as those events are marked as outliers, but
|
||||
they still need to be processed as if they're new events (e.g. updating
|
||||
invite state in the database, relaying to clients, etc).
|
||||
"""
|
||||
return getattr(self, "out_of_band_membership", False)
|
||||
return self._dict.get("out_of_band_membership", False)
|
||||
|
||||
def get_send_on_behalf_of(self):
|
||||
def get_send_on_behalf_of(self) -> Optional[str]:
|
||||
"""Whether this server should send the event on behalf of another server.
|
||||
This is used by the federation "send_join" API to forward the initial join
|
||||
event for a server in the room.
|
||||
|
||||
returns a str with the name of the server this event is sent on behalf of.
|
||||
"""
|
||||
return getattr(self, "send_on_behalf_of", None)
|
||||
return self._dict.get("send_on_behalf_of")
|
||||
|
||||
def need_to_check_redaction(self):
|
||||
def need_to_check_redaction(self) -> bool:
|
||||
"""Whether the redaction event needs to be rechecked when fetching
|
||||
from the database.
|
||||
|
||||
|
@ -77,9 +159,9 @@ class _EventInternalMetadata(object):
|
|||
Returns:
|
||||
bool
|
||||
"""
|
||||
return getattr(self, "recheck_redaction", False)
|
||||
return self._dict.get("recheck_redaction", False)
|
||||
|
||||
def is_soft_failed(self):
|
||||
def is_soft_failed(self) -> bool:
|
||||
"""Whether the event has been soft failed.
|
||||
|
||||
Soft failed events should be handled as usual, except:
|
||||
|
@ -91,7 +173,7 @@ class _EventInternalMetadata(object):
|
|||
Returns:
|
||||
bool
|
||||
"""
|
||||
return getattr(self, "soft_failed", False)
|
||||
return self._dict.get("soft_failed", False)
|
||||
|
||||
def should_proactively_send(self):
|
||||
"""Whether the event, if ours, should be sent to other clients and
|
||||
|
@ -103,7 +185,7 @@ class _EventInternalMetadata(object):
|
|||
Returns:
|
||||
bool
|
||||
"""
|
||||
return getattr(self, "proactively_send", True)
|
||||
return self._dict.get("proactively_send", True)
|
||||
|
||||
def is_redacted(self):
|
||||
"""Whether the event has been redacted.
|
||||
|
@ -114,52 +196,7 @@ class _EventInternalMetadata(object):
|
|||
Returns:
|
||||
bool
|
||||
"""
|
||||
return getattr(self, "redacted", False)
|
||||
|
||||
|
||||
_SENTINEL = object()
|
||||
|
||||
|
||||
def _event_dict_property(key, default=_SENTINEL):
|
||||
"""Creates a new property for the given key that delegates access to
|
||||
`self._event_dict`.
|
||||
|
||||
The default is used if the key is missing from the `_event_dict`, if given,
|
||||
otherwise an AttributeError will be raised.
|
||||
|
||||
Note: If a default is given then `hasattr` will always return true.
|
||||
"""
|
||||
|
||||
# We want to be able to use hasattr with the event dict properties.
|
||||
# However, (on python3) hasattr expects AttributeError to be raised. Hence,
|
||||
# we need to transform the KeyError into an AttributeError
|
||||
|
||||
def getter_raises(self):
|
||||
try:
|
||||
return self._event_dict[key]
|
||||
except KeyError:
|
||||
raise AttributeError(key)
|
||||
|
||||
def getter_default(self):
|
||||
return self._event_dict.get(key, default)
|
||||
|
||||
def setter(self, v):
|
||||
try:
|
||||
self._event_dict[key] = v
|
||||
except KeyError:
|
||||
raise AttributeError(key)
|
||||
|
||||
def delete(self):
|
||||
try:
|
||||
del self._event_dict[key]
|
||||
except KeyError:
|
||||
raise AttributeError(key)
|
||||
|
||||
if default is _SENTINEL:
|
||||
# No default given, so use the getter that raises
|
||||
return property(getter_raises, setter, delete)
|
||||
else:
|
||||
return property(getter_default, setter, delete)
|
||||
return self._dict.get("redacted", False)
|
||||
|
||||
|
||||
class EventBase(object):
|
||||
|
@ -175,23 +212,23 @@ class EventBase(object):
|
|||
self.unsigned = unsigned
|
||||
self.rejected_reason = rejected_reason
|
||||
|
||||
self._event_dict = event_dict
|
||||
self._dict = event_dict
|
||||
|
||||
self.internal_metadata = _EventInternalMetadata(internal_metadata_dict)
|
||||
|
||||
auth_events = _event_dict_property("auth_events")
|
||||
depth = _event_dict_property("depth")
|
||||
content = _event_dict_property("content")
|
||||
hashes = _event_dict_property("hashes")
|
||||
origin = _event_dict_property("origin")
|
||||
origin_server_ts = _event_dict_property("origin_server_ts")
|
||||
prev_events = _event_dict_property("prev_events")
|
||||
redacts = _event_dict_property("redacts", None)
|
||||
room_id = _event_dict_property("room_id")
|
||||
sender = _event_dict_property("sender")
|
||||
state_key = _event_dict_property("state_key")
|
||||
type = _event_dict_property("type")
|
||||
user_id = _event_dict_property("sender")
|
||||
auth_events = DictProperty("auth_events")
|
||||
depth = DictProperty("depth")
|
||||
content = DictProperty("content")
|
||||
hashes = DictProperty("hashes")
|
||||
origin = DictProperty("origin")
|
||||
origin_server_ts = DictProperty("origin_server_ts")
|
||||
prev_events = DictProperty("prev_events")
|
||||
redacts = DefaultDictProperty("redacts", None)
|
||||
room_id = DictProperty("room_id")
|
||||
sender = DictProperty("sender")
|
||||
state_key = DictProperty("state_key")
|
||||
type = DictProperty("type")
|
||||
user_id = DictProperty("sender")
|
||||
|
||||
@property
|
||||
def event_id(self) -> str:
|
||||
|
@ -205,13 +242,13 @@ class EventBase(object):
|
|||
return hasattr(self, "state_key") and self.state_key is not None
|
||||
|
||||
def get_dict(self) -> JsonDict:
|
||||
d = dict(self._event_dict)
|
||||
d = dict(self._dict)
|
||||
d.update({"signatures": self.signatures, "unsigned": dict(self.unsigned)})
|
||||
|
||||
return d
|
||||
|
||||
def get(self, key, default=None):
|
||||
return self._event_dict.get(key, default)
|
||||
return self._dict.get(key, default)
|
||||
|
||||
def get_internal_metadata_dict(self):
|
||||
return self.internal_metadata.get_dict()
|
||||
|
@ -233,16 +270,16 @@ class EventBase(object):
|
|||
raise AttributeError("Unrecognized attribute %s" % (instance,))
|
||||
|
||||
def __getitem__(self, field):
|
||||
return self._event_dict[field]
|
||||
return self._dict[field]
|
||||
|
||||
def __contains__(self, field):
|
||||
return field in self._event_dict
|
||||
return field in self._dict
|
||||
|
||||
def items(self):
|
||||
return list(self._event_dict.items())
|
||||
return list(self._dict.items())
|
||||
|
||||
def keys(self):
|
||||
return six.iterkeys(self._event_dict)
|
||||
return six.iterkeys(self._dict)
|
||||
|
||||
def prev_event_ids(self):
|
||||
"""Returns the list of prev event IDs. The order matches the order
|
||||
|
|
|
@ -25,6 +25,15 @@ from synapse.app import check_bind_error
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ACME_REGISTER_FAIL_ERROR = """
|
||||
--------------------------------------------------------------------------------
|
||||
Failed to register with the ACME provider. This is likely happening because the install
|
||||
is new, and ACME v1 has been deprecated by Let's Encrypt and is disabled for installs set
|
||||
up after November 2019.
|
||||
At the moment, Synapse doesn't support ACME v2. For more info and alternative solution,
|
||||
check out https://github.com/matrix-org/synapse/blob/master/docs/ACME.md#deprecation-of-acme-v1
|
||||
--------------------------------------------------------------------------------"""
|
||||
|
||||
|
||||
class AcmeHandler(object):
|
||||
def __init__(self, hs):
|
||||
|
@ -71,7 +80,12 @@ class AcmeHandler(object):
|
|||
# want it to control where we save the certificates, we have to reach in
|
||||
# and trigger the registration machinery ourselves.
|
||||
self._issuer._registered = False
|
||||
yield self._issuer._ensure_registered()
|
||||
|
||||
try:
|
||||
yield self._issuer._ensure_registered()
|
||||
except Exception:
|
||||
logger.error(ACME_REGISTER_FAIL_ERROR)
|
||||
raise
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def provision_certificate(self):
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
import logging
|
||||
import string
|
||||
from typing import List
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
|
@ -28,7 +29,7 @@ from synapse.api.errors import (
|
|||
StoreError,
|
||||
SynapseError,
|
||||
)
|
||||
from synapse.types import RoomAlias, UserID, get_domain_from_id
|
||||
from synapse.types import Requester, RoomAlias, UserID, get_domain_from_id
|
||||
|
||||
from ._base import BaseHandler
|
||||
|
||||
|
@ -452,3 +453,17 @@ class DirectoryHandler(BaseHandler):
|
|||
yield self.store.set_room_is_public_appservice(
|
||||
room_id, appservice_id, network_id, visibility == "public"
|
||||
)
|
||||
|
||||
async def get_aliases_for_room(
|
||||
self, requester: Requester, room_id: str
|
||||
) -> List[str]:
|
||||
"""
|
||||
Get a list of the aliases that currently point to this room on this server
|
||||
"""
|
||||
# allow access to server admins and current members of the room
|
||||
is_admin = await self.auth.is_server_admin(requester.user)
|
||||
if not is_admin:
|
||||
await self.auth.check_joined_room(room_id, requester.user.to_string())
|
||||
|
||||
aliases = await self.store.get_aliases_for_room(room_id)
|
||||
return aliases
|
||||
|
|
|
@ -133,7 +133,7 @@ class PaginationHandler(object):
|
|||
include_null = False
|
||||
|
||||
logger.info(
|
||||
"[purge] Running purge job for %d < max_lifetime <= %d (include NULLs = %s)",
|
||||
"[purge] Running purge job for %s < max_lifetime <= %s (include NULLs = %s)",
|
||||
min_ms,
|
||||
max_ms,
|
||||
include_null,
|
||||
|
|
|
@ -45,6 +45,10 @@ from synapse.storage.state import StateFilter
|
|||
from synapse.streams.config import PaginationConfig
|
||||
from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID
|
||||
|
||||
MYPY = False
|
||||
if MYPY:
|
||||
import synapse.server
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -843,6 +847,24 @@ class RoomTypingRestServlet(RestServlet):
|
|||
return 200, {}
|
||||
|
||||
|
||||
class RoomAliasListServlet(RestServlet):
|
||||
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/aliases", unstable=False)
|
||||
|
||||
def __init__(self, hs: "synapse.server.HomeServer"):
|
||||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.directory_handler = hs.get_handlers().directory_handler
|
||||
|
||||
async def on_GET(self, request, room_id):
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
|
||||
alias_list = await self.directory_handler.get_aliases_for_room(
|
||||
requester, room_id
|
||||
)
|
||||
|
||||
return 200, {"aliases": alias_list}
|
||||
|
||||
|
||||
class SearchRestServlet(RestServlet):
|
||||
PATTERNS = client_patterns("/search$", v1=True)
|
||||
|
||||
|
@ -931,6 +953,7 @@ def register_servlets(hs, http_server):
|
|||
JoinedRoomsRestServlet(hs).register(http_server)
|
||||
RoomEventServlet(hs).register(http_server)
|
||||
RoomEventContextServlet(hs).register(http_server)
|
||||
RoomAliasListServlet(hs).register(http_server)
|
||||
|
||||
|
||||
def register_deprecated_servlets(hs, http_server):
|
||||
|
|
|
@ -62,32 +62,37 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
|||
)
|
||||
|
||||
def _get_auth_chain_ids_txn(self, txn, event_ids, include_given):
|
||||
if include_given:
|
||||
results = set(event_ids)
|
||||
else:
|
||||
results = set()
|
||||
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
# For efficiency we make the database do this if we can.
|
||||
sql = """
|
||||
WITH RECURSIVE auth_chain(event_id) AS (
|
||||
SELECT auth_id FROM event_auth WHERE event_id = ANY(?)
|
||||
UNION
|
||||
SELECT auth_id FROM event_auth
|
||||
INNER JOIN auth_chain USING (event_id)
|
||||
)
|
||||
SELECT event_id FROM auth_chain
|
||||
"""
|
||||
txn.execute(sql, (list(event_ids),))
|
||||
|
||||
results = set(event_id for event_id, in txn)
|
||||
# We need to be a little careful with querying large amounts at
|
||||
# once, for some reason postgres really doesn't like it. We do this
|
||||
# by only asking for auth chain of 500 events at a time.
|
||||
event_ids = list(event_ids)
|
||||
chunks = [event_ids[x : x + 500] for x in range(0, len(event_ids), 500)]
|
||||
for chunk in chunks:
|
||||
sql = """
|
||||
WITH RECURSIVE auth_chain(event_id) AS (
|
||||
SELECT auth_id FROM event_auth WHERE event_id = ANY(?)
|
||||
UNION
|
||||
SELECT auth_id FROM event_auth
|
||||
INNER JOIN auth_chain USING (event_id)
|
||||
)
|
||||
SELECT event_id FROM auth_chain
|
||||
"""
|
||||
txn.execute(sql, (chunk,))
|
||||
|
||||
if include_given:
|
||||
results.update(event_ids)
|
||||
results.update(event_id for event_id, in txn)
|
||||
|
||||
return list(results)
|
||||
|
||||
# Database doesn't necessarily support recursive CTE, so we fall
|
||||
# back to do doing it manually.
|
||||
if include_given:
|
||||
results = set(event_ids)
|
||||
else:
|
||||
results = set()
|
||||
|
||||
base_sql = "SELECT auth_id FROM event_auth WHERE "
|
||||
|
||||
|
|
|
@ -868,6 +868,37 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
desc="get_membership_from_event_ids",
|
||||
)
|
||||
|
||||
async def is_local_host_in_room_ignoring_users(
|
||||
self, room_id: str, ignore_users: Collection[str]
|
||||
) -> bool:
|
||||
"""Check if there are any local users, excluding those in the given
|
||||
list, in the room.
|
||||
"""
|
||||
|
||||
clause, args = make_in_list_sql_clause(
|
||||
self.database_engine, "user_id", ignore_users
|
||||
)
|
||||
|
||||
sql = """
|
||||
SELECT 1 FROM local_current_membership
|
||||
WHERE
|
||||
room_id = ? AND membership = ?
|
||||
AND NOT (%s)
|
||||
LIMIT 1
|
||||
""" % (
|
||||
clause,
|
||||
)
|
||||
|
||||
def _is_local_host_in_room_ignoring_users_txn(txn):
|
||||
txn.execute(sql, (room_id, Membership.JOIN, *args))
|
||||
|
||||
return bool(txn.fetchone())
|
||||
|
||||
return await self.db.runInteraction(
|
||||
"is_local_host_in_room_ignoring_users",
|
||||
_is_local_host_in_room_ignoring_users_txn,
|
||||
)
|
||||
|
||||
|
||||
class RoomMemberBackgroundUpdateStore(SQLBaseStore):
|
||||
def __init__(self, database: Database, db_conn, hs):
|
||||
|
|
|
@ -727,6 +727,7 @@ class EventsPersistenceStorage(object):
|
|||
|
||||
# Check if any of the given events are a local join that appear in the
|
||||
# current state
|
||||
events_to_check = [] # Event IDs that aren't an event we're persisting
|
||||
for (typ, state_key), event_id in delta.to_insert.items():
|
||||
if typ != EventTypes.Member or not self.is_mine_id(state_key):
|
||||
continue
|
||||
|
@ -736,8 +737,33 @@ class EventsPersistenceStorage(object):
|
|||
if event.membership == Membership.JOIN:
|
||||
return True
|
||||
|
||||
# There's been a change of membership but we don't have a local join
|
||||
# event in the new events, so we need to check the full state.
|
||||
# The event is not in `ev_ctx_rm`, so we need to pull it out of
|
||||
# the DB.
|
||||
events_to_check.append(event_id)
|
||||
|
||||
# Check if any of the changes that we don't have events for are joins.
|
||||
if events_to_check:
|
||||
rows = await self.main_store.get_membership_from_event_ids(events_to_check)
|
||||
is_still_joined = any(row["membership"] == Membership.JOIN for row in rows)
|
||||
if is_still_joined:
|
||||
return True
|
||||
|
||||
# None of the new state events are local joins, so we check the database
|
||||
# to see if there are any other local users in the room. We ignore users
|
||||
# whose state has changed as we've already their new state above.
|
||||
users_to_ignore = [
|
||||
state_key
|
||||
for _, state_key in itertools.chain(delta.to_insert, delta.to_delete)
|
||||
if self.is_mine_id(state_key)
|
||||
]
|
||||
|
||||
if await self.main_store.is_local_host_in_room_ignoring_users(
|
||||
room_id, users_to_ignore
|
||||
):
|
||||
return True
|
||||
|
||||
# The server will leave the room, so we go and find out which remote
|
||||
# users will still be joined when we leave.
|
||||
if current_state is None:
|
||||
current_state = await self.main_store.get_current_state_ids(room_id)
|
||||
current_state = dict(current_state)
|
||||
|
@ -746,19 +772,6 @@ class EventsPersistenceStorage(object):
|
|||
|
||||
current_state.update(delta.to_insert)
|
||||
|
||||
event_ids = [
|
||||
event_id
|
||||
for (typ, state_key,), event_id in current_state.items()
|
||||
if typ == EventTypes.Member and self.is_mine_id(state_key)
|
||||
]
|
||||
|
||||
rows = await self.main_store.get_membership_from_event_ids(event_ids)
|
||||
is_still_joined = any(row["membership"] == Membership.JOIN for row in rows)
|
||||
if is_still_joined:
|
||||
return True
|
||||
|
||||
# The server will leave the room, so we go and find out which remote
|
||||
# users will still be joined when we leave.
|
||||
remote_event_ids = [
|
||||
event_id
|
||||
for (typ, state_key,), event_id in current_state.items()
|
||||
|
|
|
@ -28,8 +28,9 @@ from twisted.internet import defer
|
|||
import synapse.rest.admin
|
||||
from synapse.api.constants import EventContentFields, EventTypes, Membership
|
||||
from synapse.handlers.pagination import PurgeStatus
|
||||
from synapse.rest.client.v1 import login, profile, room
|
||||
from synapse.rest.client.v1 import directory, login, profile, room
|
||||
from synapse.rest.client.v2_alpha import account
|
||||
from synapse.types import JsonDict, RoomAlias
|
||||
from synapse.util.stringutils import random_string
|
||||
|
||||
from tests import unittest
|
||||
|
@ -1726,3 +1727,70 @@ class ContextTestCase(unittest.HomeserverTestCase):
|
|||
self.assertEqual(len(events_after), 2, events_after)
|
||||
self.assertDictEqual(events_after[0].get("content"), {}, events_after[0])
|
||||
self.assertEqual(events_after[1].get("content"), {}, events_after[1])
|
||||
|
||||
|
||||
class DirectoryTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
servlets = [
|
||||
synapse.rest.admin.register_servlets_for_client_rest_resource,
|
||||
directory.register_servlets,
|
||||
login.register_servlets,
|
||||
room.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor, clock, homeserver):
|
||||
self.room_owner = self.register_user("room_owner", "test")
|
||||
self.room_owner_tok = self.login("room_owner", "test")
|
||||
|
||||
self.room_id = self.helper.create_room_as(
|
||||
self.room_owner, tok=self.room_owner_tok
|
||||
)
|
||||
|
||||
def test_no_aliases(self):
|
||||
res = self._get_aliases(self.room_owner_tok)
|
||||
self.assertEqual(res["aliases"], [])
|
||||
|
||||
def test_not_in_room(self):
|
||||
self.register_user("user", "test")
|
||||
user_tok = self.login("user", "test")
|
||||
res = self._get_aliases(user_tok, expected_code=403)
|
||||
self.assertEqual(res["errcode"], "M_FORBIDDEN")
|
||||
|
||||
def test_with_aliases(self):
|
||||
alias1 = self._random_alias()
|
||||
alias2 = self._random_alias()
|
||||
|
||||
self._set_alias_via_directory(alias1)
|
||||
self._set_alias_via_directory(alias2)
|
||||
|
||||
res = self._get_aliases(self.room_owner_tok)
|
||||
self.assertEqual(set(res["aliases"]), {alias1, alias2})
|
||||
|
||||
def _get_aliases(self, access_token: str, expected_code: int = 200) -> JsonDict:
|
||||
"""Calls the endpoint under test. returns the json response object."""
|
||||
request, channel = self.make_request(
|
||||
"GET",
|
||||
"/_matrix/client/r0/rooms/%s/aliases" % (self.room_id,),
|
||||
access_token=access_token,
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEqual(channel.code, expected_code, channel.result)
|
||||
res = channel.json_body
|
||||
self.assertIsInstance(res, dict)
|
||||
if expected_code == 200:
|
||||
self.assertIsInstance(res["aliases"], list)
|
||||
return res
|
||||
|
||||
def _random_alias(self) -> str:
|
||||
return RoomAlias(random_string(5), self.hs.hostname).to_string()
|
||||
|
||||
def _set_alias_via_directory(self, alias: str, expected_code: int = 200):
|
||||
url = "/_matrix/client/r0/directory/room/" + alias
|
||||
data = {"room_id": self.room_id}
|
||||
request_data = json.dumps(data)
|
||||
|
||||
request, channel = self.make_request(
|
||||
"PUT", url, request_data, access_token=self.room_owner_tok
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEqual(channel.code, expected_code, channel.result)
|
||||
|
|
|
@ -240,7 +240,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
|
|||
built_event = yield self._base_builder.build(prev_event_ids)
|
||||
|
||||
built_event._event_id = self._event_id
|
||||
built_event._event_dict["event_id"] = self._event_id
|
||||
built_event._dict["event_id"] = self._event_id
|
||||
assert built_event.event_id == self._event_id
|
||||
|
||||
return built_event
|
||||
|
|
|
@ -21,6 +21,7 @@ import hmac
|
|||
import inspect
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional, Tuple, Type, TypeVar, Union
|
||||
|
||||
from mock import Mock
|
||||
|
||||
|
@ -42,7 +43,13 @@ from synapse.server import HomeServer
|
|||
from synapse.types import Requester, UserID, create_requester
|
||||
from synapse.util.ratelimitutils import FederationRateLimiter
|
||||
|
||||
from tests.server import get_clock, make_request, render, setup_test_homeserver
|
||||
from tests.server import (
|
||||
FakeChannel,
|
||||
get_clock,
|
||||
make_request,
|
||||
render,
|
||||
setup_test_homeserver,
|
||||
)
|
||||
from tests.test_utils.logging_setup import setup_logging
|
||||
from tests.utils import default_config, setupdb
|
||||
|
||||
|
@ -71,6 +78,9 @@ def around(target):
|
|||
return _around
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class TestCase(unittest.TestCase):
|
||||
"""A subclass of twisted.trial's TestCase which looks for 'loglevel'
|
||||
attributes on both itself and its individual test methods, to override the
|
||||
|
@ -334,14 +344,14 @@ class HomeserverTestCase(TestCase):
|
|||
|
||||
def make_request(
|
||||
self,
|
||||
method,
|
||||
path,
|
||||
content=b"",
|
||||
access_token=None,
|
||||
request=SynapseRequest,
|
||||
shorthand=True,
|
||||
federation_auth_origin=None,
|
||||
):
|
||||
method: Union[bytes, str],
|
||||
path: Union[bytes, str],
|
||||
content: Union[bytes, dict] = b"",
|
||||
access_token: Optional[str] = None,
|
||||
request: Type[T] = SynapseRequest,
|
||||
shorthand: bool = True,
|
||||
federation_auth_origin: str = None,
|
||||
) -> Tuple[T, FakeChannel]:
|
||||
"""
|
||||
Create a SynapseRequest at the path using the method and containing the
|
||||
given content.
|
||||
|
|
Loading…
Reference in New Issue