From 80828eda06f8e3d6a930c9fa45204ad6fef1d411 Mon Sep 17 00:00:00 2001 From: David Teller Date: Wed, 22 Sep 2021 15:09:43 +0200 Subject: [PATCH 01/31] =?UTF-8?q?Extend=20ModuleApi=20with=20the=20methods?= =?UTF-8?q?=20we'll=20need=20to=20reject=20spam=20based=20on=20=E2=80=A6IP?= =?UTF-8?q?=20-=20resolves=20#10832=20(#10833)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extend ModuleApi with the methods we'll need to reject spam based on IP - resolves #10832 Signed-off-by: David Teller --- changelog.d/10833.misc | 1 + synapse/module_api/__init__.py | 82 +++++++++++++++++++- synapse/storage/databases/main/client_ips.py | 27 +++++-- tests/module_api/test_api.py | 72 +++++++++++++++++ 4 files changed, 174 insertions(+), 8 deletions(-) create mode 100644 changelog.d/10833.misc diff --git a/changelog.d/10833.misc b/changelog.d/10833.misc new file mode 100644 index 0000000000..f23c0a1a02 --- /dev/null +++ b/changelog.d/10833.misc @@ -0,0 +1 @@ +Extend the ModuleApi to let plug-ins check whether an ID is local and to access IP + User Agent data. diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 3196c2bec6..174e6934a8 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -24,8 +24,10 @@ from typing import ( List, Optional, Tuple, + Union, ) +import attr import jinja2 from twisted.internet import defer @@ -46,7 +48,14 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.roommember import ProfileInfo from synapse.storage.state import StateFilter -from synapse.types import JsonDict, Requester, UserID, UserInfo, create_requester +from synapse.types import ( + DomainSpecificString, + JsonDict, + Requester, + UserID, + UserInfo, + create_requester, +) from synapse.util import Clock from synapse.util.caches.descriptors import cached @@ -79,6 +88,18 @@ __all__ = [ logger = logging.getLogger(__name__) +@attr.s(auto_attribs=True) +class UserIpAndAgent: + """ + An IP address and user agent used by a user to connect to this homeserver. + """ + + ip: str + user_agent: str + # The time at which this user agent/ip was last seen. + last_seen: int + + class ModuleApi: """A proxy object that gets passed to various plugin modules so they can register new users etc if necessary. @@ -700,6 +721,65 @@ class ModuleApi: (td for td in (self.custom_template_dir, custom_template_directory) if td), ) + def is_mine(self, id: Union[str, DomainSpecificString]) -> bool: + """ + Checks whether an ID (user id, room, ...) comes from this homeserver. + + Args: + id: any Matrix id (e.g. user id, room id, ...), either as a raw id, + e.g. string "@user:example.com" or as a parsed UserID, RoomID, ... + Returns: + True if id comes from this homeserver, False otherwise. + + Added in Synapse v1.44.0. + """ + if isinstance(id, DomainSpecificString): + return self._hs.is_mine(id) + else: + return self._hs.is_mine_id(id) + + async def get_user_ip_and_agents( + self, user_id: str, since_ts: int = 0 + ) -> List[UserIpAndAgent]: + """ + Return the list of user IPs and agents for a user. + + Args: + user_id: the id of a user, local or remote + since_ts: a timestamp in seconds since the epoch, + or the epoch itself if not specified. + Returns: + The list of all UserIpAndAgent that the user has + used to connect to this homeserver since `since_ts`. + If the user is remote, this list is empty. + + Added in Synapse v1.44.0. + """ + # Don't hit the db if this is not a local user. + is_mine = False + try: + # Let's be defensive against ill-formed strings. + if self.is_mine(user_id): + is_mine = True + except Exception: + pass + + if is_mine: + raw_data = await self._store.get_user_ip_and_agents( + UserID.from_string(user_id), since_ts + ) + # Sanitize some of the data. We don't want to return tokens. + return [ + UserIpAndAgent( + ip=str(data["ip"]), + user_agent=str(data["user_agent"]), + last_seen=int(data["last_seen"]), + ) + for data in raw_data + ] + else: + return [] + class PublicRoomListManager: """Contains methods for adding to, removing from and querying whether a room diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index 7a98275d92..7e33ae578c 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -555,8 +555,11 @@ class ClientIpStore(ClientIpWorkerStore): return ret async def get_user_ip_and_agents( - self, user: UserID + self, user: UserID, since_ts: int = 0 ) -> List[Dict[str, Union[str, int]]]: + """ + Fetch IP/User Agent connection since a given timestamp. + """ user_id = user.to_string() results = {} @@ -568,13 +571,23 @@ class ClientIpStore(ClientIpWorkerStore): ) = key if uid == user_id: user_agent, _, last_seen = self._batch_row_update[key] - results[(access_token, ip)] = (user_agent, last_seen) + if last_seen >= since_ts: + results[(access_token, ip)] = (user_agent, last_seen) - rows = await self.db_pool.simple_select_list( - table="user_ips", - keyvalues={"user_id": user_id}, - retcols=["access_token", "ip", "user_agent", "last_seen"], - desc="get_user_ip_and_agents", + def get_recent(txn): + txn.execute( + """ + SELECT access_token, ip, user_agent, last_seen FROM user_ips + WHERE last_seen >= ? AND user_id = ? + ORDER BY last_seen + DESC + """, + (since_ts, user_id), + ) + return txn.fetchall() + + rows = await self.db_pool.runInteraction( + desc="get_user_ip_and_agents", func=get_recent ) results.update( diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index 7dd519cd44..9d38974fba 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -43,6 +43,7 @@ class ModuleApiTestCase(HomeserverTestCase): self.module_api = homeserver.get_module_api() self.event_creation_handler = homeserver.get_event_creation_handler() self.sync_handler = homeserver.get_sync_handler() + self.auth_handler = homeserver.get_auth_handler() def make_homeserver(self, reactor, clock): return self.setup_test_homeserver( @@ -89,6 +90,77 @@ class ModuleApiTestCase(HomeserverTestCase): found_user = self.get_success(self.module_api.get_userinfo_by_id("@alice:test")) self.assertIsNone(found_user) + def test_get_user_ip_and_agents(self): + user_id = self.register_user("test_get_user_ip_and_agents_user", "1234") + + # Initially, we should have no ip/agent for our user. + info = self.get_success(self.module_api.get_user_ip_and_agents(user_id)) + self.assertEqual(info, []) + + # Insert a first ip, agent. We should be able to retrieve it. + self.get_success( + self.store.insert_client_ip( + user_id, "access_token", "ip_1", "user_agent_1", "device_1", None + ) + ) + info = self.get_success(self.module_api.get_user_ip_and_agents(user_id)) + + self.assertEqual(len(info), 1) + last_seen_1 = info[0].last_seen + + # Insert a second ip, agent at a later date. We should be able to retrieve it. + last_seen_2 = last_seen_1 + 10000 + print("%s => %s" % (last_seen_1, last_seen_2)) + self.get_success( + self.store.insert_client_ip( + user_id, "access_token", "ip_2", "user_agent_2", "device_2", last_seen_2 + ) + ) + info = self.get_success(self.module_api.get_user_ip_and_agents(user_id)) + + self.assertEqual(len(info), 2) + ip_1_seen = False + ip_2_seen = False + + for i in info: + if i.ip == "ip_1": + ip_1_seen = True + self.assertEqual(i.user_agent, "user_agent_1") + self.assertEqual(i.last_seen, last_seen_1) + elif i.ip == "ip_2": + ip_2_seen = True + self.assertEqual(i.user_agent, "user_agent_2") + self.assertEqual(i.last_seen, last_seen_2) + self.assertTrue(ip_1_seen) + self.assertTrue(ip_2_seen) + + # If we fetch from a midpoint between last_seen_1 and last_seen_2, + # we should only find the second ip, agent. + info = self.get_success( + self.module_api.get_user_ip_and_agents( + user_id, (last_seen_1 + last_seen_2) / 2 + ) + ) + self.assertEqual(len(info), 1) + self.assertEqual(info[0].ip, "ip_2") + self.assertEqual(info[0].user_agent, "user_agent_2") + self.assertEqual(info[0].last_seen, last_seen_2) + + # If we fetch from a point later than last_seen_2, we shouldn't + # find anything. + info = self.get_success( + self.module_api.get_user_ip_and_agents(user_id, last_seen_2 + 10000) + ) + self.assertEqual(info, []) + + def test_get_user_ip_and_agents__no_user_found(self): + info = self.get_success( + self.module_api.get_user_ip_and_agents( + "@test_get_user_ip_and_agents_user_nonexistent:example.com" + ) + ) + self.assertEqual(info, []) + def test_sending_events_into_room(self): """Tests that a module can send events into a room""" # Mock out create_and_send_nonmember_event to check whether events are being sent From 724aef9a878cebc137c81f3b261bafb9302fb592 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Wed, 22 Sep 2021 14:21:58 +0100 Subject: [PATCH 02/31] Opt out of cache expiry for `get_users_who_share_room_with_user` (#10826) * Allow LruCaches to opt out of time-based expiry * Don't expire `get_users_who_share_room` & friends --- changelog.d/10826.misc | 2 ++ synapse/storage/databases/main/roommember.py | 11 ++++++++--- synapse/util/caches/deferred_cache.py | 2 ++ synapse/util/caches/descriptors.py | 5 +++++ synapse/util/caches/lrucache.py | 16 +++++++++++++--- 5 files changed, 30 insertions(+), 6 deletions(-) create mode 100644 changelog.d/10826.misc diff --git a/changelog.d/10826.misc b/changelog.d/10826.misc new file mode 100644 index 0000000000..53e56fc362 --- /dev/null +++ b/changelog.d/10826.misc @@ -0,0 +1,2 @@ +Opt out of cache expiry for `get_users_who_share_room_with_user`, to hopefully improve `/sync` performance when you +haven't synced recently. diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 9beeb96aa9..a4ec6bc328 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -162,7 +162,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): self._check_safe_current_state_events_membership_updated_txn, ) - @cached(max_entries=100000, iterable=True) + @cached(max_entries=100000, iterable=True, prune_unread_entries=False) async def get_users_in_room(self, room_id: str) -> List[str]: return await self.db_pool.runInteraction( "get_users_in_room", self.get_users_in_room_txn, room_id @@ -439,7 +439,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): return results_dict.get("membership"), results_dict.get("event_id") - @cached(max_entries=500000, iterable=True) + @cached(max_entries=500000, iterable=True, prune_unread_entries=False) async def get_rooms_for_user_with_stream_ordering( self, user_id: str ) -> FrozenSet[GetRoomsForUserWithStreamOrdering]: @@ -544,7 +544,12 @@ class RoomMemberWorkerStore(EventsWorkerStore): ) return frozenset(r.room_id for r in rooms) - @cached(max_entries=500000, cache_context=True, iterable=True) + @cached( + max_entries=500000, + cache_context=True, + iterable=True, + prune_unread_entries=False, + ) async def get_users_who_share_room_with_user( self, user_id: str, cache_context: _CacheContext ) -> Set[str]: diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index f05590da0d..6262efe072 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -73,6 +73,7 @@ class DeferredCache(Generic[KT, VT]): tree: bool = False, iterable: bool = False, apply_cache_factor_from_config: bool = True, + prune_unread_entries: bool = True, ): """ Args: @@ -105,6 +106,7 @@ class DeferredCache(Generic[KT, VT]): size_callback=(lambda d: len(d) or 1) if iterable else None, metrics_collection_callback=metrics_cb, apply_cache_factor_from_config=apply_cache_factor_from_config, + prune_unread_entries=prune_unread_entries, ) self.thread: Optional[threading.Thread] = None diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 1ca31e41ac..b9dcca17f1 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -258,6 +258,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): tree=False, cache_context=False, iterable=False, + prune_unread_entries: bool = True, ): super().__init__(orig, num_args=num_args, cache_context=cache_context) @@ -269,6 +270,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): self.max_entries = max_entries self.tree = tree self.iterable = iterable + self.prune_unread_entries = prune_unread_entries def __get__(self, obj, owner): cache: DeferredCache[CacheKey, Any] = DeferredCache( @@ -276,6 +278,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): max_entries=self.max_entries, tree=self.tree, iterable=self.iterable, + prune_unread_entries=self.prune_unread_entries, ) get_cache_key = self.cache_key_builder @@ -507,6 +510,7 @@ def cached( tree: bool = False, cache_context: bool = False, iterable: bool = False, + prune_unread_entries: bool = True, ) -> Callable[[F], _CachedFunction[F]]: func = lambda orig: DeferredCacheDescriptor( orig, @@ -515,6 +519,7 @@ def cached( tree=tree, cache_context=cache_context, iterable=iterable, + prune_unread_entries=prune_unread_entries, ) return cast(Callable[[F], _CachedFunction[F]], func) diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index ea6e8dc8d1..4ff62b403f 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -202,10 +202,11 @@ class _Node: cache: "weakref.ReferenceType[LruCache]", clock: Clock, callbacks: Collection[Callable[[], None]] = (), + prune_unread_entries: bool = True, ): self._list_node = ListNode.insert_after(self, root) - self._global_list_node = None - if USE_GLOBAL_LIST: + self._global_list_node: Optional[_TimedListNode] = None + if USE_GLOBAL_LIST and prune_unread_entries: self._global_list_node = _TimedListNode.insert_after(self, GLOBAL_ROOT) self._global_list_node.update_last_access(clock) @@ -314,6 +315,7 @@ class LruCache(Generic[KT, VT]): metrics_collection_callback: Optional[Callable[[], None]] = None, apply_cache_factor_from_config: bool = True, clock: Optional[Clock] = None, + prune_unread_entries: bool = True, ): """ Args: @@ -427,7 +429,15 @@ class LruCache(Generic[KT, VT]): self.len = synchronized(cache_len) def add_node(key, value, callbacks: Collection[Callable[[], None]] = ()): - node = _Node(list_root, key, value, weak_ref_to_self, real_clock, callbacks) + node = _Node( + list_root, + key, + value, + weak_ref_to_self, + real_clock, + callbacks, + prune_unread_entries, + ) cache[key] = node if size_callback: From 52913d56a5a2b07106774d97f4e188148d85a900 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 22 Sep 2021 09:41:42 -0400 Subject: [PATCH 03/31] Add documentation for experimental feature flags. (#10865) --- changelog.d/10865.doc | 1 + docs/SUMMARY.md | 1 + docs/development/experimental_features.md | 37 +++++++++++++++++++++++ 3 files changed, 39 insertions(+) create mode 100644 changelog.d/10865.doc create mode 100644 docs/development/experimental_features.md diff --git a/changelog.d/10865.doc b/changelog.d/10865.doc new file mode 100644 index 0000000000..deeb0eedf3 --- /dev/null +++ b/changelog.d/10865.doc @@ -0,0 +1 @@ +Add developer documentation about experimental configuration flags. diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index fd0045e1ef..bdb44543b8 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -74,6 +74,7 @@ - [Testing]() - [OpenTracing](opentracing.md) - [Database Schemas](development/database_schema.md) + - [Experimental features](development/experimental_features.md) - [Synapse Architecture]() - [Log Contexts](log_contexts.md) - [Replication](replication.md) diff --git a/docs/development/experimental_features.md b/docs/development/experimental_features.md new file mode 100644 index 0000000000..d6b11496cc --- /dev/null +++ b/docs/development/experimental_features.md @@ -0,0 +1,37 @@ +# Implementing experimental features in Synapse + +It can be desirable to implement "experimental" features which are disabled by +default and must be explicitly enabled via the Synapse configuration. This is +applicable for features which: + +* Are unstable in the Matrix spec (e.g. those defined by an MSC that has not yet been merged). +* Developers are not confident in their use by general Synapse administrators/users + (e.g. a feature is incomplete, buggy, performs poorly, or needs further testing). + +Note that this only really applies to features which are expected to be desirable +to a broad audience. The [module infrastructure](../modules/index.md) should +instead be investigated for non-standard features. + +Guarding experimental features behind configuration flags should help with some +of the following scenarios: + +* Ensure that clients do not assume that unstable features exist (failing + gracefully if they do not). +* Unstable features do not become de-facto standards and can be removed + aggressively (since only those who have opted-in will be affected). +* Ease finding the implementation of unstable features in Synapse (for future + removal or stabilization). +* Ease testing a feature (or removal of feature) due to enabling/disabling without + code changes. It also becomes possible to ask for wider testing, if desired. + +Experimental configuration flags should be disabled by default (requiring Synapse +administrators to explicitly opt-in), although there are situations where it makes +sense (from a product point-of-view) to enable features by default. This is +expected and not an issue. + +It is not a requirement for experimental features to be behind a configuration flag, +but one should be used if unsure. + +New experimental configuration flags should be added under the `experimental` +configuration key (see the `synapse.config.experimental` file) and either explain +(briefly) what is being enabled, or include the MSC number. From 9391de3f373454aeec5b5c2f01b3c576528e76fe Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Wed, 22 Sep 2021 14:43:26 +0100 Subject: [PATCH 04/31] Fix /initialSync error due to unhashable `RoomStreamToken` (#10827) The deprecated /initialSync endpoint maintains a cache of responses, using parameter values as part of the cache key. When a `from` or `to` parameter is specified, it gets converted into a `StreamToken`, which contains a `RoomStreamToken` and forms part of the cache key. `RoomStreamToken`s need to be made hashable for this to work. --- changelog.d/10827.bugfix | 1 + synapse/storage/databases/main/stream.py | 4 +++- synapse/types.py | 20 +++++++++++++++----- 3 files changed, 19 insertions(+), 6 deletions(-) create mode 100644 changelog.d/10827.bugfix diff --git a/changelog.d/10827.bugfix b/changelog.d/10827.bugfix new file mode 100644 index 0000000000..11a618bf82 --- /dev/null +++ b/changelog.d/10827.bugfix @@ -0,0 +1 @@ +Fix error in deprecated `/initialSync` endpoint when using the undocumented `from` and `to` parameters. diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 959f13de47..9a3b6f4acf 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -39,6 +39,8 @@ import logging from collections import namedtuple from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple +from frozendict import frozendict + from twisted.internet import defer from synapse.api.filtering import Filter @@ -379,7 +381,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): if p > min_pos } - return RoomStreamToken(None, min_pos, positions) + return RoomStreamToken(None, min_pos, frozendict(positions)) async def get_room_events_stream_for_rooms( self, diff --git a/synapse/types.py b/synapse/types.py index 90168ce8fa..ed831a5c1d 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -30,6 +30,7 @@ from typing import ( ) import attr +from frozendict import frozendict from signedjson.key import decode_verify_key_bytes from unpaddedbase64 import decode_base64 from zope.interface import Interface @@ -457,6 +458,9 @@ class RoomStreamToken: Note: The `RoomStreamToken` cannot have both a topological part and an instance map. + + For caching purposes, `RoomStreamToken`s and by extension, all their + attributes, must be hashable. """ topological = attr.ib( @@ -466,12 +470,12 @@ class RoomStreamToken: stream = attr.ib(type=int, validator=attr.validators.instance_of(int)) instance_map = attr.ib( - type=Dict[str, int], - factory=dict, + type="frozendict[str, int]", + factory=frozendict, validator=attr.validators.deep_mapping( key_validator=attr.validators.instance_of(str), value_validator=attr.validators.instance_of(int), - mapping_validator=attr.validators.instance_of(dict), + mapping_validator=attr.validators.instance_of(frozendict), ), ) @@ -507,7 +511,7 @@ class RoomStreamToken: return cls( topological=None, stream=stream, - instance_map=instance_map, + instance_map=frozendict(instance_map), ) except Exception: pass @@ -540,7 +544,7 @@ class RoomStreamToken: for instance in set(self.instance_map).union(other.instance_map) } - return RoomStreamToken(None, max_stream, instance_map) + return RoomStreamToken(None, max_stream, frozendict(instance_map)) def as_historical_tuple(self) -> Tuple[int, int]: """Returns a tuple of `(topological, stream)` for historical tokens. @@ -593,6 +597,12 @@ class RoomStreamToken: @attr.s(slots=True, frozen=True) class StreamToken: + """A collection of positions within multiple streams. + + For caching purposes, `StreamToken`s and by extension, all their attributes, + must be hashable. + """ + room_key = attr.ib( type=RoomStreamToken, validator=attr.validators.instance_of(RoomStreamToken) ) From 6fc8be9a1b2046e69e8c6f731442887e3addeec0 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 22 Sep 2021 09:45:20 -0400 Subject: [PATCH 05/31] Include more information in oEmbed previews. (#10819) * Improved titles (fall back to the author name if there's not title) and include the site name. * Handle photo/video payloads. * Include the original URL in the Open Graph response. * Fix the expiration time (by properly converting from seconds to milliseconds). --- changelog.d/10819.feature | 1 + synapse/rest/media/v1/oembed.py | 49 +++++++++++++++++-- synapse/rest/media/v1/preview_url_resource.py | 2 +- tests/rest/media/v1/test_url_preview.py | 30 ++++++++---- 4 files changed, 68 insertions(+), 14 deletions(-) create mode 100644 changelog.d/10819.feature diff --git a/changelog.d/10819.feature b/changelog.d/10819.feature new file mode 100644 index 0000000000..4fa95a6cc9 --- /dev/null +++ b/changelog.d/10819.feature @@ -0,0 +1 @@ +Improve oEmbed previews by processing the author name, photo, and video information. diff --git a/synapse/rest/media/v1/oembed.py b/synapse/rest/media/v1/oembed.py index 8b74e72655..e04671fb95 100644 --- a/synapse/rest/media/v1/oembed.py +++ b/synapse/rest/media/v1/oembed.py @@ -13,7 +13,7 @@ # limitations under the License. import logging import urllib.parse -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, List, Optional import attr @@ -22,6 +22,8 @@ from synapse.types import JsonDict from synapse.util import json_decoder if TYPE_CHECKING: + from lxml import etree + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -31,7 +33,7 @@ logger = logging.getLogger(__name__) class OEmbedResult: # The Open Graph result (converted from the oEmbed result). open_graph_result: JsonDict - # Number of seconds to cache the content, according to the oEmbed response. + # Number of milliseconds to cache the content, according to the oEmbed response. # # This will be None if no cache-age is provided in the oEmbed response (or # if the oEmbed response cannot be turned into an Open Graph response). @@ -119,10 +121,22 @@ class OEmbedProvider: # Ensure the cache age is None or an int. cache_age = oembed.get("cache_age") if cache_age: - cache_age = int(cache_age) + cache_age = int(cache_age) * 1000 # The results. - open_graph_response = {"og:title": oembed.get("title")} + open_graph_response = { + "og:url": url, + } + + # Use either title or author's name as the title. + title = oembed.get("title") or oembed.get("author_name") + if title: + open_graph_response["og:title"] = title + + # Use the provider name and as the site. + provider_name = oembed.get("provider_name") + if provider_name: + open_graph_response["og:site_name"] = provider_name # If a thumbnail exists, use it. Note that dimensions will be calculated later. if "thumbnail_url" in oembed: @@ -137,6 +151,15 @@ class OEmbedProvider: # If this is a photo, use the full image, not the thumbnail. open_graph_response["og:image"] = oembed["url"] + elif oembed_type == "video": + open_graph_response["og:type"] = "video.other" + calc_description_and_urls(open_graph_response, oembed["html"]) + open_graph_response["og:video:width"] = oembed["width"] + open_graph_response["og:video:height"] = oembed["height"] + + elif oembed_type == "link": + open_graph_response["og:type"] = "website" + else: raise RuntimeError(f"Unknown oEmbed type: {oembed_type}") @@ -149,6 +172,14 @@ class OEmbedProvider: return OEmbedResult(open_graph_response, cache_age) +def _fetch_urls(tree: "etree.Element", tag_name: str) -> List[str]: + results = [] + for tag in tree.xpath("//*/" + tag_name): + if "src" in tag.attrib: + results.append(tag.attrib["src"]) + return results + + def calc_description_and_urls(open_graph_response: JsonDict, html_body: str) -> None: """ Calculate description for an HTML document. @@ -179,6 +210,16 @@ def calc_description_and_urls(open_graph_response: JsonDict, html_body: str) -> if tree is None: return + # Attempt to find interesting URLs (images, videos, embeds). + if "og:image" not in open_graph_response: + image_urls = _fetch_urls(tree, "img") + if image_urls: + open_graph_response["og:image"] = image_urls[0] + + video_urls = _fetch_urls(tree, "video") + _fetch_urls(tree, "embed") + if video_urls: + open_graph_response["og:video"] = video_urls[0] + from synapse.rest.media.v1.preview_url_resource import _calc_description description = _calc_description(tree) diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 0a0b476d2b..9ffa983fbb 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -305,7 +305,7 @@ class PreviewUrlResource(DirectServeJsonResource): with open(media_info.filename, "rb") as file: body = file.read() - oembed_response = self._oembed.parse_oembed_response(media_info.uri, body) + oembed_response = self._oembed.parse_oembed_response(url, body) og = oembed_response.open_graph_result # Use the cache age from the oEmbed result, instead of the HTTP response. diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index 9d13899584..d83dfacfed 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -620,11 +620,12 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.assertIn(b"/matrixdotorg", server.data) self.assertEqual(channel.code, 200) - self.assertIsNone(channel.json_body["og:title"]) - self.assertTrue(channel.json_body["og:image"].startswith("mxc://")) - self.assertEqual(channel.json_body["og:image:height"], 1) - self.assertEqual(channel.json_body["og:image:width"], 1) - self.assertEqual(channel.json_body["og:image:type"], "image/png") + body = channel.json_body + self.assertEqual(body["og:url"], "http://twitter.com/matrixdotorg/status/12345") + self.assertTrue(body["og:image"].startswith("mxc://")) + self.assertEqual(body["og:image:height"], 1) + self.assertEqual(body["og:image:width"], 1) + self.assertEqual(body["og:image:type"], "image/png") def test_oembed_rich(self): """Test an oEmbed endpoint which returns HTML content via the 'rich' type.""" @@ -633,6 +634,8 @@ class URLPreviewTests(unittest.HomeserverTestCase): result = { "version": "1.0", "type": "rich", + # Note that this provides the author, not the title. + "author_name": "Alice", "html": "
Content Preview
", } end_content = json.dumps(result).encode("utf-8") @@ -660,9 +663,14 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.pump() self.assertEqual(channel.code, 200) + body = channel.json_body self.assertEqual( - channel.json_body, - {"og:title": None, "og:description": "Content Preview"}, + body, + { + "og:url": "http://twitter.com/matrixdotorg/status/12345", + "og:title": "Alice", + "og:description": "Content Preview", + }, ) def test_oembed_format(self): @@ -705,7 +713,11 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.assertIn(b"format=json", server.data) self.assertEqual(channel.code, 200) + body = channel.json_body self.assertEqual( - channel.json_body, - {"og:title": None, "og:description": "Content Preview"}, + body, + { + "og:url": "http://www.hulu.com/watch/12345", + "og:description": "Content Preview", + }, ) From 8f2a52766bc242c02a309f45406f827e670311e7 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 22 Sep 2021 15:20:18 +0100 Subject: [PATCH 06/31] Ensure we mark sent knocks as outliers (#10873) --- changelog.d/10873.bugfix | 1 + synapse/handlers/federation.py | 7 +++++++ 2 files changed, 8 insertions(+) create mode 100644 changelog.d/10873.bugfix diff --git a/changelog.d/10873.bugfix b/changelog.d/10873.bugfix new file mode 100644 index 0000000000..32b2e50fd9 --- /dev/null +++ b/changelog.d/10873.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.37.0 which caused `knock` events which we sent to remote servers to be incorrectly stored in the local database. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 8e2cf3387a..a03d77dffd 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -593,6 +593,13 @@ class FederationHandler(BaseHandler): target_hosts, room_id, knockee, Membership.KNOCK, content, params=params ) + # Mark the knock as an outlier as we don't yet have the state at this point in + # the DAG. + event.internal_metadata.outlier = True + + # ... but tell /sync to send it to clients anyway. + event.internal_metadata.out_of_band_membership = True + # Record the room ID and its version so that we have a record of the room await self._maybe_store_room_on_outlier_membership( room_id=event.room_id, room_version=event_format_version From 03db6701d5379f4aa05037bd9ce23942c501874e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 22 Sep 2021 10:31:05 -0400 Subject: [PATCH 07/31] Fix invalidating OTK count cache after claim (#10875) The invalidation was missing in `_claim_e2e_one_time_key_returning`, which is used on SQLite 3.24+ and Postgres. This could break e2ee if nothing else happened to invalidate the caches before the keys ran out. Signed-off-by: Tulir Asokan --- changelog.d/10875.bugfix | 1 + synapse/storage/databases/main/end_to_end_keys.py | 4 ++++ 2 files changed, 5 insertions(+) create mode 100644 changelog.d/10875.bugfix diff --git a/changelog.d/10875.bugfix b/changelog.d/10875.bugfix new file mode 100644 index 0000000000..6f370da5c7 --- /dev/null +++ b/changelog.d/10875.bugfix @@ -0,0 +1 @@ +Fix invalidating one-time key count cache after claiming keys. Contributed by Tulir at Beeper. diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 1f0a39eac4..a95ac34f09 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -824,6 +824,10 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): if otk_row is None: return None + self._invalidate_cache_and_stream( + txn, self.count_e2e_one_time_keys, (user_id, device_id) + ) + key_id, key_json = otk_row return f"{algorithm}:{key_id}", key_json From f78b68a96b1f179043b38b4109e09fa0a315643d Mon Sep 17 00:00:00 2001 From: Hillery Shay Date: Wed, 22 Sep 2021 08:25:26 -0700 Subject: [PATCH 08/31] Treat "\u0000" as "\u0020" for the purposes of message search (message indexing) (#10820) * add test to check if null code points are being inserted * add logic to detect and replace null code points before insertion into db * lints * add license to test * change approach to null substitution * add type hint for SearchEntry * Add changelog entry Signed-off-by: H.Shay * updated changelog * update chanelog message * remove duplicate changelog * Update synapse/storage/databases/main/events.py remove extra space Co-authored-by: Patrick Cloke * rename and move test file, update tests, delete old test file * fix typo in comments * update _find_highlights_in_postgres to replace null byte with space * replace null byte in sqlite search insertion * beef up and reorganize test for this pr * update changelog * add type hints and update docstring * check db engine directly vs using env variable * refactor tests to be less repetetive * move rplace logic into seperate function * requested changes * Fix typo. * Update synapse/storage/databases/main/search.py Co-authored-by: reivilibre * Update changelog.d/10820.misc Co-authored-by: Aaron Raimist Co-authored-by: Patrick Cloke Co-authored-by: reivilibre Co-authored-by: Aaron Raimist --- changelog.d/10820.misc | 1 + synapse/storage/databases/main/search.py | 34 ++++++++--- tests/storage/test_room_search.py | 74 ++++++++++++++++++++++++ 3 files changed, 100 insertions(+), 9 deletions(-) create mode 100644 changelog.d/10820.misc create mode 100644 tests/storage/test_room_search.py diff --git a/changelog.d/10820.misc b/changelog.d/10820.misc new file mode 100644 index 0000000000..4373bf6f6b --- /dev/null +++ b/changelog.d/10820.misc @@ -0,0 +1 @@ +Fix a long-standing bug where an `m.room.message` event containing a null byte would cause an internal server error. \ No newline at end of file diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index 6480d5a9f5..2a1e99e17a 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -15,12 +15,12 @@ import logging import re from collections import namedtuple -from typing import Collection, List, Optional, Set +from typing import Collection, Iterable, List, Optional, Set from synapse.api.errors import SynapseError from synapse.events import EventBase from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.engines import PostgresEngine, Sqlite3Engine @@ -32,14 +32,24 @@ SearchEntry = namedtuple( ) +def _clean_value_for_search(value: str) -> str: + """ + Replaces any null code points in the string with spaces as + Postgres and SQLite do not like the insertion of strings with + null code points into the full-text search tables. + """ + return value.replace("\u0000", " ") + + class SearchWorkerStore(SQLBaseStore): - def store_search_entries_txn(self, txn, entries): + def store_search_entries_txn( + self, txn: LoggingTransaction, entries: Iterable[SearchEntry] + ) -> None: """Add entries to the search table Args: - txn (cursor): - entries (iterable[SearchEntry]): - entries to be added to the table + txn: + entries: entries to be added to the table """ if not self.hs.config.enable_search: return @@ -55,7 +65,7 @@ class SearchWorkerStore(SQLBaseStore): entry.event_id, entry.room_id, entry.key, - entry.value, + _clean_value_for_search(entry.value), entry.stream_ordering, entry.origin_server_ts, ) @@ -70,11 +80,16 @@ class SearchWorkerStore(SQLBaseStore): " VALUES (?,?,?,?)" ) args = ( - (entry.event_id, entry.room_id, entry.key, entry.value) + ( + entry.event_id, + entry.room_id, + entry.key, + _clean_value_for_search(entry.value), + ) for entry in entries ) - txn.execute_batch(sql, args) + else: # This should be unreachable. raise Exception("Unrecognized database engine") @@ -646,6 +661,7 @@ class SearchStore(SearchBackgroundUpdateStore): for key in ("body", "name", "topic"): v = event.content.get(key, None) if v: + v = _clean_value_for_search(v) values.append(v) if not values: diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py new file mode 100644 index 0000000000..8971ecccbd --- /dev/null +++ b/tests/storage/test_room_search.py @@ -0,0 +1,74 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import synapse.rest.admin +from synapse.rest.client import login, room +from synapse.storage.engines import PostgresEngine + +from tests.unittest import HomeserverTestCase + + +class NullByteInsertionTest(HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + room.register_servlets, + ] + + def test_null_byte(self): + """ + Postgres/SQLite don't like null bytes going into the search tables. Internally + we replace those with a space. + + Ensure this doesn't break anything. + """ + + # Register a user and create a room, create some messages + self.register_user("alice", "password") + access_token = self.login("alice", "password") + room_id = self.helper.create_room_as("alice", tok=access_token) + + # Send messages and ensure they don't cause an internal server + # error + for body in ["hi\u0000bob", "another message", "hi alice"]: + response = self.helper.send(room_id, body, tok=access_token) + self.assertIn("event_id", response) + + # Check that search works for the message where the null byte was replaced + store = self.hs.get_datastore() + result = self.get_success( + store.search_msgs([room_id], "hi bob", ["content.body"]) + ) + self.assertEquals(result.get("count"), 1) + if isinstance(store.database_engine, PostgresEngine): + self.assertIn("hi", result.get("highlights")) + self.assertIn("bob", result.get("highlights")) + + # Check that search works for an unrelated message + result = self.get_success( + store.search_msgs([room_id], "another", ["content.body"]) + ) + self.assertEquals(result.get("count"), 1) + if isinstance(store.database_engine, PostgresEngine): + self.assertIn("another", result.get("highlights")) + + # Check that search works for a search term that overlaps with the message + # containing a null byte and an unrelated message. + result = self.get_success(store.search_msgs([room_id], "hi", ["content.body"])) + self.assertEquals(result.get("count"), 2) + result = self.get_success( + store.search_msgs([room_id], "hi alice", ["content.body"]) + ) + if isinstance(store.database_engine, PostgresEngine): + self.assertIn("alice", result.get("highlights")) From 26f2bfedbf5493d8a69d1b38147b6236e7606cd3 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 22 Sep 2021 17:58:57 +0100 Subject: [PATCH 09/31] Factor out a separate `EventContext.for_outlier` (#10883) Constructing an EventContext for an outlier is actually really simple, and there's no sense in going via an `async` method in the `StateHandler`. This also means that we can resolve a bunch of FIXMEs. --- changelog.d/10883.misc | 1 + synapse/events/snapshot.py | 14 ++++++++---- synapse/handlers/federation.py | 9 ++++---- synapse/handlers/federation_event.py | 7 ++---- synapse/state/__init__.py | 34 ++++------------------------ 5 files changed, 21 insertions(+), 44 deletions(-) create mode 100644 changelog.d/10883.misc diff --git a/changelog.d/10883.misc b/changelog.d/10883.misc new file mode 100644 index 0000000000..9a765435db --- /dev/null +++ b/changelog.d/10883.misc @@ -0,0 +1 @@ +Clean up some of the federation event authentication code for clarity. diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index f8d898c3b1..5ba01eeef9 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -80,9 +80,7 @@ class EventContext: (type, state_key) -> event_id - FIXME: what is this for an outlier? it seems ill-defined. It seems like - it could be either {}, or the state we were given by the remote - server, depending on $THINGS + For an outlier, this is {} Note that this is a private attribute: it should be accessed via ``get_current_state_ids``. _AsyncEventContext impl calculates this @@ -96,7 +94,7 @@ class EventContext: (type, state_key) -> event_id - FIXME: again, what is this for an outlier? + For an outlier, this is {} As with _current_state_ids, this is a private attribute. It should be accessed via get_prev_state_ids. @@ -130,6 +128,14 @@ class EventContext: delta_ids=delta_ids, ) + @staticmethod + def for_outlier(): + """Return an EventContext instance suitable for persisting an outlier event""" + return EventContext( + current_state_ids={}, + prev_state_ids={}, + ) + async def serialize(self, event: EventBase, store: "DataStore") -> dict: """Converts self to a type that can be serialized as JSON, and then deserialized by `deserialize` diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index a03d77dffd..0befe9ce43 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -624,7 +624,7 @@ class FederationHandler(BaseHandler): # in the invitee's sync stream. It is stripped out for all other local users. event.unsigned["knock_room_state"] = stripped_room_state["knock_state_events"] - context = await self.state_handler.compute_event_context(event) + context = EventContext.for_outlier() stream_id = await self._federation_event_handler.persist_events_and_notify( event.room_id, [(event, context)] ) @@ -814,7 +814,7 @@ class FederationHandler(BaseHandler): ) ) - context = await self.state_handler.compute_event_context(event) + context = EventContext.for_outlier() await self._federation_event_handler.persist_events_and_notify( event.room_id, [(event, context)] ) @@ -843,7 +843,7 @@ class FederationHandler(BaseHandler): await self.federation_client.send_leave(host_list, event) - context = await self.state_handler.compute_event_context(event) + context = EventContext.for_outlier() stream_id = await self._federation_event_handler.persist_events_and_notify( event.room_id, [(event, context)] ) @@ -1115,8 +1115,7 @@ class FederationHandler(BaseHandler): events_to_context = {} for e in itertools.chain(auth_events, state): e.internal_metadata.outlier = True - ctx = await self.state_handler.compute_event_context(e) - events_to_context[e.event_id] = ctx + events_to_context[e.event_id] = EventContext.for_outlier() event_map = { e.event_id: e for e in itertools.chain(auth_events, state, [event]) diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 3b95beeb08..10b3fdc222 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -1221,7 +1221,7 @@ class FederationEventHandler: async def prep(ev_info: _NewEventInfo) -> EventContext: event = ev_info.event with nested_logging_context(suffix=event.event_id): - res = await self._state_handler.compute_event_context(event) + res = EventContext.for_outlier() res = await self._check_event_auth( origin, event, @@ -1540,10 +1540,7 @@ class FederationEventHandler: event.event_id, auth_event.event_id, ) - missing_auth_event_context = ( - await self._state_handler.compute_event_context(auth_event) - ) - + missing_auth_event_context = EventContext.for_outlier() missing_auth_event_context = await self._check_event_auth( origin, auth_event, diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 463ce58dae..c981df3f18 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -263,7 +263,9 @@ class StateHandler: async def compute_event_context( self, event: EventBase, old_state: Optional[Iterable[EventBase]] = None ) -> EventContext: - """Build an EventContext structure for the event. + """Build an EventContext structure for a non-outlier event. + + (for an outlier, call EventContext.for_outlier directly) This works out what the current state should be for the event, and generates a new state group if necessary. @@ -278,35 +280,7 @@ class StateHandler: The event context. """ - if event.internal_metadata.is_outlier(): - # If this is an outlier, then we know it shouldn't have any current - # state. Certainly store.get_current_state won't return any, and - # persisting the event won't store the state group. - - # FIXME: why do we populate current_state_ids? I thought the point was - # that we weren't supposed to have any state for outliers? - if old_state: - prev_state_ids = {(s.type, s.state_key): s.event_id for s in old_state} - if event.is_state(): - current_state_ids = dict(prev_state_ids) - key = (event.type, event.state_key) - current_state_ids[key] = event.event_id - else: - current_state_ids = prev_state_ids - else: - current_state_ids = {} - prev_state_ids = {} - - # We don't store state for outliers, so we don't generate a state - # group for it. - context = EventContext.with_state( - state_group=None, - state_group_before_event=None, - current_state_ids=current_state_ids, - prev_state_ids=prev_state_ids, - ) - - return context + assert not event.internal_metadata.is_outlier() # # first of all, figure out the state before the event From aa2c027792d04c36b17866710e95a41d31f5d99c Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Thu, 23 Sep 2021 11:59:07 +0100 Subject: [PATCH 10/31] Remove unnecessary parentheses around tuples returned from methods (#10889) --- changelog.d/10889.misc | 1 + synapse/config/server.py | 2 +- synapse/federation/sender/per_destination_queue.py | 4 ++-- synapse/handlers/federation.py | 2 +- synapse/handlers/message.py | 4 ++-- synapse/handlers/receipts.py | 4 ++-- synapse/handlers/room.py | 2 +- synapse/handlers/room_summary.py | 2 +- synapse/handlers/typing.py | 4 ++-- synapse/http/matrixfederationclient.py | 2 +- synapse/rest/admin/rooms.py | 4 ++-- synapse/rest/client/devices.py | 4 ++-- synapse/rest/client/password_policy.py | 4 ++-- synapse/storage/databases/main/account_data.py | 2 +- synapse/storage/databases/main/deviceinbox.py | 6 +++--- synapse/storage/databases/main/events_worker.py | 2 +- synapse/storage/databases/main/state_deltas.py | 2 +- synapse/storage/databases/main/stream.py | 4 ++-- synapse/streams/config.py | 2 +- synapse/types.py | 4 ++-- tests/test_state.py | 2 +- tests/utils.py | 2 +- 22 files changed, 33 insertions(+), 32 deletions(-) create mode 100644 changelog.d/10889.misc diff --git a/changelog.d/10889.misc b/changelog.d/10889.misc new file mode 100644 index 0000000000..6d60188f55 --- /dev/null +++ b/changelog.d/10889.misc @@ -0,0 +1 @@ +Clean up some unnecessary parentheses in places around the codebase. \ No newline at end of file diff --git a/synapse/config/server.py b/synapse/config/server.py index 7b9109a592..ad8715da29 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -1447,7 +1447,7 @@ def read_gc_thresholds(thresholds): return None try: assert len(thresholds) == 3 - return (int(thresholds[0]), int(thresholds[1]), int(thresholds[2])) + return int(thresholds[0]), int(thresholds[1]), int(thresholds[2]) except Exception: raise ConfigError( "Value of `gc_threshold` must be a list of three integers if set" diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index c11d1f6d31..afe35e72b6 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -560,7 +560,7 @@ class PerDestinationQueue: assert len(edus) <= limit, "get_device_updates_by_remote returned too many EDUs" - return (edus, now_stream_id) + return edus, now_stream_id async def _get_to_device_message_edus(self, limit: int) -> Tuple[List[Edu], int]: last_device_stream_id = self._last_device_stream_id @@ -593,7 +593,7 @@ class PerDestinationQueue: stream_id, ) - return (edus, stream_id) + return edus, stream_id def _start_catching_up(self) -> None: """ diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 0befe9ce43..4523b25636 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1369,7 +1369,7 @@ class FederationHandler(BaseHandler): builder=builder ) EventValidator().validate_new(event, self.config) - return (event, context) + return event, context async def _check_signature(self, event: EventBase, context: EventContext) -> None: """ diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 6cd694b2da..7a5d8e6f4e 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -666,7 +666,7 @@ class EventCreationHandler: self.validator.validate_new(event, self.config) - return (event, context) + return event, context async def _is_exempt_from_privacy_policy( self, builder: EventBuilder, requester: Requester @@ -1004,7 +1004,7 @@ class EventCreationHandler: logger.debug("Created event %s", event.event_id) - return (event, context) + return event, context @measure_func("handle_new_client_event") async def handle_new_client_event( diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 5881f09ebd..f21f33ada2 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -238,7 +238,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]): if self.config.experimental.msc2285_enabled: events = ReceiptEventSource.filter_out_hidden(events, user.to_string()) - return (events, to_key) + return events, to_key async def get_new_events_as( self, from_key: int, service: ApplicationService @@ -270,7 +270,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]): events.append(event) - return (events, to_key) + return events, to_key def get_current_key(self, direction: str = "f") -> int: return self.store.get_max_receipt_stream_id() diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 287ea2fd06..b5768220d9 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1235,7 +1235,7 @@ class RoomEventSource(EventSource[RoomStreamToken, EventBase]): else: end_key = to_key - return (events, end_key) + return events, end_key def get_current_key(self) -> RoomStreamToken: return self.store.get_room_max_token() diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py index 4e28fb9685..fb26ee7ad7 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py @@ -1179,4 +1179,4 @@ def _child_events_comparison_key( order = None # Items without an order come last. - return (order is None, order, child.origin_server_ts, child.room_id) + return order is None, order, child.origin_server_ts, child.room_id diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 9326330c90..d10e9b8ec4 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -483,7 +483,7 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]): events.append(self._make_event_for(room_id)) - return (events, handler._latest_room_serial) + return events, handler._latest_room_serial async def get_new_events( self, @@ -507,7 +507,7 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]): events.append(self._make_event_for(room_id)) - return (events, handler._latest_room_serial) + return events, handler._latest_room_serial def get_current_key(self) -> int: return self.get_typing_handler()._latest_room_serial diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index ef10ec0937..e56fa477bb 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -1186,7 +1186,7 @@ class MatrixFederationHttpClient: request.method, request.uri.decode("ascii"), ) - return (length, headers) + return length, headers def _flatten_response_never_received(e): diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 8f781f745f..a4823ca6e7 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -213,7 +213,7 @@ class RoomRestServlet(RestServlet): members = await self.store.get_users_in_room(room_id) ret["joined_local_devices"] = await self.store.count_devices_by_users(members) - return (200, ret) + return 200, ret async def on_DELETE( self, request: SynapseRequest, room_id: str @@ -668,4 +668,4 @@ async def _delete_room( if purge: await pagination_handler.purge_room(room_id, force=force_purge) - return (200, ret) + return 200, ret diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py index 25bc3c8f47..8566dc5cb5 100644 --- a/synapse/rest/client/devices.py +++ b/synapse/rest/client/devices.py @@ -211,7 +211,7 @@ class DehydratedDeviceServlet(RestServlet): if dehydrated_device is not None: (device_id, device_data) = dehydrated_device result = {"device_id": device_id, "device_data": device_data} - return (200, result) + return 200, result else: raise errors.NotFoundError("No dehydrated device available") @@ -293,7 +293,7 @@ class ClaimDehydratedDeviceServlet(RestServlet): submission["device_id"], ) - return (200, result) + return 200, result def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: diff --git a/synapse/rest/client/password_policy.py b/synapse/rest/client/password_policy.py index 6d64efb165..0465fd2292 100644 --- a/synapse/rest/client/password_policy.py +++ b/synapse/rest/client/password_policy.py @@ -40,7 +40,7 @@ class PasswordPolicyServlet(RestServlet): def on_GET(self, request: Request) -> Tuple[int, JsonDict]: if not self.enabled or not self.policy: - return (200, {}) + return 200, {} policy = {} @@ -54,7 +54,7 @@ class PasswordPolicyServlet(RestServlet): if param in self.policy: policy["m.%s" % param] = self.policy[param] - return (200, policy) + return 200, policy def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index d0cf3460da..70ca3e09f7 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -324,7 +324,7 @@ class AccountDataWorkerStore(SQLBaseStore): user_id, int(stream_id) ) if not changed: - return ({}, {}) + return {}, {} return await self.db_pool.runInteraction( "get_updated_account_data_for_user", get_updated_account_data_for_user_txn diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index c55508867d..3154906d45 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -136,7 +136,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): user_id, last_stream_id ) if not has_changed: - return ([], current_stream_id) + return [], current_stream_id def get_new_messages_for_device_txn(txn): sql = ( @@ -240,11 +240,11 @@ class DeviceInboxWorkerStore(SQLBaseStore): ) if not has_changed or last_stream_id == current_stream_id: log_kv({"message": "No new messages in stream"}) - return ([], current_stream_id) + return [], current_stream_id if limit <= 0: # This can happen if we run out of room for EDUs in the transaction. - return ([], last_stream_id) + return [], last_stream_id @trace def get_new_messages_for_remote_destination_txn(txn): diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index d72e716b5c..4a1a2f4a6a 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -1495,7 +1495,7 @@ class EventsWorkerStore(SQLBaseStore): if not res: raise SynapseError(404, "Could not find event %s" % (event_id,)) - return (int(res["topological_ordering"]), int(res["stream_ordering"])) + return int(res["topological_ordering"]), int(res["stream_ordering"]) async def get_next_event_to_expire(self) -> Optional[Tuple[str, int]]: """Retrieve the entry with the lowest expiry timestamp in the event_expiry diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py index bff7d0404f..a89747d741 100644 --- a/synapse/storage/databases/main/state_deltas.py +++ b/synapse/storage/databases/main/state_deltas.py @@ -58,7 +58,7 @@ class StateDeltasStore(SQLBaseStore): # if the CSDs haven't changed between prev_stream_id and now, we # know for certain that they haven't changed between prev_stream_id and # max_stream_id. - return (max_stream_id, []) + return max_stream_id, [] def get_current_state_deltas_txn(txn): # First we calculate the max stream id that will give us less than diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 9a3b6f4acf..dc7884b1c0 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -624,7 +624,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): self._set_before_and_after(events, rows) - return (events, token) + return events, token async def get_recent_event_ids_for_room( self, room_id: str, limit: int, end_token: RoomStreamToken @@ -1242,7 +1242,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): self._set_before_and_after(events, rows) - return (events, token) + return events, token @cached() async def get_id_for_instance(self, instance_name: str) -> int: diff --git a/synapse/streams/config.py b/synapse/streams/config.py index cf4005984b..c08d591f29 100644 --- a/synapse/streams/config.py +++ b/synapse/streams/config.py @@ -81,7 +81,7 @@ class PaginationConfig: raise SynapseError(400, "Invalid request.") def __repr__(self) -> str: - return ("PaginationConfig(from_tok=%r, to_tok=%r, direction=%r, limit=%r)") % ( + return "PaginationConfig(from_tok=%r, to_tok=%r, direction=%r, limit=%r)" % ( self.from_token, self.to_token, self.direction, diff --git a/synapse/types.py b/synapse/types.py index ed831a5c1d..364ecf7d45 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -556,7 +556,7 @@ class RoomStreamToken: "Cannot call `RoomStreamToken.as_historical_tuple` on live token" ) - return (self.topological, self.stream) + return self.topological, self.stream def get_stream_pos_for_instance(self, instance_name: str) -> int: """Get the stream position that the given writer was at at this token. @@ -766,7 +766,7 @@ def get_verify_key_from_cross_signing_key(key_info): raise ValueError("Invalid key") # and return that one key for key_id, key_data in keys.items(): - return (key_id, decode_verify_key_bytes(key_id, decode_base64(key_data))) + return key_id, decode_verify_key_bytes(key_id, decode_base64(key_data)) @attr.s(auto_attribs=True, frozen=True, slots=True) diff --git a/tests/test_state.py b/tests/test_state.py index e5488df1ac..76e0e8ca7f 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -106,7 +106,7 @@ class StateGroupStore: } async def get_state_group_delta(self, name): - return (None, None) + return None, None def register_events(self, events): for e in events: diff --git a/tests/utils.py b/tests/utils.py index f3458ca88d..cf8ba5c5db 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -434,7 +434,7 @@ class MockHttpResource: ) return code, response except CodeMessageException as e: - return (e.code, cs_error(e.msg, code=e.errcode)) + return e.code, cs_error(e.msg, code=e.errcode) raise KeyError("No event can handle %s" % path) From e584534403b55ad3f250f92592e30b15b01f0201 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 23 Sep 2021 07:13:34 -0400 Subject: [PATCH 11/31] Use direct references for some configuration variables (part 3) (#10885) This avoids the overhead of searching through the various configuration classes by directly referencing the class that the attributes are in. It also improves type hints since mypy can now resolve the types of the configuration variables. --- changelog.d/10885.misc | 1 + synapse/app/homeserver.py | 2 +- synapse/config/consent.py | 9 +++-- synapse/handlers/account_validity.py | 2 +- synapse/handlers/appservice.py | 2 +- synapse/handlers/auth.py | 22 +++++----- synapse/handlers/cas.py | 8 ++-- synapse/handlers/identity.py | 12 +++--- synapse/handlers/message.py | 4 +- synapse/handlers/password_policy.py | 4 +- synapse/handlers/register.py | 11 +++-- synapse/handlers/ui_auth/checkers.py | 17 +++++--- synapse/module_api/__init__.py | 8 ++-- synapse/push/pusher.py | 2 +- synapse/rest/admin/users.py | 4 +- synapse/rest/client/account.py | 40 +++++++++---------- synapse/rest/client/auth.py | 10 ++--- synapse/rest/client/login.py | 4 +- synapse/rest/client/password_policy.py | 4 +- synapse/rest/client/register.py | 30 +++++++------- synapse/rest/consent/consent_resource.py | 9 +++-- synapse/rest/synapse/client/password_reset.py | 10 ++--- .../server_notices/consent_server_notices.py | 11 +++-- synapse/storage/databases/main/appservice.py | 2 +- .../databases/main/monthly_active_users.py | 2 +- .../storage/databases/main/registration.py | 2 +- synapse/storage/prepare_database.py | 2 +- .../storage/schema/main/delta/30/as_users.py | 2 +- tests/rest/admin/test_room.py | 2 +- tests/rest/client/test_login.py | 2 +- tests/storage/test_appservice.py | 14 +++---- tests/storage/test_cleanup_extrems.py | 2 +- 32 files changed, 137 insertions(+), 119 deletions(-) create mode 100644 changelog.d/10885.misc diff --git a/changelog.d/10885.misc b/changelog.d/10885.misc new file mode 100644 index 0000000000..586a0b3a96 --- /dev/null +++ b/changelog.d/10885.misc @@ -0,0 +1 @@ +Use direct references to config flags. diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index b909f8db8d..886e291e4c 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -195,7 +195,7 @@ class SynapseHomeServer(HomeServer): } ) - if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL: from synapse.rest.synapse.client.password_reset import ( PasswordResetSubmitTokenResource, ) diff --git a/synapse/config/consent.py b/synapse/config/consent.py index b05a9bd97f..ecc43b08b9 100644 --- a/synapse/config/consent.py +++ b/synapse/config/consent.py @@ -13,6 +13,7 @@ # limitations under the License. from os import path +from typing import Optional from synapse.config import ConfigError @@ -78,8 +79,8 @@ class ConsentConfig(Config): def __init__(self, *args): super().__init__(*args) - self.user_consent_version = None - self.user_consent_template_dir = None + self.user_consent_version: Optional[str] = None + self.user_consent_template_dir: Optional[str] = None self.user_consent_server_notice_content = None self.user_consent_server_notice_to_guests = False self.block_events_without_consent_error = None @@ -94,7 +95,9 @@ class ConsentConfig(Config): return self.user_consent_version = str(consent_config["version"]) self.user_consent_template_dir = self.abspath(consent_config["template_dir"]) - if not path.isdir(self.user_consent_template_dir): + if not isinstance(self.user_consent_template_dir, str) or not path.isdir( + self.user_consent_template_dir + ): raise ConfigError( "Could not find template directory '%s'" % (self.user_consent_template_dir,) diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index 4724565ba5..5a5f124ddf 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -47,7 +47,7 @@ class AccountValidityHandler: self.send_email_handler = self.hs.get_send_email_handler() self.clock = self.hs.get_clock() - self._app_name = self.hs.config.email_app_name + self._app_name = self.hs.config.email.email_app_name self._account_validity_enabled = ( hs.config.account_validity.account_validity_enabled diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index b7213b67a5..163278708c 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -52,7 +52,7 @@ class ApplicationServicesHandler: self.scheduler = hs.get_application_service_scheduler() self.started_scheduler = False self.clock = hs.get_clock() - self.notify_appservices = hs.config.notify_appservices + self.notify_appservices = hs.config.appservice.notify_appservices self.event_sources = hs.get_event_sources() self.current_max = 0 diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index bcd4249e09..b747f80bc1 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -210,15 +210,15 @@ class AuthHandler(BaseHandler): self.password_providers = [ PasswordProvider.load(module, config, account_handler) - for module, config in hs.config.password_providers + for module, config in hs.config.authproviders.password_providers ] logger.info("Extra password_providers: %s", self.password_providers) self.hs = hs # FIXME better possibility to access registrationHandler later? self.macaroon_gen = hs.get_macaroon_generator() - self._password_enabled = hs.config.password_enabled - self._password_localdb_enabled = hs.config.password_localdb_enabled + self._password_enabled = hs.config.auth.password_enabled + self._password_localdb_enabled = hs.config.auth.password_localdb_enabled # start out by assuming PASSWORD is enabled; we will remove it later if not. login_types = set() @@ -250,7 +250,7 @@ class AuthHandler(BaseHandler): ) # The number of seconds to keep a UI auth session active. - self._ui_auth_session_timeout = hs.config.ui_auth_session_timeout + self._ui_auth_session_timeout = hs.config.auth.ui_auth_session_timeout # Ratelimitier for failed /login attempts self._failed_login_attempts_ratelimiter = Ratelimiter( @@ -739,19 +739,19 @@ class AuthHandler(BaseHandler): return canonical_id def _get_params_recaptcha(self) -> dict: - return {"public_key": self.hs.config.recaptcha_public_key} + return {"public_key": self.hs.config.captcha.recaptcha_public_key} def _get_params_terms(self) -> dict: return { "policies": { "privacy_policy": { - "version": self.hs.config.user_consent_version, + "version": self.hs.config.consent.user_consent_version, "en": { - "name": self.hs.config.user_consent_policy_name, + "name": self.hs.config.consent.user_consent_policy_name, "url": "%s_matrix/consent?v=%s" % ( self.hs.config.server.public_baseurl, - self.hs.config.user_consent_version, + self.hs.config.consent.user_consent_version, ), }, } @@ -1016,7 +1016,7 @@ class AuthHandler(BaseHandler): def can_change_password(self) -> bool: """Get whether users on this server are allowed to change or set a password. - Both `config.password_enabled` and `config.password_localdb_enabled` must be true. + Both `config.auth.password_enabled` and `config.auth.password_localdb_enabled` must be true. Note that any account (even SSO accounts) are allowed to add passwords if the above is true. @@ -1486,7 +1486,7 @@ class AuthHandler(BaseHandler): pw = unicodedata.normalize("NFKC", password) return bcrypt.hashpw( - pw.encode("utf8") + self.hs.config.password_pepper.encode("utf8"), + pw.encode("utf8") + self.hs.config.auth.password_pepper.encode("utf8"), bcrypt.gensalt(self.bcrypt_rounds), ).decode("ascii") @@ -1510,7 +1510,7 @@ class AuthHandler(BaseHandler): pw = unicodedata.normalize("NFKC", password) return bcrypt.checkpw( - pw.encode("utf8") + self.hs.config.password_pepper.encode("utf8"), + pw.encode("utf8") + self.hs.config.auth.password_pepper.encode("utf8"), checked_hash, ) diff --git a/synapse/handlers/cas.py b/synapse/handlers/cas.py index b0b188dc78..5d8f6c50a9 100644 --- a/synapse/handlers/cas.py +++ b/synapse/handlers/cas.py @@ -65,10 +65,10 @@ class CasHandler: self._auth_handler = hs.get_auth_handler() self._registration_handler = hs.get_registration_handler() - self._cas_server_url = hs.config.cas_server_url - self._cas_service_url = hs.config.cas_service_url - self._cas_displayname_attribute = hs.config.cas_displayname_attribute - self._cas_required_attributes = hs.config.cas_required_attributes + self._cas_server_url = hs.config.cas.cas_server_url + self._cas_service_url = hs.config.cas.cas_service_url + self._cas_displayname_attribute = hs.config.cas.cas_displayname_attribute + self._cas_required_attributes = hs.config.cas.cas_required_attributes self._http_client = hs.get_proxied_http_client() diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 8b8f1f41ca..fe8a995892 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -62,7 +62,7 @@ class IdentityHandler(BaseHandler): self.federation_http_client = hs.get_federation_http_client() self.hs = hs - self._web_client_location = hs.config.invite_client_location + self._web_client_location = hs.config.email.invite_client_location # Ratelimiters for `/requestToken` endpoints. self._3pid_validation_ratelimiter_ip = Ratelimiter( @@ -419,7 +419,7 @@ class IdentityHandler(BaseHandler): token_expires = ( self.hs.get_clock().time_msec() - + self.hs.config.email_validation_token_lifetime + + self.hs.config.email.email_validation_token_lifetime ) await self.store.start_or_continue_validation_session( @@ -465,7 +465,7 @@ class IdentityHandler(BaseHandler): if next_link: params["next_link"] = next_link - if self.hs.config.using_identity_server_from_trusted_list: + if self.hs.config.email.using_identity_server_from_trusted_list: # Warn that a deprecated config option is in use logger.warning( 'The config option "trust_identity_server_for_password_resets" ' @@ -518,7 +518,7 @@ class IdentityHandler(BaseHandler): if next_link: params["next_link"] = next_link - if self.hs.config.using_identity_server_from_trusted_list: + if self.hs.config.email.using_identity_server_from_trusted_list: # Warn that a deprecated config option is in use logger.warning( 'The config option "trust_identity_server_for_password_resets" ' @@ -572,12 +572,12 @@ class IdentityHandler(BaseHandler): validation_session = None # Try to validate as email - if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: + if self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE: # Ask our delegated email identity server validation_session = await self.threepid_from_creds( self.hs.config.account_threepid_delegate_email, threepid_creds ) - elif self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + elif self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL: # Get a validated session matching these details validation_session = await self.store.get_threepid_validation_session( "email", client_secret, sid=sid, validated=True diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 7a5d8e6f4e..ad4e4a3d6f 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -443,7 +443,7 @@ class EventCreationHandler: ) self._block_events_without_consent_error = ( - self.config.block_events_without_consent_error + self.config.consent.block_events_without_consent_error ) # we need to construct a ConsentURIBuilder here, as it checks that the necessary @@ -744,7 +744,7 @@ class EventCreationHandler: if u["appservice_id"] is not None: # users registered by an appservice are exempt return - if u["consent_version"] == self.config.user_consent_version: + if u["consent_version"] == self.config.consent.user_consent_version: return consent_uri = self._consent_uri_builder.build_user_consent_uri(user.localpart) diff --git a/synapse/handlers/password_policy.py b/synapse/handlers/password_policy.py index cd21efdcc6..eadd7ced09 100644 --- a/synapse/handlers/password_policy.py +++ b/synapse/handlers/password_policy.py @@ -27,8 +27,8 @@ logger = logging.getLogger(__name__) class PasswordPolicyHandler: def __init__(self, hs: "HomeServer"): - self.policy = hs.config.password_policy - self.enabled = hs.config.password_policy_enabled + self.policy = hs.config.auth.password_policy + self.enabled = hs.config.auth.password_policy_enabled # Regexps for the spec'd policy parameters. self.regexp_digit = re.compile("[0-9]") diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 1c195c65db..01c5e1385d 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -97,6 +97,7 @@ class RegistrationHandler(BaseHandler): self.ratelimiter = hs.get_registration_ratelimiter() self.macaroon_gen = hs.get_macaroon_generator() self._account_validity_handler = hs.get_account_validity_handler() + self._user_consent_version = self.hs.config.consent.user_consent_version self._server_notices_mxid = hs.config.server_notices_mxid self._server_name = hs.hostname @@ -339,7 +340,7 @@ class RegistrationHandler(BaseHandler): auth_provider=(auth_provider_id or ""), ).inc() - if not self.hs.config.user_consent_at_registration: + if not self.hs.config.consent.user_consent_at_registration: if not self.hs.config.auto_join_rooms_for_guests and make_guest: logger.info( "Skipping auto-join for %s because auto-join for guests is disabled", @@ -864,7 +865,9 @@ class RegistrationHandler(BaseHandler): await self._register_msisdn_threepid(user_id, threepid) if auth_result and LoginType.TERMS in auth_result: - await self._on_user_consented(user_id, self.hs.config.user_consent_version) + # The terms type should only exist if consent is enabled. + assert self._user_consent_version is not None + await self._on_user_consented(user_id, self._user_consent_version) async def _on_user_consented(self, user_id: str, consent_version: str) -> None: """A user consented to the terms on registration @@ -910,8 +913,8 @@ class RegistrationHandler(BaseHandler): # getting mail spam where they weren't before if email # notifs are set up on a homeserver) if ( - self.hs.config.email_enable_notifs - and self.hs.config.email_notif_for_new_users + self.hs.config.email.email_enable_notifs + and self.hs.config.email.email_notif_for_new_users and token ): # Pull the ID of the access token back out of the db diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py index ea9325e96a..8f5d465fa1 100644 --- a/synapse/handlers/ui_auth/checkers.py +++ b/synapse/handlers/ui_auth/checkers.py @@ -82,10 +82,10 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self._enabled = bool(hs.config.recaptcha_private_key) + self._enabled = bool(hs.config.captcha.recaptcha_private_key) self._http_client = hs.get_proxied_http_client() - self._url = hs.config.recaptcha_siteverify_api - self._secret = hs.config.recaptcha_private_key + self._url = hs.config.captcha.recaptcha_siteverify_api + self._secret = hs.config.captcha.recaptcha_private_key def is_enabled(self) -> bool: return self._enabled @@ -161,12 +161,17 @@ class _BaseThreepidAuthChecker: self.hs.config.account_threepid_delegate_msisdn, threepid_creds ) elif medium == "email": - if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: + if ( + self.hs.config.email.threepid_behaviour_email + == ThreepidBehaviour.REMOTE + ): assert self.hs.config.account_threepid_delegate_email threepid = await identity_handler.threepid_from_creds( self.hs.config.account_threepid_delegate_email, threepid_creds ) - elif self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + elif ( + self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL + ): threepid = None row = await self.store.get_threepid_validation_session( medium, @@ -218,7 +223,7 @@ class EmailIdentityAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChec _BaseThreepidAuthChecker.__init__(self, hs) def is_enabled(self) -> bool: - return self.hs.config.threepid_behaviour_email in ( + return self.hs.config.email.threepid_behaviour_email in ( ThreepidBehaviour.REMOTE, ThreepidBehaviour.LOCAL, ) diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 174e6934a8..8ae21bc43c 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -119,14 +119,16 @@ class ModuleApi: self.custom_template_dir = hs.config.server.custom_template_directory try: - app_name = self._hs.config.email_app_name + app_name = self._hs.config.email.email_app_name - self._from_string = self._hs.config.email_notif_from % {"app": app_name} + self._from_string = self._hs.config.email.email_notif_from % { + "app": app_name + } except (KeyError, TypeError): # If substitution failed (which can happen if the string contains # placeholders other than just "app", or if the type of the placeholder is # not a string), fall back to the bare strings. - self._from_string = self._hs.config.email_notif_from + self._from_string = self._hs.config.email.email_notif_from self._raw_from = email.utils.parseaddr(self._from_string)[1] diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py index 29ed346d37..b57e094091 100644 --- a/synapse/push/pusher.py +++ b/synapse/push/pusher.py @@ -77,4 +77,4 @@ class PusherFactory: if isinstance(brand, str): return brand - return self.config.email_app_name + return self.config.email.email_app_name diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 681e491826..46bfec4623 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -368,8 +368,8 @@ class UserRestServletV2(RestServlet): user_id, medium, address, current_time ) if ( - self.hs.config.email_enable_notifs - and self.hs.config.email_notif_for_new_users + self.hs.config.email.email_enable_notifs + and self.hs.config.email.email_notif_for_new_users ): await self.pusher_pool.add_pusher( user_id=user_id, diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py index aefaaa8ae8..6a7608d60b 100644 --- a/synapse/rest/client/account.py +++ b/synapse/rest/client/account.py @@ -64,17 +64,17 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): self.config = hs.config self.identity_handler = hs.get_identity_handler() - if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL: self.mailer = Mailer( hs=self.hs, - app_name=self.config.email_app_name, - template_html=self.config.email_password_reset_template_html, - template_text=self.config.email_password_reset_template_text, + app_name=self.config.email.email_app_name, + template_html=self.config.email.email_password_reset_template_html, + template_text=self.config.email.email_password_reset_template_text, ) async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: - if self.config.local_threepid_handling_disabled_due_to_email_config: + if self.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF: + if self.config.email.local_threepid_handling_disabled_due_to_email_config: logger.warning( "User password resets have been disabled due to lack of email config" ) @@ -129,7 +129,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND) - if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: + if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE: assert self.hs.config.account_threepid_delegate_email # Have the configured identity server handle the request @@ -349,17 +349,17 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): self.identity_handler = hs.get_identity_handler() self.store = self.hs.get_datastore() - if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL: self.mailer = Mailer( hs=self.hs, - app_name=self.config.email_app_name, - template_html=self.config.email_add_threepid_template_html, - template_text=self.config.email_add_threepid_template_text, + app_name=self.config.email.email_app_name, + template_html=self.config.email.email_add_threepid_template_html, + template_text=self.config.email.email_add_threepid_template_text, ) async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: - if self.config.local_threepid_handling_disabled_due_to_email_config: + if self.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF: + if self.config.email.local_threepid_handling_disabled_due_to_email_config: logger.warning( "Adding emails have been disabled due to lack of an email config" ) @@ -413,7 +413,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) - if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: + if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE: assert self.hs.config.account_threepid_delegate_email # Have the configured identity server handle the request @@ -534,21 +534,21 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet): self.config = hs.config self.clock = hs.get_clock() self.store = hs.get_datastore() - if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL: self._failure_email_template = ( - self.config.email_add_threepid_template_failure_html + self.config.email.email_add_threepid_template_failure_html ) async def on_GET(self, request: Request) -> None: - if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: - if self.config.local_threepid_handling_disabled_due_to_email_config: + if self.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF: + if self.config.email.local_threepid_handling_disabled_due_to_email_config: logger.warning( "Adding emails have been disabled due to lack of an email config" ) raise SynapseError( 400, "Adding an email to your account is disabled on this server" ) - elif self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: + elif self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE: raise SynapseError( 400, "This homeserver is not validating threepids. Use an identity server " @@ -575,7 +575,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet): return None # Otherwise show the success template - html = self.config.email_add_threepid_template_success_html_content + html = self.config.email.email_add_threepid_template_success_html_content status_code = 200 except ThreepidValidationError as e: status_code = e.code diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py index 7bb7801472..282861fae2 100644 --- a/synapse/rest/client/auth.py +++ b/synapse/rest/client/auth.py @@ -47,7 +47,7 @@ class AuthRestServlet(RestServlet): self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() - self.recaptcha_template = hs.config.recaptcha_template + self.recaptcha_template = hs.config.captcha.recaptcha_template self.terms_template = hs.config.terms_template self.registration_token_template = hs.config.registration_token_template self.success_template = hs.config.fallback_success_template @@ -62,7 +62,7 @@ class AuthRestServlet(RestServlet): session=session, myurl="%s/r0/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.RECAPTCHA), - sitekey=self.hs.config.recaptcha_public_key, + sitekey=self.hs.config.captcha.recaptcha_public_key, ) elif stagetype == LoginType.TERMS: html = self.terms_template.render( @@ -70,7 +70,7 @@ class AuthRestServlet(RestServlet): terms_url="%s_matrix/consent?v=%s" % ( self.hs.config.server.public_baseurl, - self.hs.config.user_consent_version, + self.hs.config.consent.user_consent_version, ), myurl="%s/r0/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.TERMS), @@ -118,7 +118,7 @@ class AuthRestServlet(RestServlet): session=session, myurl="%s/r0/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.RECAPTCHA), - sitekey=self.hs.config.recaptcha_public_key, + sitekey=self.hs.config.captcha.recaptcha_public_key, error=e.msg, ) else: @@ -139,7 +139,7 @@ class AuthRestServlet(RestServlet): terms_url="%s_matrix/consent?v=%s" % ( self.hs.config.server.public_baseurl, - self.hs.config.user_consent_version, + self.hs.config.consent.user_consent_version, ), myurl="%s/r0/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.TERMS), diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index a6ede7e2f3..d766e98dce 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -77,7 +77,7 @@ class LoginRestServlet(RestServlet): # SSO configuration. self.saml2_enabled = hs.config.saml2_enabled - self.cas_enabled = hs.config.cas_enabled + self.cas_enabled = hs.config.cas.cas_enabled self.oidc_enabled = hs.config.oidc_enabled self._msc2918_enabled = hs.config.access_token_lifetime is not None @@ -559,7 +559,7 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: if hs.config.access_token_lifetime is not None: RefreshTokenServlet(hs).register(http_server) SsoRedirectServlet(hs).register(http_server) - if hs.config.cas_enabled: + if hs.config.cas.cas_enabled: CasTicketServlet(hs).register(http_server) diff --git a/synapse/rest/client/password_policy.py b/synapse/rest/client/password_policy.py index 0465fd2292..9f1908004b 100644 --- a/synapse/rest/client/password_policy.py +++ b/synapse/rest/client/password_policy.py @@ -35,8 +35,8 @@ class PasswordPolicyServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() - self.policy = hs.config.password_policy - self.enabled = hs.config.password_policy_enabled + self.policy = hs.config.auth.password_policy + self.enabled = hs.config.auth.password_policy_enabled def on_GET(self, request: Request) -> Tuple[int, JsonDict]: if not self.enabled or not self.policy: diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index abe4d7e205..48b0062cf4 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -75,17 +75,19 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): self.identity_handler = hs.get_identity_handler() self.config = hs.config - if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + if self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL: self.mailer = Mailer( hs=self.hs, - app_name=self.config.email_app_name, - template_html=self.config.email_registration_template_html, - template_text=self.config.email_registration_template_text, + app_name=self.config.email.email_app_name, + template_html=self.config.email.email_registration_template_html, + template_text=self.config.email.email_registration_template_text, ) async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.OFF: - if self.hs.config.local_threepid_handling_disabled_due_to_email_config: + if self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF: + if ( + self.hs.config.email.local_threepid_handling_disabled_due_to_email_config + ): logger.warning( "Email registration has been disabled due to lack of email config" ) @@ -137,7 +139,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) - if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: + if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE: assert self.hs.config.account_threepid_delegate_email # Have the configured identity server handle the request @@ -259,9 +261,9 @@ class RegistrationSubmitTokenServlet(RestServlet): self.clock = hs.get_clock() self.store = hs.get_datastore() - if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL: self._failure_email_template = ( - self.config.email_registration_template_failure_html + self.config.email.email_registration_template_failure_html ) async def on_GET(self, request: Request, medium: str) -> None: @@ -269,8 +271,8 @@ class RegistrationSubmitTokenServlet(RestServlet): raise SynapseError( 400, "This medium is currently not supported for registration" ) - if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: - if self.config.local_threepid_handling_disabled_due_to_email_config: + if self.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF: + if self.config.email.local_threepid_handling_disabled_due_to_email_config: logger.warning( "User registration via email has been disabled due to lack of email config" ) @@ -303,7 +305,7 @@ class RegistrationSubmitTokenServlet(RestServlet): return None # Otherwise show the success template - html = self.config.email_registration_template_success_html_content + html = self.config.email.email_registration_template_success_html_content status_code = 200 except ThreepidValidationError as e: status_code = e.code @@ -897,12 +899,12 @@ def _calculate_registration_flows( flows.append([LoginType.MSISDN, LoginType.EMAIL_IDENTITY]) # Prepend m.login.terms to all flows if we're requiring consent - if config.user_consent_at_registration: + if config.consent.user_consent_at_registration: for flow in flows: flow.insert(0, LoginType.TERMS) # Prepend recaptcha to all flows if we're requiring captcha - if config.enable_registration_captcha: + if config.captcha.enable_registration_captcha: for flow in flows: flow.insert(0, LoginType.RECAPTCHA) diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py index 06e0fbde22..fc634a492d 100644 --- a/synapse/rest/consent/consent_resource.py +++ b/synapse/rest/consent/consent_resource.py @@ -84,14 +84,15 @@ class ConsentResource(DirectServeHtmlResource): # this is required by the request_handler wrapper self.clock = hs.get_clock() - self._default_consent_version = hs.config.user_consent_version - if self._default_consent_version is None: + # Consent must be configured to create this resource. + default_consent_version = hs.config.consent.user_consent_version + consent_template_directory = hs.config.consent.user_consent_template_dir + if default_consent_version is None or consent_template_directory is None: raise ConfigError( "Consent resource is enabled but user_consent section is " "missing in config file." ) - - consent_template_directory = hs.config.user_consent_template_dir + self._default_consent_version = default_consent_version # TODO: switch to synapse.util.templates.build_jinja_env loader = jinja2.FileSystemLoader(consent_template_directory) diff --git a/synapse/rest/synapse/client/password_reset.py b/synapse/rest/synapse/client/password_reset.py index f2800bf2db..28a67f04e3 100644 --- a/synapse/rest/synapse/client/password_reset.py +++ b/synapse/rest/synapse/client/password_reset.py @@ -47,20 +47,20 @@ class PasswordResetSubmitTokenResource(DirectServeHtmlResource): self.store = hs.get_datastore() self._local_threepid_handling_disabled_due_to_email_config = ( - hs.config.local_threepid_handling_disabled_due_to_email_config + hs.config.email.local_threepid_handling_disabled_due_to_email_config ) self._confirmation_email_template = ( - hs.config.email_password_reset_template_confirmation_html + hs.config.email.email_password_reset_template_confirmation_html ) self._email_password_reset_template_success_html = ( - hs.config.email_password_reset_template_success_html_content + hs.config.email.email_password_reset_template_success_html_content ) self._failure_email_template = ( - hs.config.email_password_reset_template_failure_html + hs.config.email.email_password_reset_template_failure_html ) # This resource should not be mounted if threepid behaviour is not LOCAL - assert hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL + assert hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL async def _async_render_GET(self, request: Request) -> Tuple[int, bytes]: sid = parse_string(request, "sid", required=True) diff --git a/synapse/server_notices/consent_server_notices.py b/synapse/server_notices/consent_server_notices.py index 4e0f814035..e09a25591f 100644 --- a/synapse/server_notices/consent_server_notices.py +++ b/synapse/server_notices/consent_server_notices.py @@ -36,9 +36,11 @@ class ConsentServerNotices: self._users_in_progress: Set[str] = set() - self._current_consent_version = hs.config.user_consent_version - self._server_notice_content = hs.config.user_consent_server_notice_content - self._send_to_guests = hs.config.user_consent_server_notice_to_guests + self._current_consent_version = hs.config.consent.user_consent_version + self._server_notice_content = ( + hs.config.consent.user_consent_server_notice_content + ) + self._send_to_guests = hs.config.consent.user_consent_server_notice_to_guests if self._server_notice_content is not None: if not self._server_notices_manager.is_enabled(): @@ -63,6 +65,9 @@ class ConsentServerNotices: # not enabled return + # A consent version must be given. + assert self._current_consent_version is not None + # make sure we don't send two messages to the same user at once if user_id in self._users_in_progress: return diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index e2d1b758bd..2da2659f41 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -60,7 +60,7 @@ def _make_exclusive_regex( class ApplicationServiceWorkerStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): self.services_cache = load_appservices( - hs.hostname, hs.config.app_service_config_files + hs.hostname, hs.config.appservice.app_service_config_files ) self.exclusive_user_regex = _make_exclusive_regex(self.services_cache) diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index d213b26703..b76ee51a9b 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -63,7 +63,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): """Generates current count of monthly active users broken down by service. A service is typically an appservice but also includes native matrix users. Since the `monthly_active_users` table is populated from the `user_ips` table - `config.track_appservice_user_ips` must be set to `true` for this + `config.appservice.track_appservice_user_ips` must be set to `true` for this method to return anything other than native matrix users. Returns: diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index fafadb88fc..52ef9deede 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -388,7 +388,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): "get_users_expiring_soon", select_users_txn, self._clock.time_msec(), - self.config.account_validity_renew_at, + self.config.account_validity.account_validity_renew_at, ) async def set_renewal_mail_status(self, user_id: str, email_sent: bool) -> None: diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index d4754c904c..f31880b8ec 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -545,7 +545,7 @@ def _apply_module_schemas( database_engine: config: application config """ - for (mod, _config) in config.password_providers: + for (mod, _config) in config.authproviders.password_providers: if not hasattr(mod, "get_db_schema_files"): continue modname = ".".join((mod.__module__, mod.__name__)) diff --git a/synapse/storage/schema/main/delta/30/as_users.py b/synapse/storage/schema/main/delta/30/as_users.py index 8a1f340083..22a7901e15 100644 --- a/synapse/storage/schema/main/delta/30/as_users.py +++ b/synapse/storage/schema/main/delta/30/as_users.py @@ -33,7 +33,7 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs): config_files = [] try: - config_files = config.app_service_config_files + config_files = config.appservice.app_service_config_files except AttributeError: logger.warning("Could not get app_service_config_files from config") pass diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index e798513ac1..0fa55e03b4 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -47,7 +47,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.event_creation_handler = hs.get_event_creation_handler() - hs.config.user_consent_version = "1" + hs.config.consent.user_consent_version = "1" consent_uri_builder = Mock() consent_uri_builder.build_user_consent_uri.return_value = "http://example.com" diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index f5c195a075..414c8781a9 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -97,7 +97,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): self.hs.config.enable_registration = True self.hs.config.registrations_require_3pid = [] self.hs.config.auto_join_rooms = [] - self.hs.config.enable_registration_captcha = False + self.hs.config.captcha.enable_registration_captcha = False return self.hs diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 666bffe257..ebadf47948 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -41,9 +41,8 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): self.addCleanup, federation_sender=Mock(), federation_client=Mock() ) - hs.config.app_service_config_files = self.as_yaml_files + hs.config.appservice.app_service_config_files = self.as_yaml_files hs.config.caches.event_cache_size = 1 - hs.config.password_providers = [] self.as_token = "token1" self.as_url = "some_url" @@ -108,9 +107,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): self.addCleanup, federation_sender=Mock(), federation_client=Mock() ) - hs.config.app_service_config_files = self.as_yaml_files + hs.config.appservice.app_service_config_files = self.as_yaml_files hs.config.caches.event_cache_size = 1 - hs.config.password_providers = [] self.as_list = [ {"token": "token1", "url": "https://matrix-as.org", "id": "id_1"}, @@ -496,9 +494,8 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): self.addCleanup, federation_sender=Mock(), federation_client=Mock() ) - hs.config.app_service_config_files = [f1, f2] + hs.config.appservice.app_service_config_files = [f1, f2] hs.config.caches.event_cache_size = 1 - hs.config.password_providers = [] database = hs.get_datastores().databases[0] ApplicationServiceStore( @@ -514,7 +511,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): self.addCleanup, federation_sender=Mock(), federation_client=Mock() ) - hs.config.app_service_config_files = [f1, f2] + hs.config.appservice.app_service_config_files = [f1, f2] hs.config.caches.event_cache_size = 1 hs.config.password_providers = [] @@ -540,9 +537,8 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): self.addCleanup, federation_sender=Mock(), federation_client=Mock() ) - hs.config.app_service_config_files = [f1, f2] + hs.config.appservice.app_service_config_files = [f1, f2] hs.config.caches.event_cache_size = 1 - hs.config.password_providers = [] with self.assertRaises(ConfigError) as cm: database = hs.get_datastores().databases[0] diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index da98733ce8..7cc5e621ba 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -258,7 +258,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): info, _ = self.get_success(self.room_creator.create_room(self.requester, {})) self.room_id = info["room_id"] self.event_creator = homeserver.get_event_creation_handler() - homeserver.config.user_consent_version = self.CONSENT_VERSION + homeserver.config.consent.user_consent_version = self.CONSENT_VERSION def test_send_dummy_event(self): self._create_extremity_rich_graph() From dcfd8649704bd0a05bfbffdd96d60fc2b1913a2f Mon Sep 17 00:00:00 2001 From: David Robertson Date: Thu, 23 Sep 2021 13:02:13 +0100 Subject: [PATCH 12/31] Fix reactivated users not being added to the user directory (#10782) Co-authored-by: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Co-authored-by: reivilibre Co-authored-by: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> --- changelog.d/10782.bugfix | 1 + synapse/handlers/deactivate_account.py | 9 ++++-- tests/handlers/test_user_directory.py | 42 +++++++++++++++++++++++++- 3 files changed, 48 insertions(+), 4 deletions(-) create mode 100644 changelog.d/10782.bugfix diff --git a/changelog.d/10782.bugfix b/changelog.d/10782.bugfix new file mode 100644 index 0000000000..3e410447cc --- /dev/null +++ b/changelog.d/10782.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug which caused deactivated users that were later reactivated to be missing from the user directory. \ No newline at end of file diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index a03ff9842b..9ae5b7750e 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -255,13 +255,16 @@ class DeactivateAccountHandler(BaseHandler): Args: user_id: ID of user to be re-activated """ - # Add the user to the directory, if necessary. user = UserID.from_string(user_id) - profile = await self.store.get_profileinfo(user.localpart) - await self.user_directory_handler.handle_local_profile_change(user_id, profile) # Ensure the user is not marked as erased. await self.store.mark_user_not_erased(user_id) # Mark the user as active. await self.store.set_user_deactivated_status(user_id, False) + + # Add the user to the directory, if necessary. Note that + # this must be done after the user is re-activated, because + # deactivated users are excluded from the user directory. + profile = await self.store.get_profileinfo(user.localpart) + await self.user_directory_handler.handle_local_profile_change(user_id, profile) diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index ae88ed89aa..f3684c34a2 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from unittest.mock import Mock +from urllib.parse import quote from twisted.internet import defer @@ -20,6 +21,7 @@ from synapse.api.constants import UserTypes from synapse.api.room_versions import RoomVersion, RoomVersions from synapse.rest.client import login, room, user_directory from synapse.storage.roommember import ProfileInfo +from synapse.types import create_requester from tests import unittest from tests.unittest import override_config @@ -32,7 +34,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): servlets = [ login.register_servlets, - synapse.rest.admin.register_servlets_for_client_rest_resource, + synapse.rest.admin.register_servlets, room.register_servlets, ] @@ -130,6 +132,44 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.get_success(self.handler.handle_local_user_deactivated(r_user_id)) self.store.remove_from_user_dir.called_once_with(r_user_id) + def test_reactivation_makes_regular_user_searchable(self): + user = self.register_user("regular", "pass") + user_token = self.login(user, "pass") + admin_user = self.register_user("admin", "pass", admin=True) + admin_token = self.login(admin_user, "pass") + + # Ensure the regular user is publicly visible and searchable. + self.helper.create_room_as(user, is_public=True, tok=user_token) + s = self.get_success(self.handler.search_users(admin_user, user, 10)) + self.assertEqual(len(s["results"]), 1) + self.assertEqual(s["results"][0]["user_id"], user) + + # Deactivate the user and check they're not searchable. + deactivate_handler = self.hs.get_deactivate_account_handler() + self.get_success( + deactivate_handler.deactivate_account( + user, erase_data=False, requester=create_requester(admin_user) + ) + ) + s = self.get_success(self.handler.search_users(admin_user, user, 10)) + self.assertEqual(s["results"], []) + + # Reactivate the user + channel = self.make_request( + "PUT", + f"/_synapse/admin/v2/users/{quote(user)}", + access_token=admin_token, + content={"deactivated": False, "password": "pass"}, + ) + self.assertEqual(channel.code, 200) + user_token = self.login(user, "pass") + self.helper.create_room_as(user, is_public=True, tok=user_token) + + # Check they're searchable. + s = self.get_success(self.handler.search_users(admin_user, user, 10)) + self.assertEqual(len(s["results"]), 1) + self.assertEqual(s["results"][0]["user_id"], user) + def test_private_room(self): """ A user can be searched for only by people that are either in a public From a10988983a1cd145fc5ae57c9a00ea95fbaece61 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Thu, 23 Sep 2021 14:45:32 +0100 Subject: [PATCH 13/31] Break down cache expiry reasons in grafana (#10880) A follow-up to #10829 --- changelog.d/10880.misc | 1 + contrib/grafana/synapse.json | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) create mode 100644 changelog.d/10880.misc diff --git a/changelog.d/10880.misc b/changelog.d/10880.misc new file mode 100644 index 0000000000..5f58d6198c --- /dev/null +++ b/changelog.d/10880.misc @@ -0,0 +1 @@ +Break down Grafana's cache expiry time series based on reason for eviction---see #10829. \ No newline at end of file diff --git a/contrib/grafana/synapse.json b/contrib/grafana/synapse.json index ed1e8ba7f8..2c839c30d0 100644 --- a/contrib/grafana/synapse.json +++ b/contrib/grafana/synapse.json @@ -6785,7 +6785,7 @@ "expr": "rate(synapse_util_caches_cache:evicted_size{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size])", "format": "time_series", "intervalFactor": 1, - "legendFormat": "{{name}} {{job}}-{{index}}", + "legendFormat": "{{name}} ({{reason}}) {{job}}-{{index}}", "refId": "A" } ], @@ -10888,5 +10888,5 @@ "timezone": "", "title": "Synapse", "uid": "000000012", - "version": 99 + "version": 100 } \ No newline at end of file From 47854c71e9bded2c446a251f3ef16f4d5da96ebe Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 23 Sep 2021 12:03:01 -0400 Subject: [PATCH 14/31] Use direct references for configuration variables (part 4). (#10893) --- changelog.d/10893.misc | 1 + synapse/api/urls.py | 4 ++-- synapse/app/_base.py | 6 ++++-- synapse/app/admin_cmd.py | 2 +- synapse/app/generic_worker.py | 4 ++-- synapse/app/homeserver.py | 10 +++++----- synapse/app/phone_stats_home.py | 8 +++++--- synapse/config/logger.py | 2 +- synapse/federation/transport/server/_base.py | 4 +++- synapse/groups/groups_server.py | 6 +++--- synapse/handlers/auth.py | 2 +- synapse/handlers/oidc.py | 2 +- synapse/handlers/profile.py | 2 +- synapse/http/matrixfederationclient.py | 5 +++-- synapse/push/httppusher.py | 4 +++- synapse/rest/client/login.py | 12 ++++++------ synapse/rest/consent/consent_resource.py | 4 ++-- synapse/rest/key/v2/local_key_resource.py | 10 +++++----- synapse/rest/key/v2/remote_key_resource.py | 6 ++++-- synapse/rest/media/v1/media_repository.py | 4 +++- synapse/rest/synapse/client/__init__.py | 2 +- synapse/storage/databases/main/roommember.py | 2 +- tests/api/test_auth.py | 4 ++-- tests/app/test_phone_stats_home.py | 2 +- tests/config/test_load.py | 10 +++++----- tests/config/test_ratelimiting.py | 2 +- tests/handlers/test_auth.py | 2 +- tests/replication/_base.py | 2 +- tests/rest/client/test_login.py | 12 ++++++------ tests/rest/client/test_register.py | 2 +- tests/storage/test_appservice.py | 1 - tests/util/test_ratelimitutils.py | 2 +- 32 files changed, 77 insertions(+), 64 deletions(-) create mode 100644 changelog.d/10893.misc diff --git a/changelog.d/10893.misc b/changelog.d/10893.misc new file mode 100644 index 0000000000..586a0b3a96 --- /dev/null +++ b/changelog.d/10893.misc @@ -0,0 +1 @@ +Use direct references to config flags. diff --git a/synapse/api/urls.py b/synapse/api/urls.py index d3270cd6d2..032c69b210 100644 --- a/synapse/api/urls.py +++ b/synapse/api/urls.py @@ -39,12 +39,12 @@ class ConsentURIBuilder: Args: hs_config (synapse.config.homeserver.HomeServerConfig): """ - if hs_config.form_secret is None: + if hs_config.key.form_secret is None: raise ConfigError("form_secret not set in config") if hs_config.server.public_baseurl is None: raise ConfigError("public_baseurl not set in config") - self._hmac_secret = hs_config.form_secret.encode("utf-8") + self._hmac_secret = hs_config.key.form_secret.encode("utf-8") self._public_baseurl = hs_config.server.public_baseurl def build_user_consent_uri(self, user_id): diff --git a/synapse/app/_base.py b/synapse/app/_base.py index d1aa2e7fb5..f657f11f76 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -424,12 +424,14 @@ def setup_sentry(hs): hs (synapse.server.HomeServer) """ - if not hs.config.sentry_enabled: + if not hs.config.metrics.sentry_enabled: return import sentry_sdk - sentry_sdk.init(dsn=hs.config.sentry_dsn, release=get_version_string(synapse)) + sentry_sdk.init( + dsn=hs.config.metrics.sentry_dsn, release=get_version_string(synapse) + ) # We set some default tags that give some context to this instance with sentry_sdk.configure_scope() as scope: diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index 5e956b1e27..259d5ec7cc 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -192,7 +192,7 @@ def start(config_options): ): # Since we're meant to be run as a "command" let's not redirect stdio # unless we've actually set log config. - config.no_redirect_stdio = True + config.logging.no_redirect_stdio = True # Explicitly disable background processes config.update_user_directory = False diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 33afd59c72..e0776689ce 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -395,7 +395,7 @@ class GenericWorkerServer(HomeServer): manhole_globals={"hs": self}, ) elif listener.type == "metrics": - if not self.config.enable_metrics: + if not self.config.metrics.enable_metrics: logger.warning( "Metrics listener configured, but " "enable_metrics is not True!" @@ -488,7 +488,7 @@ def start(config_options): register_start(_base.start, hs) # redirect stdio to the logs, if configured. - if not hs.config.no_redirect_stdio: + if not hs.config.logging.no_redirect_stdio: redirect_stdio_to_logs() _base.start_worker_reactor("synapse-generic-worker", config) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 886e291e4c..f1769f146b 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -269,7 +269,7 @@ class SynapseHomeServer(HomeServer): # https://twistedmatrix.com/trac/ticket/7678 resources[WEB_CLIENT_PREFIX] = File(webclient_loc) - if name == "metrics" and self.config.enable_metrics: + if name == "metrics" and self.config.metrics.enable_metrics: resources[METRICS_PREFIX] = MetricsResource(RegistryProxy) if name == "replication": @@ -278,7 +278,7 @@ class SynapseHomeServer(HomeServer): return resources def start_listening(self): - if self.config.redis_enabled: + if self.config.redis.redis_enabled: # If redis is enabled we connect via the replication command handler # in the same way as the workers (since we're effectively a client # rather than a server). @@ -305,7 +305,7 @@ class SynapseHomeServer(HomeServer): for s in services: reactor.addSystemEventTrigger("before", "shutdown", s.stopListening) elif listener.type == "metrics": - if not self.config.enable_metrics: + if not self.config.metrics.enable_metrics: logger.warning( "Metrics listener configured, but " "enable_metrics is not True!" @@ -366,7 +366,7 @@ def setup(config_options): async def start(): # Load the OIDC provider metadatas, if OIDC is enabled. - if hs.config.oidc_enabled: + if hs.config.oidc.oidc_enabled: oidc = hs.get_oidc_handler() # Loading the provider metadata also ensures the provider config is valid. await oidc.load_metadata() @@ -455,7 +455,7 @@ def main(): hs = setup(sys.argv[1:]) # redirect stdio to the logs, if configured. - if not hs.config.no_redirect_stdio: + if not hs.config.logging.no_redirect_stdio: redirect_stdio_to_logs() run(hs) diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py index 4a95da90f9..49e7a45e5c 100644 --- a/synapse/app/phone_stats_home.py +++ b/synapse/app/phone_stats_home.py @@ -131,10 +131,12 @@ async def phone_stats_home(hs, stats, stats_process=_stats_process): log_level = synapse_logger.getEffectiveLevel() stats["log_level"] = logging.getLevelName(log_level) - logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats)) + logger.info( + "Reporting stats to %s: %s" % (hs.config.metrics.report_stats_endpoint, stats) + ) try: await hs.get_proxied_http_client().put_json( - hs.config.report_stats_endpoint, stats + hs.config.metrics.report_stats_endpoint, stats ) except Exception as e: logger.warning("Error reporting stats: %s", e) @@ -188,7 +190,7 @@ def start_phone_stats_home(hs): clock.looping_call(generate_monthly_active_users, 5 * 60 * 1000) # End of monthly active user settings - if hs.config.report_stats: + if hs.config.metrics.report_stats: logger.info("Scheduling stats reporting for 3 hour intervals") clock.looping_call(phone_stats_home, 3 * 60 * 60 * 1000, hs, stats) diff --git a/synapse/config/logger.py b/synapse/config/logger.py index aca9d467e6..bf8ca7d5fe 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -322,7 +322,7 @@ def setup_logging( """ log_config_path = ( - config.worker_log_config if use_worker_options else config.log_config + config.worker_log_config if use_worker_options else config.logging.log_config ) # Perform one-time logging configuration. diff --git a/synapse/federation/transport/server/_base.py b/synapse/federation/transport/server/_base.py index 624c859f1e..cef65929c5 100644 --- a/synapse/federation/transport/server/_base.py +++ b/synapse/federation/transport/server/_base.py @@ -49,7 +49,9 @@ class Authenticator: self.keyring = hs.get_keyring() self.server_name = hs.hostname self.store = hs.get_datastore() - self.federation_domain_whitelist = hs.config.federation_domain_whitelist + self.federation_domain_whitelist = ( + hs.config.federation.federation_domain_whitelist + ) self.notifier = hs.get_notifier() self.replication_client = None diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py index d6b75ac27f..449bbc7004 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py @@ -847,16 +847,16 @@ class GroupsServerHandler(GroupsServerWorkerHandler): UserID.from_string(requester_user_id) ) if not is_admin: - if not self.hs.config.enable_group_creation: + if not self.hs.config.groups.enable_group_creation: raise SynapseError( 403, "Only a server admin can create groups on this server" ) localpart = group_id_obj.localpart - if not localpart.startswith(self.hs.config.group_creation_prefix): + if not localpart.startswith(self.hs.config.groups.group_creation_prefix): raise SynapseError( 400, "Can only create groups with prefix %r on this server" - % (self.hs.config.group_creation_prefix,), + % (self.hs.config.groups.group_creation_prefix,), ) profile = content.get("profile", {}) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index b747f80bc1..0f80dfdc43 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -1802,7 +1802,7 @@ class MacaroonGenerator: macaroon = pymacaroons.Macaroon( location=self.hs.config.server.server_name, identifier="key", - key=self.hs.config.macaroon_secret_key, + key=self.hs.config.key.macaroon_secret_key, ) macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index aed5a40a78..3665d91513 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -277,7 +277,7 @@ class OidcProvider: self._token_generator = token_generator self._config = provider - self._callback_url: str = hs.config.oidc_callback_url + self._callback_url: str = hs.config.oidc.oidc_callback_url # Calculate the prefix for OIDC callback paths based on the public_baseurl. # We'll insert this into the Path= parameter of any session cookies we set. diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index f06070bfcf..b23a1541bc 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -309,7 +309,7 @@ class ProfileHandler(BaseHandler): async def on_profile_query(self, args: JsonDict) -> JsonDict: """Handles federation profile query requests.""" - if not self.hs.config.allow_profile_lookup_over_federation: + if not self.hs.config.federation.allow_profile_lookup_over_federation: raise SynapseError( 403, "Profile lookup over federation is disabled on this homeserver", diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index e56fa477bb..cdc36b8d25 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -465,8 +465,9 @@ class MatrixFederationHttpClient: _sec_timeout = self.default_timeout if ( - self.hs.config.federation_domain_whitelist is not None - and request.destination not in self.hs.config.federation_domain_whitelist + self.hs.config.federation.federation_domain_whitelist is not None + and request.destination + not in self.hs.config.federation.federation_domain_whitelist ): raise FederationDeniedError(request.destination) diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 065948f982..eac65572b2 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -73,7 +73,9 @@ class HttpPusher(Pusher): self.failing_since = pusher_config.failing_since self.timed_call: Optional[IDelayedCall] = None self._is_processing = False - self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room + self._group_unread_count_by_room = ( + hs.config.push.push_group_unread_count_by_room + ) self._pusherpool = hs.get_pusherpool() self.data = pusher_config.data diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index d766e98dce..64446fc486 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -69,16 +69,16 @@ class LoginRestServlet(RestServlet): self.hs = hs # JWT configuration variables. - self.jwt_enabled = hs.config.jwt_enabled - self.jwt_secret = hs.config.jwt_secret - self.jwt_algorithm = hs.config.jwt_algorithm - self.jwt_issuer = hs.config.jwt_issuer - self.jwt_audiences = hs.config.jwt_audiences + self.jwt_enabled = hs.config.jwt.jwt_enabled + self.jwt_secret = hs.config.jwt.jwt_secret + self.jwt_algorithm = hs.config.jwt.jwt_algorithm + self.jwt_issuer = hs.config.jwt.jwt_issuer + self.jwt_audiences = hs.config.jwt.jwt_audiences # SSO configuration. self.saml2_enabled = hs.config.saml2_enabled self.cas_enabled = hs.config.cas.cas_enabled - self.oidc_enabled = hs.config.oidc_enabled + self.oidc_enabled = hs.config.oidc.oidc_enabled self._msc2918_enabled = hs.config.access_token_lifetime is not None self.auth = hs.get_auth() diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py index fc634a492d..3d2afacc50 100644 --- a/synapse/rest/consent/consent_resource.py +++ b/synapse/rest/consent/consent_resource.py @@ -100,13 +100,13 @@ class ConsentResource(DirectServeHtmlResource): loader=loader, autoescape=jinja2.select_autoescape(["html", "htm", "xml"]) ) - if hs.config.form_secret is None: + if hs.config.key.form_secret is None: raise ConfigError( "Consent resource is enabled but form_secret is not set in " "config file. It should be set to an arbitrary secret string." ) - self._hmac_secret = hs.config.form_secret.encode("utf-8") + self._hmac_secret = hs.config.key.form_secret.encode("utf-8") async def _async_render_GET(self, request: Request) -> None: version = parse_string(request, "v", default=self._default_consent_version) diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py index ebe243bcfd..12b3ae120c 100644 --- a/synapse/rest/key/v2/local_key_resource.py +++ b/synapse/rest/key/v2/local_key_resource.py @@ -70,19 +70,19 @@ class LocalKey(Resource): Resource.__init__(self) def update_response_body(self, time_now_msec: int) -> None: - refresh_interval = self.config.key_refresh_interval + refresh_interval = self.config.key.key_refresh_interval self.valid_until_ts = int(time_now_msec + refresh_interval) self.response_body = encode_canonical_json(self.response_json_object()) def response_json_object(self) -> JsonDict: verify_keys = {} - for key in self.config.signing_key: + for key in self.config.key.signing_key: verify_key_bytes = key.verify_key.encode() key_id = "%s:%s" % (key.alg, key.version) verify_keys[key_id] = {"key": encode_base64(verify_key_bytes)} old_verify_keys = {} - for key_id, key in self.config.old_signing_keys.items(): + for key_id, key in self.config.key.old_signing_keys.items(): verify_key_bytes = key.encode() old_verify_keys[key_id] = { "key": encode_base64(verify_key_bytes), @@ -95,13 +95,13 @@ class LocalKey(Resource): "verify_keys": verify_keys, "old_verify_keys": old_verify_keys, } - for key in self.config.signing_key: + for key in self.config.key.signing_key: json_object = sign_json(json_object, self.config.server.server_name, key) return json_object def render_GET(self, request: Request) -> int: time_now = self.clock.time_msec() # Update the expiry time if less than half the interval remains. - if time_now + self.config.key_refresh_interval / 2 > self.valid_until_ts: + if time_now + self.config.key.key_refresh_interval / 2 > self.valid_until_ts: self.update_response_body(time_now) return respond_with_json_bytes(request, 200, self.response_body) diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index d8fd7938a4..c111a9d20f 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -97,7 +97,9 @@ class RemoteKey(DirectServeJsonResource): self.fetcher = ServerKeyFetcher(hs) self.store = hs.get_datastore() self.clock = hs.get_clock() - self.federation_domain_whitelist = hs.config.federation_domain_whitelist + self.federation_domain_whitelist = ( + hs.config.federation.federation_domain_whitelist + ) self.config = hs.config async def _async_render_GET(self, request: Request) -> None: @@ -235,7 +237,7 @@ class RemoteKey(DirectServeJsonResource): signed_keys = [] for key_json in json_results: key_json = json_decoder.decode(key_json.decode("utf-8")) - for signing_key in self.config.key_server_signing_keys: + for signing_key in self.config.key.key_server_signing_keys: key_json = sign_json( key_json, self.config.server.server_name, signing_key ) diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 50e4c9e29f..a30007a1e2 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -92,7 +92,9 @@ class MediaRepository: self.recently_accessed_remotes: Set[Tuple[str, str]] = set() self.recently_accessed_locals: Set[str] = set() - self.federation_domain_whitelist = hs.config.federation_domain_whitelist + self.federation_domain_whitelist = ( + hs.config.federation.federation_domain_whitelist + ) # List of StorageProviders where we should search for media and # potentially upload to. diff --git a/synapse/rest/synapse/client/__init__.py b/synapse/rest/synapse/client/__init__.py index 47a2f72b32..086c80b723 100644 --- a/synapse/rest/synapse/client/__init__.py +++ b/synapse/rest/synapse/client/__init__.py @@ -45,7 +45,7 @@ def build_synapse_client_resource_tree(hs: "HomeServer") -> Mapping[str, Resourc # provider-specific SSO bits. Only load these if they are enabled, since they # rely on optional dependencies. - if hs.config.oidc_enabled: + if hs.config.oidc.oidc_enabled: from synapse.rest.synapse.client.oidc import OIDCResource resources["/_synapse/client/oidc"] = OIDCResource(hs) diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index a4ec6bc328..ddb162a4fc 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -82,7 +82,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): if ( self.hs.config.worker.run_background_tasks - and self.hs.config.metrics_flags.known_servers + and self.hs.config.metrics.metrics_flags.known_servers ): self._known_servers_count = 1 self.hs.get_clock().looping_call( diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index f76fea4f66..8a4ef13054 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -217,7 +217,7 @@ class AuthTestCase(unittest.HomeserverTestCase): macaroon = pymacaroons.Macaroon( location=self.hs.config.server_name, identifier="key", - key=self.hs.config.macaroon_secret_key, + key=self.hs.config.key.macaroon_secret_key, ) macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("type = access") @@ -239,7 +239,7 @@ class AuthTestCase(unittest.HomeserverTestCase): macaroon = pymacaroons.Macaroon( location=self.hs.config.server_name, identifier="key", - key=self.hs.config.macaroon_secret_key, + key=self.hs.config.key.macaroon_secret_key, ) macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("type = access") diff --git a/tests/app/test_phone_stats_home.py b/tests/app/test_phone_stats_home.py index d66aeb00eb..19eb4c79d0 100644 --- a/tests/app/test_phone_stats_home.py +++ b/tests/app/test_phone_stats_home.py @@ -172,7 +172,7 @@ class PhoneHomeR30V2TestCase(HomeserverTestCase): # We don't want our tests to actually report statistics, so check # that it's not enabled - assert not hs.config.report_stats + assert not hs.config.metrics.report_stats # This starts the needed data collection that we rely on to calculate # R30v2 metrics. diff --git a/tests/config/test_load.py b/tests/config/test_load.py index 903c69127d..ef6c2beec7 100644 --- a/tests/config/test_load.py +++ b/tests/config/test_load.py @@ -52,10 +52,10 @@ class ConfigLoadingTestCase(unittest.TestCase): hasattr(config, "macaroon_secret_key"), "Want config to have attr macaroon_secret_key", ) - if len(config.macaroon_secret_key) < 5: + if len(config.key.macaroon_secret_key) < 5: self.fail( "Want macaroon secret key to be string of at least length 5," - "was: %r" % (config.macaroon_secret_key,) + "was: %r" % (config.key.macaroon_secret_key,) ) config = HomeServerConfig.load_or_generate_config("", ["-c", self.file]) @@ -63,10 +63,10 @@ class ConfigLoadingTestCase(unittest.TestCase): hasattr(config, "macaroon_secret_key"), "Want config to have attr macaroon_secret_key", ) - if len(config.macaroon_secret_key) < 5: + if len(config.key.macaroon_secret_key) < 5: self.fail( "Want macaroon secret key to be string of at least length 5," - "was: %r" % (config.macaroon_secret_key,) + "was: %r" % (config.key.macaroon_secret_key,) ) def test_load_succeeds_if_macaroon_secret_key_missing(self): @@ -101,7 +101,7 @@ class ConfigLoadingTestCase(unittest.TestCase): # The default Metrics Flags are off by default. config = HomeServerConfig.load_config("", ["-c", self.file]) - self.assertFalse(config.metrics_flags.known_servers) + self.assertFalse(config.metrics.metrics_flags.known_servers) def generate_config(self): with redirect_stdout(StringIO()): diff --git a/tests/config/test_ratelimiting.py b/tests/config/test_ratelimiting.py index 3c7bb32e07..1b63e1adfd 100644 --- a/tests/config/test_ratelimiting.py +++ b/tests/config/test_ratelimiting.py @@ -30,7 +30,7 @@ class RatelimitConfigTestCase(TestCase): config = HomeServerConfig() config.parse_config_dict(config_dict, "", "") - config_obj = config.rc_federation + config_obj = config.ratelimiting.rc_federation self.assertEqual(config_obj.window_size, 20000) self.assertEqual(config_obj.sleep_limit, 693) diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 5f3350e490..12857053e7 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -67,7 +67,7 @@ class AuthTestCase(unittest.HomeserverTestCase): v.satisfy_general(verify_type) v.satisfy_general(verify_nonce) v.satisfy_general(verify_guest) - v.verify(macaroon, self.hs.config.macaroon_secret_key) + v.verify(macaroon, self.hs.config.key.macaroon_secret_key) def test_short_term_login_token_gives_user_id(self): token = self.macaroon_generator.generate_short_term_login_token( diff --git a/tests/replication/_base.py b/tests/replication/_base.py index e9fd991718..c7555c26db 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -328,7 +328,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): # Set up TCP replication between master and the new worker if we don't # have Redis support enabled. - if not worker_hs.config.redis_enabled: + if not worker_hs.config.redis.redis_enabled: repl_handler = ReplicationCommandHandler(worker_hs) client = ClientReplicationStreamProtocol( worker_hs, diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index 414c8781a9..371615a015 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -815,9 +815,9 @@ class JWTTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): self.hs = self.setup_test_homeserver() - self.hs.config.jwt_enabled = True - self.hs.config.jwt_secret = self.jwt_secret - self.hs.config.jwt_algorithm = self.jwt_algorithm + self.hs.config.jwt.jwt_enabled = True + self.hs.config.jwt.jwt_secret = self.jwt_secret + self.hs.config.jwt.jwt_algorithm = self.jwt_algorithm return self.hs def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str: @@ -1023,9 +1023,9 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): self.hs = self.setup_test_homeserver() - self.hs.config.jwt_enabled = True - self.hs.config.jwt_secret = self.jwt_pubkey - self.hs.config.jwt_algorithm = "RS256" + self.hs.config.jwt.jwt_enabled = True + self.hs.config.jwt.jwt_secret = self.jwt_pubkey + self.hs.config.jwt.jwt_algorithm = "RS256" return self.hs def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str: diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index 9f3ab2c985..72a5a11b46 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -146,7 +146,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.json_body["errcode"], "M_FORBIDDEN") def test_POST_guest_registration(self): - self.hs.config.macaroon_secret_key = "test" + self.hs.config.key.macaroon_secret_key = "test" self.hs.config.allow_guest_access = True channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index ebadf47948..cf9748f218 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -513,7 +513,6 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): hs.config.appservice.app_service_config_files = [f1, f2] hs.config.caches.event_cache_size = 1 - hs.config.password_providers = [] with self.assertRaises(ConfigError) as cm: database = hs.get_datastores().databases[0] diff --git a/tests/util/test_ratelimitutils.py b/tests/util/test_ratelimitutils.py index 34aaffe859..89d8656634 100644 --- a/tests/util/test_ratelimitutils.py +++ b/tests/util/test_ratelimitutils.py @@ -95,4 +95,4 @@ def build_rc_config(settings: Optional[dict] = None): config_dict.update(settings or {}) config = HomeServerConfig() config.parse_config_dict(config_dict, "", "") - return config.rc_federation + return config.ratelimiting.rc_federation From a7304adc7d383caad1b3f83fa707b1090323ecca Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 23 Sep 2021 17:34:33 +0100 Subject: [PATCH 15/31] Factor out `_get_remote_auth_chain_for_event` from `_update_auth_events_and_context_for_auth` (#10884) * Reload auth events from db after fetching and persisting In `_update_auth_events_and_context_for_auth`, when we fetch the remote auth tree and persist the returned events: load the missing events from the database rather than using the copies we got from the remote server. This is mostly in preparation for additional refactors, but does have an advantage in that if we later get around to checking the rejected status, we'll be able to make use of it. * Factor out `_get_remote_auth_chain_for_event` from `_update_auth_events_and_context_for_auth` * changelog --- changelog.d/10884.misc | 1 + synapse/handlers/federation_event.py | 124 ++++++++++++++++----------- 2 files changed, 73 insertions(+), 52 deletions(-) create mode 100644 changelog.d/10884.misc diff --git a/changelog.d/10884.misc b/changelog.d/10884.misc new file mode 100644 index 0000000000..9a765435db --- /dev/null +++ b/changelog.d/10884.misc @@ -0,0 +1 @@ +Clean up some of the federation event authentication code for clarity. diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 10b3fdc222..7d468bd2df 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -1505,61 +1505,22 @@ class FederationEventHandler: # If we don't have all the auth events, we need to get them. logger.info("auth_events contains unknown events: %s", missing_auth) try: - try: - remote_auth_chain = await self._federation_client.get_event_auth( - origin, event.room_id, event.event_id - ) - except RequestSendFailed as e1: - # The other side isn't around or doesn't implement the - # endpoint, so lets just bail out. - logger.info("Failed to get event auth from remote: %s", e1) - return context, auth_events - - seen_remotes = await self._store.have_seen_events( - event.room_id, [e.event_id for e in remote_auth_chain] + await self._get_remote_auth_chain_for_event( + origin, event.room_id, event.event_id ) - - for auth_event in remote_auth_chain: - if auth_event.event_id in seen_remotes: - continue - - if auth_event.event_id == event.event_id: - continue - - try: - auth_ids = auth_event.auth_event_ids() - auth = { - (e.type, e.state_key): e - for e in remote_auth_chain - if e.event_id in auth_ids or e.type == EventTypes.Create - } - auth_event.internal_metadata.outlier = True - - logger.debug( - "_check_event_auth %s missing_auth: %s", - event.event_id, - auth_event.event_id, - ) - missing_auth_event_context = EventContext.for_outlier() - missing_auth_event_context = await self._check_event_auth( - origin, - auth_event, - missing_auth_event_context, - claimed_auth_event_map=auth, - ) - await self.persist_events_and_notify( - event.room_id, [(auth_event, missing_auth_event_context)] - ) - - if auth_event.event_id in event_auth_events: - auth_events[ - (auth_event.type, auth_event.state_key) - ] = auth_event - except AuthError: - pass - except Exception: logger.exception("Failed to get auth chain") + else: + # load any auth events we might have persisted from the database. This + # has the side-effect of correctly setting the rejected_reason on them. + auth_events.update( + { + (ae.type, ae.state_key): ae + for ae in await self._store.get_events_as_list( + missing_auth, allow_rejected=True + ) + } + ) if event.internal_metadata.is_outlier(): # XXX: given that, for an outlier, we'll be working with the @@ -1633,6 +1594,65 @@ class FederationEventHandler: return context, auth_events + async def _get_remote_auth_chain_for_event( + self, destination: str, room_id: str, event_id: str + ) -> None: + """If we are missing some of an event's auth events, attempt to request them + + Args: + destination: where to fetch the auth tree from + room_id: the room in which we are lacking auth events + event_id: the event for which we are lacking auth events + """ + try: + remote_auth_chain = await self._federation_client.get_event_auth( + destination, room_id, event_id + ) + except RequestSendFailed as e1: + # The other side isn't around or doesn't implement the + # endpoint, so lets just bail out. + logger.info("Failed to get event auth from remote: %s", e1) + return + + seen_remotes = await self._store.have_seen_events( + room_id, [e.event_id for e in remote_auth_chain] + ) + + for auth_event in remote_auth_chain: + if auth_event.event_id in seen_remotes: + continue + + if auth_event.event_id == event_id: + continue + + try: + auth_ids = auth_event.auth_event_ids() + auth = { + (e.type, e.state_key): e + for e in remote_auth_chain + if e.event_id in auth_ids or e.type == EventTypes.Create + } + auth_event.internal_metadata.outlier = True + + logger.debug( + "_check_event_auth %s missing_auth: %s", + event_id, + auth_event.event_id, + ) + missing_auth_event_context = EventContext.for_outlier() + missing_auth_event_context = await self._check_event_auth( + destination, + auth_event, + missing_auth_event_context, + claimed_auth_event_map=auth, + ) + await self.persist_events_and_notify( + room_id, + [(auth_event, missing_auth_event_context)], + ) + except AuthError: + pass + async def _update_context_for_auth_events( self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase] ) -> EventContext: From 90d9fc750514b1ede327f1dfe6e0a1c09b281d6d Mon Sep 17 00:00:00 2001 From: Callum Brown Date: Thu, 23 Sep 2021 18:58:12 +0100 Subject: [PATCH 16/31] Allow `.` and `~` chars in registration tokens (#10887) Per updates to MSC3231 in order to use the same grammar as other identifiers. --- changelog.d/10887.bugfix | 1 + synapse/rest/admin/registration_tokens.py | 2 +- tests/rest/admin/test_registration_tokens.py | 8 +++++--- 3 files changed, 7 insertions(+), 4 deletions(-) create mode 100644 changelog.d/10887.bugfix diff --git a/changelog.d/10887.bugfix b/changelog.d/10887.bugfix new file mode 100644 index 0000000000..2d1f67489a --- /dev/null +++ b/changelog.d/10887.bugfix @@ -0,0 +1 @@ +Allow the `.` and `~` characters when creating registration tokens as per the change to [MSC3231](https://github.com/matrix-org/matrix-doc/pull/3231). diff --git a/synapse/rest/admin/registration_tokens.py b/synapse/rest/admin/registration_tokens.py index 5a1c929d85..aba48f6e7b 100644 --- a/synapse/rest/admin/registration_tokens.py +++ b/synapse/rest/admin/registration_tokens.py @@ -113,7 +113,7 @@ class NewRegistrationTokenRestServlet(RestServlet): self.store = hs.get_datastore() self.clock = hs.get_clock() # A string of all the characters allowed to be in a registration_token - self.allowed_chars = string.ascii_letters + string.digits + "-_" + self.allowed_chars = string.ascii_letters + string.digits + "._~-" self.allowed_chars_set = set(self.allowed_chars) async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: diff --git a/tests/rest/admin/test_registration_tokens.py b/tests/rest/admin/test_registration_tokens.py index 4927321e5a..9bac423ae0 100644 --- a/tests/rest/admin/test_registration_tokens.py +++ b/tests/rest/admin/test_registration_tokens.py @@ -95,8 +95,10 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): def test_create_specifying_fields(self): """Create a token specifying the value of all fields.""" + # As many of the allowed characters as possible with length <= 64 + token = "adefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789._~-" data = { - "token": "abcd", + "token": token, "uses_allowed": 1, "expiry_time": self.clock.time_msec() + 1000000, } @@ -109,7 +111,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(channel.json_body["token"], "abcd") + self.assertEqual(channel.json_body["token"], token) self.assertEqual(channel.json_body["uses_allowed"], 1) self.assertEqual(channel.json_body["expiry_time"], data["expiry_time"]) self.assertEqual(channel.json_body["pending"], 0) @@ -193,7 +195,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): """Check right error is raised when server can't generate unique token.""" # Create all possible single character tokens tokens = [] - for c in string.ascii_letters + string.digits + "-_": + for c in string.ascii_letters + string.digits + "._~-": tokens.append( { "token": c, From e704cc2a48c6adc5d3da79a49ed02961edfc3b4a Mon Sep 17 00:00:00 2001 From: Kokokokoka Date: Fri, 24 Sep 2021 12:19:51 +0300 Subject: [PATCH 17/31] In `_purge_history_txn`, ensure that txn.fetchall has elements before accessing rows (#10690) This change adds a check for row existence before accessing row element, this should fix issue #10669 Signed-off-by: Vasya Boytsov vasiliy.boytsov@phystech.edu --- changelog.d/10690.bugfix | 1 + .../storage/databases/main/purge_events.py | 20 +++++++++++-------- 2 files changed, 13 insertions(+), 8 deletions(-) create mode 100644 changelog.d/10690.bugfix diff --git a/changelog.d/10690.bugfix b/changelog.d/10690.bugfix new file mode 100644 index 0000000000..059eea7464 --- /dev/null +++ b/changelog.d/10690.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug that caused an `AssertionError` when purging history in certain rooms. Contributed by @Kokokokoka. diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index bccff5e5b9..3eb30944bf 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -102,15 +102,19 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): (room_id,), ) rows = txn.fetchall() - max_depth = max(row[1] for row in rows) + # if we already have no forwards extremities (for example because they were + # cleared out by the `delete_old_current_state_events` background database + # update), then we may as well carry on. + if rows: + max_depth = max(row[1] for row in rows) - if max_depth < token.topological: - # We need to ensure we don't delete all the events from the database - # otherwise we wouldn't be able to send any events (due to not - # having any backwards extremities) - raise SynapseError( - 400, "topological_ordering is greater than forward extremeties" - ) + if max_depth < token.topological: + # We need to ensure we don't delete all the events from the database + # otherwise we wouldn't be able to send any events (due to not + # having any backwards extremities) + raise SynapseError( + 400, "topological_ordering is greater than forward extremities" + ) logger.info("[purge] looking for events to delete") From 7f3352743e02e0d02ec00eb3a50fd0ceb422286c Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 24 Sep 2021 10:38:22 +0100 Subject: [PATCH 18/31] Improve typing in user_directory files (#10891) * Improve typing in user_directory files This makes the user_directory.py in storage pass most of mypy's checks (including `no-untyped-defs`). Unfortunately that file is in the tangled web of Store class inheritance so doesn't pass mypy at the moment. The handlers directory has already been mypyed. Co-authored-by: reivilibre --- changelog.d/10891.misc | 1 + mypy.ini | 2 + .../storage/databases/main/user_directory.py | 124 +++++++++++++----- tests/handlers/test_user_directory.py | 5 +- 4 files changed, 95 insertions(+), 37 deletions(-) create mode 100644 changelog.d/10891.misc diff --git a/changelog.d/10891.misc b/changelog.d/10891.misc new file mode 100644 index 0000000000..6eecea4065 --- /dev/null +++ b/changelog.d/10891.misc @@ -0,0 +1 @@ +Improve type hinting in the user directory code. \ No newline at end of file diff --git a/mypy.ini b/mypy.ini index 3cb6cecd7e..437d0a46a5 100644 --- a/mypy.ini +++ b/mypy.ini @@ -85,9 +85,11 @@ files = tests/handlers/test_room_summary.py, tests/handlers/test_send_email.py, tests/handlers/test_sync.py, + tests/handlers/test_user_directory.py, tests/rest/client/test_login.py, tests/rest/client/test_auth.py, tests/storage/test_state.py, + tests/storage/test_user_directory.py, tests/util/test_itertools.py, tests/util/test_stream_change_cache.py diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 718f3e9976..7ca04237a5 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -14,14 +14,28 @@ import logging import re -from typing import Any, Dict, Iterable, Optional, Set, Tuple +from typing import ( + TYPE_CHECKING, + Dict, + Iterable, + List, + Optional, + Sequence, + Set, + Tuple, + cast, +) + +if TYPE_CHECKING: + from synapse.server import HomeServer from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.state import StateFilter from synapse.storage.databases.main.state_deltas import StateDeltasStore from synapse.storage.engines import PostgresEngine, Sqlite3Engine -from synapse.types import get_domain_from_id, get_localpart_from_id +from synapse.storage.types import Connection +from synapse.types import JsonDict, get_domain_from_id, get_localpart_from_id from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -36,7 +50,12 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): # add_users_who_share_private_rooms? SHARE_PRIVATE_WORKING_SET = 500 - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__( + self, + database: DatabasePool, + db_conn: Connection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.server_name = hs.hostname @@ -57,10 +76,12 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): "populate_user_directory_cleanup", self._populate_user_directory_cleanup ) - async def _populate_user_directory_createtables(self, progress, batch_size): + async def _populate_user_directory_createtables( + self, progress: JsonDict, batch_size: int + ) -> int: # Get all the rooms that we want to process. - def _make_staging_area(txn): + def _make_staging_area(txn: LoggingTransaction) -> None: sql = ( "CREATE TABLE IF NOT EXISTS " + TEMP_TABLE @@ -110,16 +131,20 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): ) return 1 - async def _populate_user_directory_cleanup(self, progress, batch_size): + async def _populate_user_directory_cleanup( + self, + progress: JsonDict, + batch_size: int, + ) -> int: """ Update the user directory stream position, then clean up the old tables. """ position = await self.db_pool.simple_select_one_onecol( - TEMP_TABLE + "_position", None, "position" + TEMP_TABLE + "_position", {}, "position" ) await self.update_user_directory_stream_pos(position) - def _delete_staging_area(txn): + def _delete_staging_area(txn: LoggingTransaction) -> None: txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_rooms") txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_users") txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_position") @@ -133,18 +158,32 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): ) return 1 - async def _populate_user_directory_process_rooms(self, progress, batch_size): + async def _populate_user_directory_process_rooms( + self, progress: JsonDict, batch_size: int + ) -> int: """ + Rescan the state of all rooms so we can track + + - who's in a public room; + - which local users share a private room with other users (local + and remote); and + - who should be in the user_directory. + Args: progress (dict) batch_size (int): Maximum number of state events to process per cycle. + + Returns: + number of events processed. """ # If we don't have progress filed, delete everything. if not progress: await self.delete_all_from_user_dir() - def _get_next_batch(txn): + def _get_next_batch( + txn: LoggingTransaction, + ) -> Optional[Sequence[Tuple[str, int]]]: # Only fetch 250 rooms, so we don't fetch too many at once, even # if those 250 rooms have less than batch_size state events. sql = """ @@ -155,7 +194,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): TEMP_TABLE + "_rooms", ) txn.execute(sql) - rooms_to_work_on = txn.fetchall() + rooms_to_work_on = cast(List[Tuple[str, int]], txn.fetchall()) if not rooms_to_work_on: return None @@ -163,7 +202,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): # Get how many are left to process, so we can give status on how # far we are in processing txn.execute("SELECT COUNT(*) FROM " + TEMP_TABLE + "_rooms") - progress["remaining"] = txn.fetchone()[0] + result = txn.fetchone() + assert result is not None + progress["remaining"] = result[0] return rooms_to_work_on @@ -261,29 +302,33 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): return processed_event_count - async def _populate_user_directory_process_users(self, progress, batch_size): + async def _populate_user_directory_process_users( + self, progress: JsonDict, batch_size: int + ) -> int: """ Add all local users to the user directory. """ - def _get_next_batch(txn): + def _get_next_batch(txn: LoggingTransaction) -> Optional[List[str]]: sql = "SELECT user_id FROM %s LIMIT %s" % ( TEMP_TABLE + "_users", str(batch_size), ) txn.execute(sql) - users_to_work_on = txn.fetchall() + user_result = cast(List[Tuple[str]], txn.fetchall()) - if not users_to_work_on: + if not user_result: return None - users_to_work_on = [x[0] for x in users_to_work_on] + users_to_work_on = [x[0] for x in user_result] # Get how many are left to process, so we can give status on how # far we are in processing sql = "SELECT COUNT(*) FROM " + TEMP_TABLE + "_users" txn.execute(sql) - progress["remaining"] = txn.fetchone()[0] + count_result = txn.fetchone() + assert count_result is not None + progress["remaining"] = count_result[0] return users_to_work_on @@ -324,7 +369,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): return len(users_to_work_on) - async def is_room_world_readable_or_publicly_joinable(self, room_id): + async def is_room_world_readable_or_publicly_joinable(self, room_id: str) -> bool: """Check if the room is either world_readable or publically joinable""" # Create a state filter that only queries join and history state event @@ -368,7 +413,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): if not isinstance(avatar_url, str): avatar_url = None - def _update_profile_in_user_dir_txn(txn): + def _update_profile_in_user_dir_txn(txn: LoggingTransaction) -> None: self.db_pool.simple_upsert_txn( txn, table="user_directory", @@ -435,7 +480,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): for user_id, other_user_id in user_id_tuples ], value_names=(), - value_values=None, + value_values=(), desc="add_users_who_share_room", ) @@ -454,14 +499,14 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): key_names=["user_id", "room_id"], key_values=[(user_id, room_id) for user_id in user_ids], value_names=(), - value_values=None, + value_values=(), desc="add_users_in_public_rooms", ) async def delete_all_from_user_dir(self) -> None: """Delete the entire user directory""" - def _delete_all_from_user_dir_txn(txn): + def _delete_all_from_user_dir_txn(txn: LoggingTransaction) -> None: txn.execute("DELETE FROM user_directory") txn.execute("DELETE FROM user_directory_search") txn.execute("DELETE FROM users_in_public_rooms") @@ -473,7 +518,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): ) @cached() - async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, Any]]: + async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, str]]: return await self.db_pool.simple_select_one( table="user_directory", keyvalues={"user_id": user_id}, @@ -497,7 +542,12 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): # add_users_who_share_private_rooms? SHARE_PRIVATE_WORKING_SET = 500 - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__( + self, + database: DatabasePool, + db_conn: Connection, + hs: "HomeServer", + ) -> None: super().__init__(database, db_conn, hs) self._prefer_local_users_in_search = ( @@ -506,7 +556,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): self._server_name = hs.config.server.server_name async def remove_from_user_dir(self, user_id: str) -> None: - def _remove_from_user_dir_txn(txn): + def _remove_from_user_dir_txn(txn: LoggingTransaction) -> None: self.db_pool.simple_delete_txn( txn, table="user_directory", keyvalues={"user_id": user_id} ) @@ -532,7 +582,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): "remove_from_user_dir", _remove_from_user_dir_txn ) - async def get_users_in_dir_due_to_room(self, room_id): + async def get_users_in_dir_due_to_room(self, room_id: str) -> Set[str]: """Get all user_ids that are in the room directory because they're in the given room_id """ @@ -565,7 +615,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): room_id """ - def _remove_user_who_share_room_txn(txn): + def _remove_user_who_share_room_txn(txn: LoggingTransaction) -> None: self.db_pool.simple_delete_txn( txn, table="users_who_share_private_rooms", @@ -586,7 +636,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): "remove_user_who_share_room", _remove_user_who_share_room_txn ) - async def get_user_dir_rooms_user_is_in(self, user_id): + async def get_user_dir_rooms_user_is_in(self, user_id: str) -> List[str]: """ Returns the rooms that a user is in. @@ -628,7 +678,9 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): A set of room ID's that the users share. """ - def _get_shared_rooms_for_users_txn(txn): + def _get_shared_rooms_for_users_txn( + txn: LoggingTransaction, + ) -> List[Dict[str, str]]: txn.execute( """ SELECT p1.room_id @@ -669,7 +721,9 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): desc="get_user_directory_stream_pos", ) - async def search_user_dir(self, user_id, search_term, limit): + async def search_user_dir( + self, user_id: str, search_term: str, limit: int + ) -> JsonDict: """Searches for users in directory Returns: @@ -705,7 +759,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): # We allow manipulating the ranking algorithm by injecting statements # based on config options. additional_ordering_statements = [] - ordering_arguments = () + ordering_arguments: Tuple[str, ...] = () if isinstance(self.database_engine, PostgresEngine): full_query, exact_query, prefix_query = _parse_query_postgres(search_term) @@ -811,7 +865,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): return {"limited": limited, "results": results} -def _parse_query_sqlite(search_term): +def _parse_query_sqlite(search_term: str) -> str: """Takes a plain unicode string from the user and converts it into a form that can be passed to database. We use this so that we can add prefix matching, which isn't something @@ -826,7 +880,7 @@ def _parse_query_sqlite(search_term): return " & ".join("(%s* OR %s)" % (result, result) for result in results) -def _parse_query_postgres(search_term): +def _parse_query_postgres(search_term: str) -> Tuple[str, str, str]: """Takes a plain unicode string from the user and converts it into a form that can be passed to database. We use this so that we can add prefix matching, which isn't something diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index f3684c34a2..ba32585a14 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import List, Tuple from unittest.mock import Mock from urllib.parse import quote @@ -325,7 +326,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): r.add((i["user_id"], i["other_user_id"], i["room_id"])) return r - def get_users_in_public_rooms(self): + def get_users_in_public_rooms(self) -> List[Tuple[str, str]]: r = self.get_success( self.store.db_pool.simple_select_list( "users_in_public_rooms", None, ("user_id", "room_id") @@ -336,7 +337,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): retval.append((i["user_id"], i["room_id"])) return retval - def get_users_who_share_private_rooms(self): + def get_users_who_share_private_rooms(self) -> List[Tuple[str, str, str]]: return self.get_success( self.store.db_pool.simple_select_list( "users_who_share_private_rooms", From fa7453638408c2c55fade2d20dba362ff23226e5 Mon Sep 17 00:00:00 2001 From: Jason Robinson Date: Fri, 24 Sep 2021 12:41:18 +0300 Subject: [PATCH 19/31] Fix AuthBlocking check when requester is appservice (#10881) If the MAU count had been reached, Synapse incorrectly blocked appservice users even though they've been explicitly configured not to be tracked (the default). This was due to bypassing the relevant if as it was chained behind another earlier hit if as an elif. Signed-off-by: Jason Robinson --- changelog.d/10881.bugfix | 1 + synapse/api/auth_blocking.py | 2 +- tests/api/test_auth.py | 62 ++++++++++++++++++++++++++++++++++++ 3 files changed, 64 insertions(+), 1 deletion(-) create mode 100644 changelog.d/10881.bugfix diff --git a/changelog.d/10881.bugfix b/changelog.d/10881.bugfix new file mode 100644 index 0000000000..0a8905cc46 --- /dev/null +++ b/changelog.d/10881.bugfix @@ -0,0 +1 @@ +Fix application service users being subject to MAU blocking if MAU had been reached, even if configured not to be blocked. diff --git a/synapse/api/auth_blocking.py b/synapse/api/auth_blocking.py index a3b95f4de0..08fe160c98 100644 --- a/synapse/api/auth_blocking.py +++ b/synapse/api/auth_blocking.py @@ -81,7 +81,7 @@ class AuthBlocking: # We never block the server from doing actions on behalf of # users. return - elif requester.app_service and not self._track_appservice_user_ips: + if requester.app_service and not self._track_appservice_user_ips: # If we're authenticated as an appservice then we only block # auth if `track_appservice_user_ips` is set, as that option # implicitly means that application services are part of MAU diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 8a4ef13054..cccff7af26 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -25,7 +25,9 @@ from synapse.api.errors import ( MissingClientTokenError, ResourceLimitError, ) +from synapse.appservice import ApplicationService from synapse.storage.databases.main.registration import TokenLookupResult +from synapse.types import Requester from tests import unittest from tests.test_utils import simple_async_mock @@ -290,6 +292,66 @@ class AuthTestCase(unittest.HomeserverTestCase): # Real users not allowed self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError) + def test_blocking_mau__appservice_requester_allowed_when_not_tracking_ips(self): + self.auth_blocking._max_mau_value = 50 + self.auth_blocking._limit_usage_by_mau = True + self.auth_blocking._track_appservice_user_ips = False + + self.store.get_monthly_active_count = simple_async_mock(100) + self.store.user_last_seen_monthly_active = simple_async_mock() + self.store.is_trial_user = simple_async_mock() + + appservice = ApplicationService( + "abcd", + self.hs.config.server_name, + id="1234", + namespaces={ + "users": [{"regex": "@_appservice.*:sender", "exclusive": True}] + }, + sender="@appservice:sender", + ) + requester = Requester( + user="@appservice:server", + access_token_id=None, + device_id="FOOBAR", + is_guest=False, + shadow_banned=False, + app_service=appservice, + authenticated_entity="@appservice:server", + ) + self.get_success(self.auth.check_auth_blocking(requester=requester)) + + def test_blocking_mau__appservice_requester_disallowed_when_tracking_ips(self): + self.auth_blocking._max_mau_value = 50 + self.auth_blocking._limit_usage_by_mau = True + self.auth_blocking._track_appservice_user_ips = True + + self.store.get_monthly_active_count = simple_async_mock(100) + self.store.user_last_seen_monthly_active = simple_async_mock() + self.store.is_trial_user = simple_async_mock() + + appservice = ApplicationService( + "abcd", + self.hs.config.server_name, + id="1234", + namespaces={ + "users": [{"regex": "@_appservice.*:sender", "exclusive": True}] + }, + sender="@appservice:sender", + ) + requester = Requester( + user="@appservice:server", + access_token_id=None, + device_id="FOOBAR", + is_guest=False, + shadow_banned=False, + app_service=appservice, + authenticated_entity="@appservice:server", + ) + self.get_failure( + self.auth.check_auth_blocking(requester=requester), ResourceLimitError + ) + def test_reserved_threepid(self): self.auth_blocking._limit_usage_by_mau = True self.auth_blocking._max_mau_value = 1 From 50022cff966a3991fbd8a1e5c98f490d9b335442 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 24 Sep 2021 11:01:25 +0100 Subject: [PATCH 20/31] Add reactor to `SynapseRequest` and fix up types. (#10868) --- changelog.d/10868.feature | 1 + synapse/http/server.py | 4 +- synapse/http/site.py | 37 ++++++++++------ synapse/rest/key/v2/remote_key_resource.py | 9 ++-- synapse/rest/media/v1/_base.py | 7 +-- synapse/rest/media/v1/config_resource.py | 4 +- synapse/rest/media/v1/download_resource.py | 5 +-- synapse/rest/media/v1/media_repository.py | 10 +++-- synapse/rest/media/v1/preview_url_resource.py | 3 +- synapse/rest/media/v1/thumbnail_resource.py | 15 +++---- synapse/rest/media/v1/upload_resource.py | 4 +- tests/http/test_additional_resource.py | 8 +++- tests/logging/test_terse_json.py | 3 +- tests/replication/test_multi_media_repo.py | 2 +- tests/rest/admin/test_admin.py | 6 +-- tests/rest/admin/test_media.py | 6 +-- tests/rest/admin/test_user.py | 2 +- tests/rest/client/test_account.py | 4 +- tests/rest/client/test_consent.py | 12 ++++-- tests/rest/client/utils.py | 2 +- tests/rest/key/v2/test_remote_key_resource.py | 4 +- tests/rest/media/v1/test_media_storage.py | 8 ++-- tests/server.py | 6 ++- tests/test_server.py | 43 ++++++++++++++----- 24 files changed, 123 insertions(+), 82 deletions(-) create mode 100644 changelog.d/10868.feature diff --git a/changelog.d/10868.feature b/changelog.d/10868.feature new file mode 100644 index 0000000000..07e7b2c6a7 --- /dev/null +++ b/changelog.d/10868.feature @@ -0,0 +1 @@ +Speed up responding with large JSON objects to requests. diff --git a/synapse/http/server.py b/synapse/http/server.py index b79fa722e9..e28b56abb9 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -320,7 +320,7 @@ class DirectServeJsonResource(_AsyncResource): def _send_response( self, - request: Request, + request: SynapseRequest, code: int, response_object: Any, ): @@ -629,7 +629,7 @@ def _encode_json_bytes(json_object: Any) -> Iterator[bytes]: def respond_with_json( - request: Request, + request: SynapseRequest, code: int, json_object: Any, send_cors: bool = False, diff --git a/synapse/http/site.py b/synapse/http/site.py index dd4c749e16..755ad56637 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -14,13 +14,14 @@ import contextlib import logging import time -from typing import Optional, Tuple, Union +from typing import Generator, Optional, Tuple, Union import attr from zope.interface import implementer from twisted.internet.interfaces import IAddress, IReactorTime from twisted.python.failure import Failure +from twisted.web.http import HTTPChannel from twisted.web.resource import IResource, Resource from twisted.web.server import Request, Site @@ -61,10 +62,18 @@ class SynapseRequest(Request): logcontext: the log context for this request """ - def __init__(self, channel, *args, max_request_body_size: int = 1024, **kw): - Request.__init__(self, channel, *args, **kw) + def __init__( + self, + channel: HTTPChannel, + site: "SynapseSite", + *args, + max_request_body_size: int = 1024, + **kw, + ): + super().__init__(channel, *args, **kw) self._max_request_body_size = max_request_body_size - self.site: SynapseSite = channel.site + self.synapse_site = site + self.reactor = site.reactor self._channel = channel # this is used by the tests self.start_time = 0.0 @@ -97,7 +106,7 @@ class SynapseRequest(Request): self.get_method(), self.get_redacted_uri(), self.clientproto.decode("ascii", errors="replace"), - self.site.site_tag, + self.synapse_site.site_tag, ) def handleContentChunk(self, data: bytes) -> None: @@ -216,7 +225,7 @@ class SynapseRequest(Request): request=ContextRequest( request_id=request_id, ip_address=self.getClientIP(), - site_tag=self.site.site_tag, + site_tag=self.synapse_site.site_tag, # The requester is going to be unknown at this point. requester=None, authenticated_entity=None, @@ -228,7 +237,7 @@ class SynapseRequest(Request): ) # override the Server header which is set by twisted - self.setHeader("Server", self.site.server_version_string) + self.setHeader("Server", self.synapse_site.server_version_string) with PreserveLoggingContext(self.logcontext): # we start the request metrics timer here with an initial stab @@ -247,7 +256,7 @@ class SynapseRequest(Request): requests_counter.labels(self.get_method(), self.request_metrics.name).inc() @contextlib.contextmanager - def processing(self): + def processing(self) -> Generator[None, None, None]: """Record the fact that we are processing this request. Returns a context manager; the correct way to use this is: @@ -346,10 +355,10 @@ class SynapseRequest(Request): self.start_time, name=servlet_name, method=self.get_method() ) - self.site.access_logger.debug( + self.synapse_site.access_logger.debug( "%s - %s - Received request: %s %s", self.getClientIP(), - self.site.site_tag, + self.synapse_site.site_tag, self.get_method(), self.get_redacted_uri(), ) @@ -388,13 +397,13 @@ class SynapseRequest(Request): if authenticated_entity: requester = f"{authenticated_entity}|{requester}" - self.site.access_logger.log( + self.synapse_site.access_logger.log( log_level, "%s - %s - {%s}" " Processed request: %.3fsec/%.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)" ' %sB %s "%s %s %s" "%s" [%d dbevts]', self.getClientIP(), - self.site.site_tag, + self.synapse_site.site_tag, requester, processing_time, response_send_time, @@ -522,7 +531,7 @@ class SynapseSite(Site): site_tag: str, config: ListenerConfig, resource: IResource, - server_version_string, + server_version_string: str, max_request_body_size: int, reactor: IReactorTime, ): @@ -542,6 +551,7 @@ class SynapseSite(Site): Site.__init__(self, resource, reactor=reactor) self.site_tag = site_tag + self.reactor = reactor assert config.http_options is not None proxied = config.http_options.x_forwarded @@ -550,6 +560,7 @@ class SynapseSite(Site): def request_factory(channel, queued: bool) -> Request: return request_class( channel, + self, max_request_body_size=max_request_body_size, queued=queued, ) diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index c111a9d20f..3923ba8439 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -17,12 +17,11 @@ from typing import TYPE_CHECKING, Dict from signedjson.sign import sign_json -from twisted.web.server import Request - from synapse.api.errors import Codes, SynapseError from synapse.crypto.keyring import ServerKeyFetcher from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.servlet import parse_integer, parse_json_object_from_request +from synapse.http.site import SynapseRequest from synapse.types import JsonDict from synapse.util import json_decoder from synapse.util.async_helpers import yieldable_gather_results @@ -102,7 +101,7 @@ class RemoteKey(DirectServeJsonResource): ) self.config = hs.config - async def _async_render_GET(self, request: Request) -> None: + async def _async_render_GET(self, request: SynapseRequest) -> None: assert request.postpath is not None if len(request.postpath) == 1: (server,) = request.postpath @@ -119,7 +118,7 @@ class RemoteKey(DirectServeJsonResource): await self.query_keys(request, query, query_remote_on_cache_miss=True) - async def _async_render_POST(self, request: Request) -> None: + async def _async_render_POST(self, request: SynapseRequest) -> None: content = parse_json_object_from_request(request) query = content["server_keys"] @@ -128,7 +127,7 @@ class RemoteKey(DirectServeJsonResource): async def query_keys( self, - request: Request, + request: SynapseRequest, query: JsonDict, query_remote_on_cache_miss: bool = False, ) -> None: diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index 7c881f2bdb..014fa893d6 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -27,6 +27,7 @@ from twisted.web.server import Request from synapse.api.errors import Codes, SynapseError, cs_error from synapse.http.server import finish_request, respond_with_json +from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable from synapse.util.stringutils import is_ascii @@ -74,7 +75,7 @@ def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]: ) -def respond_404(request: Request) -> None: +def respond_404(request: SynapseRequest) -> None: respond_with_json( request, 404, @@ -84,7 +85,7 @@ def respond_404(request: Request) -> None: async def respond_with_file( - request: Request, + request: SynapseRequest, media_type: str, file_path: str, file_size: Optional[int] = None, @@ -221,7 +222,7 @@ def _can_encode_filename_as_token(x: str) -> bool: async def respond_with_responder( - request: Request, + request: SynapseRequest, responder: "Optional[Responder]", media_type: str, file_size: Optional[int], diff --git a/synapse/rest/media/v1/config_resource.py b/synapse/rest/media/v1/config_resource.py index a1d36e5cf1..712d4e8368 100644 --- a/synapse/rest/media/v1/config_resource.py +++ b/synapse/rest/media/v1/config_resource.py @@ -16,8 +16,6 @@ from typing import TYPE_CHECKING -from twisted.web.server import Request - from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.site import SynapseRequest @@ -39,5 +37,5 @@ class MediaConfigResource(DirectServeJsonResource): await self.auth.get_user_by_req(request) respond_with_json(request, 200, self.limits_dict, send_cors=True) - async def _async_render_OPTIONS(self, request: Request) -> None: + async def _async_render_OPTIONS(self, request: SynapseRequest) -> None: respond_with_json(request, 200, {}, send_cors=True) diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py index d6d938953e..6180fa575e 100644 --- a/synapse/rest/media/v1/download_resource.py +++ b/synapse/rest/media/v1/download_resource.py @@ -15,10 +15,9 @@ import logging from typing import TYPE_CHECKING -from twisted.web.server import Request - from synapse.http.server import DirectServeJsonResource, set_cors_headers from synapse.http.servlet import parse_boolean +from synapse.http.site import SynapseRequest from ._base import parse_media_id, respond_404 @@ -37,7 +36,7 @@ class DownloadResource(DirectServeJsonResource): self.media_repo = media_repo self.server_name = hs.hostname - async def _async_render_GET(self, request: Request) -> None: + async def _async_render_GET(self, request: SynapseRequest) -> None: set_cors_headers(request) request.setHeader( b"Content-Security-Policy", diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index a30007a1e2..c1bd81100d 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -23,7 +23,6 @@ import twisted.internet.error import twisted.web.http from twisted.internet.defer import Deferred from twisted.web.resource import Resource -from twisted.web.server import Request from synapse.api.errors import ( FederationDeniedError, @@ -34,6 +33,7 @@ from synapse.api.errors import ( ) from synapse.config._base import ConfigError from synapse.config.repository import ThumbnailRequirement +from synapse.http.site import SynapseRequest from synapse.logging.context import defer_to_thread from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import UserID @@ -189,7 +189,7 @@ class MediaRepository: return "mxc://%s/%s" % (self.server_name, media_id) async def get_local_media( - self, request: Request, media_id: str, name: Optional[str] + self, request: SynapseRequest, media_id: str, name: Optional[str] ) -> None: """Responds to requests for local media, if exists, or returns 404. @@ -223,7 +223,11 @@ class MediaRepository: ) async def get_remote_media( - self, request: Request, server_name: str, media_id: str, name: Optional[str] + self, + request: SynapseRequest, + server_name: str, + media_id: str, + name: Optional[str], ) -> None: """Respond to requests for remote media. diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 9ffa983fbb..128706d297 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -29,7 +29,6 @@ import attr from twisted.internet.defer import Deferred from twisted.internet.error import DNSLookupError -from twisted.web.server import Request from synapse.api.errors import Codes, SynapseError from synapse.http.client import SimpleHttpClient @@ -168,7 +167,7 @@ class PreviewUrlResource(DirectServeJsonResource): self._start_expire_url_cache_data, 10 * 1000 ) - async def _async_render_OPTIONS(self, request: Request) -> None: + async def _async_render_OPTIONS(self, request: SynapseRequest) -> None: request.setHeader(b"Allow", b"OPTIONS, GET") respond_with_json(request, 200, {}, send_cors=True) diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index 22f43d8531..cb2f88676e 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -17,11 +17,10 @@ import logging from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple -from twisted.web.server import Request - from synapse.api.errors import SynapseError from synapse.http.server import DirectServeJsonResource, set_cors_headers from synapse.http.servlet import parse_integer, parse_string +from synapse.http.site import SynapseRequest from synapse.rest.media.v1.media_storage import MediaStorage from ._base import ( @@ -57,7 +56,7 @@ class ThumbnailResource(DirectServeJsonResource): self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.server_name = hs.hostname - async def _async_render_GET(self, request: Request) -> None: + async def _async_render_GET(self, request: SynapseRequest) -> None: set_cors_headers(request) server_name, media_id, _ = parse_media_id(request) width = parse_integer(request, "width", required=True) @@ -88,7 +87,7 @@ class ThumbnailResource(DirectServeJsonResource): async def _respond_local_thumbnail( self, - request: Request, + request: SynapseRequest, media_id: str, width: int, height: int, @@ -121,7 +120,7 @@ class ThumbnailResource(DirectServeJsonResource): async def _select_or_generate_local_thumbnail( self, - request: Request, + request: SynapseRequest, media_id: str, desired_width: int, desired_height: int, @@ -186,7 +185,7 @@ class ThumbnailResource(DirectServeJsonResource): async def _select_or_generate_remote_thumbnail( self, - request: Request, + request: SynapseRequest, server_name: str, media_id: str, desired_width: int, @@ -249,7 +248,7 @@ class ThumbnailResource(DirectServeJsonResource): async def _respond_remote_thumbnail( self, - request: Request, + request: SynapseRequest, server_name: str, media_id: str, width: int, @@ -280,7 +279,7 @@ class ThumbnailResource(DirectServeJsonResource): async def _select_and_respond_with_thumbnail( self, - request: Request, + request: SynapseRequest, desired_width: int, desired_height: int, desired_method: str, diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index 146adca8f1..39b29318bb 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -16,8 +16,6 @@ import logging from typing import IO, TYPE_CHECKING, Dict, List, Optional -from twisted.web.server import Request - from synapse.api.errors import Codes, SynapseError from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.servlet import parse_bytes_from_args @@ -46,7 +44,7 @@ class UploadResource(DirectServeJsonResource): self.max_upload_size = hs.config.max_upload_size self.clock = hs.get_clock() - async def _async_render_OPTIONS(self, request: Request) -> None: + async def _async_render_OPTIONS(self, request: SynapseRequest) -> None: respond_with_json(request, 200, {}, send_cors=True) async def _async_render_POST(self, request: SynapseRequest) -> None: diff --git a/tests/http/test_additional_resource.py b/tests/http/test_additional_resource.py index 768c2ba4ea..391196425c 100644 --- a/tests/http/test_additional_resource.py +++ b/tests/http/test_additional_resource.py @@ -45,7 +45,9 @@ class AdditionalResourceTests(HomeserverTestCase): handler = _AsyncTestCustomEndpoint({}, None).handle_request resource = AdditionalResource(self.hs, handler) - channel = make_request(self.reactor, FakeSite(resource), "GET", "/") + channel = make_request( + self.reactor, FakeSite(resource, self.reactor), "GET", "/" + ) self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body, {"some_key": "some_value_async"}) @@ -54,7 +56,9 @@ class AdditionalResourceTests(HomeserverTestCase): handler = _SyncTestCustomEndpoint({}, None).handle_request resource = AdditionalResource(self.hs, handler) - channel = make_request(self.reactor, FakeSite(resource), "GET", "/") + channel = make_request( + self.reactor, FakeSite(resource, self.reactor), "GET", "/" + ) self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body, {"some_key": "some_value_sync"}) diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py index 1160716929..f73fcd684e 100644 --- a/tests/logging/test_terse_json.py +++ b/tests/logging/test_terse_json.py @@ -152,7 +152,8 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase): site = Mock(spec=["site_tag", "server_version_string", "getResourceFor"]) site.site_tag = "test-site" site.server_version_string = "Server v1" - request = SynapseRequest(FakeChannel(site, None)) + site.reactor = Mock() + request = SynapseRequest(FakeChannel(site, None), site) # Call requestReceived to finish instantiating the object. request.content = BytesIO() # Partially skip some of the internal processing of SynapseRequest. diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py index 01b1b0d4a0..13aa5eb51a 100644 --- a/tests/replication/test_multi_media_repo.py +++ b/tests/replication/test_multi_media_repo.py @@ -68,7 +68,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): resource = hs.get_media_repository_resource().children[b"download"] channel = make_request( self.reactor, - FakeSite(resource), + FakeSite(resource, self.reactor), "GET", f"/{target}/{media_id}", shorthand=False, diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index febd40b656..192073c520 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -201,7 +201,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): """Ensure a piece of media is quarantined when trying to access it.""" channel = make_request( self.reactor, - FakeSite(self.download_resource), + FakeSite(self.download_resource, self.reactor), "GET", server_and_media_id, shorthand=False, @@ -271,7 +271,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): # Attempt to access the media channel = make_request( self.reactor, - FakeSite(self.download_resource), + FakeSite(self.download_resource, self.reactor), "GET", server_name_and_media_id, shorthand=False, @@ -458,7 +458,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): # Attempt to access each piece of media channel = make_request( self.reactor, - FakeSite(self.download_resource), + FakeSite(self.download_resource, self.reactor), "GET", server_and_media_id_2, shorthand=False, diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py index 2f02934e72..f813866073 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py @@ -125,7 +125,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): # Attempt to access media channel = make_request( self.reactor, - FakeSite(download_resource), + FakeSite(download_resource, self.reactor), "GET", server_and_media_id, shorthand=False, @@ -164,7 +164,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): # Attempt to access media channel = make_request( self.reactor, - FakeSite(download_resource), + FakeSite(download_resource, self.reactor), "GET", server_and_media_id, shorthand=False, @@ -525,7 +525,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): channel = make_request( self.reactor, - FakeSite(download_resource), + FakeSite(download_resource, self.reactor), "GET", server_and_media_id, shorthand=False, diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index cc3f16c62a..e79e0e1850 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -2973,7 +2973,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): # Try to access a media and to create `last_access_ts` channel = make_request( self.reactor, - FakeSite(download_resource), + FakeSite(download_resource, self.reactor), "GET", server_and_media_id, shorthand=False, diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index b946fca8b3..9e9e953cf4 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -312,7 +312,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): # Load the password reset confirmation page channel = make_request( self.reactor, - FakeSite(self.submit_token_resource), + FakeSite(self.submit_token_resource, self.reactor), "GET", path, shorthand=False, @@ -326,7 +326,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): # Confirm the password reset channel = make_request( self.reactor, - FakeSite(self.submit_token_resource), + FakeSite(self.submit_token_resource, self.reactor), "POST", path, content=b"", diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py index 65c58ce70a..84d092ca82 100644 --- a/tests/rest/client/test_consent.py +++ b/tests/rest/client/test_consent.py @@ -61,7 +61,11 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): """You can observe the terms form without specifying a user""" resource = consent_resource.ConsentResource(self.hs) channel = make_request( - self.reactor, FakeSite(resource), "GET", "/consent?v=1", shorthand=False + self.reactor, + FakeSite(resource, self.reactor), + "GET", + "/consent?v=1", + shorthand=False, ) self.assertEqual(channel.code, 200) @@ -83,7 +87,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): ) channel = make_request( self.reactor, - FakeSite(resource), + FakeSite(resource, self.reactor), "GET", consent_uri, access_token=access_token, @@ -98,7 +102,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): # POST to the consent page, saying we've agreed channel = make_request( self.reactor, - FakeSite(resource), + FakeSite(resource, self.reactor), "POST", consent_uri + "&v=" + version, access_token=access_token, @@ -110,7 +114,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): # changed channel = make_request( self.reactor, - FakeSite(resource), + FakeSite(resource, self.reactor), "GET", consent_uri, access_token=access_token, diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index c56e45fc10..3075d3f288 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -383,7 +383,7 @@ class RestHelper: path = "/_matrix/media/r0/upload?filename=%s" % (filename,) channel = make_request( self.hs.get_reactor(), - FakeSite(resource), + FakeSite(resource, self.hs.get_reactor()), "POST", path, content=image_data, diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py index a75c0ea3f0..4672a68596 100644 --- a/tests/rest/key/v2/test_remote_key_resource.py +++ b/tests/rest/key/v2/test_remote_key_resource.py @@ -84,7 +84,7 @@ class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase): Checks that the response is a 200 and returns the decoded json body. """ channel = FakeChannel(self.site, self.reactor) - req = SynapseRequest(channel) + req = SynapseRequest(channel, self.site) req.content = BytesIO(b"") req.requestReceived( b"GET", @@ -183,7 +183,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase): ) channel = FakeChannel(self.site, self.reactor) - req = SynapseRequest(channel) + req = SynapseRequest(channel, self.site) req.content = BytesIO(encode_canonical_json(data)) req.requestReceived( diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index 9ea1c2bf25..44a643d506 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -252,7 +252,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): channel = make_request( self.reactor, - FakeSite(self.download_resource), + FakeSite(self.download_resource, self.reactor), "GET", self.media_id, shorthand=False, @@ -384,7 +384,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): params = "?width=32&height=32&method=scale" channel = make_request( self.reactor, - FakeSite(self.thumbnail_resource), + FakeSite(self.thumbnail_resource, self.reactor), "GET", self.media_id + params, shorthand=False, @@ -413,7 +413,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): channel = make_request( self.reactor, - FakeSite(self.thumbnail_resource), + FakeSite(self.thumbnail_resource, self.reactor), "GET", self.media_id + params, shorthand=False, @@ -433,7 +433,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): params = "?width=32&height=32&method=" + method channel = make_request( self.reactor, - FakeSite(self.thumbnail_resource), + FakeSite(self.thumbnail_resource, self.reactor), "GET", self.media_id + params, shorthand=False, diff --git a/tests/server.py b/tests/server.py index b861c7b866..88dfa8058e 100644 --- a/tests/server.py +++ b/tests/server.py @@ -19,6 +19,7 @@ from twisted.internet.interfaces import ( IPullProducer, IPushProducer, IReactorPluggableNameResolver, + IReactorTime, IResolverSimple, ITransport, ) @@ -181,13 +182,14 @@ class FakeSite: site_tag = "test" access_logger = logging.getLogger("synapse.access.http.fake") - def __init__(self, resource: IResource): + def __init__(self, resource: IResource, reactor: IReactorTime): """ Args: resource: the resource to be used for rendering all requests """ self._resource = resource + self.reactor = reactor def getResourceFor(self, request): return self._resource @@ -268,7 +270,7 @@ def make_request( channel = FakeChannel(site, reactor, ip=client_ip) - req = request(channel) + req = request(channel, site) req.content = BytesIO(content) # Twisted expects to be at the end of the content when parsing the request. req.content.seek(SEEK_END) diff --git a/tests/test_server.py b/tests/test_server.py index 407e172e41..f2ffbc895b 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -65,7 +65,10 @@ class JsonResourceTests(unittest.TestCase): ) make_request( - self.reactor, FakeSite(res), b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83" + self.reactor, + FakeSite(res, self.reactor), + b"GET", + b"/_matrix/foo/%E2%98%83?a=%E2%98%83", ) self.assertEqual(got_kwargs, {"room_id": "\N{SNOWMAN}"}) @@ -84,7 +87,9 @@ class JsonResourceTests(unittest.TestCase): "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" ) - channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo") + channel = make_request( + self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foo" + ) self.assertEqual(channel.result["code"], b"500") @@ -100,7 +105,7 @@ class JsonResourceTests(unittest.TestCase): def _callback(request, **kwargs): d = Deferred() d.addCallback(_throw) - self.reactor.callLater(1, d.callback, True) + self.reactor.callLater(0.5, d.callback, True) return make_deferred_yieldable(d) res = JsonResource(self.homeserver) @@ -108,7 +113,9 @@ class JsonResourceTests(unittest.TestCase): "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" ) - channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo") + channel = make_request( + self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foo" + ) self.assertEqual(channel.result["code"], b"500") @@ -126,7 +133,9 @@ class JsonResourceTests(unittest.TestCase): "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" ) - channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo") + channel = make_request( + self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foo" + ) self.assertEqual(channel.result["code"], b"403") self.assertEqual(channel.json_body["error"], "Forbidden!!one!") @@ -148,7 +157,9 @@ class JsonResourceTests(unittest.TestCase): "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" ) - channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foobar") + channel = make_request( + self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foobar" + ) self.assertEqual(channel.result["code"], b"400") self.assertEqual(channel.json_body["error"], "Unrecognized request") @@ -173,7 +184,9 @@ class JsonResourceTests(unittest.TestCase): ) # The path was registered as GET, but this is a HEAD request. - channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/_matrix/foo") + channel = make_request( + self.reactor, FakeSite(res, self.reactor), b"HEAD", b"/_matrix/foo" + ) self.assertEqual(channel.result["code"], b"200") self.assertNotIn("body", channel.result) @@ -280,7 +293,9 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): res = WrapHtmlRequestHandlerTests.TestResource() res.callback = callback - channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path") + channel = make_request( + self.reactor, FakeSite(res, self.reactor), b"GET", b"/path" + ) self.assertEqual(channel.result["code"], b"200") body = channel.result["body"] @@ -298,7 +313,9 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): res = WrapHtmlRequestHandlerTests.TestResource() res.callback = callback - channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path") + channel = make_request( + self.reactor, FakeSite(res, self.reactor), b"GET", b"/path" + ) self.assertEqual(channel.result["code"], b"301") headers = channel.result["headers"] @@ -319,7 +336,9 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): res = WrapHtmlRequestHandlerTests.TestResource() res.callback = callback - channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path") + channel = make_request( + self.reactor, FakeSite(res, self.reactor), b"GET", b"/path" + ) self.assertEqual(channel.result["code"], b"304") headers = channel.result["headers"] @@ -338,7 +357,9 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): res = WrapHtmlRequestHandlerTests.TestResource() res.callback = callback - channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/path") + channel = make_request( + self.reactor, FakeSite(res, self.reactor), b"HEAD", b"/path" + ) self.assertEqual(channel.result["code"], b"200") self.assertNotIn("body", channel.result) From 261c9763c472f0ea1ceac9729dfc3a5da2799300 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 24 Sep 2021 11:56:13 +0100 Subject: [PATCH 21/31] Simplify `_auth_and_persist_fetched_events` (#10901) Combine the two loops over the list of events, and hence get rid of `_NewEventInfo`. Also pass the event back alongside the context, so that it's easier to process the result. --- changelog.d/10901.misc | 1 + synapse/handlers/federation_event.py | 91 +++++++--------------------- 2 files changed, 23 insertions(+), 69 deletions(-) create mode 100644 changelog.d/10901.misc diff --git a/changelog.d/10901.misc b/changelog.d/10901.misc new file mode 100644 index 0000000000..9a765435db --- /dev/null +++ b/changelog.d/10901.misc @@ -0,0 +1 @@ +Clean up some of the federation event authentication code for clarity. diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 7d468bd2df..4eefcc36d8 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -27,11 +27,8 @@ from typing import ( Tuple, ) -import attr from prometheus_client import Counter -from twisted.internet import defer - from synapse import event_auth from synapse.api.constants import ( EventContentFields, @@ -54,11 +51,7 @@ from synapse.event_auth import auth_types_for_event from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.federation.federation_client import InvalidResponseError -from synapse.logging.context import ( - make_deferred_yieldable, - nested_logging_context, - run_in_background, -) +from synapse.logging.context import nested_logging_context, run_in_background from synapse.logging.utils import log_function from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet @@ -75,7 +68,11 @@ from synapse.types import ( UserID, get_domain_from_id, ) -from synapse.util.async_helpers import Linearizer, concurrently_execute +from synapse.util.async_helpers import ( + Linearizer, + concurrently_execute, + yieldable_gather_results, +) from synapse.util.iterutils import batch_iter from synapse.util.retryutils import NotRetryingDestination from synapse.util.stringutils import shortstr @@ -92,30 +89,6 @@ soft_failed_event_counter = Counter( ) -@attr.s(slots=True, frozen=True, auto_attribs=True) -class _NewEventInfo: - """Holds information about a received event, ready for passing to _auth_and_persist_events - - Attributes: - event: the received event - - claimed_auth_event_map: a map of (type, state_key) => event for the event's - claimed auth_events. - - This can include events which have not yet been persisted, in the case that - we are backfilling a batch of events. - - Note: May be incomplete: if we were unable to find all of the claimed auth - events. Also, treat the contents with caution: the events might also have - been rejected, might not yet have been authorized themselves, or they might - be in the wrong room. - - """ - - event: EventBase - claimed_auth_event_map: StateMap[EventBase] - - class FederationEventHandler: """Handles events that originated from federation. @@ -1203,47 +1176,27 @@ class FederationEventHandler: allow_rejected=True, ) - event_infos = [] - for event in fetched_events: - auth = {} - for auth_event_id in event.auth_event_ids(): - ae = persisted_events.get(auth_event_id) - if ae: - auth[(ae.type, ae.state_key)] = ae - else: - logger.info("Missing auth event %s", auth_event_id) - - event_infos.append(_NewEventInfo(event, auth)) - - if not event_infos: - return - - async def prep(ev_info: _NewEventInfo) -> EventContext: - event = ev_info.event + async def prep(event: EventBase) -> Optional[Tuple[EventBase, EventContext]]: with nested_logging_context(suffix=event.event_id): - res = EventContext.for_outlier() - res = await self._check_event_auth( + auth = {} + for auth_event_id in event.auth_event_ids(): + ae = persisted_events.get(auth_event_id) + if ae: + auth[(ae.type, ae.state_key)] = ae + else: + logger.info("Missing auth event %s", auth_event_id) + + context = EventContext.for_outlier() + context = await self._check_event_auth( origin, event, - res, - claimed_auth_event_map=ev_info.claimed_auth_event_map, + context, + claimed_auth_event_map=auth, ) - return res + return event, context - contexts = await make_deferred_yieldable( - defer.gatherResults( - [run_in_background(prep, ev_info) for ev_info in event_infos], - consumeErrors=True, - ) - ) - - await self.persist_events_and_notify( - room_id, - [ - (ev_info.event, context) - for ev_info, context in zip(event_infos, contexts) - ], - ) + events_to_persist = await yieldable_gather_results(prep, fetched_events) + await self.persist_events_and_notify(room_id, events_to_persist) async def _check_event_auth( self, From 85551b7a8555eb4e4456d5cf2db0fecd4a44621c Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 24 Sep 2021 11:56:33 +0100 Subject: [PATCH 22/31] Factor out common code for persisting fetched auth events (#10896) * Factor more stuff out of `_get_events_and_persist` It turns out that the event-sorting algorithm in `_get_events_and_persist` is also useful in other circumstances. Here we move the current `_auth_and_persist_fetched_events` to `_auth_and_persist_fetched_events_inner`, and then factor the sorting part out to `_auth_and_persist_fetched_events`. * `_get_remote_auth_chain_for_event`: remove redundant `outlier` assignment `get_event_auth` returns events with the outlier flag already set, so this is redundant (though we need to update a test where `get_event_auth` is mocked). * `_get_remote_auth_chain_for_event`: move existing-event tests earlier Move a couple of tests outside the loop. This is a bit inefficient for now, but a future commit will make it better. It should be functionally identical. * `_get_remote_auth_chain_for_event`: use `_auth_and_persist_fetched_events` We can use the same codepath for persisting the events fetched as part of an auth chain as for those fetched individually by `_get_events_and_persist` for building the state at a backwards extremity. * `_get_remote_auth_chain_for_event`: use a dict for efficiency `_auth_and_persist_fetched_events` sorts the events itself, so we no longer need to care about maintaining the ordering from `get_event_auth` (and no longer need to sort by depth in `get_event_auth`). That means that we can use a map, making it easier to filter out events we already have, etc. * changelog * `_auth_and_persist_fetched_events`: improve docstring --- changelog.d/10896.misc | 1 + synapse/federation/federation_client.py | 2 - synapse/handlers/federation_event.py | 103 +++++++++++------------- tests/handlers/test_federation.py | 7 +- 4 files changed, 55 insertions(+), 58 deletions(-) create mode 100644 changelog.d/10896.misc diff --git a/changelog.d/10896.misc b/changelog.d/10896.misc new file mode 100644 index 0000000000..41de995842 --- /dev/null +++ b/changelog.d/10896.misc @@ -0,0 +1 @@ + Clean up some of the federation event authentication code for clarity. diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 1416abd0fb..584836c04a 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -501,8 +501,6 @@ class FederationClient(FederationBase): destination, auth_chain, outlier=True, room_version=room_version ) - signed_auth.sort(key=lambda e: e.depth) - return signed_auth def _is_unknown_endpoint( diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 4eefcc36d8..8fd9e51044 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -1080,7 +1080,7 @@ class FederationEventHandler: room_version = await self._store.get_room_version(room_id) - event_map: Dict[str, EventBase] = {} + events: List[EventBase] = [] async def get_event(event_id: str) -> None: with nested_logging_context(event_id): @@ -1098,8 +1098,7 @@ class FederationEventHandler: event_id, ) return - - event_map[event.event_id] = event + events.append(event) except Exception as e: logger.warning( @@ -1110,11 +1109,29 @@ class FederationEventHandler: ) await concurrently_execute(get_event, event_ids, 5) - logger.info("Fetched %i events of %i requested", len(event_map), len(event_ids)) + logger.info("Fetched %i events of %i requested", len(events), len(event_ids)) + await self._auth_and_persist_fetched_events(destination, room_id, events) + + async def _auth_and_persist_fetched_events( + self, origin: str, room_id: str, events: Iterable[EventBase] + ) -> None: + """Persist the events fetched by _get_events_and_persist or _get_remote_auth_chain_for_event + + The events to be persisted must be outliers. + + We first sort the events to make sure that we process each event's auth_events + before the event itself, and then auth and persist them. + + Notifies about the events where appropriate. + + Params: + origin: where the events came from + room_id: the room that the events are meant to be in (though this has + not yet been checked) + events: the events that have been fetched + """ + event_map = {event.event_id: event for event in events} - # we now need to auth the events in an order which ensures that each event's - # auth_events are authed before the event itself. - # # XXX: it might be possible to kick this process off in parallel with fetching # the events. while event_map: @@ -1141,22 +1158,18 @@ class FederationEventHandler: "Persisting %i of %i remaining events", len(roots), len(event_map) ) - await self._auth_and_persist_fetched_events(destination, room_id, roots) + await self._auth_and_persist_fetched_events_inner(origin, room_id, roots) for ev in roots: del event_map[ev.event_id] - async def _auth_and_persist_fetched_events( + async def _auth_and_persist_fetched_events_inner( self, origin: str, room_id: str, fetched_events: Collection[EventBase] ) -> None: - """Persist the events fetched by _get_events_and_persist. + """Helper for _auth_and_persist_fetched_events - The events should not depend on one another, e.g. this should be used to persist - a bunch of outliers, but not a chunk of individual events that depend - on each other for state calculations. - - We also assume that all of the auth events for all of the events have already - been persisted. + Persists a batch of events where we have (theoretically) already persisted all + of their auth events. Notifies about the events where appropriate. @@ -1164,7 +1177,7 @@ class FederationEventHandler: origin: where the events came from room_id: the room that the events are meant to be in (though this has not yet been checked) - event_id: map from event_id -> event for the fetched events + fetched_events: the events to persist """ # get all the auth events for all the events in this batch. By now, they should # have been persisted. @@ -1558,53 +1571,33 @@ class FederationEventHandler: event_id: the event for which we are lacking auth events """ try: - remote_auth_chain = await self._federation_client.get_event_auth( - destination, room_id, event_id - ) + remote_event_map = { + e.event_id: e + for e in await self._federation_client.get_event_auth( + destination, room_id, event_id + ) + } except RequestSendFailed as e1: # The other side isn't around or doesn't implement the # endpoint, so lets just bail out. logger.info("Failed to get event auth from remote: %s", e1) return + logger.info("/event_auth returned %i events", len(remote_event_map)) + + # `event` may be returned, but we should not yet process it. + remote_event_map.pop(event_id, None) + + # nor should we reprocess any events we have already seen. seen_remotes = await self._store.have_seen_events( - room_id, [e.event_id for e in remote_auth_chain] + room_id, remote_event_map.keys() ) + for s in seen_remotes: + remote_event_map.pop(s, None) - for auth_event in remote_auth_chain: - if auth_event.event_id in seen_remotes: - continue - - if auth_event.event_id == event_id: - continue - - try: - auth_ids = auth_event.auth_event_ids() - auth = { - (e.type, e.state_key): e - for e in remote_auth_chain - if e.event_id in auth_ids or e.type == EventTypes.Create - } - auth_event.internal_metadata.outlier = True - - logger.debug( - "_check_event_auth %s missing_auth: %s", - event_id, - auth_event.event_id, - ) - missing_auth_event_context = EventContext.for_outlier() - missing_auth_event_context = await self._check_event_auth( - destination, - auth_event, - missing_auth_event_context, - claimed_auth_event_map=auth, - ) - await self.persist_events_and_notify( - room_id, - [(auth_event, missing_auth_event_context)], - ) - except AuthError: - pass + await self._auth_and_persist_fetched_events( + destination, room_id, remote_event_map.values() + ) async def _update_context_for_auth_events( self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase] diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index 6c67a16de9..936ebf3dde 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -308,7 +308,12 @@ class FederationTestCase(unittest.HomeserverTestCase): async def get_event_auth( destination: str, room_id: str, event_id: str ) -> List[EventBase]: - return auth_events + return [ + event_from_pdu_json( + ae.get_pdu_json(), room_version=room_version, outlier=True + ) + for ae in auth_events + ] self.handler.federation_client.get_event_auth = get_event_auth From bb7fdd821b07016a43bdbb245eda5b35356863c0 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 24 Sep 2021 07:25:21 -0400 Subject: [PATCH 23/31] Use direct references for configuration variables (part 5). (#10897) --- changelog.d/10897.misc | 1 + synapse/app/_base.py | 4 ++-- synapse/app/admin_cmd.py | 6 ++--- synapse/app/generic_worker.py | 6 ++--- synapse/app/homeserver.py | 2 +- synapse/config/logger.py | 4 +++- synapse/crypto/context_factory.py | 4 ++-- synapse/events/spamcheck.py | 2 +- synapse/events/third_party_rules.py | 4 ++-- synapse/handlers/auth.py | 10 ++++---- synapse/handlers/directory.py | 6 ++--- synapse/handlers/federation.py | 2 +- synapse/handlers/message.py | 8 +++---- synapse/handlers/register.py | 2 +- synapse/handlers/room.py | 8 ++++--- synapse/handlers/room_list.py | 2 +- synapse/handlers/room_member.py | 2 +- synapse/handlers/saml.py | 15 ++++++------ synapse/handlers/sso.py | 10 ++++---- synapse/handlers/stats.py | 2 +- synapse/handlers/user_directory.py | 2 +- synapse/logging/opentracing.py | 6 ++--- synapse/replication/http/_base.py | 4 ++-- synapse/replication/tcp/handler.py | 4 ++-- synapse/rest/admin/__init__.py | 2 +- synapse/rest/client/login.py | 2 +- synapse/rest/client/user_directory.py | 2 +- synapse/rest/client/versions.py | 6 ++--- synapse/rest/client/voip.py | 12 +++++----- synapse/rest/media/v1/config_resource.py | 2 +- synapse/rest/media/v1/media_repository.py | 20 +++++++++------- synapse/rest/media/v1/preview_url_resource.py | 10 ++++---- synapse/rest/media/v1/storage_provider.py | 2 +- synapse/rest/media/v1/thumbnail_resource.py | 2 +- synapse/rest/media/v1/upload_resource.py | 2 +- synapse/rest/synapse/client/__init__.py | 2 +- .../synapse/client/saml2/metadata_resource.py | 2 +- .../server_notices/server_notices_manager.py | 23 ++++++++++--------- .../storage/databases/main/registration.py | 2 +- synapse/storage/databases/main/stats.py | 2 +- .../storage/databases/main/user_directory.py | 4 ++-- tests/handlers/test_directory.py | 4 +++- tests/handlers/test_stats.py | 8 +++---- tests/handlers/test_user_directory.py | 6 ++--- tests/rest/admin/test_media.py | 4 ++-- tests/rest/admin/test_user.py | 2 +- tests/rest/media/v1/test_media_storage.py | 2 +- .../test_resource_limits_server_notices.py | 2 +- 48 files changed, 128 insertions(+), 113 deletions(-) create mode 100644 changelog.d/10897.misc diff --git a/changelog.d/10897.misc b/changelog.d/10897.misc new file mode 100644 index 0000000000..586a0b3a96 --- /dev/null +++ b/changelog.d/10897.misc @@ -0,0 +1 @@ +Use direct references to config flags. diff --git a/synapse/app/_base.py b/synapse/app/_base.py index f657f11f76..548f6dcde9 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -88,8 +88,8 @@ def start_worker_reactor(appname, config, run_command=reactor.run): appname, soft_file_limit=config.soft_file_limit, gc_thresholds=config.gc_thresholds, - pid_file=config.worker_pid_file, - daemonize=config.worker_daemonize, + pid_file=config.worker.worker_pid_file, + daemonize=config.worker.worker_daemonize, print_pidfile=config.print_pidfile, logger=logger, run_command=run_command, diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index 259d5ec7cc..f2c5b75247 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -186,9 +186,9 @@ def start(config_options): config.worker.worker_app = "synapse.app.admin_cmd" if ( - not config.worker_daemonize - and not config.worker_log_file - and not config.worker_log_config + not config.worker.worker_daemonize + and not config.worker.worker_log_file + and not config.worker.worker_log_config ): # Since we're meant to be run as a "command" let's not redirect stdio # unless we've actually set log config. diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index e0776689ce..3036e1b4a0 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -140,7 +140,7 @@ class KeyUploadServlet(RestServlet): self.auth = hs.get_auth() self.store = hs.get_datastore() self.http_client = hs.get_simple_http_client() - self.main_uri = hs.config.worker_main_http_uri + self.main_uri = hs.config.worker.worker_main_http_uri async def on_POST(self, request: Request, device_id: Optional[str]): requester = await self.auth.get_user_by_req(request, allow_guest=True) @@ -321,7 +321,7 @@ class GenericWorkerServer(HomeServer): elif name == "federation": resources.update({FEDERATION_PREFIX: TransportLayerServer(self)}) elif name == "media": - if self.config.can_load_media_repo: + if self.config.media.can_load_media_repo: media_repo = self.get_media_repository_resource() # We need to serve the admin servlets for media on the @@ -384,7 +384,7 @@ class GenericWorkerServer(HomeServer): logger.info("Synapse worker now listening on port %d", port) def start_listening(self): - for listener in self.config.worker_listeners: + for listener in self.config.worker.worker_listeners: if listener.type == "http": self._listen_http(listener) elif listener.type == "manhole": diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index f1769f146b..205831dcda 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -234,7 +234,7 @@ class SynapseHomeServer(HomeServer): ) if name in ["media", "federation", "client"]: - if self.config.enable_media_repo: + if self.config.media.enable_media_repo: media_repo = self.get_media_repository_resource() resources.update( {MEDIA_PREFIX: media_repo, LEGACY_MEDIA_PREFIX: media_repo} diff --git a/synapse/config/logger.py b/synapse/config/logger.py index bf8ca7d5fe..0a08231e5a 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -322,7 +322,9 @@ def setup_logging( """ log_config_path = ( - config.worker_log_config if use_worker_options else config.logging.log_config + config.worker.worker_log_config + if use_worker_options + else config.logging.log_config ) # Perform one-time logging configuration. diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py index d310976fe3..2a6110eb10 100644 --- a/synapse/crypto/context_factory.py +++ b/synapse/crypto/context_factory.py @@ -74,8 +74,8 @@ class ServerContextFactory(ContextFactory): context.set_options( SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3 | SSL.OP_NO_TLSv1 | SSL.OP_NO_TLSv1_1 ) - context.use_certificate_chain_file(config.tls_certificate_file) - context.use_privatekey(config.tls_private_key) + context.use_certificate_chain_file(config.tls.tls_certificate_file) + context.use_privatekey(config.tls.tls_private_key) # https://hynek.me/articles/hardening-your-web-servers-ssl-ciphers/ context.set_cipher_list( diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index 57f1d53fa8..19ee246f96 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -78,7 +78,7 @@ def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"): """ spam_checkers: List[Any] = [] api = hs.get_module_api() - for module, config in hs.config.spam_checkers: + for module, config in hs.config.spamchecker.spam_checkers: # Older spam checkers don't accept the `api` argument, so we # try and detect support. spam_args = inspect.getfullargspec(module) diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py index 7a6eb3e516..d94b1bb4d2 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py @@ -42,10 +42,10 @@ def load_legacy_third_party_event_rules(hs: "HomeServer"): """Wrapper that loads a third party event rules module configured using the old configuration, and registers the hooks they implement. """ - if hs.config.third_party_event_rules is None: + if hs.config.thirdpartyrules.third_party_event_rules is None: return - module, config = hs.config.third_party_event_rules + module, config = hs.config.thirdpartyrules.third_party_event_rules api = hs.get_module_api() third_party_rules = module(config=config, module_api=api) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 0f80dfdc43..a8c717efd5 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -277,23 +277,25 @@ class AuthHandler(BaseHandler): # after the SSO completes and before redirecting them back to their client. # It notifies the user they are about to give access to their matrix account # to the client. - self._sso_redirect_confirm_template = hs.config.sso_redirect_confirm_template + self._sso_redirect_confirm_template = ( + hs.config.sso.sso_redirect_confirm_template + ) # The following template is shown during user interactive authentication # in the fallback auth scenario. It notifies the user that they are # authenticating for an operation to occur on their account. - self._sso_auth_confirm_template = hs.config.sso_auth_confirm_template + self._sso_auth_confirm_template = hs.config.sso.sso_auth_confirm_template # The following template is shown during the SSO authentication process if # the account is deactivated. self._sso_account_deactivated_template = ( - hs.config.sso_account_deactivated_template + hs.config.sso.sso_account_deactivated_template ) self._server_name = hs.config.server.server_name # cast to tuple for use with str.startswith - self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist) + self._whitelisted_sso_clients = tuple(hs.config.sso.sso_client_whitelist) # A mapping of user ID to extra attributes to include in the login # response. diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index d487fee627..5cfba3c817 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -48,7 +48,7 @@ class DirectoryHandler(BaseHandler): self.event_creation_handler = hs.get_event_creation_handler() self.store = hs.get_datastore() self.config = hs.config - self.enable_room_list_search = hs.config.enable_room_list_search + self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search self.require_membership = hs.config.require_membership_for_aliases self.third_party_event_rules = hs.get_third_party_event_rules() @@ -143,7 +143,7 @@ class DirectoryHandler(BaseHandler): ): raise AuthError(403, "This user is not permitted to create this alias") - if not self.config.is_alias_creation_allowed( + if not self.config.roomdirectory.is_alias_creation_allowed( user_id, room_id, room_alias_str ): # Lets just return a generic message, as there may be all sorts of @@ -459,7 +459,7 @@ class DirectoryHandler(BaseHandler): if canonical_alias: room_aliases.append(canonical_alias) - if not self.config.is_publishing_room_allowed( + if not self.config.roomdirectory.is_publishing_room_allowed( user_id, room_id, room_aliases ): # Lets just return a generic message, as there may be all sorts of diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 4523b25636..b17ef2a9a1 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -91,7 +91,7 @@ class FederationHandler(BaseHandler): self.spam_checker = hs.get_spam_checker() self.event_creation_handler = hs.get_event_creation_handler() self._event_auth_handler = hs.get_event_auth_handler() - self._server_notices_mxid = hs.config.server_notices_mxid + self._server_notices_mxid = hs.config.servernotices.server_notices_mxid self.config = hs.config self.http_client = hs.get_proxied_blacklisted_http_client() self._replication = hs.get_replication_data_handler() diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index ad4e4a3d6f..c66aefe2c4 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -692,10 +692,10 @@ class EventCreationHandler: return False async def _is_server_notices_room(self, room_id: str) -> bool: - if self.config.server_notices_mxid is None: + if self.config.servernotices.server_notices_mxid is None: return False user_ids = await self.store.get_users_in_room(room_id) - return self.config.server_notices_mxid in user_ids + return self.config.servernotices.server_notices_mxid in user_ids async def assert_accepted_privacy_policy(self, requester: Requester) -> None: """Check if a user has accepted the privacy policy @@ -731,8 +731,8 @@ class EventCreationHandler: # exempt the system notices user if ( - self.config.server_notices_mxid is not None - and user_id == self.config.server_notices_mxid + self.config.servernotices.server_notices_mxid is not None + and user_id == self.config.servernotices.server_notices_mxid ): return diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 01c5e1385d..4f99f137a2 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -98,7 +98,7 @@ class RegistrationHandler(BaseHandler): self.macaroon_gen = hs.get_macaroon_generator() self._account_validity_handler = hs.get_account_validity_handler() self._user_consent_version = self.hs.config.consent.user_consent_version - self._server_notices_mxid = hs.config.server_notices_mxid + self._server_notices_mxid = hs.config.servernotices.server_notices_mxid self._server_name = hs.hostname self.spam_checker = hs.get_spam_checker() diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index b5768220d9..408b7d7b74 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -126,7 +126,7 @@ class RoomCreationHandler(BaseHandler): for preset_name, preset_config in self._presets_dict.items(): encrypted = ( preset_name - in self.config.encryption_enabled_by_default_for_room_presets + in self.config.room.encryption_enabled_by_default_for_room_presets ) preset_config["encrypted"] = encrypted @@ -141,7 +141,7 @@ class RoomCreationHandler(BaseHandler): self._upgrade_response_cache: ResponseCache[Tuple[str, str]] = ResponseCache( hs.get_clock(), "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS ) - self._server_notices_mxid = hs.config.server_notices_mxid + self._server_notices_mxid = hs.config.servernotices.server_notices_mxid self.third_party_event_rules = hs.get_third_party_event_rules() @@ -757,7 +757,9 @@ class RoomCreationHandler(BaseHandler): ) if is_public: - if not self.config.is_publishing_room_allowed(user_id, room_id, room_alias): + if not self.config.roomdirectory.is_publishing_room_allowed( + user_id, room_id, room_alias + ): # Lets just return a generic message, as there may be all sorts of # reasons why we said no. TODO: Allow configurable error messages # per alias creation rule? diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index c83ff585e3..c3d4199ed1 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -52,7 +52,7 @@ EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None) class RoomListHandler(BaseHandler): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.enable_room_list_search = hs.config.enable_room_list_search + self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search self.response_cache: ResponseCache[ Tuple[Optional[int], Optional[str], Optional[ThirdPartyInstanceID]] ] = ResponseCache(hs.get_clock(), "room_list") diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 7bb3f0bc47..1a56c82fbd 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -88,7 +88,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): self.clock = hs.get_clock() self.spam_checker = hs.get_spam_checker() self.third_party_event_rules = hs.get_third_party_event_rules() - self._server_notices_mxid = self.config.server_notices_mxid + self._server_notices_mxid = self.config.servernotices.server_notices_mxid self._enable_lookup = hs.config.enable_3pid_lookup self.allow_per_room_profiles = self.config.allow_per_room_profiles diff --git a/synapse/handlers/saml.py b/synapse/handlers/saml.py index 185befbe9f..2fed9f377a 100644 --- a/synapse/handlers/saml.py +++ b/synapse/handlers/saml.py @@ -54,19 +54,18 @@ class Saml2SessionData: class SamlHandler(BaseHandler): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self._saml_client = Saml2Client(hs.config.saml2_sp_config) - self._saml_idp_entityid = hs.config.saml2_idp_entityid + self._saml_client = Saml2Client(hs.config.saml2.saml2_sp_config) + self._saml_idp_entityid = hs.config.saml2.saml2_idp_entityid - self._saml2_session_lifetime = hs.config.saml2_session_lifetime + self._saml2_session_lifetime = hs.config.saml2.saml2_session_lifetime self._grandfathered_mxid_source_attribute = ( - hs.config.saml2_grandfathered_mxid_source_attribute + hs.config.saml2.saml2_grandfathered_mxid_source_attribute ) self._saml2_attribute_requirements = hs.config.saml2.attribute_requirements - self._error_template = hs.config.sso_error_template # plugin to do custom mapping from saml response to mxid - self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class( - hs.config.saml2_user_mapping_provider_config, + self._user_mapping_provider = hs.config.saml2.saml2_user_mapping_provider_class( + hs.config.saml2.saml2_user_mapping_provider_config, ModuleApi(hs, hs.get_auth_handler()), ) @@ -411,7 +410,7 @@ class DefaultSamlMappingProvider: self._mxid_mapper = parsed_config.mxid_mapper self._grandfathered_mxid_source_attribute = ( - module_api._hs.config.saml2_grandfathered_mxid_source_attribute + module_api._hs.config.saml2.saml2_grandfathered_mxid_source_attribute ) def get_remote_user_id( diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index e044251a13..49fde01cf0 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -184,15 +184,17 @@ class SsoHandler: self._server_name = hs.hostname self._registration_handler = hs.get_registration_handler() self._auth_handler = hs.get_auth_handler() - self._error_template = hs.config.sso_error_template - self._bad_user_template = hs.config.sso_auth_bad_user_template + self._error_template = hs.config.sso.sso_error_template + self._bad_user_template = hs.config.sso.sso_auth_bad_user_template self._profile_handler = hs.get_profile_handler() # The following template is shown after a successful user interactive # authentication session. It tells the user they can close the window. - self._sso_auth_success_template = hs.config.sso_auth_success_template + self._sso_auth_success_template = hs.config.sso.sso_auth_success_template - self._sso_update_profile_information = hs.config.sso_update_profile_information + self._sso_update_profile_information = ( + hs.config.sso.sso_update_profile_information + ) # a lock on the mappings self._mapping_lock = Linearizer(name="sso_user_mapping", clock=hs.get_clock()) diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index 9fc53333fc..bd3e6f2ec7 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -46,7 +46,7 @@ class StatsHandler: self.notifier = hs.get_notifier() self.is_mine_id = hs.is_mine_id - self.stats_enabled = hs.config.stats_enabled + self.stats_enabled = hs.config.stats.stats_enabled # The current position in the current_state_delta stream self.pos: Optional[int] = None diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index 8dc46d7674..b91e7cb501 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -61,7 +61,7 @@ class UserDirectoryHandler(StateDeltasHandler): self.notifier = hs.get_notifier() self.is_mine_id = hs.is_mine_id self.update_user_directory = hs.config.update_user_directory - self.search_all_users = hs.config.user_directory_search_all_users + self.search_all_users = hs.config.userdirectory.user_directory_search_all_users self.spam_checker = hs.get_spam_checker() # The current position in the current_state_delta stream self.pos: Optional[int] = None diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index c6c4d3bd29..03d2dd94f6 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -363,7 +363,7 @@ def noop_context_manager(*args, **kwargs): def init_tracer(hs: "HomeServer"): """Set the whitelists and initialise the JaegerClient tracer""" global opentracing - if not hs.config.opentracer_enabled: + if not hs.config.tracing.opentracer_enabled: # We don't have a tracer opentracing = None return @@ -377,12 +377,12 @@ def init_tracer(hs: "HomeServer"): # Pull out the jaeger config if it was given. Otherwise set it to something sensible. # See https://github.com/jaegertracing/jaeger-client-python/blob/master/jaeger_client/config.py - set_homeserver_whitelist(hs.config.opentracer_whitelist) + set_homeserver_whitelist(hs.config.tracing.opentracer_whitelist) from jaeger_client.metrics.prometheus import PrometheusMetricsFactory config = JaegerConfig( - config=hs.config.jaeger_config, + config=hs.config.tracing.jaeger_config, service_name=f"{hs.config.server.server_name} {hs.get_instance_name()}", scope_manager=LogContextScopeManager(hs.config), metrics_factory=PrometheusMetricsFactory(), diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 25589b0042..f1b78d09f9 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -168,8 +168,8 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): client = hs.get_simple_http_client() local_instance_name = hs.get_instance_name() - master_host = hs.config.worker_replication_host - master_port = hs.config.worker_replication_http_port + master_host = hs.config.worker.worker_replication_host + master_port = hs.config.worker.worker_replication_http_port instance_map = hs.config.worker.instance_map diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 509ed7fb13..1438a82b60 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -322,8 +322,8 @@ class ReplicationCommandHandler: else: client_name = hs.get_instance_name() self._factory = DirectTcpReplicationClientFactory(hs, client_name, self) - host = hs.config.worker_replication_host - port = hs.config.worker_replication_port + host = hs.config.worker.worker_replication_host + port = hs.config.worker.worker_replication_port hs.get_reactor().connectTCP(host.encode(), port, self._factory) def get_streams(self) -> Dict[str, Stream]: diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index a03774c98a..e1506deb2b 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -267,7 +267,7 @@ def register_servlets_for_client_rest_resource( # Load the media repo ones if we're using them. Otherwise load the servlets which # don't need a media repo (typically readonly admin APIs). - if hs.config.can_load_media_repo: + if hs.config.media.can_load_media_repo: register_servlets_for_media_repo(hs, http_server) else: ListMediaInRoom(hs).register(http_server) diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index 64446fc486..fa5c173f4b 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -76,7 +76,7 @@ class LoginRestServlet(RestServlet): self.jwt_audiences = hs.config.jwt.jwt_audiences # SSO configuration. - self.saml2_enabled = hs.config.saml2_enabled + self.saml2_enabled = hs.config.saml2.saml2_enabled self.cas_enabled = hs.config.cas.cas_enabled self.oidc_enabled = hs.config.oidc.oidc_enabled self._msc2918_enabled = hs.config.access_token_lifetime is not None diff --git a/synapse/rest/client/user_directory.py b/synapse/rest/client/user_directory.py index 8852811114..a47d9bd01d 100644 --- a/synapse/rest/client/user_directory.py +++ b/synapse/rest/client/user_directory.py @@ -58,7 +58,7 @@ class UserDirectorySearchRestServlet(RestServlet): requester = await self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() - if not self.hs.config.user_directory_search_enabled: + if not self.hs.config.userdirectory.user_directory_search_enabled: return 200, {"limited": False, "results": []} body = parse_json_object_from_request(request) diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index a1a815cf82..b52a296d8f 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -42,15 +42,15 @@ class VersionsRestServlet(RestServlet): # Calculate these once since they shouldn't change after start-up. self.e2ee_forced_public = ( RoomCreationPreset.PUBLIC_CHAT - in self.config.encryption_enabled_by_default_for_room_presets + in self.config.room.encryption_enabled_by_default_for_room_presets ) self.e2ee_forced_private = ( RoomCreationPreset.PRIVATE_CHAT - in self.config.encryption_enabled_by_default_for_room_presets + in self.config.room.encryption_enabled_by_default_for_room_presets ) self.e2ee_forced_trusted_private = ( RoomCreationPreset.TRUSTED_PRIVATE_CHAT - in self.config.encryption_enabled_by_default_for_room_presets + in self.config.room.encryption_enabled_by_default_for_room_presets ) def on_GET(self, request: Request) -> Tuple[int, JsonDict]: diff --git a/synapse/rest/client/voip.py b/synapse/rest/client/voip.py index 9d46ed3af3..ea2b8aa45f 100644 --- a/synapse/rest/client/voip.py +++ b/synapse/rest/client/voip.py @@ -37,14 +37,14 @@ class VoipRestServlet(RestServlet): async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req( - request, self.hs.config.turn_allow_guests + request, self.hs.config.voip.turn_allow_guests ) - turnUris = self.hs.config.turn_uris - turnSecret = self.hs.config.turn_shared_secret - turnUsername = self.hs.config.turn_username - turnPassword = self.hs.config.turn_password - userLifetime = self.hs.config.turn_user_lifetime + turnUris = self.hs.config.voip.turn_uris + turnSecret = self.hs.config.voip.turn_shared_secret + turnUsername = self.hs.config.voip.turn_username + turnPassword = self.hs.config.voip.turn_password + userLifetime = self.hs.config.voip.turn_user_lifetime if turnUris and turnSecret and userLifetime: expiry = (self.hs.get_clock().time_msec() + userLifetime) / 1000 diff --git a/synapse/rest/media/v1/config_resource.py b/synapse/rest/media/v1/config_resource.py index 712d4e8368..a95804d327 100644 --- a/synapse/rest/media/v1/config_resource.py +++ b/synapse/rest/media/v1/config_resource.py @@ -31,7 +31,7 @@ class MediaConfigResource(DirectServeJsonResource): config = hs.config self.clock = hs.get_clock() self.auth = hs.get_auth() - self.limits_dict = {"m.upload.size": config.max_upload_size} + self.limits_dict = {"m.upload.size": config.media.max_upload_size} async def _async_render_GET(self, request: SynapseRequest) -> None: await self.auth.get_user_by_req(request) diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index c1bd81100d..abd88a2d4f 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -76,16 +76,16 @@ class MediaRepository: self.clock = hs.get_clock() self.server_name = hs.hostname self.store = hs.get_datastore() - self.max_upload_size = hs.config.max_upload_size - self.max_image_pixels = hs.config.max_image_pixels + self.max_upload_size = hs.config.media.max_upload_size + self.max_image_pixels = hs.config.media.max_image_pixels Thumbnailer.set_limits(self.max_image_pixels) - self.primary_base_path: str = hs.config.media_store_path + self.primary_base_path: str = hs.config.media.media_store_path self.filepaths: MediaFilePaths = MediaFilePaths(self.primary_base_path) - self.dynamic_thumbnails = hs.config.dynamic_thumbnails - self.thumbnail_requirements = hs.config.thumbnail_requirements + self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails + self.thumbnail_requirements = hs.config.media.thumbnail_requirements self.remote_media_linearizer = Linearizer(name="media_remote") @@ -100,7 +100,11 @@ class MediaRepository: # potentially upload to. storage_providers = [] - for clz, provider_config, wrapper_config in hs.config.media_storage_providers: + for ( + clz, + provider_config, + wrapper_config, + ) in hs.config.media.media_storage_providers: backend = clz(hs, provider_config) provider = StorageProviderWrapper( backend, @@ -975,7 +979,7 @@ class MediaRepositoryResource(Resource): def __init__(self, hs: "HomeServer"): # If we're not configured to use it, raise if we somehow got here. - if not hs.config.can_load_media_repo: + if not hs.config.media.can_load_media_repo: raise ConfigError("Synapse is not configured to use a media repo.") super().__init__() @@ -986,7 +990,7 @@ class MediaRepositoryResource(Resource): self.putChild( b"thumbnail", ThumbnailResource(hs, media_repo, media_repo.media_storage) ) - if hs.config.url_preview_enabled: + if hs.config.media.url_preview_enabled: self.putChild( b"preview_url", PreviewUrlResource(hs, media_repo, media_repo.media_storage), diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 128706d297..0b0c4d6469 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -125,14 +125,14 @@ class PreviewUrlResource(DirectServeJsonResource): self.auth = hs.get_auth() self.clock = hs.get_clock() self.filepaths = media_repo.filepaths - self.max_spider_size = hs.config.max_spider_size + self.max_spider_size = hs.config.media.max_spider_size self.server_name = hs.hostname self.store = hs.get_datastore() self.client = SimpleHttpClient( hs, treq_args={"browser_like_redirects": True}, - ip_whitelist=hs.config.url_preview_ip_range_whitelist, - ip_blacklist=hs.config.url_preview_ip_range_blacklist, + ip_whitelist=hs.config.media.url_preview_ip_range_whitelist, + ip_blacklist=hs.config.media.url_preview_ip_range_blacklist, use_proxy=True, ) self.media_repo = media_repo @@ -150,8 +150,8 @@ class PreviewUrlResource(DirectServeJsonResource): or instance_running_jobs == hs.get_instance_name() ) - self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist - self.url_preview_accept_language = hs.config.url_preview_accept_language + self.url_preview_url_blacklist = hs.config.media.url_preview_url_blacklist + self.url_preview_accept_language = hs.config.media.url_preview_accept_language # memory cache mapping urls to an ObservableDeferred returning # JSON-encoded OG metadata diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py index 6c9969e55f..289e4297f2 100644 --- a/synapse/rest/media/v1/storage_provider.py +++ b/synapse/rest/media/v1/storage_provider.py @@ -125,7 +125,7 @@ class FileStorageProviderBackend(StorageProvider): def __init__(self, hs: "HomeServer", config: str): self.hs = hs - self.cache_directory = hs.config.media_store_path + self.cache_directory = hs.config.media.media_store_path self.base_directory = config def __str__(self) -> str: diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index cb2f88676e..ed91ef5a42 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -53,7 +53,7 @@ class ThumbnailResource(DirectServeJsonResource): self.store = hs.get_datastore() self.media_repo = media_repo self.media_storage = media_storage - self.dynamic_thumbnails = hs.config.dynamic_thumbnails + self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails self.server_name = hs.hostname async def _async_render_GET(self, request: SynapseRequest) -> None: diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index 39b29318bb..7dcb1428e4 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -41,7 +41,7 @@ class UploadResource(DirectServeJsonResource): self.clock = hs.get_clock() self.server_name = hs.hostname self.auth = hs.get_auth() - self.max_upload_size = hs.config.max_upload_size + self.max_upload_size = hs.config.media.max_upload_size self.clock = hs.get_clock() async def _async_render_OPTIONS(self, request: SynapseRequest) -> None: diff --git a/synapse/rest/synapse/client/__init__.py b/synapse/rest/synapse/client/__init__.py index 086c80b723..6ad558f5d1 100644 --- a/synapse/rest/synapse/client/__init__.py +++ b/synapse/rest/synapse/client/__init__.py @@ -50,7 +50,7 @@ def build_synapse_client_resource_tree(hs: "HomeServer") -> Mapping[str, Resourc resources["/_synapse/client/oidc"] = OIDCResource(hs) - if hs.config.saml2_enabled: + if hs.config.saml2.saml2_enabled: from synapse.rest.synapse.client.saml2 import SAML2Resource res = SAML2Resource(hs) diff --git a/synapse/rest/synapse/client/saml2/metadata_resource.py b/synapse/rest/synapse/client/saml2/metadata_resource.py index 64378ed57b..d8eae3970d 100644 --- a/synapse/rest/synapse/client/saml2/metadata_resource.py +++ b/synapse/rest/synapse/client/saml2/metadata_resource.py @@ -30,7 +30,7 @@ class SAML2MetadataResource(Resource): def __init__(self, hs: "HomeServer"): Resource.__init__(self) - self.sp_config = hs.config.saml2_sp_config + self.sp_config = hs.config.saml2.saml2_sp_config def render_GET(self, request: Request) -> bytes: metadata_xml = saml2.metadata.create_metadata_string( diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py index d87a538917..cd1c5ff6f4 100644 --- a/synapse/server_notices/server_notices_manager.py +++ b/synapse/server_notices/server_notices_manager.py @@ -39,7 +39,7 @@ class ServerNoticesManager: self._server_name = hs.hostname self._notifier = hs.get_notifier() - self.server_notices_mxid = self._config.server_notices_mxid + self.server_notices_mxid = self._config.servernotices.server_notices_mxid def is_enabled(self): """Checks if server notices are enabled on this server. @@ -47,7 +47,7 @@ class ServerNoticesManager: Returns: bool """ - return self._config.server_notices_mxid is not None + return self.server_notices_mxid is not None async def send_notice( self, @@ -71,9 +71,9 @@ class ServerNoticesManager: room_id = await self.get_or_create_notice_room_for_user(user_id) await self.maybe_invite_user_to_room(user_id, room_id) - system_mxid = self._config.server_notices_mxid + assert self.server_notices_mxid is not None requester = create_requester( - system_mxid, authenticated_entity=self._server_name + self.server_notices_mxid, authenticated_entity=self._server_name ) logger.info("Sending server notice to %s", user_id) @@ -81,7 +81,7 @@ class ServerNoticesManager: event_dict = { "type": type, "room_id": room_id, - "sender": system_mxid, + "sender": self.server_notices_mxid, "content": event_content, } @@ -106,7 +106,7 @@ class ServerNoticesManager: Returns: room id of notice room. """ - if not self.is_enabled(): + if self.server_notices_mxid is None: raise Exception("Server notices not enabled") assert self._is_mine_id(user_id), "Cannot send server notices to remote users" @@ -139,12 +139,12 @@ class ServerNoticesManager: # avatar, we have to use both. join_profile = None if ( - self._config.server_notices_mxid_display_name is not None - or self._config.server_notices_mxid_avatar_url is not None + self._config.servernotices.server_notices_mxid_display_name is not None + or self._config.servernotices.server_notices_mxid_avatar_url is not None ): join_profile = { - "displayname": self._config.server_notices_mxid_display_name, - "avatar_url": self._config.server_notices_mxid_avatar_url, + "displayname": self._config.servernotices.server_notices_mxid_display_name, + "avatar_url": self._config.servernotices.server_notices_mxid_avatar_url, } requester = create_requester( @@ -154,7 +154,7 @@ class ServerNoticesManager: requester, config={ "preset": RoomCreationPreset.PRIVATE_CHAT, - "name": self._config.server_notices_room_name, + "name": self._config.servernotices.server_notices_room_name, "power_level_content_override": {"users_default": -10}, }, ratelimit=False, @@ -178,6 +178,7 @@ class ServerNoticesManager: user_id: The ID of the user to invite. room_id: The ID of the room to invite the user to. """ + assert self.server_notices_mxid is not None requester = create_requester( self.server_notices_mxid, authenticated_entity=self._server_name ) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 52ef9deede..c83089ee63 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -2015,7 +2015,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): (user_id_obj.localpart, create_profile_with_displayname), ) - if self.hs.config.stats_enabled: + if self.hs.config.stats.stats_enabled: # we create a new completed user statistics row # we don't strictly need current_token since this user really can't diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index 343d6efc92..e20033bb28 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -98,7 +98,7 @@ class StatsStore(StateDeltasStore): self.server_name = hs.hostname self.clock = self.hs.get_clock() - self.stats_enabled = hs.config.stats_enabled + self.stats_enabled = hs.config.stats.stats_enabled self.stats_delta_processing_lock = DeferredLock() diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 7ca04237a5..90d65edc42 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -551,7 +551,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): super().__init__(database, db_conn, hs) self._prefer_local_users_in_search = ( - hs.config.user_directory_search_prefer_local_users + hs.config.userdirectory.user_directory_search_prefer_local_users ) self._server_name = hs.config.server.server_name @@ -741,7 +741,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): } """ - if self.hs.config.user_directory_search_all_users: + if self.hs.config.userdirectory.user_directory_search_all_users: join_args = (user_id,) where_clause = "user_id != ?" else: diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index a0a48b564e..6a2e76ca4a 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -405,7 +405,9 @@ class TestCreateAliasACL(unittest.HomeserverTestCase): rd_config = RoomDirectoryConfig() rd_config.read_config(config) - self.hs.config.is_alias_creation_allowed = rd_config.is_alias_creation_allowed + self.hs.config.roomdirectory.is_alias_creation_allowed = ( + rd_config.is_alias_creation_allowed + ) return hs diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index 1ba4c05b9b..24b7ef6efc 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -118,7 +118,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.assertEqual(len(r), 0) # Disable stats - self.hs.config.stats_enabled = False + self.hs.config.stats.stats_enabled = False self.handler.stats_enabled = False u1 = self.register_user("u1", "pass") @@ -134,7 +134,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.assertEqual(len(r), 0) # Enable stats - self.hs.config.stats_enabled = True + self.hs.config.stats.stats_enabled = True self.handler.stats_enabled = True # Do the initial population of the user directory via the background update @@ -469,7 +469,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): behaviour eventually to still keep current rows. """ - self.hs.config.stats_enabled = False + self.hs.config.stats.stats_enabled = False self.handler.stats_enabled = False u1 = self.register_user("u1", "pass") @@ -481,7 +481,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.assertIsNone(self._get_current_stats("room", r1)) self.assertIsNone(self._get_current_stats("user", u1)) - self.hs.config.stats_enabled = True + self.hs.config.stats.stats_enabled = True self.handler.stats_enabled = True self._perform_background_initial_update() diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index ba32585a14..266333c553 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -451,7 +451,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): visible. """ self.handler.search_all_users = True - self.hs.config.user_directory_search_all_users = True + self.hs.config.userdirectory.user_directory_search_all_users = True u1 = self.register_user("user1", "pass") self.register_user("user2", "pass") @@ -607,7 +607,7 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase): return hs def test_disabling_room_list(self): - self.config.user_directory_search_enabled = True + self.config.userdirectory.user_directory_search_enabled = True # First we create a room with another user so that user dir is non-empty # for our user @@ -624,7 +624,7 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase): self.assertTrue(len(channel.json_body["results"]) > 0) # Disable user directory and check search returns nothing - self.config.user_directory_search_enabled = False + self.config.userdirectory.user_directory_search_enabled = False channel = self.make_request( "POST", b"user_directory/search", b'{"search_term":"user2"}' ) diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py index f813866073..ce30a19213 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py @@ -43,7 +43,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") - self.filepaths = MediaFilePaths(hs.config.media_store_path) + self.filepaths = MediaFilePaths(hs.config.media.media_store_path) def test_no_auth(self): """ @@ -200,7 +200,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") - self.filepaths = MediaFilePaths(hs.config.media_store_path) + self.filepaths = MediaFilePaths(hs.config.media.media_store_path) self.url = "/_synapse/admin/v1/media/%s/delete" % self.server_name def test_no_auth(self): diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index e79e0e1850..ee3ae9cce4 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -2473,7 +2473,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() self.media_repo = hs.get_media_repository_resource() - self.filepaths = MediaFilePaths(hs.config.media_store_path) + self.filepaths = MediaFilePaths(hs.config.media.media_store_path) self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index 44a643d506..4ae00755c9 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -53,7 +53,7 @@ class MediaStorageTests(unittest.HomeserverTestCase): self.primary_base_path = os.path.join(self.test_dir, "primary") self.secondary_base_path = os.path.join(self.test_dir, "secondary") - hs.config.media_store_path = self.primary_base_path + hs.config.media.media_store_path = self.primary_base_path storage_providers = [FileStorageProviderBackend(hs, self.secondary_base_path)] diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 8701b5f7e3..7f25200a5d 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -326,7 +326,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): for event in events: if ( event["type"] == EventTypes.Message - and event["sender"] == self.hs.config.server_notices_mxid + and event["sender"] == self.hs.config.servernotices.server_notices_mxid ): notice_in_room = True From 0420d4e6a5ceb58a453ce0761a15cd8e144da650 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 24 Sep 2021 14:01:45 +0100 Subject: [PATCH 24/31] Stop trying to auth/persist events whose auth events we do not have. (#10907) --- changelog.d/10907.bugfix | 1 + synapse/handlers/federation_event.py | 24 ++++++++++++++++-------- 2 files changed, 17 insertions(+), 8 deletions(-) create mode 100644 changelog.d/10907.bugfix diff --git a/changelog.d/10907.bugfix b/changelog.d/10907.bugfix new file mode 100644 index 0000000000..601b341f9f --- /dev/null +++ b/changelog.d/10907.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug which could cause events pulled over federation to be incorrectly rejected. diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 8fd9e51044..01fd841122 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -1194,10 +1194,17 @@ class FederationEventHandler: auth = {} for auth_event_id in event.auth_event_ids(): ae = persisted_events.get(auth_event_id) - if ae: - auth[(ae.type, ae.state_key)] = ae - else: - logger.info("Missing auth event %s", auth_event_id) + if not ae: + logger.warning( + "Event %s relies on auth_event %s, which could not be found.", + event, + auth_event_id, + ) + # the fact we can't find the auth event doesn't mean it doesn't + # exist, which means it is premature to reject `event`. Instead we + # just ignore it for now. + return None + auth[(ae.type, ae.state_key)] = ae context = EventContext.for_outlier() context = await self._check_event_auth( @@ -1208,8 +1215,10 @@ class FederationEventHandler: ) return event, context - events_to_persist = await yieldable_gather_results(prep, fetched_events) - await self.persist_events_and_notify(room_id, events_to_persist) + events_to_persist = ( + x for x in await yieldable_gather_results(prep, fetched_events) if x + ) + await self.persist_events_and_notify(room_id, tuple(events_to_persist)) async def _check_event_auth( self, @@ -1235,8 +1244,7 @@ class FederationEventHandler: claimed_auth_event_map: A map of (type, state_key) => event for the event's claimed auth_events. - Possibly incomplete, and possibly including events that are not yet - persisted, or authed, or in the right room. + Possibly including events that were rejected, or are in the wrong room. Only populated when populating outliers. From ea01d4c2de65f29cf23e2d28786bfc10bd5fd881 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 24 Sep 2021 15:27:09 +0100 Subject: [PATCH 25/31] Update postgresql testing script (#10906) - Use sytest:bionic. Sytest:latest is two years old (do we want CI to push out latest at all?) and comes with Python 3.5, which we explictly no longer support. The script now runs under PostgreSQL 10 as a result. - Advertise script in the docs - Move pg testing script to scripts-dev directory - Write to host as the script's exector, not root A few changes to make it speedier to re-run the tests: - Create blank DB in the container, not the script, so we don't have to `initdb` each time - Use a named volume to persist the tox environment, so we don't have to fetch and install a bunch of packages from PyPI each time Co-authored-by: reivilibre --- .gitignore | 1 + changelog.d/10906.misc | 1 + docker/Dockerfile-pgtests | 24 +++++++++++-- docker/run_pg_tests.sh | 7 ++-- docs/development/contributing_guide.md | 47 ++++++++++++++++++++++++++ scripts-dev/test_postgresql.sh | 19 +++++++++++ test_postgresql.sh | 12 ------- 7 files changed, 92 insertions(+), 19 deletions(-) create mode 100644 changelog.d/10906.misc create mode 100755 scripts-dev/test_postgresql.sh delete mode 100755 test_postgresql.sh diff --git a/.gitignore b/.gitignore index 6b9257b5c9..fe137f3370 100644 --- a/.gitignore +++ b/.gitignore @@ -40,6 +40,7 @@ __pycache__/ /.coverage* /.mypy_cache/ /.tox +/.tox-pg-container /build/ /coverage.* /dist/ diff --git a/changelog.d/10906.misc b/changelog.d/10906.misc new file mode 100644 index 0000000000..20a1cbfbd0 --- /dev/null +++ b/changelog.d/10906.misc @@ -0,0 +1 @@ +Update development testing script `test_postgresql.sh` to use a supported Python version and make re-runs quicker. \ No newline at end of file diff --git a/docker/Dockerfile-pgtests b/docker/Dockerfile-pgtests index 3bfee845c6..92b804d193 100644 --- a/docker/Dockerfile-pgtests +++ b/docker/Dockerfile-pgtests @@ -1,6 +1,6 @@ # Use the Sytest image that comes with a lot of the build dependencies # pre-installed -FROM matrixdotorg/sytest:latest +FROM matrixdotorg/sytest:bionic # The Sytest image doesn't come with python, so install that RUN apt-get update && apt-get -qq install -y python3 python3-dev python3-pip @@ -8,5 +8,23 @@ RUN apt-get update && apt-get -qq install -y python3 python3-dev python3-pip # We need tox to run the tests in run_pg_tests.sh RUN python3 -m pip install tox -ADD run_pg_tests.sh /pg_tests.sh -ENTRYPOINT /pg_tests.sh +# Initialise the db +RUN su -c '/usr/lib/postgresql/10/bin/initdb -D /var/lib/postgresql/data -E "UTF-8" --lc-collate="C.UTF-8" --lc-ctype="C.UTF-8" --username=postgres' postgres + +# Add a user with our UID and GID so that files get created on the host owned +# by us, not root. +ARG UID +ARG GID +RUN groupadd --gid $GID user +RUN useradd --uid $UID --gid $GID --groups sudo --no-create-home user + +# Ensure we can start postgres by sudo-ing as the postgres user. +RUN apt-get update && apt-get -qq install -y sudo +RUN echo "user ALL=(ALL) NOPASSWD: ALL" >> /etc/sudoers + +ADD run_pg_tests.sh /run_pg_tests.sh +# Use the "exec form" of ENTRYPOINT (https://docs.docker.com/engine/reference/builder/#entrypoint) +# so that we can `docker run` this container and pass arguments to pg_tests.sh +ENTRYPOINT ["/run_pg_tests.sh"] + +USER user diff --git a/docker/run_pg_tests.sh b/docker/run_pg_tests.sh index 1fd08cb62b..58e2177d34 100755 --- a/docker/run_pg_tests.sh +++ b/docker/run_pg_tests.sh @@ -10,11 +10,10 @@ set -e # Set PGUSER so Synapse's tests know what user to connect to the database with export PGUSER=postgres -# Initialise & start the database -su -c '/usr/lib/postgresql/9.6/bin/initdb -D /var/lib/postgresql/data -E "UTF-8" --lc-collate="en_US.UTF-8" --lc-ctype="en_US.UTF-8" --username=postgres' postgres -su -c '/usr/lib/postgresql/9.6/bin/pg_ctl -w -D /var/lib/postgresql/data start' postgres +# Start the database +sudo -u postgres /usr/lib/postgresql/10/bin/pg_ctl -w -D /var/lib/postgresql/data start # Run the tests cd /src export TRIAL_FLAGS="-j 4" -tox --workdir=/tmp -e py35-postgres +tox --workdir=./.tox-pg-container -e py36-postgres "$@" diff --git a/docs/development/contributing_guide.md b/docs/development/contributing_guide.md index 97352b0f26..713366368c 100644 --- a/docs/development/contributing_guide.md +++ b/docs/development/contributing_guide.md @@ -170,6 +170,53 @@ To increase the log level for the tests, set `SYNAPSE_TEST_LOG_LEVEL`: SYNAPSE_TEST_LOG_LEVEL=DEBUG trial tests ``` +### Running tests under PostgreSQL + +Invoking `trial` as above will use an in-memory SQLite database. This is great for +quick development and testing. However, we recommend using a PostgreSQL database +in production (and indeed, we have some code paths specific to each database). +This means that we need to run our unit tests against PostgreSQL too. Our CI does +this automatically for pull requests and release candidates, but it's sometimes +useful to reproduce this locally. + +To do so, [configure Postgres](../postgres.md) and run `trial` with the +following environment variables matching your configuration: + +- `SYNAPSE_POSTGRES` to anything nonempty +- `SYNAPSE_POSTGRES_HOST` +- `SYNAPSE_POSTGRES_USER` +- `SYNAPSE_POSTGRES_PASSWORD` + +For example: + +```shell +export SYNAPSE_POSTGRES=1 +export SYNAPSE_POSTGRES_HOST=localhost +export SYNAPSE_POSTGRES_USER=postgres +export SYNAPSE_POSTGRES_PASSWORD=mydevenvpassword +trial +``` + +#### Prebuilt container + +Since configuring PostgreSQL can be fiddly, we can make use of a pre-made +Docker container to set up PostgreSQL and run our tests for us. To do so, run + +```shell +scripts-dev/test_postgresql.sh +``` + +Any extra arguments to the script will be passed to `tox` and then to `trial`, +so we can run a specific test in this container with e.g. + +```shell +scripts-dev/test_postgresql.sh tests.replication.test_sharded_event_persister.EventPersisterShardTestCase +``` + +The container creates a folder in your Synapse checkout called +`.tox-pg-container` and uses this as a tox environment. The output of any +`trial` runs goes into `_trial_temp` in your synapse source directory — the same +as running `trial` directly on your host machine. ## Run the integration tests ([Sytest](https://github.com/matrix-org/sytest)). diff --git a/scripts-dev/test_postgresql.sh b/scripts-dev/test_postgresql.sh new file mode 100755 index 0000000000..43cfa256e4 --- /dev/null +++ b/scripts-dev/test_postgresql.sh @@ -0,0 +1,19 @@ +#!/usr/bin/env bash + +# This script builds the Docker image to run the PostgreSQL tests, and then runs +# the tests. It uses a dedicated tox environment so that we don't have to +# rebuild it each time. + +# Command line arguments to this script are forwarded to "tox" and then to "trial". + +set -e + +# Build, and tag +docker build docker/ \ + --build-arg "UID=$(id -u)" \ + --build-arg "GID=$(id -g)" \ + -f docker/Dockerfile-pgtests \ + -t synapsepgtests + +# Run, mounting the current directory into /src +docker run --rm -it -v "$(pwd):/src" -v synapse-pg-test-tox:/tox synapsepgtests "$@" diff --git a/test_postgresql.sh b/test_postgresql.sh deleted file mode 100755 index c10828fbbc..0000000000 --- a/test_postgresql.sh +++ /dev/null @@ -1,12 +0,0 @@ -#!/usr/bin/env bash - -# This script builds the Docker image to run the PostgreSQL tests, and then runs -# the tests. - -set -e - -# Build, and tag -docker build docker/ -f docker/Dockerfile-pgtests -t synapsepgtests - -# Run, mounting the current directory into /src -docker run --rm -it -v $(pwd)\:/src synapsepgtests From b10257e87972d158f4b6a0c7d1fe7239014ea10a Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Fri, 24 Sep 2021 16:38:23 +0200 Subject: [PATCH 26/31] Add a spamchecker callback to allow or deny room creation based on invites (#10898) This is in the context of creating new module callbacks that modules in https://github.com/matrix-org/synapse-dinsic can use, in an effort to reconcile the spam checker API in synapse-dinsic with the one in mainline. This adds a callback that's fairly similar to user_may_create_room except it also allows processing based on the invites sent at room creation. --- changelog.d/10898.feature | 1 + docs/modules/spam_checker_callbacks.md | 29 ++++++ synapse/events/spamcheck.py | 42 +++++++++ synapse/handlers/room.py | 14 ++- tests/rest/client/test_rooms.py | 119 ++++++++++++++++++++++++- 5 files changed, 199 insertions(+), 6 deletions(-) create mode 100644 changelog.d/10898.feature diff --git a/changelog.d/10898.feature b/changelog.d/10898.feature new file mode 100644 index 0000000000..97fa39fd0c --- /dev/null +++ b/changelog.d/10898.feature @@ -0,0 +1 @@ +Add a `user_may_create_room_with_invites` spam checker callback to allow modules to allow or deny a room creation request based on the invites and/or 3PID invites it includes. diff --git a/docs/modules/spam_checker_callbacks.md b/docs/modules/spam_checker_callbacks.md index 81574a015c..7920ac5f8f 100644 --- a/docs/modules/spam_checker_callbacks.md +++ b/docs/modules/spam_checker_callbacks.md @@ -38,6 +38,35 @@ async def user_may_create_room(user: str) -> bool Called when processing a room creation request. The module must return a `bool` indicating whether the given user (represented by their Matrix user ID) is allowed to create a room. +### `user_may_create_room_with_invites` + +```python +async def user_may_create_room_with_invites( + user: str, + invites: List[str], + threepid_invites: List[Dict[str, str]], +) -> bool +``` + +Called when processing a room creation request (right after `user_may_create_room`). +The module is given the Matrix user ID of the user trying to create a room, as well as a +list of Matrix users to invite and a list of third-party identifiers (3PID, e.g. email +addresses) to invite. + +An invited Matrix user to invite is represented by their Matrix user IDs, and an invited +3PIDs is represented by a dict that includes the 3PID medium (e.g. "email") through its +`medium` key and its address (e.g. "alice@example.com") through its `address` key. + +See [the Matrix specification](https://matrix.org/docs/spec/appendices#pid-types) for more +information regarding third-party identifiers. + +If no invite and/or 3PID invite were specified in the room creation request, the +corresponding list(s) will be empty. + +**Note**: This callback is not called when a room is cloned (e.g. during a room upgrade) +since no invites are sent when cloning a room. To cover this case, modules also need to +implement `user_may_create_room`. + ### `user_may_create_room_alias` ```python diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index 19ee246f96..c389f70b8d 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -46,6 +46,9 @@ CHECK_EVENT_FOR_SPAM_CALLBACK = Callable[ ] USER_MAY_INVITE_CALLBACK = Callable[[str, str, str], Awaitable[bool]] USER_MAY_CREATE_ROOM_CALLBACK = Callable[[str], Awaitable[bool]] +USER_MAY_CREATE_ROOM_WITH_INVITES_CALLBACK = Callable[ + [str, List[str], List[Dict[str, str]]], Awaitable[bool] +] USER_MAY_CREATE_ROOM_ALIAS_CALLBACK = Callable[[str, RoomAlias], Awaitable[bool]] USER_MAY_PUBLISH_ROOM_CALLBACK = Callable[[str, str], Awaitable[bool]] CHECK_USERNAME_FOR_SPAM_CALLBACK = Callable[[Dict[str, str]], Awaitable[bool]] @@ -164,6 +167,9 @@ class SpamChecker: self._check_event_for_spam_callbacks: List[CHECK_EVENT_FOR_SPAM_CALLBACK] = [] self._user_may_invite_callbacks: List[USER_MAY_INVITE_CALLBACK] = [] self._user_may_create_room_callbacks: List[USER_MAY_CREATE_ROOM_CALLBACK] = [] + self._user_may_create_room_with_invites_callbacks: List[ + USER_MAY_CREATE_ROOM_WITH_INVITES_CALLBACK + ] = [] self._user_may_create_room_alias_callbacks: List[ USER_MAY_CREATE_ROOM_ALIAS_CALLBACK ] = [] @@ -183,6 +189,9 @@ class SpamChecker: check_event_for_spam: Optional[CHECK_EVENT_FOR_SPAM_CALLBACK] = None, user_may_invite: Optional[USER_MAY_INVITE_CALLBACK] = None, user_may_create_room: Optional[USER_MAY_CREATE_ROOM_CALLBACK] = None, + user_may_create_room_with_invites: Optional[ + USER_MAY_CREATE_ROOM_WITH_INVITES_CALLBACK + ] = None, user_may_create_room_alias: Optional[ USER_MAY_CREATE_ROOM_ALIAS_CALLBACK ] = None, @@ -203,6 +212,11 @@ class SpamChecker: if user_may_create_room is not None: self._user_may_create_room_callbacks.append(user_may_create_room) + if user_may_create_room_with_invites is not None: + self._user_may_create_room_with_invites_callbacks.append( + user_may_create_room_with_invites, + ) + if user_may_create_room_alias is not None: self._user_may_create_room_alias_callbacks.append( user_may_create_room_alias, @@ -283,6 +297,34 @@ class SpamChecker: return True + async def user_may_create_room_with_invites( + self, + userid: str, + invites: List[str], + threepid_invites: List[Dict[str, str]], + ) -> bool: + """Checks if a given user may create a room with invites + + If this method returns false, the creation request will be rejected. + + Args: + userid: The ID of the user attempting to create a room + invites: The IDs of the Matrix users to be invited if the room creation is + allowed. + threepid_invites: The threepids to be invited if the room creation is allowed, + as a dict including a "medium" key indicating the threepid's medium (e.g. + "email") and an "address" key indicating the threepid's address (e.g. + "alice@example.com") + + Returns: + True if the user may create the room, otherwise False + """ + for callback in self._user_may_create_room_with_invites_callbacks: + if await callback(userid, invites, threepid_invites) is False: + return False + + return True + async def user_may_create_room_alias( self, userid: str, room_alias: RoomAlias ) -> bool: diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 408b7d7b74..8fede5e935 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -649,8 +649,16 @@ class RoomCreationHandler(BaseHandler): requester, config, is_requester_admin=is_requester_admin ) - if not is_requester_admin and not await self.spam_checker.user_may_create_room( - user_id + invite_3pid_list = config.get("invite_3pid", []) + invite_list = config.get("invite", []) + + if not is_requester_admin and not ( + await self.spam_checker.user_may_create_room(user_id) + and await self.spam_checker.user_may_create_room_with_invites( + user_id, + invite_list, + invite_3pid_list, + ) ): raise SynapseError(403, "You are not permitted to create rooms") @@ -684,8 +692,6 @@ class RoomCreationHandler(BaseHandler): if mapping: raise SynapseError(400, "Room alias already taken", Codes.ROOM_IN_USE) - invite_3pid_list = config.get("invite_3pid", []) - invite_list = config.get("invite", []) for i in invite_list: try: uid = UserID.from_string(i) diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index ef847f0f5f..30bdaa9c27 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -18,7 +18,7 @@ """Tests REST events for /rooms paths.""" import json -from typing import Iterable +from typing import Dict, Iterable, List, Optional from unittest.mock import Mock, call from urllib import parse as urlparse @@ -30,7 +30,7 @@ from synapse.api.errors import Codes, HttpResponseException from synapse.handlers.pagination import PurgeStatus from synapse.rest import admin from synapse.rest.client import account, directory, login, profile, room, sync -from synapse.types import JsonDict, RoomAlias, UserID, create_requester +from synapse.types import JsonDict, Requester, RoomAlias, UserID, create_requester from synapse.util.stringutils import random_string from tests import unittest @@ -669,6 +669,121 @@ class RoomsCreateTestCase(RoomBase): channel = self.make_request("POST", "/createRoom", content) self.assertEqual(200, channel.code) + def test_spamchecker_invites(self): + """Tests the user_may_create_room_with_invites spam checker callback.""" + + # Mock do_3pid_invite, so we don't fail from failing to send a 3PID invite to an + # IS. + async def do_3pid_invite( + room_id: str, + inviter: UserID, + medium: str, + address: str, + id_server: str, + requester: Requester, + txn_id: Optional[str], + id_access_token: Optional[str] = None, + ) -> int: + return 0 + + do_3pid_invite_mock = Mock(side_effect=do_3pid_invite) + self.hs.get_room_member_handler().do_3pid_invite = do_3pid_invite_mock + + # Add a mock callback for user_may_create_room_with_invites. Make it allow any + # room creation request for now. + return_value = True + + async def user_may_create_room_with_invites( + user: str, + invites: List[str], + threepid_invites: List[Dict[str, str]], + ) -> bool: + return return_value + + callback_mock = Mock(side_effect=user_may_create_room_with_invites) + self.hs.get_spam_checker()._user_may_create_room_with_invites_callbacks.append( + callback_mock, + ) + + # The MXIDs we'll try to invite. + invited_mxids = [ + "@alice1:red", + "@alice2:red", + "@alice3:red", + "@alice4:red", + ] + + # The 3PIDs we'll try to invite. + invited_3pids = [ + { + "id_server": "example.com", + "id_access_token": "sometoken", + "medium": "email", + "address": "alice1@example.com", + }, + { + "id_server": "example.com", + "id_access_token": "sometoken", + "medium": "email", + "address": "alice2@example.com", + }, + { + "id_server": "example.com", + "id_access_token": "sometoken", + "medium": "email", + "address": "alice3@example.com", + }, + ] + + # Create a room and invite the Matrix users, and check that it succeeded. + channel = self.make_request( + "POST", + "/createRoom", + json.dumps({"invite": invited_mxids}).encode("utf8"), + ) + self.assertEqual(200, channel.code) + + # Check that the callback was called with the right arguments. + expected_call_args = ((self.user_id, invited_mxids, []),) + self.assertEquals( + callback_mock.call_args, + expected_call_args, + callback_mock.call_args, + ) + + # Create a room and invite the 3PIDs, and check that it succeeded. + channel = self.make_request( + "POST", + "/createRoom", + json.dumps({"invite_3pid": invited_3pids}).encode("utf8"), + ) + self.assertEqual(200, channel.code) + + # Check that do_3pid_invite was called the right amount of time + self.assertEquals(do_3pid_invite_mock.call_count, len(invited_3pids)) + + # Check that the callback was called with the right arguments. + expected_call_args = ((self.user_id, [], invited_3pids),) + self.assertEquals( + callback_mock.call_args, + expected_call_args, + callback_mock.call_args, + ) + + # Now deny any room creation. + return_value = False + + # Create a room and invite the 3PIDs, and check that it failed. + channel = self.make_request( + "POST", + "/createRoom", + json.dumps({"invite_3pid": invited_3pids}).encode("utf8"), + ) + self.assertEqual(403, channel.code) + + # Check that do_3pid_invite wasn't called this time. + self.assertEquals(do_3pid_invite_mock.call_count, len(invited_3pids)) + class RoomTopicTestCase(RoomBase): """Tests /rooms/$room_id/topic REST events.""" From d138187045dd3c51689c19124d65ee62e37db755 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Fri, 24 Sep 2021 17:09:12 -0500 Subject: [PATCH 27/31] Document changes to schema version 61 - 64 (#10917) As pointed out by @richvdh, https://github.com/matrix-org/synapse/pull/10838#discussion_r715424244 Retroactively summarize `61` - `64` --- changelog.d/10917.misc | 1 + synapse/storage/schema/__init__.py | 11 +++++++++++ 2 files changed, 12 insertions(+) create mode 100644 changelog.d/10917.misc diff --git a/changelog.d/10917.misc b/changelog.d/10917.misc new file mode 100644 index 0000000000..9ce6eef94b --- /dev/null +++ b/changelog.d/10917.misc @@ -0,0 +1 @@ +Document and summarize changes in schema version `61` - `64`. diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index aa2ce44c6c..573e05a482 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -27,11 +27,22 @@ for more information on how this works. Changes in SCHEMA_VERSION = 61: - The `user_stats_historical` and `room_stats_historical` tables are not written and are not read (previously, they were written but not read). + - MSC2716: Add `insertion_events` and `insertion_event_edges` tables to keep track + of insertion events in order to navigate historical chunks of messages. + - MSC2716: Add `chunk_events` table to track how the chunk is labeled and + determines which insertion event it points to. + +Changes in SCHEMA_VERSION = 62: + - MSC2716: Add `insertion_event_extremities` table that keeps track of which + insertion events need to be backfilled. Changes in SCHEMA_VERSION = 63: - The `public_room_list_stream` table is not written nor read to (previously, it was written and read to, but not for any significant purpose). https://github.com/matrix-org/synapse/pull/10565 + +Changes in SCHEMA_VERSION = 64: + - MSC2716: Rename related tables and columns from "chunks" to "batches". """ From 6c83c2710760a4f551d1a925fc9b1a19ae8797c1 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Mon, 27 Sep 2021 11:29:23 +0100 Subject: [PATCH 28/31] Fix race conditions when creating media store and config directories (#10913) --- changelog.d/10913.bugfix | 1 + synapse/config/_base.py | 9 ++------- synapse/rest/media/v1/media_storage.py | 6 ++---- synapse/rest/media/v1/storage_provider.py | 3 +-- 4 files changed, 6 insertions(+), 13 deletions(-) create mode 100644 changelog.d/10913.bugfix diff --git a/changelog.d/10913.bugfix b/changelog.d/10913.bugfix new file mode 100644 index 0000000000..a0015c8241 --- /dev/null +++ b/changelog.d/10913.bugfix @@ -0,0 +1 @@ +Fix race conditions when creating media store and config directories. diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 2cc242782a..d974a1a2a8 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -200,11 +200,7 @@ class Config: @classmethod def ensure_directory(cls, dir_path): dir_path = cls.abspath(dir_path) - try: - os.makedirs(dir_path) - except OSError as e: - if e.errno != errno.EEXIST: - raise + os.makedirs(dir_path, exist_ok=True) if not os.path.isdir(dir_path): raise ConfigError("%s is not a directory" % (dir_path,)) return dir_path @@ -693,8 +689,7 @@ class RootConfig: open_private_ports=config_args.open_private_ports, ) - if not path_exists(config_dir_path): - os.makedirs(config_dir_path) + os.makedirs(config_dir_path, exist_ok=True) with open(config_path, "w") as config_file: config_file.write(config_str) config_file.write("\n\n# vim:ft=yaml") diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index 01fada8fb5..fca239d8c7 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -132,8 +132,7 @@ class MediaStorage: fname = os.path.join(self.local_media_directory, path) dirname = os.path.dirname(fname) - if not os.path.exists(dirname): - os.makedirs(dirname) + os.makedirs(dirname, exist_ok=True) finished_called = [False] @@ -244,8 +243,7 @@ class MediaStorage: return legacy_local_path dirname = os.path.dirname(local_path) - if not os.path.exists(dirname): - os.makedirs(dirname) + os.makedirs(dirname, exist_ok=True) for provider in self.storage_providers: res: Any = await provider.fetch(path, file_info) diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py index 289e4297f2..da78fcee5e 100644 --- a/synapse/rest/media/v1/storage_provider.py +++ b/synapse/rest/media/v1/storage_provider.py @@ -138,8 +138,7 @@ class FileStorageProviderBackend(StorageProvider): backup_fname = os.path.join(self.base_directory, path) dirname = os.path.dirname(backup_fname) - if not os.path.exists(dirname): - os.makedirs(dirname) + os.makedirs(dirname, exist_ok=True) await defer_to_thread( self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname From f7768f62cbf7579a1a91e694f83d47d275373369 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Mon, 27 Sep 2021 12:55:27 +0100 Subject: [PATCH 29/31] Avoid storing URL cache files in storage providers (#10911) URL cache files are short-lived and it does not make sense to offload them (eg. to the cloud) or back them up. --- changelog.d/10911.bugfix | 1 + docs/upgrade.md | 7 + synapse/rest/media/v1/filepath.py | 11 +- synapse/rest/media/v1/preview_url_resource.py | 1 - synapse/rest/media/v1/storage_provider.py | 10 ++ tests/rest/media/v1/test_url_preview.py | 130 ++++++++++++++++++ 6 files changed, 154 insertions(+), 6 deletions(-) create mode 100644 changelog.d/10911.bugfix diff --git a/changelog.d/10911.bugfix b/changelog.d/10911.bugfix new file mode 100644 index 0000000000..96e36bb15a --- /dev/null +++ b/changelog.d/10911.bugfix @@ -0,0 +1 @@ +Avoid storing URL cache files in storage providers. Server admins may safely delete the `url_cache/` and `url_cache_thumbnails/` directories from any configured storage providers to reclaim space. diff --git a/docs/upgrade.md b/docs/upgrade.md index f9b832cb3f..a8221372df 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -85,6 +85,13 @@ process, for example: dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb ``` +# Upgrading to v1.44.0 + +## The URL preview cache is no longer mirrored to storage providers +The `url_cache/` and `url_cache_thumbnails/` directories in the media store are +no longer mirrored to storage providers. These two directories can be safely +deleted from any configured storage providers to reclaim space. + # Upgrading to v1.43.0 ## The spaces summary APIs can now be handled by workers diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py index 39bbe4e874..08bd85f664 100644 --- a/synapse/rest/media/v1/filepath.py +++ b/synapse/rest/media/v1/filepath.py @@ -195,23 +195,24 @@ class MediaFilePaths: url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel) - def url_cache_thumbnail_directory(self, media_id: str) -> str: + def url_cache_thumbnail_directory_rel(self, media_id: str) -> str: # Media id is of the form # E.g.: 2017-09-28-fsdRDt24DS234dsf if NEW_FORMAT_ID_RE.match(media_id): - return os.path.join( - self.base_path, "url_cache_thumbnails", media_id[:10], media_id[11:] - ) + return os.path.join("url_cache_thumbnails", media_id[:10], media_id[11:]) else: return os.path.join( - self.base_path, "url_cache_thumbnails", media_id[0:2], media_id[2:4], media_id[4:], ) + url_cache_thumbnail_directory = _wrap_in_base_path( + url_cache_thumbnail_directory_rel + ) + def url_cache_thumbnail_dirs_to_delete(self, media_id: str) -> List[str]: "The dirs to try and remove if we delete the media_id thumbnails" # Media id is of the form diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 0b0c4d6469..79a42b2455 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -485,7 +485,6 @@ class PreviewUrlResource(DirectServeJsonResource): async def _expire_url_cache_data(self) -> None: """Clean up expired url cache content, media and thumbnails.""" - # TODO: Delete from backup media store assert self._worker_run_media_background_jobs diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py index da78fcee5e..18bf977d3d 100644 --- a/synapse/rest/media/v1/storage_provider.py +++ b/synapse/rest/media/v1/storage_provider.py @@ -93,6 +93,11 @@ class StorageProviderWrapper(StorageProvider): if file_info.server_name and not self.store_remote: return None + if file_info.url_cache: + # The URL preview cache is short lived and not worth offloading or + # backing up. + return None + if self.store_synchronous: # store_file is supposed to return an Awaitable, but guard # against improper implementations. @@ -110,6 +115,11 @@ class StorageProviderWrapper(StorageProvider): run_in_background(store) async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: + if file_info.url_cache: + # Files in the URL preview cache definitely aren't stored here, + # so avoid any potentially slow I/O or network access. + return None + # store_file is supposed to return an Awaitable, but guard # against improper implementations. return await maybe_awaitable(self.backend.fetch(path, file_info)) diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index d83dfacfed..4d09b5d07e 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -21,6 +21,7 @@ from twisted.internet.error import DNSLookupError from twisted.test.proto_helpers import AccumulatingProtocol from synapse.config.oembed import OEmbedEndpointConfig +from synapse.util.stringutils import parse_and_validate_mxc_uri from tests import unittest from tests.server import FakeTransport @@ -721,3 +722,132 @@ class URLPreviewTests(unittest.HomeserverTestCase): "og:description": "Content Preview", }, ) + + def _download_image(self): + """Downloads an image into the URL cache. + + Returns: + A (host, media_id) tuple representing the MXC URI of the image. + """ + self.lookups["cdn.twitter.com"] = [(IPv4Address, "10.1.2.3")] + + channel = self.make_request( + "GET", + "preview_url?url=http://cdn.twitter.com/matrixdotorg", + shorthand=False, + await_result=False, + ) + self.pump() + + client = self.reactor.tcpClients[0][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: image/png\r\n\r\n" + % (len(SMALL_PNG),) + + SMALL_PNG + ) + + self.pump() + self.assertEqual(channel.code, 200) + body = channel.json_body + mxc_uri = body["og:image"] + host, _port, media_id = parse_and_validate_mxc_uri(mxc_uri) + self.assertIsNone(_port) + return host, media_id + + def test_storage_providers_exclude_files(self): + """Test that files are not stored in or fetched from storage providers.""" + host, media_id = self._download_image() + + rel_file_path = self.preview_url.filepaths.url_cache_filepath_rel(media_id) + media_store_path = os.path.join(self.media_store_path, rel_file_path) + storage_provider_path = os.path.join(self.storage_path, rel_file_path) + + # Check storage + self.assertTrue(os.path.isfile(media_store_path)) + self.assertFalse( + os.path.isfile(storage_provider_path), + "URL cache file was unexpectedly stored in a storage provider", + ) + + # Check fetching + channel = self.make_request( + "GET", + f"download/{host}/{media_id}", + shorthand=False, + await_result=False, + ) + self.pump() + self.assertEqual(channel.code, 200) + + # Move cached file into the storage provider + os.makedirs(os.path.dirname(storage_provider_path), exist_ok=True) + os.rename(media_store_path, storage_provider_path) + + channel = self.make_request( + "GET", + f"download/{host}/{media_id}", + shorthand=False, + await_result=False, + ) + self.pump() + self.assertEqual( + channel.code, + 404, + "URL cache file was unexpectedly retrieved from a storage provider", + ) + + def test_storage_providers_exclude_thumbnails(self): + """Test that thumbnails are not stored in or fetched from storage providers.""" + host, media_id = self._download_image() + + rel_thumbnail_path = ( + self.preview_url.filepaths.url_cache_thumbnail_directory_rel(media_id) + ) + media_store_thumbnail_path = os.path.join( + self.media_store_path, rel_thumbnail_path + ) + storage_provider_thumbnail_path = os.path.join( + self.storage_path, rel_thumbnail_path + ) + + # Check storage + self.assertTrue(os.path.isdir(media_store_thumbnail_path)) + self.assertFalse( + os.path.isdir(storage_provider_thumbnail_path), + "URL cache thumbnails were unexpectedly stored in a storage provider", + ) + + # Check fetching + channel = self.make_request( + "GET", + f"thumbnail/{host}/{media_id}?width=32&height=32&method=scale", + shorthand=False, + await_result=False, + ) + self.pump() + self.assertEqual(channel.code, 200) + + # Remove the original, otherwise thumbnails will regenerate + rel_file_path = self.preview_url.filepaths.url_cache_filepath_rel(media_id) + media_store_path = os.path.join(self.media_store_path, rel_file_path) + os.remove(media_store_path) + + # Move cached thumbnails into the storage provider + os.makedirs(os.path.dirname(storage_provider_thumbnail_path), exist_ok=True) + os.rename(media_store_thumbnail_path, storage_provider_thumbnail_path) + + channel = self.make_request( + "GET", + f"thumbnail/{host}/{media_id}?width=32&height=32&method=scale", + shorthand=False, + await_result=False, + ) + self.pump() + self.assertEqual( + channel.code, + 404, + "URL cache thumbnail was unexpectedly retrieved from a storage provider", + ) From d37841787a9e152938ddb39af5bc1d93d04bc640 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 27 Sep 2021 15:39:49 +0100 Subject: [PATCH 30/31] Sign the git tag in release script (#10925) --- changelog.d/10925.misc | 1 + scripts-dev/release.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/10925.misc diff --git a/changelog.d/10925.misc b/changelog.d/10925.misc new file mode 100644 index 0000000000..0c8027ecc2 --- /dev/null +++ b/changelog.d/10925.misc @@ -0,0 +1 @@ +Update release script to sign the newly created git tags. diff --git a/scripts-dev/release.py b/scripts-dev/release.py index a339260c43..ab2d860ab8 100755 --- a/scripts-dev/release.py +++ b/scripts-dev/release.py @@ -276,7 +276,7 @@ def tag(gh_token: Optional[str]): if click.confirm("Edit text?", default=False): changes = click.edit(changes, require_save=False) - repo.create_tag(tag_name, message=changes) + repo.create_tag(tag_name, message=changes, sign=True) if not click.confirm("Push tag to GitHub?", default=True): print("") From 707d5e4e48e839dabd34e4b67426fe8382a2c978 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 28 Sep 2021 10:37:58 +0100 Subject: [PATCH 31/31] Encode JSON responses on a thread in C, mk2 (#10905) Currently we use `JsonEncoder.iterencode` to write JSON responses, which ensures that we don't block the main reactor thread when encoding huge objects. The downside to this is that `iterencode` falls back to using a pure Python encoder that is *much* less efficient and can easily burn a lot of CPU for huge responses. To fix this, while still ensuring we don't block the reactor loop, we encode the JSON on a threadpool using the standard `JsonEncoder.encode` functions, which is backed by a C library. Doing so, however, requires `respond_with_json` to have access to the reactor, which it previously didn't. There are two ways of doing this: 1. threading through the reactor object, which is a bit fiddly as e.g. `DirectServeJsonResource` doesn't currently take a reactor, but is exposed to modules and so is a PITA to change; or 2. expose the reactor in `SynapseRequest`, which requires updating a bunch of servlet types. I went with the latter as that is just a mechanical change, and I think makes sense as a request already has a reactor associated with it (via its http channel). --- changelog.d/10905.feature | 1 + synapse/http/server.py | 72 +++++++++++++++++++++++++++++-------- synapse/push/emailpusher.py | 2 +- synapse/util/iterutils.py | 19 ++++++++-- 4 files changed, 76 insertions(+), 18 deletions(-) create mode 100644 changelog.d/10905.feature diff --git a/changelog.d/10905.feature b/changelog.d/10905.feature new file mode 100644 index 0000000000..07e7b2c6a7 --- /dev/null +++ b/changelog.d/10905.feature @@ -0,0 +1 @@ +Speed up responding with large JSON objects to requests. diff --git a/synapse/http/server.py b/synapse/http/server.py index e28b56abb9..1a50305dcf 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -21,7 +21,6 @@ import types import urllib from http import HTTPStatus from inspect import isawaitable -from io import BytesIO from typing import ( Any, Awaitable, @@ -37,7 +36,7 @@ from typing import ( ) import jinja2 -from canonicaljson import iterencode_canonical_json +from canonicaljson import encode_canonical_json from typing_extensions import Protocol from zope.interface import implementer @@ -45,7 +44,7 @@ from twisted.internet import defer, interfaces from twisted.python import failure from twisted.web import resource from twisted.web.server import NOT_DONE_YET, Request -from twisted.web.static import File, NoRangeStaticProducer +from twisted.web.static import File from twisted.web.util import redirectTo from synapse.api.errors import ( @@ -56,10 +55,11 @@ from synapse.api.errors import ( UnrecognizedRequestError, ) from synapse.http.site import SynapseRequest -from synapse.logging.context import preserve_fn +from synapse.logging.context import defer_to_thread, preserve_fn, run_in_background from synapse.logging.opentracing import trace_servlet from synapse.util import json_encoder from synapse.util.caches import intern_dict +from synapse.util.iterutils import chunk_seq logger = logging.getLogger(__name__) @@ -620,12 +620,11 @@ class _ByteProducer: self._request = None -def _encode_json_bytes(json_object: Any) -> Iterator[bytes]: +def _encode_json_bytes(json_object: Any) -> bytes: """ Encode an object into JSON. Returns an iterator of bytes. """ - for chunk in json_encoder.iterencode(json_object): - yield chunk.encode("utf-8") + return json_encoder.encode(json_object).encode("utf-8") def respond_with_json( @@ -659,7 +658,7 @@ def respond_with_json( return None if canonical_json: - encoder = iterencode_canonical_json + encoder = encode_canonical_json else: encoder = _encode_json_bytes @@ -670,7 +669,9 @@ def respond_with_json( if send_cors: set_cors_headers(request) - _ByteProducer(request, encoder(json_object)) + run_in_background( + _async_write_json_to_request_in_thread, request, encoder, json_object + ) return NOT_DONE_YET @@ -706,15 +707,56 @@ def respond_with_json_bytes( if send_cors: set_cors_headers(request) - # note that this is zero-copy (the bytesio shares a copy-on-write buffer with - # the original `bytes`). - bytes_io = BytesIO(json_bytes) - - producer = NoRangeStaticProducer(request, bytes_io) - producer.start() + _write_bytes_to_request(request, json_bytes) return NOT_DONE_YET +async def _async_write_json_to_request_in_thread( + request: SynapseRequest, + json_encoder: Callable[[Any], bytes], + json_object: Any, +): + """Encodes the given JSON object on a thread and then writes it to the + request. + + This is done so that encoding large JSON objects doesn't block the reactor + thread. + + Note: We don't use JsonEncoder.iterencode here as that falls back to the + Python implementation (rather than the C backend), which is *much* more + expensive. + """ + + json_str = await defer_to_thread(request.reactor, json_encoder, json_object) + + _write_bytes_to_request(request, json_str) + + +def _write_bytes_to_request(request: Request, bytes_to_write: bytes) -> None: + """Writes the bytes to the request using an appropriate producer. + + Note: This should be used instead of `Request.write` to correctly handle + large response bodies. + """ + + # The problem with dumping all of the response into the `Request` object at + # once (via `Request.write`) is that doing so starts the timeout for the + # next request to be received: so if it takes longer than 60s to stream back + # the response to the client, the client never gets it. + # + # The correct solution is to use a Producer; then the timeout is only + # started once all of the content is sent over the TCP connection. + + # To make sure we don't write all of the bytes at once we split it up into + # chunks. + chunk_size = 4096 + bytes_generator = chunk_seq(bytes_to_write, chunk_size) + + # We use a `_ByteProducer` here rather than `NoRangeStaticProducer` as the + # unit tests can't cope with being given a pull producer. + _ByteProducer(request, bytes_generator) + + def set_cors_headers(request: Request): """Set the CORS headers so that javascript running in a web browsers can use this API diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index e08e125cb8..cf5abdfbda 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -184,7 +184,7 @@ class EmailPusher(Pusher): should_notify_at = max(notif_ready_at, room_ready_at) - if should_notify_at < self.clock.time_msec(): + if should_notify_at <= self.clock.time_msec(): # one of our notifications is ready for sending, so we send # *one* email updating the user on their notifications, # we then consider all previously outstanding notifications diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py index 8ac3eab2f5..4938ddf703 100644 --- a/synapse/util/iterutils.py +++ b/synapse/util/iterutils.py @@ -21,13 +21,28 @@ from typing import ( Iterable, Iterator, Mapping, - Sequence, Set, + Sized, Tuple, TypeVar, ) +from typing_extensions import Protocol + T = TypeVar("T") +S = TypeVar("S", bound="_SelfSlice") + + +class _SelfSlice(Sized, Protocol): + """A helper protocol that matches types where taking a slice results in the + same type being returned. + + This is more specific than `Sequence`, which allows another `Sequence` to be + returned. + """ + + def __getitem__(self: S, i: slice) -> S: + ... def batch_iter(iterable: Iterable[T], size: int) -> Iterator[Tuple[T, ...]]: @@ -46,7 +61,7 @@ def batch_iter(iterable: Iterable[T], size: int) -> Iterator[Tuple[T, ...]]: return iter(lambda: tuple(islice(sourceiter, size)), ()) -def chunk_seq(iseq: Sequence[T], maxlen: int) -> Iterable[Sequence[T]]: +def chunk_seq(iseq: S, maxlen: int) -> Iterator[S]: """Split the given sequence into chunks of the given size The last chunk may be shorter than the given size.