From 88ce3080d4d064b9872c9867208116dc9db73d7e Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 24 May 2022 09:23:23 -0400 Subject: [PATCH 01/74] Experimental support for MSC3772 (#12740) Implements the following behind an experimental configuration flag: * A new push rule kind for mutually related events. * A new default push rule (`.m.rule.thread_reply`) under an unstable prefix. This is missing part of MSC3772: * The `.m.rule.thread_reply_to_me` push rule, this depends on MSC3664 / #11804. --- changelog.d/12740.feature | 1 + synapse/config/experimental.py | 3 + synapse/push/baserules.py | 14 ++++ synapse/push/bulk_push_rule_evaluator.py | 71 ++++++++++++++++- synapse/push/clientformat.py | 4 + synapse/push/push_rule_evaluator.py | 50 +++++++++++- synapse/storage/databases/main/events.py | 9 +++ synapse/storage/databases/main/push_rule.py | 5 ++ synapse/storage/databases/main/relations.py | 52 +++++++++++++ tests/push/test_push_rule_evaluator.py | 84 ++++++++++++++++++++- 10 files changed, 287 insertions(+), 6 deletions(-) create mode 100644 changelog.d/12740.feature diff --git a/changelog.d/12740.feature b/changelog.d/12740.feature new file mode 100644 index 0000000000..e674c31ae8 --- /dev/null +++ b/changelog.d/12740.feature @@ -0,0 +1 @@ +Experimental support for [MSC3772](https://github.com/matrix-org/matrix-spec-proposals/pull/3772): Push rule for mutually related events. diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index b20d949689..cc417e2fbf 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -84,3 +84,6 @@ class ExperimentalConfig(Config): # MSC3786 (Add a default push rule to ignore m.room.server_acl events) self.msc3786_enabled: bool = experimental.get("msc3786_enabled", False) + + # MSC3772: A push rule for mutual relations. + self.msc3772_enabled: bool = experimental.get("msc3772_enabled", False) diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py index a17b35a605..4c7278b5a1 100644 --- a/synapse/push/baserules.py +++ b/synapse/push/baserules.py @@ -139,6 +139,7 @@ BASE_APPEND_CONTENT_RULES: List[Dict[str, Any]] = [ { "kind": "event_match", "key": "content.body", + # Match the localpart of the requester's MXID. "pattern_type": "user_localpart", } ], @@ -191,6 +192,7 @@ BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [ "pattern": "invite", "_cache_key": "_invite_member", }, + # Match the requester's MXID. {"kind": "event_match", "key": "state_key", "pattern_type": "user_id"}, ], "actions": [ @@ -350,6 +352,18 @@ BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [ {"set_tweak": "highlight", "value": False}, ], }, + { + "rule_id": "global/underride/.org.matrix.msc3772.thread_reply", + "conditions": [ + { + "kind": "org.matrix.msc3772.relation_match", + "rel_type": "m.thread", + # Match the requester's MXID. + "sender_type": "user_id", + } + ], + "actions": ["notify", {"set_tweak": "highlight", "value": False}], + }, { "rule_id": "global/underride/.m.rule.message", "conditions": [ diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 4cc8a2ecca..1a8e7ef3dc 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -13,8 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import itertools import logging -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, Union import attr from prometheus_client import Counter @@ -121,6 +122,9 @@ class BulkPushRuleEvaluator: resizable=False, ) + # Whether to support MSC3772 is supported. + self._relations_match_enabled = self.hs.config.experimental.msc3772_enabled + async def _get_rules_for_event( self, event: EventBase, context: EventContext ) -> Dict[str, List[Dict[str, Any]]]: @@ -192,6 +196,60 @@ class BulkPushRuleEvaluator: return pl_event.content if pl_event else {}, sender_level + async def _get_mutual_relations( + self, event: EventBase, rules: Iterable[Dict[str, Any]] + ) -> Dict[str, Set[Tuple[str, str]]]: + """ + Fetch event metadata for events which related to the same event as the given event. + + If the given event has no relation information, returns an empty dictionary. + + Args: + event_id: The event ID which is targeted by relations. + rules: The push rules which will be processed for this event. + + Returns: + A dictionary of relation type to: + A set of tuples of: + The sender + The event type + """ + + # If the experimental feature is not enabled, skip fetching relations. + if not self._relations_match_enabled: + return {} + + # If the event does not have a relation, then cannot have any mutual + # relations. + relation = relation_from_event(event) + if not relation: + return {} + + # Pre-filter to figure out which relation types are interesting. + rel_types = set() + for rule in rules: + # Skip disabled rules. + if "enabled" in rule and not rule["enabled"]: + continue + + for condition in rule["conditions"]: + if condition["kind"] != "org.matrix.msc3772.relation_match": + continue + + # rel_type is required. + rel_type = condition.get("rel_type") + if rel_type: + rel_types.add(rel_type) + + # If no valid rules were found, no mutual relations. + if not rel_types: + return {} + + # If any valid rules were found, fetch the mutual relations. + return await self.store.get_mutual_event_relations( + relation.parent_id, rel_types + ) + @measure_func("action_for_event_by_user") async def action_for_event_by_user( self, event: EventBase, context: EventContext @@ -216,8 +274,17 @@ class BulkPushRuleEvaluator: sender_power_level, ) = await self._get_power_levels_and_sender_level(event, context) + relations = await self._get_mutual_relations( + event, itertools.chain(*rules_by_user.values()) + ) + evaluator = PushRuleEvaluatorForEvent( - event, len(room_members), sender_power_level, power_levels + event, + len(room_members), + sender_power_level, + power_levels, + relations, + self._relations_match_enabled, ) # If the event is not a state event check if any users ignore the sender. diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py index 63b22d50ae..5117ef6854 100644 --- a/synapse/push/clientformat.py +++ b/synapse/push/clientformat.py @@ -48,6 +48,10 @@ def format_push_rules_for_user( elif pattern_type == "user_localpart": c["pattern"] = user.localpart + sender_type = c.pop("sender_type", None) + if sender_type == "user_id": + c["sender"] = user.to_string() + rulearray = rules["global"][template_name] template_rule = _rule_to_template(r) diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 54db6b5612..2e8a017add 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -15,7 +15,7 @@ import logging import re -from typing import Any, Dict, List, Mapping, Optional, Pattern, Tuple, Union +from typing import Any, Dict, List, Mapping, Optional, Pattern, Set, Tuple, Union from matrix_common.regex import glob_to_regex, to_word_pattern @@ -120,11 +120,15 @@ class PushRuleEvaluatorForEvent: room_member_count: int, sender_power_level: int, power_levels: Dict[str, Union[int, Dict[str, int]]], + relations: Dict[str, Set[Tuple[str, str]]], + relations_match_enabled: bool, ): self._event = event self._room_member_count = room_member_count self._sender_power_level = sender_power_level self._power_levels = power_levels + self._relations = relations + self._relations_match_enabled = relations_match_enabled # Maps strings of e.g. 'content.body' -> event["content"]["body"] self._value_cache = _flatten_dict(event) @@ -188,7 +192,16 @@ class PushRuleEvaluatorForEvent: return _sender_notification_permission( self._event, condition, self._sender_power_level, self._power_levels ) + elif ( + condition["kind"] == "org.matrix.msc3772.relation_match" + and self._relations_match_enabled + ): + return self._relation_match(condition, user_id) else: + # XXX This looks incorrect -- we have reached an unknown condition + # kind and are unconditionally returning that it matches. Note + # that it seems possible to provide a condition to the /pushrules + # endpoint with an unknown kind, see _rule_tuple_from_request_object. return True def _event_match(self, condition: dict, user_id: str) -> bool: @@ -256,6 +269,41 @@ class PushRuleEvaluatorForEvent: return bool(r.search(body)) + def _relation_match(self, condition: dict, user_id: str) -> bool: + """ + Check an "relation_match" push rule condition. + + Args: + condition: The "event_match" push rule condition to match. + user_id: The user's MXID. + + Returns: + True if the condition matches the event, False otherwise. + """ + rel_type = condition.get("rel_type") + if not rel_type: + logger.warning("relation_match condition missing rel_type") + return False + + sender_pattern = condition.get("sender") + if sender_pattern is None: + sender_type = condition.get("sender_type") + if sender_type == "user_id": + sender_pattern = user_id + type_pattern = condition.get("type") + + # If any other relations matches, return True. + for sender, event_type in self._relations.get(rel_type, ()): + if sender_pattern and not _glob_matches(sender_pattern, sender): + continue + if type_pattern and not _glob_matches(type_pattern, event_type): + continue + # All values must have matched. + return True + + # No relations matched. + return False + # Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches regex_cache: LruCache[Tuple[str, bool, bool], Pattern] = LruCache( diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 0df8ff5395..17e35cf63e 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1828,6 +1828,10 @@ class PersistEventsStore: self.store.get_aggregation_groups_for_event.invalidate, (relation.parent_id,), ) + txn.call_after( + self.store.get_mutual_event_relations_for_rel_type.invalidate, + (relation.parent_id,), + ) if relation.rel_type == RelationTypes.REPLACE: txn.call_after( @@ -2004,6 +2008,11 @@ class PersistEventsStore: self.store._invalidate_cache_and_stream( txn, self.store.get_thread_participated, (redacted_relates_to,) ) + self.store._invalidate_cache_and_stream( + txn, + self.store.get_mutual_event_relations_for_rel_type, + (redacted_relates_to,), + ) self.db_pool.simple_delete_txn( txn, table="event_relations", keyvalues={"event_id": redacted_event_id} diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index ad67901cc1..4adabc88cc 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -61,6 +61,11 @@ def _is_experimental_rule_enabled( and not experimental_config.msc3786_enabled ): return False + if ( + rule_id == "global/underride/.org.matrix.msc3772.thread_reply" + and not experimental_config.msc3772_enabled + ): + return False return True diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index fe8fded88b..3b1b2ce6cb 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -13,6 +13,7 @@ # limitations under the License. import logging +from collections import defaultdict from typing import ( Collection, Dict, @@ -767,6 +768,57 @@ class RelationsWorkerStore(SQLBaseStore): "get_if_user_has_annotated_event", _get_if_user_has_annotated_event ) + @cached(iterable=True) + async def get_mutual_event_relations_for_rel_type( + self, event_id: str, relation_type: str + ) -> Set[Tuple[str, str]]: + raise NotImplementedError() + + @cachedList( + cached_method_name="get_mutual_event_relations_for_rel_type", + list_name="relation_types", + ) + async def get_mutual_event_relations( + self, event_id: str, relation_types: Collection[str] + ) -> Dict[str, Set[Tuple[str, str]]]: + """ + Fetch event metadata for events which related to the same event as the given event. + + If the given event has no relation information, returns an empty dictionary. + + Args: + event_id: The event ID which is targeted by relations. + relation_types: The relation types to check for mutual relations. + + Returns: + A dictionary of relation type to: + A set of tuples of: + The sender + The event type + """ + rel_type_sql, rel_type_args = make_in_list_sql_clause( + self.database_engine, "relation_type", relation_types + ) + + sql = f""" + SELECT DISTINCT relation_type, sender, type FROM event_relations + INNER JOIN events USING (event_id) + WHERE relates_to_id = ? AND {rel_type_sql} + """ + + def _get_event_relations( + txn: LoggingTransaction, + ) -> Dict[str, Set[Tuple[str, str]]]: + txn.execute(sql, [event_id] + rel_type_args) + result = defaultdict(set) + for rel_type, sender, type in txn.fetchall(): + result[rel_type].add((sender, type)) + return result + + return await self.db_pool.runInteraction( + "get_event_relations", _get_event_relations + ) + class RelationsStore(RelationsWorkerStore): pass diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index 5dba187076..9b623d0033 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Optional, Union +from typing import Dict, Optional, Set, Tuple, Union import frozendict @@ -26,7 +26,12 @@ from tests import unittest class PushRuleEvaluatorTestCase(unittest.TestCase): - def _get_evaluator(self, content: JsonDict) -> PushRuleEvaluatorForEvent: + def _get_evaluator( + self, + content: JsonDict, + relations: Optional[Dict[str, Set[Tuple[str, str]]]] = None, + relations_match_enabled: bool = False, + ) -> PushRuleEvaluatorForEvent: event = FrozenEvent( { "event_id": "$event_id", @@ -42,7 +47,12 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): sender_power_level = 0 power_levels: Dict[str, Union[int, Dict[str, int]]] = {} return PushRuleEvaluatorForEvent( - event, room_member_count, sender_power_level, power_levels + event, + room_member_count, + sender_power_level, + power_levels, + relations or set(), + relations_match_enabled, ) def test_display_name(self) -> None: @@ -276,3 +286,71 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): push_rule_evaluator.tweaks_for_actions(actions), {"sound": "default", "highlight": True}, ) + + def test_relation_match(self) -> None: + """Test the relation_match push rule kind.""" + + # Check if the experimental feature is disabled. + evaluator = self._get_evaluator( + {}, {"m.annotation": {("@user:test", "m.reaction")}} + ) + condition = {"kind": "relation_match"} + # Oddly, an unknown condition always matches. + self.assertTrue(evaluator.matches(condition, "@user:test", "foo")) + + # A push rule evaluator with the experimental rule enabled. + evaluator = self._get_evaluator( + {}, {"m.annotation": {("@user:test", "m.reaction")}}, True + ) + + # Check just relation type. + condition = { + "kind": "org.matrix.msc3772.relation_match", + "rel_type": "m.annotation", + } + self.assertTrue(evaluator.matches(condition, "@user:test", "foo")) + + # Check relation type and sender. + condition = { + "kind": "org.matrix.msc3772.relation_match", + "rel_type": "m.annotation", + "sender": "@user:test", + } + self.assertTrue(evaluator.matches(condition, "@user:test", "foo")) + condition = { + "kind": "org.matrix.msc3772.relation_match", + "rel_type": "m.annotation", + "sender": "@other:test", + } + self.assertFalse(evaluator.matches(condition, "@user:test", "foo")) + + # Check relation type and event type. + condition = { + "kind": "org.matrix.msc3772.relation_match", + "rel_type": "m.annotation", + "type": "m.reaction", + } + self.assertTrue(evaluator.matches(condition, "@user:test", "foo")) + + # Check just sender, this fails since rel_type is required. + condition = { + "kind": "org.matrix.msc3772.relation_match", + "sender": "@user:test", + } + self.assertFalse(evaluator.matches(condition, "@user:test", "foo")) + + # Check sender glob. + condition = { + "kind": "org.matrix.msc3772.relation_match", + "rel_type": "m.annotation", + "sender": "@*:test", + } + self.assertTrue(evaluator.matches(condition, "@user:test", "foo")) + + # Check event type glob. + condition = { + "kind": "org.matrix.msc3772.relation_match", + "rel_type": "m.annotation", + "event_type": "*.reaction", + } + self.assertTrue(evaluator.matches(condition, "@user:test", "foo")) From 6855024e0a363ff09d50586dcf1b089b77ac3b0c Mon Sep 17 00:00:00 2001 From: Will Hunt Date: Tue, 24 May 2022 15:39:54 +0100 Subject: [PATCH 02/74] Add authentication to thirdparty bridge APIs (#12746) Co-authored-by: Brendan Abolivier --- changelog.d/12746.bugfix | 1 + synapse/appservice/api.py | 15 ++++-- tests/appservice/test_api.py | 102 +++++++++++++++++++++++++++++++++++ 3 files changed, 115 insertions(+), 3 deletions(-) create mode 100644 changelog.d/12746.bugfix create mode 100644 tests/appservice/test_api.py diff --git a/changelog.d/12746.bugfix b/changelog.d/12746.bugfix new file mode 100644 index 0000000000..67e7fc854c --- /dev/null +++ b/changelog.d/12746.bugfix @@ -0,0 +1 @@ +Always send an `access_token` in `/thirdparty/` requests to appservices, as required by the [Matrix specification](https://spec.matrix.org/v1.1/application-service-api/#third-party-networks). \ No newline at end of file diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index d19f8dd996..df1c214462 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -14,7 +14,7 @@ # limitations under the License. import logging import urllib.parse -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Tuple from prometheus_client import Counter from typing_extensions import TypeGuard @@ -155,6 +155,9 @@ class ApplicationServiceApi(SimpleHttpClient): if service.url is None: return [] + # This is required by the configuration. + assert service.hs_token is not None + uri = "%s%s/thirdparty/%s/%s" % ( service.url, APP_SERVICE_PREFIX, @@ -162,7 +165,11 @@ class ApplicationServiceApi(SimpleHttpClient): urllib.parse.quote(protocol), ) try: - response = await self.get_json(uri, fields) + args: Mapping[Any, Any] = { + **fields, + b"access_token": service.hs_token, + } + response = await self.get_json(uri, args=args) if not isinstance(response, list): logger.warning( "query_3pe to %s returned an invalid response %r", uri, response @@ -190,13 +197,15 @@ class ApplicationServiceApi(SimpleHttpClient): return {} async def _get() -> Optional[JsonDict]: + # This is required by the configuration. + assert service.hs_token is not None uri = "%s%s/thirdparty/protocol/%s" % ( service.url, APP_SERVICE_PREFIX, urllib.parse.quote(protocol), ) try: - info = await self.get_json(uri) + info = await self.get_json(uri, {"access_token": service.hs_token}) if not _is_valid_3pe_metadata(info): logger.warning( diff --git a/tests/appservice/test_api.py b/tests/appservice/test_api.py new file mode 100644 index 0000000000..3e0db4dd98 --- /dev/null +++ b/tests/appservice/test_api.py @@ -0,0 +1,102 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, List, Mapping +from unittest.mock import Mock + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.appservice import ApplicationService +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock + +from tests import unittest + +PROTOCOL = "myproto" +TOKEN = "myastoken" +URL = "http://mytestservice" + + +class ApplicationServiceApiTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): + self.api = hs.get_application_service_api() + self.service = ApplicationService( + id="unique_identifier", + sender="@as:test", + url=URL, + token="unused", + hs_token=TOKEN, + hostname="myserver", + ) + + def test_query_3pe_authenticates_token(self): + """ + Tests that 3pe queries to the appservice are authenticated + with the appservice's token. + """ + + SUCCESS_RESULT_USER = [ + { + "protocol": PROTOCOL, + "userid": "@a:user", + "fields": { + "more": "fields", + }, + } + ] + SUCCESS_RESULT_LOCATION = [ + { + "protocol": PROTOCOL, + "alias": "#a:room", + "fields": { + "more": "fields", + }, + } + ] + + URL_USER = f"{URL}/_matrix/app/unstable/thirdparty/user/{PROTOCOL}" + URL_LOCATION = f"{URL}/_matrix/app/unstable/thirdparty/location/{PROTOCOL}" + + self.request_url = None + + async def get_json(url: str, args: Mapping[Any, Any]) -> List[JsonDict]: + if not args.get(b"access_token"): + raise RuntimeError("Access token not provided") + + self.assertEqual(args.get(b"access_token"), TOKEN) + self.request_url = url + if url == URL_USER: + return SUCCESS_RESULT_USER + elif url == URL_LOCATION: + return SUCCESS_RESULT_LOCATION + else: + raise RuntimeError( + "URL provided was invalid. This should never be seen." + ) + + # We assign to a method, which mypy doesn't like. + self.api.get_json = Mock(side_effect=get_json) # type: ignore[assignment] + + result = self.get_success( + self.api.query_3pe(self.service, "user", PROTOCOL, {b"some": [b"field"]}) + ) + self.assertEqual(self.request_url, URL_USER) + self.assertEqual(result, SUCCESS_RESULT_USER) + result = self.get_success( + self.api.query_3pe( + self.service, "location", PROTOCOL, {b"some": [b"field"]} + ) + ) + self.assertEqual(self.request_url, URL_LOCATION) + self.assertEqual(result, SUCCESS_RESULT_LOCATION) From 042e47970b15260eeb7e3162e4406b4f2e94008c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0imon=20Brandner?= Date: Tue, 24 May 2022 18:42:32 +0200 Subject: [PATCH 03/74] Remove `dont_notify` from the `.m.rule.room.server_acl` rule (#12849) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Šimon Brandner --- changelog.d/12849.misc | 1 + synapse/push/baserules.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/12849.misc diff --git a/changelog.d/12849.misc b/changelog.d/12849.misc new file mode 100644 index 0000000000..4c2a15ce2b --- /dev/null +++ b/changelog.d/12849.misc @@ -0,0 +1 @@ +Remove `dont_notify` from the `.m.rule.room.server_acl` rule. \ No newline at end of file diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py index 4c7278b5a1..819bc9e9b6 100644 --- a/synapse/push/baserules.py +++ b/synapse/push/baserules.py @@ -292,7 +292,7 @@ BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [ "_cache_key": "_room_server_acl", } ], - "actions": ["dont_notify"], + "actions": [], }, ] From 81d9f2a8e9ee2d18f4ed9cc6d39fd9c2e793bc62 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 24 May 2022 17:50:50 +0100 Subject: [PATCH 04/74] Fixes to MSC3787 implementation (#12858) --- changelog.d/12858.bugfix | 1 + scripts-dev/complement.sh | 2 +- synapse/handlers/room_summary.py | 3 ++- synapse/storage/databases/main/room.py | 35 +++++++++++++------------- 4 files changed, 21 insertions(+), 20 deletions(-) create mode 100644 changelog.d/12858.bugfix diff --git a/changelog.d/12858.bugfix b/changelog.d/12858.bugfix new file mode 100644 index 0000000000..7a7ddc9a13 --- /dev/null +++ b/changelog.d/12858.bugfix @@ -0,0 +1 @@ +Fix [MSC3878](https://github.com/matrix-org/matrix-spec-proposals/pull/3787) rooms being omitted from room directory, room summary and space hierarchy responses. diff --git a/scripts-dev/complement.sh b/scripts-dev/complement.sh index ca476d9a5e..3c472c576e 100755 --- a/scripts-dev/complement.sh +++ b/scripts-dev/complement.sh @@ -45,7 +45,7 @@ docker build -t matrixdotorg/synapse -f "docker/Dockerfile" . extra_test_args=() -test_tags="synapse_blacklist,msc2716,msc3030" +test_tags="synapse_blacklist,msc2716,msc3030,msc3787" # If we're using workers, modify the docker files slightly. if [[ -n "$WORKERS" ]]; then diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py index af83de3193..1dd74912fa 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py @@ -662,7 +662,8 @@ class RoomSummaryHandler: # The API doesn't return the room version so assume that a # join rule of knock is valid. if ( - room.get("join_rules") in (JoinRules.PUBLIC, JoinRules.KNOCK) + room.get("join_rules") + in (JoinRules.PUBLIC, JoinRules.KNOCK, JoinRules.KNOCK_RESTRICTED) or room.get("world_readable") is True ): return True diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index ded15b92ef..10f2ceb50b 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -233,24 +233,23 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): UNION SELECT room_id from appservice_room_list """ - sql = """ + sql = f""" SELECT COUNT(*) FROM ( - %(published_sql)s + {published_sql} ) published INNER JOIN room_stats_state USING (room_id) INNER JOIN room_stats_current USING (room_id) WHERE ( - join_rules = 'public' OR join_rules = '%(knock_join_rule)s' + join_rules = '{JoinRules.PUBLIC}' + OR join_rules = '{JoinRules.KNOCK}' + OR join_rules = '{JoinRules.KNOCK_RESTRICTED}' OR history_visibility = 'world_readable' ) AND joined_members > 0 - """ % { - "published_sql": published_sql, - "knock_join_rule": JoinRules.KNOCK, - } + """ txn.execute(sql, query_args) return cast(Tuple[int], txn.fetchone())[0] @@ -369,29 +368,29 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): if where_clauses: where_clause = " AND " + " AND ".join(where_clauses) - sql = """ + dir = "DESC" if forwards else "ASC" + sql = f""" SELECT room_id, name, topic, canonical_alias, joined_members, avatar, history_visibility, guest_access, join_rules FROM ( - %(published_sql)s + {published_sql} ) published INNER JOIN room_stats_state USING (room_id) INNER JOIN room_stats_current USING (room_id) WHERE ( - join_rules = 'public' OR join_rules = '%(knock_join_rule)s' + join_rules = '{JoinRules.PUBLIC}' + OR join_rules = '{JoinRules.KNOCK}' + OR join_rules = '{JoinRules.KNOCK_RESTRICTED}' OR history_visibility = 'world_readable' ) AND joined_members > 0 - %(where_clause)s - ORDER BY joined_members %(dir)s, room_id %(dir)s - """ % { - "published_sql": published_sql, - "where_clause": where_clause, - "dir": "DESC" if forwards else "ASC", - "knock_join_rule": JoinRules.KNOCK, - } + {where_clause} + ORDER BY + joined_members {dir}, + room_id {dir} + """ if limit is not None: query_args.append(limit) From e7c77a8750094616419720379afa02506e716c7d Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 24 May 2022 19:17:21 +0100 Subject: [PATCH 05/74] Correct annotation of `_iterate_over_text` (#12860) --- changelog.d/12860.misc | 1 + synapse/rest/media/v1/preview_html.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/12860.misc diff --git a/changelog.d/12860.misc b/changelog.d/12860.misc new file mode 100644 index 0000000000..b7d2943023 --- /dev/null +++ b/changelog.d/12860.misc @@ -0,0 +1 @@ +Correct a type annotation in the URL preview source code. diff --git a/synapse/rest/media/v1/preview_html.py b/synapse/rest/media/v1/preview_html.py index ca73965fc2..e72c8987cc 100644 --- a/synapse/rest/media/v1/preview_html.py +++ b/synapse/rest/media/v1/preview_html.py @@ -281,7 +281,7 @@ def parse_html_description(tree: "etree.Element") -> Optional[str]: def _iterate_over_text( - tree: "etree.Element", *tags_to_ignore: Iterable[Union[str, "etree.Comment"]] + tree: "etree.Element", *tags_to_ignore: Union[str, "etree.Comment"] ) -> Generator[str, None, None]: """Iterate over the tree returning text nodes in a depth first fashion, skipping text nodes inside certain tags. From 298911555c2572da823398f2816846f7353e89e9 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Wed, 25 May 2022 11:14:03 +0200 Subject: [PATCH 06/74] Fix typos in documentation (#12863) --- changelog.d/12863.doc | 1 + docs/message_retention_policies.md | 2 +- docs/structured_logging.md | 2 +- docs/workers.md | 2 +- 4 files changed, 4 insertions(+), 3 deletions(-) create mode 100644 changelog.d/12863.doc diff --git a/changelog.d/12863.doc b/changelog.d/12863.doc new file mode 100644 index 0000000000..94f7b8371a --- /dev/null +++ b/changelog.d/12863.doc @@ -0,0 +1 @@ +Fix typos in documentation. \ No newline at end of file diff --git a/docs/message_retention_policies.md b/docs/message_retention_policies.md index 9214d6d7e9..b52c4aaa24 100644 --- a/docs/message_retention_policies.md +++ b/docs/message_retention_policies.md @@ -117,7 +117,7 @@ In this example, we define three jobs: Note that this example is tailored to show different configurations and features slightly more jobs than it's probably necessary (in practice, a server admin would probably consider it better to replace the two last -jobs with one that runs once a day and handles rooms which which +jobs with one that runs once a day and handles rooms which policy's `max_lifetime` is greater than 3 days). Keep in mind, when configuring these jobs, that a purge job can become diff --git a/docs/structured_logging.md b/docs/structured_logging.md index a6667e1a11..d43dc9eb6e 100644 --- a/docs/structured_logging.md +++ b/docs/structured_logging.md @@ -43,7 +43,7 @@ loggers: The above logging config will set Synapse as 'INFO' logging level by default, with the SQL layer at 'WARNING', and will log to a file, stored as JSON. -It is also possible to figure Synapse to log to a remote endpoint by using the +It is also possible to configure Synapse to log to a remote endpoint by using the `synapse.logging.RemoteHandler` class included with Synapse. It takes the following arguments: diff --git a/docs/workers.md b/docs/workers.md index 779069b817..5033722098 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -1,6 +1,6 @@ # Scaling synapse via workers -For small instances it recommended to run Synapse in the default monolith mode. +For small instances it is recommended to run Synapse in the default monolith mode. For larger instances where performance is a concern it can be helpful to split out functionality into multiple separate python processes. These processes are called 'workers', and are (eventually) intended to scale horizontally From 774ac4930dbb0e6f2f6dad4b9eb4630154e1e161 Mon Sep 17 00:00:00 2001 From: Carl Bordum Hansen Date: Wed, 25 May 2022 11:14:45 +0200 Subject: [PATCH 07/74] Make sure `prev_ids` defaults to empty list (#12829) Signed-off-by: Carl Bordum Hansen --- changelog.d/12829.bugfix | 1 + synapse/handlers/device.py | 4 ++++ 2 files changed, 5 insertions(+) create mode 100644 changelog.d/12829.bugfix diff --git a/changelog.d/12829.bugfix b/changelog.d/12829.bugfix new file mode 100644 index 0000000000..dfa1fed34e --- /dev/null +++ b/changelog.d/12829.bugfix @@ -0,0 +1 @@ +Fix a bug where we did not correctly handle invalid device list updates over federation. Contributed by Carl Bordum Hansen. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 1d6d1f8a92..e59937fd75 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -763,6 +763,10 @@ class DeviceListUpdater: device_id = edu_content.pop("device_id") stream_id = str(edu_content.pop("stream_id")) # They may come as ints prev_ids = edu_content.pop("prev_id", []) + if not isinstance(prev_ids, list): + raise SynapseError( + 400, "Device list update had an invalid 'prev_ids' field" + ) prev_ids = [str(p) for p in prev_ids] # They may come as ints if get_domain_from_id(user_id) != origin: From b4fab0b14f7167c907286ea065d65b5370ba8221 Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Wed, 25 May 2022 10:20:34 +0100 Subject: [PATCH 08/74] Fix incorrect worker-allowed path in documentation (#12867) --- changelog.d/12867.doc | 1 + docs/workers.md | 3 --- 2 files changed, 1 insertion(+), 3 deletions(-) create mode 100644 changelog.d/12867.doc diff --git a/changelog.d/12867.doc b/changelog.d/12867.doc new file mode 100644 index 0000000000..1caeb7a290 --- /dev/null +++ b/changelog.d/12867.doc @@ -0,0 +1 @@ +Fix documentation incorrectly stating the `sendToDevice` endpoint can be directed at generic workers. Contributed by Nick @ Beeper. diff --git a/docs/workers.md b/docs/workers.md index 5033722098..25b9338e57 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -237,9 +237,6 @@ information. ^/_matrix/client/(api/v1|r0|v3|unstable)/join/ ^/_matrix/client/(api/v1|r0|v3|unstable)/profile/ - # Device requests - ^/_matrix/client/(r0|v3|unstable)/sendToDevice/ - # Account data requests ^/_matrix/client/(r0|v3|unstable)/.*/tags ^/_matrix/client/(r0|v3|unstable)/.*/account_data From 2e5f88b5e69fa4d7385b32d9c439e0073e8d6916 Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Wed, 25 May 2022 10:41:41 +0100 Subject: [PATCH 09/74] Add the `/account/whoami` endpoint to generic workers (#12866) --- changelog.d/12866.misc | 1 + docs/workers.md | 1 + synapse/app/generic_worker.py | 3 ++- 3 files changed, 4 insertions(+), 1 deletion(-) create mode 100644 changelog.d/12866.misc diff --git a/changelog.d/12866.misc b/changelog.d/12866.misc new file mode 100644 index 0000000000..3f7ef59253 --- /dev/null +++ b/changelog.d/12866.misc @@ -0,0 +1 @@ +Enable the `/account/whoami` endpoint on synapse worker processes. Contributed by Nick @ Beeper. diff --git a/docs/workers.md b/docs/workers.md index 25b9338e57..3c3360ccb4 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -208,6 +208,7 @@ information. ^/_matrix/client/(v1|unstable/org.matrix.msc2946)/rooms/.*/hierarchy$ ^/_matrix/client/unstable/im.nheko.summary/rooms/.*/summary$ ^/_matrix/client/(r0|v3|unstable)/account/3pid$ + ^/_matrix/client/(r0|v3|unstable)/account/whoami$ ^/_matrix/client/(r0|v3|unstable)/devices$ ^/_matrix/client/versions$ ^/_matrix/client/(api/v1|r0|v3|unstable)/voip/turnServer$ diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 2a9480a5c1..39d9db8d98 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -87,7 +87,7 @@ from synapse.rest.client import ( voip, ) from synapse.rest.client._base import client_patterns -from synapse.rest.client.account import ThreepidRestServlet +from synapse.rest.client.account import ThreepidRestServlet, WhoamiRestServlet from synapse.rest.client.devices import DevicesRestServlet from synapse.rest.client.keys import ( KeyChangesServlet, @@ -289,6 +289,7 @@ class GenericWorkerServer(HomeServer): RegistrationTokenValidityRestServlet(self).register(resource) login.register_servlets(self, resource) ThreepidRestServlet(self).register(resource) + WhoamiRestServlet(self).register(resource) DevicesRestServlet(self).register(resource) # Read-only From 33e2916858c0503a54be1c01e242123dcfb02e21 Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Wed, 25 May 2022 10:46:05 +0100 Subject: [PATCH 10/74] Don't create empty AS txns when the AS is down (#12869) --- changelog.d/12869.misc | 1 + synapse/appservice/scheduler.py | 6 +++++- 2 files changed, 6 insertions(+), 1 deletion(-) create mode 100644 changelog.d/12869.misc diff --git a/changelog.d/12869.misc b/changelog.d/12869.misc new file mode 100644 index 0000000000..1d9d1c8921 --- /dev/null +++ b/changelog.d/12869.misc @@ -0,0 +1 @@ +Don't generate empty AS transactions when the AS is flagged as down. Contributed by Nick @ Beeper. diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index 3b49e60716..de5e5216c2 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -384,6 +384,11 @@ class _TransactionController: device_list_summary: The device list summary to include in the transaction. """ try: + service_is_up = await self._is_service_up(service) + # Don't create empty txns when in recovery mode (ephemeral events are dropped) + if not service_is_up and not events: + return + txn = await self.store.create_appservice_txn( service=service, events=events, @@ -393,7 +398,6 @@ class _TransactionController: unused_fallback_keys=unused_fallback_keys or {}, device_list_summary=device_list_summary or DeviceListUpdates(), ) - service_is_up = await self._is_service_up(service) if service_is_up: sent = await txn.send(self.as_api) if sent: From 1f9013ce60ac7c2b75ea1bfacb9314239e4e0cff Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Wed, 25 May 2022 10:51:07 +0100 Subject: [PATCH 11/74] Add the `batch_send` endpoint to generic workers (#12868) --- changelog.d/12868.misc | 1 + docker/configure_workers_and_start.py | 1 + docs/workers.md | 1 + synapse/app/generic_worker.py | 2 ++ 4 files changed, 5 insertions(+) create mode 100644 changelog.d/12868.misc diff --git a/changelog.d/12868.misc b/changelog.d/12868.misc new file mode 100644 index 0000000000..382a876dab --- /dev/null +++ b/changelog.d/12868.misc @@ -0,0 +1 @@ +Enable the `batch_send` endpoint on synapse worker processes. Contributed by Nick @ Beeper. diff --git a/docker/configure_workers_and_start.py b/docker/configure_workers_and_start.py index b6ad141173..f7dac90222 100755 --- a/docker/configure_workers_and_start.py +++ b/docker/configure_workers_and_start.py @@ -158,6 +158,7 @@ WORKERS_CONFIG: Dict[str, Dict[str, Any]] = { "^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/(join|invite|leave|ban|unban|kick)$", "^/_matrix/client/(api/v1|r0|v3|unstable)/join/", "^/_matrix/client/(api/v1|r0|v3|unstable)/profile/", + "^/_matrix/client/(v1|unstable/org.matrix.msc2716)/rooms/.*/batch_send", ], "shared_extra_conf": {}, "worker_extra_conf": "", diff --git a/docs/workers.md b/docs/workers.md index 3c3360ccb4..6a76f43fa1 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -206,6 +206,7 @@ information. ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/members$ ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/state$ ^/_matrix/client/(v1|unstable/org.matrix.msc2946)/rooms/.*/hierarchy$ + ^/_matrix/client/(v1|unstable/org.matrix.msc2716)/rooms/.*/batch_send$ ^/_matrix/client/unstable/im.nheko.summary/rooms/.*/summary$ ^/_matrix/client/(r0|v3|unstable)/account/3pid$ ^/_matrix/client/(r0|v3|unstable)/account/whoami$ diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 39d9db8d98..c0d007bb79 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -78,6 +78,7 @@ from synapse.rest.client import ( read_marker, receipts, room, + room_batch, room_keys, sendtodevice, sync, @@ -309,6 +310,7 @@ class GenericWorkerServer(HomeServer): room.register_servlets(self, resource, is_worker=True) room.register_deprecated_servlets(self, resource) initial_sync.register_servlets(self, resource) + room_batch.register_servlets(self, resource) room_keys.register_servlets(self, resource) tags.register_servlets(self, resource) account_data.register_servlets(self, resource) From 6aeee9a19deb68ed071ddde7150609826bfa4988 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Wed, 25 May 2022 11:19:22 +0100 Subject: [PATCH 12/74] Correct typo in changelog for #12858. --- changelog.d/12858.bugfix | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/changelog.d/12858.bugfix b/changelog.d/12858.bugfix index 7a7ddc9a13..8c95a3e3a3 100644 --- a/changelog.d/12858.bugfix +++ b/changelog.d/12858.bugfix @@ -1 +1 @@ -Fix [MSC3878](https://github.com/matrix-org/matrix-spec-proposals/pull/3787) rooms being omitted from room directory, room summary and space hierarchy responses. +Fix [MSC3787](https://github.com/matrix-org/matrix-spec-proposals/pull/3787) rooms being omitted from room directory, room summary and space hierarchy responses. From 4cbcd4a99959a4aaa04c023812f02d9c27e4945f Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 25 May 2022 07:49:12 -0400 Subject: [PATCH 13/74] Misc clean-up of push rules datastore (#12856) --- changelog.d/12856.misc | 1 + synapse/storage/databases/main/push_rule.py | 16 +++++----------- 2 files changed, 6 insertions(+), 11 deletions(-) create mode 100644 changelog.d/12856.misc diff --git a/changelog.d/12856.misc b/changelog.d/12856.misc new file mode 100644 index 0000000000..19ecefd9af --- /dev/null +++ b/changelog.d/12856.misc @@ -0,0 +1 @@ +Clean-up the push rules datastore. diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 4adabc88cc..d5aefe02b6 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -174,7 +174,7 @@ class PushRulesWorkerStore( "conditions", "actions", ), - desc="get_push_rules_enabled_for_user", + desc="get_push_rules_for_user", ) rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))) @@ -188,10 +188,10 @@ class PushRulesWorkerStore( results = await self.db_pool.simple_select_list( table="push_rules_enable", keyvalues={"user_name": user_id}, - retcols=("user_name", "rule_id", "enabled"), + retcols=("rule_id", "enabled"), desc="get_push_rules_enabled_for_user", ) - return {r["rule_id"]: False if r["enabled"] == 0 else True for r in results} + return {r["rule_id"]: bool(r["enabled"]) for r in results} async def have_push_rules_changed_for_user( self, user_id: str, last_id: int @@ -213,11 +213,7 @@ class PushRulesWorkerStore( "have_push_rules_changed", have_push_rules_changed_txn ) - @cachedList( - cached_method_name="get_push_rules_for_user", - list_name="user_ids", - num_args=1, - ) + @cachedList(cached_method_name="get_push_rules_for_user", list_name="user_ids") async def bulk_get_push_rules( self, user_ids: Collection[str] ) -> Dict[str, List[JsonDict]]: @@ -249,9 +245,7 @@ class PushRulesWorkerStore( return results @cachedList( - cached_method_name="get_push_rules_enabled_for_user", - list_name="user_ids", - num_args=1, + cached_method_name="get_push_rules_enabled_for_user", list_name="user_ids" ) async def bulk_get_push_rules_enabled( self, user_ids: Collection[str] From 759f9c09e1b2019b772f6baf6a40e74f79df9017 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 25 May 2022 07:49:54 -0400 Subject: [PATCH 14/74] Fix caching behavior for relations push rules. (#12859) By always returning all requested values from the function wrapped by cachedList. Otherwise implicit None values get added into the cache, which are unexpected. --- changelog.d/12859.feature | 1 + synapse/storage/databases/main/relations.py | 5 +++-- synapse/util/caches/descriptors.py | 15 ++++++++------- 3 files changed, 12 insertions(+), 9 deletions(-) create mode 100644 changelog.d/12859.feature diff --git a/changelog.d/12859.feature b/changelog.d/12859.feature new file mode 100644 index 0000000000..e674c31ae8 --- /dev/null +++ b/changelog.d/12859.feature @@ -0,0 +1 @@ +Experimental support for [MSC3772](https://github.com/matrix-org/matrix-spec-proposals/pull/3772): Push rule for mutually related events. diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 3b1b2ce6cb..b457bc189e 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -13,7 +13,6 @@ # limitations under the License. import logging -from collections import defaultdict from typing import ( Collection, Dict, @@ -810,7 +809,9 @@ class RelationsWorkerStore(SQLBaseStore): txn: LoggingTransaction, ) -> Dict[str, Set[Tuple[str, str]]]: txn.execute(sql, [event_id] + rel_type_args) - result = defaultdict(set) + result: Dict[str, Set[Tuple[str, str]]] = { + rel_type: set() for rel_type in relation_types + } for rel_type, sender, type in txn.fetchall(): result[rel_type].add((sender, type)) return result diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index eda92d864d..867f315b2a 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -595,13 +595,14 @@ def cached( def cachedList( *, cached_method_name: str, list_name: str, num_args: Optional[int] = None ) -> Callable[[F], _CachedFunction[F]]: - """Creates a descriptor that wraps a function in a `CacheListDescriptor`. + """Creates a descriptor that wraps a function in a `DeferredCacheListDescriptor`. - Used to do batch lookups for an already created cache. A single argument + Used to do batch lookups for an already created cache. One of the arguments is specified as a list that is iterated through to lookup keys in the original cache. A new tuple consisting of the (deduplicated) keys that weren't in - the cache gets passed to the original function, the result of which is stored in the - cache. + the cache gets passed to the original function, which is expected to results + in a map of key to value for each passed value. THe new results are stored in the + original cache. Note that any missing values are cached as None. Args: cached_method_name: The name of the single-item lookup method. @@ -614,11 +615,11 @@ def cachedList( Example: class Example: - @cached(num_args=2) - def do_something(self, first_arg): + @cached() + def do_something(self, first_arg, second_arg): ... - @cachedList(do_something.cache, list_name="second_args", num_args=2) + @cachedList(cached_method_name="do_something", list_name="second_args") def batch_do_something(self, first_arg, second_args): ... """ From a8db8c6eba8625f8fc224b320be6074d849ceada Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 25 May 2022 07:53:40 -0400 Subject: [PATCH 15/74] Remove user-visible groups/communities code (#12553) Makes it so that groups/communities no longer exist from a user-POV. E.g. we remove: * All API endpoints (including Client-Server, Server-Server, and admin). * Documented configuration options (and the experimental flag, which is now unused). * Special handling during room upgrades. * The `groups` section of the `/sync` response. --- changelog.d/12553.removal | 1 + docs/sample_config.yaml | 10 - .../configuration/config_documentation.md | 19 - synapse/api/constants.py | 5 - synapse/app/generic_worker.py | 4 - synapse/config/experimental.py | 3 - synapse/config/groups.py | 12 - .../federation/transport/server/__init__.py | 48 +- .../transport/server/groups_local.py | 115 --- .../transport/server/groups_server.py | 755 -------------- synapse/handlers/room_member.py | 11 - synapse/handlers/sync.py | 65 -- synapse/rest/__init__.py | 3 - synapse/rest/admin/__init__.py | 3 - synapse/rest/admin/groups.py | 50 - synapse/rest/client/groups.py | 962 ------------------ synapse/rest/client/sync.py | 8 - tests/rest/admin/test_admin.py | 90 +- tests/rest/client/test_groups.py | 56 - 19 files changed, 3 insertions(+), 2217 deletions(-) create mode 100644 changelog.d/12553.removal delete mode 100644 synapse/federation/transport/server/groups_local.py delete mode 100644 synapse/federation/transport/server/groups_server.py delete mode 100644 synapse/rest/admin/groups.py delete mode 100644 synapse/rest/client/groups.py delete mode 100644 tests/rest/client/test_groups.py diff --git a/changelog.d/12553.removal b/changelog.d/12553.removal new file mode 100644 index 0000000000..41f6fae5da --- /dev/null +++ b/changelog.d/12553.removal @@ -0,0 +1 @@ +Remove support for the non-standard groups/communities feature from Synapse. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index ee98d193cb..4388a00df1 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -2521,16 +2521,6 @@ push: # "events_default": 1 -# Uncomment to allow non-server-admin users to create groups on this server -# -#enable_group_creation: true - -# If enabled, non server admins can only create groups with local parts -# starting with this prefix -# -#group_creation_prefix: "unofficial_" - - # User Directory configuration # diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index 0f5bda32b9..8724bf27e8 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -3145,25 +3145,6 @@ Example configuration: encryption_enabled_by_default_for_room_type: invite ``` --- -Config option: `enable_group_creation` - -Set to true to allow non-server-admin users to create groups on this server - -Example configuration: -```yaml -enable_group_creation: true -``` ---- -Config option: `group_creation_prefix` - -If enabled/present, non-server admins can only create groups with local parts -starting with this prefix. - -Example configuration: -```yaml -group_creation_prefix: "unofficial_" -``` ---- Config option: `user_directory` This setting defines options related to the user directory. diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 330de21f6b..4a0552e7e5 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -31,11 +31,6 @@ MAX_ALIAS_LENGTH = 255 # the maximum length for a user id is 255 characters MAX_USERID_LENGTH = 255 -# The maximum length for a group id is 255 characters -MAX_GROUPID_LENGTH = 255 -MAX_GROUP_CATEGORYID_LENGTH = 255 -MAX_GROUP_ROLEID_LENGTH = 255 - class Membership: diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index c0d007bb79..0a6dd618f6 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -69,7 +69,6 @@ from synapse.rest.admin import register_servlets_for_media_repo from synapse.rest.client import ( account_data, events, - groups, initial_sync, login, presence, @@ -323,9 +322,6 @@ class GenericWorkerServer(HomeServer): presence.register_servlets(self, resource) - if self.config.experimental.groups_enabled: - groups.register_servlets(self, resource) - resources.update({CLIENT_API_PREFIX: resource}) resources.update(build_synapse_client_resource_tree(self)) diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index cc417e2fbf..f2dfd49b07 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -73,9 +73,6 @@ class ExperimentalConfig(Config): # MSC3720 (Account status endpoint) self.msc3720_enabled: bool = experimental.get("msc3720_enabled", False) - # The deprecated groups feature. - self.groups_enabled: bool = experimental.get("groups_enabled", False) - # MSC2654: Unread counts self.msc2654_enabled: bool = experimental.get("msc2654_enabled", False) diff --git a/synapse/config/groups.py b/synapse/config/groups.py index c9b9c6daad..baa051fdd4 100644 --- a/synapse/config/groups.py +++ b/synapse/config/groups.py @@ -25,15 +25,3 @@ class GroupsConfig(Config): def read_config(self, config: JsonDict, **kwargs: Any) -> None: self.enable_group_creation = config.get("enable_group_creation", False) self.group_creation_prefix = config.get("group_creation_prefix", "") - - def generate_config_section(self, **kwargs: Any) -> str: - return """\ - # Uncomment to allow non-server-admin users to create groups on this server - # - #enable_group_creation: true - - # If enabled, non server admins can only create groups with local parts - # starting with this prefix - # - #group_creation_prefix: "unofficial_" - """ diff --git a/synapse/federation/transport/server/__init__.py b/synapse/federation/transport/server/__init__.py index 71b2f90eb9..50623cd385 100644 --- a/synapse/federation/transport/server/__init__.py +++ b/synapse/federation/transport/server/__init__.py @@ -27,10 +27,6 @@ from synapse.federation.transport.server.federation import ( FederationAccountStatusServlet, FederationTimestampLookupServlet, ) -from synapse.federation.transport.server.groups_local import GROUP_LOCAL_SERVLET_CLASSES -from synapse.federation.transport.server.groups_server import ( - GROUP_SERVER_SERVLET_CLASSES, -) from synapse.http.server import HttpServer, JsonResource from synapse.http.servlet import ( parse_boolean_from_args, @@ -199,38 +195,6 @@ class PublicRoomList(BaseFederationServlet): return 200, data -class FederationGroupsRenewAttestaionServlet(BaseFederationServlet): - """A group or user's server renews their attestation""" - - PATH = "/groups/(?P[^/]*)/renew_attestation/(?P[^/]*)" - - def __init__( - self, - hs: "HomeServer", - authenticator: Authenticator, - ratelimiter: FederationRateLimiter, - server_name: str, - ): - super().__init__(hs, authenticator, ratelimiter, server_name) - self.handler = hs.get_groups_attestation_renewer() - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - user_id: str, - ) -> Tuple[int, JsonDict]: - # We don't need to check auth here as we check the attestation signatures - - new_content = await self.handler.on_renew_attestation( - group_id, user_id, content - ) - - return 200, new_content - - class OpenIdUserInfo(BaseFederationServlet): """ Exchange a bearer token for information about a user. @@ -292,16 +256,9 @@ class OpenIdUserInfo(BaseFederationServlet): SERVLET_GROUPS: Dict[str, Iterable[Type[BaseFederationServlet]]] = { "federation": FEDERATION_SERVLET_CLASSES, "room_list": (PublicRoomList,), - "group_server": GROUP_SERVER_SERVLET_CLASSES, - "group_local": GROUP_LOCAL_SERVLET_CLASSES, - "group_attestation": (FederationGroupsRenewAttestaionServlet,), "openid": (OpenIdUserInfo,), } -DEFAULT_SERVLET_GROUPS = ("federation", "room_list", "openid") - -GROUP_SERVLET_GROUPS = ("group_server", "group_local", "group_attestation") - def register_servlets( hs: "HomeServer", @@ -324,10 +281,7 @@ def register_servlets( Defaults to ``DEFAULT_SERVLET_GROUPS``. """ if not servlet_groups: - servlet_groups = DEFAULT_SERVLET_GROUPS - # Only allow the groups servlets if the deprecated groups feature is enabled. - if hs.config.experimental.groups_enabled: - servlet_groups = servlet_groups + GROUP_SERVLET_GROUPS + servlet_groups = SERVLET_GROUPS.keys() for servlet_group in servlet_groups: # Skip unknown servlet groups. diff --git a/synapse/federation/transport/server/groups_local.py b/synapse/federation/transport/server/groups_local.py deleted file mode 100644 index 496472e1dc..0000000000 --- a/synapse/federation/transport/server/groups_local.py +++ /dev/null @@ -1,115 +0,0 @@ -# Copyright 2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import TYPE_CHECKING, Dict, List, Tuple, Type - -from synapse.api.errors import SynapseError -from synapse.federation.transport.server._base import ( - Authenticator, - BaseFederationServlet, -) -from synapse.handlers.groups_local import GroupsLocalHandler -from synapse.types import JsonDict, get_domain_from_id -from synapse.util.ratelimitutils import FederationRateLimiter - -if TYPE_CHECKING: - from synapse.server import HomeServer - - -class BaseGroupsLocalServlet(BaseFederationServlet): - """Abstract base class for federation servlet classes which provides a groups local handler. - - See BaseFederationServlet for more information. - """ - - def __init__( - self, - hs: "HomeServer", - authenticator: Authenticator, - ratelimiter: FederationRateLimiter, - server_name: str, - ): - super().__init__(hs, authenticator, ratelimiter, server_name) - self.handler = hs.get_groups_local_handler() - - -class FederationGroupsLocalInviteServlet(BaseGroupsLocalServlet): - """A group server has invited a local user""" - - PATH = "/groups/local/(?P[^/]*)/users/(?P[^/]*)/invite" - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - user_id: str, - ) -> Tuple[int, JsonDict]: - if get_domain_from_id(group_id) != origin: - raise SynapseError(403, "group_id doesn't match origin") - - assert isinstance( - self.handler, GroupsLocalHandler - ), "Workers cannot handle group invites." - - new_content = await self.handler.on_invite(group_id, user_id, content) - - return 200, new_content - - -class FederationGroupsRemoveLocalUserServlet(BaseGroupsLocalServlet): - """A group server has removed a local user""" - - PATH = "/groups/local/(?P[^/]*)/users/(?P[^/]*)/remove" - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - user_id: str, - ) -> Tuple[int, None]: - if get_domain_from_id(group_id) != origin: - raise SynapseError(403, "user_id doesn't match origin") - - assert isinstance( - self.handler, GroupsLocalHandler - ), "Workers cannot handle group removals." - - await self.handler.user_removed_from_group(group_id, user_id, content) - - return 200, None - - -class FederationGroupsBulkPublicisedServlet(BaseGroupsLocalServlet): - """Get roles in a group""" - - PATH = "/get_groups_publicised" - - async def on_POST( - self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] - ) -> Tuple[int, JsonDict]: - resp = await self.handler.bulk_get_publicised_groups( - content["user_ids"], proxy=False - ) - - return 200, resp - - -GROUP_LOCAL_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( - FederationGroupsLocalInviteServlet, - FederationGroupsRemoveLocalUserServlet, - FederationGroupsBulkPublicisedServlet, -) diff --git a/synapse/federation/transport/server/groups_server.py b/synapse/federation/transport/server/groups_server.py deleted file mode 100644 index 851b50152e..0000000000 --- a/synapse/federation/transport/server/groups_server.py +++ /dev/null @@ -1,755 +0,0 @@ -# Copyright 2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import TYPE_CHECKING, Dict, List, Tuple, Type - -from typing_extensions import Literal - -from synapse.api.constants import MAX_GROUP_CATEGORYID_LENGTH, MAX_GROUP_ROLEID_LENGTH -from synapse.api.errors import Codes, SynapseError -from synapse.federation.transport.server._base import ( - Authenticator, - BaseFederationServlet, -) -from synapse.http.servlet import parse_string_from_args -from synapse.types import JsonDict, get_domain_from_id -from synapse.util.ratelimitutils import FederationRateLimiter - -if TYPE_CHECKING: - from synapse.server import HomeServer - - -class BaseGroupsServerServlet(BaseFederationServlet): - """Abstract base class for federation servlet classes which provides a groups server handler. - - See BaseFederationServlet for more information. - """ - - def __init__( - self, - hs: "HomeServer", - authenticator: Authenticator, - ratelimiter: FederationRateLimiter, - server_name: str, - ): - super().__init__(hs, authenticator, ratelimiter, server_name) - self.handler = hs.get_groups_server_handler() - - -class FederationGroupsProfileServlet(BaseGroupsServerServlet): - """Get/set the basic profile of a group on behalf of a user""" - - PATH = "/groups/(?P[^/]*)/profile" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - new_content = await self.handler.get_group_profile(group_id, requester_user_id) - - return 200, new_content - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - new_content = await self.handler.update_group_profile( - group_id, requester_user_id, content - ) - - return 200, new_content - - -class FederationGroupsSummaryServlet(BaseGroupsServerServlet): - PATH = "/groups/(?P[^/]*)/summary" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - new_content = await self.handler.get_group_summary(group_id, requester_user_id) - - return 200, new_content - - -class FederationGroupsRoomsServlet(BaseGroupsServerServlet): - """Get the rooms in a group on behalf of a user""" - - PATH = "/groups/(?P[^/]*)/rooms" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - new_content = await self.handler.get_rooms_in_group(group_id, requester_user_id) - - return 200, new_content - - -class FederationGroupsAddRoomsServlet(BaseGroupsServerServlet): - """Add/remove room from group""" - - PATH = "/groups/(?P[^/]*)/room/(?P[^/]*)" - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - room_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - new_content = await self.handler.add_room_to_group( - group_id, requester_user_id, room_id, content - ) - - return 200, new_content - - async def on_DELETE( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - room_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - new_content = await self.handler.remove_room_from_group( - group_id, requester_user_id, room_id - ) - - return 200, new_content - - -class FederationGroupsAddRoomsConfigServlet(BaseGroupsServerServlet): - """Update room config in group""" - - PATH = ( - "/groups/(?P[^/]*)/room/(?P[^/]*)" - "/config/(?P[^/]*)" - ) - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - room_id: str, - config_key: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - result = await self.handler.update_room_in_group( - group_id, requester_user_id, room_id, config_key, content - ) - - return 200, result - - -class FederationGroupsUsersServlet(BaseGroupsServerServlet): - """Get the users in a group on behalf of a user""" - - PATH = "/groups/(?P[^/]*)/users" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - new_content = await self.handler.get_users_in_group(group_id, requester_user_id) - - return 200, new_content - - -class FederationGroupsInvitedUsersServlet(BaseGroupsServerServlet): - """Get the users that have been invited to a group""" - - PATH = "/groups/(?P[^/]*)/invited_users" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - new_content = await self.handler.get_invited_users_in_group( - group_id, requester_user_id - ) - - return 200, new_content - - -class FederationGroupsInviteServlet(BaseGroupsServerServlet): - """Ask a group server to invite someone to the group""" - - PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/invite" - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - user_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - new_content = await self.handler.invite_to_group( - group_id, user_id, requester_user_id, content - ) - - return 200, new_content - - -class FederationGroupsAcceptInviteServlet(BaseGroupsServerServlet): - """Accept an invitation from the group server""" - - PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/accept_invite" - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - user_id: str, - ) -> Tuple[int, JsonDict]: - if get_domain_from_id(user_id) != origin: - raise SynapseError(403, "user_id doesn't match origin") - - new_content = await self.handler.accept_invite(group_id, user_id, content) - - return 200, new_content - - -class FederationGroupsJoinServlet(BaseGroupsServerServlet): - """Attempt to join a group""" - - PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/join" - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - user_id: str, - ) -> Tuple[int, JsonDict]: - if get_domain_from_id(user_id) != origin: - raise SynapseError(403, "user_id doesn't match origin") - - new_content = await self.handler.join_group(group_id, user_id, content) - - return 200, new_content - - -class FederationGroupsRemoveUserServlet(BaseGroupsServerServlet): - """Leave or kick a user from the group""" - - PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/remove" - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - user_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - new_content = await self.handler.remove_user_from_group( - group_id, user_id, requester_user_id, content - ) - - return 200, new_content - - -class FederationGroupsSummaryRoomsServlet(BaseGroupsServerServlet): - """Add/remove a room from the group summary, with optional category. - - Matches both: - - /groups/:group/summary/rooms/:room_id - - /groups/:group/summary/categories/:category/rooms/:room_id - """ - - PATH = ( - "/groups/(?P[^/]*)/summary" - "(/categories/(?P[^/]+))?" - "/rooms/(?P[^/]*)" - ) - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - category_id: str, - room_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - if category_id == "": - raise SynapseError( - 400, "category_id cannot be empty string", Codes.INVALID_PARAM - ) - - if len(category_id) > MAX_GROUP_CATEGORYID_LENGTH: - raise SynapseError( - 400, - "category_id may not be longer than %s characters" - % (MAX_GROUP_CATEGORYID_LENGTH,), - Codes.INVALID_PARAM, - ) - - resp = await self.handler.update_group_summary_room( - group_id, - requester_user_id, - room_id=room_id, - category_id=category_id, - content=content, - ) - - return 200, resp - - async def on_DELETE( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - category_id: str, - room_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - if category_id == "": - raise SynapseError(400, "category_id cannot be empty string") - - resp = await self.handler.delete_group_summary_room( - group_id, requester_user_id, room_id=room_id, category_id=category_id - ) - - return 200, resp - - -class FederationGroupsCategoriesServlet(BaseGroupsServerServlet): - """Get all categories for a group""" - - PATH = "/groups/(?P[^/]*)/categories/?" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - resp = await self.handler.get_group_categories(group_id, requester_user_id) - - return 200, resp - - -class FederationGroupsCategoryServlet(BaseGroupsServerServlet): - """Add/remove/get a category in a group""" - - PATH = "/groups/(?P[^/]*)/categories/(?P[^/]+)" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - category_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - resp = await self.handler.get_group_category( - group_id, requester_user_id, category_id - ) - - return 200, resp - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - category_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - if category_id == "": - raise SynapseError(400, "category_id cannot be empty string") - - if len(category_id) > MAX_GROUP_CATEGORYID_LENGTH: - raise SynapseError( - 400, - "category_id may not be longer than %s characters" - % (MAX_GROUP_CATEGORYID_LENGTH,), - Codes.INVALID_PARAM, - ) - - resp = await self.handler.upsert_group_category( - group_id, requester_user_id, category_id, content - ) - - return 200, resp - - async def on_DELETE( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - category_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - if category_id == "": - raise SynapseError(400, "category_id cannot be empty string") - - resp = await self.handler.delete_group_category( - group_id, requester_user_id, category_id - ) - - return 200, resp - - -class FederationGroupsRolesServlet(BaseGroupsServerServlet): - """Get roles in a group""" - - PATH = "/groups/(?P[^/]*)/roles/?" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - resp = await self.handler.get_group_roles(group_id, requester_user_id) - - return 200, resp - - -class FederationGroupsRoleServlet(BaseGroupsServerServlet): - """Add/remove/get a role in a group""" - - PATH = "/groups/(?P[^/]*)/roles/(?P[^/]+)" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - role_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - resp = await self.handler.get_group_role(group_id, requester_user_id, role_id) - - return 200, resp - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - role_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - if role_id == "": - raise SynapseError( - 400, "role_id cannot be empty string", Codes.INVALID_PARAM - ) - - if len(role_id) > MAX_GROUP_ROLEID_LENGTH: - raise SynapseError( - 400, - "role_id may not be longer than %s characters" - % (MAX_GROUP_ROLEID_LENGTH,), - Codes.INVALID_PARAM, - ) - - resp = await self.handler.update_group_role( - group_id, requester_user_id, role_id, content - ) - - return 200, resp - - async def on_DELETE( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - role_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - if role_id == "": - raise SynapseError(400, "role_id cannot be empty string") - - resp = await self.handler.delete_group_role( - group_id, requester_user_id, role_id - ) - - return 200, resp - - -class FederationGroupsSummaryUsersServlet(BaseGroupsServerServlet): - """Add/remove a user from the group summary, with optional role. - - Matches both: - - /groups/:group/summary/users/:user_id - - /groups/:group/summary/roles/:role/users/:user_id - """ - - PATH = ( - "/groups/(?P[^/]*)/summary" - "(/roles/(?P[^/]+))?" - "/users/(?P[^/]*)" - ) - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - role_id: str, - user_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - if role_id == "": - raise SynapseError(400, "role_id cannot be empty string") - - if len(role_id) > MAX_GROUP_ROLEID_LENGTH: - raise SynapseError( - 400, - "role_id may not be longer than %s characters" - % (MAX_GROUP_ROLEID_LENGTH,), - Codes.INVALID_PARAM, - ) - - resp = await self.handler.update_group_summary_user( - group_id, - requester_user_id, - user_id=user_id, - role_id=role_id, - content=content, - ) - - return 200, resp - - async def on_DELETE( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - role_id: str, - user_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - if role_id == "": - raise SynapseError(400, "role_id cannot be empty string") - - resp = await self.handler.delete_group_summary_user( - group_id, requester_user_id, user_id=user_id, role_id=role_id - ) - - return 200, resp - - -class FederationGroupsSettingJoinPolicyServlet(BaseGroupsServerServlet): - """Sets whether a group is joinable without an invite or knock""" - - PATH = "/groups/(?P[^/]*)/settings/m.join_policy" - - async def on_PUT( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - new_content = await self.handler.set_group_join_policy( - group_id, requester_user_id, content - ) - - return 200, new_content - - -GROUP_SERVER_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( - FederationGroupsProfileServlet, - FederationGroupsSummaryServlet, - FederationGroupsRoomsServlet, - FederationGroupsUsersServlet, - FederationGroupsInvitedUsersServlet, - FederationGroupsInviteServlet, - FederationGroupsAcceptInviteServlet, - FederationGroupsJoinServlet, - FederationGroupsRemoveUserServlet, - FederationGroupsSummaryRoomsServlet, - FederationGroupsCategoriesServlet, - FederationGroupsCategoryServlet, - FederationGroupsRolesServlet, - FederationGroupsRoleServlet, - FederationGroupsSummaryUsersServlet, - FederationGroupsAddRoomsServlet, - FederationGroupsAddRoomsConfigServlet, - FederationGroupsSettingJoinPolicyServlet, -) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index ea876c168d..00662dc961 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -1081,17 +1081,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # Transfer alias mappings in the room directory await self.store.update_aliases_for_room(old_room_id, room_id) - # Check if any groups we own contain the predecessor room - local_group_ids = await self.store.get_local_groups_for_room(old_room_id) - for group_id in local_group_ids: - # Add new the new room to those groups - await self.store.add_room_to_group( - group_id, room_id, old_room is not None and old_room["is_public"] - ) - - # Remove the old room from those groups - await self.store.remove_room_from_group(group_id, old_room_id) - async def copy_user_state_on_room_upgrade( self, old_room_id: str, new_room_id: str, user_ids: Iterable[str] ) -> None: diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 59b5d497be..dcbb5ce921 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -166,16 +166,6 @@ class KnockedSyncResult: return True -@attr.s(slots=True, frozen=True, auto_attribs=True) -class GroupsSyncResult: - join: JsonDict - invite: JsonDict - leave: JsonDict - - def __bool__(self) -> bool: - return bool(self.join or self.invite or self.leave) - - @attr.s(slots=True, auto_attribs=True) class _RoomChanges: """The set of room entries to include in the sync, plus the set of joined @@ -206,7 +196,6 @@ class SyncResult: for this device device_unused_fallback_key_types: List of key types that have an unused fallback key - groups: Group updates, if any """ next_batch: StreamToken @@ -220,7 +209,6 @@ class SyncResult: device_lists: DeviceListUpdates device_one_time_keys_count: JsonDict device_unused_fallback_key_types: List[str] - groups: Optional[GroupsSyncResult] def __bool__(self) -> bool: """Make the result appear empty if there are no updates. This is used @@ -236,7 +224,6 @@ class SyncResult: or self.account_data or self.to_device or self.device_lists - or self.groups ) @@ -1157,10 +1144,6 @@ class SyncHandler: await self.store.get_e2e_unused_fallback_key_types(user_id, device_id) ) - if self.hs_config.experimental.groups_enabled: - logger.debug("Fetching group data") - await self._generate_sync_entry_for_groups(sync_result_builder) - num_events = 0 # debug for https://github.com/matrix-org/synapse/issues/9424 @@ -1184,57 +1167,11 @@ class SyncHandler: archived=sync_result_builder.archived, to_device=sync_result_builder.to_device, device_lists=device_lists, - groups=sync_result_builder.groups, device_one_time_keys_count=one_time_key_counts, device_unused_fallback_key_types=unused_fallback_key_types, next_batch=sync_result_builder.now_token, ) - @measure_func("_generate_sync_entry_for_groups") - async def _generate_sync_entry_for_groups( - self, sync_result_builder: "SyncResultBuilder" - ) -> None: - user_id = sync_result_builder.sync_config.user.to_string() - since_token = sync_result_builder.since_token - now_token = sync_result_builder.now_token - - if since_token and since_token.groups_key: - results = await self.store.get_groups_changes_for_user( - user_id, since_token.groups_key, now_token.groups_key - ) - else: - results = await self.store.get_all_groups_for_user( - user_id, now_token.groups_key - ) - - invited = {} - joined = {} - left = {} - for result in results: - membership = result["membership"] - group_id = result["group_id"] - gtype = result["type"] - content = result["content"] - - if membership == "join": - if gtype == "membership": - # TODO: Add profile - content.pop("membership", None) - joined[group_id] = content["content"] - else: - joined.setdefault(group_id, {})[gtype] = content - elif membership == "invite": - if gtype == "membership": - content.pop("membership", None) - invited[group_id] = content["content"] - else: - if gtype == "membership": - left[group_id] = content["content"] - - sync_result_builder.groups = GroupsSyncResult( - join=joined, invite=invited, leave=left - ) - @measure_func("_generate_sync_entry_for_device_list") async def _generate_sync_entry_for_device_list( self, @@ -2333,7 +2270,6 @@ class SyncResultBuilder: invited knocked archived - groups to_device """ @@ -2349,7 +2285,6 @@ class SyncResultBuilder: invited: List[InvitedSyncResult] = attr.Factory(list) knocked: List[KnockedSyncResult] = attr.Factory(list) archived: List[ArchivedSyncResult] = attr.Factory(list) - groups: Optional[GroupsSyncResult] = None to_device: List[JsonDict] = attr.Factory(list) def calculate_user_changes(self) -> Tuple[Set[str], Set[str]]: diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 57c4773edc..b712215112 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -26,7 +26,6 @@ from synapse.rest.client import ( directory, events, filter, - groups, initial_sync, keys, knock, @@ -118,8 +117,6 @@ class ClientRestResource(JsonResource): thirdparty.register_servlets(hs, client_resource) sendtodevice.register_servlets(hs, client_resource) user_directory.register_servlets(hs, client_resource) - if hs.config.experimental.groups_enabled: - groups.register_servlets(hs, client_resource) room_upgrade_rest_servlet.register_servlets(hs, client_resource) room_batch.register_servlets(hs, client_resource) capabilities.register_servlets(hs, client_resource) diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index cb4d55c89d..1aa08f8d95 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -47,7 +47,6 @@ from synapse.rest.admin.federation import ( DestinationRestServlet, ListDestinationsRestServlet, ) -from synapse.rest.admin.groups import DeleteGroupAdminRestServlet from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo from synapse.rest.admin.registration_tokens import ( ListRegistrationTokensRestServlet, @@ -293,8 +292,6 @@ def register_servlets_for_client_rest_resource( ResetPasswordRestServlet(hs).register(http_server) SearchUsersRestServlet(hs).register(http_server) UserRegisterServlet(hs).register(http_server) - if hs.config.experimental.groups_enabled: - DeleteGroupAdminRestServlet(hs).register(http_server) AccountValidityRenewServlet(hs).register(http_server) # Load the media repo ones if we're using them. Otherwise load the servlets which diff --git a/synapse/rest/admin/groups.py b/synapse/rest/admin/groups.py deleted file mode 100644 index cd697e180e..0000000000 --- a/synapse/rest/admin/groups.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -from http import HTTPStatus -from typing import TYPE_CHECKING, Tuple - -from synapse.api.errors import SynapseError -from synapse.http.servlet import RestServlet -from synapse.http.site import SynapseRequest -from synapse.rest.admin._base import admin_patterns, assert_user_is_admin -from synapse.types import JsonDict - -if TYPE_CHECKING: - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - - -class DeleteGroupAdminRestServlet(RestServlet): - """Allows deleting of local groups""" - - PATTERNS = admin_patterns("/delete_group/(?P[^/]*)$") - - def __init__(self, hs: "HomeServer"): - self.group_server = hs.get_groups_server_handler() - self.is_mine_id = hs.is_mine_id - self.auth = hs.get_auth() - - async def on_POST( - self, request: SynapseRequest, group_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) - - if not self.is_mine_id(group_id): - raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only delete local groups") - - await self.group_server.delete_group(group_id, requester.user.to_string()) - return HTTPStatus.OK, {} diff --git a/synapse/rest/client/groups.py b/synapse/rest/client/groups.py deleted file mode 100644 index 7e1149c7f4..0000000000 --- a/synapse/rest/client/groups.py +++ /dev/null @@ -1,962 +0,0 @@ -# Copyright 2017 Vector Creations Ltd -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from functools import wraps -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple - -from twisted.web.server import Request - -from synapse.api.constants import ( - MAX_GROUP_CATEGORYID_LENGTH, - MAX_GROUP_ROLEID_LENGTH, - MAX_GROUPID_LENGTH, -) -from synapse.api.errors import Codes, SynapseError -from synapse.handlers.groups_local import GroupsLocalHandler -from synapse.http.server import HttpServer -from synapse.http.servlet import ( - RestServlet, - assert_params_in_dict, - parse_json_object_from_request, -) -from synapse.http.site import SynapseRequest -from synapse.types import GroupID, JsonDict - -from ._base import client_patterns - -if TYPE_CHECKING: - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - - -def _validate_group_id( - f: Callable[..., Awaitable[Tuple[int, JsonDict]]] -) -> Callable[..., Awaitable[Tuple[int, JsonDict]]]: - """Wrapper to validate the form of the group ID. - - Can be applied to any on_FOO methods that accepts a group ID as a URL parameter. - """ - - @wraps(f) - def wrapper( - self: RestServlet, request: Request, group_id: str, *args: Any, **kwargs: Any - ) -> Awaitable[Tuple[int, JsonDict]]: - if not GroupID.is_valid(group_id): - raise SynapseError(400, "%s is not a legal group ID" % (group_id,)) - - return f(self, request, group_id, *args, **kwargs) - - return wrapper - - -class GroupServlet(RestServlet): - """Get the group profile""" - - PATTERNS = client_patterns("/groups/(?P[^/]*)/profile$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_GET( - self, request: SynapseRequest, group_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - requester_user_id = requester.user.to_string() - - group_description = await self.groups_handler.get_group_profile( - group_id, requester_user_id - ) - - return 200, group_description - - @_validate_group_id - async def on_POST( - self, request: SynapseRequest, group_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - content = parse_json_object_from_request(request) - assert_params_in_dict( - content, ("name", "avatar_url", "short_description", "long_description") - ) - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot create group profiles." - await self.groups_handler.update_group_profile( - group_id, requester_user_id, content - ) - - return 200, {} - - -class GroupSummaryServlet(RestServlet): - """Get the full group summary""" - - PATTERNS = client_patterns("/groups/(?P[^/]*)/summary$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_GET( - self, request: SynapseRequest, group_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - requester_user_id = requester.user.to_string() - - get_group_summary = await self.groups_handler.get_group_summary( - group_id, requester_user_id - ) - - return 200, get_group_summary - - -class GroupSummaryRoomsCatServlet(RestServlet): - """Update/delete a rooms entry in the summary. - - Matches both: - - /groups/:group/summary/rooms/:room_id - - /groups/:group/summary/categories/:category/rooms/:room_id - """ - - PATTERNS = client_patterns( - "/groups/(?P[^/]*)/summary" - "(/categories/(?P[^/]+))?" - "/rooms/(?P[^/]*)$" - ) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_PUT( - self, - request: SynapseRequest, - group_id: str, - category_id: Optional[str], - room_id: str, - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - if category_id == "": - raise SynapseError(400, "category_id cannot be empty", Codes.INVALID_PARAM) - - if category_id and len(category_id) > MAX_GROUP_CATEGORYID_LENGTH: - raise SynapseError( - 400, - "category_id may not be longer than %s characters" - % (MAX_GROUP_CATEGORYID_LENGTH,), - Codes.INVALID_PARAM, - ) - - content = parse_json_object_from_request(request) - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot modify group summaries." - resp = await self.groups_handler.update_group_summary_room( - group_id, - requester_user_id, - room_id=room_id, - category_id=category_id, - content=content, - ) - - return 200, resp - - @_validate_group_id - async def on_DELETE( - self, request: SynapseRequest, group_id: str, category_id: str, room_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot modify group profiles." - resp = await self.groups_handler.delete_group_summary_room( - group_id, requester_user_id, room_id=room_id, category_id=category_id - ) - - return 200, resp - - -class GroupCategoryServlet(RestServlet): - """Get/add/update/delete a group category""" - - PATTERNS = client_patterns( - "/groups/(?P[^/]*)/categories/(?P[^/]+)$" - ) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_GET( - self, request: SynapseRequest, group_id: str, category_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - requester_user_id = requester.user.to_string() - - category = await self.groups_handler.get_group_category( - group_id, requester_user_id, category_id=category_id - ) - - return 200, category - - @_validate_group_id - async def on_PUT( - self, request: SynapseRequest, group_id: str, category_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - if not category_id: - raise SynapseError(400, "category_id cannot be empty", Codes.INVALID_PARAM) - - if len(category_id) > MAX_GROUP_CATEGORYID_LENGTH: - raise SynapseError( - 400, - "category_id may not be longer than %s characters" - % (MAX_GROUP_CATEGORYID_LENGTH,), - Codes.INVALID_PARAM, - ) - - content = parse_json_object_from_request(request) - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot modify group categories." - resp = await self.groups_handler.update_group_category( - group_id, requester_user_id, category_id=category_id, content=content - ) - - return 200, resp - - @_validate_group_id - async def on_DELETE( - self, request: SynapseRequest, group_id: str, category_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot modify group categories." - resp = await self.groups_handler.delete_group_category( - group_id, requester_user_id, category_id=category_id - ) - - return 200, resp - - -class GroupCategoriesServlet(RestServlet): - """Get all group categories""" - - PATTERNS = client_patterns("/groups/(?P[^/]*)/categories/$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_GET( - self, request: SynapseRequest, group_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - requester_user_id = requester.user.to_string() - - category = await self.groups_handler.get_group_categories( - group_id, requester_user_id - ) - - return 200, category - - -class GroupRoleServlet(RestServlet): - """Get/add/update/delete a group role""" - - PATTERNS = client_patterns("/groups/(?P[^/]*)/roles/(?P[^/]+)$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_GET( - self, request: SynapseRequest, group_id: str, role_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - requester_user_id = requester.user.to_string() - - category = await self.groups_handler.get_group_role( - group_id, requester_user_id, role_id=role_id - ) - - return 200, category - - @_validate_group_id - async def on_PUT( - self, request: SynapseRequest, group_id: str, role_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - if not role_id: - raise SynapseError(400, "role_id cannot be empty", Codes.INVALID_PARAM) - - if len(role_id) > MAX_GROUP_ROLEID_LENGTH: - raise SynapseError( - 400, - "role_id may not be longer than %s characters" - % (MAX_GROUP_ROLEID_LENGTH,), - Codes.INVALID_PARAM, - ) - - content = parse_json_object_from_request(request) - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot modify group roles." - resp = await self.groups_handler.update_group_role( - group_id, requester_user_id, role_id=role_id, content=content - ) - - return 200, resp - - @_validate_group_id - async def on_DELETE( - self, request: SynapseRequest, group_id: str, role_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot modify group roles." - resp = await self.groups_handler.delete_group_role( - group_id, requester_user_id, role_id=role_id - ) - - return 200, resp - - -class GroupRolesServlet(RestServlet): - """Get all group roles""" - - PATTERNS = client_patterns("/groups/(?P[^/]*)/roles/$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_GET( - self, request: SynapseRequest, group_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - requester_user_id = requester.user.to_string() - - category = await self.groups_handler.get_group_roles( - group_id, requester_user_id - ) - - return 200, category - - -class GroupSummaryUsersRoleServlet(RestServlet): - """Update/delete a user's entry in the summary. - - Matches both: - - /groups/:group/summary/users/:room_id - - /groups/:group/summary/roles/:role/users/:user_id - """ - - PATTERNS = client_patterns( - "/groups/(?P[^/]*)/summary" - "(/roles/(?P[^/]+))?" - "/users/(?P[^/]*)$" - ) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_PUT( - self, - request: SynapseRequest, - group_id: str, - role_id: Optional[str], - user_id: str, - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - if role_id == "": - raise SynapseError(400, "role_id cannot be empty", Codes.INVALID_PARAM) - - if role_id and len(role_id) > MAX_GROUP_ROLEID_LENGTH: - raise SynapseError( - 400, - "role_id may not be longer than %s characters" - % (MAX_GROUP_ROLEID_LENGTH,), - Codes.INVALID_PARAM, - ) - - content = parse_json_object_from_request(request) - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot modify group summaries." - resp = await self.groups_handler.update_group_summary_user( - group_id, - requester_user_id, - user_id=user_id, - role_id=role_id, - content=content, - ) - - return 200, resp - - @_validate_group_id - async def on_DELETE( - self, request: SynapseRequest, group_id: str, role_id: str, user_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot modify group summaries." - resp = await self.groups_handler.delete_group_summary_user( - group_id, requester_user_id, user_id=user_id, role_id=role_id - ) - - return 200, resp - - -class GroupRoomServlet(RestServlet): - """Get all rooms in a group""" - - PATTERNS = client_patterns("/groups/(?P[^/]*)/rooms$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_GET( - self, request: SynapseRequest, group_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - requester_user_id = requester.user.to_string() - - result = await self.groups_handler.get_rooms_in_group( - group_id, requester_user_id - ) - - return 200, result - - -class GroupUsersServlet(RestServlet): - """Get all users in a group""" - - PATTERNS = client_patterns("/groups/(?P[^/]*)/users$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_GET( - self, request: SynapseRequest, group_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - requester_user_id = requester.user.to_string() - - result = await self.groups_handler.get_users_in_group( - group_id, requester_user_id - ) - - return 200, result - - -class GroupInvitedUsersServlet(RestServlet): - """Get users invited to a group""" - - PATTERNS = client_patterns("/groups/(?P[^/]*)/invited_users$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_GET( - self, request: SynapseRequest, group_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - result = await self.groups_handler.get_invited_users_in_group( - group_id, requester_user_id - ) - - return 200, result - - -class GroupSettingJoinPolicyServlet(RestServlet): - """Set group join policy""" - - PATTERNS = client_patterns("/groups/(?P[^/]*)/settings/m.join_policy$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_PUT( - self, request: SynapseRequest, group_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - content = parse_json_object_from_request(request) - - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot modify group join policy." - result = await self.groups_handler.set_group_join_policy( - group_id, requester_user_id, content - ) - - return 200, result - - -class GroupCreateServlet(RestServlet): - """Create a group""" - - PATTERNS = client_patterns("/create_group$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - self.server_name = hs.hostname - - async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - # TODO: Create group on remote server - content = parse_json_object_from_request(request) - localpart = content.pop("localpart") - group_id = GroupID(localpart, self.server_name).to_string() - - if not localpart: - raise SynapseError(400, "Group ID cannot be empty", Codes.INVALID_PARAM) - - if len(group_id) > MAX_GROUPID_LENGTH: - raise SynapseError( - 400, - "Group ID may not be longer than %s characters" % (MAX_GROUPID_LENGTH,), - Codes.INVALID_PARAM, - ) - - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot create groups." - result = await self.groups_handler.create_group( - group_id, requester_user_id, content - ) - - return 200, result - - -class GroupAdminRoomsServlet(RestServlet): - """Add a room to the group""" - - PATTERNS = client_patterns( - "/groups/(?P[^/]*)/admin/rooms/(?P[^/]*)$" - ) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_PUT( - self, request: SynapseRequest, group_id: str, room_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - content = parse_json_object_from_request(request) - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot modify rooms in a group." - result = await self.groups_handler.add_room_to_group( - group_id, requester_user_id, room_id, content - ) - - return 200, result - - @_validate_group_id - async def on_DELETE( - self, request: SynapseRequest, group_id: str, room_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot modify group categories." - result = await self.groups_handler.remove_room_from_group( - group_id, requester_user_id, room_id - ) - - return 200, result - - -class GroupAdminRoomsConfigServlet(RestServlet): - """Update the config of a room in a group""" - - PATTERNS = client_patterns( - "/groups/(?P[^/]*)/admin/rooms/(?P[^/]*)" - "/config/(?P[^/]*)$" - ) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_PUT( - self, request: SynapseRequest, group_id: str, room_id: str, config_key: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - content = parse_json_object_from_request(request) - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot modify group categories." - result = await self.groups_handler.update_room_in_group( - group_id, requester_user_id, room_id, config_key, content - ) - - return 200, result - - -class GroupAdminUsersInviteServlet(RestServlet): - """Invite a user to the group""" - - PATTERNS = client_patterns( - "/groups/(?P[^/]*)/admin/users/invite/(?P[^/]*)$" - ) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - self.store = hs.get_datastores().main - self.is_mine_id = hs.is_mine_id - - @_validate_group_id - async def on_PUT( - self, request: SynapseRequest, group_id: str, user_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - content = parse_json_object_from_request(request) - config = content.get("config", {}) - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot invite users to a group." - result = await self.groups_handler.invite( - group_id, user_id, requester_user_id, config - ) - - return 200, result - - -class GroupAdminUsersKickServlet(RestServlet): - """Kick a user from the group""" - - PATTERNS = client_patterns( - "/groups/(?P[^/]*)/admin/users/remove/(?P[^/]*)$" - ) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_PUT( - self, request: SynapseRequest, group_id: str, user_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - content = parse_json_object_from_request(request) - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot kick users from a group." - result = await self.groups_handler.remove_user_from_group( - group_id, user_id, requester_user_id, content - ) - - return 200, result - - -class GroupSelfLeaveServlet(RestServlet): - """Leave a joined group""" - - PATTERNS = client_patterns("/groups/(?P[^/]*)/self/leave$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_PUT( - self, request: SynapseRequest, group_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - content = parse_json_object_from_request(request) - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot leave a group for a users." - result = await self.groups_handler.remove_user_from_group( - group_id, requester_user_id, requester_user_id, content - ) - - return 200, result - - -class GroupSelfJoinServlet(RestServlet): - """Attempt to join a group, or knock""" - - PATTERNS = client_patterns("/groups/(?P[^/]*)/self/join$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_PUT( - self, request: SynapseRequest, group_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - content = parse_json_object_from_request(request) - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot join a user to a group." - result = await self.groups_handler.join_group( - group_id, requester_user_id, content - ) - - return 200, result - - -class GroupSelfAcceptInviteServlet(RestServlet): - """Accept a group invite""" - - PATTERNS = client_patterns("/groups/(?P[^/]*)/self/accept_invite$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_PUT( - self, request: SynapseRequest, group_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - content = parse_json_object_from_request(request) - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot accept an invite to a group." - result = await self.groups_handler.accept_invite( - group_id, requester_user_id, content - ) - - return 200, result - - -class GroupSelfUpdatePublicityServlet(RestServlet): - """Update whether we publicise a users membership of a group""" - - PATTERNS = client_patterns("/groups/(?P[^/]*)/self/update_publicity$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.store = hs.get_datastores().main - - @_validate_group_id - async def on_PUT( - self, request: SynapseRequest, group_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - content = parse_json_object_from_request(request) - publicise = content["publicise"] - await self.store.update_group_publicity(group_id, requester_user_id, publicise) - - return 200, {} - - -class PublicisedGroupsForUserServlet(RestServlet): - """Get the list of groups a user is advertising""" - - PATTERNS = client_patterns("/publicised_groups/(?P[^/]*)$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.store = hs.get_datastores().main - self.groups_handler = hs.get_groups_local_handler() - - async def on_GET( - self, request: SynapseRequest, user_id: str - ) -> Tuple[int, JsonDict]: - await self.auth.get_user_by_req(request, allow_guest=True) - - result = await self.groups_handler.get_publicised_groups_for_user(user_id) - - return 200, result - - -class PublicisedGroupsForUsersServlet(RestServlet): - """Get the list of groups a user is advertising""" - - PATTERNS = client_patterns("/publicised_groups$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.store = hs.get_datastores().main - self.groups_handler = hs.get_groups_local_handler() - - async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - await self.auth.get_user_by_req(request, allow_guest=True) - - content = parse_json_object_from_request(request) - user_ids = content["user_ids"] - - result = await self.groups_handler.bulk_get_publicised_groups(user_ids) - - return 200, result - - -class GroupsForUserServlet(RestServlet): - """Get all groups the logged in user is joined to""" - - PATTERNS = client_patterns("/joined_groups$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - requester_user_id = requester.user.to_string() - - result = await self.groups_handler.get_joined_groups(requester_user_id) - - return 200, result - - -def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: - GroupServlet(hs).register(http_server) - GroupSummaryServlet(hs).register(http_server) - GroupInvitedUsersServlet(hs).register(http_server) - GroupUsersServlet(hs).register(http_server) - GroupRoomServlet(hs).register(http_server) - GroupSettingJoinPolicyServlet(hs).register(http_server) - GroupCreateServlet(hs).register(http_server) - GroupAdminRoomsServlet(hs).register(http_server) - GroupAdminRoomsConfigServlet(hs).register(http_server) - GroupAdminUsersInviteServlet(hs).register(http_server) - GroupAdminUsersKickServlet(hs).register(http_server) - GroupSelfLeaveServlet(hs).register(http_server) - GroupSelfJoinServlet(hs).register(http_server) - GroupSelfAcceptInviteServlet(hs).register(http_server) - GroupsForUserServlet(hs).register(http_server) - GroupCategoryServlet(hs).register(http_server) - GroupCategoriesServlet(hs).register(http_server) - GroupSummaryRoomsCatServlet(hs).register(http_server) - GroupRoleServlet(hs).register(http_server) - GroupRolesServlet(hs).register(http_server) - GroupSelfUpdatePublicityServlet(hs).register(http_server) - GroupSummaryUsersRoleServlet(hs).register(http_server) - PublicisedGroupsForUserServlet(hs).register(http_server) - PublicisedGroupsForUsersServlet(hs).register(http_server) diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index e8772f86e7..f596b792fa 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -298,14 +298,6 @@ class SyncRestServlet(RestServlet): if archived: response["rooms"][Membership.LEAVE] = archived - if sync_result.groups is not None: - if sync_result.groups.join: - response["groups"][Membership.JOIN] = sync_result.groups.join - if sync_result.groups.invite: - response["groups"][Membership.INVITE] = sync_result.groups.invite - if sync_result.groups.leave: - response["groups"][Membership.LEAVE] = sync_result.groups.leave - return response @staticmethod diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 40571b753a..82ac5991e6 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -14,7 +14,6 @@ import urllib.parse from http import HTTPStatus -from typing import List from parameterized import parameterized @@ -23,7 +22,7 @@ from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.http.server import JsonResource from synapse.rest.admin import VersionServlet -from synapse.rest.client import groups, login, room +from synapse.rest.client import login, room from synapse.server import HomeServer from synapse.util import Clock @@ -49,93 +48,6 @@ class VersionTestCase(unittest.HomeserverTestCase): ) -class DeleteGroupTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - login.register_servlets, - groups.register_servlets, - ] - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.admin_user = self.register_user("admin", "pass", admin=True) - self.admin_user_tok = self.login("admin", "pass") - - self.other_user = self.register_user("user", "pass") - self.other_user_token = self.login("user", "pass") - - @unittest.override_config({"experimental_features": {"groups_enabled": True}}) - def test_delete_group(self) -> None: - # Create a new group - channel = self.make_request( - "POST", - b"/create_group", - access_token=self.admin_user_tok, - content={"localpart": "test"}, - ) - - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - - group_id = channel.json_body["group_id"] - - self._check_group(group_id, expect_code=HTTPStatus.OK) - - # Invite/join another user - - url = "/groups/%s/admin/users/invite/%s" % (group_id, self.other_user) - channel = self.make_request( - "PUT", url.encode("ascii"), access_token=self.admin_user_tok, content={} - ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - - url = "/groups/%s/self/accept_invite" % (group_id,) - channel = self.make_request( - "PUT", url.encode("ascii"), access_token=self.other_user_token, content={} - ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - - # Check other user knows they're in the group - self.assertIn(group_id, self._get_groups_user_is_in(self.admin_user_tok)) - self.assertIn(group_id, self._get_groups_user_is_in(self.other_user_token)) - - # Now delete the group - url = "/_synapse/admin/v1/delete_group/" + group_id - channel = self.make_request( - "POST", - url.encode("ascii"), - access_token=self.admin_user_tok, - content={"localpart": "test"}, - ) - - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - - # Check group returns HTTPStatus.NOT_FOUND - self._check_group(group_id, expect_code=HTTPStatus.NOT_FOUND) - - # Check users don't think they're in the group - self.assertNotIn(group_id, self._get_groups_user_is_in(self.admin_user_tok)) - self.assertNotIn(group_id, self._get_groups_user_is_in(self.other_user_token)) - - def _check_group(self, group_id: str, expect_code: int) -> None: - """Assert that trying to fetch the given group results in the given - HTTP status code - """ - - url = "/groups/%s/profile" % (group_id,) - channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok - ) - - self.assertEqual(expect_code, channel.code, msg=channel.json_body) - - def _get_groups_user_is_in(self, access_token: str) -> List[str]: - """Returns the list of groups the user is in (given their access token)""" - channel = self.make_request("GET", b"/joined_groups", access_token=access_token) - - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - - return channel.json_body["groups"] - - class QuarantineMediaTestCase(unittest.HomeserverTestCase): """Test /quarantine_media admin API.""" diff --git a/tests/rest/client/test_groups.py b/tests/rest/client/test_groups.py deleted file mode 100644 index e067cf825c..0000000000 --- a/tests/rest/client/test_groups.py +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright 2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.rest.client import groups, room - -from tests import unittest -from tests.unittest import override_config - - -class GroupsTestCase(unittest.HomeserverTestCase): - user_id = "@alice:test" - room_creator_user_id = "@bob:test" - - servlets = [room.register_servlets, groups.register_servlets] - - @override_config({"enable_group_creation": True}) - def test_rooms_limited_by_visibility(self) -> None: - group_id = "+spqr:test" - - # Alice creates a group - channel = self.make_request("POST", "/create_group", {"localpart": "spqr"}) - self.assertEqual(channel.code, 200, msg=channel.text_body) - self.assertEqual(channel.json_body, {"group_id": group_id}) - - # Bob creates a private room - room_id = self.helper.create_room_as(self.room_creator_user_id, is_public=False) - self.helper.auth_user_id = self.room_creator_user_id - self.helper.send_state( - room_id, "m.room.name", {"name": "bob's secret room"}, tok=None - ) - self.helper.auth_user_id = self.user_id - - # Alice adds the room to her group. - channel = self.make_request( - "PUT", f"/groups/{group_id}/admin/rooms/{room_id}", {} - ) - self.assertEqual(channel.code, 200, msg=channel.text_body) - self.assertEqual(channel.json_body, {}) - - # Alice now tries to retrieve the room list of the space. - channel = self.make_request("GET", f"/groups/{group_id}/rooms") - self.assertEqual(channel.code, 200, msg=channel.text_body) - self.assertEqual( - channel.json_body, {"chunk": [], "total_room_count_estimate": 0} - ) From 4660d9fdcffc833ae4774ac7d162e63769373dc5 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 25 May 2022 12:59:04 +0100 Subject: [PATCH 16/74] Fix up `state_store` naming (#12871) --- changelog.d/12871.misc | 1 + synapse/handlers/admin.py | 4 ++-- synapse/handlers/device.py | 6 ++++-- synapse/handlers/federation.py | 6 ++++-- synapse/handlers/federation_event.py | 10 +++++----- synapse/handlers/initial_sync.py | 6 +++--- synapse/handlers/message.py | 10 +++++----- synapse/handlers/pagination.py | 4 ++-- synapse/handlers/room.py | 4 ++-- synapse/handlers/room_batch.py | 4 ++-- synapse/handlers/search.py | 4 ++-- synapse/handlers/sync.py | 24 ++++++++++++++---------- synapse/push/mailer.py | 6 +++--- synapse/state/__init__.py | 14 +++++++------- tests/handlers/test_federation.py | 4 ++-- 15 files changed, 58 insertions(+), 49 deletions(-) create mode 100644 changelog.d/12871.misc diff --git a/changelog.d/12871.misc b/changelog.d/12871.misc new file mode 100644 index 0000000000..94bd6c4974 --- /dev/null +++ b/changelog.d/12871.misc @@ -0,0 +1 @@ +Fix up the variable `state_store` naming. diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 96376963f2..50e34743b7 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -31,7 +31,7 @@ class AdminHandler: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main self.storage = hs.get_storage() - self.state_store = self.storage.state + self.state_storage = self.storage.state async def get_whois(self, user: UserID) -> JsonDict: connections = [] @@ -233,7 +233,7 @@ class AdminHandler: for event_id in extremities: if not event_to_unseen_prevs[event_id]: continue - state = await self.state_store.get_state_for_event(event_id) + state = await self.state_storage.get_state_for_event(event_id) writer.write_state(room_id, event_id, state) return writer.finished() diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index e59937fd75..b21e469865 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -70,7 +70,7 @@ class DeviceWorkerHandler: self.store = hs.get_datastores().main self.notifier = hs.get_notifier() self.state = hs.get_state_handler() - self.state_store = hs.get_storage().state + self.state_storage = hs.get_storage().state self._auth_handler = hs.get_auth_handler() self.server_name = hs.hostname @@ -203,7 +203,9 @@ class DeviceWorkerHandler: continue # mapping from event_id -> state_dict - prev_state_ids = await self.state_store.get_state_ids_for_events(event_ids) + prev_state_ids = await self.state_storage.get_state_ids_for_events( + event_ids + ) # Check if we've joined the room? If so we just blindly add all the users to # the "possibly changed" users. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 0386d0a07b..c8233270d7 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -126,7 +126,7 @@ class FederationHandler: self.store = hs.get_datastores().main self.storage = hs.get_storage() - self.state_store = self.storage.state + self.state_storage = self.storage.state self.federation_client = hs.get_federation_client() self.state_handler = hs.get_state_handler() self.server_name = hs.hostname @@ -1027,7 +1027,9 @@ class FederationHandler: if event.internal_metadata.outlier: raise NotFoundError("State not known at event %s" % (event_id,)) - state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id]) + state_groups = await self.state_storage.get_state_groups_ids( + room_id, [event_id] + ) # get_state_groups_ids should return exactly one result assert len(state_groups) == 1 diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index ca82df8a6d..8ce7187bef 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -99,7 +99,7 @@ class FederationEventHandler: def __init__(self, hs: "HomeServer"): self._store = hs.get_datastores().main self._storage = hs.get_storage() - self._state_store = self._storage.state + self._state_storage = self._storage.state self._state_handler = hs.get_state_handler() self._event_creation_handler = hs.get_event_creation_handler() @@ -533,7 +533,7 @@ class FederationEventHandler: ) return await self._store.update_state_for_partial_state_event(event, context) - self._state_store.notify_event_un_partial_stated(event.event_id) + self._state_storage.notify_event_un_partial_stated(event.event_id) async def backfill( self, dest: str, room_id: str, limit: int, extremities: Collection[str] @@ -832,7 +832,7 @@ class FederationEventHandler: event_map = {event_id: event} try: # Get the state of the events we know about - ours = await self._state_store.get_state_groups_ids(room_id, seen) + ours = await self._state_storage.get_state_groups_ids(room_id, seen) # state_maps is a list of mappings from (type, state_key) to event_id state_maps: List[StateMap[str]] = list(ours.values()) @@ -1626,7 +1626,7 @@ class FederationEventHandler: # given state at the event. This should correctly handle cases # like bans, especially with state res v2. - state_sets_d = await self._state_store.get_state_groups( + state_sets_d = await self._state_storage.get_state_groups( event.room_id, extrem_ids ) state_sets: List[Iterable[EventBase]] = list(state_sets_d.values()) @@ -1895,7 +1895,7 @@ class FederationEventHandler: # create a new state group as a delta from the existing one. prev_group = context.state_group - state_group = await self._state_store.store_state_group( + state_group = await self._state_storage.store_state_group( event.event_id, event.room_id, prev_group=prev_group, diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index d79248ad90..c06932a41a 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -68,7 +68,7 @@ class InitialSyncHandler: ] = ResponseCache(hs.get_clock(), "initial_sync_cache") self._event_serializer = hs.get_event_client_serializer() self.storage = hs.get_storage() - self.state_store = self.storage.state + self.state_storage = self.storage.state async def snapshot_all_rooms( self, @@ -198,7 +198,7 @@ class InitialSyncHandler: event.stream_ordering, ) deferred_room_state = run_in_background( - self.state_store.get_state_for_events, [event.event_id] + self.state_storage.get_state_for_events, [event.event_id] ).addCallback( lambda states: cast(StateMap[EventBase], states[event.event_id]) ) @@ -355,7 +355,7 @@ class InitialSyncHandler: member_event_id: str, is_peeking: bool, ) -> JsonDict: - room_state = await self.state_store.get_state_for_event(member_event_id) + room_state = await self.state_storage.get_state_for_event(member_event_id) limit = pagin_config.limit if pagin_config else None if limit is None: diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index cb1bc4c06f..9501e7f1b7 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -78,7 +78,7 @@ class MessageHandler: self.state = hs.get_state_handler() self.store = hs.get_datastores().main self.storage = hs.get_storage() - self.state_store = self.storage.state + self.state_storage = self.storage.state self._event_serializer = hs.get_event_client_serializer() self._ephemeral_events_enabled = hs.config.server.enable_ephemeral_messages @@ -125,7 +125,7 @@ class MessageHandler: assert ( membership_event_id is not None ), "check_user_in_room_or_world_readable returned invalid data" - room_state = await self.state_store.get_state_for_events( + room_state = await self.state_storage.get_state_for_events( [membership_event_id], StateFilter.from_types([key]) ) data = room_state[membership_event_id].get(key) @@ -186,7 +186,7 @@ class MessageHandler: # check whether the user is in the room at that time to determine # whether they should be treated as peeking. - state_map = await self.state_store.get_state_for_event( + state_map = await self.state_storage.get_state_for_event( last_event.event_id, StateFilter.from_types([(EventTypes.Member, user_id)]), ) @@ -207,7 +207,7 @@ class MessageHandler: ) if visible_events: - room_state_events = await self.state_store.get_state_for_events( + room_state_events = await self.state_storage.get_state_for_events( [last_event.event_id], state_filter=state_filter ) room_state: Mapping[Any, EventBase] = room_state_events[ @@ -237,7 +237,7 @@ class MessageHandler: assert ( membership_event_id is not None ), "check_user_in_room_or_world_readable returned invalid data" - room_state_events = await self.state_store.get_state_for_events( + room_state_events = await self.state_storage.get_state_for_events( [membership_event_id], state_filter=state_filter ) room_state = room_state_events[membership_event_id] diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 19a4407050..6f4820c240 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -130,7 +130,7 @@ class PaginationHandler: self.auth = hs.get_auth() self.store = hs.get_datastores().main self.storage = hs.get_storage() - self.state_store = self.storage.state + self.state_storage = self.storage.state self.clock = hs.get_clock() self._server_name = hs.hostname self._room_shutdown_handler = hs.get_room_shutdown_handler() @@ -539,7 +539,7 @@ class PaginationHandler: (EventTypes.Member, event.sender) for event in events ) - state_ids = await self.state_store.get_state_ids_for_event( + state_ids = await self.state_storage.get_state_ids_for_event( events[0].event_id, state_filter=state_filter ) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 92e1de0500..e2775b34f1 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1193,7 +1193,7 @@ class RoomContextHandler: self.auth = hs.get_auth() self.store = hs.get_datastores().main self.storage = hs.get_storage() - self.state_store = self.storage.state + self.state_storage = self.storage.state self._relations_handler = hs.get_relations_handler() async def get_event_context( @@ -1293,7 +1293,7 @@ class RoomContextHandler: # first? Shouldn't we be consistent with /sync? # https://github.com/matrix-org/matrix-doc/issues/687 - state = await self.state_store.get_state_for_events( + state = await self.state_storage.get_state_for_events( [last_event_id], state_filter=state_filter ) diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py index fbfd748406..7ce32f2e9c 100644 --- a/synapse/handlers/room_batch.py +++ b/synapse/handlers/room_batch.py @@ -17,7 +17,7 @@ class RoomBatchHandler: def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastores().main - self.state_store = hs.get_storage().state + self.state_storage = hs.get_storage().state self.event_creation_handler = hs.get_event_creation_handler() self.room_member_handler = hs.get_room_member_handler() self.auth = hs.get_auth() @@ -141,7 +141,7 @@ class RoomBatchHandler: ) = await self.store.get_max_depth_of(event_ids) # mapping from (type, state_key) -> state_event_id assert most_recent_event_id is not None - prev_state_map = await self.state_store.get_state_ids_for_event( + prev_state_map = await self.state_storage.get_state_ids_for_event( most_recent_event_id ) # List of state event ID's diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index cd1c47dae8..e02c915248 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -56,7 +56,7 @@ class SearchHandler: self._event_serializer = hs.get_event_client_serializer() self._relations_handler = hs.get_relations_handler() self.storage = hs.get_storage() - self.state_store = self.storage.state + self.state_storage = self.storage.state self.auth = hs.get_auth() async def get_old_rooms_from_upgraded_room(self, room_id: str) -> Iterable[str]: @@ -677,7 +677,7 @@ class SearchHandler: [(EventTypes.Member, sender) for sender in senders] ) - state = await self.state_store.get_state_for_event( + state = await self.state_storage.get_state_for_event( last_event_id, state_filter ) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index dcbb5ce921..c5c538e0c3 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -239,7 +239,7 @@ class SyncHandler: self.state = hs.get_state_handler() self.auth = hs.get_auth() self.storage = hs.get_storage() - self.state_store = self.storage.state + self.state_storage = self.storage.state # TODO: flush cache entries on subsequent sync request. # Once we get the next /sync request (ie, one with the same access token @@ -630,7 +630,7 @@ class SyncHandler: event: event of interest state_filter: The state filter used to fetch state from the database. """ - state_ids = await self.state_store.get_state_ids_for_event( + state_ids = await self.state_storage.get_state_ids_for_event( event.event_id, state_filter=state_filter or StateFilter.all() ) if event.is_state(): @@ -710,7 +710,7 @@ class SyncHandler: return None last_event = last_events[-1] - state_ids = await self.state_store.get_state_ids_for_event( + state_ids = await self.state_storage.get_state_ids_for_event( last_event.event_id, state_filter=StateFilter.from_types( [(EventTypes.Name, ""), (EventTypes.CanonicalAlias, "")] @@ -888,11 +888,13 @@ class SyncHandler: if full_state: if batch: - current_state_ids = await self.state_store.get_state_ids_for_event( - batch.events[-1].event_id, state_filter=state_filter + current_state_ids = ( + await self.state_storage.get_state_ids_for_event( + batch.events[-1].event_id, state_filter=state_filter + ) ) - state_ids = await self.state_store.get_state_ids_for_event( + state_ids = await self.state_storage.get_state_ids_for_event( batch.events[0].event_id, state_filter=state_filter ) @@ -913,7 +915,7 @@ class SyncHandler: elif batch.limited: if batch: state_at_timeline_start = ( - await self.state_store.get_state_ids_for_event( + await self.state_storage.get_state_ids_for_event( batch.events[0].event_id, state_filter=state_filter ) ) @@ -947,8 +949,10 @@ class SyncHandler: ) if batch: - current_state_ids = await self.state_store.get_state_ids_for_event( - batch.events[-1].event_id, state_filter=state_filter + current_state_ids = ( + await self.state_storage.get_state_ids_for_event( + batch.events[-1].event_id, state_filter=state_filter + ) ) else: # Its not clear how we get here, but empirically we do @@ -978,7 +982,7 @@ class SyncHandler: # So we fish out all the member events corresponding to the # timeline here, and then dedupe any redundant ones below. - state_ids = await self.state_store.get_state_ids_for_event( + state_ids = await self.state_storage.get_state_ids_for_event( batch.events[0].event_id, # we only want members! state_filter=StateFilter.from_types( diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 5ccdd88364..84124af965 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -114,7 +114,7 @@ class Mailer: self.send_email_handler = hs.get_send_email_handler() self.store = self.hs.get_datastores().main - self.state_store = self.hs.get_storage().state + self.state_storage = self.hs.get_storage().state self.macaroon_gen = self.hs.get_macaroon_generator() self.state_handler = self.hs.get_state_handler() self.storage = hs.get_storage() @@ -494,7 +494,7 @@ class Mailer: ) else: # Attempt to check the historical state for the room. - historical_state = await self.state_store.get_state_for_event( + historical_state = await self.state_storage.get_state_for_event( event.event_id, StateFilter.from_types((type_state_key,)) ) sender_state_event = historical_state.get(type_state_key) @@ -767,7 +767,7 @@ class Mailer: member_event_ids.append(sender_state_event_id) else: # Attempt to check the historical state for the room. - historical_state = await self.state_store.get_state_for_event( + historical_state = await self.state_storage.get_state_for_event( event_id, StateFilter.from_types((type_state_key,)) ) sender_state_event = historical_state.get(type_state_key) diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 4b4ed42cff..536564b7ff 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -127,7 +127,7 @@ class StateHandler: def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self.store = hs.get_datastores().main - self.state_store = hs.get_storage().state + self.state_storage = hs.get_storage().state self.hs = hs self._state_resolution_handler = hs.get_state_resolution_handler() self._storage = hs.get_storage() @@ -339,7 +339,7 @@ class StateHandler: # if not state_group_before_event: - state_group_before_event = await self.state_store.store_state_group( + state_group_before_event = await self.state_storage.store_state_group( event.event_id, event.room_id, prev_group=state_group_before_event_prev_group, @@ -384,7 +384,7 @@ class StateHandler: state_ids_after_event[key] = event.event_id delta_ids = {key: event.event_id} - state_group_after_event = await self.state_store.store_state_group( + state_group_after_event = await self.state_storage.store_state_group( event.event_id, event.room_id, prev_group=state_group_before_event, @@ -418,7 +418,7 @@ class StateHandler: """ logger.debug("resolve_state_groups event_ids %s", event_ids) - state_groups = await self.state_store.get_state_group_for_events(event_ids) + state_groups = await self.state_storage.get_state_group_for_events(event_ids) state_group_ids = state_groups.values() @@ -426,8 +426,8 @@ class StateHandler: state_group_ids_set = set(state_group_ids) if len(state_group_ids_set) == 1: (state_group_id,) = state_group_ids_set - state = await self.state_store.get_state_for_groups(state_group_ids_set) - prev_group, delta_ids = await self.state_store.get_state_group_delta( + state = await self.state_storage.get_state_for_groups(state_group_ids_set) + prev_group, delta_ids = await self.state_storage.get_state_group_delta( state_group_id ) return _StateCacheEntry( @@ -441,7 +441,7 @@ class StateHandler: room_version = await self.store.get_room_version_id(room_id) - state_to_resolve = await self.state_store.get_state_for_groups( + state_to_resolve = await self.state_storage.get_state_for_groups( state_group_ids_set ) diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index e95dfdce20..bef6c2b776 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -50,7 +50,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase): hs = self.setup_test_homeserver(federation_http_client=None) self.handler = hs.get_federation_handler() self.store = hs.get_datastores().main - self.state_store = hs.get_storage().state + self.state_storage = hs.get_storage().state self._event_auth_handler = hs.get_event_auth_handler() return hs @@ -334,7 +334,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase): # mapping from (type, state_key) -> state_event_id assert most_recent_prev_event_id is not None prev_state_map = self.get_success( - self.state_store.get_state_ids_for_event(most_recent_prev_event_id) + self.state_storage.get_state_ids_for_event(most_recent_prev_event_id) ) # List of state event ID's prev_state_ids = list(prev_state_map.values()) From 1b338476afbcb83918c5df285975878032bbce75 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 25 May 2022 23:24:28 +0200 Subject: [PATCH 17/74] Allow bigger responses to `/federation/v1/state` (#12877) * Refactor HTTP response size limits Rather than passing a separate `max_response_size` down the stack, make it an attribute of the `parser`. * Allow bigger responses on `federation/v1/state` `/state` can return huge responses, so we need to handle that. --- changelog.d/12877.bugfix | 1 + synapse/federation/transport/client.py | 15 ++++++------- synapse/http/matrixfederationclient.py | 29 +++++++------------------- tests/http/test_fedclient.py | 6 +++--- 4 files changed, 19 insertions(+), 32 deletions(-) create mode 100644 changelog.d/12877.bugfix diff --git a/changelog.d/12877.bugfix b/changelog.d/12877.bugfix new file mode 100644 index 0000000000..1ecf448baf --- /dev/null +++ b/changelog.d/12877.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.54 which could sometimes cause exceptions when handling federated traffic. diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 9ce06dfa28..25df1905c6 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -49,11 +49,6 @@ from synapse.types import JsonDict logger = logging.getLogger(__name__) -# Send join responses can be huge, so we set a separate limit here. The response -# is parsed in a streaming manner, which helps alleviate the issue of memory -# usage a bit. -MAX_RESPONSE_SIZE_SEND_JOIN = 500 * 1024 * 1024 - class TransportLayerClient: """Sends federation HTTP requests to other servers""" @@ -349,7 +344,6 @@ class TransportLayerClient: path=path, data=content, parser=SendJoinParser(room_version, v1_api=True), - max_response_size=MAX_RESPONSE_SIZE_SEND_JOIN, ) async def send_join_v2( @@ -372,7 +366,6 @@ class TransportLayerClient: args=query_params, data=content, parser=SendJoinParser(room_version, v1_api=False), - max_response_size=MAX_RESPONSE_SIZE_SEND_JOIN, ) async def send_leave_v1( @@ -1360,6 +1353,11 @@ class SendJoinParser(ByteParser[SendJoinResponse]): CONTENT_TYPE = "application/json" + # /send_join responses can be huge, so we override the size limit here. The response + # is parsed in a streaming manner, which helps alleviate the issue of memory + # usage a bit. + MAX_RESPONSE_SIZE = 500 * 1024 * 1024 + def __init__(self, room_version: RoomVersion, v1_api: bool): self._response = SendJoinResponse([], [], event_dict={}) self._room_version = room_version @@ -1427,6 +1425,9 @@ class _StateParser(ByteParser[StateRequestResponse]): CONTENT_TYPE = "application/json" + # As with /send_join, /state responses can be huge. + MAX_RESPONSE_SIZE = 500 * 1024 * 1024 + def __init__(self, room_version: RoomVersion): self._response = StateRequestResponse([], []) self._room_version = room_version diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 0b9475debd..db44721ef5 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -92,9 +92,6 @@ incoming_responses_counter = Counter( "synapse_http_matrixfederationclient_responses", "", ["method", "code"] ) -# a federation response can be rather large (eg a big state_ids is 50M or so), so we -# need a generous limit here. -MAX_RESPONSE_SIZE = 100 * 1024 * 1024 MAX_LONG_RETRIES = 10 MAX_SHORT_RETRIES = 3 @@ -116,6 +113,11 @@ class ByteParser(ByteWriteable, Generic[T], abc.ABC): the content type doesn't match we fail the request. """ + # a federation response can be rather large (eg a big state_ids is 50M or so), so we + # need a generous limit here. + MAX_RESPONSE_SIZE: int = 100 * 1024 * 1024 + """The largest response this parser will accept.""" + @abc.abstractmethod def finish(self) -> T: """Called when response has finished streaming and the parser should @@ -203,7 +205,6 @@ async def _handle_response( response: IResponse, start_ms: int, parser: ByteParser[T], - max_response_size: Optional[int] = None, ) -> T: """ Reads the body of a response with a timeout and sends it to a parser @@ -215,15 +216,12 @@ async def _handle_response( response: response to the request start_ms: Timestamp when request was made parser: The parser for the response - max_response_size: The maximum size to read from the response, if None - uses the default. Returns: The parsed response """ - if max_response_size is None: - max_response_size = MAX_RESPONSE_SIZE + max_response_size = parser.MAX_RESPONSE_SIZE try: check_content_type_is(response.headers, parser.CONTENT_TYPE) @@ -240,7 +238,7 @@ async def _handle_response( "{%s} [%s] JSON response exceeded max size %i - %s %s", request.txn_id, request.destination, - MAX_RESPONSE_SIZE, + max_response_size, request.method, request.uri.decode("ascii"), ) @@ -772,7 +770,6 @@ class MatrixFederationHttpClient: backoff_on_404: bool = False, try_trailing_slash_on_400: bool = False, parser: Literal[None] = None, - max_response_size: Optional[int] = None, ) -> Union[JsonDict, list]: ... @@ -790,7 +787,6 @@ class MatrixFederationHttpClient: backoff_on_404: bool = False, try_trailing_slash_on_400: bool = False, parser: Optional[ByteParser[T]] = None, - max_response_size: Optional[int] = None, ) -> T: ... @@ -807,7 +803,6 @@ class MatrixFederationHttpClient: backoff_on_404: bool = False, try_trailing_slash_on_400: bool = False, parser: Optional[ByteParser] = None, - max_response_size: Optional[int] = None, ): """Sends the specified json data using PUT @@ -843,8 +838,6 @@ class MatrixFederationHttpClient: enabled. parser: The parser to use to decode the response. Defaults to parsing as JSON. - max_response_size: The maximum size to read from the response, if None - uses the default. Returns: Succeeds when we get a 2xx HTTP response. The @@ -895,7 +888,6 @@ class MatrixFederationHttpClient: response, start_ms, parser=parser, - max_response_size=max_response_size, ) return body @@ -984,7 +976,6 @@ class MatrixFederationHttpClient: ignore_backoff: bool = False, try_trailing_slash_on_400: bool = False, parser: Literal[None] = None, - max_response_size: Optional[int] = None, ) -> Union[JsonDict, list]: ... @@ -999,7 +990,6 @@ class MatrixFederationHttpClient: ignore_backoff: bool = ..., try_trailing_slash_on_400: bool = ..., parser: ByteParser[T] = ..., - max_response_size: Optional[int] = ..., ) -> T: ... @@ -1013,7 +1003,6 @@ class MatrixFederationHttpClient: ignore_backoff: bool = False, try_trailing_slash_on_400: bool = False, parser: Optional[ByteParser] = None, - max_response_size: Optional[int] = None, ): """GETs some json from the given host homeserver and path @@ -1043,9 +1032,6 @@ class MatrixFederationHttpClient: parser: The parser to use to decode the response. Defaults to parsing as JSON. - max_response_size: The maximum size to read from the response. If None, - uses the default. - Returns: Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. @@ -1090,7 +1076,6 @@ class MatrixFederationHttpClient: response, start_ms, parser=parser, - max_response_size=max_response_size, ) return body diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py index 638babae69..006dbab093 100644 --- a/tests/http/test_fedclient.py +++ b/tests/http/test_fedclient.py @@ -26,7 +26,7 @@ from twisted.web.http import HTTPChannel from synapse.api.errors import RequestSendFailed from synapse.http.matrixfederationclient import ( - MAX_RESPONSE_SIZE, + JsonParser, MatrixFederationHttpClient, MatrixFederationRequest, ) @@ -609,9 +609,9 @@ class FederationClientTests(HomeserverTestCase): while not test_d.called: protocol.dataReceived(b"a" * chunk_size) sent += chunk_size - self.assertLessEqual(sent, MAX_RESPONSE_SIZE) + self.assertLessEqual(sent, JsonParser.MAX_RESPONSE_SIZE) - self.assertEqual(sent, MAX_RESPONSE_SIZE) + self.assertEqual(sent, JsonParser.MAX_RESPONSE_SIZE) f = self.failureResultOf(test_d) self.assertIsInstance(f.value, RequestSendFailed) From b83bc5fab57b37f75a79d02213d6032c586fd36e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 26 May 2022 10:48:12 +0100 Subject: [PATCH 18/74] Pull out less state when handling gaps mk2 (#12852) --- changelog.d/12852.misc | 1 + synapse/handlers/federation_event.py | 178 +++++++++++------------- synapse/handlers/message.py | 40 +++++- synapse/state/__init__.py | 22 ++- synapse/storage/databases/main/state.py | 59 ++++++++ tests/handlers/test_federation.py | 6 +- tests/storage/test_events.py | 43 ++++-- tests/test_state.py | 14 +- 8 files changed, 236 insertions(+), 127 deletions(-) create mode 100644 changelog.d/12852.misc diff --git a/changelog.d/12852.misc b/changelog.d/12852.misc new file mode 100644 index 0000000000..afca32471f --- /dev/null +++ b/changelog.d/12852.misc @@ -0,0 +1 @@ +Pull out less state when handling gaps in room DAG. diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 8ce7187bef..a1361af272 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -274,7 +274,7 @@ class FederationEventHandler: affected=pdu.event_id, ) - await self._process_received_pdu(origin, pdu, state=None) + await self._process_received_pdu(origin, pdu, state_ids=None) async def on_send_membership_event( self, origin: str, event: EventBase @@ -463,7 +463,9 @@ class FederationEventHandler: with nested_logging_context(suffix=event.event_id): context = await self._state_handler.compute_event_context( event, - old_state=state, + state_ids_before_event={ + (e.type, e.state_key): e.event_id for e in state + }, partial_state=partial_state, ) @@ -512,12 +514,12 @@ class FederationEventHandler: # # This is the same operation as we do when we receive a regular event # over federation. - state = await self._resolve_state_at_missing_prevs(destination, event) + state_ids = await self._resolve_state_at_missing_prevs(destination, event) # build a new state group for it if need be context = await self._state_handler.compute_event_context( event, - old_state=state, + state_ids_before_event=state_ids, ) if context.partial_state: # this can happen if some or all of the event's prev_events still have @@ -767,11 +769,12 @@ class FederationEventHandler: return try: - state = await self._resolve_state_at_missing_prevs(origin, event) + state_ids = await self._resolve_state_at_missing_prevs(origin, event) # TODO(faster_joins): make sure that _resolve_state_at_missing_prevs does # not return partial state + await self._process_received_pdu( - origin, event, state=state, backfilled=backfilled + origin, event, state_ids=state_ids, backfilled=backfilled ) except FederationError as e: if e.code == 403: @@ -781,7 +784,7 @@ class FederationEventHandler: async def _resolve_state_at_missing_prevs( self, dest: str, event: EventBase - ) -> Optional[Iterable[EventBase]]: + ) -> Optional[StateMap[str]]: """Calculate the state at an event with missing prev_events. This is used when we have pulled a batch of events from a remote server, and @@ -808,8 +811,8 @@ class FederationEventHandler: event: an event to check for missing prevs. Returns: - if we already had all the prev events, `None`. Otherwise, returns a list of - the events in the state at `event`. + if we already had all the prev events, `None`. Otherwise, returns + the event ids of the state at `event`. """ room_id = event.room_id event_id = event.event_id @@ -829,7 +832,7 @@ class FederationEventHandler: ) # Calculate the state after each of the previous events, and # resolve them to find the correct state at the current event. - event_map = {event_id: event} + try: # Get the state of the events we know about ours = await self._state_storage.get_state_groups_ids(room_id, seen) @@ -849,40 +852,23 @@ class FederationEventHandler: # note that if any of the missing prevs share missing state or # auth events, the requests to fetch those events are deduped # by the get_pdu_cache in federation_client. - remote_state = await self._get_state_after_missing_prev_event( - dest, room_id, p + remote_state_map = ( + await self._get_state_ids_after_missing_prev_event( + dest, room_id, p + ) ) - remote_state_map = { - (x.type, x.state_key): x.event_id for x in remote_state - } state_maps.append(remote_state_map) - for x in remote_state: - event_map[x.event_id] = x - room_version = await self._store.get_room_version_id(room_id) state_map = await self._state_resolution_handler.resolve_events_with_store( room_id, room_version, state_maps, - event_map, + event_map={event_id: event}, state_res_store=StateResolutionStore(self._store), ) - # We need to give _process_received_pdu the actual state events - # rather than event ids, so generate that now. - - # First though we need to fetch all the events that are in - # state_map, so we can build up the state below. - evs = await self._store.get_events( - list(state_map.values()), - get_prev_content=False, - redact_behaviour=EventRedactBehaviour.as_is, - ) - event_map.update(evs) - - state = [event_map[e] for e in state_map.values()] except Exception: logger.warning( "Error attempting to resolve state at missing prev_events", @@ -894,14 +880,14 @@ class FederationEventHandler: "We can't get valid state history.", affected=event_id, ) - return state + return state_map - async def _get_state_after_missing_prev_event( + async def _get_state_ids_after_missing_prev_event( self, destination: str, room_id: str, event_id: str, - ) -> List[EventBase]: + ) -> StateMap[str]: """Requests all of the room state at a given event from a remote homeserver. Args: @@ -910,7 +896,7 @@ class FederationEventHandler: event_id: The id of the event we want the state at. Returns: - A list of events in the state, including the event itself + The event ids of the state *after* the given event. """ ( state_event_ids, @@ -925,19 +911,17 @@ class FederationEventHandler: len(auth_event_ids), ) - # start by just trying to fetch the events from the store + # Start by checking events we already have in the DB desired_events = set(state_event_ids) desired_events.add(event_id) logger.debug("Fetching %i events from cache/store", len(desired_events)) - fetched_events = await self._store.get_events( - desired_events, allow_rejected=True - ) + have_events = await self._store.have_seen_events(room_id, desired_events) - missing_desired_events = desired_events - fetched_events.keys() + missing_desired_events = desired_events - have_events logger.debug( "We are missing %i events (got %i)", len(missing_desired_events), - len(fetched_events), + len(have_events), ) # We probably won't need most of the auth events, so let's just check which @@ -948,7 +932,7 @@ class FederationEventHandler: # already have a bunch of the state events. It would be nice if the # federation api gave us a way of finding out which we actually need. - missing_auth_events = set(auth_event_ids) - fetched_events.keys() + missing_auth_events = set(auth_event_ids) - have_events missing_auth_events.difference_update( await self._store.have_seen_events(room_id, missing_auth_events) ) @@ -974,47 +958,51 @@ class FederationEventHandler: destination=destination, room_id=room_id, event_ids=missing_events ) - # we need to make sure we re-load from the database to get the rejected - # state correct. - fetched_events.update( - await self._store.get_events(missing_desired_events, allow_rejected=True) - ) + # We now need to fill out the state map, which involves fetching the + # type and state key for each event ID in the state. + state_map = {} - # check for events which were in the wrong room. - # - # this can happen if a remote server claims that the state or - # auth_events at an event in room A are actually events in room B + event_metadata = await self._store.get_metadata_for_events(state_event_ids) + for state_event_id, metadata in event_metadata.items(): + if metadata.room_id != room_id: + # This is a bogus situation, but since we may only discover it a long time + # after it happened, we try our best to carry on, by just omitting the + # bad events from the returned state set. + # + # This can happen if a remote server claims that the state or + # auth_events at an event in room A are actually events in room B + logger.warning( + "Remote server %s claims event %s in room %s is an auth/state " + "event in room %s", + destination, + state_event_id, + metadata.room_id, + room_id, + ) + continue - bad_events = [ - (event_id, event.room_id) - for event_id, event in fetched_events.items() - if event.room_id != room_id - ] + if metadata.state_key is None: + logger.warning( + "Remote server gave us non-state event in state: %s", state_event_id + ) + continue - for bad_event_id, bad_room_id in bad_events: - # This is a bogus situation, but since we may only discover it a long time - # after it happened, we try our best to carry on, by just omitting the - # bad events from the returned state set. - logger.warning( - "Remote server %s claims event %s in room %s is an auth/state " - "event in room %s", - destination, - bad_event_id, - bad_room_id, - room_id, - ) - - del fetched_events[bad_event_id] + state_map[(metadata.event_type, metadata.state_key)] = state_event_id # if we couldn't get the prev event in question, that's a problem. - remote_event = fetched_events.get(event_id) + remote_event = await self._store.get_event( + event_id, + allow_none=True, + allow_rejected=True, + redact_behaviour=EventRedactBehaviour.as_is, + ) if not remote_event: raise Exception("Unable to get missing prev_event %s" % (event_id,)) # missing state at that event is a warning, not a blocker # XXX: this doesn't sound right? it means that we'll end up with incomplete # state. - failed_to_fetch = desired_events - fetched_events.keys() + failed_to_fetch = desired_events - event_metadata.keys() if failed_to_fetch: logger.warning( "Failed to fetch missing state events for %s %s", @@ -1022,14 +1010,12 @@ class FederationEventHandler: failed_to_fetch, ) - remote_state = [ - fetched_events[e_id] for e_id in state_event_ids if e_id in fetched_events - ] - if remote_event.is_state() and remote_event.rejected_reason is None: - remote_state.append(remote_event) + state_map[ + (remote_event.type, remote_event.state_key) + ] = remote_event.event_id - return remote_state + return state_map async def _get_state_and_persist( self, destination: str, room_id: str, event_id: str @@ -1056,7 +1042,7 @@ class FederationEventHandler: self, origin: str, event: EventBase, - state: Optional[Iterable[EventBase]], + state_ids: Optional[StateMap[str]], backfilled: bool = False, ) -> None: """Called when we have a new non-outlier event. @@ -1078,7 +1064,7 @@ class FederationEventHandler: event: event to be persisted - state: Normally None, but if we are handling a gap in the graph + state_ids: Normally None, but if we are handling a gap in the graph (ie, we are missing one or more prev_events), the resolved state at the event @@ -1090,7 +1076,8 @@ class FederationEventHandler: try: context = await self._state_handler.compute_event_context( - event, old_state=state + event, + state_ids_before_event=state_ids, ) context = await self._check_event_auth( origin, @@ -1107,7 +1094,7 @@ class FederationEventHandler: # For new (non-backfilled and non-outlier) events we check if the event # passes auth based on the current state. If it doesn't then we # "soft-fail" the event. - await self._check_for_soft_fail(event, state, origin=origin) + await self._check_for_soft_fail(event, state_ids, origin=origin) await self._run_push_actions_and_persist_event(event, context, backfilled) @@ -1589,7 +1576,7 @@ class FederationEventHandler: async def _check_for_soft_fail( self, event: EventBase, - state: Optional[Iterable[EventBase]], + state_ids: Optional[StateMap[str]], origin: str, ) -> None: """Checks if we should soft fail the event; if so, marks the event as @@ -1597,7 +1584,7 @@ class FederationEventHandler: Args: event - state: The state at the event if we don't have all the event's prev events + state_ids: The state at the event if we don't have all the event's prev events origin: The host the event originates from. """ extrem_ids_list = await self._store.get_latest_event_ids_in_room(event.room_id) @@ -1613,7 +1600,7 @@ class FederationEventHandler: room_version_obj = KNOWN_ROOM_VERSIONS[room_version] # Calculate the "current state". - if state is not None: + if state_ids is not None: # If we're explicitly given the state then we won't have all the # prev events, and so we have a gap in the graph. In this case # we want to be a little careful as we might have been down for @@ -1626,17 +1613,20 @@ class FederationEventHandler: # given state at the event. This should correctly handle cases # like bans, especially with state res v2. - state_sets_d = await self._state_storage.get_state_groups( + state_sets_d = await self._state_storage.get_state_groups_ids( event.room_id, extrem_ids ) - state_sets: List[Iterable[EventBase]] = list(state_sets_d.values()) - state_sets.append(state) - current_states = await self._state_handler.resolve_events( - room_version, state_sets, event + state_sets: List[StateMap[str]] = list(state_sets_d.values()) + state_sets.append(state_ids) + current_state_ids = ( + await self._state_resolution_handler.resolve_events_with_store( + event.room_id, + room_version, + state_sets, + event_map=None, + state_res_store=StateResolutionStore(self._store), + ) ) - current_state_ids: StateMap[str] = { - k: e.event_id for k, e in current_states.items() - } else: current_state_ids = await self._state_handler.get_current_state_ids( event.room_id, latest_event_ids=extrem_ids diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 9501e7f1b7..7ca126dbd1 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -55,7 +55,14 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.send_event import ReplicationSendEventRestServlet from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.state import StateFilter -from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester +from synapse.types import ( + MutableStateMap, + Requester, + RoomAlias, + StreamToken, + UserID, + create_requester, +) from synapse.util import json_decoder, json_encoder, log_failure, unwrapFirstError from synapse.util.async_helpers import Linearizer, gather_results from synapse.util.caches.expiringcache import ExpiringCache @@ -1022,8 +1029,35 @@ class EventCreationHandler: # # TODO(faster_joins): figure out how this works, and make sure that the # old state is complete. - old_state = await self.store.get_events_as_list(state_event_ids) - context = await self.state.compute_event_context(event, old_state=old_state) + metadata = await self.store.get_metadata_for_events(state_event_ids) + + state_map_for_event: MutableStateMap[str] = {} + for state_id in state_event_ids: + data = metadata.get(state_id) + if data is None: + # We're trying to persist a new historical batch of events + # with the given state, e.g. via + # `RoomBatchSendEventRestServlet`. The state can be inferred + # by Synapse or set directly by the client. + # + # Either way, we should have persisted all the state before + # getting here. + raise Exception( + f"State event {state_id} not found in DB," + " Synapse should have persisted it before using it." + ) + + if data.state_key is None: + raise Exception( + f"Trying to set non-state event {state_id} as state" + ) + + state_map_for_event[(data.event_type, data.state_key)] = state_id + + context = await self.state.compute_event_context( + event, + state_ids_before_event=state_map_for_event, + ) else: context = await self.state.compute_event_context(event) diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 536564b7ff..9c9d946f38 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -261,7 +261,7 @@ class StateHandler: async def compute_event_context( self, event: EventBase, - old_state: Optional[Iterable[EventBase]] = None, + state_ids_before_event: Optional[StateMap[str]] = None, partial_state: bool = False, ) -> EventContext: """Build an EventContext structure for a non-outlier event. @@ -273,12 +273,12 @@ class StateHandler: Args: event: - old_state: The state at the event if it can't be - calculated from existing events. This is normally only specified - when receiving an event from federation where we don't have the - prev events for, e.g. when backfilling. - partial_state: True if `old_state` is partial and omits non-critical - membership events + state_ids_before_event: The event ids of the state before the event if + it can't be calculated from existing events. This is normally + only specified when receiving an event from federation where we + don't have the prev events, e.g. when backfilling. + partial_state: True if `state_ids_before_event` is partial and omits + non-critical membership events Returns: The event context. """ @@ -286,13 +286,11 @@ class StateHandler: assert not event.internal_metadata.is_outlier() # - # first of all, figure out the state before the event + # first of all, figure out the state before the event, unless we + # already have it. # - if old_state: + if state_ids_before_event: # if we're given the state before the event, then we use that - state_ids_before_event: StateMap[str] = { - (s.type, s.state_key): s.event_id for s in old_state - } state_group_before_event = None state_group_before_event_prev_group = None deltas_to_state_group_before_event = None diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index 18ae8aee29..ea5cbdac08 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -16,6 +16,8 @@ import collections.abc import logging from typing import TYPE_CHECKING, Collection, Dict, Iterable, Optional, Set, Tuple +import attr + from synapse.api.constants import EventTypes, Membership from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion @@ -26,6 +28,7 @@ from synapse.storage.database import ( DatabasePool, LoggingDatabaseConnection, LoggingTransaction, + make_in_list_sql_clause, ) from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore @@ -33,6 +36,7 @@ from synapse.storage.state import StateFilter from synapse.types import JsonDict, JsonMapping, StateMap from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedList +from synapse.util.iterutils import batch_iter if TYPE_CHECKING: from synapse.server import HomeServer @@ -43,6 +47,15 @@ logger = logging.getLogger(__name__) MAX_STATE_DELTA_HOPS = 100 +@attr.s(slots=True, frozen=True, auto_attribs=True) +class EventMetadata: + """Returned by `get_metadata_for_events`""" + + room_id: str + event_type: str + state_key: Optional[str] + + def _retrieve_and_check_room_version(room_id: str, room_version_id: str) -> RoomVersion: v = KNOWN_ROOM_VERSIONS.get(room_version_id) if not v: @@ -133,6 +146,52 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return room_version + async def get_metadata_for_events( + self, event_ids: Collection[str] + ) -> Dict[str, EventMetadata]: + """Get some metadata (room_id, type, state_key) for the given events. + + This method is a faster alternative than fetching the full events from + the DB, and should be used when the full event is not needed. + + Returns metadata for rejected and redacted events. Events that have not + been persisted are omitted from the returned dict. + """ + + def get_metadata_for_events_txn( + txn: LoggingTransaction, + batch_ids: Collection[str], + ) -> Dict[str, EventMetadata]: + clause, args = make_in_list_sql_clause( + self.database_engine, "e.event_id", batch_ids + ) + + sql = f""" + SELECT e.event_id, e.room_id, e.type, e.state_key FROM events AS e + LEFT JOIN state_events USING (event_id) + WHERE {clause} + """ + + txn.execute(sql, args) + return { + event_id: EventMetadata( + room_id=room_id, event_type=event_type, state_key=state_key + ) + for event_id, room_id, event_type, state_key in txn + } + + result_map: Dict[str, EventMetadata] = {} + for batch_ids in batch_iter(event_ids, 1000): + result_map.update( + await self.db_pool.runInteraction( + "get_metadata_for_events", + get_metadata_for_events_txn, + batch_ids=batch_ids, + ) + ) + + return result_map + async def get_room_predecessor(self, room_id: str) -> Optional[JsonMapping]: """Get the predecessor of an upgraded room if it exists. Otherwise return None. diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index bef6c2b776..ec00900621 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -276,7 +276,11 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase): # federation handler wanting to backfill the fake event. self.get_success( federation_event_handler._process_received_pdu( - self.OTHER_SERVER_NAME, event, state=current_state + self.OTHER_SERVER_NAME, + event, + state_ids={ + (e.type, e.state_key): e.event_id for e in current_state + }, ) ) diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py index ef5e25873c..aaa3189b16 100644 --- a/tests/storage/test_events.py +++ b/tests/storage/test_events.py @@ -69,7 +69,7 @@ class ExtremPruneTestCase(HomeserverTestCase): def persist_event(self, event, state=None): """Persist the event, with optional state""" context = self.get_success( - self.state.compute_event_context(event, old_state=state) + self.state.compute_event_context(event, state_ids_before_event=state) ) self.get_success(self.persistence.persist_event(event, context)) @@ -103,9 +103,11 @@ class ExtremPruneTestCase(HomeserverTestCase): RoomVersions.V6, ) - state_before_gap = self.get_success(self.state.get_current_state(self.room_id)) + state_before_gap = self.get_success( + self.state.get_current_state_ids(self.room_id) + ) - self.persist_event(remote_event_2, state=state_before_gap.values()) + self.persist_event(remote_event_2, state=state_before_gap) # Check the new extremity is just the new remote event. self.assert_extremities([remote_event_2.event_id]) @@ -135,13 +137,14 @@ class ExtremPruneTestCase(HomeserverTestCase): # setting. The state resolution across the old and new event will then # include it, and so the resolved state won't match the new state. state_before_gap = dict( - self.get_success(self.state.get_current_state(self.room_id)) + self.get_success(self.state.get_current_state_ids(self.room_id)) ) state_before_gap.pop(("m.room.history_visibility", "")) context = self.get_success( self.state.compute_event_context( - remote_event_2, old_state=state_before_gap.values() + remote_event_2, + state_ids_before_event=state_before_gap, ) ) @@ -177,9 +180,11 @@ class ExtremPruneTestCase(HomeserverTestCase): RoomVersions.V6, ) - state_before_gap = self.get_success(self.state.get_current_state(self.room_id)) + state_before_gap = self.get_success( + self.state.get_current_state_ids(self.room_id) + ) - self.persist_event(remote_event_2, state=state_before_gap.values()) + self.persist_event(remote_event_2, state=state_before_gap) # Check the new extremity is just the new remote event. self.assert_extremities([remote_event_2.event_id]) @@ -207,9 +212,11 @@ class ExtremPruneTestCase(HomeserverTestCase): RoomVersions.V6, ) - state_before_gap = self.get_success(self.state.get_current_state(self.room_id)) + state_before_gap = self.get_success( + self.state.get_current_state_ids(self.room_id) + ) - self.persist_event(remote_event_2, state=state_before_gap.values()) + self.persist_event(remote_event_2, state=state_before_gap) # Check the new extremity is just the new remote event. self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id]) @@ -247,9 +254,11 @@ class ExtremPruneTestCase(HomeserverTestCase): RoomVersions.V6, ) - state_before_gap = self.get_success(self.state.get_current_state(self.room_id)) + state_before_gap = self.get_success( + self.state.get_current_state_ids(self.room_id) + ) - self.persist_event(remote_event_2, state=state_before_gap.values()) + self.persist_event(remote_event_2, state=state_before_gap) # Check the new extremity is just the new remote event. self.assert_extremities([remote_event_2.event_id]) @@ -289,9 +298,11 @@ class ExtremPruneTestCase(HomeserverTestCase): RoomVersions.V6, ) - state_before_gap = self.get_success(self.state.get_current_state(self.room_id)) + state_before_gap = self.get_success( + self.state.get_current_state_ids(self.room_id) + ) - self.persist_event(remote_event_2, state=state_before_gap.values()) + self.persist_event(remote_event_2, state=state_before_gap) # Check the new extremity is just the new remote event. self.assert_extremities([remote_event_2.event_id, local_message_event_id]) @@ -323,9 +334,11 @@ class ExtremPruneTestCase(HomeserverTestCase): RoomVersions.V6, ) - state_before_gap = self.get_success(self.state.get_current_state(self.room_id)) + state_before_gap = self.get_success( + self.state.get_current_state_ids(self.room_id) + ) - self.persist_event(remote_event_2, state=state_before_gap.values()) + self.persist_event(remote_event_2, state=state_before_gap) # Check the new extremity is just the new remote event. self.assert_extremities([local_message_event_id, remote_event_2.event_id]) diff --git a/tests/test_state.py b/tests/test_state.py index c6baea3d76..84694d368d 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -442,7 +442,12 @@ class StateTestCase(unittest.TestCase): ] context = yield defer.ensureDeferred( - self.state.compute_event_context(event, old_state=old_state) + self.state.compute_event_context( + event, + state_ids_before_event={ + (e.type, e.state_key): e.event_id for e in old_state + }, + ) ) prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids()) @@ -467,7 +472,12 @@ class StateTestCase(unittest.TestCase): ] context = yield defer.ensureDeferred( - self.state.compute_event_context(event, old_state=old_state) + self.state.compute_event_context( + event, + state_ids_before_event={ + (e.type, e.state_key): e.event_id for e in old_state + }, + ) ) prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids()) From b5707ceabad79267928b1f5e0bff582b09488847 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 26 May 2022 07:09:16 -0400 Subject: [PATCH 19/74] Avoid attempting to delete push actions for remote users. (#12879) Remote users will never have push actions, so we can avoid a database round-trip/transaction completely. --- changelog.d/12879.misc | 1 + synapse/federation/sender/per_destination_queue.py | 2 +- synapse/storage/databases/main/event_push_actions.py | 2 +- synapse/storage/databases/main/receipts.py | 5 ++++- synapse/storage/persist_events.py | 2 +- 5 files changed, 8 insertions(+), 4 deletions(-) create mode 100644 changelog.d/12879.misc diff --git a/changelog.d/12879.misc b/changelog.d/12879.misc new file mode 100644 index 0000000000..24fa0d0de0 --- /dev/null +++ b/changelog.d/12879.misc @@ -0,0 +1 @@ +Avoid running queries which will never result in deletions. diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index d80f0ac5e8..8983b5a53d 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -223,7 +223,7 @@ class PerDestinationQueue: """Marks that the destination has new data to send, without starting a new transaction. - If a transaction loop is already in progress then a new transcation will + If a transaction loop is already in progress then a new transaction will be attempted when the current one finishes. """ diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index b7c4c62222..b019979350 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -938,7 +938,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): users can still get a list of recent highlights. Args: - txn: The transcation + txn: The transaction room_id: Room ID to delete from user_id: user ID to delete for stream_ordering: The lowest stream ordering which will diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index d035969a31..cfa4d4924d 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -673,8 +673,11 @@ class ReceiptsWorkerStore(SQLBaseStore): lock=False, ) + # When updating a local users read receipt, remove any push actions + # which resulted from the receipt's event and all earlier events. if ( - receipt_type in (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE) + self.hs.is_mine_id(user_id) + and receipt_type in (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE) and stream_ordering is not None ): self._remove_old_push_actions_before_txn( # type: ignore[attr-defined] diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index 0fc282866b..a21dea91c8 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -313,7 +313,7 @@ class EventsPersistenceStorage: List of events persisted, the current position room stream position. The list of events persisted may not be the same as those passed in if they were deduplicated due to an event already existing that - matched the transcation ID; the existing event is returned in such + matched the transaction ID; the existing event is returned in such a case. """ partitioned: Dict[str, List[Tuple[EventBase, EventContext]]] = {} From 1885ee011395f9c1f121f8045ac6d47a74c4cc24 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 26 May 2022 07:10:28 -0400 Subject: [PATCH 20/74] Remove unstable APIs for /hierarchy. (#12851) Removes the unstable endpoint as well as a duplicated field which was modified during stabilization. --- changelog.d/12851.misc | 1 + docs/workers.md | 6 +++--- .../federation/transport/server/federation.py | 5 ----- synapse/handlers/room_summary.py | 5 +---- synapse/rest/client/room.py | 7 +------ tests/handlers/test_room_summary.py | 20 +++++++++---------- 6 files changed, 16 insertions(+), 28 deletions(-) create mode 100644 changelog.d/12851.misc diff --git a/changelog.d/12851.misc b/changelog.d/12851.misc new file mode 100644 index 0000000000..ca6f48c369 --- /dev/null +++ b/changelog.d/12851.misc @@ -0,0 +1 @@ +Remove the unstable `/hierarchy` endpoint from [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). diff --git a/docs/workers.md b/docs/workers.md index 6a76f43fa1..78973a498c 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -193,7 +193,7 @@ information. ^/_matrix/federation/v1/user/devices/ ^/_matrix/federation/v1/get_groups_publicised$ ^/_matrix/key/v2/query - ^/_matrix/federation/(v1|unstable/org.matrix.msc2946)/hierarchy/ + ^/_matrix/federation/v1/hierarchy/ # Inbound federation transaction request ^/_matrix/federation/v1/send/ @@ -205,8 +205,8 @@ information. ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/context/.*$ ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/members$ ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/state$ - ^/_matrix/client/(v1|unstable/org.matrix.msc2946)/rooms/.*/hierarchy$ - ^/_matrix/client/(v1|unstable/org.matrix.msc2716)/rooms/.*/batch_send$ + ^/_matrix/client/v1/rooms/.*/hierarchy$ + ^/_matrix/client/unstable/org.matrix.msc2716/rooms/.*/batch_send$ ^/_matrix/client/unstable/im.nheko.summary/rooms/.*/summary$ ^/_matrix/client/(r0|v3|unstable)/account/3pid$ ^/_matrix/client/(r0|v3|unstable)/account/whoami$ diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py index 6fbc7b5f15..57e8fb21b0 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py @@ -650,10 +650,6 @@ class FederationRoomHierarchyServlet(BaseFederationServlet): ) -class FederationRoomHierarchyUnstableServlet(FederationRoomHierarchyServlet): - PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc2946" - - class RoomComplexityServlet(BaseFederationServlet): """ Indicates to other servers how complex (and therefore likely @@ -752,7 +748,6 @@ FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( FederationVersionServlet, RoomComplexityServlet, FederationRoomHierarchyServlet, - FederationRoomHierarchyUnstableServlet, FederationV1SendKnockServlet, FederationMakeKnockServlet, FederationAccountStatusServlet, diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py index 1dd74912fa..75aee6a111 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py @@ -662,7 +662,7 @@ class RoomSummaryHandler: # The API doesn't return the room version so assume that a # join rule of knock is valid. if ( - room.get("join_rules") + room.get("join_rule") in (JoinRules.PUBLIC, JoinRules.KNOCK, JoinRules.KNOCK_RESTRICTED) or room.get("world_readable") is True ): @@ -714,9 +714,6 @@ class RoomSummaryHandler: "canonical_alias": stats["canonical_alias"], "num_joined_members": stats["joined_members"], "avatar_url": stats["avatar"], - # plural join_rules is a documentation error but kept for historical - # purposes. Should match /publicRooms. - "join_rules": stats["join_rules"], "join_rule": stats["join_rules"], "world_readable": ( stats["history_visibility"] == HistoryVisibility.WORLD_READABLE diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 5a2361a2e6..7a5ce8ad0e 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -1193,12 +1193,7 @@ class TimestampLookupRestServlet(RestServlet): class RoomHierarchyRestServlet(RestServlet): - PATTERNS = ( - re.compile( - "^/_matrix/client/(v1|unstable/org.matrix.msc2946)" - "/rooms/(?P[^/]*)/hierarchy$" - ), - ) + PATTERNS = (re.compile("^/_matrix/client/v1/rooms/(?P[^/]*)/hierarchy$"),) def __init__(self, hs: "HomeServer"): super().__init__() diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py index e74eb71774..0546655690 100644 --- a/tests/handlers/test_room_summary.py +++ b/tests/handlers/test_room_summary.py @@ -179,7 +179,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): result_children_ids.append( [ (cs["room_id"], cs["state_key"]) - for cs in result_room.get("children_state") + for cs in result_room["children_state"] ] ) @@ -772,7 +772,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): { "room_id": public_room, "world_readable": False, - "join_rules": JoinRules.PUBLIC, + "join_rule": JoinRules.PUBLIC, }, ), ( @@ -780,7 +780,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): { "room_id": knock_room, "world_readable": False, - "join_rules": JoinRules.KNOCK, + "join_rule": JoinRules.KNOCK, }, ), ( @@ -788,7 +788,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): { "room_id": not_invited_room, "world_readable": False, - "join_rules": JoinRules.INVITE, + "join_rule": JoinRules.INVITE, }, ), ( @@ -796,7 +796,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): { "room_id": invited_room, "world_readable": False, - "join_rules": JoinRules.INVITE, + "join_rule": JoinRules.INVITE, }, ), ( @@ -804,7 +804,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): { "room_id": restricted_room, "world_readable": False, - "join_rules": JoinRules.RESTRICTED, + "join_rule": JoinRules.RESTRICTED, "allowed_room_ids": [], }, ), @@ -813,7 +813,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): { "room_id": restricted_accessible_room, "world_readable": False, - "join_rules": JoinRules.RESTRICTED, + "join_rule": JoinRules.RESTRICTED, "allowed_room_ids": [self.room], }, ), @@ -822,7 +822,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): { "room_id": world_readable_room, "world_readable": True, - "join_rules": JoinRules.INVITE, + "join_rule": JoinRules.INVITE, }, ), ( @@ -830,7 +830,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): { "room_id": joined_room, "world_readable": False, - "join_rules": JoinRules.INVITE, + "join_rule": JoinRules.INVITE, }, ), ) @@ -911,7 +911,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): { "room_id": fed_room, "world_readable": False, - "join_rules": JoinRules.INVITE, + "join_rule": JoinRules.INVITE, }, ) From e76864436867deba7fc6b740d1f8d80f4717f44b Mon Sep 17 00:00:00 2001 From: reivilibre Date: Thu, 26 May 2022 12:19:01 +0100 Subject: [PATCH 21/74] Fix ambiguous column name that would prevent use of MSC2716 History Import when using Postgres as a database. (#12843) --- changelog.d/12843.bugfix | 1 + synapse/storage/databases/main/event_federation.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/12843.bugfix diff --git a/changelog.d/12843.bugfix b/changelog.d/12843.bugfix new file mode 100644 index 0000000000..f87c0799a0 --- /dev/null +++ b/changelog.d/12843.bugfix @@ -0,0 +1 @@ +Fix bug where servers using a Postgres database would fail to backfill from an insertion event when MSC2716 is enabled (`experimental_features.msc2716_enabled`). diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index dcfe8caf47..562dcbe94d 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -1057,7 +1057,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas INNER JOIN batch_events AS c ON i.next_batch_id = c.batch_id /* Get the depth of the batch start event from the events table */ - INNER JOIN events AS e USING (event_id) + INNER JOIN events AS e ON c.event_id = e.event_id /* Find an insertion event which matches the given event_id */ WHERE i.event_id = ? LIMIT ? From 1cba285a7971eb88f41139ff466918332a98b479 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 26 May 2022 12:42:21 +0000 Subject: [PATCH 22/74] Bump pyjwt from 2.3.0 to 2.4.0 (#12865) Bumps [pyjwt](https://github.com/jpadilla/pyjwt) from 2.3.0 to 2.4.0. - [Release notes](https://github.com/jpadilla/pyjwt/releases) - [Changelog](https://github.com/jpadilla/pyjwt/blob/master/CHANGELOG.rst) - [Commits](https://github.com/jpadilla/pyjwt/compare/2.3.0...2.4.0) --- updated-dependencies: - dependency-name: pyjwt dependency-type: direct:production ... Signed-off-by: dependabot[bot] --- changelog.d/12865.misc | 1 + poetry.lock | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) create mode 100644 changelog.d/12865.misc diff --git a/changelog.d/12865.misc b/changelog.d/12865.misc new file mode 100644 index 0000000000..d982ca7622 --- /dev/null +++ b/changelog.d/12865.misc @@ -0,0 +1 @@ +Update `pyjwt` dependency to [2.4.0](https://github.com/jpadilla/pyjwt/releases/tag/2.4.0). diff --git a/poetry.lock b/poetry.lock index 49a912a589..f64d70941e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -813,7 +813,7 @@ python-versions = ">=3.5" [[package]] name = "pyjwt" -version = "2.3.0" +version = "2.4.0" description = "JSON Web Token implementation in Python" category = "main" optional = false @@ -2264,8 +2264,8 @@ pygments = [ {file = "Pygments-2.11.2.tar.gz", hash = "sha256:4e426f72023d88d03b2fa258de560726ce890ff3b630f88c21cbb8b2503b8c6a"}, ] pyjwt = [ - {file = "PyJWT-2.3.0-py3-none-any.whl", hash = "sha256:e0c4bb8d9f0af0c7f5b1ec4c5036309617d03d56932877f2f7a0beeb5318322f"}, - {file = "PyJWT-2.3.0.tar.gz", hash = "sha256:b888b4d56f06f6dcd777210c334e69c737be74755d3e5e9ee3fe67dc18a0ee41"}, + {file = "PyJWT-2.4.0-py3-none-any.whl", hash = "sha256:72d1d253f32dbd4f5c88eaf1fdc62f3a19f676ccbadb9dbc5d07e951b2b26daf"}, + {file = "PyJWT-2.4.0.tar.gz", hash = "sha256:d42908208c699b3b973cbeb01a969ba6a96c821eefb1c5bfe4c390c01d67abba"}, ] pymacaroons = [ {file = "pymacaroons-0.13.0-py2.py3-none-any.whl", hash = "sha256:3e14dff6a262fdbf1a15e769ce635a8aea72e6f8f91e408f9a97166c53b91907"}, From 49f06866e4db2e19467a3733b2909ba397da265e Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 26 May 2022 09:04:34 -0400 Subject: [PATCH 23/74] Remove backing code for groups/communities (#12558) Including handlers, configuration code, appservice support, and the GroupID construct. --- changelog.d/12558.removal | 1 + synapse/appservice/__init__.py | 43 +- synapse/config/_base.pyi | 2 - synapse/config/groups.py | 27 - synapse/config/homeserver.py | 2 - synapse/groups/__init__.py | 0 synapse/groups/attestations.py | 218 ------ synapse/groups/groups_server.py | 1019 --------------------------- synapse/handlers/groups_local.py | 503 ------------- synapse/server.py | 39 +- synapse/types.py | 23 - tests/appservice/test_appservice.py | 2 +- tests/test_types.py | 21 +- 13 files changed, 6 insertions(+), 1894 deletions(-) create mode 100644 changelog.d/12558.removal delete mode 100644 synapse/config/groups.py delete mode 100644 synapse/groups/__init__.py delete mode 100644 synapse/groups/attestations.py delete mode 100644 synapse/groups/groups_server.py delete mode 100644 synapse/handlers/groups_local.py diff --git a/changelog.d/12558.removal b/changelog.d/12558.removal new file mode 100644 index 0000000000..41f6fae5da --- /dev/null +++ b/changelog.d/12558.removal @@ -0,0 +1 @@ +Remove support for the non-standard groups/communities feature from Synapse. diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index a610fb785d..ed92c2e910 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -23,13 +23,7 @@ from netaddr import IPSet from synapse.api.constants import EventTypes from synapse.events import EventBase -from synapse.types import ( - DeviceListUpdates, - GroupID, - JsonDict, - UserID, - get_domain_from_id, -) +from synapse.types import DeviceListUpdates, JsonDict, UserID from synapse.util.caches.descriptors import _CacheContext, cached if TYPE_CHECKING: @@ -55,7 +49,6 @@ class ApplicationServiceState(Enum): @attr.s(slots=True, frozen=True, auto_attribs=True) class Namespace: exclusive: bool - group_id: Optional[str] regex: Pattern[str] @@ -141,30 +134,13 @@ class ApplicationService: exclusive = regex_obj.get("exclusive") if not isinstance(exclusive, bool): raise ValueError("Expected bool for 'exclusive' in ns '%s'" % ns) - group_id = regex_obj.get("group_id") - if group_id: - if not isinstance(group_id, str): - raise ValueError( - "Expected string for 'group_id' in ns '%s'" % ns - ) - try: - GroupID.from_string(group_id) - except Exception: - raise ValueError( - "Expected valid group ID for 'group_id' in ns '%s'" % ns - ) - - if get_domain_from_id(group_id) != self.server_name: - raise ValueError( - "Expected 'group_id' to be this host in ns '%s'" % ns - ) regex = regex_obj.get("regex") if not isinstance(regex, str): raise ValueError("Expected string for 'regex' in ns '%s'" % ns) # Pre-compile regex. - result[ns].append(Namespace(exclusive, group_id, re.compile(regex))) + result[ns].append(Namespace(exclusive, re.compile(regex))) return result @@ -369,21 +345,6 @@ class ApplicationService: if namespace.exclusive ] - def get_groups_for_user(self, user_id: str) -> Iterable[str]: - """Get the groups that this user is associated with by this AS - - Args: - user_id: The ID of the user. - - Returns: - An iterable that yields group_id strings. - """ - return ( - namespace.group_id - for namespace in self.namespaces[ApplicationService.NS_USERS] - if namespace.group_id and namespace.regex.match(user_id) - ) - def is_rate_limited(self) -> bool: return self.rate_limited diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index 71d6655fda..01ea2b4dab 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -32,7 +32,6 @@ from synapse.config import ( emailconfig, experimental, federation, - groups, jwt, key, logger, @@ -107,7 +106,6 @@ class RootConfig: push: push.PushConfig spamchecker: spam_checker.SpamCheckerConfig room: room.RoomConfig - groups: groups.GroupsConfig userdirectory: user_directory.UserDirectoryConfig consent: consent.ConsentConfig stats: stats.StatsConfig diff --git a/synapse/config/groups.py b/synapse/config/groups.py deleted file mode 100644 index baa051fdd4..0000000000 --- a/synapse/config/groups.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright 2017 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any - -from synapse.types import JsonDict - -from ._base import Config - - -class GroupsConfig(Config): - section = "groups" - - def read_config(self, config: JsonDict, **kwargs: Any) -> None: - self.enable_group_creation = config.get("enable_group_creation", False) - self.group_creation_prefix = config.get("group_creation_prefix", "") diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index a4ec706908..4d2b298a70 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -25,7 +25,6 @@ from .database import DatabaseConfig from .emailconfig import EmailConfig from .experimental import ExperimentalConfig from .federation import FederationConfig -from .groups import GroupsConfig from .jwt import JWTConfig from .key import KeyConfig from .logger import LoggingConfig @@ -89,7 +88,6 @@ class HomeServerConfig(RootConfig): PushConfig, SpamCheckerConfig, RoomConfig, - GroupsConfig, UserDirectoryConfig, ConsentConfig, StatsConfig, diff --git a/synapse/groups/__init__.py b/synapse/groups/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py deleted file mode 100644 index ed26d6a6ce..0000000000 --- a/synapse/groups/attestations.py +++ /dev/null @@ -1,218 +0,0 @@ -# Copyright 2017 Vector Creations Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Attestations ensure that users and groups can't lie about their memberships. - -When a user joins a group the HS and GS swap attestations, which allow them -both to independently prove to third parties their membership.These -attestations have a validity period so need to be periodically renewed. - -If a user leaves (or gets kicked out of) a group, either side can still use -their attestation to "prove" their membership, until the attestation expires. -Therefore attestations shouldn't be relied on to prove membership in important -cases, but can for less important situations, e.g. showing a users membership -of groups on their profile, showing flairs, etc. - -An attestation is a signed blob of json that looks like: - - { - "user_id": "@foo:a.example.com", - "group_id": "+bar:b.example.com", - "valid_until_ms": 1507994728530, - "signatures":{"matrix.org":{"ed25519:auto":"..."}} - } -""" - -import logging -import random -from typing import TYPE_CHECKING, Optional, Tuple - -from signedjson.sign import sign_json - -from twisted.internet.defer import Deferred - -from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError -from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.types import JsonDict, get_domain_from_id - -if TYPE_CHECKING: - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - - -# Default validity duration for new attestations we create -DEFAULT_ATTESTATION_LENGTH_MS = 3 * 24 * 60 * 60 * 1000 - -# We add some jitter to the validity duration of attestations so that if we -# add lots of users at once we don't need to renew them all at once. -# The jitter is a multiplier picked randomly between the first and second number -DEFAULT_ATTESTATION_JITTER = (0.9, 1.3) - -# Start trying to update our attestations when they come this close to expiring -UPDATE_ATTESTATION_TIME_MS = 1 * 24 * 60 * 60 * 1000 - - -class GroupAttestationSigning: - """Creates and verifies group attestations.""" - - def __init__(self, hs: "HomeServer"): - self.keyring = hs.get_keyring() - self.clock = hs.get_clock() - self.server_name = hs.hostname - self.signing_key = hs.signing_key - - async def verify_attestation( - self, - attestation: JsonDict, - group_id: str, - user_id: str, - server_name: Optional[str] = None, - ) -> None: - """Verifies that the given attestation matches the given parameters. - - An optional server_name can be supplied to explicitly set which server's - signature is expected. Otherwise assumes that either the group_id or user_id - is local and uses the other's server as the one to check. - """ - - if not server_name: - if get_domain_from_id(group_id) == self.server_name: - server_name = get_domain_from_id(user_id) - elif get_domain_from_id(user_id) == self.server_name: - server_name = get_domain_from_id(group_id) - else: - raise Exception("Expected either group_id or user_id to be local") - - if user_id != attestation["user_id"]: - raise SynapseError(400, "Attestation has incorrect user_id") - - if group_id != attestation["group_id"]: - raise SynapseError(400, "Attestation has incorrect group_id") - valid_until_ms = attestation["valid_until_ms"] - - # TODO: We also want to check that *new* attestations that people give - # us to store are valid for at least a little while. - now = self.clock.time_msec() - if valid_until_ms < now: - raise SynapseError(400, "Attestation expired") - - assert server_name is not None - await self.keyring.verify_json_for_server( - server_name, - attestation, - now, - ) - - def create_attestation(self, group_id: str, user_id: str) -> JsonDict: - """Create an attestation for the group_id and user_id with default - validity length. - """ - validity_period = DEFAULT_ATTESTATION_LENGTH_MS * random.uniform( - *DEFAULT_ATTESTATION_JITTER - ) - valid_until_ms = int(self.clock.time_msec() + validity_period) - - return sign_json( - { - "group_id": group_id, - "user_id": user_id, - "valid_until_ms": valid_until_ms, - }, - self.server_name, - self.signing_key, - ) - - -class GroupAttestionRenewer: - """Responsible for sending and receiving attestation updates.""" - - def __init__(self, hs: "HomeServer"): - self.clock = hs.get_clock() - self.store = hs.get_datastores().main - self.assestations = hs.get_groups_attestation_signing() - self.transport_client = hs.get_federation_transport_client() - self.is_mine_id = hs.is_mine_id - self.attestations = hs.get_groups_attestation_signing() - - if not hs.config.worker.worker_app: - self._renew_attestations_loop = self.clock.looping_call( - self._start_renew_attestations, 30 * 60 * 1000 - ) - - async def on_renew_attestation( - self, group_id: str, user_id: str, content: JsonDict - ) -> JsonDict: - """When a remote updates an attestation""" - attestation = content["attestation"] - - if not self.is_mine_id(group_id) and not self.is_mine_id(user_id): - raise SynapseError(400, "Neither user not group are on this server") - - await self.attestations.verify_attestation( - attestation, user_id=user_id, group_id=group_id - ) - - await self.store.update_remote_attestion(group_id, user_id, attestation) - - return {} - - def _start_renew_attestations(self) -> "Deferred[None]": - return run_as_background_process("renew_attestations", self._renew_attestations) - - async def _renew_attestations(self) -> None: - """Called periodically to check if we need to update any of our attestations""" - - now = self.clock.time_msec() - - rows = await self.store.get_attestations_need_renewals( - now + UPDATE_ATTESTATION_TIME_MS - ) - - async def _renew_attestation(group_user: Tuple[str, str]) -> None: - group_id, user_id = group_user - try: - if not self.is_mine_id(group_id): - destination = get_domain_from_id(group_id) - elif not self.is_mine_id(user_id): - destination = get_domain_from_id(user_id) - else: - logger.warning( - "Incorrectly trying to do attestations for user: %r in %r", - user_id, - group_id, - ) - await self.store.remove_attestation_renewal(group_id, user_id) - return - - attestation = self.attestations.create_attestation(group_id, user_id) - - await self.transport_client.renew_group_attestation( - destination, group_id, user_id, content={"attestation": attestation} - ) - - await self.store.update_attestation_renewal( - group_id, user_id, attestation - ) - except (RequestSendFailed, HttpResponseException) as e: - logger.warning( - "Failed to renew attestation of %r in %r: %s", user_id, group_id, e - ) - except Exception: - logger.exception( - "Error renewing attestation of %r in %r", user_id, group_id - ) - - for row in rows: - await _renew_attestation((row["group_id"], row["user_id"])) diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py deleted file mode 100644 index dfd24af695..0000000000 --- a/synapse/groups/groups_server.py +++ /dev/null @@ -1,1019 +0,0 @@ -# Copyright 2017 Vector Creations Ltd -# Copyright 2018 New Vector Ltd -# Copyright 2019 Michael Telatynski <7t3chguy@gmail.com> -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from typing import TYPE_CHECKING, Optional - -from synapse.api.errors import Codes, SynapseError -from synapse.handlers.groups_local import GroupsLocalHandler -from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN -from synapse.types import GroupID, JsonDict, RoomID, UserID, get_domain_from_id -from synapse.util.async_helpers import concurrently_execute - -if TYPE_CHECKING: - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - - -# TODO: Allow users to "knock" or simply join depending on rules -# TODO: Federation admin APIs -# TODO: is_privileged flag to users and is_public to users and rooms -# TODO: Audit log for admins (profile updates, membership changes, users who tried -# to join but were rejected, etc) -# TODO: Flairs - - -# Note that the maximum lengths are somewhat arbitrary. -MAX_SHORT_DESC_LEN = 1000 -MAX_LONG_DESC_LEN = 10000 - - -class GroupsServerWorkerHandler: - def __init__(self, hs: "HomeServer"): - self.hs = hs - self.store = hs.get_datastores().main - self.room_list_handler = hs.get_room_list_handler() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.keyring = hs.get_keyring() - self.is_mine_id = hs.is_mine_id - self.signing_key = hs.signing_key - self.server_name = hs.hostname - self.attestations = hs.get_groups_attestation_signing() - self.transport_client = hs.get_federation_transport_client() - self.profile_handler = hs.get_profile_handler() - - async def check_group_is_ours( - self, - group_id: str, - requester_user_id: str, - and_exists: bool = False, - and_is_admin: Optional[str] = None, - ) -> Optional[dict]: - """Check that the group is ours, and optionally if it exists. - - If group does exist then return group. - - Args: - group_id: The group ID to check. - requester_user_id: The user ID of the requester. - and_exists: whether to also check if group exists - and_is_admin: whether to also check if given str is a user_id - that is an admin - """ - if not self.is_mine_id(group_id): - raise SynapseError(400, "Group not on this server") - - group = await self.store.get_group(group_id) - if and_exists and not group: - raise SynapseError(404, "Unknown group") - - is_user_in_group = await self.store.is_user_in_group( - requester_user_id, group_id - ) - if group and not is_user_in_group and not group["is_public"]: - raise SynapseError(404, "Unknown group") - - if and_is_admin: - is_admin = await self.store.is_user_admin_in_group(group_id, and_is_admin) - if not is_admin: - raise SynapseError(403, "User is not admin in group") - - return group - - async def get_group_summary( - self, group_id: str, requester_user_id: str - ) -> JsonDict: - """Get the summary for a group as seen by requester_user_id. - - The group summary consists of the profile of the room, and a curated - list of users and rooms. These list *may* be organised by role/category. - The roles/categories are ordered, and so are the users/rooms within them. - - A user/room may appear in multiple roles/categories. - """ - await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - - is_user_in_group = await self.store.is_user_in_group( - requester_user_id, group_id - ) - - profile = await self.get_group_profile(group_id, requester_user_id) - - users, roles = await self.store.get_users_for_summary_by_role( - group_id, include_private=is_user_in_group - ) - - # TODO: Add profiles to users - - rooms, categories = await self.store.get_rooms_for_summary_by_category( - group_id, include_private=is_user_in_group - ) - - for room_entry in rooms: - room_id = room_entry["room_id"] - joined_users = await self.store.get_users_in_room(room_id) - entry = await self.room_list_handler.generate_room_entry( - room_id, len(joined_users), with_alias=False, allow_private=True - ) - if entry is None: - continue - entry = dict(entry) # so we don't change what's cached - entry.pop("room_id", None) - - room_entry["profile"] = entry - - rooms.sort(key=lambda e: e.get("order", 0)) - - for user in users: - user_id = user["user_id"] - - if not self.is_mine_id(requester_user_id): - attestation = await self.store.get_remote_attestation(group_id, user_id) - if not attestation: - continue - - user["attestation"] = attestation - else: - user["attestation"] = self.attestations.create_attestation( - group_id, user_id - ) - - user_profile = await self.profile_handler.get_profile_from_cache(user_id) - user.update(user_profile) - - users.sort(key=lambda e: e.get("order", 0)) - - membership_info = await self.store.get_users_membership_info_in_group( - group_id, requester_user_id - ) - - return { - "profile": profile, - "users_section": { - "users": users, - "roles": roles, - "total_user_count_estimate": 0, # TODO - }, - "rooms_section": { - "rooms": rooms, - "categories": categories, - "total_room_count_estimate": 0, # TODO - }, - "user": membership_info, - } - - async def get_group_categories( - self, group_id: str, requester_user_id: str - ) -> JsonDict: - """Get all categories in a group (as seen by user)""" - await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - - categories = await self.store.get_group_categories(group_id=group_id) - return {"categories": categories} - - async def get_group_category( - self, group_id: str, requester_user_id: str, category_id: str - ) -> JsonDict: - """Get a specific category in a group (as seen by user)""" - await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - - return await self.store.get_group_category( - group_id=group_id, category_id=category_id - ) - - async def get_group_roles(self, group_id: str, requester_user_id: str) -> JsonDict: - """Get all roles in a group (as seen by user)""" - await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - - roles = await self.store.get_group_roles(group_id=group_id) - return {"roles": roles} - - async def get_group_role( - self, group_id: str, requester_user_id: str, role_id: str - ) -> JsonDict: - """Get a specific role in a group (as seen by user)""" - await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - - return await self.store.get_group_role(group_id=group_id, role_id=role_id) - - async def get_group_profile( - self, group_id: str, requester_user_id: str - ) -> JsonDict: - """Get the group profile as seen by requester_user_id""" - - await self.check_group_is_ours(group_id, requester_user_id) - - group = await self.store.get_group(group_id) - - if group: - cols = [ - "name", - "short_description", - "long_description", - "avatar_url", - "is_public", - ] - group_description = {key: group[key] for key in cols} - group_description["is_openly_joinable"] = group["join_policy"] == "open" - - return group_description - else: - raise SynapseError(404, "Unknown group") - - async def get_users_in_group( - self, group_id: str, requester_user_id: str - ) -> JsonDict: - """Get the users in group as seen by requester_user_id. - - The ordering is arbitrary at the moment - """ - - await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - - is_user_in_group = await self.store.is_user_in_group( - requester_user_id, group_id - ) - - user_results = await self.store.get_users_in_group( - group_id, include_private=is_user_in_group - ) - - chunk = [] - for user_result in user_results: - g_user_id = user_result["user_id"] - is_public = user_result["is_public"] - is_privileged = user_result["is_admin"] - - entry = {"user_id": g_user_id} - - profile = await self.profile_handler.get_profile_from_cache(g_user_id) - entry.update(profile) - - entry["is_public"] = bool(is_public) - entry["is_privileged"] = bool(is_privileged) - - if not self.is_mine_id(g_user_id): - attestation = await self.store.get_remote_attestation( - group_id, g_user_id - ) - if not attestation: - continue - - entry["attestation"] = attestation - else: - entry["attestation"] = self.attestations.create_attestation( - group_id, g_user_id - ) - - chunk.append(entry) - - # TODO: If admin add lists of users whose attestations have timed out - - return {"chunk": chunk, "total_user_count_estimate": len(user_results)} - - async def get_invited_users_in_group( - self, group_id: str, requester_user_id: str - ) -> JsonDict: - """Get the users that have been invited to a group as seen by requester_user_id. - - The ordering is arbitrary at the moment - """ - - await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - - is_user_in_group = await self.store.is_user_in_group( - requester_user_id, group_id - ) - - if not is_user_in_group: - raise SynapseError(403, "User not in group") - - invited_users = await self.store.get_invited_users_in_group(group_id) - - user_profiles = [] - - for user_id in invited_users: - user_profile = {"user_id": user_id} - try: - profile = await self.profile_handler.get_profile_from_cache(user_id) - user_profile.update(profile) - except Exception as e: - logger.warning("Error getting profile for %s: %s", user_id, e) - user_profiles.append(user_profile) - - return {"chunk": user_profiles, "total_user_count_estimate": len(invited_users)} - - async def get_rooms_in_group( - self, group_id: str, requester_user_id: str - ) -> JsonDict: - """Get the rooms in group as seen by requester_user_id - - This returns rooms in order of decreasing number of joined users - """ - - await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - - is_user_in_group = await self.store.is_user_in_group( - requester_user_id, group_id - ) - - # Note! room_results["is_public"] is about whether the room is considered - # public from the group's point of view. (i.e. whether non-group members - # should be able to see the room is in the group). - # This is not the same as whether the room itself is public (in the sense - # of being visible in the room directory). - # As such, room_results["is_public"] itself is not sufficient to determine - # whether any given user is permitted to see the room's metadata. - room_results = await self.store.get_rooms_in_group( - group_id, include_private=is_user_in_group - ) - - chunk = [] - for room_result in room_results: - room_id = room_result["room_id"] - - joined_users = await self.store.get_users_in_room(room_id) - - # check the user is actually allowed to see the room before showing it to them - allow_private = requester_user_id in joined_users - - entry = await self.room_list_handler.generate_room_entry( - room_id, - len(joined_users), - with_alias=False, - allow_private=allow_private, - ) - - if not entry: - continue - - entry["is_public"] = bool(room_result["is_public"]) - - chunk.append(entry) - - chunk.sort(key=lambda e: -e["num_joined_members"]) - - return {"chunk": chunk, "total_room_count_estimate": len(chunk)} - - -class GroupsServerHandler(GroupsServerWorkerHandler): - def __init__(self, hs: "HomeServer"): - super().__init__(hs) - - # Ensure attestations get renewed - hs.get_groups_attestation_renewer() - - async def update_group_summary_room( - self, - group_id: str, - requester_user_id: str, - room_id: str, - category_id: str, - content: JsonDict, - ) -> JsonDict: - """Add/update a room to the group summary""" - await self.check_group_is_ours( - group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id - ) - - RoomID.from_string(room_id) # Ensure valid room id - - order = content.get("order", None) - - is_public = _parse_visibility_from_contents(content) - - await self.store.add_room_to_summary( - group_id=group_id, - room_id=room_id, - category_id=category_id, - order=order, - is_public=is_public, - ) - - return {} - - async def delete_group_summary_room( - self, group_id: str, requester_user_id: str, room_id: str, category_id: str - ) -> JsonDict: - """Remove a room from the summary""" - await self.check_group_is_ours( - group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id - ) - - await self.store.remove_room_from_summary( - group_id=group_id, room_id=room_id, category_id=category_id - ) - - return {} - - async def set_group_join_policy( - self, group_id: str, requester_user_id: str, content: JsonDict - ) -> JsonDict: - """Sets the group join policy. - - Currently supported policies are: - - "invite": an invite must be received and accepted in order to join. - - "open": anyone can join. - """ - await self.check_group_is_ours( - group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id - ) - - join_policy = _parse_join_policy_from_contents(content) - if join_policy is None: - raise SynapseError(400, "No value specified for 'm.join_policy'") - - await self.store.set_group_join_policy(group_id, join_policy=join_policy) - - return {} - - async def update_group_category( - self, group_id: str, requester_user_id: str, category_id: str, content: JsonDict - ) -> JsonDict: - """Add/Update a group category""" - await self.check_group_is_ours( - group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id - ) - - is_public = _parse_visibility_from_contents(content) - profile = content.get("profile") - - await self.store.upsert_group_category( - group_id=group_id, - category_id=category_id, - is_public=is_public, - profile=profile, - ) - - return {} - - async def delete_group_category( - self, group_id: str, requester_user_id: str, category_id: str - ) -> JsonDict: - """Delete a group category""" - await self.check_group_is_ours( - group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id - ) - - await self.store.remove_group_category( - group_id=group_id, category_id=category_id - ) - - return {} - - async def update_group_role( - self, group_id: str, requester_user_id: str, role_id: str, content: JsonDict - ) -> JsonDict: - """Add/update a role in a group""" - await self.check_group_is_ours( - group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id - ) - - is_public = _parse_visibility_from_contents(content) - - profile = content.get("profile") - - await self.store.upsert_group_role( - group_id=group_id, role_id=role_id, is_public=is_public, profile=profile - ) - - return {} - - async def delete_group_role( - self, group_id: str, requester_user_id: str, role_id: str - ) -> JsonDict: - """Remove role from group""" - await self.check_group_is_ours( - group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id - ) - - await self.store.remove_group_role(group_id=group_id, role_id=role_id) - - return {} - - async def update_group_summary_user( - self, - group_id: str, - requester_user_id: str, - user_id: str, - role_id: str, - content: JsonDict, - ) -> JsonDict: - """Add/update a users entry in the group summary""" - await self.check_group_is_ours( - group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id - ) - - order = content.get("order", None) - - is_public = _parse_visibility_from_contents(content) - - await self.store.add_user_to_summary( - group_id=group_id, - user_id=user_id, - role_id=role_id, - order=order, - is_public=is_public, - ) - - return {} - - async def delete_group_summary_user( - self, group_id: str, requester_user_id: str, user_id: str, role_id: str - ) -> JsonDict: - """Remove a user from the group summary""" - await self.check_group_is_ours( - group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id - ) - - await self.store.remove_user_from_summary( - group_id=group_id, user_id=user_id, role_id=role_id - ) - - return {} - - async def update_group_profile( - self, group_id: str, requester_user_id: str, content: JsonDict - ) -> None: - """Update the group profile""" - await self.check_group_is_ours( - group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id - ) - - profile = {} - for keyname, max_length in ( - ("name", MAX_DISPLAYNAME_LEN), - ("avatar_url", MAX_AVATAR_URL_LEN), - ("short_description", MAX_SHORT_DESC_LEN), - ("long_description", MAX_LONG_DESC_LEN), - ): - if keyname in content: - value = content[keyname] - if not isinstance(value, str): - raise SynapseError( - 400, - "%r value is not a string" % (keyname,), - errcode=Codes.INVALID_PARAM, - ) - if len(value) > max_length: - raise SynapseError( - 400, - "Invalid %s parameter" % (keyname,), - errcode=Codes.INVALID_PARAM, - ) - profile[keyname] = value - - await self.store.update_group_profile(group_id, profile) - - async def add_room_to_group( - self, group_id: str, requester_user_id: str, room_id: str, content: JsonDict - ) -> JsonDict: - """Add room to group""" - RoomID.from_string(room_id) # Ensure valid room id - - await self.check_group_is_ours( - group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id - ) - - is_public = _parse_visibility_from_contents(content) - - await self.store.add_room_to_group(group_id, room_id, is_public=is_public) - - return {} - - async def update_room_in_group( - self, - group_id: str, - requester_user_id: str, - room_id: str, - config_key: str, - content: JsonDict, - ) -> JsonDict: - """Update room in group""" - RoomID.from_string(room_id) # Ensure valid room id - - await self.check_group_is_ours( - group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id - ) - - if config_key == "m.visibility": - is_public = _parse_visibility_dict(content) - - await self.store.update_room_in_group_visibility( - group_id, room_id, is_public=is_public - ) - else: - raise SynapseError(400, "Unknown config option") - - return {} - - async def remove_room_from_group( - self, group_id: str, requester_user_id: str, room_id: str - ) -> JsonDict: - """Remove room from group""" - await self.check_group_is_ours( - group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id - ) - - await self.store.remove_room_from_group(group_id, room_id) - - return {} - - async def invite_to_group( - self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict - ) -> JsonDict: - """Invite user to group""" - - group = await self.check_group_is_ours( - group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id - ) - if not group: - raise SynapseError(400, "Group does not exist", errcode=Codes.BAD_STATE) - - # TODO: Check if user knocked - - invited_users = await self.store.get_invited_users_in_group(group_id) - if user_id in invited_users: - raise SynapseError( - 400, "User already invited to group", errcode=Codes.BAD_STATE - ) - - user_results = await self.store.get_users_in_group( - group_id, include_private=True - ) - if user_id in (user_result["user_id"] for user_result in user_results): - raise SynapseError(400, "User already in group") - - content = { - "profile": {"name": group["name"], "avatar_url": group["avatar_url"]}, - "inviter": requester_user_id, - } - - if self.hs.is_mine_id(user_id): - groups_local = self.hs.get_groups_local_handler() - assert isinstance( - groups_local, GroupsLocalHandler - ), "Workers cannot invites users to groups." - res = await groups_local.on_invite(group_id, user_id, content) - local_attestation = None - else: - local_attestation = self.attestations.create_attestation(group_id, user_id) - content.update({"attestation": local_attestation}) - - res = await self.transport_client.invite_to_group_notification( - get_domain_from_id(user_id), group_id, user_id, content - ) - - user_profile = res.get("user_profile", {}) - await self.store.add_remote_profile_cache( - user_id, - displayname=user_profile.get("displayname"), - avatar_url=user_profile.get("avatar_url"), - ) - - if res["state"] == "join": - if not self.hs.is_mine_id(user_id): - remote_attestation = res["attestation"] - - await self.attestations.verify_attestation( - remote_attestation, user_id=user_id, group_id=group_id - ) - else: - remote_attestation = None - - await self.store.add_user_to_group( - group_id, - user_id, - is_admin=False, - is_public=False, # TODO - local_attestation=local_attestation, - remote_attestation=remote_attestation, - ) - return {"state": "join"} - elif res["state"] == "invite": - await self.store.add_group_invite(group_id, user_id) - return {"state": "invite"} - elif res["state"] == "reject": - return {"state": "reject"} - else: - raise SynapseError(502, "Unknown state returned by HS") - - async def _add_user( - self, group_id: str, user_id: str, content: JsonDict - ) -> Optional[JsonDict]: - """Add a user to a group based on a content dict. - - See accept_invite, join_group. - """ - if not self.hs.is_mine_id(user_id): - local_attestation: Optional[ - JsonDict - ] = self.attestations.create_attestation(group_id, user_id) - - remote_attestation = content["attestation"] - - await self.attestations.verify_attestation( - remote_attestation, user_id=user_id, group_id=group_id - ) - else: - local_attestation = None - remote_attestation = None - - is_public = _parse_visibility_from_contents(content) - - await self.store.add_user_to_group( - group_id, - user_id, - is_admin=False, - is_public=is_public, - local_attestation=local_attestation, - remote_attestation=remote_attestation, - ) - - return local_attestation - - async def accept_invite( - self, group_id: str, requester_user_id: str, content: JsonDict - ) -> JsonDict: - """User tries to accept an invite to the group. - - This is different from them asking to join, and so should error if no - invite exists (and they're not a member of the group) - """ - - await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - - is_invited = await self.store.is_user_invited_to_local_group( - group_id, requester_user_id - ) - if not is_invited: - raise SynapseError(403, "User not invited to group") - - local_attestation = await self._add_user(group_id, requester_user_id, content) - - return {"state": "join", "attestation": local_attestation} - - async def join_group( - self, group_id: str, requester_user_id: str, content: JsonDict - ) -> JsonDict: - """User tries to join the group. - - This will error if the group requires an invite/knock to join - """ - - group_info = await self.check_group_is_ours( - group_id, requester_user_id, and_exists=True - ) - if not group_info: - raise SynapseError(404, "Group does not exist", errcode=Codes.NOT_FOUND) - if group_info["join_policy"] != "open": - raise SynapseError(403, "Group is not publicly joinable") - - local_attestation = await self._add_user(group_id, requester_user_id, content) - - return {"state": "join", "attestation": local_attestation} - - async def remove_user_from_group( - self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict - ) -> JsonDict: - """Remove a user from the group; either a user is leaving or an admin - kicked them. - """ - - await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - - is_kick = False - if requester_user_id != user_id: - is_admin = await self.store.is_user_admin_in_group( - group_id, requester_user_id - ) - if not is_admin: - raise SynapseError(403, "User is not admin in group") - - is_kick = True - - await self.store.remove_user_from_group(group_id, user_id) - - if is_kick: - if self.hs.is_mine_id(user_id): - groups_local = self.hs.get_groups_local_handler() - assert isinstance( - groups_local, GroupsLocalHandler - ), "Workers cannot remove users from groups." - await groups_local.user_removed_from_group(group_id, user_id, {}) - else: - await self.transport_client.remove_user_from_group_notification( - get_domain_from_id(user_id), group_id, user_id, {} - ) - - if not self.hs.is_mine_id(user_id): - await self.store.maybe_delete_remote_profile_cache(user_id) - - # Delete group if the last user has left - users = await self.store.get_users_in_group(group_id, include_private=True) - if not users: - await self.store.delete_group(group_id) - - return {} - - async def create_group( - self, group_id: str, requester_user_id: str, content: JsonDict - ) -> JsonDict: - logger.info("Attempting to create group with ID: %r", group_id) - - # parsing the id into a GroupID validates it. - group_id_obj = GroupID.from_string(group_id) - - group = await self.check_group_is_ours(group_id, requester_user_id) - if group: - raise SynapseError(400, "Group already exists") - - is_admin = await self.auth.is_server_admin( - UserID.from_string(requester_user_id) - ) - if not is_admin: - if not self.hs.config.groups.enable_group_creation: - raise SynapseError( - 403, "Only a server admin can create groups on this server" - ) - localpart = group_id_obj.localpart - if not localpart.startswith(self.hs.config.groups.group_creation_prefix): - raise SynapseError( - 400, - "Can only create groups with prefix %r on this server" - % (self.hs.config.groups.group_creation_prefix,), - ) - - profile = content.get("profile", {}) - name = profile.get("name") - avatar_url = profile.get("avatar_url") - short_description = profile.get("short_description") - long_description = profile.get("long_description") - user_profile = content.get("user_profile", {}) - - await self.store.create_group( - group_id, - requester_user_id, - name=name, - avatar_url=avatar_url, - short_description=short_description, - long_description=long_description, - ) - - if not self.hs.is_mine_id(requester_user_id): - remote_attestation = content["attestation"] - - await self.attestations.verify_attestation( - remote_attestation, user_id=requester_user_id, group_id=group_id - ) - - local_attestation: Optional[ - JsonDict - ] = self.attestations.create_attestation(group_id, requester_user_id) - else: - local_attestation = None - remote_attestation = None - - await self.store.add_user_to_group( - group_id, - requester_user_id, - is_admin=True, - is_public=True, # TODO - local_attestation=local_attestation, - remote_attestation=remote_attestation, - ) - - if not self.hs.is_mine_id(requester_user_id): - await self.store.add_remote_profile_cache( - requester_user_id, - displayname=user_profile.get("displayname"), - avatar_url=user_profile.get("avatar_url"), - ) - - return {"group_id": group_id} - - async def delete_group(self, group_id: str, requester_user_id: str) -> None: - """Deletes a group, kicking out all current members. - - Only group admins or server admins can call this request - - Args: - group_id: The group ID to delete. - requester_user_id: The user requesting to delete the group. - """ - - await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - - # Only server admins or group admins can delete groups. - - is_admin = await self.store.is_user_admin_in_group(group_id, requester_user_id) - - if not is_admin: - is_admin = await self.auth.is_server_admin( - UserID.from_string(requester_user_id) - ) - - if not is_admin: - raise SynapseError(403, "User is not an admin") - - # Before deleting the group lets kick everyone out of it - users = await self.store.get_users_in_group(group_id, include_private=True) - - async def _kick_user_from_group(user_id: str) -> None: - if self.hs.is_mine_id(user_id): - groups_local = self.hs.get_groups_local_handler() - assert isinstance( - groups_local, GroupsLocalHandler - ), "Workers cannot kick users from groups." - await groups_local.user_removed_from_group(group_id, user_id, {}) - else: - await self.transport_client.remove_user_from_group_notification( - get_domain_from_id(user_id), group_id, user_id, {} - ) - await self.store.maybe_delete_remote_profile_cache(user_id) - - # We kick users out in the order of: - # 1. Non-admins - # 2. Other admins - # 3. The requester - # - # This is so that if the deletion fails for some reason other admins or - # the requester still has auth to retry. - non_admins = [] - admins = [] - for u in users: - if u["user_id"] == requester_user_id: - continue - if u["is_admin"]: - admins.append(u["user_id"]) - else: - non_admins.append(u["user_id"]) - - await concurrently_execute(_kick_user_from_group, non_admins, 10) - await concurrently_execute(_kick_user_from_group, admins, 10) - await _kick_user_from_group(requester_user_id) - - await self.store.delete_group(group_id) - - -def _parse_join_policy_from_contents(content: JsonDict) -> Optional[str]: - """Given a content for a request, return the specified join policy or None""" - - join_policy_dict = content.get("m.join_policy") - if join_policy_dict: - return _parse_join_policy_dict(join_policy_dict) - else: - return None - - -def _parse_join_policy_dict(join_policy_dict: JsonDict) -> str: - """Given a dict for the "m.join_policy" config return the join policy specified""" - join_policy_type = join_policy_dict.get("type") - if not join_policy_type: - return "invite" - - if join_policy_type not in ("invite", "open"): - raise SynapseError(400, "Synapse only supports 'invite'/'open' join rule") - return join_policy_type - - -def _parse_visibility_from_contents(content: JsonDict) -> bool: - """Given a content for a request parse out whether the entity should be - public or not - """ - - visibility = content.get("m.visibility") - if visibility: - return _parse_visibility_dict(visibility) - else: - is_public = True - - return is_public - - -def _parse_visibility_dict(visibility: JsonDict) -> bool: - """Given a dict for the "m.visibility" config return if the entity should - be public or not - """ - vis_type = visibility.get("type") - if not vis_type: - return True - - if vis_type not in ("public", "private"): - raise SynapseError(400, "Synapse only supports 'public'/'private' visibility") - return vis_type == "public" diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py deleted file mode 100644 index e7a399787b..0000000000 --- a/synapse/handlers/groups_local.py +++ /dev/null @@ -1,503 +0,0 @@ -# Copyright 2017 Vector Creations Ltd -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Iterable, List, Set - -from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError -from synapse.types import GroupID, JsonDict, get_domain_from_id - -if TYPE_CHECKING: - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - - -def _create_rerouter(func_name: str) -> Callable[..., Awaitable[JsonDict]]: - """Returns an async function that looks at the group id and calls the function - on federation or the local group server if the group is local - """ - - async def f( - self: "GroupsLocalWorkerHandler", group_id: str, *args: Any, **kwargs: Any - ) -> JsonDict: - if not GroupID.is_valid(group_id): - raise SynapseError(400, "%s is not a legal group ID" % (group_id,)) - - if self.is_mine_id(group_id): - return await getattr(self.groups_server_handler, func_name)( - group_id, *args, **kwargs - ) - else: - destination = get_domain_from_id(group_id) - - try: - return await getattr(self.transport_client, func_name)( - destination, group_id, *args, **kwargs - ) - except HttpResponseException as e: - # Capture errors returned by the remote homeserver and - # re-throw specific errors as SynapseErrors. This is so - # when the remote end responds with things like 403 Not - # In Group, we can communicate that to the client instead - # of a 500. - raise e.to_synapse_error() - except RequestSendFailed: - raise SynapseError(502, "Failed to contact group server") - - return f - - -class GroupsLocalWorkerHandler: - def __init__(self, hs: "HomeServer"): - self.hs = hs - self.store = hs.get_datastores().main - self.room_list_handler = hs.get_room_list_handler() - self.groups_server_handler = hs.get_groups_server_handler() - self.transport_client = hs.get_federation_transport_client() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.keyring = hs.get_keyring() - self.is_mine_id = hs.is_mine_id - self.signing_key = hs.signing_key - self.server_name = hs.hostname - self.notifier = hs.get_notifier() - self.attestations = hs.get_groups_attestation_signing() - - self.profile_handler = hs.get_profile_handler() - - # The following functions merely route the query to the local groups server - # or federation depending on if the group is local or remote - - get_group_profile = _create_rerouter("get_group_profile") - get_rooms_in_group = _create_rerouter("get_rooms_in_group") - get_invited_users_in_group = _create_rerouter("get_invited_users_in_group") - get_group_category = _create_rerouter("get_group_category") - get_group_categories = _create_rerouter("get_group_categories") - get_group_role = _create_rerouter("get_group_role") - get_group_roles = _create_rerouter("get_group_roles") - - async def get_group_summary( - self, group_id: str, requester_user_id: str - ) -> JsonDict: - """Get the group summary for a group. - - If the group is remote we check that the users have valid attestations. - """ - if self.is_mine_id(group_id): - res = await self.groups_server_handler.get_group_summary( - group_id, requester_user_id - ) - else: - try: - res = await self.transport_client.get_group_summary( - get_domain_from_id(group_id), group_id, requester_user_id - ) - except HttpResponseException as e: - raise e.to_synapse_error() - except RequestSendFailed: - raise SynapseError(502, "Failed to contact group server") - - group_server_name = get_domain_from_id(group_id) - - # Loop through the users and validate the attestations. - chunk = res["users_section"]["users"] - valid_users = [] - for entry in chunk: - g_user_id = entry["user_id"] - attestation = entry.pop("attestation", {}) - try: - if get_domain_from_id(g_user_id) != group_server_name: - await self.attestations.verify_attestation( - attestation, - group_id=group_id, - user_id=g_user_id, - server_name=get_domain_from_id(g_user_id), - ) - valid_users.append(entry) - except Exception as e: - logger.info("Failed to verify user is in group: %s", e) - - res["users_section"]["users"] = valid_users - - res["users_section"]["users"].sort(key=lambda e: e.get("order", 0)) - res["rooms_section"]["rooms"].sort(key=lambda e: e.get("order", 0)) - - # Add `is_publicised` flag to indicate whether the user has publicised their - # membership of the group on their profile - result = await self.store.get_publicised_groups_for_user(requester_user_id) - is_publicised = group_id in result - - res.setdefault("user", {})["is_publicised"] = is_publicised - - return res - - async def get_users_in_group( - self, group_id: str, requester_user_id: str - ) -> JsonDict: - """Get users in a group""" - if self.is_mine_id(group_id): - return await self.groups_server_handler.get_users_in_group( - group_id, requester_user_id - ) - - group_server_name = get_domain_from_id(group_id) - - try: - res = await self.transport_client.get_users_in_group( - get_domain_from_id(group_id), group_id, requester_user_id - ) - except HttpResponseException as e: - raise e.to_synapse_error() - except RequestSendFailed: - raise SynapseError(502, "Failed to contact group server") - - chunk = res["chunk"] - valid_entries = [] - for entry in chunk: - g_user_id = entry["user_id"] - attestation = entry.pop("attestation", {}) - try: - if get_domain_from_id(g_user_id) != group_server_name: - await self.attestations.verify_attestation( - attestation, - group_id=group_id, - user_id=g_user_id, - server_name=get_domain_from_id(g_user_id), - ) - valid_entries.append(entry) - except Exception as e: - logger.info("Failed to verify user is in group: %s", e) - - res["chunk"] = valid_entries - - return res - - async def get_joined_groups(self, user_id: str) -> JsonDict: - group_ids = await self.store.get_joined_groups(user_id) - return {"groups": group_ids} - - async def get_publicised_groups_for_user(self, user_id: str) -> JsonDict: - if self.hs.is_mine_id(user_id): - result = await self.store.get_publicised_groups_for_user(user_id) - - # Check AS associated groups for this user - this depends on the - # RegExps in the AS registration file (under `users`) - for app_service in self.store.get_app_services(): - result.extend(app_service.get_groups_for_user(user_id)) - - return {"groups": result} - else: - try: - bulk_result = await self.transport_client.bulk_get_publicised_groups( - get_domain_from_id(user_id), [user_id] - ) - except HttpResponseException as e: - raise e.to_synapse_error() - except RequestSendFailed: - raise SynapseError(502, "Failed to contact group server") - - result = bulk_result.get("users", {}).get(user_id) - # TODO: Verify attestations - return {"groups": result} - - async def bulk_get_publicised_groups( - self, user_ids: Iterable[str], proxy: bool = True - ) -> JsonDict: - destinations: Dict[str, Set[str]] = {} - local_users = set() - - for user_id in user_ids: - if self.hs.is_mine_id(user_id): - local_users.add(user_id) - else: - destinations.setdefault(get_domain_from_id(user_id), set()).add(user_id) - - if not proxy and destinations: - raise SynapseError(400, "Some user_ids are not local") - - results = {} - failed_results: List[str] = [] - for destination, dest_user_ids in destinations.items(): - try: - r = await self.transport_client.bulk_get_publicised_groups( - destination, list(dest_user_ids) - ) - results.update(r["users"]) - except Exception: - failed_results.extend(dest_user_ids) - - for uid in local_users: - results[uid] = await self.store.get_publicised_groups_for_user(uid) - - # Check AS associated groups for this user - this depends on the - # RegExps in the AS registration file (under `users`) - for app_service in self.store.get_app_services(): - results[uid].extend(app_service.get_groups_for_user(uid)) - - return {"users": results} - - -class GroupsLocalHandler(GroupsLocalWorkerHandler): - def __init__(self, hs: "HomeServer"): - super().__init__(hs) - - # Ensure attestations get renewed - hs.get_groups_attestation_renewer() - - # The following functions merely route the query to the local groups server - # or federation depending on if the group is local or remote - - update_group_profile = _create_rerouter("update_group_profile") - - add_room_to_group = _create_rerouter("add_room_to_group") - update_room_in_group = _create_rerouter("update_room_in_group") - remove_room_from_group = _create_rerouter("remove_room_from_group") - - update_group_summary_room = _create_rerouter("update_group_summary_room") - delete_group_summary_room = _create_rerouter("delete_group_summary_room") - - update_group_category = _create_rerouter("update_group_category") - delete_group_category = _create_rerouter("delete_group_category") - - update_group_summary_user = _create_rerouter("update_group_summary_user") - delete_group_summary_user = _create_rerouter("delete_group_summary_user") - - update_group_role = _create_rerouter("update_group_role") - delete_group_role = _create_rerouter("delete_group_role") - - set_group_join_policy = _create_rerouter("set_group_join_policy") - - async def create_group( - self, group_id: str, user_id: str, content: JsonDict - ) -> JsonDict: - """Create a group""" - - logger.info("Asking to create group with ID: %r", group_id) - - if self.is_mine_id(group_id): - res = await self.groups_server_handler.create_group( - group_id, user_id, content - ) - local_attestation = None - remote_attestation = None - else: - raise SynapseError(400, "Unable to create remote groups") - - is_publicised = content.get("publicise", False) - token = await self.store.register_user_group_membership( - group_id, - user_id, - membership="join", - is_admin=True, - local_attestation=local_attestation, - remote_attestation=remote_attestation, - is_publicised=is_publicised, - ) - self.notifier.on_new_event("groups_key", token, users=[user_id]) - - return res - - async def join_group( - self, group_id: str, user_id: str, content: JsonDict - ) -> JsonDict: - """Request to join a group""" - if self.is_mine_id(group_id): - await self.groups_server_handler.join_group(group_id, user_id, content) - local_attestation = None - remote_attestation = None - else: - local_attestation = self.attestations.create_attestation(group_id, user_id) - content["attestation"] = local_attestation - - try: - res = await self.transport_client.join_group( - get_domain_from_id(group_id), group_id, user_id, content - ) - except HttpResponseException as e: - raise e.to_synapse_error() - except RequestSendFailed: - raise SynapseError(502, "Failed to contact group server") - - remote_attestation = res["attestation"] - - await self.attestations.verify_attestation( - remote_attestation, - group_id=group_id, - user_id=user_id, - server_name=get_domain_from_id(group_id), - ) - - # TODO: Check that the group is public and we're being added publicly - is_publicised = content.get("publicise", False) - - token = await self.store.register_user_group_membership( - group_id, - user_id, - membership="join", - is_admin=False, - local_attestation=local_attestation, - remote_attestation=remote_attestation, - is_publicised=is_publicised, - ) - self.notifier.on_new_event("groups_key", token, users=[user_id]) - - return {} - - async def accept_invite( - self, group_id: str, user_id: str, content: JsonDict - ) -> JsonDict: - """Accept an invite to a group""" - if self.is_mine_id(group_id): - await self.groups_server_handler.accept_invite(group_id, user_id, content) - local_attestation = None - remote_attestation = None - else: - local_attestation = self.attestations.create_attestation(group_id, user_id) - content["attestation"] = local_attestation - - try: - res = await self.transport_client.accept_group_invite( - get_domain_from_id(group_id), group_id, user_id, content - ) - except HttpResponseException as e: - raise e.to_synapse_error() - except RequestSendFailed: - raise SynapseError(502, "Failed to contact group server") - - remote_attestation = res["attestation"] - - await self.attestations.verify_attestation( - remote_attestation, - group_id=group_id, - user_id=user_id, - server_name=get_domain_from_id(group_id), - ) - - # TODO: Check that the group is public and we're being added publicly - is_publicised = content.get("publicise", False) - - token = await self.store.register_user_group_membership( - group_id, - user_id, - membership="join", - is_admin=False, - local_attestation=local_attestation, - remote_attestation=remote_attestation, - is_publicised=is_publicised, - ) - self.notifier.on_new_event("groups_key", token, users=[user_id]) - - return {} - - async def invite( - self, group_id: str, user_id: str, requester_user_id: str, config: JsonDict - ) -> JsonDict: - """Invite a user to a group""" - content = {"requester_user_id": requester_user_id, "config": config} - if self.is_mine_id(group_id): - res = await self.groups_server_handler.invite_to_group( - group_id, user_id, requester_user_id, content - ) - else: - try: - res = await self.transport_client.invite_to_group( - get_domain_from_id(group_id), - group_id, - user_id, - requester_user_id, - content, - ) - except HttpResponseException as e: - raise e.to_synapse_error() - except RequestSendFailed: - raise SynapseError(502, "Failed to contact group server") - - return res - - async def on_invite( - self, group_id: str, user_id: str, content: JsonDict - ) -> JsonDict: - """One of our users were invited to a group""" - # TODO: Support auto join and rejection - - if not self.is_mine_id(user_id): - raise SynapseError(400, "User not on this server") - - local_profile = {} - if "profile" in content: - if "name" in content["profile"]: - local_profile["name"] = content["profile"]["name"] - if "avatar_url" in content["profile"]: - local_profile["avatar_url"] = content["profile"]["avatar_url"] - - token = await self.store.register_user_group_membership( - group_id, - user_id, - membership="invite", - content={"profile": local_profile, "inviter": content["inviter"]}, - ) - self.notifier.on_new_event("groups_key", token, users=[user_id]) - try: - user_profile = await self.profile_handler.get_profile(user_id) - except Exception as e: - logger.warning("No profile for user %s: %s", user_id, e) - user_profile = {} - - return {"state": "invite", "user_profile": user_profile} - - async def remove_user_from_group( - self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict - ) -> JsonDict: - """Remove a user from a group""" - if user_id == requester_user_id: - token = await self.store.register_user_group_membership( - group_id, user_id, membership="leave" - ) - self.notifier.on_new_event("groups_key", token, users=[user_id]) - - # TODO: Should probably remember that we tried to leave so that we can - # retry if the group server is currently down. - - if self.is_mine_id(group_id): - res = await self.groups_server_handler.remove_user_from_group( - group_id, user_id, requester_user_id, content - ) - else: - content["requester_user_id"] = requester_user_id - try: - res = await self.transport_client.remove_user_from_group( - get_domain_from_id(group_id), - group_id, - requester_user_id, - user_id, - content, - ) - except HttpResponseException as e: - raise e.to_synapse_error() - except RequestSendFailed: - raise SynapseError(502, "Failed to contact group server") - - return res - - async def user_removed_from_group( - self, group_id: str, user_id: str, content: JsonDict - ) -> None: - """One of our users was removed/kicked from a group""" - # TODO: Check if user in group - token = await self.store.register_user_group_membership( - group_id, user_id, membership="leave" - ) - self.notifier.on_new_event("groups_key", token, users=[user_id]) diff --git a/synapse/server.py b/synapse/server.py index ee60cce8eb..3fd23aaf52 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -21,17 +21,7 @@ import abc import functools import logging -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - List, - Optional, - TypeVar, - Union, - cast, -) +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar, cast from twisted.internet.interfaces import IOpenSSLContextFactory from twisted.internet.tcp import Port @@ -60,8 +50,6 @@ from synapse.federation.federation_server import ( from synapse.federation.send_queue import FederationRemoteSendQueue from synapse.federation.sender import AbstractFederationSender, FederationSender from synapse.federation.transport.client import TransportLayerClient -from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer -from synapse.groups.groups_server import GroupsServerHandler, GroupsServerWorkerHandler from synapse.handlers.account import AccountHandler from synapse.handlers.account_data import AccountDataHandler from synapse.handlers.account_validity import AccountValidityHandler @@ -79,7 +67,6 @@ from synapse.handlers.event_auth import EventAuthHandler from synapse.handlers.events import EventHandler, EventStreamHandler from synapse.handlers.federation import FederationHandler from synapse.handlers.federation_event import FederationEventHandler -from synapse.handlers.groups_local import GroupsLocalHandler, GroupsLocalWorkerHandler from synapse.handlers.identity import IdentityHandler from synapse.handlers.initial_sync import InitialSyncHandler from synapse.handlers.message import EventCreationHandler, MessageHandler @@ -651,30 +638,6 @@ class HomeServer(metaclass=abc.ABCMeta): def get_user_directory_handler(self) -> UserDirectoryHandler: return UserDirectoryHandler(self) - @cache_in_self - def get_groups_local_handler( - self, - ) -> Union[GroupsLocalWorkerHandler, GroupsLocalHandler]: - if self.config.worker.worker_app: - return GroupsLocalWorkerHandler(self) - else: - return GroupsLocalHandler(self) - - @cache_in_self - def get_groups_server_handler(self): - if self.config.worker.worker_app: - return GroupsServerWorkerHandler(self) - else: - return GroupsServerHandler(self) - - @cache_in_self - def get_groups_attestation_signing(self) -> GroupAttestationSigning: - return GroupAttestationSigning(self) - - @cache_in_self - def get_groups_attestation_renewer(self) -> GroupAttestionRenewer: - return GroupAttestionRenewer(self) - @cache_in_self def get_stats_handler(self) -> StatsHandler: return StatsHandler(self) diff --git a/synapse/types.py b/synapse/types.py index 6f7128ddd6..091cc611ab 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -320,29 +320,6 @@ class EventID(DomainSpecificString): SIGIL = "$" -@attr.s(slots=True, frozen=True, repr=False) -class GroupID(DomainSpecificString): - """Structure representing a group ID.""" - - SIGIL = "+" - - @classmethod - def from_string(cls: Type[DS], s: str) -> DS: - group_id: DS = super().from_string(s) # type: ignore - - if not group_id.localpart: - raise SynapseError(400, "Group ID cannot be empty", Codes.INVALID_PARAM) - - if contains_invalid_mxid_characters(group_id.localpart): - raise SynapseError( - 400, - "Group ID can only contain characters a-z, 0-9, or '=_-./'", - Codes.INVALID_PARAM, - ) - - return group_id - - mxid_localpart_allowed_characters = set( "_-./=" + string.ascii_lowercase + string.digits ) diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py index edc584d0cf..7135362f76 100644 --- a/tests/appservice/test_appservice.py +++ b/tests/appservice/test_appservice.py @@ -23,7 +23,7 @@ from tests.test_utils import simple_async_mock def _regex(regex: str, exclusive: bool = True) -> Namespace: - return Namespace(exclusive, None, re.compile(regex)) + return Namespace(exclusive, re.compile(regex)) class ApplicationServiceTestCase(unittest.TestCase): diff --git a/tests/test_types.py b/tests/test_types.py index 80888a744d..0b10dae848 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -13,7 +13,7 @@ # limitations under the License. from synapse.api.errors import SynapseError -from synapse.types import GroupID, RoomAlias, UserID, map_username_to_mxid_localpart +from synapse.types import RoomAlias, UserID, map_username_to_mxid_localpart from tests import unittest @@ -62,25 +62,6 @@ class RoomAliasTestCase(unittest.HomeserverTestCase): self.assertFalse(RoomAlias.is_valid(id_string)) -class GroupIDTestCase(unittest.TestCase): - def test_parse(self): - group_id = GroupID.from_string("+group/=_-.123:my.domain") - self.assertEqual("group/=_-.123", group_id.localpart) - self.assertEqual("my.domain", group_id.domain) - - def test_validate(self): - bad_ids = ["$badsigil:domain", "+:empty"] + [ - "+group" + c + ":domain" for c in "A%?æ£" - ] - for id_string in bad_ids: - try: - GroupID.from_string(id_string) - self.fail("Parsing '%s' should raise exception" % id_string) - except SynapseError as exc: - self.assertEqual(400, exc.code) - self.assertEqual("M_INVALID_PARAM", exc.errcode) - - class MapUsernameTestCase(unittest.TestCase): def testPassThrough(self): self.assertEqual(map_username_to_mxid_localpart("test1234"), "test1234") From bc1beebc27befccdbec5d199c61228930ded8143 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 27 May 2022 11:27:33 +0200 Subject: [PATCH 24/74] Refactor have_seen_events to reduce OOMs (#12886) My server is currently OOMing in the middle of have_seen_events, so let's try to fix that. --- changelog.d/12886.misc | 1 + .../storage/databases/main/events_worker.py | 42 +++++++++++-------- 2 files changed, 25 insertions(+), 18 deletions(-) create mode 100644 changelog.d/12886.misc diff --git a/changelog.d/12886.misc b/changelog.d/12886.misc new file mode 100644 index 0000000000..3dd08f74ba --- /dev/null +++ b/changelog.d/12886.misc @@ -0,0 +1 @@ +Refactor `have_seen_events` to reduce memory consumed when processing federation traffic. diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 5b22d6b452..a97d7e1664 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -1356,14 +1356,23 @@ class EventsWorkerStore(SQLBaseStore): Returns: The set of events we have already seen. """ - res = await self._have_seen_events_dict( - (room_id, event_id) for event_id in event_ids - ) - return {eid for ((_rid, eid), have_event) in res.items() if have_event} + + # @cachedList chomps lots of memory if you call it with a big list, so + # we break it down. However, each batch requires its own index scan, so we make + # the batches as big as possible. + + results: Set[str] = set() + for chunk in batch_iter(event_ids, 500): + r = await self._have_seen_events_dict( + [(room_id, event_id) for event_id in chunk] + ) + results.update(eid for ((_rid, eid), have_event) in r.items() if have_event) + + return results @cachedList(cached_method_name="have_seen_event", list_name="keys") async def _have_seen_events_dict( - self, keys: Iterable[Tuple[str, str]] + self, keys: Collection[Tuple[str, str]] ) -> Dict[Tuple[str, str], bool]: """Helper for have_seen_events @@ -1375,11 +1384,12 @@ class EventsWorkerStore(SQLBaseStore): cache_results = { (rid, eid) for (rid, eid) in keys if self._get_event_cache.contains((eid,)) } - results = {x: True for x in cache_results} + results = dict.fromkeys(cache_results, True) + remaining = [k for k in keys if k not in cache_results] + if not remaining: + return results - def have_seen_events_txn( - txn: LoggingTransaction, chunk: Tuple[Tuple[str, str], ...] - ) -> None: + def have_seen_events_txn(txn: LoggingTransaction) -> None: # we deliberately do *not* query the database for room_id, to make the # query an index-only lookup on `events_event_id_key`. # @@ -1387,21 +1397,17 @@ class EventsWorkerStore(SQLBaseStore): sql = "SELECT event_id FROM events AS e WHERE " clause, args = make_in_list_sql_clause( - txn.database_engine, "e.event_id", [eid for (_rid, eid) in chunk] + txn.database_engine, "e.event_id", [eid for (_rid, eid) in remaining] ) txn.execute(sql + clause, args) found_events = {eid for eid, in txn} - # ... and then we can update the results for each row in the batch - results.update({(rid, eid): (eid in found_events) for (rid, eid) in chunk}) - - # each batch requires its own index scan, so we make the batches as big as - # possible. - for chunk in batch_iter((k for k in keys if k not in cache_results), 500): - await self.db_pool.runInteraction( - "have_seen_events", have_seen_events_txn, chunk + # ... and then we can update the results for each key + results.update( + {(rid, eid): (eid in found_events) for (rid, eid) in remaining} ) + await self.db_pool.runInteraction("have_seen_events", have_seen_events_txn) return results @cached(max_entries=100000, tree=True) From f1605b7447196d2b13cfbaf70483aa9f2f1a34b4 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 27 May 2022 11:31:08 +0200 Subject: [PATCH 25/74] Fix room deletion (#12889) * Fix room deletion ae7858f broke room deletion by attempting to delete the entry from `rooms` before the tables that reference it. * faster_joins: remove database rows on purge --- changelog.d/12889.bugfix | 1 + .../storage/databases/main/purge_events.py | 19 ++++++++++--------- 2 files changed, 11 insertions(+), 9 deletions(-) create mode 100644 changelog.d/12889.bugfix diff --git a/changelog.d/12889.bugfix b/changelog.d/12889.bugfix new file mode 100644 index 0000000000..582b2f0642 --- /dev/null +++ b/changelog.d/12889.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.59.0 which caused room deletion to fail with a foreign key violation. diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index c94d5f9f81..2353c120e9 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -322,12 +322,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): ) def _purge_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[int]: - # We *immediately* delete the room from the rooms table. This ensures - # that we don't race when persisting events (as that transaction checks - # that the room exists). - txn.execute("DELETE FROM rooms WHERE room_id = ?", (room_id,)) - - # Next, we fetch all the state groups that should be deleted, before + # First, fetch all the state groups that should be deleted, before # we delete that information. txn.execute( """ @@ -387,7 +382,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): (room_id,), ) - # and finally, the tables with an index on room_id (or no useful index) + # next, the tables with an index on room_id (or no useful index) for table in ( "current_state_events", "destination_rooms", @@ -395,8 +390,13 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): "event_forward_extremities", "event_push_actions", "event_search", + "partial_state_events", "events", + "federation_inbound_events_staging", "group_rooms", + "local_current_membership", + "partial_state_rooms_servers", + "partial_state_rooms", "receipts_graph", "receipts_linearized", "room_aliases", @@ -416,8 +416,9 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): "group_summary_rooms", "room_account_data", "room_tags", - "local_current_membership", - "federation_inbound_events_staging", + # "rooms" happens last, to keep the foreign keys in the other tables + # happy + "rooms", ): logger.info("[purge] removing %s from %s", room_id, table) txn.execute("DELETE FROM %s WHERE room_id=?" % (table,), (room_id,)) From 3503f42741d438b67490bc774ee2c3a856f6bc81 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 27 May 2022 11:17:33 +0100 Subject: [PATCH 26/74] Easy type hints in synapse.logging.opentracing (#12894) --- changelog.d/12894.misc | 1 + synapse/config/tracer.py | 6 +- synapse/logging/opentracing.py | 114 ++++++++++-------- synapse/metrics/background_process_metrics.py | 9 +- 4 files changed, 73 insertions(+), 57 deletions(-) create mode 100644 changelog.d/12894.misc diff --git a/changelog.d/12894.misc b/changelog.d/12894.misc new file mode 100644 index 0000000000..646a62fccb --- /dev/null +++ b/changelog.d/12894.misc @@ -0,0 +1 @@ +Add type annotations to `synapse.logging.opentracing`. diff --git a/synapse/config/tracer.py b/synapse/config/tracer.py index 3472a9a01b..ae68a3dd1a 100644 --- a/synapse/config/tracer.py +++ b/synapse/config/tracer.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Set +from typing import Any, List, Set from synapse.types import JsonDict from synapse.util.check_dependencies import DependencyException, check_requirements @@ -49,7 +49,9 @@ class TracerConfig(Config): # The tracer is enabled so sanitize the config - self.opentracer_whitelist = opentracing_config.get("homeserver_whitelist", []) + self.opentracer_whitelist: List[str] = opentracing_config.get( + "homeserver_whitelist", [] + ) if not isinstance(self.opentracer_whitelist, list): raise ConfigError("Tracer homeserver_whitelist config is malformed") diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index a02b5bf6bd..903ec40c86 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -168,9 +168,24 @@ import inspect import logging import re from functools import wraps -from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Pattern, Type +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Collection, + Dict, + Generator, + Iterable, + List, + Optional, + Pattern, + Type, + TypeVar, + Union, +) import attr +from typing_extensions import ParamSpec from twisted.internet import defer from twisted.web.http import Request @@ -256,7 +271,7 @@ try: def set_process(self, *args, **kwargs): return self._reporter.set_process(*args, **kwargs) - def report_span(self, span): + def report_span(self, span: "opentracing.Span") -> None: try: return self._reporter.report_span(span) except Exception: @@ -307,15 +322,19 @@ _homeserver_whitelist: Optional[Pattern[str]] = None Sentinel = object() -def only_if_tracing(func): +P = ParamSpec("P") +R = TypeVar("R") + + +def only_if_tracing(func: Callable[P, R]) -> Callable[P, Optional[R]]: """Executes the function only if we're tracing. Otherwise returns None.""" @wraps(func) - def _only_if_tracing_inner(*args, **kwargs): + def _only_if_tracing_inner(*args: P.args, **kwargs: P.kwargs) -> Optional[R]: if opentracing: return func(*args, **kwargs) else: - return + return None return _only_if_tracing_inner @@ -356,17 +375,10 @@ def ensure_active_span(message, ret=None): return ensure_active_span_inner_1 -@contextlib.contextmanager -def noop_context_manager(*args, **kwargs): - """Does exactly what it says on the tin""" - # TODO: replace with contextlib.nullcontext once we drop support for Python 3.6 - yield - - # Setup -def init_tracer(hs: "HomeServer"): +def init_tracer(hs: "HomeServer") -> None: """Set the whitelists and initialise the JaegerClient tracer""" global opentracing if not hs.config.tracing.opentracer_enabled: @@ -408,11 +420,11 @@ def init_tracer(hs: "HomeServer"): @only_if_tracing -def set_homeserver_whitelist(homeserver_whitelist): +def set_homeserver_whitelist(homeserver_whitelist: Iterable[str]) -> None: """Sets the homeserver whitelist Args: - homeserver_whitelist (Iterable[str]): regex of whitelisted homeservers + homeserver_whitelist: regexes specifying whitelisted homeservers """ global _homeserver_whitelist if homeserver_whitelist: @@ -423,15 +435,15 @@ def set_homeserver_whitelist(homeserver_whitelist): @only_if_tracing -def whitelisted_homeserver(destination): +def whitelisted_homeserver(destination: str) -> bool: """Checks if a destination matches the whitelist Args: - destination (str) + destination """ if _homeserver_whitelist: - return _homeserver_whitelist.match(destination) + return _homeserver_whitelist.match(destination) is not None return False @@ -457,11 +469,11 @@ def start_active_span( Args: See opentracing.tracer Returns: - scope (Scope) or noop_context_manager + scope (Scope) or contextlib.nullcontext """ if opentracing is None: - return noop_context_manager() # type: ignore[unreachable] + return contextlib.nullcontext() # type: ignore[unreachable] if tracer is None: # use the global tracer by default @@ -505,7 +517,7 @@ def start_active_span_follows_from( tracer: override the opentracing tracer. By default the global tracer is used. """ if opentracing is None: - return noop_context_manager() # type: ignore[unreachable] + return contextlib.nullcontext() # type: ignore[unreachable] references = [opentracing.follows_from(context) for context in contexts] scope = start_active_span( @@ -525,19 +537,19 @@ def start_active_span_follows_from( def start_active_span_from_edu( - edu_content, - operation_name, - references: Optional[list] = None, - tags=None, - start_time=None, - ignore_active_span=False, - finish_on_close=True, -): + edu_content: Dict[str, Any], + operation_name: str, + references: Optional[List["opentracing.Reference"]] = None, + tags: Optional[Dict] = None, + start_time: Optional[float] = None, + ignore_active_span: bool = False, + finish_on_close: bool = True, +) -> "opentracing.Scope": """ Extracts a span context from an edu and uses it to start a new active span Args: - edu_content (dict): and edu_content with a `context` field whose value is + edu_content: an edu_content with a `context` field whose value is canonical json for a dict which contains opentracing information. For the other args see opentracing.tracer @@ -545,7 +557,7 @@ def start_active_span_from_edu( references = references or [] if opentracing is None: - return noop_context_manager() # type: ignore[unreachable] + return contextlib.nullcontext() # type: ignore[unreachable] carrier = json_decoder.decode(edu_content.get("context", "{}")).get( "opentracing", {} @@ -578,27 +590,27 @@ def start_active_span_from_edu( # Opentracing setters for tags, logs, etc @only_if_tracing -def active_span(): +def active_span() -> Optional["opentracing.Span"]: """Get the currently active span, if any""" return opentracing.tracer.active_span @ensure_active_span("set a tag") -def set_tag(key, value): +def set_tag(key: str, value: Union[str, bool, int, float]) -> None: """Sets a tag on the active span""" assert opentracing.tracer.active_span is not None opentracing.tracer.active_span.set_tag(key, value) @ensure_active_span("log") -def log_kv(key_values, timestamp=None): +def log_kv(key_values: Dict[str, Any], timestamp: Optional[float] = None) -> None: """Log to the active span""" assert opentracing.tracer.active_span is not None opentracing.tracer.active_span.log_kv(key_values, timestamp) @ensure_active_span("set the traces operation name") -def set_operation_name(operation_name): +def set_operation_name(operation_name: str) -> None: """Sets the operation name of the active span""" assert opentracing.tracer.active_span is not None opentracing.tracer.active_span.set_operation_name(operation_name) @@ -624,7 +636,9 @@ def force_tracing(span=Sentinel) -> None: span.set_baggage_item(SynapseBaggage.FORCE_TRACING, "1") -def is_context_forced_tracing(span_context) -> bool: +def is_context_forced_tracing( + span_context: Optional["opentracing.SpanContext"], +) -> bool: """Check if sampling has been force for the given span context.""" if span_context is None: return False @@ -696,13 +710,13 @@ def inject_response_headers(response_headers: Headers) -> None: @ensure_active_span("get the active span context as a dict", ret={}) -def get_active_span_text_map(destination=None): +def get_active_span_text_map(destination: Optional[str] = None) -> Dict[str, str]: """ Gets a span context as a dict. This can be used instead of manually injecting a span into an empty carrier. Args: - destination (str): the name of the remote server. + destination: the name of the remote server. Returns: dict: the active span's context if opentracing is enabled, otherwise empty. @@ -721,7 +735,7 @@ def get_active_span_text_map(destination=None): @ensure_active_span("get the span context as a string.", ret={}) -def active_span_context_as_string(): +def active_span_context_as_string() -> str: """ Returns: The active span context encoded as a string. @@ -750,21 +764,21 @@ def span_context_from_request(request: Request) -> "Optional[opentracing.SpanCon @only_if_tracing -def span_context_from_string(carrier): +def span_context_from_string(carrier: str) -> Optional["opentracing.SpanContext"]: """ Returns: The active span context decoded from a string. """ - carrier = json_decoder.decode(carrier) - return opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier) + payload: Dict[str, str] = json_decoder.decode(carrier) + return opentracing.tracer.extract(opentracing.Format.TEXT_MAP, payload) @only_if_tracing -def extract_text_map(carrier): +def extract_text_map(carrier: Dict[str, str]) -> Optional["opentracing.SpanContext"]: """ Wrapper method for opentracing's tracer.extract for TEXT_MAP. Args: - carrier (dict): a dict possibly containing a span context. + carrier: a dict possibly containing a span context. Returns: The active span context extracted from carrier. @@ -843,7 +857,7 @@ def trace(func=None, opname=None): return decorator -def tag_args(func): +def tag_args(func: Callable[P, R]) -> Callable[P, R]: """ Tags all of the args to the active span. """ @@ -852,11 +866,11 @@ def tag_args(func): return func @wraps(func) - def _tag_args_inner(*args, **kwargs): + def _tag_args_inner(*args: P.args, **kwargs: P.kwargs) -> R: argspec = inspect.getfullargspec(func) for i, arg in enumerate(argspec.args[1:]): - set_tag("ARG_" + arg, args[i]) - set_tag("args", args[len(argspec.args) :]) + set_tag("ARG_" + arg, args[i]) # type: ignore[index] + set_tag("args", args[len(argspec.args) :]) # type: ignore[index] set_tag("kwargs", kwargs) return func(*args, **kwargs) @@ -864,7 +878,9 @@ def tag_args(func): @contextlib.contextmanager -def trace_servlet(request: "SynapseRequest", extract_context: bool = False): +def trace_servlet( + request: "SynapseRequest", extract_context: bool = False +) -> Generator[None, None, None]: """Returns a context manager which traces a request. It starts a span with some servlet specific tags such as the request metrics name and request information. diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py index 298809742a..eef3462e10 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py @@ -14,6 +14,7 @@ import logging import threading +from contextlib import nullcontext from functools import wraps from types import TracebackType from typing import ( @@ -41,11 +42,7 @@ from synapse.logging.context import ( LoggingContext, PreserveLoggingContext, ) -from synapse.logging.opentracing import ( - SynapseTags, - noop_context_manager, - start_active_span, -) +from synapse.logging.opentracing import SynapseTags, start_active_span from synapse.metrics._types import Collector if TYPE_CHECKING: @@ -238,7 +235,7 @@ def run_as_background_process( f"bgproc.{desc}", tags={SynapseTags.REQUEST_ID: str(context)} ) else: - ctx = noop_context_manager() + ctx = nullcontext() with ctx: return await func(*args, **kwargs) except Exception: From a7da00d4f74b0c614971da0978a0f0d6c316fa8b Mon Sep 17 00:00:00 2001 From: Matt C <96466754+buffless-matt@users.noreply.github.com> Date: Fri, 27 May 2022 20:25:57 +1000 Subject: [PATCH 27/74] Add storage and module API methods to get monthly active users and their appservices (#12838) --- changelog.d/12838.feature | 1 + synapse/module_api/__init__.py | 20 +++++ .../databases/main/monthly_active_users.py | 45 ++++++++++ tests/storage/test_monthly_active_users.py | 83 +++++++++++++++++++ 4 files changed, 149 insertions(+) create mode 100644 changelog.d/12838.feature diff --git a/changelog.d/12838.feature b/changelog.d/12838.feature new file mode 100644 index 0000000000..b24489aaad --- /dev/null +++ b/changelog.d/12838.feature @@ -0,0 +1 @@ +Add storage and module API methods to get monthly active users (and their corresponding appservices) within an optionally specified time range. diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 95f3b27927..edcf59aa0b 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -1429,6 +1429,26 @@ class ModuleApi: user_id, spec, {"actions": actions} ) + async def get_monthly_active_users_by_service( + self, start_timestamp: Optional[int] = None, end_timestamp: Optional[int] = None + ) -> List[Tuple[str, str]]: + """Generates list of monthly active users and their services. + Please see corresponding storage docstring for more details. + + Arguments: + start_timestamp: If specified, only include users that were first active + at or after this point + end_timestamp: If specified, only include users that were first active + at or before this point + + Returns: + A list of tuples (appservice_id, user_id) + + """ + return await self._store.get_monthly_active_users_by_service( + start_timestamp, end_timestamp + ) + class PublicRoomListManager: """Contains methods for adding to, removing from and querying whether a room diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index 5beb8f1d4b..9a63f953fb 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -122,6 +122,51 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore): "count_users_by_service", _count_users_by_service ) + async def get_monthly_active_users_by_service( + self, start_timestamp: Optional[int] = None, end_timestamp: Optional[int] = None + ) -> List[Tuple[str, str]]: + """Generates list of monthly active users and their services. + Please see "get_monthly_active_count_by_service" docstring for more details + about services. + + Arguments: + start_timestamp: If specified, only include users that were first active + at or after this point + end_timestamp: If specified, only include users that were first active + at or before this point + + Returns: + A list of tuples (appservice_id, user_id). "native" is emitted as the + appservice for users that don't come from appservices (i.e. native Matrix + users). + + """ + if start_timestamp is not None and end_timestamp is not None: + where_clause = 'WHERE "timestamp" >= ? and "timestamp" <= ?' + query_params = [start_timestamp, end_timestamp] + elif start_timestamp is not None: + where_clause = 'WHERE "timestamp" >= ?' + query_params = [start_timestamp] + elif end_timestamp is not None: + where_clause = 'WHERE "timestamp" <= ?' + query_params = [end_timestamp] + else: + where_clause = "" + query_params = [] + + def _list_users(txn: LoggingTransaction) -> List[Tuple[str, str]]: + sql = f""" + SELECT COALESCE(appservice_id, 'native'), user_id + FROM monthly_active_users + LEFT JOIN users ON monthly_active_users.user_id=users.name + {where_clause}; + """ + + txn.execute(sql, query_params) + return cast(List[Tuple[str, str]], txn.fetchall()) + + return await self.db_pool.runInteraction("list_users", _list_users) + async def get_registered_reserved_users(self) -> List[str]: """Of the reserved threepids defined in config, retrieve those that are associated with registered users diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 4c29ad79b6..e8b4a5644b 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -407,3 +407,86 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.assertEqual(result[service1], 2) self.assertEqual(result[service2], 1) self.assertEqual(result[native], 1) + + def test_get_monthly_active_users_by_service(self): + # (No users, no filtering) -> empty result + result = self.get_success(self.store.get_monthly_active_users_by_service()) + + self.assertEqual(len(result), 0) + + # (Some users, no filtering) -> non-empty result + appservice1_user1 = "@appservice1_user1:example.com" + appservice2_user1 = "@appservice2_user1:example.com" + service1 = "service1" + service2 = "service2" + self.get_success( + self.store.register_user( + user_id=appservice1_user1, password_hash=None, appservice_id=service1 + ) + ) + self.get_success(self.store.upsert_monthly_active_user(appservice1_user1)) + self.get_success( + self.store.register_user( + user_id=appservice2_user1, password_hash=None, appservice_id=service2 + ) + ) + self.get_success(self.store.upsert_monthly_active_user(appservice2_user1)) + + result = self.get_success(self.store.get_monthly_active_users_by_service()) + + self.assertEqual(len(result), 2) + self.assertIn((service1, appservice1_user1), result) + self.assertIn((service2, appservice2_user1), result) + + # (Some users, end-timestamp filtering) -> non-empty result + appservice1_user2 = "@appservice1_user2:example.com" + timestamp1 = self.reactor.seconds() + self.reactor.advance(5) + timestamp2 = self.reactor.seconds() + self.get_success( + self.store.register_user( + user_id=appservice1_user2, password_hash=None, appservice_id=service1 + ) + ) + self.get_success(self.store.upsert_monthly_active_user(appservice1_user2)) + + result = self.get_success( + self.store.get_monthly_active_users_by_service( + end_timestamp=round(timestamp1 * 1000) + ) + ) + + self.assertEqual(len(result), 2) + self.assertNotIn((service1, appservice1_user2), result) + + # (Some users, start-timestamp filtering) -> non-empty result + result = self.get_success( + self.store.get_monthly_active_users_by_service( + start_timestamp=round(timestamp2 * 1000) + ) + ) + + self.assertEqual(len(result), 1) + self.assertIn((service1, appservice1_user2), result) + + # (Some users, full-timestamp filtering) -> non-empty result + native_user1 = "@native_user1:example.com" + native = "native" + timestamp3 = self.reactor.seconds() + self.reactor.advance(100) + self.get_success( + self.store.register_user( + user_id=native_user1, password_hash=None, appservice_id=native + ) + ) + self.get_success(self.store.upsert_monthly_active_user(native_user1)) + + result = self.get_success( + self.store.get_monthly_active_users_by_service( + start_timestamp=round(timestamp2 * 1000), + end_timestamp=round(timestamp3 * 1000), + ) + ) + + self.assertEqual(len(result), 1) + self.assertIn((service1, appservice1_user2), result) From d9f092285b28f0cdac0d985813a1cabd8ea990b6 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 27 May 2022 07:13:58 -0400 Subject: [PATCH 28/74] Remove federation client code for groups. (#12563) --- changelog.d/12563.removal | 1 + synapse/federation/transport/client.py | 483 ------------------------- 2 files changed, 1 insertion(+), 483 deletions(-) create mode 100644 changelog.d/12563.removal diff --git a/changelog.d/12563.removal b/changelog.d/12563.removal new file mode 100644 index 0000000000..41f6fae5da --- /dev/null +++ b/changelog.d/12563.removal @@ -0,0 +1 @@ +Remove support for the non-standard groups/communities feature from Synapse. diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 9da80176a5..9e84bd677e 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -17,7 +17,6 @@ import logging import urllib from typing import ( Any, - Awaitable, Callable, Collection, Dict, @@ -681,488 +680,6 @@ class TransportLayerClient: timeout=timeout, ) - async def get_group_profile( - self, destination: str, group_id: str, requester_user_id: str - ) -> JsonDict: - """Get a group profile""" - path = _create_v1_path("/groups/%s/profile", group_id) - - return await self.client.get_json( - destination=destination, - path=path, - args={"requester_user_id": requester_user_id}, - ignore_backoff=True, - ) - - async def update_group_profile( - self, destination: str, group_id: str, requester_user_id: str, content: JsonDict - ) -> JsonDict: - """Update a remote group profile - - Args: - destination - group_id - requester_user_id - content: The new profile of the group - """ - path = _create_v1_path("/groups/%s/profile", group_id) - - return self.client.post_json( - destination=destination, - path=path, - args={"requester_user_id": requester_user_id}, - data=content, - ignore_backoff=True, - ) - - async def get_group_summary( - self, destination: str, group_id: str, requester_user_id: str - ) -> JsonDict: - """Get a group summary""" - path = _create_v1_path("/groups/%s/summary", group_id) - - return await self.client.get_json( - destination=destination, - path=path, - args={"requester_user_id": requester_user_id}, - ignore_backoff=True, - ) - - async def get_rooms_in_group( - self, destination: str, group_id: str, requester_user_id: str - ) -> JsonDict: - """Get all rooms in a group""" - path = _create_v1_path("/groups/%s/rooms", group_id) - - return await self.client.get_json( - destination=destination, - path=path, - args={"requester_user_id": requester_user_id}, - ignore_backoff=True, - ) - - async def add_room_to_group( - self, - destination: str, - group_id: str, - requester_user_id: str, - room_id: str, - content: JsonDict, - ) -> JsonDict: - """Add a room to a group""" - path = _create_v1_path("/groups/%s/room/%s", group_id, room_id) - - return await self.client.post_json( - destination=destination, - path=path, - args={"requester_user_id": requester_user_id}, - data=content, - ignore_backoff=True, - ) - - async def update_room_in_group( - self, - destination: str, - group_id: str, - requester_user_id: str, - room_id: str, - config_key: str, - content: JsonDict, - ) -> JsonDict: - """Update room in group""" - path = _create_v1_path( - "/groups/%s/room/%s/config/%s", group_id, room_id, config_key - ) - - return await self.client.post_json( - destination=destination, - path=path, - args={"requester_user_id": requester_user_id}, - data=content, - ignore_backoff=True, - ) - - async def remove_room_from_group( - self, destination: str, group_id: str, requester_user_id: str, room_id: str - ) -> JsonDict: - """Remove a room from a group""" - path = _create_v1_path("/groups/%s/room/%s", group_id, room_id) - - return await self.client.delete_json( - destination=destination, - path=path, - args={"requester_user_id": requester_user_id}, - ignore_backoff=True, - ) - - async def get_users_in_group( - self, destination: str, group_id: str, requester_user_id: str - ) -> JsonDict: - """Get users in a group""" - path = _create_v1_path("/groups/%s/users", group_id) - - return await self.client.get_json( - destination=destination, - path=path, - args={"requester_user_id": requester_user_id}, - ignore_backoff=True, - ) - - async def get_invited_users_in_group( - self, destination: str, group_id: str, requester_user_id: str - ) -> JsonDict: - """Get users that have been invited to a group""" - path = _create_v1_path("/groups/%s/invited_users", group_id) - - return await self.client.get_json( - destination=destination, - path=path, - args={"requester_user_id": requester_user_id}, - ignore_backoff=True, - ) - - async def accept_group_invite( - self, destination: str, group_id: str, user_id: str, content: JsonDict - ) -> JsonDict: - """Accept a group invite""" - path = _create_v1_path("/groups/%s/users/%s/accept_invite", group_id, user_id) - - return await self.client.post_json( - destination=destination, path=path, data=content, ignore_backoff=True - ) - - def join_group( - self, destination: str, group_id: str, user_id: str, content: JsonDict - ) -> Awaitable[JsonDict]: - """Attempts to join a group""" - path = _create_v1_path("/groups/%s/users/%s/join", group_id, user_id) - - return self.client.post_json( - destination=destination, path=path, data=content, ignore_backoff=True - ) - - async def invite_to_group( - self, - destination: str, - group_id: str, - user_id: str, - requester_user_id: str, - content: JsonDict, - ) -> JsonDict: - """Invite a user to a group""" - path = _create_v1_path("/groups/%s/users/%s/invite", group_id, user_id) - - return await self.client.post_json( - destination=destination, - path=path, - args={"requester_user_id": requester_user_id}, - data=content, - ignore_backoff=True, - ) - - async def invite_to_group_notification( - self, destination: str, group_id: str, user_id: str, content: JsonDict - ) -> JsonDict: - """Sent by group server to inform a user's server that they have been - invited. - """ - - path = _create_v1_path("/groups/local/%s/users/%s/invite", group_id, user_id) - - return await self.client.post_json( - destination=destination, path=path, data=content, ignore_backoff=True - ) - - async def remove_user_from_group( - self, - destination: str, - group_id: str, - requester_user_id: str, - user_id: str, - content: JsonDict, - ) -> JsonDict: - """Remove a user from a group""" - path = _create_v1_path("/groups/%s/users/%s/remove", group_id, user_id) - - return await self.client.post_json( - destination=destination, - path=path, - args={"requester_user_id": requester_user_id}, - data=content, - ignore_backoff=True, - ) - - async def remove_user_from_group_notification( - self, destination: str, group_id: str, user_id: str, content: JsonDict - ) -> JsonDict: - """Sent by group server to inform a user's server that they have been - kicked from the group. - """ - - path = _create_v1_path("/groups/local/%s/users/%s/remove", group_id, user_id) - - return await self.client.post_json( - destination=destination, path=path, data=content, ignore_backoff=True - ) - - async def renew_group_attestation( - self, destination: str, group_id: str, user_id: str, content: JsonDict - ) -> JsonDict: - """Sent by either a group server or a user's server to periodically update - the attestations - """ - - path = _create_v1_path("/groups/%s/renew_attestation/%s", group_id, user_id) - - return await self.client.post_json( - destination=destination, path=path, data=content, ignore_backoff=True - ) - - async def update_group_summary_room( - self, - destination: str, - group_id: str, - user_id: str, - room_id: str, - category_id: str, - content: JsonDict, - ) -> JsonDict: - """Update a room entry in a group summary""" - if category_id: - path = _create_v1_path( - "/groups/%s/summary/categories/%s/rooms/%s", - group_id, - category_id, - room_id, - ) - else: - path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id) - - return await self.client.post_json( - destination=destination, - path=path, - args={"requester_user_id": user_id}, - data=content, - ignore_backoff=True, - ) - - async def delete_group_summary_room( - self, - destination: str, - group_id: str, - user_id: str, - room_id: str, - category_id: str, - ) -> JsonDict: - """Delete a room entry in a group summary""" - if category_id: - path = _create_v1_path( - "/groups/%s/summary/categories/%s/rooms/%s", - group_id, - category_id, - room_id, - ) - else: - path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id) - - return await self.client.delete_json( - destination=destination, - path=path, - args={"requester_user_id": user_id}, - ignore_backoff=True, - ) - - async def get_group_categories( - self, destination: str, group_id: str, requester_user_id: str - ) -> JsonDict: - """Get all categories in a group""" - path = _create_v1_path("/groups/%s/categories", group_id) - - return await self.client.get_json( - destination=destination, - path=path, - args={"requester_user_id": requester_user_id}, - ignore_backoff=True, - ) - - async def get_group_category( - self, destination: str, group_id: str, requester_user_id: str, category_id: str - ) -> JsonDict: - """Get category info in a group""" - path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id) - - return await self.client.get_json( - destination=destination, - path=path, - args={"requester_user_id": requester_user_id}, - ignore_backoff=True, - ) - - async def update_group_category( - self, - destination: str, - group_id: str, - requester_user_id: str, - category_id: str, - content: JsonDict, - ) -> JsonDict: - """Update a category in a group""" - path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id) - - return await self.client.post_json( - destination=destination, - path=path, - args={"requester_user_id": requester_user_id}, - data=content, - ignore_backoff=True, - ) - - async def delete_group_category( - self, destination: str, group_id: str, requester_user_id: str, category_id: str - ) -> JsonDict: - """Delete a category in a group""" - path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id) - - return await self.client.delete_json( - destination=destination, - path=path, - args={"requester_user_id": requester_user_id}, - ignore_backoff=True, - ) - - async def get_group_roles( - self, destination: str, group_id: str, requester_user_id: str - ) -> JsonDict: - """Get all roles in a group""" - path = _create_v1_path("/groups/%s/roles", group_id) - - return await self.client.get_json( - destination=destination, - path=path, - args={"requester_user_id": requester_user_id}, - ignore_backoff=True, - ) - - async def get_group_role( - self, destination: str, group_id: str, requester_user_id: str, role_id: str - ) -> JsonDict: - """Get a roles info""" - path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id) - - return await self.client.get_json( - destination=destination, - path=path, - args={"requester_user_id": requester_user_id}, - ignore_backoff=True, - ) - - async def update_group_role( - self, - destination: str, - group_id: str, - requester_user_id: str, - role_id: str, - content: JsonDict, - ) -> JsonDict: - """Update a role in a group""" - path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id) - - return await self.client.post_json( - destination=destination, - path=path, - args={"requester_user_id": requester_user_id}, - data=content, - ignore_backoff=True, - ) - - async def delete_group_role( - self, destination: str, group_id: str, requester_user_id: str, role_id: str - ) -> JsonDict: - """Delete a role in a group""" - path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id) - - return await self.client.delete_json( - destination=destination, - path=path, - args={"requester_user_id": requester_user_id}, - ignore_backoff=True, - ) - - async def update_group_summary_user( - self, - destination: str, - group_id: str, - requester_user_id: str, - user_id: str, - role_id: str, - content: JsonDict, - ) -> JsonDict: - """Update a users entry in a group""" - if role_id: - path = _create_v1_path( - "/groups/%s/summary/roles/%s/users/%s", group_id, role_id, user_id - ) - else: - path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id) - - return await self.client.post_json( - destination=destination, - path=path, - args={"requester_user_id": requester_user_id}, - data=content, - ignore_backoff=True, - ) - - async def set_group_join_policy( - self, destination: str, group_id: str, requester_user_id: str, content: JsonDict - ) -> JsonDict: - """Sets the join policy for a group""" - path = _create_v1_path("/groups/%s/settings/m.join_policy", group_id) - - return await self.client.put_json( - destination=destination, - path=path, - args={"requester_user_id": requester_user_id}, - data=content, - ignore_backoff=True, - ) - - async def delete_group_summary_user( - self, - destination: str, - group_id: str, - requester_user_id: str, - user_id: str, - role_id: str, - ) -> JsonDict: - """Delete a users entry in a group""" - if role_id: - path = _create_v1_path( - "/groups/%s/summary/roles/%s/users/%s", group_id, role_id, user_id - ) - else: - path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id) - - return await self.client.delete_json( - destination=destination, - path=path, - args={"requester_user_id": requester_user_id}, - ignore_backoff=True, - ) - - async def bulk_get_publicised_groups( - self, destination: str, user_ids: Iterable[str] - ) -> JsonDict: - """Get the groups a list of users are publicising""" - - path = _create_v1_path("/get_groups_publicised") - - content = {"user_ids": user_ids} - - return await self.client.post_json( - destination=destination, path=path, data=content, ignore_backoff=True - ) - async def get_room_complexity(self, destination: str, room_id: str) -> JsonDict: """ Args: From c52abc1cfdd9e5480cdb4a03d626fe61cacc6573 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 27 May 2022 07:14:36 -0400 Subject: [PATCH 29/74] Additional constants for EDU types. (#12884) Instead of hard-coding strings in many places. --- changelog.d/12884.misc | 1 + synapse/api/constants.py | 8 ++++- synapse/api/filtering.py | 4 +-- synapse/federation/federation_server.py | 2 +- .../sender/per_destination_queue.py | 7 ++-- .../federation/sender/transaction_manager.py | 6 +++- .../federation/transport/server/federation.py | 6 +++- synapse/handlers/appservice.py | 4 +-- synapse/handlers/device.py | 5 +-- synapse/handlers/devicemessage.py | 6 ++-- synapse/handlers/e2e_keys.py | 5 +-- synapse/handlers/events.py | 2 +- synapse/handlers/initial_sync.py | 4 +-- synapse/handlers/presence.py | 8 +++-- synapse/handlers/receipts.py | 6 ++-- synapse/handlers/typing.py | 11 ++++--- synapse/notifier.py | 4 +-- synapse/rest/client/sync.py | 4 +-- synapse/storage/databases/main/devices.py | 5 +-- synapse/storage/databases/main/receipts.py | 8 ++--- tests/api/test_filtering.py | 6 ++-- tests/events/test_presence_router.py | 2 +- tests/federation/test_federation_sender.py | 26 ++++++++------- tests/federation/transport/test_server.py | 4 ++- tests/handlers/test_appservice.py | 3 +- tests/handlers/test_receipts.py | 32 +++++++++--------- tests/handlers/test_typing.py | 33 +++++++++++++------ tests/module_api/test_api.py | 2 +- tests/rest/client/test_events.py | 3 +- tests/rest/client/test_rooms.py | 3 +- tests/rest/client/test_sendtodevice.py | 5 +-- tests/rest/client/test_shadow_banned.py | 4 +-- tests/rest/client/test_sync.py | 3 +- tests/rest/client/test_typing.py | 3 +- tests/storage/test_devices.py | 7 ++-- 35 files changed, 146 insertions(+), 96 deletions(-) create mode 100644 changelog.d/12884.misc diff --git a/changelog.d/12884.misc b/changelog.d/12884.misc new file mode 100644 index 0000000000..56eead9472 --- /dev/null +++ b/changelog.d/12884.misc @@ -0,0 +1 @@ +Use constants for EDU types. diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 4a0552e7e5..f03fdd6dae 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -137,7 +137,13 @@ class DeviceKeyAlgorithms: class EduTypes: - Presence: Final = "m.presence" + PRESENCE: Final = "m.presence" + TYPING: Final = "m.typing" + RECEIPT: Final = "m.receipt" + DEVICE_LIST_UPDATE: Final = "m.device_list_update" + SIGNING_KEY_UPDATE: Final = "m.signing_key_update" + UNSTABLE_SIGNING_KEY_UPDATE: Final = "org.matrix.signing_key_update" + DIRECT_TO_DEVICE: Final = "m.direct_to_device" class RejectedReason: diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index b91ce06de7..b007147519 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -33,7 +33,7 @@ from typing import ( import jsonschema from jsonschema import FormatChecker -from synapse.api.constants import EventContentFields +from synapse.api.constants import EduTypes, EventContentFields from synapse.api.errors import SynapseError from synapse.api.presence import UserPresenceState from synapse.events import EventBase @@ -347,7 +347,7 @@ class Filter: user_id = event.user_id field_matchers = { "senders": lambda v: user_id == v, - "types": lambda v: "m.presence" == v, + "types": lambda v: EduTypes.PRESENCE == v, } return self._check_fields(field_matchers) else: diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index b8232e5257..5b227b85fd 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -1353,7 +1353,7 @@ class FederationHandlerRegistry: self._edu_type_to_instance[edu_type] = instance_names async def on_edu(self, edu_type: str, origin: str, content: dict) -> None: - if not self.config.server.use_presence and edu_type == EduTypes.Presence: + if not self.config.server.use_presence and edu_type == EduTypes.PRESENCE: return # Check if we have a handler on this instance diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index 8983b5a53d..333ca9a97f 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -21,6 +21,7 @@ from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tupl import attr from prometheus_client import Counter +from synapse.api.constants import EduTypes from synapse.api.errors import ( FederationDeniedError, HttpResponseException, @@ -542,7 +543,7 @@ class PerDestinationQueue: edu = Edu( origin=self._server_name, destination=self._destination, - edu_type="m.receipt", + edu_type=EduTypes.RECEIPT, content=self._pending_rrs, ) self._pending_rrs = {} @@ -592,7 +593,7 @@ class PerDestinationQueue: Edu( origin=self._server_name, destination=self._destination, - edu_type="m.direct_to_device", + edu_type=EduTypes.DIRECT_TO_DEVICE, content=content, ) for content in contents @@ -670,7 +671,7 @@ class _TransactionQueueManager: Edu( origin=self.queue._server_name, destination=self.queue._destination, - edu_type="m.presence", + edu_type=EduTypes.PRESENCE, content={ "push": [ format_user_presence_state( diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py index 0c1cad86ab..75081810fd 100644 --- a/synapse/federation/sender/transaction_manager.py +++ b/synapse/federation/sender/transaction_manager.py @@ -16,6 +16,7 @@ from typing import TYPE_CHECKING, List from prometheus_client import Gauge +from synapse.api.constants import EduTypes from synapse.api.errors import HttpResponseException from synapse.events import EventBase from synapse.federation.persistence import TransactionActions @@ -126,7 +127,10 @@ class TransactionManager: len(edus), ) if issue_8631_logger.isEnabledFor(logging.DEBUG): - DEVICE_UPDATE_EDUS = {"m.device_list_update", "m.signing_key_update"} + DEVICE_UPDATE_EDUS = { + EduTypes.DEVICE_LIST_UPDATE, + EduTypes.SIGNING_KEY_UPDATE, + } device_list_updates = [ edu.content for edu in edus if edu.edu_type in DEVICE_UPDATE_EDUS ] diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py index 57e8fb21b0..7dfb890661 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py @@ -27,6 +27,7 @@ from typing import ( from matrix_common.versionstring import get_distribution_version_string from typing_extensions import Literal +from synapse.api.constants import EduTypes from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import RoomVersions from synapse.api.urls import FEDERATION_UNSTABLE_PREFIX, FEDERATION_V2_PREFIX @@ -108,7 +109,10 @@ class FederationSendServlet(BaseFederationServerServlet): ) if issue_8631_logger.isEnabledFor(logging.DEBUG): - DEVICE_UPDATE_EDUS = ["m.device_list_update", "m.signing_key_update"] + DEVICE_UPDATE_EDUS = [ + EduTypes.DEVICE_LIST_UPDATE, + EduTypes.SIGNING_KEY_UPDATE, + ] device_list_updates = [ edu.get("content", {}) for edu in transaction_data.get("edus", []) diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 1da7bcc85b..814553e098 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -19,7 +19,7 @@ from prometheus_client import Counter from twisted.internet import defer import synapse -from synapse.api.constants import EventTypes +from synapse.api.constants import EduTypes, EventTypes from synapse.appservice import ApplicationService from synapse.events import EventBase from synapse.handlers.presence import format_user_presence_state @@ -503,7 +503,7 @@ class ApplicationServicesHandler: time_now = self.clock.time_msec() events.extend( { - "type": "m.presence", + "type": EduTypes.PRESENCE, "sender": event.user_id, "content": format_user_presence_state( event, time_now, include_user_id=False diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index b21e469865..438a549339 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -28,7 +28,7 @@ from typing import ( ) from synapse.api import errors -from synapse.api.constants import EventTypes +from synapse.api.constants import EduTypes, EventTypes from synapse.api.errors import ( Codes, FederationDeniedError, @@ -279,7 +279,8 @@ class DeviceHandler(DeviceWorkerHandler): federation_registry = hs.get_federation_registry() federation_registry.register_edu_handler( - "m.device_list_update", self.device_list_updater.incoming_device_list_update + EduTypes.DEVICE_LIST_UPDATE, + self.device_list_updater.incoming_device_list_update, ) hs.get_distributor().observe("user_left_room", self.user_left_room) diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index 53668cce3b..444c08bc2e 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -15,7 +15,7 @@ import logging from typing import TYPE_CHECKING, Any, Dict -from synapse.api.constants import ToDeviceEventTypes +from synapse.api.constants import EduTypes, ToDeviceEventTypes from synapse.api.errors import SynapseError from synapse.api.ratelimiting import Ratelimiter from synapse.logging.context import run_in_background @@ -59,11 +59,11 @@ class DeviceMessageHandler: # to the appropriate worker. if hs.get_instance_name() in hs.config.worker.writers.to_device: hs.get_federation_registry().register_edu_handler( - "m.direct_to_device", self.on_direct_to_device_edu + EduTypes.DIRECT_TO_DEVICE, self.on_direct_to_device_edu ) else: hs.get_federation_registry().register_instances_for_edu( - "m.direct_to_device", + EduTypes.DIRECT_TO_DEVICE, hs.config.worker.writers.to_device, ) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index e6c2cfb8c8..52bb5c9c55 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -25,6 +25,7 @@ from unpaddedbase64 import decode_base64 from twisted.internet import defer +from synapse.api.constants import EduTypes from synapse.api.errors import CodeMessageException, Codes, NotFoundError, SynapseError from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace @@ -66,13 +67,13 @@ class E2eKeysHandler: # Only register this edu handler on master as it requires writing # device updates to the db federation_registry.register_edu_handler( - "m.signing_key_update", + EduTypes.SIGNING_KEY_UPDATE, self._edu_updater.incoming_signing_key_update, ) # also handle the unstable version # FIXME: remove this when enough servers have upgraded federation_registry.register_edu_handler( - "org.matrix.signing_key_update", + EduTypes.UNSTABLE_SIGNING_KEY_UPDATE, self._edu_updater.incoming_signing_key_update, ) diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 82a5aac3dd..cb7e0ca7a8 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -113,7 +113,7 @@ class EventStreamHandler: states = await presence_handler.get_states(users) to_add.extend( { - "type": EduTypes.Presence, + "type": EduTypes.PRESENCE, "content": format_user_presence_state(state, time_now), } for state in states diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index c06932a41a..fbdbeeedfd 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -274,7 +274,7 @@ class InitialSyncHandler: "rooms": rooms_ret, "presence": [ { - "type": "m.presence", + "type": EduTypes.PRESENCE, "content": format_user_presence_state(event, now), } for event in presence @@ -439,7 +439,7 @@ class InitialSyncHandler: return [ { - "type": EduTypes.Presence, + "type": EduTypes.PRESENCE, "content": format_user_presence_state(s, time_now), } for s in states diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index dd84e6c88b..bf112b9e1e 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -49,7 +49,7 @@ from prometheus_client import Counter from typing_extensions import ContextManager import synapse.metrics -from synapse.api.constants import EventTypes, Membership, PresenceState +from synapse.api.constants import EduTypes, EventTypes, Membership, PresenceState from synapse.api.errors import SynapseError from synapse.api.presence import UserPresenceState from synapse.appservice import ApplicationService @@ -394,7 +394,7 @@ class WorkerPresenceHandler(BasePresenceHandler): # Route presence EDUs to the right worker hs.get_federation_registry().register_instances_for_edu( - "m.presence", + EduTypes.PRESENCE, hs.config.worker.writers.presence, ) @@ -649,7 +649,9 @@ class PresenceHandler(BasePresenceHandler): federation_registry = hs.get_federation_registry() - federation_registry.register_edu_handler("m.presence", self.incoming_presence) + federation_registry.register_edu_handler( + EduTypes.PRESENCE, self.incoming_presence + ) LaterGauge( "synapse_handlers_presence_user_to_current_state_size", diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index e6a35f1d09..43d2882b0a 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -14,7 +14,7 @@ import logging from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple -from synapse.api.constants import ReceiptTypes +from synapse.api.constants import EduTypes, ReceiptTypes from synapse.appservice import ApplicationService from synapse.streams import EventSource from synapse.types import ( @@ -52,11 +52,11 @@ class ReceiptsHandler: # to the appropriate worker. if hs.get_instance_name() in hs.config.worker.writers.receipts: hs.get_federation_registry().register_edu_handler( - "m.receipt", self._received_remote_receipt + EduTypes.RECEIPT, self._received_remote_receipt ) else: hs.get_federation_registry().register_instances_for_edu( - "m.receipt", + EduTypes.RECEIPT, hs.config.worker.writers.receipts, ) diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index bb00750bfd..0aeab86bbb 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple import attr +from synapse.api.constants import EduTypes from synapse.api.errors import AuthError, ShadowBanError, SynapseError from synapse.appservice import ApplicationService from synapse.metrics.background_process_metrics import ( @@ -68,7 +69,7 @@ class FollowerTypingHandler: if hs.get_instance_name() not in hs.config.worker.writers.typing: hs.get_federation_registry().register_instances_for_edu( - "m.typing", + EduTypes.TYPING, hs.config.worker.writers.typing, ) @@ -143,7 +144,7 @@ class FollowerTypingHandler: logger.debug("sending typing update to %s", domain) self.federation.build_and_send_edu( destination=domain, - edu_type="m.typing", + edu_type=EduTypes.TYPING, content={ "room_id": member.room_id, "user_id": member.user_id, @@ -218,7 +219,9 @@ class TypingWriterHandler(FollowerTypingHandler): self.hs = hs - hs.get_federation_registry().register_edu_handler("m.typing", self._recv_edu) + hs.get_federation_registry().register_edu_handler( + EduTypes.TYPING, self._recv_edu + ) hs.get_distributor().observe("user_left_room", self.user_left_room) @@ -458,7 +461,7 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]): def _make_event_for(self, room_id: str) -> JsonDict: typing = self.get_typing_handler()._room_typing[room_id] return { - "type": "m.typing", + "type": EduTypes.TYPING, "room_id": room_id, "content": {"user_ids": list(typing)}, } diff --git a/synapse/notifier.py b/synapse/notifier.py index ba23257f54..c2b66eec62 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -33,7 +33,7 @@ from prometheus_client import Counter from twisted.internet import defer -from synapse.api.constants import EventTypes, HistoryVisibility, Membership +from synapse.api.constants import EduTypes, EventTypes, HistoryVisibility, Membership from synapse.api.errors import AuthError from synapse.events import EventBase from synapse.handlers.presence import format_user_presence_state @@ -632,7 +632,7 @@ class Notifier: now = self.clock.time_msec() new_events[:] = [ { - "type": "m.presence", + "type": EduTypes.PRESENCE, "content": format_user_presence_state(event, now), } for event in new_events diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index f596b792fa..8bbf35148d 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -16,7 +16,7 @@ import logging from collections import defaultdict from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union -from synapse.api.constants import Membership, PresenceState +from synapse.api.constants import EduTypes, Membership, PresenceState from synapse.api.errors import Codes, StoreError, SynapseError from synapse.api.filtering import FilterCollection from synapse.api.presence import UserPresenceState @@ -305,7 +305,7 @@ class SyncRestServlet(RestServlet): return { "events": [ { - "type": "m.presence", + "type": EduTypes.PRESENCE, "sender": event.user_id, "content": format_user_presence_state( event, time_now, include_user_id=False diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 2df4dd4ed4..dd43bae784 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -28,6 +28,7 @@ from typing import ( cast, ) +from synapse.api.constants import EduTypes from synapse.api.errors import Codes, StoreError from synapse.logging.opentracing import ( get_active_span_text_map, @@ -419,7 +420,7 @@ class DeviceWorkerStore(SQLBaseStore): # Add the updated cross-signing keys to the results list for user_id, result in cross_signing_keys_by_user.items(): result["user_id"] = user_id - results.append(("m.signing_key_update", result)) + results.append((EduTypes.SIGNING_KEY_UPDATE, result)) # also send the unstable version # FIXME: remove this when enough servers have upgraded # and remove the length budgeting above. @@ -545,7 +546,7 @@ class DeviceWorkerStore(SQLBaseStore): else: result["deleted"] = True - results.append(("m.device_list_update", result)) + results.append((EduTypes.DEVICE_LIST_UPDATE, result)) return results diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index cfa4d4924d..f74aa1e3f3 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -26,7 +26,7 @@ from typing import ( cast, ) -from synapse.api.constants import ReceiptTypes +from synapse.api.constants import EduTypes, ReceiptTypes from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import ReceiptsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause @@ -363,7 +363,7 @@ class ReceiptsWorkerStore(SQLBaseStore): row["user_id"] ] = db_to_json(row["data"]) - return [{"type": "m.receipt", "room_id": room_id, "content": content}] + return [{"type": EduTypes.RECEIPT, "room_id": room_id, "content": content}] @cachedList( cached_method_name="_get_linearized_receipts_for_room", @@ -411,7 +411,7 @@ class ReceiptsWorkerStore(SQLBaseStore): # receipts by room, event and type. room_event = results.setdefault( row["room_id"], - {"type": "m.receipt", "room_id": row["room_id"], "content": {}}, + {"type": EduTypes.RECEIPT, "room_id": row["room_id"], "content": {}}, ) # The content is of the form: @@ -476,7 +476,7 @@ class ReceiptsWorkerStore(SQLBaseStore): # receipts by room, event and type. room_event = results.setdefault( row["room_id"], - {"type": "m.receipt", "room_id": row["room_id"], "content": {}}, + {"type": EduTypes.RECEIPT, "room_id": row["room_id"], "content": {}}, ) # The content is of the form: diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index 985d6e397d..a269c477fb 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -20,7 +20,7 @@ from unittest.mock import patch import jsonschema from frozendict import frozendict -from synapse.api.constants import EventContentFields +from synapse.api.constants import EduTypes, EventContentFields from synapse.api.errors import SynapseError from synapse.api.filtering import Filter from synapse.events import make_event_from_dict @@ -85,13 +85,13 @@ class FilteringTestCase(unittest.HomeserverTestCase): "org.matrix.not_labels": ["#work"], }, "ephemeral": { - "types": ["m.receipt", "m.typing"], + "types": [EduTypes.RECEIPT, EduTypes.TYPING], "not_rooms": ["!726s6s6q:example.com"], "not_senders": ["@spam:example.com"], }, }, "presence": { - "types": ["m.presence"], + "types": [EduTypes.PRESENCE], "not_senders": ["@alice:example.com"], }, "event_format": "client", diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py index 3deb14c308..ffc3012a86 100644 --- a/tests/events/test_presence_router.py +++ b/tests/events/test_presence_router.py @@ -439,7 +439,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): for edu in edus: # Make sure we're only checking presence-type EDUs - if edu["edu_type"] != EduTypes.Presence: + if edu["edu_type"] != EduTypes.PRESENCE: continue # EDUs can contain multiple presence updates diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index 6b26353d5e..b5be727fe4 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -19,7 +19,7 @@ from signedjson.types import BaseKey, SigningKey from twisted.internet import defer -from synapse.api.constants import RoomEncryptionAlgorithms +from synapse.api.constants import EduTypes, RoomEncryptionAlgorithms from synapse.rest import admin from synapse.rest.client import login from synapse.types import JsonDict, ReadReceipt @@ -63,7 +63,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): data["edus"], [ { - "edu_type": "m.receipt", + "edu_type": EduTypes.RECEIPT, "content": { "room_id": { "m.read": { @@ -103,7 +103,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): data["edus"], [ { - "edu_type": "m.receipt", + "edu_type": EduTypes.RECEIPT, "content": { "room_id": { "m.read": { @@ -138,7 +138,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): data["edus"], [ { - "edu_type": "m.receipt", + "edu_type": EduTypes.RECEIPT, "content": { "room_id": { "m.read": { @@ -322,8 +322,10 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): # expect signing key update edu self.assertEqual(len(self.edus), 2) - self.assertEqual(self.edus.pop(0)["edu_type"], "m.signing_key_update") - self.assertEqual(self.edus.pop(0)["edu_type"], "org.matrix.signing_key_update") + self.assertEqual(self.edus.pop(0)["edu_type"], EduTypes.SIGNING_KEY_UPDATE) + self.assertEqual( + self.edus.pop(0)["edu_type"], EduTypes.UNSTABLE_SIGNING_KEY_UPDATE + ) # sign the devices d1_json = build_device_dict(u1, "D1", device1_signing_key) @@ -348,7 +350,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): self.assertEqual(len(self.edus), 2) stream_id = None # FIXME: there is a discontinuity in the stream IDs: see #7142 for edu in self.edus: - self.assertEqual(edu["edu_type"], "m.device_list_update") + self.assertEqual(edu["edu_type"], EduTypes.DEVICE_LIST_UPDATE) c = edu["content"] if stream_id is not None: self.assertEqual(c["prev_id"], [stream_id]) @@ -388,7 +390,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): # expect three edus, in an unknown order self.assertEqual(len(self.edus), 3) for edu in self.edus: - self.assertEqual(edu["edu_type"], "m.device_list_update") + self.assertEqual(edu["edu_type"], EduTypes.DEVICE_LIST_UPDATE) c = edu["content"] self.assertGreaterEqual( c.items(), @@ -435,7 +437,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): self.assertEqual(len(self.edus), 3) stream_id = None for edu in self.edus: - self.assertEqual(edu["edu_type"], "m.device_list_update") + self.assertEqual(edu["edu_type"], EduTypes.DEVICE_LIST_UPDATE) c = edu["content"] self.assertEqual(c["prev_id"], [stream_id] if stream_id is not None else []) if stream_id is not None: @@ -487,7 +489,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): # there should be a single update for this user. self.assertEqual(len(self.edus), 1) edu = self.edus.pop(0) - self.assertEqual(edu["edu_type"], "m.device_list_update") + self.assertEqual(edu["edu_type"], EduTypes.DEVICE_LIST_UPDATE) c = edu["content"] # synapse uses an empty prev_id list to indicate "needs a full resync". @@ -544,7 +546,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): # ... and we should get a single update for this user. self.assertEqual(len(self.edus), 1) edu = self.edus.pop(0) - self.assertEqual(edu["edu_type"], "m.device_list_update") + self.assertEqual(edu["edu_type"], EduTypes.DEVICE_LIST_UPDATE) c = edu["content"] # synapse uses an empty prev_id list to indicate "needs a full resync". @@ -560,7 +562,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): """Check that the given EDU is an update for the given device Returns the stream_id. """ - self.assertEqual(edu["edu_type"], "m.device_list_update") + self.assertEqual(edu["edu_type"], EduTypes.DEVICE_LIST_UPDATE) content = edu["content"] expected = { diff --git a/tests/federation/transport/test_server.py b/tests/federation/transport/test_server.py index 5f001c33b0..cfd550a04b 100644 --- a/tests/federation/transport/test_server.py +++ b/tests/federation/transport/test_server.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.api.constants import EduTypes + from tests import unittest from tests.unittest import DEBUG, override_config @@ -50,7 +52,7 @@ class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase): "/_matrix/federation/v1/send/txn_id_1234/", content={ "edus": [ - {"edu_type": "m.device_list_update", "content": {"foo": "bar"}} + {"edu_type": EduTypes.DEVICE_LIST_UPDATE, "content": {"foo": "bar"}} ], "pdus": [], }, diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 53e7a5d81b..0e100c404d 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -22,6 +22,7 @@ from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin import synapse.storage +from synapse.api.constants import EduTypes from synapse.appservice import ( ApplicationService, TransactionOneTimeKeyCounts, @@ -476,7 +477,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): # Check that the ephemeral event is a read receipt with the expected structure latest_read_receipt = all_ephemeral_events[-1] - self.assertEqual(latest_read_receipt["type"], "m.receipt") + self.assertEqual(latest_read_receipt["type"], EduTypes.RECEIPT) event_id = list(latest_read_receipt["content"].keys())[0] self.assertEqual( diff --git a/tests/handlers/test_receipts.py b/tests/handlers/test_receipts.py index 78807cdcfc..a95868b5c0 100644 --- a/tests/handlers/test_receipts.py +++ b/tests/handlers/test_receipts.py @@ -15,7 +15,7 @@ from copy import deepcopy from typing import List -from synapse.api.constants import ReceiptTypes +from synapse.api.constants import EduTypes, ReceiptTypes from synapse.types import JsonDict from tests import unittest @@ -39,7 +39,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): } }, "room_id": "!jEsUZKDJdhlrceRyVU:example.org", - "type": "m.receipt", + "type": EduTypes.RECEIPT, } ], [], @@ -64,7 +64,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): }, }, "room_id": "!jEsUZKDJdhlrceRyVU:example.org", - "type": "m.receipt", + "type": EduTypes.RECEIPT, } ], [ @@ -79,7 +79,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): } }, "room_id": "!jEsUZKDJdhlrceRyVU:example.org", - "type": "m.receipt", + "type": EduTypes.RECEIPT, } ], ) @@ -105,7 +105,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): }, }, "room_id": "!jEsUZKDJdhlrceRyVU:example.org", - "type": "m.receipt", + "type": EduTypes.RECEIPT, } ], [ @@ -120,7 +120,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): } }, "room_id": "!jEsUZKDJdhlrceRyVU:example.org", - "type": "m.receipt", + "type": EduTypes.RECEIPT, } ], ) @@ -140,7 +140,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): }, }, "room_id": "!jEsUZKDJdhlrceRyVU:example.org", - "type": "m.receipt", + "type": EduTypes.RECEIPT, } ], [ @@ -155,7 +155,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): }, }, "room_id": "!jEsUZKDJdhlrceRyVU:example.org", - "type": "m.receipt", + "type": EduTypes.RECEIPT, } ], ) @@ -174,7 +174,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): }, }, "room_id": "!jEsUZKDJdhlrceRyVU:example.org", - "type": "m.receipt", + "type": EduTypes.RECEIPT, }, { "content": { @@ -187,7 +187,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): }, }, "room_id": "!jEsUZKDJdhlrceRyVU:example.org", - "type": "m.receipt", + "type": EduTypes.RECEIPT, }, ], [ @@ -202,7 +202,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): } }, "room_id": "!jEsUZKDJdhlrceRyVU:example.org", - "type": "m.receipt", + "type": EduTypes.RECEIPT, } ], ) @@ -224,7 +224,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): }, }, "room_id": "!jEsUZKDJdhlrceRyVU:example.org", - "type": "m.receipt", + "type": EduTypes.RECEIPT, }, ], [ @@ -237,7 +237,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): }, }, "room_id": "!jEsUZKDJdhlrceRyVU:example.org", - "type": "m.receipt", + "type": EduTypes.RECEIPT, }, ], ) @@ -266,7 +266,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): }, }, "room_id": "!jEsUZKDJdhlrceRyVU:example.org", - "type": "m.receipt", + "type": EduTypes.RECEIPT, } ], [ @@ -291,7 +291,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): } }, "room_id": "!jEsUZKDJdhlrceRyVU:example.org", - "type": "m.receipt", + "type": EduTypes.RECEIPT, } ], ) @@ -310,7 +310,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): } }, "room_id": "!jEsUZKDJdhlrceRyVU:example.org", - "type": "m.receipt", + "type": EduTypes.RECEIPT, } ] original_events = deepcopy(events) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 5f2e26a5fc..057256cecd 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -21,6 +21,7 @@ from twisted.internet import defer from twisted.test.proto_helpers import MemoryReactor from twisted.web.resource import Resource +from synapse.api.constants import EduTypes from synapse.api.errors import AuthError from synapse.federation.transport.server import TransportLayerServer from synapse.server import HomeServer @@ -184,7 +185,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): events[0], [ { - "type": "m.typing", + "type": EduTypes.TYPING, "room_id": ROOM_ID, "content": {"user_ids": [U_APPLE.to_string()]}, } @@ -209,7 +210,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): "farm", path="/_matrix/federation/v1/send/1000000", data=_expect_edu_transaction( - "m.typing", + EduTypes.TYPING, content={ "room_id": ROOM_ID, "user_id": U_APPLE.to_string(), @@ -231,7 +232,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): "PUT", "/_matrix/federation/v1/send/1000000", _make_edu_transaction_json( - "m.typing", + EduTypes.TYPING, content={ "room_id": ROOM_ID, "user_id": U_ONION.to_string(), @@ -254,7 +255,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): events[0], [ { - "type": "m.typing", + "type": EduTypes.TYPING, "room_id": ROOM_ID, "content": {"user_ids": [U_ONION.to_string()]}, } @@ -270,7 +271,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): "PUT", "/_matrix/federation/v1/send/1000000", _make_edu_transaction_json( - "m.typing", + EduTypes.TYPING, content={ "room_id": OTHER_ROOM_ID, "user_id": U_ONION.to_string(), @@ -324,7 +325,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): "farm", path="/_matrix/federation/v1/send/1000000", data=_expect_edu_transaction( - "m.typing", + EduTypes.TYPING, content={ "room_id": ROOM_ID, "user_id": U_APPLE.to_string(), @@ -345,7 +346,13 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ) self.assertEqual( events[0], - [{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}], + [ + { + "type": EduTypes.TYPING, + "room_id": ROOM_ID, + "content": {"user_ids": []}, + } + ], ) def test_typing_timeout(self) -> None: @@ -379,7 +386,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): events[0], [ { - "type": "m.typing", + "type": EduTypes.TYPING, "room_id": ROOM_ID, "content": {"user_ids": [U_APPLE.to_string()]}, } @@ -402,7 +409,13 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ) self.assertEqual( events[0], - [{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}], + [ + { + "type": EduTypes.TYPING, + "room_id": ROOM_ID, + "content": {"user_ids": []}, + } + ], ) # SYN-230 - see if we can still set after timeout @@ -433,7 +446,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): events[0], [ { - "type": "m.typing", + "type": EduTypes.TYPING, "room_id": ROOM_ID, "content": {"user_ids": [U_APPLE.to_string()]}, } diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index 8bc84aaaca..169e29b590 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -399,7 +399,7 @@ class ModuleApiTestCase(HomeserverTestCase): for edu in edus: # Make sure we're only checking presence-type EDUs - if edu["edu_type"] != EduTypes.Presence: + if edu["edu_type"] != EduTypes.PRESENCE: continue # EDUs can contain multiple presence updates diff --git a/tests/rest/client/test_events.py b/tests/rest/client/test_events.py index 1b1392fa2f..a9b7db9db2 100644 --- a/tests/rest/client/test_events.py +++ b/tests/rest/client/test_events.py @@ -19,6 +19,7 @@ from unittest.mock import Mock from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin +from synapse.api.constants import EduTypes from synapse.rest.client import events, login, room from synapse.server import HomeServer from synapse.util import Clock @@ -103,7 +104,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): c for c in channel.json_body["chunk"] if not ( - c.get("type") == "m.presence" + c.get("type") == EduTypes.PRESENCE and c["content"].get("user_id") == self.user_id ) ] diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index d0197aca94..f523d89b8f 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -26,6 +26,7 @@ from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.api.constants import ( + EduTypes, EventContentFields, EventTypes, Membership, @@ -1412,7 +1413,7 @@ class RoomInitialSyncTestCase(RoomBase): e["content"]["user_id"]: e for e in channel.json_body["presence"] } self.assertTrue(self.user_id in presence_by_user) - self.assertEqual("m.presence", presence_by_user[self.user_id]["type"]) + self.assertEqual(EduTypes.PRESENCE, presence_by_user[self.user_id]["type"]) class RoomMessageListTestCase(RoomBase): diff --git a/tests/rest/client/test_sendtodevice.py b/tests/rest/client/test_sendtodevice.py index c3942889e1..6435800fa1 100644 --- a/tests/rest/client/test_sendtodevice.py +++ b/tests/rest/client/test_sendtodevice.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.api.constants import EduTypes from synapse.rest import admin from synapse.rest.client import login, sendtodevice, sync @@ -139,7 +140,7 @@ class SendToDeviceTestCase(HomeserverTestCase): for i in range(3): self.get_success( federation_registry.on_edu( - "m.direct_to_device", + EduTypes.DIRECT_TO_DEVICE, "remote_server", { "sender": "@user:remote_server", @@ -172,7 +173,7 @@ class SendToDeviceTestCase(HomeserverTestCase): # and we can send more messages self.get_success( federation_registry.on_edu( - "m.direct_to_device", + EduTypes.DIRECT_TO_DEVICE, "remote_server", { "sender": "@user:remote_server", diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py index ae5ada3be7..d9bd8c4a28 100644 --- a/tests/rest/client/test_shadow_banned.py +++ b/tests/rest/client/test_shadow_banned.py @@ -17,7 +17,7 @@ from unittest.mock import Mock, patch from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin -from synapse.api.constants import EventTypes +from synapse.api.constants import EduTypes, EventTypes from synapse.rest.client import ( directory, login, @@ -226,7 +226,7 @@ class RoomTestCase(_ShadowBannedBase): events[0], [ { - "type": "m.typing", + "type": EduTypes.TYPING, "room_id": room_id, "content": {"user_ids": [self.other_user_id]}, } diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index 74b6560cbc..e3efd1f1b0 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -22,6 +22,7 @@ from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.api.constants import ( + EduTypes, EventContentFields, EventTypes, ReceiptTypes, @@ -504,7 +505,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): # Checks if event is a read receipt def is_read_receipt(event: JsonDict) -> bool: - return event["type"] == "m.receipt" + return event["type"] == EduTypes.RECEIPT # Sync channel = self.make_request( diff --git a/tests/rest/client/test_typing.py b/tests/rest/client/test_typing.py index d6da510773..61b66d7685 100644 --- a/tests/rest/client/test_typing.py +++ b/tests/rest/client/test_typing.py @@ -17,6 +17,7 @@ from twisted.test.proto_helpers import MemoryReactor +from synapse.api.constants import EduTypes from synapse.rest.client import room from synapse.server import HomeServer from synapse.types import UserID @@ -67,7 +68,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): events[0], [ { - "type": "m.typing", + "type": EduTypes.TYPING, "room_id": self.room_id, "content": {"user_ids": [self.user_id]}, } diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index bbf079b25b..f37505b6cf 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -13,6 +13,7 @@ # limitations under the License. import synapse.api.errors +from synapse.api.constants import EduTypes from tests.unittest import HomeserverTestCase @@ -266,10 +267,12 @@ class DeviceStoreTestCase(HomeserverTestCase): # (This is a temporary arrangement for backwards compatibility!) self.assertEqual(len(device_updates), 2, device_updates) self.assertEqual( - device_updates[0][0], "m.signing_key_update", device_updates[0] + device_updates[0][0], EduTypes.SIGNING_KEY_UPDATE, device_updates[0] ) self.assertEqual( - device_updates[1][0], "org.matrix.signing_key_update", device_updates[1] + device_updates[1][0], + EduTypes.UNSTABLE_SIGNING_KEY_UPDATE, + device_updates[1], ) # Check there are no more device updates left. From 724e11d62057a77aa9b43fdd803b6fcd1cbc183b Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 27 May 2022 07:44:10 -0400 Subject: [PATCH 30/74] Clean-up some receipts code (#12888) * Properly marks private methods as private. * Adds missing docstrings. * Rework inline methods. --- changelog.d/12888.misc | 1 + synapse/storage/databases/main/receipts.py | 89 ++++++++++++---------- 2 files changed, 48 insertions(+), 42 deletions(-) create mode 100644 changelog.d/12888.misc diff --git a/changelog.d/12888.misc b/changelog.d/12888.misc new file mode 100644 index 0000000000..8ed2ea65b5 --- /dev/null +++ b/changelog.d/12888.misc @@ -0,0 +1 @@ +Refactor receipt linearization code. diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index f74aa1e3f3..21e954ccc1 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -597,7 +597,7 @@ class ReceiptsWorkerStore(SQLBaseStore): return super().process_replication_rows(stream_name, instance_name, token, rows) - def insert_linearized_receipt_txn( + def _insert_linearized_receipt_txn( self, txn: LoggingTransaction, room_id: str, @@ -686,6 +686,44 @@ class ReceiptsWorkerStore(SQLBaseStore): return rx_ts + def _graph_to_linear( + self, txn: LoggingTransaction, room_id: str, event_ids: List[str] + ) -> str: + """ + Generate a linearized event from a list of events (i.e. a list of forward + extremities in the room). + + This should allow for calculation of the correct read receipt even if + servers have different event ordering. + + Args: + txn: The transaction + room_id: The room ID the events are in. + event_ids: The list of event IDs to linearize. + + Returns: + The linearized event ID. + """ + # TODO: Make this better. + clause, args = make_in_list_sql_clause( + self.database_engine, "event_id", event_ids + ) + + sql = """ + SELECT event_id WHERE room_id = ? AND stream_ordering IN ( + SELECT max(stream_ordering) WHERE %s + ) + """ % ( + clause, + ) + + txn.execute(sql, [room_id] + list(args)) + rows = txn.fetchall() + if rows: + return rows[0][0] + else: + raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,)) + async def insert_receipt( self, room_id: str, @@ -712,35 +750,14 @@ class ReceiptsWorkerStore(SQLBaseStore): linearized_event_id = event_ids[0] else: # we need to points in graph -> linearized form. - # TODO: Make this better. - def graph_to_linear(txn: LoggingTransaction) -> str: - clause, args = make_in_list_sql_clause( - self.database_engine, "event_id", event_ids - ) - - sql = """ - SELECT event_id WHERE room_id = ? AND stream_ordering IN ( - SELECT max(stream_ordering) WHERE %s - ) - """ % ( - clause, - ) - - txn.execute(sql, [room_id] + list(args)) - rows = txn.fetchall() - if rows: - return rows[0][0] - else: - raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,)) - linearized_event_id = await self.db_pool.runInteraction( - "insert_receipt_conv", graph_to_linear + "insert_receipt_conv", self._graph_to_linear, room_id, event_ids ) async with self._receipts_id_gen.get_next() as stream_id: # type: ignore[attr-defined] event_ts = await self.db_pool.runInteraction( "insert_linearized_receipt", - self.insert_linearized_receipt_txn, + self._insert_linearized_receipt_txn, room_id, receipt_type, user_id, @@ -761,25 +778,9 @@ class ReceiptsWorkerStore(SQLBaseStore): now - event_ts, ) - await self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data) - - max_persisted_id = self._receipts_id_gen.get_current_token() - - return stream_id, max_persisted_id - - async def insert_graph_receipt( - self, - room_id: str, - receipt_type: str, - user_id: str, - event_ids: List[str], - data: JsonDict, - ) -> None: - assert self._can_write_to_receipts - await self.db_pool.runInteraction( "insert_graph_receipt", - self.insert_graph_receipt_txn, + self._insert_graph_receipt_txn, room_id, receipt_type, user_id, @@ -787,7 +788,11 @@ class ReceiptsWorkerStore(SQLBaseStore): data, ) - def insert_graph_receipt_txn( + max_persisted_id = self._receipts_id_gen.get_current_token() + + return stream_id, max_persisted_id + + def _insert_graph_receipt_txn( self, txn: LoggingTransaction, room_id: str, From 888eb736a15035c94676eb60da6b6fabb642e252 Mon Sep 17 00:00:00 2001 From: David Teller Date: Fri, 27 May 2022 15:13:29 +0200 Subject: [PATCH 31/74] Add code M_USER_ACCOUNT_SUSPENDED, as per MSC3823. (#12845) Signed-off-by: David Teller Co-authored-by: Brendan Abolivier --- changelog.d/12845.feature | 1 + synapse/api/errors.py | 7 +++++++ 2 files changed, 8 insertions(+) create mode 100644 changelog.d/12845.feature diff --git a/changelog.d/12845.feature b/changelog.d/12845.feature new file mode 100644 index 0000000000..628fb16d08 --- /dev/null +++ b/changelog.d/12845.feature @@ -0,0 +1 @@ +Support the new error code "M_ORG_MATRIX_MSC3823_USER_ACCOUNT_SUSPENDED" from [MSC3823](https://github.com/matrix-org/matrix-spec-proposals/pull/3823). \ No newline at end of file diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 6650e826d5..05e96843cf 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -79,6 +79,13 @@ class Codes(str, Enum): WEAK_PASSWORD = "M_WEAK_PASSWORD" INVALID_SIGNATURE = "M_INVALID_SIGNATURE" USER_DEACTIVATED = "M_USER_DEACTIVATED" + + # The account has been suspended on the server. + # By opposition to `USER_DEACTIVATED`, this is a reversible measure + # that can possibly be appealed and reverted. + # Part of MSC3823. + USER_ACCOUNT_SUSPENDED = "M_ORG_MATRIX_MSC3823_USER_ACCOUNT_SUSPENDED" + BAD_ALIAS = "M_BAD_ALIAS" # For restricted join rules. UNABLE_AUTHORISE_JOIN = "M_UNABLE_TO_AUTHORISE_JOIN" From 28989cb301fecf5a669a634c09bc2b73f97fec5d Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Fri, 27 May 2022 17:47:32 +0200 Subject: [PATCH 32/74] Add a background job to automatically delete stale devices (#12855) Co-authored-by: Patrick Cloke --- changelog.d/12855.feature | 1 + .../configuration/config_documentation.md | 12 ++++++ synapse/config/server.py | 11 +++++ synapse/handlers/device.py | 30 ++++++++++++- synapse/storage/databases/main/devices.py | 39 +++++++++++++++++ .../{test_device_lists.py => test_devices.py} | 43 +++++++++++++++++++ 6 files changed, 135 insertions(+), 1 deletion(-) create mode 100644 changelog.d/12855.feature rename tests/rest/client/{test_device_lists.py => test_devices.py} (76%) diff --git a/changelog.d/12855.feature b/changelog.d/12855.feature new file mode 100644 index 0000000000..915f008ec6 --- /dev/null +++ b/changelog.d/12855.feature @@ -0,0 +1 @@ +Add a configurable background job to delete stale devices. diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index b71b09ba96..88b9e5744d 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -575,6 +575,18 @@ Example configuration: dummy_events_threshold: 5 ``` --- +Config option `delete_stale_devices_after` + +An optional duration. If set, Synapse will run a daily background task to log out and +delete any device that hasn't been accessed for more than the specified amount of time. + +Defaults to no duration, which means devices are never pruned. + +Example configuration: +```yaml +delete_stale_devices_after: 1y +``` + ## Homeserver blocking ## Useful options for Synapse admins. diff --git a/synapse/config/server.py b/synapse/config/server.py index f73d5e1f66..657322cb1f 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -679,6 +679,17 @@ class ServerConfig(Config): config.get("exclude_rooms_from_sync") or [] ) + delete_stale_devices_after: Optional[str] = ( + config.get("delete_stale_devices_after") or None + ) + + if delete_stale_devices_after is not None: + self.delete_stale_devices_after: Optional[int] = self.parse_duration( + delete_stale_devices_after + ) + else: + self.delete_stale_devices_after = None + def has_tls_listener(self) -> bool: return any(listener.tls for listener in self.listeners) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 438a549339..2a56473dc6 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -61,6 +61,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) MAX_DEVICE_DISPLAY_NAME_LEN = 100 +DELETE_STALE_DEVICES_INTERVAL_MS = 24 * 60 * 60 * 1000 class DeviceWorkerHandler: @@ -295,6 +296,19 @@ class DeviceHandler(DeviceWorkerHandler): # On start up check if there are any updates pending. hs.get_reactor().callWhenRunning(self._handle_new_device_update_async) + self._delete_stale_devices_after = hs.config.server.delete_stale_devices_after + + # Ideally we would run this on a worker and condition this on the + # "run_background_tasks_on" setting, but this would mean making the notification + # of device list changes over federation work on workers, which is nontrivial. + if self._delete_stale_devices_after is not None: + self.clock.looping_call( + run_as_background_process, + DELETE_STALE_DEVICES_INTERVAL_MS, + "delete_stale_devices", + self._delete_stale_devices, + ) + def _check_device_name_length(self, name: Optional[str]) -> None: """ Checks whether a device name is longer than the maximum allowed length. @@ -370,6 +384,19 @@ class DeviceHandler(DeviceWorkerHandler): raise errors.StoreError(500, "Couldn't generate a device ID.") + async def _delete_stale_devices(self) -> None: + """Background task that deletes devices which haven't been accessed for more than + a configured time period. + """ + # We should only be running this job if the config option is defined. + assert self._delete_stale_devices_after is not None + now_ms = self.clock.time_msec() + since_ms = now_ms - self._delete_stale_devices_after + devices = await self.store.get_local_devices_not_accessed_since(since_ms) + + for user_id, user_devices in devices.items(): + await self.delete_devices(user_id, user_devices) + @trace async def delete_device(self, user_id: str, device_id: str) -> None: """Delete the given device @@ -692,7 +719,8 @@ class DeviceHandler(DeviceWorkerHandler): ) # TODO: when called, this isn't in a logging context. # This leads to log spam, sentry event spam, and massive - # memory usage. See #12552. + # memory usage. + # See https://github.com/matrix-org/synapse/issues/12552. # log_kv( # {"message": "sent device update to host", "host": host} # ) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index dd43bae784..d900064c07 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -1154,6 +1154,45 @@ class DeviceWorkerStore(SQLBaseStore): _prune_txn, ) + async def get_local_devices_not_accessed_since( + self, since_ms: int + ) -> Dict[str, List[str]]: + """Retrieves local devices that haven't been accessed since a given date. + + Args: + since_ms: the timestamp to select on, every device with a last access date + from before that time is returned. + + Returns: + A dictionary with an entry for each user with at least one device matching + the request, which value is a list of the device ID(s) for the corresponding + device(s). + """ + + def get_devices_not_accessed_since_txn( + txn: LoggingTransaction, + ) -> List[Dict[str, str]]: + sql = """ + SELECT user_id, device_id + FROM devices WHERE last_seen < ? AND hidden = FALSE + """ + txn.execute(sql, (since_ms,)) + return self.db_pool.cursor_to_dict(txn) + + rows = await self.db_pool.runInteraction( + "get_devices_not_accessed_since", + get_devices_not_accessed_since_txn, + ) + + devices: Dict[str, List[str]] = {} + for row in rows: + # Remote devices are never stale from our point of view. + if self.hs.is_mine_id(row["user_id"]): + user_devices = devices.setdefault(row["user_id"], []) + user_devices.append(row["device_id"]) + + return devices + class DeviceBackgroundUpdateStore(SQLBaseStore): def __init__( diff --git a/tests/rest/client/test_device_lists.py b/tests/rest/client/test_devices.py similarity index 76% rename from tests/rest/client/test_device_lists.py rename to tests/rest/client/test_devices.py index a8af4e2435..aa98222434 100644 --- a/tests/rest/client/test_device_lists.py +++ b/tests/rest/client/test_devices.py @@ -13,8 +13,13 @@ # limitations under the License. from http import HTTPStatus +from twisted.test.proto_helpers import MemoryReactor + +from synapse.api.errors import NotFoundError from synapse.rest import admin, devices, room, sync from synapse.rest.client import account, login, register +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest @@ -157,3 +162,41 @@ class DeviceListsTestCase(unittest.HomeserverTestCase): self.assertNotIn( alice_user_id, changed_device_lists, bob_sync_channel.json_body ) + + +class DevicesTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + sync.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.handler = hs.get_device_handler() + + @unittest.override_config({"delete_stale_devices_after": 72000000}) + def test_delete_stale_devices(self) -> None: + """Tests that stale devices are automatically removed after a set time of + inactivity. + The configuration is set to delete devices that haven't been used in the past 20h. + """ + # Register a user and creates 2 devices for them. + user_id = self.register_user("user", "password") + tok1 = self.login("user", "password", device_id="abc") + tok2 = self.login("user", "password", device_id="def") + + # Sync them so they have a last_seen value. + self.make_request("GET", "/sync", access_token=tok1) + self.make_request("GET", "/sync", access_token=tok2) + + # Advance half a day and sync again with one of the devices, so that the next + # time the background job runs we don't delete this device (since it will look + # for devices that haven't been used for over an hour). + self.reactor.advance(43200) + self.make_request("GET", "/sync", access_token=tok1) + + # Advance another half a day, and check that the device that has synced still + # exists but the one that hasn't has been removed. + self.reactor.advance(43200) + self.get_success(self.handler.get_device(user_id, "abc")) + self.get_failure(self.handler.get_device(user_id, "def"), NotFoundError) From bda460039941d5f7ae5c66544d5b428c0b33bd22 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Mon, 30 May 2022 02:41:13 -0600 Subject: [PATCH 33/74] LockStore: fix acquiring a lock via `LockStore.try_acquire_lock` (#12832) Signed-off-by: Sumner Evans --- changelog.d/12832.bugfix | 1 + synapse/storage/databases/main/lock.py | 19 +++++++- tests/storage/databases/main/test_lock.py | 54 +++++++++++++++++++++++ 3 files changed, 73 insertions(+), 1 deletion(-) create mode 100644 changelog.d/12832.bugfix diff --git a/changelog.d/12832.bugfix b/changelog.d/12832.bugfix new file mode 100644 index 0000000000..497d5184ea --- /dev/null +++ b/changelog.d/12832.bugfix @@ -0,0 +1 @@ +Fixed a bug which allowed multiple async operations to access database locks concurrently. Contributed by @sumnerevans @ Beeper. diff --git a/synapse/storage/databases/main/lock.py b/synapse/storage/databases/main/lock.py index bedacaf0d7..2d7633fbd5 100644 --- a/synapse/storage/databases/main/lock.py +++ b/synapse/storage/databases/main/lock.py @@ -13,7 +13,7 @@ # limitations under the License. import logging from types import TracebackType -from typing import TYPE_CHECKING, Optional, Tuple, Type +from typing import TYPE_CHECKING, Optional, Set, Tuple, Type from weakref import WeakValueDictionary from twisted.internet.interfaces import IReactorCore @@ -84,6 +84,8 @@ class LockStore(SQLBaseStore): self._on_shutdown, ) + self._acquiring_locks: Set[Tuple[str, str]] = set() + @wrap_as_background_process("LockStore._on_shutdown") async def _on_shutdown(self) -> None: """Called when the server is shutting down""" @@ -103,6 +105,21 @@ class LockStore(SQLBaseStore): context manager if the lock is successfully acquired, which *must* be used (otherwise the lock will leak). """ + if (lock_name, lock_key) in self._acquiring_locks: + return None + try: + self._acquiring_locks.add((lock_name, lock_key)) + return await self._try_acquire_lock(lock_name, lock_key) + finally: + self._acquiring_locks.discard((lock_name, lock_key)) + + async def _try_acquire_lock( + self, lock_name: str, lock_key: str + ) -> Optional["Lock"]: + """Try to acquire a lock for the given name/key. Will return an async + context manager if the lock is successfully acquired, which *must* be + used (otherwise the lock will leak). + """ # Check if this process has taken out a lock and if it's still valid. lock = self._live_tokens.get((lock_name, lock_key)) diff --git a/tests/storage/databases/main/test_lock.py b/tests/storage/databases/main/test_lock.py index 74c6224eb6..3cc2a58d8d 100644 --- a/tests/storage/databases/main/test_lock.py +++ b/tests/storage/databases/main/test_lock.py @@ -12,6 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.internet import defer, reactor +from twisted.internet.base import ReactorBase +from twisted.internet.defer import Deferred + from synapse.server import HomeServer from synapse.storage.databases.main.lock import _LOCK_TIMEOUT_MS @@ -22,6 +26,56 @@ class LockTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs: HomeServer): self.store = hs.get_datastores().main + def test_acquire_contention(self): + # Track the number of tasks holding the lock. + # Should be at most 1. + in_lock = 0 + max_in_lock = 0 + + release_lock: "Deferred[None]" = Deferred() + + async def task(): + nonlocal in_lock + nonlocal max_in_lock + + lock = await self.store.try_acquire_lock("name", "key") + if not lock: + return + + async with lock: + in_lock += 1 + max_in_lock = max(max_in_lock, in_lock) + + # Block to allow other tasks to attempt to take the lock. + await release_lock + + in_lock -= 1 + + # Start 3 tasks. + task1 = defer.ensureDeferred(task()) + task2 = defer.ensureDeferred(task()) + task3 = defer.ensureDeferred(task()) + + # Give the reactor a kick so that the database transaction returns. + self.pump() + + release_lock.callback(None) + + # Run the tasks to completion. + # To work around `Linearizer`s using a different reactor to sleep when + # contended (#12841), we call `runUntilCurrent` on + # `twisted.internet.reactor`, which is a different reactor to that used + # by the homeserver. + assert isinstance(reactor, ReactorBase) + self.get_success(task1) + reactor.runUntilCurrent() + self.get_success(task2) + reactor.runUntilCurrent() + self.get_success(task3) + + # At most one task should have held the lock at a time. + self.assertEqual(max_in_lock, 1) + def test_simple_lock(self): """Test that we can take out a lock and that while we hold it nobody else can take it out. From 6be4953b998e4e4b730192b40642d2ec7bb0d7ad Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Mon, 30 May 2022 11:05:31 +0200 Subject: [PATCH 34/74] Mutual rooms: Remove dependency on user directory (#12836) --- changelog.d/12836.misc | 1 + synapse/rest/client/mutual_rooms.py | 15 +------ synapse/storage/databases/main/roommember.py | 24 +++++++++++ .../storage/databases/main/user_directory.py | 43 ------------------- tests/rest/client/test_mutual_rooms.py | 2 - 5 files changed, 27 insertions(+), 58 deletions(-) create mode 100644 changelog.d/12836.misc diff --git a/changelog.d/12836.misc b/changelog.d/12836.misc new file mode 100644 index 0000000000..85909c6a2d --- /dev/null +++ b/changelog.d/12836.misc @@ -0,0 +1 @@ +Remove Mutual Rooms ([MSC2666](https://github.com/matrix-org/matrix-spec-proposals/pull/2666)) endpoint dependency on the User Directory. \ No newline at end of file diff --git a/synapse/rest/client/mutual_rooms.py b/synapse/rest/client/mutual_rooms.py index 27bfaf0b29..38ef4e459f 100644 --- a/synapse/rest/client/mutual_rooms.py +++ b/synapse/rest/client/mutual_rooms.py @@ -42,21 +42,10 @@ class UserMutualRoomsServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastores().main - self.user_directory_search_enabled = ( - hs.config.userdirectory.user_directory_search_enabled - ) async def on_GET( self, request: SynapseRequest, user_id: str ) -> Tuple[int, JsonDict]: - - if not self.user_directory_search_enabled: - raise SynapseError( - code=400, - msg="User directory searching is disabled. Cannot determine shared rooms.", - errcode=Codes.UNKNOWN, - ) - UserID.from_string(user_id) requester = await self.auth.get_user_by_req(request) @@ -67,8 +56,8 @@ class UserMutualRoomsServlet(RestServlet): errcode=Codes.FORBIDDEN, ) - rooms = await self.store.get_mutual_rooms_for_users( - requester.user.to_string(), user_id + rooms = await self.store.get_mutual_rooms_between_users( + frozenset((requester.user.to_string(), user_id)) ) return 200, {"joined": list(rooms)} diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index cc528fcf2d..e222b7bd1f 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -670,6 +670,30 @@ class RoomMemberWorkerStore(EventsWorkerStore): return user_who_share_room + @cached(cache_context=True, iterable=True) + async def get_mutual_rooms_between_users( + self, user_ids: FrozenSet[str], cache_context: _CacheContext + ) -> FrozenSet[str]: + """ + Returns the set of rooms that all users in `user_ids` share. + + Args: + user_ids: A frozen set of all users to investigate and return + overlapping joined rooms for. + cache_context + """ + shared_room_ids: Optional[FrozenSet[str]] = None + for user_id in user_ids: + room_ids = await self.get_rooms_for_user( + user_id, on_invalidate=cache_context.invalidate + ) + if shared_room_ids is not None: + shared_room_ids &= room_ids + else: + shared_room_ids = room_ids + + return shared_room_ids or frozenset() + async def get_joined_users_from_context( self, event: EventBase, context: EventContext ) -> Dict[str, ProfileInfo]: diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 028db69af3..2282242e9d 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -729,49 +729,6 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): users.update(rows) return list(users) - async def get_mutual_rooms_for_users( - self, user_id: str, other_user_id: str - ) -> Set[str]: - """ - Returns the rooms that a local user shares with another local or remote user. - - Args: - user_id: The MXID of a local user - other_user_id: The MXID of the other user - - Returns: - A set of room ID's that the users share. - """ - - def _get_mutual_rooms_for_users_txn( - txn: LoggingTransaction, - ) -> List[Dict[str, str]]: - txn.execute( - """ - SELECT p1.room_id - FROM users_in_public_rooms as p1 - INNER JOIN users_in_public_rooms as p2 - ON p1.room_id = p2.room_id - AND p1.user_id = ? - AND p2.user_id = ? - UNION - SELECT room_id - FROM users_who_share_private_rooms - WHERE - user_id = ? - AND other_user_id = ? - """, - (user_id, other_user_id, user_id, other_user_id), - ) - rows = self.db_pool.cursor_to_dict(txn) - return rows - - rows = await self.db_pool.runInteraction( - "get_mutual_rooms_for_users", _get_mutual_rooms_for_users_txn - ) - - return {row["room_id"] for row in rows} - async def get_user_directory_stream_pos(self) -> Optional[int]: """ Get the stream ID of the user directory stream. diff --git a/tests/rest/client/test_mutual_rooms.py b/tests/rest/client/test_mutual_rooms.py index 7b7d283bb6..a4327f7ace 100644 --- a/tests/rest/client/test_mutual_rooms.py +++ b/tests/rest/client/test_mutual_rooms.py @@ -36,12 +36,10 @@ class UserMutualRoomsTest(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() - config["update_user_directory"] = True return self.setup_test_homeserver(config=config) def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main - self.handler = hs.get_user_directory_handler() def _get_mutual_rooms(self, token: str, other_user: str) -> FakeChannel: return self.make_request( From 796a0312e18c0fd93395ced1d910a251f477a820 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 30 May 2022 10:47:09 +0100 Subject: [PATCH 35/74] Bump jsonschema stubs (#12912) --- changelog.d/12912.misc | 1 + poetry.lock | 6 +++--- synapse/events/validator.py | 9 +++++++-- 3 files changed, 11 insertions(+), 5 deletions(-) create mode 100644 changelog.d/12912.misc diff --git a/changelog.d/12912.misc b/changelog.d/12912.misc new file mode 100644 index 0000000000..6396fd9d36 --- /dev/null +++ b/changelog.d/12912.misc @@ -0,0 +1 @@ +Bump types-jsonschema from 4.4.1 to 4.4.6. diff --git a/poetry.lock b/poetry.lock index f64d70941e..6b4686545b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1355,7 +1355,7 @@ python-versions = "*" [[package]] name = "types-jsonschema" -version = "4.4.1" +version = "4.4.6" description = "Typing stubs for jsonschema" category = "dev" optional = false @@ -2618,8 +2618,8 @@ types-ipaddress = [ {file = "types_ipaddress-1.0.8-py3-none-any.whl", hash = "sha256:4933b74da157ba877b1a705d64f6fa7742745e9ffd65e51011f370c11ebedb55"}, ] types-jsonschema = [ - {file = "types-jsonschema-4.4.1.tar.gz", hash = "sha256:bd68b75217ebbb33b0242db10047581dad3b061a963a46ee80d4a9044080663e"}, - {file = "types_jsonschema-4.4.1-py3-none-any.whl", hash = "sha256:ab3ecfdc912d6091cc82f4b7556cfbf1a7cbabc26da0ceaa1cbbc232d1d09971"}, + {file = "types-jsonschema-4.4.6.tar.gz", hash = "sha256:7f2a804618756768c7c0616f8c794b61fcfe3077c7ee1ad47dcf01c5e5f692bb"}, + {file = "types_jsonschema-4.4.6-py3-none-any.whl", hash = "sha256:1db9031ca49a8444d01bd2ce8cf2f89318382b04610953b108321e6f8fb03390"}, ] types-opentracing = [ {file = "types-opentracing-2.4.7.tar.gz", hash = "sha256:be60e9618355aa892571ace002e6b353702538b1c0dc4fbc1c921219d6658830"}, diff --git a/synapse/events/validator.py b/synapse/events/validator.py index 360d24274a..29fa9b3880 100644 --- a/synapse/events/validator.py +++ b/synapse/events/validator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import collections.abc -from typing import Iterable, Type, Union +from typing import Iterable, Type, Union, cast import jsonschema @@ -103,7 +103,12 @@ class EventValidator: except jsonschema.ValidationError as e: if e.path: # example: "users_default": '0' is not of type 'integer' - message = '"' + e.path[-1] + '": ' + e.message # noqa: B306 + # cast safety: path entries can be integers, if we fail to validate + # items in an array. However the POWER_LEVELS_SCHEMA doesn't expect + # to see any arrays. + message = ( + '"' + cast(str, e.path[-1]) + '": ' + e.message # noqa: B306 + ) # jsonschema.ValidationError.message is a valid attribute else: # example: '0' is not of type 'integer' From 72df42078b7926783b27b490f0ebed0411ad7a64 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 30 May 2022 10:47:25 +0100 Subject: [PATCH 36/74] Remove contrib/scripts/kick_users.py (#12908) --- changelog.d/12908.removal | 1 + contrib/scripts/kick_users.py | 88 ----------------------------------- 2 files changed, 1 insertion(+), 88 deletions(-) create mode 100644 changelog.d/12908.removal delete mode 100755 contrib/scripts/kick_users.py diff --git a/changelog.d/12908.removal b/changelog.d/12908.removal new file mode 100644 index 0000000000..a1d05d69e8 --- /dev/null +++ b/changelog.d/12908.removal @@ -0,0 +1 @@ +Remove contributed `kick_users.py` script. This is broken under Python 3, and is not added to the environment when `pip install`ing Synapse. diff --git a/contrib/scripts/kick_users.py b/contrib/scripts/kick_users.py deleted file mode 100755 index f8e0c732fb..0000000000 --- a/contrib/scripts/kick_users.py +++ /dev/null @@ -1,88 +0,0 @@ -#!/usr/bin/env python - -import json -import sys -import urllib -from argparse import ArgumentParser - -import requests - - -def _mkurl(template, kws): - for key in kws: - template = template.replace(key, kws[key]) - return template - - -def main(hs, room_id, access_token, user_id_prefix, why): - if not why: - why = "Automated kick." - print( - "Kicking members on %s in room %s matching %s" % (hs, room_id, user_id_prefix) - ) - room_state_url = _mkurl( - "$HS/_matrix/client/api/v1/rooms/$ROOM/state?access_token=$TOKEN", - {"$HS": hs, "$ROOM": room_id, "$TOKEN": access_token}, - ) - print("Getting room state => %s" % room_state_url) - res = requests.get(room_state_url) - print("HTTP %s" % res.status_code) - state_events = res.json() - if "error" in state_events: - print("FATAL") - print(state_events) - return - - kick_list = [] - room_name = room_id - for event in state_events: - if not event["type"] == "m.room.member": - if event["type"] == "m.room.name": - room_name = event["content"].get("name") - continue - if not event["content"].get("membership") == "join": - continue - if event["state_key"].startswith(user_id_prefix): - kick_list.append(event["state_key"]) - - if len(kick_list) == 0: - print("No user IDs match the prefix '%s'" % user_id_prefix) - return - - print("The following user IDs will be kicked from %s" % room_name) - for uid in kick_list: - print(uid) - doit = input("Continue? [Y]es\n") - if len(doit) > 0 and doit.lower() == "y": - print("Kicking members...") - # encode them all - kick_list = [urllib.quote(uid) for uid in kick_list] - for uid in kick_list: - kick_url = _mkurl( - "$HS/_matrix/client/api/v1/rooms/$ROOM/state/m.room.member/$UID?access_token=$TOKEN", - {"$HS": hs, "$UID": uid, "$ROOM": room_id, "$TOKEN": access_token}, - ) - kick_body = {"membership": "leave", "reason": why} - print("Kicking %s" % uid) - res = requests.put(kick_url, data=json.dumps(kick_body)) - if res.status_code != 200: - print("ERROR: HTTP %s" % res.status_code) - if res.json().get("error"): - print("ERROR: JSON %s" % res.json()) - - -if __name__ == "__main__": - parser = ArgumentParser("Kick members in a room matching a certain user ID prefix.") - parser.add_argument("-u", "--user-id", help="The user ID prefix e.g. '@irc_'") - parser.add_argument("-t", "--token", help="Your access_token") - parser.add_argument("-r", "--room", help="The room ID to kick members in") - parser.add_argument( - "-s", "--homeserver", help="The base HS url e.g. http://matrix.org" - ) - parser.add_argument("-w", "--why", help="Reason for the kick. Optional.") - args = parser.parse_args() - if not args.room or not args.token or not args.user_id or not args.homeserver: - parser.print_help() - sys.exit(1) - else: - main(args.homeserver, args.room, args.token, args.user_id, args.why) From 563ef172ae92a6085f115263375f6238cb35f698 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 30 May 2022 10:47:40 +0100 Subject: [PATCH 37/74] Remove contrib/jitsimeetbridge (#12909) --- changelog.d/12909.removal | 1 + contrib/jitsimeetbridge/jitsimeetbridge.py | 298 -- .../syweb-jitsi-conference.patch | 188 - .../unjingle/strophe.jingle.sdp.js | 712 ---- .../unjingle/strophe.jingle.sdp.util.js | 408 --- .../unjingle/strophe/XMLHttpRequest.js | 254 -- .../unjingle/strophe/base64.js | 83 - .../jitsimeetbridge/unjingle/strophe/md5.js | 279 -- .../unjingle/strophe/strophe.js | 3256 ----------------- contrib/jitsimeetbridge/unjingle/unjingle.js | 48 - debian/changelog | 7 + debian/copyright | 23 - 12 files changed, 8 insertions(+), 5549 deletions(-) create mode 100644 changelog.d/12909.removal delete mode 100644 contrib/jitsimeetbridge/jitsimeetbridge.py delete mode 100644 contrib/jitsimeetbridge/syweb-jitsi-conference.patch delete mode 100644 contrib/jitsimeetbridge/unjingle/strophe.jingle.sdp.js delete mode 100644 contrib/jitsimeetbridge/unjingle/strophe.jingle.sdp.util.js delete mode 100644 contrib/jitsimeetbridge/unjingle/strophe/XMLHttpRequest.js delete mode 100644 contrib/jitsimeetbridge/unjingle/strophe/base64.js delete mode 100644 contrib/jitsimeetbridge/unjingle/strophe/md5.js delete mode 100644 contrib/jitsimeetbridge/unjingle/strophe/strophe.js delete mode 100644 contrib/jitsimeetbridge/unjingle/unjingle.js diff --git a/changelog.d/12909.removal b/changelog.d/12909.removal new file mode 100644 index 0000000000..0baff46ea9 --- /dev/null +++ b/changelog.d/12909.removal @@ -0,0 +1 @@ +Remove `contrib/jitsimeetbridge`. This was an unused experiment that hasn't been meaningfully changed since 2014. diff --git a/contrib/jitsimeetbridge/jitsimeetbridge.py b/contrib/jitsimeetbridge/jitsimeetbridge.py deleted file mode 100644 index b3de468687..0000000000 --- a/contrib/jitsimeetbridge/jitsimeetbridge.py +++ /dev/null @@ -1,298 +0,0 @@ -#!/usr/bin/env python - -""" -This is an attempt at bridging matrix clients into a Jitis meet room via Matrix -video call. It uses hard-coded xml strings overg XMPP BOSH. It can display one -of the streams from the Jitsi bridge until the second lot of SDP comes down and -we set the remote SDP at which point the stream ends. Our video never gets to -the bridge. - -Requires: -npm install jquery jsdom -""" -import json -import subprocess -import time - -import gevent -import grequests -from BeautifulSoup import BeautifulSoup - -ACCESS_TOKEN = "" - -MATRIXBASE = "https://matrix.org/_matrix/client/api/v1/" -MYUSERNAME = "@davetest:matrix.org" - -HTTPBIND = "https://meet.jit.si/http-bind" -# HTTPBIND = 'https://jitsi.vuc.me/http-bind' -# ROOMNAME = "matrix" -ROOMNAME = "pibble" - -HOST = "guest.jit.si" -# HOST="jitsi.vuc.me" - -TURNSERVER = "turn.guest.jit.si" -# TURNSERVER="turn.jitsi.vuc.me" - -ROOMDOMAIN = "meet.jit.si" -# ROOMDOMAIN="conference.jitsi.vuc.me" - - -class TrivialMatrixClient: - def __init__(self, access_token): - self.token = None - self.access_token = access_token - - def getEvent(self): - while True: - url = ( - MATRIXBASE - + "events?access_token=" - + self.access_token - + "&timeout=60000" - ) - if self.token: - url += "&from=" + self.token - req = grequests.get(url) - resps = grequests.map([req]) - obj = json.loads(resps[0].content) - print("incoming from matrix", obj) - if "end" not in obj: - continue - self.token = obj["end"] - if len(obj["chunk"]): - return obj["chunk"][0] - - def joinRoom(self, roomId): - url = MATRIXBASE + "rooms/" + roomId + "/join?access_token=" + self.access_token - print(url) - headers = {"Content-Type": "application/json"} - req = grequests.post(url, headers=headers, data="{}") - resps = grequests.map([req]) - obj = json.loads(resps[0].content) - print("response: ", obj) - - def sendEvent(self, roomId, evType, event): - url = ( - MATRIXBASE - + "rooms/" - + roomId - + "/send/" - + evType - + "?access_token=" - + self.access_token - ) - print(url) - print(json.dumps(event)) - headers = {"Content-Type": "application/json"} - req = grequests.post(url, headers=headers, data=json.dumps(event)) - resps = grequests.map([req]) - obj = json.loads(resps[0].content) - print("response: ", obj) - - -xmppClients = {} - - -def matrixLoop(): - while True: - ev = matrixCli.getEvent() - print(ev) - if ev["type"] == "m.room.member": - print("membership event") - if ev["membership"] == "invite" and ev["state_key"] == MYUSERNAME: - roomId = ev["room_id"] - print("joining room %s" % (roomId)) - matrixCli.joinRoom(roomId) - elif ev["type"] == "m.room.message": - if ev["room_id"] in xmppClients: - print("already have a bridge for that user, ignoring") - continue - print("got message, connecting") - xmppClients[ev["room_id"]] = TrivialXmppClient(ev["room_id"], ev["user_id"]) - gevent.spawn(xmppClients[ev["room_id"]].xmppLoop) - elif ev["type"] == "m.call.invite": - print("Incoming call") - # sdp = ev['content']['offer']['sdp'] - # print "sdp: %s" % (sdp) - # xmppClients[ev['room_id']] = TrivialXmppClient(ev['room_id'], ev['user_id']) - # gevent.spawn(xmppClients[ev['room_id']].xmppLoop) - elif ev["type"] == "m.call.answer": - print("Call answered") - sdp = ev["content"]["answer"]["sdp"] - if ev["room_id"] not in xmppClients: - print("We didn't have a call for that room") - continue - # should probably check call ID too - xmppCli = xmppClients[ev["room_id"]] - xmppCli.sendAnswer(sdp) - elif ev["type"] == "m.call.hangup": - if ev["room_id"] in xmppClients: - xmppClients[ev["room_id"]].stop() - del xmppClients[ev["room_id"]] - - -class TrivialXmppClient: - def __init__(self, matrixRoom, userId): - self.rid = 0 - self.matrixRoom = matrixRoom - self.userId = userId - self.running = True - - def stop(self): - self.running = False - - def nextRid(self): - self.rid += 1 - return "%d" % (self.rid) - - def sendIq(self, xml): - fullXml = ( - "%s" - % (self.nextRid(), self.sid, xml) - ) - # print "\t>>>%s" % (fullXml) - return self.xmppPoke(fullXml) - - def xmppPoke(self, xml): - headers = {"Content-Type": "application/xml"} - req = grequests.post(HTTPBIND, verify=False, headers=headers, data=xml) - resps = grequests.map([req]) - obj = BeautifulSoup(resps[0].content) - return obj - - def sendAnswer(self, answer): - print("sdp from matrix client", answer) - p = subprocess.Popen( - ["node", "unjingle/unjingle.js", "--sdp"], - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - ) - jingle, out_err = p.communicate(answer) - jingle = jingle % { - "tojid": self.callfrom, - "action": "session-accept", - "initiator": self.callfrom, - "responder": self.jid, - "sid": self.callsid, - } - print("answer jingle from sdp", jingle) - res = self.sendIq(jingle) - print("reply from answer: ", res) - - self.ssrcs = {} - jingleSoup = BeautifulSoup(jingle) - for cont in jingleSoup.iq.jingle.findAll("content"): - if cont.description: - self.ssrcs[cont["name"]] = cont.description["ssrc"] - print("my ssrcs:", self.ssrcs) - - gevent.joinall([gevent.spawn(self.advertiseSsrcs)]) - - def advertiseSsrcs(self): - time.sleep(7) - print("SSRC spammer started") - while self.running: - ssrcMsg = ( - "%(nick)s" - % { - "tojid": "%s@%s/%s" % (ROOMNAME, ROOMDOMAIN, self.shortJid), - "nick": self.userId, - "assrc": self.ssrcs["audio"], - "vssrc": self.ssrcs["video"], - } - ) - res = self.sendIq(ssrcMsg) - print("reply from ssrc announce: ", res) - time.sleep(10) - - def xmppLoop(self): - self.matrixCallId = time.time() - res = self.xmppPoke( - "" - % (self.nextRid(), HOST) - ) - - print(res) - self.sid = res.body["sid"] - print("sid %s" % (self.sid)) - - res = self.sendIq( - "" - ) - - res = self.xmppPoke( - "" - % (self.nextRid(), self.sid, HOST) - ) - - res = self.sendIq( - "" - ) - print(res) - - self.jid = res.body.iq.bind.jid.string - print("jid: %s" % (self.jid)) - self.shortJid = self.jid.split("-")[0] - - res = self.sendIq( - "" - ) - - # randomthing = res.body.iq['to'] - # whatsitpart = randomthing.split('-')[0] - - # print "other random bind thing: %s" % (randomthing) - - # advertise preence to the jitsi room, with our nick - res = self.sendIq( - "%s" - % (HOST, TURNSERVER, ROOMNAME, ROOMDOMAIN, self.userId) - ) - self.muc = {"users": []} - for p in res.body.findAll("presence"): - u = {} - u["shortJid"] = p["from"].split("/")[1] - if p.c and p.c.nick: - u["nick"] = p.c.nick.string - self.muc["users"].append(u) - print("muc: ", self.muc) - - # wait for stuff - while True: - print("waiting...") - res = self.sendIq("") - print("got from stream: ", res) - if res.body.iq: - jingles = res.body.iq.findAll("jingle") - if len(jingles): - self.callfrom = res.body.iq["from"] - self.handleInvite(jingles[0]) - elif "type" in res.body and res.body["type"] == "terminate": - self.running = False - del xmppClients[self.matrixRoom] - return - - def handleInvite(self, jingle): - self.initiator = jingle["initiator"] - self.callsid = jingle["sid"] - p = subprocess.Popen( - ["node", "unjingle/unjingle.js", "--jingle"], - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - ) - print("raw jingle invite", str(jingle)) - sdp, out_err = p.communicate(str(jingle)) - print("transformed remote offer sdp", sdp) - inviteEvent = { - "offer": {"type": "offer", "sdp": sdp}, - "call_id": self.matrixCallId, - "version": 0, - "lifetime": 30000, - } - matrixCli.sendEvent(self.matrixRoom, "m.call.invite", inviteEvent) - - -matrixCli = TrivialMatrixClient(ACCESS_TOKEN) # Undefined name - -gevent.joinall([gevent.spawn(matrixLoop)]) diff --git a/contrib/jitsimeetbridge/syweb-jitsi-conference.patch b/contrib/jitsimeetbridge/syweb-jitsi-conference.patch deleted file mode 100644 index aed23c78aa..0000000000 --- a/contrib/jitsimeetbridge/syweb-jitsi-conference.patch +++ /dev/null @@ -1,188 +0,0 @@ -diff --git a/syweb/webclient/app/components/matrix/matrix-call.js b/syweb/webclient/app/components/matrix/matrix-call.js -index 9fbfff0..dc68077 100644 ---- a/syweb/webclient/app/components/matrix/matrix-call.js -+++ b/syweb/webclient/app/components/matrix/matrix-call.js -@@ -16,6 +16,45 @@ limitations under the License. - - 'use strict'; - -+ -+function sendKeyframe(pc) { -+ console.log('sendkeyframe', pc.iceConnectionState); -+ if (pc.iceConnectionState !== 'connected') return; // safe... -+ pc.setRemoteDescription( -+ pc.remoteDescription, -+ function () { -+ pc.createAnswer( -+ function (modifiedAnswer) { -+ pc.setLocalDescription( -+ modifiedAnswer, -+ function () { -+ // noop -+ }, -+ function (error) { -+ console.log('triggerKeyframe setLocalDescription failed', error); -+ messageHandler.showError(); -+ } -+ ); -+ }, -+ function (error) { -+ console.log('triggerKeyframe createAnswer failed', error); -+ messageHandler.showError(); -+ } -+ ); -+ }, -+ function (error) { -+ console.log('triggerKeyframe setRemoteDescription failed', error); -+ messageHandler.showError(); -+ } -+ ); -+} -+ -+ -+ -+ -+ -+ -+ - var forAllVideoTracksOnStream = function(s, f) { - var tracks = s.getVideoTracks(); - for (var i = 0; i < tracks.length; i++) { -@@ -83,7 +122,7 @@ angular.module('MatrixCall', []) - } - - // FIXME: we should prevent any calls from being placed or accepted before this has finished -- MatrixCall.getTurnServer(); -+ //MatrixCall.getTurnServer(); - - MatrixCall.CALL_TIMEOUT = 60000; - MatrixCall.FALLBACK_STUN_SERVER = 'stun:stun.l.google.com:19302'; -@@ -132,6 +171,22 @@ angular.module('MatrixCall', []) - pc.onsignalingstatechange = function() { self.onSignallingStateChanged(); }; - pc.onicecandidate = function(c) { self.gotLocalIceCandidate(c); }; - pc.onaddstream = function(s) { self.onAddStream(s); }; -+ -+ var datachan = pc.createDataChannel('RTCDataChannel', { -+ reliable: false -+ }); -+ console.log("data chan: "+datachan); -+ datachan.onopen = function() { -+ console.log("data channel open"); -+ }; -+ datachan.onmessage = function() { -+ console.log("data channel message"); -+ }; -+ pc.ondatachannel = function(event) { -+ console.log("have data channel"); -+ event.channel.binaryType = 'blob'; -+ }; -+ - return pc; - } - -@@ -200,6 +255,12 @@ angular.module('MatrixCall', []) - }, this.msg.lifetime - event.age); - }; - -+ MatrixCall.prototype.receivedInvite = function(event) { -+ console.log("Got second invite for call "+this.call_id); -+ this.peerConn.setRemoteDescription(new RTCSessionDescription(this.msg.offer), this.onSetRemoteDescriptionSuccess, this.onSetRemoteDescriptionError); -+ }; -+ -+ - // perverse as it may seem, sometimes we want to instantiate a call with a hangup message - // (because when getting the state of the room on load, events come in reverse order and - // we want to remember that a call has been hung up) -@@ -349,7 +410,7 @@ angular.module('MatrixCall', []) - 'mandatory': { - 'OfferToReceiveAudio': true, - 'OfferToReceiveVideo': this.type == 'video' -- }, -+ } - }; - this.peerConn.createAnswer(function(d) { self.createdAnswer(d); }, function(e) {}, constraints); - // This can't be in an apply() because it's called by a predecessor call under glare conditions :( -@@ -359,8 +420,20 @@ angular.module('MatrixCall', []) - MatrixCall.prototype.gotLocalIceCandidate = function(event) { - if (event.candidate) { - console.log("Got local ICE "+event.candidate.sdpMid+" candidate: "+event.candidate.candidate); -- this.sendCandidate(event.candidate); -- } -+ //this.sendCandidate(event.candidate); -+ } else { -+ console.log("have all candidates, sending answer"); -+ var content = { -+ version: 0, -+ call_id: this.call_id, -+ answer: this.peerConn.localDescription -+ }; -+ this.sendEventWithRetry('m.call.answer', content); -+ var self = this; -+ $rootScope.$apply(function() { -+ self.state = 'connecting'; -+ }); -+ } - } - - MatrixCall.prototype.gotRemoteIceCandidate = function(cand) { -@@ -418,15 +491,6 @@ angular.module('MatrixCall', []) - console.log("Created answer: "+description); - var self = this; - this.peerConn.setLocalDescription(description, function() { -- var content = { -- version: 0, -- call_id: self.call_id, -- answer: self.peerConn.localDescription -- }; -- self.sendEventWithRetry('m.call.answer', content); -- $rootScope.$apply(function() { -- self.state = 'connecting'; -- }); - }, function() { console.log("Error setting local description!"); } ); - }; - -@@ -448,6 +512,9 @@ angular.module('MatrixCall', []) - $rootScope.$apply(function() { - self.state = 'connected'; - self.didConnect = true; -+ /*$timeout(function() { -+ sendKeyframe(self.peerConn); -+ }, 1000);*/ - }); - } else if (this.peerConn.iceConnectionState == 'failed') { - this.hangup('ice_failed'); -@@ -518,6 +585,7 @@ angular.module('MatrixCall', []) - - MatrixCall.prototype.onRemoteStreamEnded = function(event) { - console.log("Remote stream ended"); -+ return; - var self = this; - $rootScope.$apply(function() { - self.state = 'ended'; -diff --git a/syweb/webclient/app/components/matrix/matrix-phone-service.js b/syweb/webclient/app/components/matrix/matrix-phone-service.js -index 55dbbf5..272fa27 100644 ---- a/syweb/webclient/app/components/matrix/matrix-phone-service.js -+++ b/syweb/webclient/app/components/matrix/matrix-phone-service.js -@@ -48,6 +48,13 @@ angular.module('matrixPhoneService', []) - return; - } - -+ // do we already have an entry for this call ID? -+ var existingEntry = matrixPhoneService.allCalls[msg.call_id]; -+ if (existingEntry) { -+ existingEntry.receivedInvite(msg); -+ return; -+ } -+ - var call = undefined; - if (!isLive) { - // if this event wasn't live then this call may already be over -@@ -108,7 +115,7 @@ angular.module('matrixPhoneService', []) - call.hangup(); - } - } else { -- $rootScope.$broadcast(matrixPhoneService.INCOMING_CALL_EVENT, call); -+ $rootScope.$broadcast(matrixPhoneService.INCOMING_CALL_EVENT, call); - } - } else if (event.type == 'm.call.answer') { - var call = matrixPhoneService.allCalls[msg.call_id]; diff --git a/contrib/jitsimeetbridge/unjingle/strophe.jingle.sdp.js b/contrib/jitsimeetbridge/unjingle/strophe.jingle.sdp.js deleted file mode 100644 index e99dd7bf96..0000000000 --- a/contrib/jitsimeetbridge/unjingle/strophe.jingle.sdp.js +++ /dev/null @@ -1,712 +0,0 @@ -/* jshint -W117 */ -// SDP STUFF -function SDP(sdp) { - this.media = sdp.split('\r\nm='); - for (var i = 1; i < this.media.length; i++) { - this.media[i] = 'm=' + this.media[i]; - if (i != this.media.length - 1) { - this.media[i] += '\r\n'; - } - } - this.session = this.media.shift() + '\r\n'; - this.raw = this.session + this.media.join(''); -} - -exports.SDP = SDP; - -var jsdom = require("jsdom"); -var window = jsdom.jsdom().parentWindow; -var $ = require('jquery')(window); - -var SDPUtil = require('./strophe.jingle.sdp.util.js').SDPUtil; - -/** - * Returns map of MediaChannel mapped per channel idx. - */ -SDP.prototype.getMediaSsrcMap = function() { - var self = this; - var media_ssrcs = {}; - for (channelNum = 0; channelNum < self.media.length; channelNum++) { - modified = true; - tmp = SDPUtil.find_lines(self.media[channelNum], 'a=ssrc:'); - var type = SDPUtil.parse_mid(SDPUtil.find_line(self.media[channelNum], 'a=mid:')); - var channel = new MediaChannel(channelNum, type); - media_ssrcs[channelNum] = channel; - tmp.forEach(function (line) { - var linessrc = line.substring(7).split(' ')[0]; - // allocate new ChannelSsrc - if(!channel.ssrcs[linessrc]) { - channel.ssrcs[linessrc] = new ChannelSsrc(linessrc, type); - } - channel.ssrcs[linessrc].lines.push(line); - }); - tmp = SDPUtil.find_lines(self.media[channelNum], 'a=ssrc-group:'); - tmp.forEach(function(line){ - var semantics = line.substr(0, idx).substr(13); - var ssrcs = line.substr(14 + semantics.length).split(' '); - if (ssrcs.length != 0) { - var ssrcGroup = new ChannelSsrcGroup(semantics, ssrcs); - channel.ssrcGroups.push(ssrcGroup); - } - }); - } - return media_ssrcs; -}; -/** - * Returns true if this SDP contains given SSRC. - * @param ssrc the ssrc to check. - * @returns {boolean} true if this SDP contains given SSRC. - */ -SDP.prototype.containsSSRC = function(ssrc) { - var channels = this.getMediaSsrcMap(); - var contains = false; - Object.keys(channels).forEach(function(chNumber){ - var channel = channels[chNumber]; - //console.log("Check", channel, ssrc); - if(Object.keys(channel.ssrcs).indexOf(ssrc) != -1){ - contains = true; - } - }); - return contains; -}; - -/** - * Returns map of MediaChannel that contains only media not contained in otherSdp. Mapped by channel idx. - * @param otherSdp the other SDP to check ssrc with. - */ -SDP.prototype.getNewMedia = function(otherSdp) { - - // this could be useful in Array.prototype. - function arrayEquals(array) { - // if the other array is a falsy value, return - if (!array) - return false; - - // compare lengths - can save a lot of time - if (this.length != array.length) - return false; - - for (var i = 0, l=this.length; i < l; i++) { - // Check if we have nested arrays - if (this[i] instanceof Array && array[i] instanceof Array) { - // recurse into the nested arrays - if (!this[i].equals(array[i])) - return false; - } - else if (this[i] != array[i]) { - // Warning - two different object instances will never be equal: {x:20} != {x:20} - return false; - } - } - return true; - } - - var myMedia = this.getMediaSsrcMap(); - var othersMedia = otherSdp.getMediaSsrcMap(); - var newMedia = {}; - Object.keys(othersMedia).forEach(function(channelNum) { - var myChannel = myMedia[channelNum]; - var othersChannel = othersMedia[channelNum]; - if(!myChannel && othersChannel) { - // Add whole channel - newMedia[channelNum] = othersChannel; - return; - } - // Look for new ssrcs accross the channel - Object.keys(othersChannel.ssrcs).forEach(function(ssrc) { - if(Object.keys(myChannel.ssrcs).indexOf(ssrc) === -1) { - // Allocate channel if we've found ssrc that doesn't exist in our channel - if(!newMedia[channelNum]){ - newMedia[channelNum] = new MediaChannel(othersChannel.chNumber, othersChannel.mediaType); - } - newMedia[channelNum].ssrcs[ssrc] = othersChannel.ssrcs[ssrc]; - } - }); - - // Look for new ssrc groups across the channels - othersChannel.ssrcGroups.forEach(function(otherSsrcGroup){ - - // try to match the other ssrc-group with an ssrc-group of ours - var matched = false; - for (var i = 0; i < myChannel.ssrcGroups.length; i++) { - var mySsrcGroup = myChannel.ssrcGroups[i]; - if (otherSsrcGroup.semantics == mySsrcGroup.semantics - && arrayEquals.apply(otherSsrcGroup.ssrcs, [mySsrcGroup.ssrcs])) { - - matched = true; - break; - } - } - - if (!matched) { - // Allocate channel if we've found an ssrc-group that doesn't - // exist in our channel - - if(!newMedia[channelNum]){ - newMedia[channelNum] = new MediaChannel(othersChannel.chNumber, othersChannel.mediaType); - } - newMedia[channelNum].ssrcGroups.push(otherSsrcGroup); - } - }); - }); - return newMedia; -}; - -// remove iSAC and CN from SDP -SDP.prototype.mangle = function () { - var i, j, mline, lines, rtpmap, newdesc; - for (i = 0; i < this.media.length; i++) { - lines = this.media[i].split('\r\n'); - lines.pop(); // remove empty last element - mline = SDPUtil.parse_mline(lines.shift()); - if (mline.media != 'audio') - continue; - newdesc = ''; - mline.fmt.length = 0; - for (j = 0; j < lines.length; j++) { - if (lines[j].substr(0, 9) == 'a=rtpmap:') { - rtpmap = SDPUtil.parse_rtpmap(lines[j]); - if (rtpmap.name == 'CN' || rtpmap.name == 'ISAC') - continue; - mline.fmt.push(rtpmap.id); - newdesc += lines[j] + '\r\n'; - } else { - newdesc += lines[j] + '\r\n'; - } - } - this.media[i] = SDPUtil.build_mline(mline) + '\r\n'; - this.media[i] += newdesc; - } - this.raw = this.session + this.media.join(''); -}; - -// remove lines matching prefix from session section -SDP.prototype.removeSessionLines = function(prefix) { - var self = this; - var lines = SDPUtil.find_lines(this.session, prefix); - lines.forEach(function(line) { - self.session = self.session.replace(line + '\r\n', ''); - }); - this.raw = this.session + this.media.join(''); - return lines; -} -// remove lines matching prefix from a media section specified by mediaindex -// TODO: non-numeric mediaindex could match mid -SDP.prototype.removeMediaLines = function(mediaindex, prefix) { - var self = this; - var lines = SDPUtil.find_lines(this.media[mediaindex], prefix); - lines.forEach(function(line) { - self.media[mediaindex] = self.media[mediaindex].replace(line + '\r\n', ''); - }); - this.raw = this.session + this.media.join(''); - return lines; -} - -// add content's to a jingle element -SDP.prototype.toJingle = function (elem, thecreator) { - var i, j, k, mline, ssrc, rtpmap, tmp, line, lines; - var self = this; - // new bundle plan - if (SDPUtil.find_line(this.session, 'a=group:')) { - lines = SDPUtil.find_lines(this.session, 'a=group:'); - for (i = 0; i < lines.length; i++) { - tmp = lines[i].split(' '); - var semantics = tmp.shift().substr(8); - elem.c('group', {xmlns: 'urn:xmpp:jingle:apps:grouping:0', semantics:semantics}); - for (j = 0; j < tmp.length; j++) { - elem.c('content', {name: tmp[j]}).up(); - } - elem.up(); - } - } - // old bundle plan, to be removed - var bundle = []; - if (SDPUtil.find_line(this.session, 'a=group:BUNDLE')) { - bundle = SDPUtil.find_line(this.session, 'a=group:BUNDLE ').split(' '); - bundle.shift(); - } - for (i = 0; i < this.media.length; i++) { - mline = SDPUtil.parse_mline(this.media[i].split('\r\n')[0]); - if (!(mline.media === 'audio' || - mline.media === 'video' || - mline.media === 'application')) - { - continue; - } - if (SDPUtil.find_line(this.media[i], 'a=ssrc:')) { - ssrc = SDPUtil.find_line(this.media[i], 'a=ssrc:').substring(7).split(' ')[0]; // take the first - } else { - ssrc = false; - } - - elem.c('content', {creator: thecreator, name: mline.media}); - if (SDPUtil.find_line(this.media[i], 'a=mid:')) { - // prefer identifier from a=mid if present - var mid = SDPUtil.parse_mid(SDPUtil.find_line(this.media[i], 'a=mid:')); - elem.attrs({ name: mid }); - - // old BUNDLE plan, to be removed - if (bundle.indexOf(mid) !== -1) { - elem.c('bundle', {xmlns: 'http://estos.de/ns/bundle'}).up(); - bundle.splice(bundle.indexOf(mid), 1); - } - } - - if (SDPUtil.find_line(this.media[i], 'a=rtpmap:').length) - { - elem.c('description', - {xmlns: 'urn:xmpp:jingle:apps:rtp:1', - media: mline.media }); - if (ssrc) { - elem.attrs({ssrc: ssrc}); - } - for (j = 0; j < mline.fmt.length; j++) { - rtpmap = SDPUtil.find_line(this.media[i], 'a=rtpmap:' + mline.fmt[j]); - elem.c('payload-type', SDPUtil.parse_rtpmap(rtpmap)); - // put any 'a=fmtp:' + mline.fmt[j] lines into - if (SDPUtil.find_line(this.media[i], 'a=fmtp:' + mline.fmt[j])) { - tmp = SDPUtil.parse_fmtp(SDPUtil.find_line(this.media[i], 'a=fmtp:' + mline.fmt[j])); - for (k = 0; k < tmp.length; k++) { - elem.c('parameter', tmp[k]).up(); - } - } - this.RtcpFbToJingle(i, elem, mline.fmt[j]); // XEP-0293 -- map a=rtcp-fb - - elem.up(); - } - if (SDPUtil.find_line(this.media[i], 'a=crypto:', this.session)) { - elem.c('encryption', {required: 1}); - var crypto = SDPUtil.find_lines(this.media[i], 'a=crypto:', this.session); - crypto.forEach(function(line) { - elem.c('crypto', SDPUtil.parse_crypto(line)).up(); - }); - elem.up(); // end of encryption - } - - if (ssrc) { - // new style mapping - elem.c('source', { ssrc: ssrc, xmlns: 'urn:xmpp:jingle:apps:rtp:ssma:0' }); - // FIXME: group by ssrc and support multiple different ssrcs - var ssrclines = SDPUtil.find_lines(this.media[i], 'a=ssrc:'); - ssrclines.forEach(function(line) { - idx = line.indexOf(' '); - var linessrc = line.substr(0, idx).substr(7); - if (linessrc != ssrc) { - elem.up(); - ssrc = linessrc; - elem.c('source', { ssrc: ssrc, xmlns: 'urn:xmpp:jingle:apps:rtp:ssma:0' }); - } - var kv = line.substr(idx + 1); - elem.c('parameter'); - if (kv.indexOf(':') == -1) { - elem.attrs({ name: kv }); - } else { - elem.attrs({ name: kv.split(':', 2)[0] }); - elem.attrs({ value: kv.split(':', 2)[1] }); - } - elem.up(); - }); - elem.up(); - - // old proprietary mapping, to be removed at some point - tmp = SDPUtil.parse_ssrc(this.media[i]); - tmp.xmlns = 'http://estos.de/ns/ssrc'; - tmp.ssrc = ssrc; - elem.c('ssrc', tmp).up(); // ssrc is part of description - - // XEP-0339 handle ssrc-group attributes - var ssrc_group_lines = SDPUtil.find_lines(this.media[i], 'a=ssrc-group:'); - ssrc_group_lines.forEach(function(line) { - idx = line.indexOf(' '); - var semantics = line.substr(0, idx).substr(13); - var ssrcs = line.substr(14 + semantics.length).split(' '); - if (ssrcs.length != 0) { - elem.c('ssrc-group', { semantics: semantics, xmlns: 'urn:xmpp:jingle:apps:rtp:ssma:0' }); - ssrcs.forEach(function(ssrc) { - elem.c('source', { ssrc: ssrc }) - .up(); - }); - elem.up(); - } - }); - } - - if (SDPUtil.find_line(this.media[i], 'a=rtcp-mux')) { - elem.c('rtcp-mux').up(); - } - - // XEP-0293 -- map a=rtcp-fb:* - this.RtcpFbToJingle(i, elem, '*'); - - // XEP-0294 - if (SDPUtil.find_line(this.media[i], 'a=extmap:')) { - lines = SDPUtil.find_lines(this.media[i], 'a=extmap:'); - for (j = 0; j < lines.length; j++) { - tmp = SDPUtil.parse_extmap(lines[j]); - elem.c('rtp-hdrext', { xmlns: 'urn:xmpp:jingle:apps:rtp:rtp-hdrext:0', - uri: tmp.uri, - id: tmp.value }); - if (tmp.hasOwnProperty('direction')) { - switch (tmp.direction) { - case 'sendonly': - elem.attrs({senders: 'responder'}); - break; - case 'recvonly': - elem.attrs({senders: 'initiator'}); - break; - case 'sendrecv': - elem.attrs({senders: 'both'}); - break; - case 'inactive': - elem.attrs({senders: 'none'}); - break; - } - } - // TODO: handle params - elem.up(); - } - } - elem.up(); // end of description - } - - // map ice-ufrag/pwd, dtls fingerprint, candidates - this.TransportToJingle(i, elem); - - if (SDPUtil.find_line(this.media[i], 'a=sendrecv', this.session)) { - elem.attrs({senders: 'both'}); - } else if (SDPUtil.find_line(this.media[i], 'a=sendonly', this.session)) { - elem.attrs({senders: 'initiator'}); - } else if (SDPUtil.find_line(this.media[i], 'a=recvonly', this.session)) { - elem.attrs({senders: 'responder'}); - } else if (SDPUtil.find_line(this.media[i], 'a=inactive', this.session)) { - elem.attrs({senders: 'none'}); - } - if (mline.port == '0') { - // estos hack to reject an m-line - elem.attrs({senders: 'rejected'}); - } - elem.up(); // end of content - } - elem.up(); - return elem; -}; - -SDP.prototype.TransportToJingle = function (mediaindex, elem) { - var i = mediaindex; - var tmp; - var self = this; - elem.c('transport'); - - // XEP-0343 DTLS/SCTP - if (SDPUtil.find_line(this.media[mediaindex], 'a=sctpmap:').length) - { - var sctpmap = SDPUtil.find_line( - this.media[i], 'a=sctpmap:', self.session); - if (sctpmap) - { - var sctpAttrs = SDPUtil.parse_sctpmap(sctpmap); - elem.c('sctpmap', - { - xmlns: 'urn:xmpp:jingle:transports:dtls-sctp:1', - number: sctpAttrs[0], /* SCTP port */ - protocol: sctpAttrs[1], /* protocol */ - }); - // Optional stream count attribute - if (sctpAttrs.length > 2) - elem.attrs({ streams: sctpAttrs[2]}); - elem.up(); - } - } - // XEP-0320 - var fingerprints = SDPUtil.find_lines(this.media[mediaindex], 'a=fingerprint:', this.session); - fingerprints.forEach(function(line) { - tmp = SDPUtil.parse_fingerprint(line); - tmp.xmlns = 'urn:xmpp:jingle:apps:dtls:0'; - elem.c('fingerprint').t(tmp.fingerprint); - delete tmp.fingerprint; - line = SDPUtil.find_line(self.media[mediaindex], 'a=setup:', self.session); - if (line) { - tmp.setup = line.substr(8); - } - elem.attrs(tmp); - elem.up(); // end of fingerprint - }); - tmp = SDPUtil.iceparams(this.media[mediaindex], this.session); - if (tmp) { - tmp.xmlns = 'urn:xmpp:jingle:transports:ice-udp:1'; - elem.attrs(tmp); - // XEP-0176 - if (SDPUtil.find_line(this.media[mediaindex], 'a=candidate:', this.session)) { // add any a=candidate lines - var lines = SDPUtil.find_lines(this.media[mediaindex], 'a=candidate:', this.session); - lines.forEach(function (line) { - elem.c('candidate', SDPUtil.candidateToJingle(line)).up(); - }); - } - } - elem.up(); // end of transport -} - -SDP.prototype.RtcpFbToJingle = function (mediaindex, elem, payloadtype) { // XEP-0293 - var lines = SDPUtil.find_lines(this.media[mediaindex], 'a=rtcp-fb:' + payloadtype); - lines.forEach(function (line) { - var tmp = SDPUtil.parse_rtcpfb(line); - if (tmp.type == 'trr-int') { - elem.c('rtcp-fb-trr-int', {xmlns: 'urn:xmpp:jingle:apps:rtp:rtcp-fb:0', value: tmp.params[0]}); - elem.up(); - } else { - elem.c('rtcp-fb', {xmlns: 'urn:xmpp:jingle:apps:rtp:rtcp-fb:0', type: tmp.type}); - if (tmp.params.length > 0) { - elem.attrs({'subtype': tmp.params[0]}); - } - elem.up(); - } - }); -}; - -SDP.prototype.RtcpFbFromJingle = function (elem, payloadtype) { // XEP-0293 - var media = ''; - var tmp = elem.find('>rtcp-fb-trr-int[xmlns="urn:xmpp:jingle:apps:rtp:rtcp-fb:0"]'); - if (tmp.length) { - media += 'a=rtcp-fb:' + '*' + ' ' + 'trr-int' + ' '; - if (tmp.attr('value')) { - media += tmp.attr('value'); - } else { - media += '0'; - } - media += '\r\n'; - } - tmp = elem.find('>rtcp-fb[xmlns="urn:xmpp:jingle:apps:rtp:rtcp-fb:0"]'); - tmp.each(function () { - media += 'a=rtcp-fb:' + payloadtype + ' ' + $(this).attr('type'); - if ($(this).attr('subtype')) { - media += ' ' + $(this).attr('subtype'); - } - media += '\r\n'; - }); - return media; -}; - -// construct an SDP from a jingle stanza -SDP.prototype.fromJingle = function (jingle) { - var self = this; - this.raw = 'v=0\r\n' + - 'o=- ' + '1923518516' + ' 2 IN IP4 0.0.0.0\r\n' +// FIXME - 's=-\r\n' + - 't=0 0\r\n'; - // http://tools.ietf.org/html/draft-ietf-mmusic-sdp-bundle-negotiation-04#section-8 - if ($(jingle).find('>group[xmlns="urn:xmpp:jingle:apps:grouping:0"]').length) { - $(jingle).find('>group[xmlns="urn:xmpp:jingle:apps:grouping:0"]').each(function (idx, group) { - var contents = $(group).find('>content').map(function (idx, content) { - return content.getAttribute('name'); - }).get(); - if (contents.length > 0) { - self.raw += 'a=group:' + (group.getAttribute('semantics') || group.getAttribute('type')) + ' ' + contents.join(' ') + '\r\n'; - } - }); - } else if ($(jingle).find('>group[xmlns="urn:ietf:rfc:5888"]').length) { - // temporary namespace, not to be used. to be removed soon. - $(jingle).find('>group[xmlns="urn:ietf:rfc:5888"]').each(function (idx, group) { - var contents = $(group).find('>content').map(function (idx, content) { - return content.getAttribute('name'); - }).get(); - if (group.getAttribute('type') !== null && contents.length > 0) { - self.raw += 'a=group:' + group.getAttribute('type') + ' ' + contents.join(' ') + '\r\n'; - } - }); - } else { - // for backward compability, to be removed soon - // assume all contents are in the same bundle group, can be improved upon later - var bundle = $(jingle).find('>content').filter(function (idx, content) { - //elem.c('bundle', {xmlns:'http://estos.de/ns/bundle'}); - return $(content).find('>bundle').length > 0; - }).map(function (idx, content) { - return content.getAttribute('name'); - }).get(); - if (bundle.length) { - this.raw += 'a=group:BUNDLE ' + bundle.join(' ') + '\r\n'; - } - } - - this.session = this.raw; - jingle.find('>content').each(function () { - var m = self.jingle2media($(this)); - self.media.push(m); - }); - - // reconstruct msid-semantic -- apparently not necessary - /* - var msid = SDPUtil.parse_ssrc(this.raw); - if (msid.hasOwnProperty('mslabel')) { - this.session += "a=msid-semantic: WMS " + msid.mslabel + "\r\n"; - } - */ - - this.raw = this.session + this.media.join(''); -}; - -// translate a jingle content element into an an SDP media part -SDP.prototype.jingle2media = function (content) { - var media = '', - desc = content.find('description'), - ssrc = desc.attr('ssrc'), - self = this, - tmp; - var sctp = content.find( - '>transport>sctpmap[xmlns="urn:xmpp:jingle:transports:dtls-sctp:1"]'); - - tmp = { media: desc.attr('media') }; - tmp.port = '1'; - if (content.attr('senders') == 'rejected') { - // estos hack to reject an m-line. - tmp.port = '0'; - } - if (content.find('>transport>fingerprint').length || desc.find('encryption').length) { - if (sctp.length) - tmp.proto = 'DTLS/SCTP'; - else - tmp.proto = 'RTP/SAVPF'; - } else { - tmp.proto = 'RTP/AVPF'; - } - if (!sctp.length) - { - tmp.fmt = desc.find('payload-type').map( - function () { return this.getAttribute('id'); }).get(); - media += SDPUtil.build_mline(tmp) + '\r\n'; - } - else - { - media += 'm=application 1 DTLS/SCTP ' + sctp.attr('number') + '\r\n'; - media += 'a=sctpmap:' + sctp.attr('number') + - ' ' + sctp.attr('protocol'); - - var streamCount = sctp.attr('streams'); - if (streamCount) - media += ' ' + streamCount + '\r\n'; - else - media += '\r\n'; - } - - media += 'c=IN IP4 0.0.0.0\r\n'; - if (!sctp.length) - media += 'a=rtcp:1 IN IP4 0.0.0.0\r\n'; - //tmp = content.find('>transport[xmlns="urn:xmpp:jingle:transports:ice-udp:1"]'); - tmp = content.find('>bundle>transport[xmlns="urn:xmpp:jingle:transports:ice-udp:1"]'); - //console.log('transports: '+content.find('>transport[xmlns="urn:xmpp:jingle:transports:ice-udp:1"]').length); - //console.log('bundle.transports: '+content.find('>bundle>transport[xmlns="urn:xmpp:jingle:transports:ice-udp:1"]').length); - //console.log("tmp fingerprint: "+tmp.find('>fingerprint').innerHTML); - if (tmp.length) { - if (tmp.attr('ufrag')) { - media += SDPUtil.build_iceufrag(tmp.attr('ufrag')) + '\r\n'; - } - if (tmp.attr('pwd')) { - media += SDPUtil.build_icepwd(tmp.attr('pwd')) + '\r\n'; - } - tmp.find('>fingerprint').each(function () { - // FIXME: check namespace at some point - media += 'a=fingerprint:' + this.getAttribute('hash'); - media += ' ' + $(this).text(); - media += '\r\n'; - //console.log("mline "+media); - if (this.getAttribute('setup')) { - media += 'a=setup:' + this.getAttribute('setup') + '\r\n'; - } - }); - } - switch (content.attr('senders')) { - case 'initiator': - media += 'a=sendonly\r\n'; - break; - case 'responder': - media += 'a=recvonly\r\n'; - break; - case 'none': - media += 'a=inactive\r\n'; - break; - case 'both': - media += 'a=sendrecv\r\n'; - break; - } - media += 'a=mid:' + content.attr('name') + '\r\n'; - /*if (content.attr('name') == 'video') { - media += 'a=x-google-flag:conference' + '\r\n'; - }*/ - - // - // see http://code.google.com/p/libjingle/issues/detail?id=309 -- no spec though - // and http://mail.jabber.org/pipermail/jingle/2011-December/001761.html - if (desc.find('rtcp-mux').length) { - media += 'a=rtcp-mux\r\n'; - } - - if (desc.find('encryption').length) { - desc.find('encryption>crypto').each(function () { - media += 'a=crypto:' + this.getAttribute('tag'); - media += ' ' + this.getAttribute('crypto-suite'); - media += ' ' + this.getAttribute('key-params'); - if (this.getAttribute('session-params')) { - media += ' ' + this.getAttribute('session-params'); - } - media += '\r\n'; - }); - } - desc.find('payload-type').each(function () { - media += SDPUtil.build_rtpmap(this) + '\r\n'; - if ($(this).find('>parameter').length) { - media += 'a=fmtp:' + this.getAttribute('id') + ' '; - media += $(this).find('parameter').map(function () { return (this.getAttribute('name') ? (this.getAttribute('name') + '=') : '') + this.getAttribute('value'); }).get().join('; '); - media += '\r\n'; - } - // xep-0293 - media += self.RtcpFbFromJingle($(this), this.getAttribute('id')); - }); - - // xep-0293 - media += self.RtcpFbFromJingle(desc, '*'); - - // xep-0294 - tmp = desc.find('>rtp-hdrext[xmlns="urn:xmpp:jingle:apps:rtp:rtp-hdrext:0"]'); - tmp.each(function () { - media += 'a=extmap:' + this.getAttribute('id') + ' ' + this.getAttribute('uri') + '\r\n'; - }); - - content.find('>bundle>transport[xmlns="urn:xmpp:jingle:transports:ice-udp:1"]>candidate').each(function () { - media += SDPUtil.candidateFromJingle(this); - }); - - // XEP-0339 handle ssrc-group attributes - tmp = content.find('description>ssrc-group[xmlns="urn:xmpp:jingle:apps:rtp:ssma:0"]').each(function() { - var semantics = this.getAttribute('semantics'); - var ssrcs = $(this).find('>source').map(function() { - return this.getAttribute('ssrc'); - }).get(); - - if (ssrcs.length != 0) { - media += 'a=ssrc-group:' + semantics + ' ' + ssrcs.join(' ') + '\r\n'; - } - }); - - tmp = content.find('description>source[xmlns="urn:xmpp:jingle:apps:rtp:ssma:0"]'); - tmp.each(function () { - var ssrc = this.getAttribute('ssrc'); - $(this).find('>parameter').each(function () { - media += 'a=ssrc:' + ssrc + ' ' + this.getAttribute('name'); - if (this.getAttribute('value') && this.getAttribute('value').length) - media += ':' + this.getAttribute('value'); - media += '\r\n'; - }); - }); - - if (tmp.length === 0) { - // fallback to proprietary mapping of a=ssrc lines - tmp = content.find('description>ssrc[xmlns="http://estos.de/ns/ssrc"]'); - if (tmp.length) { - media += 'a=ssrc:' + ssrc + ' cname:' + tmp.attr('cname') + '\r\n'; - media += 'a=ssrc:' + ssrc + ' msid:' + tmp.attr('msid') + '\r\n'; - media += 'a=ssrc:' + ssrc + ' mslabel:' + tmp.attr('mslabel') + '\r\n'; - media += 'a=ssrc:' + ssrc + ' label:' + tmp.attr('label') + '\r\n'; - } - } - return media; -}; - diff --git a/contrib/jitsimeetbridge/unjingle/strophe.jingle.sdp.util.js b/contrib/jitsimeetbridge/unjingle/strophe.jingle.sdp.util.js deleted file mode 100644 index 042a123c32..0000000000 --- a/contrib/jitsimeetbridge/unjingle/strophe.jingle.sdp.util.js +++ /dev/null @@ -1,408 +0,0 @@ -/** - * Contains utility classes used in SDP class. - * - */ - -/** - * Class holds a=ssrc lines and media type a=mid - * @param ssrc synchronization source identifier number(a=ssrc lines from SDP) - * @param type media type eg. "audio" or "video"(a=mid frm SDP) - * @constructor - */ -function ChannelSsrc(ssrc, type) { - this.ssrc = ssrc; - this.type = type; - this.lines = []; -} - -/** - * Class holds a=ssrc-group: lines - * @param semantics - * @param ssrcs - * @constructor - */ -function ChannelSsrcGroup(semantics, ssrcs, line) { - this.semantics = semantics; - this.ssrcs = ssrcs; -} - -/** - * Helper class represents media channel. Is a container for ChannelSsrc, holds channel idx and media type. - * @param channelNumber channel idx in SDP media array. - * @param mediaType media type(a=mid) - * @constructor - */ -function MediaChannel(channelNumber, mediaType) { - /** - * SDP channel number - * @type {*} - */ - this.chNumber = channelNumber; - /** - * Channel media type(a=mid) - * @type {*} - */ - this.mediaType = mediaType; - /** - * The maps of ssrc numbers to ChannelSsrc objects. - */ - this.ssrcs = {}; - - /** - * The array of ChannelSsrcGroup objects. - * @type {Array} - */ - this.ssrcGroups = []; -} - -SDPUtil = { - iceparams: function (mediadesc, sessiondesc) { - var data = null; - if (SDPUtil.find_line(mediadesc, 'a=ice-ufrag:', sessiondesc) && - SDPUtil.find_line(mediadesc, 'a=ice-pwd:', sessiondesc)) { - data = { - ufrag: SDPUtil.parse_iceufrag(SDPUtil.find_line(mediadesc, 'a=ice-ufrag:', sessiondesc)), - pwd: SDPUtil.parse_icepwd(SDPUtil.find_line(mediadesc, 'a=ice-pwd:', sessiondesc)) - }; - } - return data; - }, - parse_iceufrag: function (line) { - return line.substring(12); - }, - build_iceufrag: function (frag) { - return 'a=ice-ufrag:' + frag; - }, - parse_icepwd: function (line) { - return line.substring(10); - }, - build_icepwd: function (pwd) { - return 'a=ice-pwd:' + pwd; - }, - parse_mid: function (line) { - return line.substring(6); - }, - parse_mline: function (line) { - var parts = line.substring(2).split(' '), - data = {}; - data.media = parts.shift(); - data.port = parts.shift(); - data.proto = parts.shift(); - if (parts[parts.length - 1] === '') { // trailing whitespace - parts.pop(); - } - data.fmt = parts; - return data; - }, - build_mline: function (mline) { - return 'm=' + mline.media + ' ' + mline.port + ' ' + mline.proto + ' ' + mline.fmt.join(' '); - }, - parse_rtpmap: function (line) { - var parts = line.substring(9).split(' '), - data = {}; - data.id = parts.shift(); - parts = parts[0].split('/'); - data.name = parts.shift(); - data.clockrate = parts.shift(); - data.channels = parts.length ? parts.shift() : '1'; - return data; - }, - /** - * Parses SDP line "a=sctpmap:..." and extracts SCTP port from it. - * @param line eg. "a=sctpmap:5000 webrtc-datachannel" - * @returns [SCTP port number, protocol, streams] - */ - parse_sctpmap: function (line) - { - var parts = line.substring(10).split(' '); - var sctpPort = parts[0]; - var protocol = parts[1]; - // Stream count is optional - var streamCount = parts.length > 2 ? parts[2] : null; - return [sctpPort, protocol, streamCount];// SCTP port - }, - build_rtpmap: function (el) { - var line = 'a=rtpmap:' + el.getAttribute('id') + ' ' + el.getAttribute('name') + '/' + el.getAttribute('clockrate'); - if (el.getAttribute('channels') && el.getAttribute('channels') != '1') { - line += '/' + el.getAttribute('channels'); - } - return line; - }, - parse_crypto: function (line) { - var parts = line.substring(9).split(' '), - data = {}; - data.tag = parts.shift(); - data['crypto-suite'] = parts.shift(); - data['key-params'] = parts.shift(); - if (parts.length) { - data['session-params'] = parts.join(' '); - } - return data; - }, - parse_fingerprint: function (line) { // RFC 4572 - var parts = line.substring(14).split(' '), - data = {}; - data.hash = parts.shift(); - data.fingerprint = parts.shift(); - // TODO assert that fingerprint satisfies 2UHEX *(":" 2UHEX) ? - return data; - }, - parse_fmtp: function (line) { - var parts = line.split(' '), - i, key, value, - data = []; - parts.shift(); - parts = parts.join(' ').split(';'); - for (i = 0; i < parts.length; i++) { - key = parts[i].split('=')[0]; - while (key.length && key[0] == ' ') { - key = key.substring(1); - } - value = parts[i].split('=')[1]; - if (key && value) { - data.push({name: key, value: value}); - } else if (key) { - // rfc 4733 (DTMF) style stuff - data.push({name: '', value: key}); - } - } - return data; - }, - parse_icecandidate: function (line) { - var candidate = {}, - elems = line.split(' '); - candidate.foundation = elems[0].substring(12); - candidate.component = elems[1]; - candidate.protocol = elems[2].toLowerCase(); - candidate.priority = elems[3]; - candidate.ip = elems[4]; - candidate.port = elems[5]; - // elems[6] => "typ" - candidate.type = elems[7]; - candidate.generation = 0; // default value, may be overwritten below - for (var i = 8; i < elems.length; i += 2) { - switch (elems[i]) { - case 'raddr': - candidate['rel-addr'] = elems[i + 1]; - break; - case 'rport': - candidate['rel-port'] = elems[i + 1]; - break; - case 'generation': - candidate.generation = elems[i + 1]; - break; - case 'tcptype': - candidate.tcptype = elems[i + 1]; - break; - default: // TODO - console.log('parse_icecandidate not translating "' + elems[i] + '" = "' + elems[i + 1] + '"'); - } - } - candidate.network = '1'; - candidate.id = Math.random().toString(36).substr(2, 10); // not applicable to SDP -- FIXME: should be unique, not just random - return candidate; - }, - build_icecandidate: function (cand) { - var line = ['a=candidate:' + cand.foundation, cand.component, cand.protocol, cand.priority, cand.ip, cand.port, 'typ', cand.type].join(' '); - line += ' '; - switch (cand.type) { - case 'srflx': - case 'prflx': - case 'relay': - if (cand.hasOwnAttribute('rel-addr') && cand.hasOwnAttribute('rel-port')) { - line += 'raddr'; - line += ' '; - line += cand['rel-addr']; - line += ' '; - line += 'rport'; - line += ' '; - line += cand['rel-port']; - line += ' '; - } - break; - } - if (cand.hasOwnAttribute('tcptype')) { - line += 'tcptype'; - line += ' '; - line += cand.tcptype; - line += ' '; - } - line += 'generation'; - line += ' '; - line += cand.hasOwnAttribute('generation') ? cand.generation : '0'; - return line; - }, - parse_ssrc: function (desc) { - // proprietary mapping of a=ssrc lines - // TODO: see "Jingle RTP Source Description" by Juberti and P. Thatcher on google docs - // and parse according to that - var lines = desc.split('\r\n'), - data = {}; - for (var i = 0; i < lines.length; i++) { - if (lines[i].substring(0, 7) == 'a=ssrc:') { - var idx = lines[i].indexOf(' '); - data[lines[i].substr(idx + 1).split(':', 2)[0]] = lines[i].substr(idx + 1).split(':', 2)[1]; - } - } - return data; - }, - parse_rtcpfb: function (line) { - var parts = line.substr(10).split(' '); - var data = {}; - data.pt = parts.shift(); - data.type = parts.shift(); - data.params = parts; - return data; - }, - parse_extmap: function (line) { - var parts = line.substr(9).split(' '); - var data = {}; - data.value = parts.shift(); - if (data.value.indexOf('/') != -1) { - data.direction = data.value.substr(data.value.indexOf('/') + 1); - data.value = data.value.substr(0, data.value.indexOf('/')); - } else { - data.direction = 'both'; - } - data.uri = parts.shift(); - data.params = parts; - return data; - }, - find_line: function (haystack, needle, sessionpart) { - var lines = haystack.split('\r\n'); - for (var i = 0; i < lines.length; i++) { - if (lines[i].substring(0, needle.length) == needle) { - return lines[i]; - } - } - if (!sessionpart) { - return false; - } - // search session part - lines = sessionpart.split('\r\n'); - for (var j = 0; j < lines.length; j++) { - if (lines[j].substring(0, needle.length) == needle) { - return lines[j]; - } - } - return false; - }, - find_lines: function (haystack, needle, sessionpart) { - var lines = haystack.split('\r\n'), - needles = []; - for (var i = 0; i < lines.length; i++) { - if (lines[i].substring(0, needle.length) == needle) - needles.push(lines[i]); - } - if (needles.length || !sessionpart) { - return needles; - } - // search session part - lines = sessionpart.split('\r\n'); - for (var j = 0; j < lines.length; j++) { - if (lines[j].substring(0, needle.length) == needle) { - needles.push(lines[j]); - } - } - return needles; - }, - candidateToJingle: function (line) { - // a=candidate:2979166662 1 udp 2113937151 192.168.2.100 57698 typ host generation 0 - // - if (line.indexOf('candidate:') === 0) { - line = 'a=' + line; - } else if (line.substring(0, 12) != 'a=candidate:') { - console.log('parseCandidate called with a line that is not a candidate line'); - console.log(line); - return null; - } - if (line.substring(line.length - 2) == '\r\n') // chomp it - line = line.substring(0, line.length - 2); - var candidate = {}, - elems = line.split(' '), - i; - if (elems[6] != 'typ') { - console.log('did not find typ in the right place'); - console.log(line); - return null; - } - candidate.foundation = elems[0].substring(12); - candidate.component = elems[1]; - candidate.protocol = elems[2].toLowerCase(); - candidate.priority = elems[3]; - candidate.ip = elems[4]; - candidate.port = elems[5]; - // elems[6] => "typ" - candidate.type = elems[7]; - - candidate.generation = '0'; // default, may be overwritten below - for (i = 8; i < elems.length; i += 2) { - switch (elems[i]) { - case 'raddr': - candidate['rel-addr'] = elems[i + 1]; - break; - case 'rport': - candidate['rel-port'] = elems[i + 1]; - break; - case 'generation': - candidate.generation = elems[i + 1]; - break; - case 'tcptype': - candidate.tcptype = elems[i + 1]; - break; - default: // TODO - console.log('not translating "' + elems[i] + '" = "' + elems[i + 1] + '"'); - } - } - candidate.network = '1'; - candidate.id = Math.random().toString(36).substr(2, 10); // not applicable to SDP -- FIXME: should be unique, not just random - return candidate; - }, - candidateFromJingle: function (cand) { - var line = 'a=candidate:'; - line += cand.getAttribute('foundation'); - line += ' '; - line += cand.getAttribute('component'); - line += ' '; - line += cand.getAttribute('protocol'); //.toUpperCase(); // chrome M23 doesn't like this - line += ' '; - line += cand.getAttribute('priority'); - line += ' '; - line += cand.getAttribute('ip'); - line += ' '; - line += cand.getAttribute('port'); - line += ' '; - line += 'typ'; - line += ' ' + cand.getAttribute('type'); - line += ' '; - switch (cand.getAttribute('type')) { - case 'srflx': - case 'prflx': - case 'relay': - if (cand.getAttribute('rel-addr') && cand.getAttribute('rel-port')) { - line += 'raddr'; - line += ' '; - line += cand.getAttribute('rel-addr'); - line += ' '; - line += 'rport'; - line += ' '; - line += cand.getAttribute('rel-port'); - line += ' '; - } - break; - } - if (cand.getAttribute('protocol').toLowerCase() == 'tcp') { - line += 'tcptype'; - line += ' '; - line += cand.getAttribute('tcptype'); - line += ' '; - } - line += 'generation'; - line += ' '; - line += cand.getAttribute('generation') || '0'; - return line + '\r\n'; - } -}; - -exports.SDPUtil = SDPUtil; - diff --git a/contrib/jitsimeetbridge/unjingle/strophe/XMLHttpRequest.js b/contrib/jitsimeetbridge/unjingle/strophe/XMLHttpRequest.js deleted file mode 100644 index 9c45c2df18..0000000000 --- a/contrib/jitsimeetbridge/unjingle/strophe/XMLHttpRequest.js +++ /dev/null @@ -1,254 +0,0 @@ -/** - * Wrapper for built-in http.js to emulate the browser XMLHttpRequest object. - * - * This can be used with JS designed for browsers to improve reuse of code and - * allow the use of existing libraries. - * - * Usage: include("XMLHttpRequest.js") and use XMLHttpRequest per W3C specs. - * - * @todo SSL Support - * @author Dan DeFelippi - * @license MIT - */ - -var Url = require("url") - ,sys = require("util"); - -exports.XMLHttpRequest = function() { - /** - * Private variables - */ - var self = this; - var http = require('http'); - var https = require('https'); - - // Holds http.js objects - var client; - var request; - var response; - - // Request settings - var settings = {}; - - // Set some default headers - var defaultHeaders = { - "User-Agent": "node.js", - "Accept": "*/*", - }; - - var headers = defaultHeaders; - - /** - * Constants - */ - this.UNSENT = 0; - this.OPENED = 1; - this.HEADERS_RECEIVED = 2; - this.LOADING = 3; - this.DONE = 4; - - /** - * Public vars - */ - // Current state - this.readyState = this.UNSENT; - - // default ready state change handler in case one is not set or is set late - this.onreadystatechange = function() {}; - - // Result & response - this.responseText = ""; - this.responseXML = ""; - this.status = null; - this.statusText = null; - - /** - * Open the connection. Currently supports local server requests. - * - * @param string method Connection method (eg GET, POST) - * @param string url URL for the connection. - * @param boolean async Asynchronous connection. Default is true. - * @param string user Username for basic authentication (optional) - * @param string password Password for basic authentication (optional) - */ - this.open = function(method, url, async, user, password) { - settings = { - "method": method, - "url": url, - "async": async || null, - "user": user || null, - "password": password || null - }; - - this.abort(); - - setState(this.OPENED); - }; - - /** - * Sets a header for the request. - * - * @param string header Header name - * @param string value Header value - */ - this.setRequestHeader = function(header, value) { - headers[header] = value; - }; - - /** - * Gets a header from the server response. - * - * @param string header Name of header to get. - * @return string Text of the header or null if it doesn't exist. - */ - this.getResponseHeader = function(header) { - if (this.readyState > this.OPENED && response.headers[header]) { - return header + ": " + response.headers[header]; - } - - return null; - }; - - /** - * Gets all the response headers. - * - * @return string - */ - this.getAllResponseHeaders = function() { - if (this.readyState < this.HEADERS_RECEIVED) { - throw "INVALID_STATE_ERR: Headers have not been received."; - } - var result = ""; - - for (var i in response.headers) { - result += i + ": " + response.headers[i] + "\r\n"; - } - return result.substr(0, result.length - 2); - }; - - /** - * Sends the request to the server. - * - * @param string data Optional data to send as request body. - */ - this.send = function(data) { - if (this.readyState != this.OPENED) { - throw "INVALID_STATE_ERR: connection must be opened before send() is called"; - } - - var ssl = false; - var url = Url.parse(settings.url); - - // Determine the server - switch (url.protocol) { - case 'https:': - ssl = true; - // SSL & non-SSL both need host, no break here. - case 'http:': - var host = url.hostname; - break; - - case undefined: - case '': - var host = "localhost"; - break; - - default: - throw "Protocol not supported."; - } - - // Default to port 80. If accessing localhost on another port be sure - // to use http://localhost:port/path - var port = url.port || (ssl ? 443 : 80); - // Add query string if one is used - var uri = url.pathname + (url.search ? url.search : ''); - - // Set the Host header or the server may reject the request - this.setRequestHeader("Host", host); - - // Set content length header - if (settings.method == "GET" || settings.method == "HEAD") { - data = null; - } else if (data) { - this.setRequestHeader("Content-Length", Buffer.byteLength(data)); - - if (!headers["Content-Type"]) { - this.setRequestHeader("Content-Type", "text/plain;charset=UTF-8"); - } - } - - // Use the proper protocol - var doRequest = ssl ? https.request : http.request; - - var options = { - host: host, - port: port, - path: uri, - method: settings.method, - headers: headers, - agent: false - }; - - var req = doRequest(options, function(res) { - response = res; - response.setEncoding("utf8"); - - setState(self.HEADERS_RECEIVED); - self.status = response.statusCode; - - response.on('data', function(chunk) { - // Make sure there's some data - if (chunk) { - self.responseText += chunk; - } - setState(self.LOADING); - }); - - response.on('end', function() { - setState(self.DONE); - }); - - response.on('error', function() { - self.handleError(error); - }); - }).on('error', function(error) { - self.handleError(error); - }); - - req.setHeader("Connection", "Close"); - - // Node 0.4 and later won't accept empty data. Make sure it's needed. - if (data) { - req.write(data); - } - - req.end(); - }; - - this.handleError = function(error) { - this.status = 503; - this.statusText = error; - this.responseText = error.stack; - setState(this.DONE); - }; - - /** - * Aborts a request. - */ - this.abort = function() { - headers = defaultHeaders; - this.readyState = this.UNSENT; - this.responseText = ""; - this.responseXML = ""; - }; - - /** - * Changes readyState and calls onreadystatechange. - * - * @param int state New state - */ - var setState = function(state) { - self.readyState = state; - self.onreadystatechange(); - } -}; diff --git a/contrib/jitsimeetbridge/unjingle/strophe/base64.js b/contrib/jitsimeetbridge/unjingle/strophe/base64.js deleted file mode 100644 index 418caac050..0000000000 --- a/contrib/jitsimeetbridge/unjingle/strophe/base64.js +++ /dev/null @@ -1,83 +0,0 @@ -// This code was written by Tyler Akins and has been placed in the -// public domain. It would be nice if you left this header intact. -// Base64 code from Tyler Akins -- http://rumkin.com - -var Base64 = (function () { - var keyStr = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/="; - - var obj = { - /** - * Encodes a string in base64 - * @param {String} input The string to encode in base64. - */ - encode: function (input) { - var output = ""; - var chr1, chr2, chr3; - var enc1, enc2, enc3, enc4; - var i = 0; - - do { - chr1 = input.charCodeAt(i++); - chr2 = input.charCodeAt(i++); - chr3 = input.charCodeAt(i++); - - enc1 = chr1 >> 2; - enc2 = ((chr1 & 3) << 4) | (chr2 >> 4); - enc3 = ((chr2 & 15) << 2) | (chr3 >> 6); - enc4 = chr3 & 63; - - if (isNaN(chr2)) { - enc3 = enc4 = 64; - } else if (isNaN(chr3)) { - enc4 = 64; - } - - output = output + keyStr.charAt(enc1) + keyStr.charAt(enc2) + - keyStr.charAt(enc3) + keyStr.charAt(enc4); - } while (i < input.length); - - return output; - }, - - /** - * Decodes a base64 string. - * @param {String} input The string to decode. - */ - decode: function (input) { - var output = ""; - var chr1, chr2, chr3; - var enc1, enc2, enc3, enc4; - var i = 0; - - // remove all characters that are not A-Z, a-z, 0-9, +, /, or = - input = input.replace(/[^A-Za-z0-9\+\/\=]/g, ''); - - do { - enc1 = keyStr.indexOf(input.charAt(i++)); - enc2 = keyStr.indexOf(input.charAt(i++)); - enc3 = keyStr.indexOf(input.charAt(i++)); - enc4 = keyStr.indexOf(input.charAt(i++)); - - chr1 = (enc1 << 2) | (enc2 >> 4); - chr2 = ((enc2 & 15) << 4) | (enc3 >> 2); - chr3 = ((enc3 & 3) << 6) | enc4; - - output = output + String.fromCharCode(chr1); - - if (enc3 != 64) { - output = output + String.fromCharCode(chr2); - } - if (enc4 != 64) { - output = output + String.fromCharCode(chr3); - } - } while (i < input.length); - - return output; - } - }; - - return obj; -})(); - -// Nodify -exports.Base64 = Base64; diff --git a/contrib/jitsimeetbridge/unjingle/strophe/md5.js b/contrib/jitsimeetbridge/unjingle/strophe/md5.js deleted file mode 100644 index 5334325e2f..0000000000 --- a/contrib/jitsimeetbridge/unjingle/strophe/md5.js +++ /dev/null @@ -1,279 +0,0 @@ -/* - * A JavaScript implementation of the RSA Data Security, Inc. MD5 Message - * Digest Algorithm, as defined in RFC 1321. - * Version 2.1 Copyright (C) Paul Johnston 1999 - 2002. - * Other contributors: Greg Holt, Andrew Kepert, Ydnar, Lostinet - * Distributed under the BSD License - * See http://pajhome.org.uk/crypt/md5 for more info. - */ - -var MD5 = (function () { - /* - * Configurable variables. You may need to tweak these to be compatible with - * the server-side, but the defaults work in most cases. - */ - var hexcase = 0; /* hex output format. 0 - lowercase; 1 - uppercase */ - var b64pad = ""; /* base-64 pad character. "=" for strict RFC compliance */ - var chrsz = 8; /* bits per input character. 8 - ASCII; 16 - Unicode */ - - /* - * Add integers, wrapping at 2^32. This uses 16-bit operations internally - * to work around bugs in some JS interpreters. - */ - var safe_add = function (x, y) { - var lsw = (x & 0xFFFF) + (y & 0xFFFF); - var msw = (x >> 16) + (y >> 16) + (lsw >> 16); - return (msw << 16) | (lsw & 0xFFFF); - }; - - /* - * Bitwise rotate a 32-bit number to the left. - */ - var bit_rol = function (num, cnt) { - return (num << cnt) | (num >>> (32 - cnt)); - }; - - /* - * Convert a string to an array of little-endian words - * If chrsz is ASCII, characters >255 have their hi-byte silently ignored. - */ - var str2binl = function (str) { - var bin = []; - var mask = (1 << chrsz) - 1; - for(var i = 0; i < str.length * chrsz; i += chrsz) - { - bin[i>>5] |= (str.charCodeAt(i / chrsz) & mask) << (i%32); - } - return bin; - }; - - /* - * Convert an array of little-endian words to a string - */ - var binl2str = function (bin) { - var str = ""; - var mask = (1 << chrsz) - 1; - for(var i = 0; i < bin.length * 32; i += chrsz) - { - str += String.fromCharCode((bin[i>>5] >>> (i % 32)) & mask); - } - return str; - }; - - /* - * Convert an array of little-endian words to a hex string. - */ - var binl2hex = function (binarray) { - var hex_tab = hexcase ? "0123456789ABCDEF" : "0123456789abcdef"; - var str = ""; - for(var i = 0; i < binarray.length * 4; i++) - { - str += hex_tab.charAt((binarray[i>>2] >> ((i%4)*8+4)) & 0xF) + - hex_tab.charAt((binarray[i>>2] >> ((i%4)*8 )) & 0xF); - } - return str; - }; - - /* - * Convert an array of little-endian words to a base-64 string - */ - var binl2b64 = function (binarray) { - var tab = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; - var str = ""; - var triplet, j; - for(var i = 0; i < binarray.length * 4; i += 3) - { - triplet = (((binarray[i >> 2] >> 8 * ( i %4)) & 0xFF) << 16) | - (((binarray[i+1 >> 2] >> 8 * ((i+1)%4)) & 0xFF) << 8 ) | - ((binarray[i+2 >> 2] >> 8 * ((i+2)%4)) & 0xFF); - for(j = 0; j < 4; j++) - { - if(i * 8 + j * 6 > binarray.length * 32) { str += b64pad; } - else { str += tab.charAt((triplet >> 6*(3-j)) & 0x3F); } - } - } - return str; - }; - - /* - * These functions implement the four basic operations the algorithm uses. - */ - var md5_cmn = function (q, a, b, x, s, t) { - return safe_add(bit_rol(safe_add(safe_add(a, q),safe_add(x, t)), s),b); - }; - - var md5_ff = function (a, b, c, d, x, s, t) { - return md5_cmn((b & c) | ((~b) & d), a, b, x, s, t); - }; - - var md5_gg = function (a, b, c, d, x, s, t) { - return md5_cmn((b & d) | (c & (~d)), a, b, x, s, t); - }; - - var md5_hh = function (a, b, c, d, x, s, t) { - return md5_cmn(b ^ c ^ d, a, b, x, s, t); - }; - - var md5_ii = function (a, b, c, d, x, s, t) { - return md5_cmn(c ^ (b | (~d)), a, b, x, s, t); - }; - - /* - * Calculate the MD5 of an array of little-endian words, and a bit length - */ - var core_md5 = function (x, len) { - /* append padding */ - x[len >> 5] |= 0x80 << ((len) % 32); - x[(((len + 64) >>> 9) << 4) + 14] = len; - - var a = 1732584193; - var b = -271733879; - var c = -1732584194; - var d = 271733878; - - var olda, oldb, oldc, oldd; - for (var i = 0; i < x.length; i += 16) - { - olda = a; - oldb = b; - oldc = c; - oldd = d; - - a = md5_ff(a, b, c, d, x[i+ 0], 7 , -680876936); - d = md5_ff(d, a, b, c, x[i+ 1], 12, -389564586); - c = md5_ff(c, d, a, b, x[i+ 2], 17, 606105819); - b = md5_ff(b, c, d, a, x[i+ 3], 22, -1044525330); - a = md5_ff(a, b, c, d, x[i+ 4], 7 , -176418897); - d = md5_ff(d, a, b, c, x[i+ 5], 12, 1200080426); - c = md5_ff(c, d, a, b, x[i+ 6], 17, -1473231341); - b = md5_ff(b, c, d, a, x[i+ 7], 22, -45705983); - a = md5_ff(a, b, c, d, x[i+ 8], 7 , 1770035416); - d = md5_ff(d, a, b, c, x[i+ 9], 12, -1958414417); - c = md5_ff(c, d, a, b, x[i+10], 17, -42063); - b = md5_ff(b, c, d, a, x[i+11], 22, -1990404162); - a = md5_ff(a, b, c, d, x[i+12], 7 , 1804603682); - d = md5_ff(d, a, b, c, x[i+13], 12, -40341101); - c = md5_ff(c, d, a, b, x[i+14], 17, -1502002290); - b = md5_ff(b, c, d, a, x[i+15], 22, 1236535329); - - a = md5_gg(a, b, c, d, x[i+ 1], 5 , -165796510); - d = md5_gg(d, a, b, c, x[i+ 6], 9 , -1069501632); - c = md5_gg(c, d, a, b, x[i+11], 14, 643717713); - b = md5_gg(b, c, d, a, x[i+ 0], 20, -373897302); - a = md5_gg(a, b, c, d, x[i+ 5], 5 , -701558691); - d = md5_gg(d, a, b, c, x[i+10], 9 , 38016083); - c = md5_gg(c, d, a, b, x[i+15], 14, -660478335); - b = md5_gg(b, c, d, a, x[i+ 4], 20, -405537848); - a = md5_gg(a, b, c, d, x[i+ 9], 5 , 568446438); - d = md5_gg(d, a, b, c, x[i+14], 9 , -1019803690); - c = md5_gg(c, d, a, b, x[i+ 3], 14, -187363961); - b = md5_gg(b, c, d, a, x[i+ 8], 20, 1163531501); - a = md5_gg(a, b, c, d, x[i+13], 5 , -1444681467); - d = md5_gg(d, a, b, c, x[i+ 2], 9 , -51403784); - c = md5_gg(c, d, a, b, x[i+ 7], 14, 1735328473); - b = md5_gg(b, c, d, a, x[i+12], 20, -1926607734); - - a = md5_hh(a, b, c, d, x[i+ 5], 4 , -378558); - d = md5_hh(d, a, b, c, x[i+ 8], 11, -2022574463); - c = md5_hh(c, d, a, b, x[i+11], 16, 1839030562); - b = md5_hh(b, c, d, a, x[i+14], 23, -35309556); - a = md5_hh(a, b, c, d, x[i+ 1], 4 , -1530992060); - d = md5_hh(d, a, b, c, x[i+ 4], 11, 1272893353); - c = md5_hh(c, d, a, b, x[i+ 7], 16, -155497632); - b = md5_hh(b, c, d, a, x[i+10], 23, -1094730640); - a = md5_hh(a, b, c, d, x[i+13], 4 , 681279174); - d = md5_hh(d, a, b, c, x[i+ 0], 11, -358537222); - c = md5_hh(c, d, a, b, x[i+ 3], 16, -722521979); - b = md5_hh(b, c, d, a, x[i+ 6], 23, 76029189); - a = md5_hh(a, b, c, d, x[i+ 9], 4 , -640364487); - d = md5_hh(d, a, b, c, x[i+12], 11, -421815835); - c = md5_hh(c, d, a, b, x[i+15], 16, 530742520); - b = md5_hh(b, c, d, a, x[i+ 2], 23, -995338651); - - a = md5_ii(a, b, c, d, x[i+ 0], 6 , -198630844); - d = md5_ii(d, a, b, c, x[i+ 7], 10, 1126891415); - c = md5_ii(c, d, a, b, x[i+14], 15, -1416354905); - b = md5_ii(b, c, d, a, x[i+ 5], 21, -57434055); - a = md5_ii(a, b, c, d, x[i+12], 6 , 1700485571); - d = md5_ii(d, a, b, c, x[i+ 3], 10, -1894986606); - c = md5_ii(c, d, a, b, x[i+10], 15, -1051523); - b = md5_ii(b, c, d, a, x[i+ 1], 21, -2054922799); - a = md5_ii(a, b, c, d, x[i+ 8], 6 , 1873313359); - d = md5_ii(d, a, b, c, x[i+15], 10, -30611744); - c = md5_ii(c, d, a, b, x[i+ 6], 15, -1560198380); - b = md5_ii(b, c, d, a, x[i+13], 21, 1309151649); - a = md5_ii(a, b, c, d, x[i+ 4], 6 , -145523070); - d = md5_ii(d, a, b, c, x[i+11], 10, -1120210379); - c = md5_ii(c, d, a, b, x[i+ 2], 15, 718787259); - b = md5_ii(b, c, d, a, x[i+ 9], 21, -343485551); - - a = safe_add(a, olda); - b = safe_add(b, oldb); - c = safe_add(c, oldc); - d = safe_add(d, oldd); - } - return [a, b, c, d]; - }; - - - /* - * Calculate the HMAC-MD5, of a key and some data - */ - var core_hmac_md5 = function (key, data) { - var bkey = str2binl(key); - if(bkey.length > 16) { bkey = core_md5(bkey, key.length * chrsz); } - - var ipad = new Array(16), opad = new Array(16); - for(var i = 0; i < 16; i++) - { - ipad[i] = bkey[i] ^ 0x36363636; - opad[i] = bkey[i] ^ 0x5C5C5C5C; - } - - var hash = core_md5(ipad.concat(str2binl(data)), 512 + data.length * chrsz); - return core_md5(opad.concat(hash), 512 + 128); - }; - - var obj = { - /* - * These are the functions you'll usually want to call. - * They take string arguments and return either hex or base-64 encoded - * strings. - */ - hexdigest: function (s) { - return binl2hex(core_md5(str2binl(s), s.length * chrsz)); - }, - - b64digest: function (s) { - return binl2b64(core_md5(str2binl(s), s.length * chrsz)); - }, - - hash: function (s) { - return binl2str(core_md5(str2binl(s), s.length * chrsz)); - }, - - hmac_hexdigest: function (key, data) { - return binl2hex(core_hmac_md5(key, data)); - }, - - hmac_b64digest: function (key, data) { - return binl2b64(core_hmac_md5(key, data)); - }, - - hmac_hash: function (key, data) { - return binl2str(core_hmac_md5(key, data)); - }, - - /* - * Perform a simple self-test to see if the VM is working - */ - test: function () { - return MD5.hexdigest("abc") === "900150983cd24fb0d6963f7d28e17f72"; - } - }; - - return obj; -})(); - -// Nodify -exports.MD5 = MD5; diff --git a/contrib/jitsimeetbridge/unjingle/strophe/strophe.js b/contrib/jitsimeetbridge/unjingle/strophe/strophe.js deleted file mode 100644 index 06d426cdec..0000000000 --- a/contrib/jitsimeetbridge/unjingle/strophe/strophe.js +++ /dev/null @@ -1,3256 +0,0 @@ -/* - This program is distributed under the terms of the MIT license. - Please see the LICENSE file for details. - - Copyright 2006-2008, OGG, LLC -*/ - -/* jslint configuration: */ -/*global document, window, setTimeout, clearTimeout, console, - XMLHttpRequest, ActiveXObject, - Base64, MD5, - Strophe, $build, $msg, $iq, $pres */ - -/** File: strophe.js - * A JavaScript library for XMPP BOSH. - * - * This is the JavaScript version of the Strophe library. Since JavaScript - * has no facilities for persistent TCP connections, this library uses - * Bidirectional-streams Over Synchronous HTTP (BOSH) to emulate - * a persistent, stateful, two-way connection to an XMPP server. More - * information on BOSH can be found in XEP 124. - */ - -/** PrivateFunction: Function.prototype.bind - * Bind a function to an instance. - * - * This Function object extension method creates a bound method similar - * to those in Python. This means that the 'this' object will point - * to the instance you want. See - * MDC's bind() documentation and - * Bound Functions and Function Imports in JavaScript - * for a complete explanation. - * - * This extension already exists in some browsers (namely, Firefox 3), but - * we provide it to support those that don't. - * - * Parameters: - * (Object) obj - The object that will become 'this' in the bound function. - * (Object) argN - An option argument that will be prepended to the - * arguments given for the function call - * - * Returns: - * The bound function. - */ - -/* Make it work on node.js: Nodify - * - * Steps: - * 1. Create the global objects: window, document, Base64, MD5 and XMLHttpRequest - * 2. Use the node-XMLHttpRequest module. - * 3. Use jsdom for the document object - since it supports DOM functions. - * 4. Replace all calls to childNodes with _childNodes (since the former doesn't - * seem to work on jsdom). - * 5. While getting the response from XMLHttpRequest, manually convert the text - * data to XML. - * 6. All calls to nodeName should replaced by nodeName.toLowerCase() since jsdom - * seems to always convert node names to upper case. - * - */ -var XMLHttpRequest = require('./XMLHttpRequest.js').XMLHttpRequest; -var Base64 = require('./base64.js').Base64; -var MD5 = require('./md5.js').MD5; -var jsdom = require("jsdom").jsdom; - -document = jsdom(""), - -window = { - XMLHttpRequest: XMLHttpRequest, - Base64: Base64, - MD5: MD5 -}; - -exports.Strophe = window; - -if (!Function.prototype.bind) { - Function.prototype.bind = function (obj /*, arg1, arg2, ... */) - { - var func = this; - var _slice = Array.prototype.slice; - var _concat = Array.prototype.concat; - var _args = _slice.call(arguments, 1); - - return function () { - return func.apply(obj ? obj : this, - _concat.call(_args, - _slice.call(arguments, 0))); - }; - }; -} - -/** PrivateFunction: Array.prototype.indexOf - * Return the index of an object in an array. - * - * This function is not supplied by some JavaScript implementations, so - * we provide it if it is missing. This code is from: - * http://developer.mozilla.org/En/Core_JavaScript_1.5_Reference:Objects:Array:indexOf - * - * Parameters: - * (Object) elt - The object to look for. - * (Integer) from - The index from which to start looking. (optional). - * - * Returns: - * The index of elt in the array or -1 if not found. - */ -if (!Array.prototype.indexOf) -{ - Array.prototype.indexOf = function(elt /*, from*/) - { - var len = this.length; - - var from = Number(arguments[1]) || 0; - from = (from < 0) ? Math.ceil(from) : Math.floor(from); - if (from < 0) { - from += len; - } - - for (; from < len; from++) { - if (from in this && this[from] === elt) { - return from; - } - } - - return -1; - }; -} - -/* All of the Strophe globals are defined in this special function below so - * that references to the globals become closures. This will ensure that - * on page reload, these references will still be available to callbacks - * that are still executing. - */ - -(function (callback) { -var Strophe; - -/** Function: $build - * Create a Strophe.Builder. - * This is an alias for 'new Strophe.Builder(name, attrs)'. - * - * Parameters: - * (String) name - The root element name. - * (Object) attrs - The attributes for the root element in object notation. - * - * Returns: - * A new Strophe.Builder object. - */ -function $build(name, attrs) { return new Strophe.Builder(name, attrs); } -/** Function: $msg - * Create a Strophe.Builder with a element as the root. - * - * Parmaeters: - * (Object) attrs - The element attributes in object notation. - * - * Returns: - * A new Strophe.Builder object. - */ -function $msg(attrs) { return new Strophe.Builder("message", attrs); } -/** Function: $iq - * Create a Strophe.Builder with an element as the root. - * - * Parameters: - * (Object) attrs - The element attributes in object notation. - * - * Returns: - * A new Strophe.Builder object. - */ -function $iq(attrs) { return new Strophe.Builder("iq", attrs); } -/** Function: $pres - * Create a Strophe.Builder with a element as the root. - * - * Parameters: - * (Object) attrs - The element attributes in object notation. - * - * Returns: - * A new Strophe.Builder object. - */ -function $pres(attrs) { return new Strophe.Builder("presence", attrs); } - -/** Class: Strophe - * An object container for all Strophe library functions. - * - * This class is just a container for all the objects and constants - * used in the library. It is not meant to be instantiated, but to - * provide a namespace for library objects, constants, and functions. - */ -Strophe = { - /** Constant: VERSION - * The version of the Strophe library. Unreleased builds will have - * a version of head-HASH where HASH is a partial revision. - */ - VERSION: "@VERSION@", - - /** Constants: XMPP Namespace Constants - * Common namespace constants from the XMPP RFCs and XEPs. - * - * NS.HTTPBIND - HTTP BIND namespace from XEP 124. - * NS.BOSH - BOSH namespace from XEP 206. - * NS.CLIENT - Main XMPP client namespace. - * NS.AUTH - Legacy authentication namespace. - * NS.ROSTER - Roster operations namespace. - * NS.PROFILE - Profile namespace. - * NS.DISCO_INFO - Service discovery info namespace from XEP 30. - * NS.DISCO_ITEMS - Service discovery items namespace from XEP 30. - * NS.MUC - Multi-User Chat namespace from XEP 45. - * NS.SASL - XMPP SASL namespace from RFC 3920. - * NS.STREAM - XMPP Streams namespace from RFC 3920. - * NS.BIND - XMPP Binding namespace from RFC 3920. - * NS.SESSION - XMPP Session namespace from RFC 3920. - */ - NS: { - HTTPBIND: "http://jabber.org/protocol/httpbind", - BOSH: "urn:xmpp:xbosh", - CLIENT: "jabber:client", - AUTH: "jabber:iq:auth", - ROSTER: "jabber:iq:roster", - PROFILE: "jabber:iq:profile", - DISCO_INFO: "http://jabber.org/protocol/disco#info", - DISCO_ITEMS: "http://jabber.org/protocol/disco#items", - MUC: "http://jabber.org/protocol/muc", - SASL: "urn:ietf:params:xml:ns:xmpp-sasl", - STREAM: "http://etherx.jabber.org/streams", - BIND: "urn:ietf:params:xml:ns:xmpp-bind", - SESSION: "urn:ietf:params:xml:ns:xmpp-session", - VERSION: "jabber:iq:version", - STANZAS: "urn:ietf:params:xml:ns:xmpp-stanzas" - }, - - /** Function: addNamespace - * This function is used to extend the current namespaces in - * Strophe.NS. It takes a key and a value with the key being the - * name of the new namespace, with its actual value. - * For example: - * Strophe.addNamespace('PUBSUB', "http://jabber.org/protocol/pubsub"); - * - * Parameters: - * (String) name - The name under which the namespace will be - * referenced under Strophe.NS - * (String) value - The actual namespace. - */ - addNamespace: function (name, value) - { - Strophe.NS[name] = value; - }, - - /** Constants: Connection Status Constants - * Connection status constants for use by the connection handler - * callback. - * - * Status.ERROR - An error has occurred - * Status.CONNECTING - The connection is currently being made - * Status.CONNFAIL - The connection attempt failed - * Status.AUTHENTICATING - The connection is authenticating - * Status.AUTHFAIL - The authentication attempt failed - * Status.CONNECTED - The connection has succeeded - * Status.DISCONNECTED - The connection has been terminated - * Status.DISCONNECTING - The connection is currently being terminated - * Status.ATTACHED - The connection has been attached - */ - Status: { - ERROR: 0, - CONNECTING: 1, - CONNFAIL: 2, - AUTHENTICATING: 3, - AUTHFAIL: 4, - CONNECTED: 5, - DISCONNECTED: 6, - DISCONNECTING: 7, - ATTACHED: 8 - }, - - /** Constants: Log Level Constants - * Logging level indicators. - * - * LogLevel.DEBUG - Debug output - * LogLevel.INFO - Informational output - * LogLevel.WARN - Warnings - * LogLevel.ERROR - Errors - * LogLevel.FATAL - Fatal errors - */ - LogLevel: { - DEBUG: 0, - INFO: 1, - WARN: 2, - ERROR: 3, - FATAL: 4 - }, - - /** PrivateConstants: DOM Element Type Constants - * DOM element types. - * - * ElementType.NORMAL - Normal element. - * ElementType.TEXT - Text data element. - */ - ElementType: { - NORMAL: 1, - TEXT: 3 - }, - - /** PrivateConstants: Timeout Values - * Timeout values for error states. These values are in seconds. - * These should not be changed unless you know exactly what you are - * doing. - * - * TIMEOUT - Timeout multiplier. A waiting request will be considered - * failed after Math.floor(TIMEOUT * wait) seconds have elapsed. - * This defaults to 1.1, and with default wait, 66 seconds. - * SECONDARY_TIMEOUT - Secondary timeout multiplier. In cases where - * Strophe can detect early failure, it will consider the request - * failed if it doesn't return after - * Math.floor(SECONDARY_TIMEOUT * wait) seconds have elapsed. - * This defaults to 0.1, and with default wait, 6 seconds. - */ - TIMEOUT: 1.1, - SECONDARY_TIMEOUT: 0.1, - - /** Function: forEachChild - * Map a function over some or all child elements of a given element. - * - * This is a small convenience function for mapping a function over - * some or all of the children of an element. If elemName is null, all - * children will be passed to the function, otherwise only children - * whose tag names match elemName will be passed. - * - * Parameters: - * (XMLElement) elem - The element to operate on. - * (String) elemName - The child element tag name filter. - * (Function) func - The function to apply to each child. This - * function should take a single argument, a DOM element. - */ - forEachChild: function (elem, elemName, func) - { - var i, childNode; - - for (i = 0; i < elem._childNodes.length; i++) { - childNode = elem._childNodes[i]; - if (childNode.nodeType == Strophe.ElementType.NORMAL && - (!elemName || this.isTagEqual(childNode, elemName))) { - func(childNode); - } - } - }, - - /** Function: isTagEqual - * Compare an element's tag name with a string. - * - * This function is case insensitive. - * - * Parameters: - * (XMLElement) el - A DOM element. - * (String) name - The element name. - * - * Returns: - * true if the element's tag name matches _el_, and false - * otherwise. - */ - isTagEqual: function (el, name) - { - return el.tagName.toLowerCase() == name.toLowerCase(); - }, - - /** PrivateVariable: _xmlGenerator - * _Private_ variable that caches a DOM document to - * generate elements. - */ - _xmlGenerator: null, - - /** PrivateFunction: _makeGenerator - * _Private_ function that creates a dummy XML DOM document to serve as - * an element and text node generator. - */ - _makeGenerator: function () { - var doc; - - if (window.ActiveXObject) { - doc = this._getIEXmlDom(); - doc.appendChild(doc.createElement('strophe')); - } else { - doc = document.implementation - .createDocument('jabber:client', 'strophe', null); - } - - return doc; - }, - - /** Function: xmlGenerator - * Get the DOM document to generate elements. - * - * Returns: - * The currently used DOM document. - */ - xmlGenerator: function () { - if (!Strophe._xmlGenerator) { - Strophe._xmlGenerator = Strophe._makeGenerator(); - } - return Strophe._xmlGenerator; - }, - - /** PrivateFunction: _getIEXmlDom - * Gets IE xml doc object - * - * Returns: - * A Microsoft XML DOM Object - * See Also: - * http://msdn.microsoft.com/en-us/library/ms757837%28VS.85%29.aspx - */ - _getIEXmlDom : function() { - var doc = null; - var docStrings = [ - "Msxml2.DOMDocument.6.0", - "Msxml2.DOMDocument.5.0", - "Msxml2.DOMDocument.4.0", - "MSXML2.DOMDocument.3.0", - "MSXML2.DOMDocument", - "MSXML.DOMDocument", - "Microsoft.XMLDOM" - ]; - - for (var d = 0; d < docStrings.length; d++) { - if (doc === null) { - try { - doc = new ActiveXObject(docStrings[d]); - } catch (e) { - doc = null; - } - } else { - break; - } - } - - return doc; - }, - - /** Function: xmlElement - * Create an XML DOM element. - * - * This function creates an XML DOM element correctly across all - * implementations. Note that these are not HTML DOM elements, which - * aren't appropriate for XMPP stanzas. - * - * Parameters: - * (String) name - The name for the element. - * (Array|Object) attrs - An optional array or object containing - * key/value pairs to use as element attributes. The object should - * be in the format {'key': 'value'} or {key: 'value'}. The array - * should have the format [['key1', 'value1'], ['key2', 'value2']]. - * (String) text - The text child data for the element. - * - * Returns: - * A new XML DOM element. - */ - xmlElement: function (name) - { - if (!name) { return null; } - - var node = Strophe.xmlGenerator().createElement(name); - - // FIXME: this should throw errors if args are the wrong type or - // there are more than two optional args - var a, i, k; - for (a = 1; a < arguments.length; a++) { - if (!arguments[a]) { continue; } - if (typeof(arguments[a]) == "string" || - typeof(arguments[a]) == "number") { - node.appendChild(Strophe.xmlTextNode(arguments[a])); - } else if (typeof(arguments[a]) == "object" && - typeof(arguments[a].sort) == "function") { - for (i = 0; i < arguments[a].length; i++) { - if (typeof(arguments[a][i]) == "object" && - typeof(arguments[a][i].sort) == "function") { - node.setAttribute(arguments[a][i][0], - arguments[a][i][1]); - } - } - } else if (typeof(arguments[a]) == "object") { - for (k in arguments[a]) { - if (arguments[a].hasOwnProperty(k)) { - node.setAttribute(k, arguments[a][k]); - } - } - } - } - - return node; - }, - - /* Function: xmlescape - * Excapes invalid xml characters. - * - * Parameters: - * (String) text - text to escape. - * - * Returns: - * Escaped text. - */ - xmlescape: function(text) - { - text = text.replace(/\&/g, "&"); - text = text.replace(//g, ">"); - return text; - }, - - /** Function: xmlTextNode - * Creates an XML DOM text node. - * - * Provides a cross implementation version of document.createTextNode. - * - * Parameters: - * (String) text - The content of the text node. - * - * Returns: - * A new XML DOM text node. - */ - xmlTextNode: function (text) - { - //ensure text is escaped - text = Strophe.xmlescape(text); - - return Strophe.xmlGenerator().createTextNode(text); - }, - - /** Function: getText - * Get the concatenation of all text children of an element. - * - * Parameters: - * (XMLElement) elem - A DOM element. - * - * Returns: - * A String with the concatenated text of all text element children. - */ - getText: function (elem) - { - if (!elem) { return null; } - - var str = ""; - if (elem._childNodes.length === 0 && elem.nodeType == - Strophe.ElementType.TEXT) { - str += elem.nodeValue; - } - - for (var i = 0; i < elem._childNodes.length; i++) { - if (elem._childNodes[i].nodeType == Strophe.ElementType.TEXT) { - str += elem._childNodes[i].nodeValue; - } - } - - return str; - }, - - /** Function: copyElement - * Copy an XML DOM element. - * - * This function copies a DOM element and all its descendants and returns - * the new copy. - * - * Parameters: - * (XMLElement) elem - A DOM element. - * - * Returns: - * A new, copied DOM element tree. - */ - copyElement: function (elem) - { - var i, el; - if (elem.nodeType == Strophe.ElementType.NORMAL) { - el = Strophe.xmlElement(elem.tagName); - - for (i = 0; i < elem.attributes.length; i++) { - el.setAttribute(elem.attributes[i].nodeName.toLowerCase(), - elem.attributes[i].value); - } - - for (i = 0; i < elem._childNodes.length; i++) { - el.appendChild(Strophe.copyElement(elem._childNodes[i])); - } - } else if (elem.nodeType == Strophe.ElementType.TEXT) { - el = Strophe.xmlTextNode(elem.nodeValue); - } - - return el; - }, - - /** Function: escapeNode - * Escape the node part (also called local part) of a JID. - * - * Parameters: - * (String) node - A node (or local part). - * - * Returns: - * An escaped node (or local part). - */ - escapeNode: function (node) - { - return node.replace(/^\s+|\s+$/g, '') - .replace(/\\/g, "\\5c") - .replace(/ /g, "\\20") - .replace(/\"/g, "\\22") - .replace(/\&/g, "\\26") - .replace(/\'/g, "\\27") - .replace(/\//g, "\\2f") - .replace(/:/g, "\\3a") - .replace(//g, "\\3e") - .replace(/@/g, "\\40"); - }, - - /** Function: unescapeNode - * Unescape a node part (also called local part) of a JID. - * - * Parameters: - * (String) node - A node (or local part). - * - * Returns: - * An unescaped node (or local part). - */ - unescapeNode: function (node) - { - return node.replace(/\\20/g, " ") - .replace(/\\22/g, '"') - .replace(/\\26/g, "&") - .replace(/\\27/g, "'") - .replace(/\\2f/g, "/") - .replace(/\\3a/g, ":") - .replace(/\\3c/g, "<") - .replace(/\\3e/g, ">") - .replace(/\\40/g, "@") - .replace(/\\5c/g, "\\"); - }, - - /** Function: getNodeFromJid - * Get the node portion of a JID String. - * - * Parameters: - * (String) jid - A JID. - * - * Returns: - * A String containing the node. - */ - getNodeFromJid: function (jid) - { - if (jid.indexOf("@") < 0) { return null; } - return jid.split("@")[0]; - }, - - /** Function: getDomainFromJid - * Get the domain portion of a JID String. - * - * Parameters: - * (String) jid - A JID. - * - * Returns: - * A String containing the domain. - */ - getDomainFromJid: function (jid) - { - var bare = Strophe.getBareJidFromJid(jid); - if (bare.indexOf("@") < 0) { - return bare; - } else { - var parts = bare.split("@"); - parts.splice(0, 1); - return parts.join('@'); - } - }, - - /** Function: getResourceFromJid - * Get the resource portion of a JID String. - * - * Parameters: - * (String) jid - A JID. - * - * Returns: - * A String containing the resource. - */ - getResourceFromJid: function (jid) - { - var s = jid.split("/"); - if (s.length < 2) { return null; } - s.splice(0, 1); - return s.join('/'); - }, - - /** Function: getBareJidFromJid - * Get the bare JID from a JID String. - * - * Parameters: - * (String) jid - A JID. - * - * Returns: - * A String containing the bare JID. - */ - getBareJidFromJid: function (jid) - { - return jid ? jid.split("/")[0] : null; - }, - - /** Function: log - * User overrideable logging function. - * - * This function is called whenever the Strophe library calls any - * of the logging functions. The default implementation of this - * function does nothing. If client code wishes to handle the logging - * messages, it should override this with - * > Strophe.log = function (level, msg) { - * > (user code here) - * > }; - * - * Please note that data sent and received over the wire is logged - * via Strophe.Connection.rawInput() and Strophe.Connection.rawOutput(). - * - * The different levels and their meanings are - * - * DEBUG - Messages useful for debugging purposes. - * INFO - Informational messages. This is mostly information like - * 'disconnect was called' or 'SASL auth succeeded'. - * WARN - Warnings about potential problems. This is mostly used - * to report transient connection errors like request timeouts. - * ERROR - Some error occurred. - * FATAL - A non-recoverable fatal error occurred. - * - * Parameters: - * (Integer) level - The log level of the log message. This will - * be one of the values in Strophe.LogLevel. - * (String) msg - The log message. - */ - log: function (level, msg) - { - return; - }, - - /** Function: debug - * Log a message at the Strophe.LogLevel.DEBUG level. - * - * Parameters: - * (String) msg - The log message. - */ - debug: function(msg) - { - this.log(this.LogLevel.DEBUG, msg); - }, - - /** Function: info - * Log a message at the Strophe.LogLevel.INFO level. - * - * Parameters: - * (String) msg - The log message. - */ - info: function (msg) - { - this.log(this.LogLevel.INFO, msg); - }, - - /** Function: warn - * Log a message at the Strophe.LogLevel.WARN level. - * - * Parameters: - * (String) msg - The log message. - */ - warn: function (msg) - { - this.log(this.LogLevel.WARN, msg); - }, - - /** Function: error - * Log a message at the Strophe.LogLevel.ERROR level. - * - * Parameters: - * (String) msg - The log message. - */ - error: function (msg) - { - this.log(this.LogLevel.ERROR, msg); - }, - - /** Function: fatal - * Log a message at the Strophe.LogLevel.FATAL level. - * - * Parameters: - * (String) msg - The log message. - */ - fatal: function (msg) - { - this.log(this.LogLevel.FATAL, msg); - }, - - /** Function: serialize - * Render a DOM element and all descendants to a String. - * - * Parameters: - * (XMLElement) elem - A DOM element. - * - * Returns: - * The serialized element tree as a String. - */ - serialize: function (elem) - { - var result; - - if (!elem) { return null; } - - if (typeof(elem.tree) === "function") { - elem = elem.tree(); - } - - var nodeName = elem.nodeName.toLowerCase(); - var i, child; - - if (elem.getAttribute("_realname")) { - nodeName = elem.getAttribute("_realname").toLowerCase(); - } - - result = "<" + nodeName.toLowerCase(); - for (i = 0; i < elem.attributes.length; i++) { - if(elem.attributes[i].nodeName.toLowerCase() != "_realname") { - result += " " + elem.attributes[i].nodeName.toLowerCase() + - "='" + elem.attributes[i].value - .replace(/&/g, "&") - .replace(/\'/g, "'") - .replace(/ 0) { - result += ">"; - for (i = 0; i < elem._childNodes.length; i++) { - child = elem._childNodes[i]; - if (child.nodeType == Strophe.ElementType.NORMAL) { - // normal element, so recurse - result += Strophe.serialize(child); - } else if (child.nodeType == Strophe.ElementType.TEXT) { - // text element - result += child.nodeValue; - } - } - result += ""; - } else { - result += "/>"; - } - - return result; - }, - - /** PrivateVariable: _requestId - * _Private_ variable that keeps track of the request ids for - * connections. - */ - _requestId: 0, - - /** PrivateVariable: Strophe.connectionPlugins - * _Private_ variable Used to store plugin names that need - * initialization on Strophe.Connection construction. - */ - _connectionPlugins: {}, - - /** Function: addConnectionPlugin - * Extends the Strophe.Connection object with the given plugin. - * - * Paramaters: - * (String) name - The name of the extension. - * (Object) ptype - The plugin's prototype. - */ - addConnectionPlugin: function (name, ptype) - { - Strophe._connectionPlugins[name] = ptype; - } -}; - -/** Class: Strophe.Builder - * XML DOM builder. - * - * This object provides an interface similar to JQuery but for building - * DOM element easily and rapidly. All the functions except for toString() - * and tree() return the object, so calls can be chained. Here's an - * example using the $iq() builder helper. - * > $iq({to: 'you', from: 'me', type: 'get', id: '1'}) - * > .c('query', {xmlns: 'strophe:example'}) - * > .c('example') - * > .toString() - * The above generates this XML fragment - * > - * > - * > - * > - * > - * The corresponding DOM manipulations to get a similar fragment would be - * a lot more tedious and probably involve several helper variables. - * - * Since adding children makes new operations operate on the child, up() - * is provided to traverse up the tree. To add two children, do - * > builder.c('child1', ...).up().c('child2', ...) - * The next operation on the Builder will be relative to the second child. - */ - -/** Constructor: Strophe.Builder - * Create a Strophe.Builder object. - * - * The attributes should be passed in object notation. For example - * > var b = new Builder('message', {to: 'you', from: 'me'}); - * or - * > var b = new Builder('messsage', {'xml:lang': 'en'}); - * - * Parameters: - * (String) name - The name of the root element. - * (Object) attrs - The attributes for the root element in object notation. - * - * Returns: - * A new Strophe.Builder. - */ -Strophe.Builder = function (name, attrs) -{ - // Set correct namespace for jabber:client elements - if (name == "presence" || name == "message" || name == "iq") { - if (attrs && !attrs.xmlns) { - attrs.xmlns = Strophe.NS.CLIENT; - } else if (!attrs) { - attrs = {xmlns: Strophe.NS.CLIENT}; - } - } - - // Holds the tree being built. - this.nodeTree = Strophe.xmlElement(name, attrs); - - // Points to the current operation node. - this.node = this.nodeTree; -}; - -Strophe.Builder.prototype = { - /** Function: tree - * Return the DOM tree. - * - * This function returns the current DOM tree as an element object. This - * is suitable for passing to functions like Strophe.Connection.send(). - * - * Returns: - * The DOM tree as a element object. - */ - tree: function () - { - return this.nodeTree; - }, - - /** Function: toString - * Serialize the DOM tree to a String. - * - * This function returns a string serialization of the current DOM - * tree. It is often used internally to pass data to a - * Strophe.Request object. - * - * Returns: - * The serialized DOM tree in a String. - */ - toString: function () - { - return Strophe.serialize(this.nodeTree); - }, - - /** Function: up - * Make the current parent element the new current element. - * - * This function is often used after c() to traverse back up the tree. - * For example, to add two children to the same element - * > builder.c('child1', {}).up().c('child2', {}); - * - * Returns: - * The Stophe.Builder object. - */ - up: function () - { - this.node = this.node.parentNode; - return this; - }, - - /** Function: attrs - * Add or modify attributes of the current element. - * - * The attributes should be passed in object notation. This function - * does not move the current element pointer. - * - * Parameters: - * (Object) moreattrs - The attributes to add/modify in object notation. - * - * Returns: - * The Strophe.Builder object. - */ - attrs: function (moreattrs) - { - for (var k in moreattrs) { - if (moreattrs.hasOwnProperty(k)) { - this.node.setAttribute(k, moreattrs[k]); - } - } - return this; - }, - - /** Function: c - * Add a child to the current element and make it the new current - * element. - * - * This function moves the current element pointer to the child. If you - * need to add another child, it is necessary to use up() to go back - * to the parent in the tree. - * - * Parameters: - * (String) name - The name of the child. - * (Object) attrs - The attributes of the child in object notation. - * - * Returns: - * The Strophe.Builder object. - */ - c: function (name, attrs) - { - var child = Strophe.xmlElement(name, attrs); - this.node.appendChild(child); - this.node = child; - return this; - }, - - /** Function: cnode - * Add a child to the current element and make it the new current - * element. - * - * This function is the same as c() except that instead of using a - * name and an attributes object to create the child it uses an - * existing DOM element object. - * - * Parameters: - * (XMLElement) elem - A DOM element. - * - * Returns: - * The Strophe.Builder object. - */ - cnode: function (elem) - { - var xmlGen = Strophe.xmlGenerator(); - var newElem = xmlGen.importNode ? xmlGen.importNode(elem, true) : Strophe.copyElement(elem); - this.node.appendChild(newElem); - this.node = newElem; - return this; - }, - - /** Function: t - * Add a child text element. - * - * This *does not* make the child the new current element since there - * are no children of text elements. - * - * Parameters: - * (String) text - The text data to append to the current element. - * - * Returns: - * The Strophe.Builder object. - */ - t: function (text) - { - var child = Strophe.xmlTextNode(text); - this.node.appendChild(child); - return this; - } -}; - - -/** PrivateClass: Strophe.Handler - * _Private_ helper class for managing stanza handlers. - * - * A Strophe.Handler encapsulates a user provided callback function to be - * executed when matching stanzas are received by the connection. - * Handlers can be either one-off or persistant depending on their - * return value. Returning true will cause a Handler to remain active, and - * returning false will remove the Handler. - * - * Users will not use Strophe.Handler objects directly, but instead they - * will use Strophe.Connection.addHandler() and - * Strophe.Connection.deleteHandler(). - */ - -/** PrivateConstructor: Strophe.Handler - * Create and initialize a new Strophe.Handler. - * - * Parameters: - * (Function) handler - A function to be executed when the handler is run. - * (String) ns - The namespace to match. - * (String) name - The element name to match. - * (String) type - The element type to match. - * (String) id - The element id attribute to match. - * (String) from - The element from attribute to match. - * (Object) options - Handler options - * - * Returns: - * A new Strophe.Handler object. - */ -Strophe.Handler = function (handler, ns, name, type, id, from, options) -{ - this.handler = handler; - this.ns = ns; - this.name = name; - this.type = type; - this.id = id; - this.options = options || {matchbare: false}; - - // default matchBare to false if undefined - if (!this.options.matchBare) { - this.options.matchBare = false; - } - - if (this.options.matchBare) { - this.from = from ? Strophe.getBareJidFromJid(from) : null; - } else { - this.from = from; - } - - // whether the handler is a user handler or a system handler - this.user = true; -}; - -Strophe.Handler.prototype = { - /** PrivateFunction: isMatch - * Tests if a stanza matches the Strophe.Handler. - * - * Parameters: - * (XMLElement) elem - The XML element to test. - * - * Returns: - * true if the stanza matches and false otherwise. - */ - isMatch: function (elem) - { - var nsMatch; - var from = null; - - if (this.options.matchBare) { - from = Strophe.getBareJidFromJid(elem.getAttribute('from')); - } else { - from = elem.getAttribute('from'); - } - - nsMatch = false; - if (!this.ns) { - nsMatch = true; - } else { - var that = this; - Strophe.forEachChild(elem, null, function (elem) { - if (elem.getAttribute("xmlns") == that.ns) { - nsMatch = true; - } - }); - - nsMatch = nsMatch || elem.getAttribute("xmlns") == this.ns; - } - - if (nsMatch && - (!this.name || Strophe.isTagEqual(elem, this.name)) && - (!this.type || elem.getAttribute("type") == this.type) && - (!this.id || elem.getAttribute("id") == this.id) && - (!this.from || from == this.from)) { - return true; - } - - return false; - }, - - /** PrivateFunction: run - * Run the callback on a matching stanza. - * - * Parameters: - * (XMLElement) elem - The DOM element that triggered the - * Strophe.Handler. - * - * Returns: - * A boolean indicating if the handler should remain active. - */ - run: function (elem) - { - var result = null; - try { - result = this.handler(elem); - } catch (e) { - if (e.sourceURL) { - Strophe.fatal("error: " + this.handler + - " " + e.sourceURL + ":" + - e.line + " - " + e.name + ": " + e.message); - } else if (e.fileName) { - if (typeof(console) != "undefined") { - console.trace(); - console.error(this.handler, " - error - ", e, e.message); - } - Strophe.fatal("error: " + this.handler + " " + - e.fileName + ":" + e.lineNumber + " - " + - e.name + ": " + e.message); - } else { - Strophe.fatal("error: " + this.handler); - } - - throw e; - } - - return result; - }, - - /** PrivateFunction: toString - * Get a String representation of the Strophe.Handler object. - * - * Returns: - * A String. - */ - toString: function () - { - return "{Handler: " + this.handler + "(" + this.name + "," + - this.id + "," + this.ns + ")}"; - } -}; - -/** PrivateClass: Strophe.TimedHandler - * _Private_ helper class for managing timed handlers. - * - * A Strophe.TimedHandler encapsulates a user provided callback that - * should be called after a certain period of time or at regular - * intervals. The return value of the callback determines whether the - * Strophe.TimedHandler will continue to fire. - * - * Users will not use Strophe.TimedHandler objects directly, but instead - * they will use Strophe.Connection.addTimedHandler() and - * Strophe.Connection.deleteTimedHandler(). - */ - -/** PrivateConstructor: Strophe.TimedHandler - * Create and initialize a new Strophe.TimedHandler object. - * - * Parameters: - * (Integer) period - The number of milliseconds to wait before the - * handler is called. - * (Function) handler - The callback to run when the handler fires. This - * function should take no arguments. - * - * Returns: - * A new Strophe.TimedHandler object. - */ -Strophe.TimedHandler = function (period, handler) -{ - this.period = period; - this.handler = handler; - - this.lastCalled = new Date().getTime(); - this.user = true; -}; - -Strophe.TimedHandler.prototype = { - /** PrivateFunction: run - * Run the callback for the Strophe.TimedHandler. - * - * Returns: - * true if the Strophe.TimedHandler should be called again, and false - * otherwise. - */ - run: function () - { - this.lastCalled = new Date().getTime(); - return this.handler(); - }, - - /** PrivateFunction: reset - * Reset the last called time for the Strophe.TimedHandler. - */ - reset: function () - { - this.lastCalled = new Date().getTime(); - }, - - /** PrivateFunction: toString - * Get a string representation of the Strophe.TimedHandler object. - * - * Returns: - * The string representation. - */ - toString: function () - { - return "{TimedHandler: " + this.handler + "(" + this.period +")}"; - } -}; - -/** PrivateClass: Strophe.Request - * _Private_ helper class that provides a cross implementation abstraction - * for a BOSH related XMLHttpRequest. - * - * The Strophe.Request class is used internally to encapsulate BOSH request - * information. It is not meant to be used from user's code. - */ - -/** PrivateConstructor: Strophe.Request - * Create and initialize a new Strophe.Request object. - * - * Parameters: - * (XMLElement) elem - The XML data to be sent in the request. - * (Function) func - The function that will be called when the - * XMLHttpRequest readyState changes. - * (Integer) rid - The BOSH rid attribute associated with this request. - * (Integer) sends - The number of times this same request has been - * sent. - */ -Strophe.Request = function (elem, func, rid, sends) -{ - this.id = ++Strophe._requestId; - this.xmlData = elem; - this.data = Strophe.serialize(elem); - // save original function in case we need to make a new request - // from this one. - this.origFunc = func; - this.func = func; - this.rid = rid; - this.date = NaN; - this.sends = sends || 0; - this.abort = false; - this.dead = null; - this.age = function () { - if (!this.date) { return 0; } - var now = new Date(); - return (now - this.date) / 1000; - }; - this.timeDead = function () { - if (!this.dead) { return 0; } - var now = new Date(); - return (now - this.dead) / 1000; - }; - this.xhr = this._newXHR(); -}; - -Strophe.Request.prototype = { - /** PrivateFunction: getResponse - * Get a response from the underlying XMLHttpRequest. - * - * This function attempts to get a response from the request and checks - * for errors. - * - * Throws: - * "parsererror" - A parser error occured. - * - * Returns: - * The DOM element tree of the response. - */ - getResponse: function () - { - // console.log("getResponse:", this.xhr.responseXML, ":", this.xhr.responseText); - - var node = null; - if (this.xhr.responseXML && this.xhr.responseXML.documentElement) { - node = this.xhr.responseXML.documentElement; - if (node.tagName == "parsererror") { - Strophe.error("invalid response received"); - Strophe.error("responseText: " + this.xhr.responseText); - Strophe.error("responseXML: " + - Strophe.serialize(this.xhr.responseXML)); - throw "parsererror"; - } - } else if (this.xhr.responseText) { - // Hack for node. - var _div = document.createElement("div"); - _div.innerHTML = this.xhr.responseText; - node = _div._childNodes[0]; - - Strophe.error("invalid response received"); - Strophe.error("responseText: " + this.xhr.responseText); - Strophe.error("responseXML: " + - Strophe.serialize(this.xhr.responseXML)); - } - - return node; - }, - - /** PrivateFunction: _newXHR - * _Private_ helper function to create XMLHttpRequests. - * - * This function creates XMLHttpRequests across all implementations. - * - * Returns: - * A new XMLHttpRequest. - */ - _newXHR: function () - { - var xhr = null; - if (window.XMLHttpRequest) { - xhr = new XMLHttpRequest(); - if (xhr.overrideMimeType) { - xhr.overrideMimeType("text/xml"); - } - } else if (window.ActiveXObject) { - xhr = new ActiveXObject("Microsoft.XMLHTTP"); - } - - // use Function.bind() to prepend ourselves as an argument - xhr.onreadystatechange = this.func.bind(null, this); - - return xhr; - } -}; - -/** Class: Strophe.Connection - * XMPP Connection manager. - * - * Thie class is the main part of Strophe. It manages a BOSH connection - * to an XMPP server and dispatches events to the user callbacks as - * data arrives. It supports SASL PLAIN, SASL DIGEST-MD5, and legacy - * authentication. - * - * After creating a Strophe.Connection object, the user will typically - * call connect() with a user supplied callback to handle connection level - * events like authentication failure, disconnection, or connection - * complete. - * - * The user will also have several event handlers defined by using - * addHandler() and addTimedHandler(). These will allow the user code to - * respond to interesting stanzas or do something periodically with the - * connection. These handlers will be active once authentication is - * finished. - * - * To send data to the connection, use send(). - */ - -/** Constructor: Strophe.Connection - * Create and initialize a Strophe.Connection object. - * - * Parameters: - * (String) service - The BOSH service URL. - * - * Returns: - * A new Strophe.Connection object. - */ -Strophe.Connection = function (service) -{ - /* The path to the httpbind service. */ - this.service = service; - /* The connected JID. */ - this.jid = ""; - /* request id for body tags */ - this.rid = Math.floor(Math.random() * 4294967295); - /* The current session ID. */ - this.sid = null; - this.streamId = null; - /* stream:features */ - this.features = null; - - // SASL - this.do_session = false; - this.do_bind = false; - - // handler lists - this.timedHandlers = []; - this.handlers = []; - this.removeTimeds = []; - this.removeHandlers = []; - this.addTimeds = []; - this.addHandlers = []; - - this._idleTimeout = null; - this._disconnectTimeout = null; - - this.authenticated = false; - this.disconnecting = false; - this.connected = false; - - this.errors = 0; - - this.paused = false; - - // default BOSH values - this.hold = 1; - this.wait = 60; - this.window = 5; - - this._data = []; - this._requests = []; - this._uniqueId = Math.round(Math.random() * 10000); - - this._sasl_success_handler = null; - this._sasl_failure_handler = null; - this._sasl_challenge_handler = null; - - // setup onIdle callback every 1/10th of a second - this._idleTimeout = setTimeout(this._onIdle.bind(this), 100); - - // initialize plugins - for (var k in Strophe._connectionPlugins) { - if (Strophe._connectionPlugins.hasOwnProperty(k)) { - var ptype = Strophe._connectionPlugins[k]; - // jslint complaints about the below line, but this is fine - var F = function () {}; - F.prototype = ptype; - this[k] = new F(); - this[k].init(this); - } - } -}; - -Strophe.Connection.prototype = { - /** Function: reset - * Reset the connection. - * - * This function should be called after a connection is disconnected - * before that connection is reused. - */ - reset: function () - { - this.rid = Math.floor(Math.random() * 4294967295); - - this.sid = null; - this.streamId = null; - - // SASL - this.do_session = false; - this.do_bind = false; - - // handler lists - this.timedHandlers = []; - this.handlers = []; - this.removeTimeds = []; - this.removeHandlers = []; - this.addTimeds = []; - this.addHandlers = []; - - this.authenticated = false; - this.disconnecting = false; - this.connected = false; - - this.errors = 0; - - this._requests = []; - this._uniqueId = Math.round(Math.random()*10000); - }, - - /** Function: pause - * Pause the request manager. - * - * This will prevent Strophe from sending any more requests to the - * server. This is very useful for temporarily pausing while a lot - * of send() calls are happening quickly. This causes Strophe to - * send the data in a single request, saving many request trips. - */ - pause: function () - { - this.paused = true; - }, - - /** Function: resume - * Resume the request manager. - * - * This resumes after pause() has been called. - */ - resume: function () - { - this.paused = false; - }, - - /** Function: getUniqueId - * Generate a unique ID for use in elements. - * - * All stanzas are required to have unique id attributes. This - * function makes creating these easy. Each connection instance has - * a counter which starts from zero, and the value of this counter - * plus a colon followed by the suffix becomes the unique id. If no - * suffix is supplied, the counter is used as the unique id. - * - * Suffixes are used to make debugging easier when reading the stream - * data, and their use is recommended. The counter resets to 0 for - * every new connection for the same reason. For connections to the - * same server that authenticate the same way, all the ids should be - * the same, which makes it easy to see changes. This is useful for - * automated testing as well. - * - * Parameters: - * (String) suffix - A optional suffix to append to the id. - * - * Returns: - * A unique string to be used for the id attribute. - */ - getUniqueId: function (suffix) - { - if (typeof(suffix) == "string" || typeof(suffix) == "number") { - return ++this._uniqueId + ":" + suffix; - } else { - return ++this._uniqueId + ""; - } - }, - - /** Function: connect - * Starts the connection process. - * - * As the connection process proceeds, the user supplied callback will - * be triggered multiple times with status updates. The callback - * should take two arguments - the status code and the error condition. - * - * The status code will be one of the values in the Strophe.Status - * constants. The error condition will be one of the conditions - * defined in RFC 3920 or the condition 'strophe-parsererror'. - * - * Please see XEP 124 for a more detailed explanation of the optional - * parameters below. - * - * Parameters: - * (String) jid - The user's JID. This may be a bare JID, - * or a full JID. If a node is not supplied, SASL ANONYMOUS - * authentication will be attempted. - * (String) pass - The user's password. - * (Function) callback The connect callback function. - * (Integer) wait - The optional HTTPBIND wait value. This is the - * time the server will wait before returning an empty result for - * a request. The default setting of 60 seconds is recommended. - * Other settings will require tweaks to the Strophe.TIMEOUT value. - * (Integer) hold - The optional HTTPBIND hold value. This is the - * number of connections the server will hold at one time. This - * should almost always be set to 1 (the default). - */ - connect: function (jid, pass, callback, wait, hold, route) - { - this.jid = jid; - this.pass = pass; - this.connect_callback = callback; - this.disconnecting = false; - this.connected = false; - this.authenticated = false; - this.errors = 0; - - this.wait = wait || this.wait; - this.hold = hold || this.hold; - - // parse jid for domain and resource - this.domain = Strophe.getDomainFromJid(this.jid); - - // build the body tag - var body_attrs = { - to: this.domain, - "xml:lang": "en", - wait: this.wait, - hold: this.hold, - content: "text/xml; charset=utf-8", - ver: "1.6", - "xmpp:version": "1.0", - "xmlns:xmpp": Strophe.NS.BOSH - }; - if (route) { - body_attrs.route = route; - } - - var body = this._buildBody().attrs(body_attrs); - - this._changeConnectStatus(Strophe.Status.CONNECTING, null); - - this._requests.push( - new Strophe.Request(body.tree(), - this._onRequestStateChange.bind( - this, this._connect_cb.bind(this)), - body.tree().getAttribute("rid"))); - this._throttledRequestHandler(); - }, - - /** Function: attach - * Attach to an already created and authenticated BOSH session. - * - * This function is provided to allow Strophe to attach to BOSH - * sessions which have been created externally, perhaps by a Web - * application. This is often used to support auto-login type features - * without putting user credentials into the page. - * - * Parameters: - * (String) jid - The full JID that is bound by the session. - * (String) sid - The SID of the BOSH session. - * (String) rid - The current RID of the BOSH session. This RID - * will be used by the next request. - * (Function) callback The connect callback function. - * (Integer) wait - The optional HTTPBIND wait value. This is the - * time the server will wait before returning an empty result for - * a request. The default setting of 60 seconds is recommended. - * Other settings will require tweaks to the Strophe.TIMEOUT value. - * (Integer) hold - The optional HTTPBIND hold value. This is the - * number of connections the server will hold at one time. This - * should almost always be set to 1 (the default). - * (Integer) wind - The optional HTTBIND window value. This is the - * allowed range of request ids that are valid. The default is 5. - */ - attach: function (jid, sid, rid, callback, wait, hold, wind) - { - this.jid = jid; - this.sid = sid; - this.rid = rid; - this.connect_callback = callback; - - this.domain = Strophe.getDomainFromJid(this.jid); - - this.authenticated = true; - this.connected = true; - - this.wait = wait || this.wait; - this.hold = hold || this.hold; - this.window = wind || this.window; - - this._changeConnectStatus(Strophe.Status.ATTACHED, null); - }, - - /** Function: xmlInput - * User overrideable function that receives XML data coming into the - * connection. - * - * The default function does nothing. User code can override this with - * > Strophe.Connection.xmlInput = function (elem) { - * > (user code) - * > }; - * - * Parameters: - * (XMLElement) elem - The XML data received by the connection. - */ - xmlInput: function (elem) - { - return; - }, - - /** Function: xmlOutput - * User overrideable function that receives XML data sent to the - * connection. - * - * The default function does nothing. User code can override this with - * > Strophe.Connection.xmlOutput = function (elem) { - * > (user code) - * > }; - * - * Parameters: - * (XMLElement) elem - The XMLdata sent by the connection. - */ - xmlOutput: function (elem) - { - return; - }, - - /** Function: rawInput - * User overrideable function that receives raw data coming into the - * connection. - * - * The default function does nothing. User code can override this with - * > Strophe.Connection.rawInput = function (data) { - * > (user code) - * > }; - * - * Parameters: - * (String) data - The data received by the connection. - */ - rawInput: function (data) - { - return; - }, - - /** Function: rawOutput - * User overrideable function that receives raw data sent to the - * connection. - * - * The default function does nothing. User code can override this with - * > Strophe.Connection.rawOutput = function (data) { - * > (user code) - * > }; - * - * Parameters: - * (String) data - The data sent by the connection. - */ - rawOutput: function (data) - { - return; - }, - - /** Function: send - * Send a stanza. - * - * This function is called to push data onto the send queue to - * go out over the wire. Whenever a request is sent to the BOSH - * server, all pending data is sent and the queue is flushed. - * - * Parameters: - * (XMLElement | - * [XMLElement] | - * Strophe.Builder) elem - The stanza to send. - */ - send: function (elem) - { - if (elem === null) { return ; } - if (typeof(elem.sort) === "function") { - for (var i = 0; i < elem.length; i++) { - this._queueData(elem[i]); - } - } else if (typeof(elem.tree) === "function") { - this._queueData(elem.tree()); - } else { - this._queueData(elem); - } - - this._throttledRequestHandler(); - clearTimeout(this._idleTimeout); - this._idleTimeout = setTimeout(this._onIdle.bind(this), 100); - }, - - /** Function: flush - * Immediately send any pending outgoing data. - * - * Normally send() queues outgoing data until the next idle period - * (100ms), which optimizes network use in the common cases when - * several send()s are called in succession. flush() can be used to - * immediately send all pending data. - */ - flush: function () - { - // cancel the pending idle period and run the idle function - // immediately - clearTimeout(this._idleTimeout); - this._onIdle(); - }, - - /** Function: sendIQ - * Helper function to send IQ stanzas. - * - * Parameters: - * (XMLElement) elem - The stanza to send. - * (Function) callback - The callback function for a successful request. - * (Function) errback - The callback function for a failed or timed - * out request. On timeout, the stanza will be null. - * (Integer) timeout - The time specified in milliseconds for a - * timeout to occur. - * - * Returns: - * The id used to send the IQ. - */ - sendIQ: function(elem, callback, errback, timeout) { - var timeoutHandler = null; - var that = this; - - if (typeof(elem.tree) === "function") { - elem = elem.tree(); - } - var id = elem.getAttribute('id'); - - // inject id if not found - if (!id) { - id = this.getUniqueId("sendIQ"); - elem.setAttribute("id", id); - } - - var handler = this.addHandler(function (stanza) { - // remove timeout handler if there is one - if (timeoutHandler) { - that.deleteTimedHandler(timeoutHandler); - } - - var iqtype = stanza.getAttribute('type'); - if (iqtype == 'result') { - if (callback) { - callback(stanza); - } - } else if (iqtype == 'error') { - if (errback) { - errback(stanza); - } - } else { - throw { - name: "StropheError", - message: "Got bad IQ type of " + iqtype - }; - } - }, null, 'iq', null, id); - - // if timeout specified, setup timeout handler. - if (timeout) { - timeoutHandler = this.addTimedHandler(timeout, function () { - // get rid of normal handler - that.deleteHandler(handler); - - // call errback on timeout with null stanza - if (errback) { - errback(null); - } - return false; - }); - } - - this.send(elem); - - return id; - }, - - /** PrivateFunction: _queueData - * Queue outgoing data for later sending. Also ensures that the data - * is a DOMElement. - */ - _queueData: function (element) { - if (element === null || - !element.tagName || - !element._childNodes) { - throw { - name: "StropheError", - message: "Cannot queue non-DOMElement." - }; - } - - this._data.push(element); - }, - - /** PrivateFunction: _sendRestart - * Send an xmpp:restart stanza. - */ - _sendRestart: function () - { - this._data.push("restart"); - - this._throttledRequestHandler(); - clearTimeout(this._idleTimeout); - this._idleTimeout = setTimeout(this._onIdle.bind(this), 100); - }, - - /** Function: addTimedHandler - * Add a timed handler to the connection. - * - * This function adds a timed handler. The provided handler will - * be called every period milliseconds until it returns false, - * the connection is terminated, or the handler is removed. Handlers - * that wish to continue being invoked should return true. - * - * Because of method binding it is necessary to save the result of - * this function if you wish to remove a handler with - * deleteTimedHandler(). - * - * Note that user handlers are not active until authentication is - * successful. - * - * Parameters: - * (Integer) period - The period of the handler. - * (Function) handler - The callback function. - * - * Returns: - * A reference to the handler that can be used to remove it. - */ - addTimedHandler: function (period, handler) - { - var thand = new Strophe.TimedHandler(period, handler); - this.addTimeds.push(thand); - return thand; - }, - - /** Function: deleteTimedHandler - * Delete a timed handler for a connection. - * - * This function removes a timed handler from the connection. The - * handRef parameter is *not* the function passed to addTimedHandler(), - * but is the reference returned from addTimedHandler(). - * - * Parameters: - * (Strophe.TimedHandler) handRef - The handler reference. - */ - deleteTimedHandler: function (handRef) - { - // this must be done in the Idle loop so that we don't change - // the handlers during iteration - this.removeTimeds.push(handRef); - }, - - /** Function: addHandler - * Add a stanza handler for the connection. - * - * This function adds a stanza handler to the connection. The - * handler callback will be called for any stanza that matches - * the parameters. Note that if multiple parameters are supplied, - * they must all match for the handler to be invoked. - * - * The handler will receive the stanza that triggered it as its argument. - * The handler should return true if it is to be invoked again; - * returning false will remove the handler after it returns. - * - * As a convenience, the ns parameters applies to the top level element - * and also any of its immediate children. This is primarily to make - * matching /iq/query elements easy. - * - * The options argument contains handler matching flags that affect how - * matches are determined. Currently the only flag is matchBare (a - * boolean). When matchBare is true, the from parameter and the from - * attribute on the stanza will be matched as bare JIDs instead of - * full JIDs. To use this, pass {matchBare: true} as the value of - * options. The default value for matchBare is false. - * - * The return value should be saved if you wish to remove the handler - * with deleteHandler(). - * - * Parameters: - * (Function) handler - The user callback. - * (String) ns - The namespace to match. - * (String) name - The stanza name to match. - * (String) type - The stanza type attribute to match. - * (String) id - The stanza id attribute to match. - * (String) from - The stanza from attribute to match. - * (String) options - The handler options - * - * Returns: - * A reference to the handler that can be used to remove it. - */ - addHandler: function (handler, ns, name, type, id, from, options) - { - var hand = new Strophe.Handler(handler, ns, name, type, id, from, options); - this.addHandlers.push(hand); - return hand; - }, - - /** Function: deleteHandler - * Delete a stanza handler for a connection. - * - * This function removes a stanza handler from the connection. The - * handRef parameter is *not* the function passed to addHandler(), - * but is the reference returned from addHandler(). - * - * Parameters: - * (Strophe.Handler) handRef - The handler reference. - */ - deleteHandler: function (handRef) - { - // this must be done in the Idle loop so that we don't change - // the handlers during iteration - this.removeHandlers.push(handRef); - }, - - /** Function: disconnect - * Start the graceful disconnection process. - * - * This function starts the disconnection process. This process starts - * by sending unavailable presence and sending BOSH body of type - * terminate. A timeout handler makes sure that disconnection happens - * even if the BOSH server does not respond. - * - * The user supplied connection callback will be notified of the - * progress as this process happens. - * - * Parameters: - * (String) reason - The reason the disconnect is occuring. - */ - disconnect: function (reason) - { - this._changeConnectStatus(Strophe.Status.DISCONNECTING, reason); - - Strophe.info("Disconnect was called because: " + reason); - if (this.connected) { - // setup timeout handler - this._disconnectTimeout = this._addSysTimedHandler( - 3000, this._onDisconnectTimeout.bind(this)); - this._sendTerminate(); - } - }, - - /** PrivateFunction: _changeConnectStatus - * _Private_ helper function that makes sure plugins and the user's - * callback are notified of connection status changes. - * - * Parameters: - * (Integer) status - the new connection status, one of the values - * in Strophe.Status - * (String) condition - the error condition or null - */ - _changeConnectStatus: function (status, condition) - { - // notify all plugins listening for status changes - for (var k in Strophe._connectionPlugins) { - if (Strophe._connectionPlugins.hasOwnProperty(k)) { - var plugin = this[k]; - if (plugin.statusChanged) { - try { - plugin.statusChanged(status, condition); - } catch (err) { - Strophe.error("" + k + " plugin caused an exception " + - "changing status: " + err); - } - } - } - } - - // notify the user's callback - if (this.connect_callback) { - try { - this.connect_callback(status, condition); - } catch (e) { - Strophe.error("User connection callback caused an " + - "exception: " + e); - } - } - }, - - /** PrivateFunction: _buildBody - * _Private_ helper function to generate the wrapper for BOSH. - * - * Returns: - * A Strophe.Builder with a element. - */ - _buildBody: function () - { - var bodyWrap = $build('body', { - rid: this.rid++, - xmlns: Strophe.NS.HTTPBIND - }); - - if (this.sid !== null) { - bodyWrap.attrs({sid: this.sid}); - } - - return bodyWrap; - }, - - /** PrivateFunction: _removeRequest - * _Private_ function to remove a request from the queue. - * - * Parameters: - * (Strophe.Request) req - The request to remove. - */ - _removeRequest: function (req) - { - Strophe.debug("removing request"); - - var i; - for (i = this._requests.length - 1; i >= 0; i--) { - if (req == this._requests[i]) { - this._requests.splice(i, 1); - } - } - - // IE6 fails on setting to null, so set to empty function - req.xhr.onreadystatechange = function () {}; - - this._throttledRequestHandler(); - }, - - /** PrivateFunction: _restartRequest - * _Private_ function to restart a request that is presumed dead. - * - * Parameters: - * (Integer) i - The index of the request in the queue. - */ - _restartRequest: function (i) - { - var req = this._requests[i]; - if (req.dead === null) { - req.dead = new Date(); - } - - this._processRequest(i); - }, - - /** PrivateFunction: _processRequest - * _Private_ function to process a request in the queue. - * - * This function takes requests off the queue and sends them and - * restarts dead requests. - * - * Parameters: - * (Integer) i - The index of the request in the queue. - */ - _processRequest: function (i) - { - var req = this._requests[i]; - var reqStatus = -1; - - try { - if (req.xhr.readyState == 4) { - reqStatus = req.xhr.status; - } - } catch (e) { - Strophe.error("caught an error in _requests[" + i + - "], reqStatus: " + reqStatus); - } - - if (typeof(reqStatus) == "undefined") { - reqStatus = -1; - } - - // make sure we limit the number of retries - if (req.sends > 5) { - this._onDisconnectTimeout(); - return; - } - - var time_elapsed = req.age(); - var primaryTimeout = (!isNaN(time_elapsed) && - time_elapsed > Math.floor(Strophe.TIMEOUT * this.wait)); - var secondaryTimeout = (req.dead !== null && - req.timeDead() > Math.floor(Strophe.SECONDARY_TIMEOUT * this.wait)); - var requestCompletedWithServerError = (req.xhr.readyState == 4 && - (reqStatus < 1 || - reqStatus >= 500)); - if (primaryTimeout || secondaryTimeout || - requestCompletedWithServerError) { - if (secondaryTimeout) { - Strophe.error("Request " + - this._requests[i].id + - " timed out (secondary), restarting"); - } - req.abort = true; - req.xhr.abort(); - // setting to null fails on IE6, so set to empty function - req.xhr.onreadystatechange = function () {}; - this._requests[i] = new Strophe.Request(req.xmlData, - req.origFunc, - req.rid, - req.sends); - req = this._requests[i]; - } - - if (req.xhr.readyState === 0) { - Strophe.debug("request id " + req.id + - "." + req.sends + " posting"); - - req.date = new Date(); - try { - req.xhr.open("POST", this.service, true); - } catch (e2) { - Strophe.error("XHR open failed."); - if (!this.connected) { - this._changeConnectStatus(Strophe.Status.CONNFAIL, - "bad-service"); - } - this.disconnect(); - return; - } - - // Fires the XHR request -- may be invoked immediately - // or on a gradually expanding retry window for reconnects - var sendFunc = function () { - req.xhr.send(req.data); - }; - - // Implement progressive backoff for reconnects -- - // First retry (send == 1) should also be instantaneous - if (req.sends > 1) { - // Using a cube of the retry number creats a nicely - // expanding retry window - var backoff = Math.pow(req.sends, 3) * 1000; - setTimeout(sendFunc, backoff); - } else { - sendFunc(); - } - - req.sends++; - - this.xmlOutput(req.xmlData); - this.rawOutput(req.data); - } else { - Strophe.debug("_processRequest: " + - (i === 0 ? "first" : "second") + - " request has readyState of " + - req.xhr.readyState); - } - }, - - /** PrivateFunction: _throttledRequestHandler - * _Private_ function to throttle requests to the connection window. - * - * This function makes sure we don't send requests so fast that the - * request ids overflow the connection window in the case that one - * request died. - */ - _throttledRequestHandler: function () - { - if (!this._requests) { - Strophe.debug("_throttledRequestHandler called with " + - "undefined requests"); - } else { - Strophe.debug("_throttledRequestHandler called with " + - this._requests.length + " requests"); - } - - if (!this._requests || this._requests.length === 0) { - return; - } - - if (this._requests.length > 0) { - this._processRequest(0); - } - - if (this._requests.length > 1 && - Math.abs(this._requests[0].rid - - this._requests[1].rid) < this.window) { - this._processRequest(1); - } - }, - - /** PrivateFunction: _onRequestStateChange - * _Private_ handler for Strophe.Request state changes. - * - * This function is called when the XMLHttpRequest readyState changes. - * It contains a lot of error handling logic for the many ways that - * requests can fail, and calls the request callback when requests - * succeed. - * - * Parameters: - * (Function) func - The handler for the request. - * (Strophe.Request) req - The request that is changing readyState. - */ - _onRequestStateChange: function (func, req) - { - Strophe.debug("request id " + req.id + - "." + req.sends + " state changed to " + - req.xhr.readyState); - - if (req.abort) { - req.abort = false; - return; - } - - // request complete - var reqStatus; - if (req.xhr.readyState == 4) { - reqStatus = 0; - try { - reqStatus = req.xhr.status; - } catch (e) { - // ignore errors from undefined status attribute. works - // around a browser bug - } - - if (typeof(reqStatus) == "undefined") { - reqStatus = 0; - } - - if (this.disconnecting) { - if (reqStatus >= 400) { - this._hitError(reqStatus); - return; - } - } - - var reqIs0 = (this._requests[0] == req); - var reqIs1 = (this._requests[1] == req); - - if ((reqStatus > 0 && reqStatus < 500) || req.sends > 5) { - // remove from internal queue - this._removeRequest(req); - Strophe.debug("request id " + - req.id + - " should now be removed"); - } - - // request succeeded - if (reqStatus == 200) { - // if request 1 finished, or request 0 finished and request - // 1 is over Strophe.SECONDARY_TIMEOUT seconds old, we need to - // restart the other - both will be in the first spot, as the - // completed request has been removed from the queue already - if (reqIs1 || - (reqIs0 && this._requests.length > 0 && - this._requests[0].age() > Math.floor(Strophe.SECONDARY_TIMEOUT * this.wait))) { - this._restartRequest(0); - } - // call handler - Strophe.debug("request id " + - req.id + "." + - req.sends + " got 200"); - func(req); - this.errors = 0; - } else { - Strophe.error("request id " + - req.id + "." + - req.sends + " error " + reqStatus + - " happened"); - if (reqStatus === 0 || - (reqStatus >= 400 && reqStatus < 600) || - reqStatus >= 12000) { - this._hitError(reqStatus); - if (reqStatus >= 400 && reqStatus < 500) { - this._changeConnectStatus(Strophe.Status.DISCONNECTING, - null); - this._doDisconnect(); - } - } - } - - if (!((reqStatus > 0 && reqStatus < 500) || - req.sends > 5)) { - this._throttledRequestHandler(); - } - } - }, - - /** PrivateFunction: _hitError - * _Private_ function to handle the error count. - * - * Requests are resent automatically until their error count reaches - * 5. Each time an error is encountered, this function is called to - * increment the count and disconnect if the count is too high. - * - * Parameters: - * (Integer) reqStatus - The request status. - */ - _hitError: function (reqStatus) - { - this.errors++; - Strophe.warn("request errored, status: " + reqStatus + - ", number of errors: " + this.errors); - if (this.errors > 4) { - this._onDisconnectTimeout(); - } - }, - - /** PrivateFunction: _doDisconnect - * _Private_ function to disconnect. - * - * This is the last piece of the disconnection logic. This resets the - * connection and alerts the user's connection callback. - */ - _doDisconnect: function () - { - Strophe.info("_doDisconnect was called"); - this.authenticated = false; - this.disconnecting = false; - this.sid = null; - this.streamId = null; - this.rid = Math.floor(Math.random() * 4294967295); - - // tell the parent we disconnected - if (this.connected) { - this._changeConnectStatus(Strophe.Status.DISCONNECTED, null); - this.connected = false; - } - - // delete handlers - this.handlers = []; - this.timedHandlers = []; - this.removeTimeds = []; - this.removeHandlers = []; - this.addTimeds = []; - this.addHandlers = []; - }, - - /** PrivateFunction: _dataRecv - * _Private_ handler to processes incoming data from the the connection. - * - * Except for _connect_cb handling the initial connection request, - * this function handles the incoming data for all requests. This - * function also fires stanza handlers that match each incoming - * stanza. - * - * Parameters: - * (Strophe.Request) req - The request that has data ready. - */ - _dataRecv: function (req) - { - try { - var elem = req.getResponse(); - } catch (e) { - if (e != "parsererror") { throw e; } - this.disconnect("strophe-parsererror"); - } - if (elem === null) { return; } - - this.xmlInput(elem); - this.rawInput(Strophe.serialize(elem)); - - // remove handlers scheduled for deletion - var i, hand; - while (this.removeHandlers.length > 0) { - hand = this.removeHandlers.pop(); - i = this.handlers.indexOf(hand); - if (i >= 0) { - this.handlers.splice(i, 1); - } - } - - // add handlers scheduled for addition - while (this.addHandlers.length > 0) { - this.handlers.push(this.addHandlers.pop()); - } - - // handle graceful disconnect - if (this.disconnecting && this._requests.length === 0) { - this.deleteTimedHandler(this._disconnectTimeout); - this._disconnectTimeout = null; - this._doDisconnect(); - return; - } - - var typ = elem.getAttribute("type"); - var cond, conflict; - if (typ !== null && typ == "terminate") { - // Don't process stanzas that come in after disconnect - if (this.disconnecting) { - return; - } - - // an error occurred - cond = elem.getAttribute("condition"); - conflict = elem.getElementsByTagName("conflict"); - if (cond !== null) { - if (cond == "remote-stream-error" && conflict.length > 0) { - cond = "conflict"; - } - this._changeConnectStatus(Strophe.Status.CONNFAIL, cond); - } else { - this._changeConnectStatus(Strophe.Status.CONNFAIL, "unknown"); - } - this.disconnect(); - return; - } - - // send each incoming stanza through the handler chain - var that = this; - Strophe.forEachChild(elem, null, function (child) { - var i, newList; - // process handlers - newList = that.handlers; - that.handlers = []; - for (i = 0; i < newList.length; i++) { - var hand = newList[i]; - if (hand.isMatch(child) && - (that.authenticated || !hand.user)) { - if (hand.run(child)) { - that.handlers.push(hand); - } - } else { - that.handlers.push(hand); - } - } - }); - }, - - /** PrivateFunction: _sendTerminate - * _Private_ function to send initial disconnect sequence. - * - * This is the first step in a graceful disconnect. It sends - * the BOSH server a terminate body and includes an unavailable - * presence if authentication has completed. - */ - _sendTerminate: function () - { - Strophe.info("_sendTerminate was called"); - var body = this._buildBody().attrs({type: "terminate"}); - - if (this.authenticated) { - body.c('presence', { - xmlns: Strophe.NS.CLIENT, - type: 'unavailable' - }); - } - - this.disconnecting = true; - - var req = new Strophe.Request(body.tree(), - this._onRequestStateChange.bind( - this, this._dataRecv.bind(this)), - body.tree().getAttribute("rid")); - - this._requests.push(req); - this._throttledRequestHandler(); - }, - - /** PrivateFunction: _connect_cb - * _Private_ handler for initial connection request. - * - * This handler is used to process the initial connection request - * response from the BOSH server. It is used to set up authentication - * handlers and start the authentication process. - * - * SASL authentication will be attempted if available, otherwise - * the code will fall back to legacy authentication. - * - * Parameters: - * (Strophe.Request) req - The current request. - */ - _connect_cb: function (req) - { - Strophe.info("_connect_cb was called"); - - this.connected = true; - var bodyWrap = req.getResponse(); - if (!bodyWrap) { return; } - - this.xmlInput(bodyWrap); - this.rawInput(Strophe.serialize(bodyWrap)); - - var typ = bodyWrap.getAttribute("type"); - var cond, conflict; - if (typ !== null && typ == "terminate") { - // an error occurred - cond = bodyWrap.getAttribute("condition"); - conflict = bodyWrap.getElementsByTagName("conflict"); - if (cond !== null) { - if (cond == "remote-stream-error" && conflict.length > 0) { - cond = "conflict"; - } - this._changeConnectStatus(Strophe.Status.CONNFAIL, cond); - } else { - this._changeConnectStatus(Strophe.Status.CONNFAIL, "unknown"); - } - return; - } - - // check to make sure we don't overwrite these if _connect_cb is - // called multiple times in the case of missing stream:features - if (!this.sid) { - this.sid = bodyWrap.getAttribute("sid"); - } - if (!this.stream_id) { - this.stream_id = bodyWrap.getAttribute("authid"); - } - var wind = bodyWrap.getAttribute('requests'); - if (wind) { this.window = parseInt(wind, 10); } - var hold = bodyWrap.getAttribute('hold'); - if (hold) { this.hold = parseInt(hold, 10); } - var wait = bodyWrap.getAttribute('wait'); - if (wait) { this.wait = parseInt(wait, 10); } - - - var do_sasl_plain = false; - var do_sasl_digest_md5 = false; - var do_sasl_anonymous = false; - - var mechanisms = bodyWrap.getElementsByTagName("mechanism"); - var i, mech, auth_str, hashed_auth_str; - if (mechanisms.length > 0) { - for (i = 0; i < mechanisms.length; i++) { - mech = Strophe.getText(mechanisms[i]); - if (mech == 'DIGEST-MD5') { - do_sasl_digest_md5 = true; - } else if (mech == 'PLAIN') { - do_sasl_plain = true; - } else if (mech == 'ANONYMOUS') { - do_sasl_anonymous = true; - } - } - } else { - // we didn't get stream:features yet, so we need wait for it - // by sending a blank poll request - var body = this._buildBody(); - this._requests.push( - new Strophe.Request(body.tree(), - this._onRequestStateChange.bind( - this, this._connect_cb.bind(this)), - body.tree().getAttribute("rid"))); - this._throttledRequestHandler(); - return; - } - - if (Strophe.getNodeFromJid(this.jid) === null && - do_sasl_anonymous) { - this._changeConnectStatus(Strophe.Status.AUTHENTICATING, null); - this._sasl_success_handler = this._addSysHandler( - this._sasl_success_cb.bind(this), null, - "success", null, null); - this._sasl_failure_handler = this._addSysHandler( - this._sasl_failure_cb.bind(this), null, - "failure", null, null); - - this.send($build("auth", { - xmlns: Strophe.NS.SASL, - mechanism: "ANONYMOUS" - }).tree()); - } else if (Strophe.getNodeFromJid(this.jid) === null) { - // we don't have a node, which is required for non-anonymous - // client connections - this._changeConnectStatus(Strophe.Status.CONNFAIL, - 'x-strophe-bad-non-anon-jid'); - this.disconnect(); - } else if (do_sasl_digest_md5) { - this._changeConnectStatus(Strophe.Status.AUTHENTICATING, null); - this._sasl_challenge_handler = this._addSysHandler( - this._sasl_challenge1_cb.bind(this), null, - "challenge", null, null); - this._sasl_failure_handler = this._addSysHandler( - this._sasl_failure_cb.bind(this), null, - "failure", null, null); - - this.send($build("auth", { - xmlns: Strophe.NS.SASL, - mechanism: "DIGEST-MD5" - }).tree()); - } else if (do_sasl_plain) { - // Build the plain auth string (barejid null - // username null password) and base 64 encoded. - auth_str = Strophe.getBareJidFromJid(this.jid); - auth_str = auth_str + "\u0000"; - auth_str = auth_str + Strophe.getNodeFromJid(this.jid); - auth_str = auth_str + "\u0000"; - auth_str = auth_str + this.pass; - - this._changeConnectStatus(Strophe.Status.AUTHENTICATING, null); - this._sasl_success_handler = this._addSysHandler( - this._sasl_success_cb.bind(this), null, - "success", null, null); - this._sasl_failure_handler = this._addSysHandler( - this._sasl_failure_cb.bind(this), null, - "failure", null, null); - - hashed_auth_str = Base64.encode(auth_str); - this.send($build("auth", { - xmlns: Strophe.NS.SASL, - mechanism: "PLAIN" - }).t(hashed_auth_str).tree()); - } else { - this._changeConnectStatus(Strophe.Status.AUTHENTICATING, null); - this._addSysHandler(this._auth1_cb.bind(this), null, null, - null, "_auth_1"); - - this.send($iq({ - type: "get", - to: this.domain, - id: "_auth_1" - }).c("query", { - xmlns: Strophe.NS.AUTH - }).c("username", {}).t(Strophe.getNodeFromJid(this.jid)).tree()); - } - }, - - /** PrivateFunction: _sasl_challenge1_cb - * _Private_ handler for DIGEST-MD5 SASL authentication. - * - * Parameters: - * (XMLElement) elem - The challenge stanza. - * - * Returns: - * false to remove the handler. - */ - _sasl_challenge1_cb: function (elem) - { - var attribMatch = /([a-z]+)=("[^"]+"|[^,"]+)(?:,|$)/; - - var challenge = Base64.decode(Strophe.getText(elem)); - var cnonce = MD5.hexdigest(Math.random() * 1234567890); - var realm = ""; - var host = null; - var nonce = ""; - var qop = ""; - var matches; - - // remove unneeded handlers - this.deleteHandler(this._sasl_failure_handler); - - while (challenge.match(attribMatch)) { - matches = challenge.match(attribMatch); - challenge = challenge.replace(matches[0], ""); - matches[2] = matches[2].replace(/^"(.+)"$/, "$1"); - switch (matches[1]) { - case "realm": - realm = matches[2]; - break; - case "nonce": - nonce = matches[2]; - break; - case "qop": - qop = matches[2]; - break; - case "host": - host = matches[2]; - break; - } - } - - var digest_uri = "xmpp/" + this.domain; - if (host !== null) { - digest_uri = digest_uri + "/" + host; - } - - var A1 = MD5.hash(Strophe.getNodeFromJid(this.jid) + - ":" + realm + ":" + this.pass) + - ":" + nonce + ":" + cnonce; - var A2 = 'AUTHENTICATE:' + digest_uri; - - var responseText = ""; - responseText += 'username=' + - this._quote(Strophe.getNodeFromJid(this.jid)) + ','; - responseText += 'realm=' + this._quote(realm) + ','; - responseText += 'nonce=' + this._quote(nonce) + ','; - responseText += 'cnonce=' + this._quote(cnonce) + ','; - responseText += 'nc="00000001",'; - responseText += 'qop="auth",'; - responseText += 'digest-uri=' + this._quote(digest_uri) + ','; - responseText += 'response=' + this._quote( - MD5.hexdigest(MD5.hexdigest(A1) + ":" + - nonce + ":00000001:" + - cnonce + ":auth:" + - MD5.hexdigest(A2))) + ','; - responseText += 'charset="utf-8"'; - - this._sasl_challenge_handler = this._addSysHandler( - this._sasl_challenge2_cb.bind(this), null, - "challenge", null, null); - this._sasl_success_handler = this._addSysHandler( - this._sasl_success_cb.bind(this), null, - "success", null, null); - this._sasl_failure_handler = this._addSysHandler( - this._sasl_failure_cb.bind(this), null, - "failure", null, null); - - this.send($build('response', { - xmlns: Strophe.NS.SASL - }).t(Base64.encode(responseText)).tree()); - - return false; - }, - - /** PrivateFunction: _quote - * _Private_ utility function to backslash escape and quote strings. - * - * Parameters: - * (String) str - The string to be quoted. - * - * Returns: - * quoted string - */ - _quote: function (str) - { - return '"' + str.replace(/\\/g, "\\\\").replace(/"/g, '\\"') + '"'; - //" end string workaround for emacs - }, - - - /** PrivateFunction: _sasl_challenge2_cb - * _Private_ handler for second step of DIGEST-MD5 SASL authentication. - * - * Parameters: - * (XMLElement) elem - The challenge stanza. - * - * Returns: - * false to remove the handler. - */ - _sasl_challenge2_cb: function (elem) - { - // remove unneeded handlers - this.deleteHandler(this._sasl_success_handler); - this.deleteHandler(this._sasl_failure_handler); - - this._sasl_success_handler = this._addSysHandler( - this._sasl_success_cb.bind(this), null, - "success", null, null); - this._sasl_failure_handler = this._addSysHandler( - this._sasl_failure_cb.bind(this), null, - "failure", null, null); - this.send($build('response', {xmlns: Strophe.NS.SASL}).tree()); - return false; - }, - - /** PrivateFunction: _auth1_cb - * _Private_ handler for legacy authentication. - * - * This handler is called in response to the initial - * for legacy authentication. It builds an authentication and - * sends it, creating a handler (calling back to _auth2_cb()) to - * handle the result - * - * Parameters: - * (XMLElement) elem - The stanza that triggered the callback. - * - * Returns: - * false to remove the handler. - */ - _auth1_cb: function (elem) - { - // build plaintext auth iq - var iq = $iq({type: "set", id: "_auth_2"}) - .c('query', {xmlns: Strophe.NS.AUTH}) - .c('username', {}).t(Strophe.getNodeFromJid(this.jid)) - .up() - .c('password').t(this.pass); - - if (!Strophe.getResourceFromJid(this.jid)) { - // since the user has not supplied a resource, we pick - // a default one here. unlike other auth methods, the server - // cannot do this for us. - this.jid = Strophe.getBareJidFromJid(this.jid) + '/strophe'; - } - iq.up().c('resource', {}).t(Strophe.getResourceFromJid(this.jid)); - - this._addSysHandler(this._auth2_cb.bind(this), null, - null, null, "_auth_2"); - - this.send(iq.tree()); - - return false; - }, - - /** PrivateFunction: _sasl_success_cb - * _Private_ handler for succesful SASL authentication. - * - * Parameters: - * (XMLElement) elem - The matching stanza. - * - * Returns: - * false to remove the handler. - */ - _sasl_success_cb: function (elem) - { - Strophe.info("SASL authentication succeeded."); - - // remove old handlers - this.deleteHandler(this._sasl_failure_handler); - this._sasl_failure_handler = null; - if (this._sasl_challenge_handler) { - this.deleteHandler(this._sasl_challenge_handler); - this._sasl_challenge_handler = null; - } - - this._addSysHandler(this._sasl_auth1_cb.bind(this), null, - "stream:features", null, null); - - // we must send an xmpp:restart now - this._sendRestart(); - - return false; - }, - - /** PrivateFunction: _sasl_auth1_cb - * _Private_ handler to start stream binding. - * - * Parameters: - * (XMLElement) elem - The matching stanza. - * - * Returns: - * false to remove the handler. - */ - _sasl_auth1_cb: function (elem) - { - // save stream:features for future usage - this.features = elem; - - var i, child; - - for (i = 0; i < elem._childNodes.length; i++) { - child = elem._childNodes[i]; - if (child.nodeName.toLowerCase() == 'bind') { - this.do_bind = true; - } - - if (child.nodeName.toLowerCase() == 'session') { - this.do_session = true; - } - } - - if (!this.do_bind) { - this._changeConnectStatus(Strophe.Status.AUTHFAIL, null); - return false; - } else { - this._addSysHandler(this._sasl_bind_cb.bind(this), null, null, - null, "_bind_auth_2"); - - var resource = Strophe.getResourceFromJid(this.jid); - if (resource) { - this.send($iq({type: "set", id: "_bind_auth_2"}) - .c('bind', {xmlns: Strophe.NS.BIND}) - .c('resource', {}).t(resource).tree()); - } else { - this.send($iq({type: "set", id: "_bind_auth_2"}) - .c('bind', {xmlns: Strophe.NS.BIND}) - .tree()); - } - } - - return false; - }, - - /** PrivateFunction: _sasl_bind_cb - * _Private_ handler for binding result and session start. - * - * Parameters: - * (XMLElement) elem - The matching stanza. - * - * Returns: - * false to remove the handler. - */ - _sasl_bind_cb: function (elem) - { - if (elem.getAttribute("type") == "error") { - Strophe.info("SASL binding failed."); - this._changeConnectStatus(Strophe.Status.AUTHFAIL, null); - return false; - } - - // TODO - need to grab errors - var bind = elem.getElementsByTagName("bind"); - var jidNode; - if (bind.length > 0) { - // Grab jid - jidNode = bind[0].getElementsByTagName("jid"); - if (jidNode.length > 0) { - this.jid = Strophe.getText(jidNode[0]); - - if (this.do_session) { - this._addSysHandler(this._sasl_session_cb.bind(this), - null, null, null, "_session_auth_2"); - - this.send($iq({type: "set", id: "_session_auth_2"}) - .c('session', {xmlns: Strophe.NS.SESSION}) - .tree()); - } else { - this.authenticated = true; - this._changeConnectStatus(Strophe.Status.CONNECTED, null); - } - } - } else { - Strophe.info("SASL binding failed."); - this._changeConnectStatus(Strophe.Status.AUTHFAIL, null); - return false; - } - }, - - /** PrivateFunction: _sasl_session_cb - * _Private_ handler to finish successful SASL connection. - * - * This sets Connection.authenticated to true on success, which - * starts the processing of user handlers. - * - * Parameters: - * (XMLElement) elem - The matching stanza. - * - * Returns: - * false to remove the handler. - */ - _sasl_session_cb: function (elem) - { - if (elem.getAttribute("type") == "result") { - this.authenticated = true; - this._changeConnectStatus(Strophe.Status.CONNECTED, null); - } else if (elem.getAttribute("type") == "error") { - Strophe.info("Session creation failed."); - this._changeConnectStatus(Strophe.Status.AUTHFAIL, null); - return false; - } - - return false; - }, - - /** PrivateFunction: _sasl_failure_cb - * _Private_ handler for SASL authentication failure. - * - * Parameters: - * (XMLElement) elem - The matching stanza. - * - * Returns: - * false to remove the handler. - */ - _sasl_failure_cb: function (elem) - { - // delete unneeded handlers - if (this._sasl_success_handler) { - this.deleteHandler(this._sasl_success_handler); - this._sasl_success_handler = null; - } - if (this._sasl_challenge_handler) { - this.deleteHandler(this._sasl_challenge_handler); - this._sasl_challenge_handler = null; - } - - this._changeConnectStatus(Strophe.Status.AUTHFAIL, null); - return false; - }, - - /** PrivateFunction: _auth2_cb - * _Private_ handler to finish legacy authentication. - * - * This handler is called when the result from the jabber:iq:auth - * stanza is returned. - * - * Parameters: - * (XMLElement) elem - The stanza that triggered the callback. - * - * Returns: - * false to remove the handler. - */ - _auth2_cb: function (elem) - { - if (elem.getAttribute("type") == "result") { - this.authenticated = true; - this._changeConnectStatus(Strophe.Status.CONNECTED, null); - } else if (elem.getAttribute("type") == "error") { - this._changeConnectStatus(Strophe.Status.AUTHFAIL, null); - this.disconnect(); - } - - return false; - }, - - /** PrivateFunction: _addSysTimedHandler - * _Private_ function to add a system level timed handler. - * - * This function is used to add a Strophe.TimedHandler for the - * library code. System timed handlers are allowed to run before - * authentication is complete. - * - * Parameters: - * (Integer) period - The period of the handler. - * (Function) handler - The callback function. - */ - _addSysTimedHandler: function (period, handler) - { - var thand = new Strophe.TimedHandler(period, handler); - thand.user = false; - this.addTimeds.push(thand); - return thand; - }, - - /** PrivateFunction: _addSysHandler - * _Private_ function to add a system level stanza handler. - * - * This function is used to add a Strophe.Handler for the - * library code. System stanza handlers are allowed to run before - * authentication is complete. - * - * Parameters: - * (Function) handler - The callback function. - * (String) ns - The namespace to match. - * (String) name - The stanza name to match. - * (String) type - The stanza type attribute to match. - * (String) id - The stanza id attribute to match. - */ - _addSysHandler: function (handler, ns, name, type, id) - { - var hand = new Strophe.Handler(handler, ns, name, type, id); - hand.user = false; - this.addHandlers.push(hand); - return hand; - }, - - /** PrivateFunction: _onDisconnectTimeout - * _Private_ timeout handler for handling non-graceful disconnection. - * - * If the graceful disconnect process does not complete within the - * time allotted, this handler finishes the disconnect anyway. - * - * Returns: - * false to remove the handler. - */ - _onDisconnectTimeout: function () - { - Strophe.info("_onDisconnectTimeout was called"); - - // cancel all remaining requests and clear the queue - var req; - while (this._requests.length > 0) { - req = this._requests.pop(); - req.abort = true; - req.xhr.abort(); - // jslint complains, but this is fine. setting to empty func - // is necessary for IE6 - req.xhr.onreadystatechange = function () {}; - } - - // actually disconnect - this._doDisconnect(); - - return false; - }, - - /** PrivateFunction: _onIdle - * _Private_ handler to process events during idle cycle. - * - * This handler is called every 100ms to fire timed handlers that - * are ready and keep poll requests going. - */ - _onIdle: function () - { - var i, thand, since, newList; - - // add timed handlers scheduled for addition - // NOTE: we add before remove in the case a timed handler is - // added and then deleted before the next _onIdle() call. - while (this.addTimeds.length > 0) { - this.timedHandlers.push(this.addTimeds.pop()); - } - - // remove timed handlers that have been scheduled for deletion - while (this.removeTimeds.length > 0) { - thand = this.removeTimeds.pop(); - i = this.timedHandlers.indexOf(thand); - if (i >= 0) { - this.timedHandlers.splice(i, 1); - } - } - - // call ready timed handlers - var now = new Date().getTime(); - newList = []; - for (i = 0; i < this.timedHandlers.length; i++) { - thand = this.timedHandlers[i]; - if (this.authenticated || !thand.user) { - since = thand.lastCalled + thand.period; - if (since - now <= 0) { - if (thand.run()) { - newList.push(thand); - } - } else { - newList.push(thand); - } - } - } - this.timedHandlers = newList; - - var body, time_elapsed; - - // if no requests are in progress, poll - if (this.authenticated && this._requests.length === 0 && - this._data.length === 0 && !this.disconnecting) { - Strophe.info("no requests during idle cycle, sending " + - "blank request"); - this._data.push(null); - } - - if (this._requests.length < 2 && this._data.length > 0 && - !this.paused) { - body = this._buildBody(); - for (i = 0; i < this._data.length; i++) { - if (this._data[i] !== null) { - if (this._data[i] === "restart") { - body.attrs({ - to: this.domain, - "xml:lang": "en", - "xmpp:restart": "true", - "xmlns:xmpp": Strophe.NS.BOSH - }); - } else { - body.cnode(this._data[i]).up(); - } - } - } - delete this._data; - this._data = []; - this._requests.push( - new Strophe.Request(body.tree(), - this._onRequestStateChange.bind( - this, this._dataRecv.bind(this)), - body.tree().getAttribute("rid"))); - this._processRequest(this._requests.length - 1); - } - - if (this._requests.length > 0) { - time_elapsed = this._requests[0].age(); - if (this._requests[0].dead !== null) { - if (this._requests[0].timeDead() > - Math.floor(Strophe.SECONDARY_TIMEOUT * this.wait)) { - this._throttledRequestHandler(); - } - } - - if (time_elapsed > Math.floor(Strophe.TIMEOUT * this.wait)) { - Strophe.warn("Request " + - this._requests[0].id + - " timed out, over " + Math.floor(Strophe.TIMEOUT * this.wait) + - " seconds since last activity"); - this._throttledRequestHandler(); - } - } - - // reactivate the timer - clearTimeout(this._idleTimeout); - this._idleTimeout = setTimeout(this._onIdle.bind(this), 100); - } -}; - -if (callback) { - callback(Strophe, $build, $msg, $iq, $pres); -} - -})(function () { - window.Strophe = arguments[0]; - window.$build = arguments[1]; - window.$msg = arguments[2]; - window.$iq = arguments[3]; - window.$pres = arguments[4]; -}); diff --git a/contrib/jitsimeetbridge/unjingle/unjingle.js b/contrib/jitsimeetbridge/unjingle/unjingle.js deleted file mode 100644 index 3dfe759914..0000000000 --- a/contrib/jitsimeetbridge/unjingle/unjingle.js +++ /dev/null @@ -1,48 +0,0 @@ -var strophe = require("./strophe/strophe.js").Strophe; - -var Strophe = strophe.Strophe; -var $iq = strophe.$iq; -var $msg = strophe.$msg; -var $build = strophe.$build; -var $pres = strophe.$pres; - -var jsdom = require("jsdom"); -var window = jsdom.jsdom().parentWindow; -var $ = require('jquery')(window); - -var stropheJingle = require("./strophe.jingle.sdp.js"); - - -var input = ''; - -process.stdin.on('readable', function() { - var chunk = process.stdin.read(); - if (chunk !== null) { - input += chunk; - } -}); - -process.stdin.on('end', function() { - if (process.argv[2] == '--jingle') { - var elem = $(input); - // app does: - // sess.setRemoteDescription($(iq).find('>jingle'), 'offer'); - //console.log(elem.find('>content')); - var sdp = new stropheJingle.SDP(''); - sdp.fromJingle(elem); - console.log(sdp.raw); - } else if (process.argv[2] == '--sdp') { - var sdp = new stropheJingle.SDP(input); - var accept = $iq({to: '%(tojid)s', - type: 'set'}) - .c('jingle', {xmlns: 'urn:xmpp:jingle:1', - //action: 'session-accept', - action: '%(action)s', - initiator: '%(initiator)s', - responder: '%(responder)s', - sid: '%(sid)s' }); - sdp.toJingle(accept, 'responder'); - console.log(Strophe.serialize(accept)); - } -}); - diff --git a/debian/changelog b/debian/changelog index b6a51d6903..b1e61e7c8a 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,10 @@ +matrix-synapse-py3 (1.60.0~rc2+nmu1) UNRELEASED; urgency=medium + + * Non-maintainer upload. + * Remove unused `jitsimeetbridge` experiment from `contrib` directory. + + -- Synapse Packaging team Sun, 29 May 2022 14:44:45 +0100 + matrix-synapse-py3 (1.60.0~rc2) stable; urgency=medium * New Synapse release 1.60.0rc2. diff --git a/debian/copyright b/debian/copyright index 95c21ea12a..902b18fa41 100644 --- a/debian/copyright +++ b/debian/copyright @@ -22,29 +22,6 @@ Files: synapse/config/repository.py Copyright: 2014-2015, matrix.org License: Apache-2.0 -Files: contrib/jitsimeetbridge/unjingle/strophe/base64.js -Copyright: Public Domain (Tyler Akins http://rumkin.com) -License: public-domain - This code was written by Tyler Akins and has been placed in the - public domain. It would be nice if you left this header intact. - Base64 code from Tyler Akins -- http://rumkin.com - -Files: contrib/jitsimeetbridge/unjingle/strophe/md5.js -Copyright: 1999-2002, Paul Johnston & Contributors -License: BSD-3-clause - -Files: contrib/jitsimeetbridge/unjingle/strophe/strophe.js -Copyright: 2006-2008, OGG, LLC -License: Expat - -Files: contrib/jitsimeetbridge/unjingle/strophe/XMLHttpRequest.js -Copyright: 2010 passive.ly LLC -License: Expat - -Files: contrib/jitsimeetbridge/unjingle/*.js -Copyright: 2014 Jitsi -License: Apache-2.0 - Files: debian/* Copyright: 2016-2017, Erik Johnston 2017, Rahul De From 80bd614dac1d303e2527d98de97f01d7b0e2daef Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 30 May 2022 10:47:47 +0100 Subject: [PATCH 38/74] Remove `contrib/experiments/test_messaging.py` (#12911) --- changelog.d/12911.removal | 1 + contrib/experiments/test_messaging.py | 367 -------------------------- 2 files changed, 1 insertion(+), 367 deletions(-) create mode 100644 changelog.d/12911.removal delete mode 100644 contrib/experiments/test_messaging.py diff --git a/changelog.d/12911.removal b/changelog.d/12911.removal new file mode 100644 index 0000000000..5178cd6532 --- /dev/null +++ b/changelog.d/12911.removal @@ -0,0 +1 @@ +Remove unused `contrib/experiements/test_messaging.py` script. This fails to run on Python 3. diff --git a/contrib/experiments/test_messaging.py b/contrib/experiments/test_messaging.py deleted file mode 100644 index 31b8a68225..0000000000 --- a/contrib/experiments/test_messaging.py +++ /dev/null @@ -1,367 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" This is an example of using the server to server implementation to do a -basic chat style thing. It accepts commands from stdin and outputs to stdout. - -It assumes that ucids are of the form @, and uses as -the address of the remote home server to hit. - -Usage: - python test_messaging.py - -Currently assumes the local address is localhost: - -""" - - -import argparse -import curses.wrapper -import json -import logging -import os -import re - -import cursesio - -from twisted.internet import defer, reactor -from twisted.python import log - -from synapse.app.homeserver import SynapseHomeServer -from synapse.federation import ReplicationHandler -from synapse.federation.units import Pdu -from synapse.util import origin_from_ucid - -# from synapse.logging.utils import log_function - - -logger = logging.getLogger("example") - - -def excpetion_errback(failure): - logging.exception(failure) - - -class InputOutput: - """This is responsible for basic I/O so that a user can interact with - the example app. - """ - - def __init__(self, screen, user): - self.screen = screen - self.user = user - - def set_home_server(self, server): - self.server = server - - def on_line(self, line): - """This is where we process commands.""" - - try: - m = re.match(r"^join (\S+)$", line) - if m: - # The `sender` wants to join a room. - (room_name,) = m.groups() - self.print_line("%s joining %s" % (self.user, room_name)) - self.server.join_room(room_name, self.user, self.user) - # self.print_line("OK.") - return - - m = re.match(r"^invite (\S+) (\S+)$", line) - if m: - # `sender` wants to invite someone to a room - room_name, invitee = m.groups() - self.print_line("%s invited to %s" % (invitee, room_name)) - self.server.invite_to_room(room_name, self.user, invitee) - # self.print_line("OK.") - return - - m = re.match(r"^send (\S+) (.*)$", line) - if m: - # `sender` wants to message a room - room_name, body = m.groups() - self.print_line("%s send to %s" % (self.user, room_name)) - self.server.send_message(room_name, self.user, body) - # self.print_line("OK.") - return - - m = re.match(r"^backfill (\S+)$", line) - if m: - # we want to backfill a room - (room_name,) = m.groups() - self.print_line("backfill %s" % room_name) - self.server.backfill(room_name) - return - - self.print_line("Unrecognized command") - - except Exception as e: - logger.exception(e) - - def print_line(self, text): - self.screen.print_line(text) - - def print_log(self, text): - self.screen.print_log(text) - - -class IOLoggerHandler(logging.Handler): - def __init__(self, io): - logging.Handler.__init__(self) - self.io = io - - def emit(self, record): - if record.levelno < logging.WARN: - return - - msg = self.format(record) - self.io.print_log(msg) - - -class Room: - """Used to store (in memory) the current membership state of a room, and - which home servers we should send PDUs associated with the room to. - """ - - def __init__(self, room_name): - self.room_name = room_name - self.invited = set() - self.participants = set() - self.servers = set() - - self.oldest_server = None - - self.have_got_metadata = False - - def add_participant(self, participant): - """Someone has joined the room""" - self.participants.add(participant) - self.invited.discard(participant) - - server = origin_from_ucid(participant) - self.servers.add(server) - - if not self.oldest_server: - self.oldest_server = server - - def add_invited(self, invitee): - """Someone has been invited to the room""" - self.invited.add(invitee) - self.servers.add(origin_from_ucid(invitee)) - - -class HomeServer(ReplicationHandler): - """A very basic home server implentation that allows people to join a - room and then invite other people. - """ - - def __init__(self, server_name, replication_layer, output): - self.server_name = server_name - self.replication_layer = replication_layer - self.replication_layer.set_handler(self) - - self.joined_rooms = {} - - self.output = output - - def on_receive_pdu(self, pdu): - """We just received a PDU""" - pdu_type = pdu.pdu_type - - if pdu_type == "sy.room.message": - self._on_message(pdu) - elif pdu_type == "sy.room.member" and "membership" in pdu.content: - if pdu.content["membership"] == "join": - self._on_join(pdu.context, pdu.state_key) - elif pdu.content["membership"] == "invite": - self._on_invite(pdu.origin, pdu.context, pdu.state_key) - else: - self.output.print_line( - "#%s (unrec) %s = %s" - % (pdu.context, pdu.pdu_type, json.dumps(pdu.content)) - ) - - def _on_message(self, pdu): - """We received a message""" - self.output.print_line( - "#%s %s %s" % (pdu.context, pdu.content["sender"], pdu.content["body"]) - ) - - def _on_join(self, context, joinee): - """Someone has joined a room, either a remote user or a local user""" - room = self._get_or_create_room(context) - room.add_participant(joinee) - - self.output.print_line("#%s %s %s" % (context, joinee, "*** JOINED")) - - def _on_invite(self, origin, context, invitee): - """Someone has been invited""" - room = self._get_or_create_room(context) - room.add_invited(invitee) - - self.output.print_line("#%s %s %s" % (context, invitee, "*** INVITED")) - - if not room.have_got_metadata and origin is not self.server_name: - logger.debug("Get room state") - self.replication_layer.get_state_for_context(origin, context) - room.have_got_metadata = True - - @defer.inlineCallbacks - def send_message(self, room_name, sender, body): - """Send a message to a room!""" - destinations = yield self.get_servers_for_context(room_name) - - try: - yield self.replication_layer.send_pdu( - Pdu.create_new( - context=room_name, - pdu_type="sy.room.message", - content={"sender": sender, "body": body}, - origin=self.server_name, - destinations=destinations, - ) - ) - except Exception as e: - logger.exception(e) - - @defer.inlineCallbacks - def join_room(self, room_name, sender, joinee): - """Join a room!""" - self._on_join(room_name, joinee) - - destinations = yield self.get_servers_for_context(room_name) - - try: - pdu = Pdu.create_new( - context=room_name, - pdu_type="sy.room.member", - is_state=True, - state_key=joinee, - content={"membership": "join"}, - origin=self.server_name, - destinations=destinations, - ) - yield self.replication_layer.send_pdu(pdu) - except Exception as e: - logger.exception(e) - - @defer.inlineCallbacks - def invite_to_room(self, room_name, sender, invitee): - """Invite someone to a room!""" - self._on_invite(self.server_name, room_name, invitee) - - destinations = yield self.get_servers_for_context(room_name) - - try: - yield self.replication_layer.send_pdu( - Pdu.create_new( - context=room_name, - is_state=True, - pdu_type="sy.room.member", - state_key=invitee, - content={"membership": "invite"}, - origin=self.server_name, - destinations=destinations, - ) - ) - except Exception as e: - logger.exception(e) - - def backfill(self, room_name, limit=5): - room = self.joined_rooms.get(room_name) - - if not room: - return - - dest = room.oldest_server - - return self.replication_layer.backfill(dest, room_name, limit) - - def _get_room_remote_servers(self, room_name): - return list(self.joined_rooms.setdefault(room_name).servers) - - def _get_or_create_room(self, room_name): - return self.joined_rooms.setdefault(room_name, Room(room_name)) - - def get_servers_for_context(self, context): - return defer.succeed( - self.joined_rooms.setdefault(context, Room(context)).servers - ) - - -def main(stdscr): - parser = argparse.ArgumentParser() - parser.add_argument("user", type=str) - parser.add_argument("-v", "--verbose", action="count") - args = parser.parse_args() - - user = args.user - server_name = origin_from_ucid(user) - - # Set up logging - - root_logger = logging.getLogger() - - formatter = logging.Formatter( - "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s" - ) - if not os.path.exists("logs"): - os.makedirs("logs") - fh = logging.FileHandler("logs/%s" % user) - fh.setFormatter(formatter) - - root_logger.addHandler(fh) - root_logger.setLevel(logging.DEBUG) - - # Hack: The only way to get it to stop logging to sys.stderr :( - log.theLogPublisher.observers = [] - observer = log.PythonLoggingObserver() - observer.start() - - # Set up synapse server - - curses_stdio = cursesio.CursesStdIO(stdscr) - input_output = InputOutput(curses_stdio, user) - - curses_stdio.set_callback(input_output) - - app_hs = SynapseHomeServer(server_name, db_name="dbs/%s" % user) - replication = app_hs.get_replication_layer() - - hs = HomeServer(server_name, replication, curses_stdio) - - input_output.set_home_server(hs) - - # Add input_output logger - io_logger = IOLoggerHandler(input_output) - io_logger.setFormatter(formatter) - root_logger.addHandler(io_logger) - - # Start! - - try: - port = int(server_name.split(":")[1]) - except Exception: - port = 12345 - - app_hs.get_http_server().start_listening(port) - - reactor.addReader(curses_stdio) - - reactor.run() - - -if __name__ == "__main__": - curses.wrapper(main) From 119938792bdeed01f94f8173849333c9ef499df6 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 30 May 2022 10:47:54 +0100 Subject: [PATCH 39/74] Remove unused `contrib/experiments/cursesio.py` (#12910) --- changelog.d/12910.removal | 1 + contrib/experiments/cursesio.py | 165 -------------------------------- 2 files changed, 1 insertion(+), 165 deletions(-) create mode 100644 changelog.d/12910.removal delete mode 100644 contrib/experiments/cursesio.py diff --git a/changelog.d/12910.removal b/changelog.d/12910.removal new file mode 100644 index 0000000000..4bd4f877f6 --- /dev/null +++ b/changelog.d/12910.removal @@ -0,0 +1 @@ +Remove unused `contrib/experiements/cursesio.py` script, which fails to run under Python 3. diff --git a/contrib/experiments/cursesio.py b/contrib/experiments/cursesio.py deleted file mode 100644 index 7695cc77ca..0000000000 --- a/contrib/experiments/cursesio.py +++ /dev/null @@ -1,165 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import curses -import curses.wrapper -from curses.ascii import isprint - -from twisted.internet import reactor - - -class CursesStdIO: - def __init__(self, stdscr, callback=None): - self.statusText = "Synapse test app -" - self.searchText = "" - self.stdscr = stdscr - - self.logLine = "" - - self.callback = callback - - self._setup() - - def _setup(self): - self.stdscr.nodelay(1) # Make non blocking - - self.rows, self.cols = self.stdscr.getmaxyx() - self.lines = [] - - curses.use_default_colors() - - self.paintStatus(self.statusText) - self.stdscr.refresh() - - def set_callback(self, callback): - self.callback = callback - - def fileno(self): - """We want to select on FD 0""" - return 0 - - def connectionLost(self, reason): - self.close() - - def print_line(self, text): - """add a line to the internal list of lines""" - - self.lines.append(text) - self.redraw() - - def print_log(self, text): - self.logLine = text - self.redraw() - - def redraw(self): - """method for redisplaying lines based on internal list of lines""" - - self.stdscr.clear() - self.paintStatus(self.statusText) - i = 0 - index = len(self.lines) - 1 - while i < (self.rows - 3) and index >= 0: - self.stdscr.addstr(self.rows - 3 - i, 0, self.lines[index], curses.A_NORMAL) - i = i + 1 - index = index - 1 - - self.printLogLine(self.logLine) - - self.stdscr.refresh() - - def paintStatus(self, text): - if len(text) > self.cols: - raise RuntimeError("TextTooLongError") - - self.stdscr.addstr( - self.rows - 2, 0, text + " " * (self.cols - len(text)), curses.A_STANDOUT - ) - - def printLogLine(self, text): - self.stdscr.addstr( - 0, 0, text + " " * (self.cols - len(text)), curses.A_STANDOUT - ) - - def doRead(self): - """Input is ready!""" - curses.noecho() - c = self.stdscr.getch() # read a character - - if c == curses.KEY_BACKSPACE: - self.searchText = self.searchText[:-1] - - elif c == curses.KEY_ENTER or c == 10: - text = self.searchText - self.searchText = "" - - self.print_line(">> %s" % text) - - try: - if self.callback: - self.callback.on_line(text) - except Exception as e: - self.print_line(str(e)) - - self.stdscr.refresh() - - elif isprint(c): - if len(self.searchText) == self.cols - 2: - return - self.searchText = self.searchText + chr(c) - - self.stdscr.addstr( - self.rows - 1, - 0, - self.searchText + (" " * (self.cols - len(self.searchText) - 2)), - ) - - self.paintStatus(self.statusText + " %d" % len(self.searchText)) - self.stdscr.move(self.rows - 1, len(self.searchText)) - self.stdscr.refresh() - - def logPrefix(self): - return "CursesStdIO" - - def close(self): - """clean up""" - - curses.nocbreak() - self.stdscr.keypad(0) - curses.echo() - curses.endwin() - - -class Callback: - def __init__(self, stdio): - self.stdio = stdio - - def on_line(self, text): - self.stdio.print_line(text) - - -def main(stdscr): - screen = CursesStdIO(stdscr) # create Screen object - - callback = Callback(screen) - - screen.set_callback(callback) - - stdscr.refresh() - reactor.addReader(screen) - reactor.run() - screen.close() - - -if __name__ == "__main__": - curses.wrapper(main) From 248046187940372c23466cb395b46ed97ebda1ed Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 30 May 2022 10:51:09 +0100 Subject: [PATCH 40/74] Fix `get_metadata_for_events` (#12904) This method was introduced in #12852. It is using the `state_key` column from the `events` table, which is not (yet) reliable (see #11496). --- changelog.d/12904.misc | 1 + synapse/storage/databases/main/state.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) create mode 100644 changelog.d/12904.misc diff --git a/changelog.d/12904.misc b/changelog.d/12904.misc new file mode 100644 index 0000000000..afca32471f --- /dev/null +++ b/changelog.d/12904.misc @@ -0,0 +1 @@ +Pull out less state when handling gaps in room DAG. diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index ea5cbdac08..a07ad85582 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -167,8 +167,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): ) sql = f""" - SELECT e.event_id, e.room_id, e.type, e.state_key FROM events AS e - LEFT JOIN state_events USING (event_id) + SELECT e.event_id, e.room_id, e.type, se.state_key FROM events AS e + LEFT JOIN state_events se USING (event_id) WHERE {clause} """ From b10211871fd19013631cf5d798a90f74a86c6c56 Mon Sep 17 00:00:00 2001 From: "DeepBlueV7.X" Date: Mon, 30 May 2022 11:14:43 +0000 Subject: [PATCH 41/74] Fix invite notifications for users without pushers (#12840) Signed-off-by: Nicolas Werner Co-authored-by: Brendan Abolivier --- changelog.d/12840.bugfix | 1 + synapse/push/bulk_push_rule_evaluator.py | 10 ++- synapse/storage/databases/main/pusher.py | 6 -- tests/rest/client/test_notifications.py | 91 ++++++++++++++++++++++++ 4 files changed, 96 insertions(+), 12 deletions(-) create mode 100644 changelog.d/12840.bugfix create mode 100644 tests/rest/client/test_notifications.py diff --git a/changelog.d/12840.bugfix b/changelog.d/12840.bugfix new file mode 100644 index 0000000000..b15cedf896 --- /dev/null +++ b/changelog.d/12840.bugfix @@ -0,0 +1 @@ +Fix an issue introduced in Synapse 0.34 where the `/notifications` endpoint would only return notifications if a user registered at least one pusher. Contributed by Famedly. diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 1a8e7ef3dc..7791b289e2 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -153,12 +153,10 @@ class BulkPushRuleEvaluator: if event.type == "m.room.member" and event.content["membership"] == "invite": invited = event.state_key if invited and self.hs.is_mine_id(invited): - has_pusher = await self.store.user_has_pusher(invited) - if has_pusher: - rules_by_user = dict(rules_by_user) - rules_by_user[invited] = await self.store.get_push_rules_for_user( - invited - ) + rules_by_user = dict(rules_by_user) + rules_by_user[invited] = await self.store.get_push_rules_for_user( + invited + ) return rules_by_user diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index 91286c9b65..bd0cfa7f32 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -91,12 +91,6 @@ class PusherWorkerStore(SQLBaseStore): yield PusherConfig(**r) - async def user_has_pusher(self, user_id: str) -> bool: - ret = await self.db_pool.simple_select_one_onecol( - "pushers", {"user_name": user_id}, "id", allow_none=True - ) - return ret is not None - async def get_pushers_by_app_id_and_pushkey( self, app_id: str, pushkey: str ) -> Iterator[PusherConfig]: diff --git a/tests/rest/client/test_notifications.py b/tests/rest/client/test_notifications.py new file mode 100644 index 0000000000..700f6587a0 --- /dev/null +++ b/tests/rest/client/test_notifications.py @@ -0,0 +1,91 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest.mock import Mock + +from twisted.test.proto_helpers import MemoryReactor + +import synapse.rest.admin +from synapse.rest.client import login, notifications, receipts, room +from synapse.server import HomeServer +from synapse.util import Clock + +from tests.test_utils import simple_async_mock +from tests.unittest import HomeserverTestCase + + +class HTTPPusherTests(HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + receipts.register_servlets, + notifications.register_servlets, + ] + + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: + self.store = homeserver.get_datastores().main + self.module_api = homeserver.get_module_api() + self.event_creation_handler = homeserver.get_event_creation_handler() + self.sync_handler = homeserver.get_sync_handler() + self.auth_handler = homeserver.get_auth_handler() + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + # Mock out the calls over federation. + fed_transport_client = Mock(spec=["send_transaction"]) + fed_transport_client.send_transaction = simple_async_mock({}) + + return self.setup_test_homeserver( + federation_transport_client=fed_transport_client, + ) + + def test_notify_for_local_invites(self) -> None: + """ + Local users will get notified for invites + """ + + user_id = self.register_user("user", "pass") + access_token = self.login("user", "pass") + other_user_id = self.register_user("otheruser", "pass") + other_access_token = self.login("otheruser", "pass") + + # Create a room + room = self.helper.create_room_as(user_id, tok=access_token) + + # Check we start with no pushes + channel = self.make_request( + "GET", + "/notifications", + access_token=other_access_token, + ) + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(len(channel.json_body["notifications"]), 0, channel.json_body) + + # Send an invite + self.helper.invite(room=room, src=user_id, targ=other_user_id, tok=access_token) + + # We should have a notification now + channel = self.make_request( + "GET", + "/notifications", + access_token=other_access_token, + ) + self.assertEqual(channel.code, 200) + self.assertEqual(len(channel.json_body["notifications"]), 1, channel.json_body) + self.assertEqual( + channel.json_body["notifications"][0]["event"]["content"]["membership"], + "invite", + channel.json_body, + ) From 7f92ac4c1cbd7fae815c02b4920eb02ddf9458cb Mon Sep 17 00:00:00 2001 From: Mathieu Velten Date: Mon, 30 May 2022 16:51:37 +0200 Subject: [PATCH 42/74] Add a migration step to cleanup potential leftovers of bug 11833 (#12784) Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com> --- changelog.d/12784.bugfix | 1 + .../delta/70/01clean_table_purged_rooms.sql | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+) create mode 100644 changelog.d/12784.bugfix create mode 100644 synapse/storage/schema/main/delta/70/01clean_table_purged_rooms.sql diff --git a/changelog.d/12784.bugfix b/changelog.d/12784.bugfix new file mode 100644 index 0000000000..a958f9a16b --- /dev/null +++ b/changelog.d/12784.bugfix @@ -0,0 +1 @@ +Delete events from the `federation_inbound_events_staging` table when a room is purged through the admin API. diff --git a/synapse/storage/schema/main/delta/70/01clean_table_purged_rooms.sql b/synapse/storage/schema/main/delta/70/01clean_table_purged_rooms.sql new file mode 100644 index 0000000000..aed79635b2 --- /dev/null +++ b/synapse/storage/schema/main/delta/70/01clean_table_purged_rooms.sql @@ -0,0 +1,19 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Clean up left over rows from bug #11833, which was fixed in #12770. +DELETE FROM federation_inbound_events_staging WHERE room_id not in ( + SELECT room_id FROM rooms +); From 1fd1856afc4155ee47b3be9829b02d740f9a64d6 Mon Sep 17 00:00:00 2001 From: Mathieu Velten Date: Mon, 30 May 2022 17:41:24 +0200 Subject: [PATCH 43/74] demo: check if we are in a virtualenv before overriding PYTHONPATH (#12916) --- changelog.d/12916.misc | 1 + demo/start.sh | 11 ++++++----- 2 files changed, 7 insertions(+), 5 deletions(-) create mode 100644 changelog.d/12916.misc diff --git a/changelog.d/12916.misc b/changelog.d/12916.misc new file mode 100644 index 0000000000..347eb096db --- /dev/null +++ b/changelog.d/12916.misc @@ -0,0 +1 @@ +Check if we are in a virtual environment before overriding the `PYTHONPATH` environment variable in the demo script. diff --git a/demo/start.sh b/demo/start.sh index 96b3a2ceab..fdd75816fb 100755 --- a/demo/start.sh +++ b/demo/start.sh @@ -6,11 +6,12 @@ CWD=$(pwd) cd "$DIR/.." || exit -PYTHONPATH=$(readlink -f "$(pwd)") -export PYTHONPATH - - -echo "$PYTHONPATH" +# Do not override PYTHONPATH if we are in a virtual env +if [ "$VIRTUAL_ENV" = "" ]; then + PYTHONPATH=$(readlink -f "$(pwd)") + export PYTHONPATH + echo "$PYTHONPATH" +fi # Create servers which listen on HTTP at 808x and HTTPS at 848x. for port in 8080 8081 8082; do From af7db19e1e89e9b4ac4818c47b7f389ad46a7c9b Mon Sep 17 00:00:00 2001 From: David Teller Date: Mon, 30 May 2022 18:24:56 +0200 Subject: [PATCH 44/74] Uniformize spam-checker API, part 3: Expand check_event_for_spam with the ability to return additional fields (#12846) Signed-off-by: David Teller --- changelog.d/12808.feature | 1 + changelog.d/12846.misc | 1 + synapse/api/errors.py | 23 +++++++++++++---------- synapse/events/spamcheck.py | 20 +++++++++++++------- synapse/handlers/message.py | 15 +++++++++++++++ 5 files changed, 43 insertions(+), 17 deletions(-) create mode 100644 changelog.d/12808.feature create mode 100644 changelog.d/12846.misc diff --git a/changelog.d/12808.feature b/changelog.d/12808.feature new file mode 100644 index 0000000000..561c8b9d34 --- /dev/null +++ b/changelog.d/12808.feature @@ -0,0 +1 @@ +Update to `check_event_for_spam`. Deprecate the current callback signature, replace it with a new signature that is both less ambiguous (replacing booleans with explicit allow/block) and more powerful (ability to return explicit error codes). \ No newline at end of file diff --git a/changelog.d/12846.misc b/changelog.d/12846.misc new file mode 100644 index 0000000000..f72d3d2bea --- /dev/null +++ b/changelog.d/12846.misc @@ -0,0 +1 @@ +Experimental: expand `check_event_for_spam` with ability to return additional fields. This enables spam-checker implementations to experiment with mechanisms to give users more information about why they are blocked and whether any action is needed from them to be unblocked. \ No newline at end of file diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 05e96843cf..54268e0889 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -146,7 +146,13 @@ class SynapseError(CodeMessageException): errcode: Matrix error code e.g 'M_FORBIDDEN' """ - def __init__(self, code: int, msg: str, errcode: str = Codes.UNKNOWN): + def __init__( + self, + code: int, + msg: str, + errcode: str = Codes.UNKNOWN, + additional_fields: Optional[Dict] = None, + ): """Constructs a synapse error. Args: @@ -156,9 +162,13 @@ class SynapseError(CodeMessageException): """ super().__init__(code, msg) self.errcode = errcode + if additional_fields is None: + self._additional_fields: Dict = {} + else: + self._additional_fields = dict(additional_fields) def error_dict(self) -> "JsonDict": - return cs_error(self.msg, self.errcode) + return cs_error(self.msg, self.errcode, **self._additional_fields) class InvalidAPICallError(SynapseError): @@ -183,14 +193,7 @@ class ProxiedRequestError(SynapseError): errcode: str = Codes.UNKNOWN, additional_fields: Optional[Dict] = None, ): - super().__init__(code, msg, errcode) - if additional_fields is None: - self._additional_fields: Dict = {} - else: - self._additional_fields = dict(additional_fields) - - def error_dict(self) -> "JsonDict": - return cs_error(self.msg, self.errcode, **self._additional_fields) + super().__init__(code, msg, errcode, additional_fields) class ConsentNotGivenError(SynapseError): diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index 7984874e21..82998ca490 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -21,6 +21,7 @@ from typing import ( Awaitable, Callable, Collection, + Dict, List, Optional, Tuple, @@ -41,13 +42,17 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) - CHECK_EVENT_FOR_SPAM_CALLBACK = Callable[ ["synapse.events.EventBase"], Awaitable[ Union[ Allow, Codes, + # Highly experimental, not officially part of the spamchecker API, may + # disappear without warning depending on the results of ongoing + # experiments. + # Use this to return additional information as part of an error. + Tuple[Codes, Dict], # Deprecated bool, # Deprecated @@ -270,7 +275,7 @@ class SpamChecker: async def check_event_for_spam( self, event: "synapse.events.EventBase" - ) -> Union[Decision, str]: + ) -> Union[Decision, Tuple[Codes, Dict], str]: """Checks if a given event is considered "spammy" by this server. If the server considers an event spammy, then it will be rejected if @@ -293,9 +298,9 @@ class SpamChecker: with Measure( self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) ): - res: Union[Decision, str, bool] = await delay_cancellation( - callback(event) - ) + res: Union[ + Decision, Tuple[Codes, Dict], str, bool + ] = await delay_cancellation(callback(event)) if res is False or res is Allow.ALLOW: # This spam-checker accepts the event. # Other spam-checkers may reject it, though. @@ -305,8 +310,9 @@ class SpamChecker: # return value `True` return Codes.FORBIDDEN else: - # This spam-checker rejects the event either with a `str` - # or with a `Codes`. In either case, we stop here. + # This spam-checker rejects the event either with a `str`, + # with a `Codes` or with a `Tuple[Codes, Dict]`. In either + # case, we stop here. return res # No spam-checker has rejected the event, let it pass. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 7ca126dbd1..38b71a2c96 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -895,6 +895,21 @@ class EventCreationHandler: spam_check = await self.spam_checker.check_event_for_spam(event) if spam_check is not synapse.spam_checker_api.Allow.ALLOW: + if isinstance(spam_check, tuple): + try: + [code, dict] = spam_check + raise SynapseError( + 403, + "This message had been rejected as probable spam", + code, + dict, + ) + except ValueError: + logger.error( + "Spam-check module returned invalid error value. Expecting [code, dict], got %s", + spam_check, + ) + spam_check = Codes.FORBIDDEN raise SynapseError( 403, "This message had been rejected as probable spam", spam_check ) From cd9fc058dea3a5f5ef282ebdcb48aa6caf1eb722 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Mon, 30 May 2022 18:37:52 +0200 Subject: [PATCH 45/74] Document the Synapse version of a new module API method (#12917) --- changelog.d/12917.feature | 1 + synapse/module_api/__init__.py | 7 ++++++- 2 files changed, 7 insertions(+), 1 deletion(-) create mode 100644 changelog.d/12917.feature diff --git a/changelog.d/12917.feature b/changelog.d/12917.feature new file mode 100644 index 0000000000..b24489aaad --- /dev/null +++ b/changelog.d/12917.feature @@ -0,0 +1 @@ +Add storage and module API methods to get monthly active users (and their corresponding appservices) within an optionally specified time range. diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index edcf59aa0b..6668f64c90 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -1149,7 +1149,10 @@ class ModuleApi: ) async def sleep(self, seconds: float) -> None: - """Sleeps for the given number of seconds.""" + """Sleeps for the given number of seconds. + + Added in Synapse v1.49.0. + """ await self._clock.sleep(seconds) @@ -1435,6 +1438,8 @@ class ModuleApi: """Generates list of monthly active users and their services. Please see corresponding storage docstring for more details. + Added in Synapse v1.61.0. + Arguments: start_timestamp: If specified, only include users that were first active at or after this point From c4f548e05d9a1858787d3a0883a5393d315473d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jacek=20Ku=C5=9Bnierz?= Date: Mon, 30 May 2022 22:03:52 +0200 Subject: [PATCH 46/74] Don't return `end` from `/messages` if there are no more events (#12903) Signed-off-by: Jacek Kusnierz --- changelog.d/12903.bugfix | 1 + synapse/handlers/pagination.py | 23 +++++++++++++++++------ 2 files changed, 18 insertions(+), 6 deletions(-) create mode 100644 changelog.d/12903.bugfix diff --git a/changelog.d/12903.bugfix b/changelog.d/12903.bugfix new file mode 100644 index 0000000000..f264399483 --- /dev/null +++ b/changelog.d/12903.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug which caused the `/messages` endpoint to return an incorrect `end` attribute when there were no more events. Contributed by @Vetchu. diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 6f4820c240..35afe6b855 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -515,14 +515,25 @@ class PaginationHandler: next_token = from_token.copy_and_replace(StreamKeyType.ROOM, next_key) - if events: - if event_filter: - events = await event_filter.filter(events) + # if no events are returned from pagination, that implies + # we have reached the end of the available events. + # In that case we do not return end, to tell the client + # there is no need for further queries. + if not events: + return { + "chunk": [], + "start": await from_token.to_string(self.store), + } - events = await filter_events_for_client( - self.storage, user_id, events, is_peeking=(member_event_id is None) - ) + if event_filter: + events = await event_filter.filter(events) + events = await filter_events_for_client( + self.storage, user_id, events, is_peeking=(member_event_id is None) + ) + + # if after the filter applied there are no more events + # return immediately - but there might be more in next_token batch if not events: return { "chunk": [], From e0fae823e9938618a260adadb82bfee6e4c2f907 Mon Sep 17 00:00:00 2001 From: Travis Ralston Date: Mon, 30 May 2022 20:27:19 -0600 Subject: [PATCH 47/74] Fix M_USER_ACCOUNT_SUSPENDED error code for spec compliance (#12922) `M_` is a reserved namespace. --- changelog.d/12845.feature | 2 +- changelog.d/12922.feature | 1 + synapse/api/errors.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) create mode 100644 changelog.d/12922.feature diff --git a/changelog.d/12845.feature b/changelog.d/12845.feature index 628fb16d08..815a1f10ea 100644 --- a/changelog.d/12845.feature +++ b/changelog.d/12845.feature @@ -1 +1 @@ -Support the new error code "M_ORG_MATRIX_MSC3823_USER_ACCOUNT_SUSPENDED" from [MSC3823](https://github.com/matrix-org/matrix-spec-proposals/pull/3823). \ No newline at end of file +Support the new error code "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED" from [MSC3823](https://github.com/matrix-org/matrix-spec-proposals/pull/3823). \ No newline at end of file diff --git a/changelog.d/12922.feature b/changelog.d/12922.feature new file mode 100644 index 0000000000..815a1f10ea --- /dev/null +++ b/changelog.d/12922.feature @@ -0,0 +1 @@ +Support the new error code "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED" from [MSC3823](https://github.com/matrix-org/matrix-spec-proposals/pull/3823). \ No newline at end of file diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 54268e0889..cc7b785472 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -84,7 +84,7 @@ class Codes(str, Enum): # By opposition to `USER_DEACTIVATED`, this is a reversible measure # that can possibly be appealed and reverted. # Part of MSC3823. - USER_ACCOUNT_SUSPENDED = "M_ORG_MATRIX_MSC3823_USER_ACCOUNT_SUSPENDED" + USER_ACCOUNT_SUSPENDED = "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED" BAD_ALIAS = "M_BAD_ALIAS" # For restricted join rules. From bcfdfeb65df820ba460de5fb96a554f694de361b Mon Sep 17 00:00:00 2001 From: Travis Ralston Date: Mon, 30 May 2022 20:27:19 -0600 Subject: [PATCH 48/74] Revert "Fix M_USER_ACCOUNT_SUSPENDED error code for spec compliance (#12922)" This reverts commit e0fae823e9938618a260adadb82bfee6e4c2f907. --- changelog.d/12845.feature | 2 +- changelog.d/12922.feature | 1 - synapse/api/errors.py | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) delete mode 100644 changelog.d/12922.feature diff --git a/changelog.d/12845.feature b/changelog.d/12845.feature index 815a1f10ea..628fb16d08 100644 --- a/changelog.d/12845.feature +++ b/changelog.d/12845.feature @@ -1 +1 @@ -Support the new error code "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED" from [MSC3823](https://github.com/matrix-org/matrix-spec-proposals/pull/3823). \ No newline at end of file +Support the new error code "M_ORG_MATRIX_MSC3823_USER_ACCOUNT_SUSPENDED" from [MSC3823](https://github.com/matrix-org/matrix-spec-proposals/pull/3823). \ No newline at end of file diff --git a/changelog.d/12922.feature b/changelog.d/12922.feature deleted file mode 100644 index 815a1f10ea..0000000000 --- a/changelog.d/12922.feature +++ /dev/null @@ -1 +0,0 @@ -Support the new error code "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED" from [MSC3823](https://github.com/matrix-org/matrix-spec-proposals/pull/3823). \ No newline at end of file diff --git a/synapse/api/errors.py b/synapse/api/errors.py index cc7b785472..54268e0889 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -84,7 +84,7 @@ class Codes(str, Enum): # By opposition to `USER_DEACTIVATED`, this is a reversible measure # that can possibly be appealed and reverted. # Part of MSC3823. - USER_ACCOUNT_SUSPENDED = "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED" + USER_ACCOUNT_SUSPENDED = "M_ORG_MATRIX_MSC3823_USER_ACCOUNT_SUSPENDED" BAD_ALIAS = "M_BAD_ALIAS" # For restricted join rules. From d0e40dfe29fdf068972fc9a63f50fc94daaa06b3 Mon Sep 17 00:00:00 2001 From: Travis Ralston Date: Tue, 31 May 2022 01:42:18 -0600 Subject: [PATCH 49/74] Fix M_USER_ACCOUNT_SUSPENDED error code for spec compliance (#12923) --- changelog.d/12845.feature | 2 +- changelog.d/12923.feature | 1 + synapse/api/errors.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) create mode 100644 changelog.d/12923.feature diff --git a/changelog.d/12845.feature b/changelog.d/12845.feature index 628fb16d08..815a1f10ea 100644 --- a/changelog.d/12845.feature +++ b/changelog.d/12845.feature @@ -1 +1 @@ -Support the new error code "M_ORG_MATRIX_MSC3823_USER_ACCOUNT_SUSPENDED" from [MSC3823](https://github.com/matrix-org/matrix-spec-proposals/pull/3823). \ No newline at end of file +Support the new error code "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED" from [MSC3823](https://github.com/matrix-org/matrix-spec-proposals/pull/3823). \ No newline at end of file diff --git a/changelog.d/12923.feature b/changelog.d/12923.feature new file mode 100644 index 0000000000..815a1f10ea --- /dev/null +++ b/changelog.d/12923.feature @@ -0,0 +1 @@ +Support the new error code "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED" from [MSC3823](https://github.com/matrix-org/matrix-spec-proposals/pull/3823). \ No newline at end of file diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 54268e0889..cc7b785472 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -84,7 +84,7 @@ class Codes(str, Enum): # By opposition to `USER_DEACTIVATED`, this is a reversible measure # that can possibly be appealed and reverted. # Part of MSC3823. - USER_ACCOUNT_SUSPENDED = "M_ORG_MATRIX_MSC3823_USER_ACCOUNT_SUSPENDED" + USER_ACCOUNT_SUSPENDED = "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED" BAD_ALIAS = "M_BAD_ALIAS" # For restricted join rules. From e541bb9eed964e6840ddf2cd859af3f94150dc85 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 31 May 2022 07:42:50 -0400 Subject: [PATCH 50/74] Rework stream token to stop caring about groups. (#12897) --- changelog.d/12897.removal | 1 + synapse/streams/events.py | 4 ++-- synapse/types.py | 6 +++++- 3 files changed, 8 insertions(+), 3 deletions(-) create mode 100644 changelog.d/12897.removal diff --git a/changelog.d/12897.removal b/changelog.d/12897.removal new file mode 100644 index 0000000000..41f6fae5da --- /dev/null +++ b/changelog.d/12897.removal @@ -0,0 +1 @@ +Remove support for the non-standard groups/communities feature from Synapse. diff --git a/synapse/streams/events.py b/synapse/streams/events.py index acf17ba623..54e0b1a23b 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -54,7 +54,6 @@ class EventSources: push_rules_key = self.store.get_max_push_rules_stream_id() to_device_key = self.store.get_to_device_stream_token() device_list_key = self.store.get_device_stream_token() - groups_key = self.store.get_group_stream_token() token = StreamToken( room_key=self.sources.room.get_current_key(), @@ -65,7 +64,8 @@ class EventSources: push_rules_key=push_rules_key, to_device_key=to_device_key, device_list_key=device_list_key, - groups_key=groups_key, + # Groups key is unused. + groups_key=0, ) return token diff --git a/synapse/types.py b/synapse/types.py index 091cc611ab..0586d2cbb9 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -639,7 +639,7 @@ class StreamToken: 6. `push_rules_key`: `541479` 7. `to_device_key`: `274711` 8. `device_list_key`: `265584` - 9. `groups_key`: `1` + 9. `groups_key`: `1` (note that this key is now unused) You can see how many of these keys correspond to the various fields in a "/sync" response: @@ -691,6 +691,7 @@ class StreamToken: push_rules_key: int to_device_key: int device_list_key: int + # Note that the groups key is no longer used and may have bogus values. groups_key: int _SEPARATOR = "_" @@ -722,6 +723,9 @@ class StreamToken: str(self.push_rules_key), str(self.to_device_key), str(self.device_list_key), + # Note that the groups key is no longer used, but it is still + # serialized so that there will not be confusion in the future + # if additional tokens are added. str(self.groups_key), ] ) From 1e453053cb12ff084fdcdc2f75c08ced274dff21 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 31 May 2022 13:17:50 +0100 Subject: [PATCH 51/74] Rename storage classes (#12913) --- changelog.d/12913.misc | 1 + synapse/events/snapshot.py | 10 +- synapse/federation/federation_server.py | 1 - synapse/handlers/admin.py | 12 +- synapse/handlers/device.py | 4 +- synapse/handlers/events.py | 4 +- synapse/handlers/federation.py | 30 +- synapse/handlers/federation_event.py | 27 +- synapse/handlers/initial_sync.py | 17 +- synapse/handlers/message.py | 30 +- synapse/handlers/pagination.py | 17 +- synapse/handlers/relations.py | 7 +- synapse/handlers/room.py | 11 +- synapse/handlers/room_batch.py | 4 +- synapse/handlers/search.py | 14 +- synapse/handlers/sync.py | 26 +- synapse/notifier.py | 4 +- synapse/push/httppusher.py | 6 +- synapse/push/mailer.py | 14 +- synapse/push/push_tools.py | 4 +- synapse/replication/http/federation.py | 4 +- synapse/replication/http/send_event.py | 6 +- synapse/server.py | 7 +- synapse/state/__init__.py | 51 ++- synapse/storage/__init__.py | 35 +- synapse/storage/controllers/__init__.py | 46 +++ .../{ => controllers}/persist_events.py | 2 +- .../storage/{ => controllers}/purge_events.py | 2 +- synapse/storage/controllers/state.py | 351 ++++++++++++++++++ synapse/storage/state.py | 320 ---------------- synapse/visibility.py | 10 +- tests/events/test_snapshot.py | 4 +- tests/handlers/test_federation.py | 6 +- tests/handlers/test_federation_event.py | 9 +- tests/handlers/test_message.py | 14 +- tests/handlers/test_user_directory.py | 2 +- tests/replication/slave/storage/_base.py | 2 +- .../replication/slave/storage/test_events.py | 10 +- .../slave/storage/test_receipts.py | 12 +- tests/rest/admin/test_user.py | 4 +- tests/rest/client/test_retention.py | 4 +- tests/rest/client/test_room_batch.py | 6 +- tests/storage/test_event_chain.py | 3 +- tests/storage/test_events.py | 12 +- tests/storage/test_purge.py | 14 +- tests/storage/test_redaction.py | 14 +- tests/storage/test_room.py | 4 +- tests/storage/test_room_search.py | 4 +- tests/storage/test_state.py | 2 +- tests/test_state.py | 6 +- tests/test_utils/event_injection.py | 2 +- tests/test_visibility.py | 46 ++- tests/utils.py | 2 +- 53 files changed, 708 insertions(+), 551 deletions(-) create mode 100644 changelog.d/12913.misc create mode 100644 synapse/storage/controllers/__init__.py rename synapse/storage/{ => controllers}/persist_events.py (99%) rename synapse/storage/{ => controllers}/purge_events.py (99%) create mode 100644 synapse/storage/controllers/state.py diff --git a/changelog.d/12913.misc b/changelog.d/12913.misc new file mode 100644 index 0000000000..a2bc940557 --- /dev/null +++ b/changelog.d/12913.misc @@ -0,0 +1 @@ +Rename storage classes. diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index 7a91544119..b700cbbfa1 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -22,7 +22,7 @@ from synapse.events import EventBase from synapse.types import JsonDict, StateMap if TYPE_CHECKING: - from synapse.storage import Storage + from synapse.storage.controllers import StorageControllers from synapse.storage.databases.main import DataStore from synapse.storage.state import StateFilter @@ -84,7 +84,7 @@ class EventContext: incomplete state. """ - _storage: "Storage" + _storage: "StorageControllers" rejected: Union[Literal[False], str] = False _state_group: Optional[int] = None state_group_before_event: Optional[int] = None @@ -97,7 +97,7 @@ class EventContext: @staticmethod def with_state( - storage: "Storage", + storage: "StorageControllers", state_group: Optional[int], state_group_before_event: Optional[int], state_delta_due_to_event: Optional[StateMap[str]], @@ -117,7 +117,7 @@ class EventContext: @staticmethod def for_outlier( - storage: "Storage", + storage: "StorageControllers", ) -> "EventContext": """Return an EventContext instance suitable for persisting an outlier event""" return EventContext(storage=storage) @@ -147,7 +147,7 @@ class EventContext: } @staticmethod - def deserialize(storage: "Storage", input: JsonDict) -> "EventContext": + def deserialize(storage: "StorageControllers", input: JsonDict) -> "EventContext": """Converts a dict that was produced by `serialize` back into a EventContext. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 5b227b85fd..3ecede22d9 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -109,7 +109,6 @@ class FederationServer(FederationBase): super().__init__(hs) self.handler = hs.get_federation_handler() - self.storage = hs.get_storage() self._spam_checker = hs.get_spam_checker() self._federation_event_handler = hs.get_federation_event_handler() self.state = hs.get_state_handler() diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 50e34743b7..d4fe7df533 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -30,8 +30,8 @@ logger = logging.getLogger(__name__) class AdminHandler: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main - self.storage = hs.get_storage() - self.state_storage = self.storage.state + self._storage_controllers = hs.get_storage_controllers() + self._state_storage_controller = self._storage_controllers.state async def get_whois(self, user: UserID) -> JsonDict: connections = [] @@ -197,7 +197,9 @@ class AdminHandler: from_key = events[-1].internal_metadata.after - events = await filter_events_for_client(self.storage, user_id, events) + events = await filter_events_for_client( + self._storage_controllers, user_id, events + ) writer.write_events(room_id, events) @@ -233,7 +235,9 @@ class AdminHandler: for event_id in extremities: if not event_to_unseen_prevs[event_id]: continue - state = await self.state_storage.get_state_for_event(event_id) + state = await self._state_storage_controller.get_state_for_event( + event_id + ) writer.write_state(room_id, event_id, state) return writer.finished() diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 2a56473dc6..72faf2ee38 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -71,7 +71,7 @@ class DeviceWorkerHandler: self.store = hs.get_datastores().main self.notifier = hs.get_notifier() self.state = hs.get_state_handler() - self.state_storage = hs.get_storage().state + self._state_storage = hs.get_storage_controllers().state self._auth_handler = hs.get_auth_handler() self.server_name = hs.hostname @@ -204,7 +204,7 @@ class DeviceWorkerHandler: continue # mapping from event_id -> state_dict - prev_state_ids = await self.state_storage.get_state_ids_for_events( + prev_state_ids = await self._state_storage.get_state_ids_for_events( event_ids ) diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index cb7e0ca7a8..ac13340d3a 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -139,7 +139,7 @@ class EventStreamHandler: class EventHandler: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main - self.storage = hs.get_storage() + self._storage_controllers = hs.get_storage_controllers() async def get_event( self, @@ -177,7 +177,7 @@ class EventHandler: is_peeking = user.to_string() not in users filtered = await filter_events_for_client( - self.storage, user.to_string(), [event], is_peeking=is_peeking + self._storage_controllers, user.to_string(), [event], is_peeking=is_peeking ) if not filtered: diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index c8233270d7..80ee7e7b4e 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -125,8 +125,8 @@ class FederationHandler: self.hs = hs self.store = hs.get_datastores().main - self.storage = hs.get_storage() - self.state_storage = self.storage.state + self._storage_controllers = hs.get_storage_controllers() + self._state_storage_controller = self._storage_controllers.state self.federation_client = hs.get_federation_client() self.state_handler = hs.get_state_handler() self.server_name = hs.hostname @@ -324,7 +324,7 @@ class FederationHandler: # We set `check_history_visibility_only` as we might otherwise get false # positives from users having been erased. filtered_extremities = await filter_events_for_server( - self.storage, + self._storage_controllers, self.server_name, events_to_check, redact=False, @@ -660,7 +660,7 @@ class FederationHandler: # in the invitee's sync stream. It is stripped out for all other local users. event.unsigned["knock_room_state"] = stripped_room_state["knock_state_events"] - context = EventContext.for_outlier(self.storage) + context = EventContext.for_outlier(self._storage_controllers) stream_id = await self._federation_event_handler.persist_events_and_notify( event.room_id, [(event, context)] ) @@ -849,7 +849,7 @@ class FederationHandler: ) ) - context = EventContext.for_outlier(self.storage) + context = EventContext.for_outlier(self._storage_controllers) await self._federation_event_handler.persist_events_and_notify( event.room_id, [(event, context)] ) @@ -878,7 +878,7 @@ class FederationHandler: await self.federation_client.send_leave(host_list, event) - context = EventContext.for_outlier(self.storage) + context = EventContext.for_outlier(self._storage_controllers) stream_id = await self._federation_event_handler.persist_events_and_notify( event.room_id, [(event, context)] ) @@ -1027,7 +1027,7 @@ class FederationHandler: if event.internal_metadata.outlier: raise NotFoundError("State not known at event %s" % (event_id,)) - state_groups = await self.state_storage.get_state_groups_ids( + state_groups = await self._state_storage_controller.get_state_groups_ids( room_id, [event_id] ) @@ -1078,7 +1078,9 @@ class FederationHandler: ], ) - events = await filter_events_for_server(self.storage, origin, events) + events = await filter_events_for_server( + self._storage_controllers, origin, events + ) return events @@ -1109,7 +1111,9 @@ class FederationHandler: if not in_room: raise AuthError(403, "Host not in room.") - events = await filter_events_for_server(self.storage, origin, [event]) + events = await filter_events_for_server( + self._storage_controllers, origin, [event] + ) event = events[0] return event else: @@ -1138,7 +1142,7 @@ class FederationHandler: ) missing_events = await filter_events_for_server( - self.storage, origin, missing_events + self._storage_controllers, origin, missing_events ) return missing_events @@ -1480,9 +1484,11 @@ class FederationHandler: # clear the lazy-loading flag. logger.info("Updating current state for %s", room_id) assert ( - self.storage.persistence is not None + self._storage_controllers.persistence is not None ), "TODO(faster_joins): support for workers" - await self.storage.persistence.update_current_state(room_id) + await self._storage_controllers.persistence.update_current_state( + room_id + ) logger.info("Clearing partial-state flag for %s", room_id) success = await self.store.clear_partial_state_room(room_id) diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index a1361af272..b908674529 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -98,8 +98,8 @@ class FederationEventHandler: def __init__(self, hs: "HomeServer"): self._store = hs.get_datastores().main - self._storage = hs.get_storage() - self._state_storage = self._storage.state + self._storage_controllers = hs.get_storage_controllers() + self._state_storage_controller = self._storage_controllers.state self._state_handler = hs.get_state_handler() self._event_creation_handler = hs.get_event_creation_handler() @@ -535,7 +535,9 @@ class FederationEventHandler: ) return await self._store.update_state_for_partial_state_event(event, context) - self._state_storage.notify_event_un_partial_stated(event.event_id) + self._state_storage_controller.notify_event_un_partial_stated( + event.event_id + ) async def backfill( self, dest: str, room_id: str, limit: int, extremities: Collection[str] @@ -835,7 +837,9 @@ class FederationEventHandler: try: # Get the state of the events we know about - ours = await self._state_storage.get_state_groups_ids(room_id, seen) + ours = await self._state_storage_controller.get_state_groups_ids( + room_id, seen + ) # state_maps is a list of mappings from (type, state_key) to event_id state_maps: List[StateMap[str]] = list(ours.values()) @@ -1436,7 +1440,7 @@ class FederationEventHandler: # we're not bothering about room state, so flag the event as an outlier. event.internal_metadata.outlier = True - context = EventContext.for_outlier(self._storage) + context = EventContext.for_outlier(self._storage_controllers) try: validate_event_for_room_version(room_version_obj, event) check_auth_rules_for_event(room_version_obj, event, auth) @@ -1613,7 +1617,7 @@ class FederationEventHandler: # given state at the event. This should correctly handle cases # like bans, especially with state res v2. - state_sets_d = await self._state_storage.get_state_groups_ids( + state_sets_d = await self._state_storage_controller.get_state_groups_ids( event.room_id, extrem_ids ) state_sets: List[StateMap[str]] = list(state_sets_d.values()) @@ -1885,7 +1889,7 @@ class FederationEventHandler: # create a new state group as a delta from the existing one. prev_group = context.state_group - state_group = await self._state_storage.store_state_group( + state_group = await self._state_storage_controller.store_state_group( event.event_id, event.room_id, prev_group=prev_group, @@ -1894,7 +1898,7 @@ class FederationEventHandler: ) return EventContext.with_state( - storage=self._storage, + storage=self._storage_controllers, state_group=state_group, state_group_before_event=context.state_group_before_event, state_delta_due_to_event=state_updates, @@ -1984,11 +1988,14 @@ class FederationEventHandler: ) return result["max_stream_id"] else: - assert self._storage.persistence + assert self._storage_controllers.persistence # Note that this returns the events that were persisted, which may not be # the same as were passed in if some were deduplicated due to transaction IDs. - events, max_stream_token = await self._storage.persistence.persist_events( + ( + events, + max_stream_token, + ) = await self._storage_controllers.persistence.persist_events( event_and_contexts, backfilled=backfilled ) diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index fbdbeeedfd..d2b489e816 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -67,8 +67,8 @@ class InitialSyncHandler: ] ] = ResponseCache(hs.get_clock(), "initial_sync_cache") self._event_serializer = hs.get_event_client_serializer() - self.storage = hs.get_storage() - self.state_storage = self.storage.state + self._storage_controllers = hs.get_storage_controllers() + self._state_storage_controller = self._storage_controllers.state async def snapshot_all_rooms( self, @@ -198,7 +198,8 @@ class InitialSyncHandler: event.stream_ordering, ) deferred_room_state = run_in_background( - self.state_storage.get_state_for_events, [event.event_id] + self._state_storage_controller.get_state_for_events, + [event.event_id], ).addCallback( lambda states: cast(StateMap[EventBase], states[event.event_id]) ) @@ -218,7 +219,7 @@ class InitialSyncHandler: ).addErrback(unwrapFirstError) messages = await filter_events_for_client( - self.storage, user_id, messages + self._storage_controllers, user_id, messages ) start_token = now_token.copy_and_replace(StreamKeyType.ROOM, token) @@ -355,7 +356,9 @@ class InitialSyncHandler: member_event_id: str, is_peeking: bool, ) -> JsonDict: - room_state = await self.state_storage.get_state_for_event(member_event_id) + room_state = await self._state_storage_controller.get_state_for_event( + member_event_id + ) limit = pagin_config.limit if pagin_config else None if limit is None: @@ -369,7 +372,7 @@ class InitialSyncHandler: ) messages = await filter_events_for_client( - self.storage, user_id, messages, is_peeking=is_peeking + self._storage_controllers, user_id, messages, is_peeking=is_peeking ) start_token = StreamToken.START.copy_and_replace(StreamKeyType.ROOM, token) @@ -474,7 +477,7 @@ class InitialSyncHandler: ) messages = await filter_events_for_client( - self.storage, user_id, messages, is_peeking=is_peeking + self._storage_controllers, user_id, messages, is_peeking=is_peeking ) start_token = now_token.copy_and_replace(StreamKeyType.ROOM, token) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 38b71a2c96..f377769071 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -84,8 +84,8 @@ class MessageHandler: self.clock = hs.get_clock() self.state = hs.get_state_handler() self.store = hs.get_datastores().main - self.storage = hs.get_storage() - self.state_storage = self.storage.state + self._storage_controllers = hs.get_storage_controllers() + self._state_storage_controller = self._storage_controllers.state self._event_serializer = hs.get_event_client_serializer() self._ephemeral_events_enabled = hs.config.server.enable_ephemeral_messages @@ -132,7 +132,7 @@ class MessageHandler: assert ( membership_event_id is not None ), "check_user_in_room_or_world_readable returned invalid data" - room_state = await self.state_storage.get_state_for_events( + room_state = await self._state_storage_controller.get_state_for_events( [membership_event_id], StateFilter.from_types([key]) ) data = room_state[membership_event_id].get(key) @@ -193,7 +193,7 @@ class MessageHandler: # check whether the user is in the room at that time to determine # whether they should be treated as peeking. - state_map = await self.state_storage.get_state_for_event( + state_map = await self._state_storage_controller.get_state_for_event( last_event.event_id, StateFilter.from_types([(EventTypes.Member, user_id)]), ) @@ -206,7 +206,7 @@ class MessageHandler: is_peeking = not joined visible_events = await filter_events_for_client( - self.storage, + self._storage_controllers, user_id, [last_event], filter_send_to_client=False, @@ -214,8 +214,10 @@ class MessageHandler: ) if visible_events: - room_state_events = await self.state_storage.get_state_for_events( - [last_event.event_id], state_filter=state_filter + room_state_events = ( + await self._state_storage_controller.get_state_for_events( + [last_event.event_id], state_filter=state_filter + ) ) room_state: Mapping[Any, EventBase] = room_state_events[ last_event.event_id @@ -244,8 +246,10 @@ class MessageHandler: assert ( membership_event_id is not None ), "check_user_in_room_or_world_readable returned invalid data" - room_state_events = await self.state_storage.get_state_for_events( - [membership_event_id], state_filter=state_filter + room_state_events = ( + await self._state_storage_controller.get_state_for_events( + [membership_event_id], state_filter=state_filter + ) ) room_state = room_state_events[membership_event_id] @@ -402,7 +406,7 @@ class EventCreationHandler: self.auth = hs.get_auth() self._event_auth_handler = hs.get_event_auth_handler() self.store = hs.get_datastores().main - self.storage = hs.get_storage() + self._storage_controllers = hs.get_storage_controllers() self.state = hs.get_state_handler() self.clock = hs.get_clock() self.validator = EventValidator() @@ -1032,7 +1036,7 @@ class EventCreationHandler: # after it is created if builder.internal_metadata.outlier: event.internal_metadata.outlier = True - context = EventContext.for_outlier(self.storage) + context = EventContext.for_outlier(self._storage_controllers) elif ( event.type == EventTypes.MSC2716_INSERTION and state_event_ids @@ -1445,7 +1449,7 @@ class EventCreationHandler: """ extra_users = extra_users or [] - assert self.storage.persistence is not None + assert self._storage_controllers.persistence is not None assert self._events_shard_config.should_handle( self._instance_name, event.room_id ) @@ -1679,7 +1683,7 @@ class EventCreationHandler: event, event_pos, max_stream_token, - ) = await self.storage.persistence.persist_event( + ) = await self._storage_controllers.persistence.persist_event( event, context=context, backfilled=backfilled ) diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 35afe6b855..6262a35822 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -129,8 +129,8 @@ class PaginationHandler: self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastores().main - self.storage = hs.get_storage() - self.state_storage = self.storage.state + self._storage_controllers = hs.get_storage_controllers() + self._state_storage_controller = self._storage_controllers.state self.clock = hs.get_clock() self._server_name = hs.hostname self._room_shutdown_handler = hs.get_room_shutdown_handler() @@ -352,7 +352,7 @@ class PaginationHandler: self._purges_in_progress_by_room.add(room_id) try: async with self.pagination_lock.write(room_id): - await self.storage.purge_events.purge_history( + await self._storage_controllers.purge_events.purge_history( room_id, token, delete_local_events ) logger.info("[purge] complete") @@ -414,7 +414,7 @@ class PaginationHandler: if joined: raise SynapseError(400, "Users are still joined to this room") - await self.storage.purge_events.purge_room(room_id) + await self._storage_controllers.purge_events.purge_room(room_id) async def get_messages( self, @@ -529,7 +529,10 @@ class PaginationHandler: events = await event_filter.filter(events) events = await filter_events_for_client( - self.storage, user_id, events, is_peeking=(member_event_id is None) + self._storage_controllers, + user_id, + events, + is_peeking=(member_event_id is None), ) # if after the filter applied there are no more events @@ -550,7 +553,7 @@ class PaginationHandler: (EventTypes.Member, event.sender) for event in events ) - state_ids = await self.state_storage.get_state_ids_for_event( + state_ids = await self._state_storage_controller.get_state_ids_for_event( events[0].event_id, state_filter=state_filter ) @@ -664,7 +667,7 @@ class PaginationHandler: 400, "Users are still joined to this room" ) - await self.storage.purge_events.purge_room(room_id) + await self._storage_controllers.purge_events.purge_room(room_id) logger.info("complete") self._delete_by_id[delete_id].status = DeleteStatus.STATUS_COMPLETE diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index ab7e54857d..9a1cc11bb3 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -69,7 +69,7 @@ class BundledAggregations: class RelationsHandler: def __init__(self, hs: "HomeServer"): self._main_store = hs.get_datastores().main - self._storage = hs.get_storage() + self._storage_controllers = hs.get_storage_controllers() self._auth = hs.get_auth() self._clock = hs.get_clock() self._event_handler = hs.get_event_handler() @@ -143,7 +143,10 @@ class RelationsHandler: ) events = await filter_events_for_client( - self._storage, user_id, events, is_peeking=(member_event_id is None) + self._storage_controllers, + user_id, + events, + is_peeking=(member_event_id is None), ) now = self._clock.time_msec() diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index e2775b34f1..5c91d33f58 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1192,8 +1192,8 @@ class RoomContextHandler: self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastores().main - self.storage = hs.get_storage() - self.state_storage = self.storage.state + self._storage_controllers = hs.get_storage_controllers() + self._state_storage_controller = self._storage_controllers.state self._relations_handler = hs.get_relations_handler() async def get_event_context( @@ -1236,7 +1236,10 @@ class RoomContextHandler: if use_admin_priviledge: return events return await filter_events_for_client( - self.storage, user.to_string(), events, is_peeking=is_peeking + self._storage_controllers, + user.to_string(), + events, + is_peeking=is_peeking, ) event = await self.store.get_event( @@ -1293,7 +1296,7 @@ class RoomContextHandler: # first? Shouldn't we be consistent with /sync? # https://github.com/matrix-org/matrix-doc/issues/687 - state = await self.state_storage.get_state_for_events( + state = await self._state_storage_controller.get_state_for_events( [last_event_id], state_filter=state_filter ) diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py index 7ce32f2e9c..1414e575d6 100644 --- a/synapse/handlers/room_batch.py +++ b/synapse/handlers/room_batch.py @@ -17,7 +17,7 @@ class RoomBatchHandler: def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastores().main - self.state_storage = hs.get_storage().state + self._state_storage_controller = hs.get_storage_controllers().state self.event_creation_handler = hs.get_event_creation_handler() self.room_member_handler = hs.get_room_member_handler() self.auth = hs.get_auth() @@ -141,7 +141,7 @@ class RoomBatchHandler: ) = await self.store.get_max_depth_of(event_ids) # mapping from (type, state_key) -> state_event_id assert most_recent_event_id is not None - prev_state_map = await self.state_storage.get_state_ids_for_event( + prev_state_map = await self._state_storage_controller.get_state_ids_for_event( most_recent_event_id ) # List of state event ID's diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index e02c915248..659f99f7e2 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -55,8 +55,8 @@ class SearchHandler: self.hs = hs self._event_serializer = hs.get_event_client_serializer() self._relations_handler = hs.get_relations_handler() - self.storage = hs.get_storage() - self.state_storage = self.storage.state + self._storage_controllers = hs.get_storage_controllers() + self._state_storage_controller = self._storage_controllers.state self.auth = hs.get_auth() async def get_old_rooms_from_upgraded_room(self, room_id: str) -> Iterable[str]: @@ -460,7 +460,7 @@ class SearchHandler: filtered_events = await search_filter.filter([r["event"] for r in results]) events = await filter_events_for_client( - self.storage, user.to_string(), filtered_events + self._storage_controllers, user.to_string(), filtered_events ) events.sort(key=lambda e: -rank_map[e.event_id]) @@ -559,7 +559,7 @@ class SearchHandler: filtered_events = await search_filter.filter([r["event"] for r in results]) events = await filter_events_for_client( - self.storage, user.to_string(), filtered_events + self._storage_controllers, user.to_string(), filtered_events ) room_events.extend(events) @@ -644,11 +644,11 @@ class SearchHandler: ) events_before = await filter_events_for_client( - self.storage, user.to_string(), res.events_before + self._storage_controllers, user.to_string(), res.events_before ) events_after = await filter_events_for_client( - self.storage, user.to_string(), res.events_after + self._storage_controllers, user.to_string(), res.events_after ) context: JsonDict = { @@ -677,7 +677,7 @@ class SearchHandler: [(EventTypes.Member, sender) for sender in senders] ) - state = await self.state_storage.get_state_for_event( + state = await self._state_storage_controller.get_state_for_event( last_event_id, state_filter ) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index c5c538e0c3..b5859dcb28 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -238,8 +238,8 @@ class SyncHandler: self.clock = hs.get_clock() self.state = hs.get_state_handler() self.auth = hs.get_auth() - self.storage = hs.get_storage() - self.state_storage = self.storage.state + self._storage_controllers = hs.get_storage_controllers() + self._state_storage_controller = self._storage_controllers.state # TODO: flush cache entries on subsequent sync request. # Once we get the next /sync request (ie, one with the same access token @@ -512,7 +512,7 @@ class SyncHandler: current_state_ids = frozenset(current_state_ids_map.values()) recents = await filter_events_for_client( - self.storage, + self._storage_controllers, sync_config.user.to_string(), recents, always_include_ids=current_state_ids, @@ -580,7 +580,7 @@ class SyncHandler: current_state_ids = frozenset(current_state_ids_map.values()) loaded_recents = await filter_events_for_client( - self.storage, + self._storage_controllers, sync_config.user.to_string(), loaded_recents, always_include_ids=current_state_ids, @@ -630,7 +630,7 @@ class SyncHandler: event: event of interest state_filter: The state filter used to fetch state from the database. """ - state_ids = await self.state_storage.get_state_ids_for_event( + state_ids = await self._state_storage_controller.get_state_ids_for_event( event.event_id, state_filter=state_filter or StateFilter.all() ) if event.is_state(): @@ -710,7 +710,7 @@ class SyncHandler: return None last_event = last_events[-1] - state_ids = await self.state_storage.get_state_ids_for_event( + state_ids = await self._state_storage_controller.get_state_ids_for_event( last_event.event_id, state_filter=StateFilter.from_types( [(EventTypes.Name, ""), (EventTypes.CanonicalAlias, "")] @@ -889,13 +889,15 @@ class SyncHandler: if full_state: if batch: current_state_ids = ( - await self.state_storage.get_state_ids_for_event( + await self._state_storage_controller.get_state_ids_for_event( batch.events[-1].event_id, state_filter=state_filter ) ) - state_ids = await self.state_storage.get_state_ids_for_event( - batch.events[0].event_id, state_filter=state_filter + state_ids = ( + await self._state_storage_controller.get_state_ids_for_event( + batch.events[0].event_id, state_filter=state_filter + ) ) else: @@ -915,7 +917,7 @@ class SyncHandler: elif batch.limited: if batch: state_at_timeline_start = ( - await self.state_storage.get_state_ids_for_event( + await self._state_storage_controller.get_state_ids_for_event( batch.events[0].event_id, state_filter=state_filter ) ) @@ -950,7 +952,7 @@ class SyncHandler: if batch: current_state_ids = ( - await self.state_storage.get_state_ids_for_event( + await self._state_storage_controller.get_state_ids_for_event( batch.events[-1].event_id, state_filter=state_filter ) ) @@ -982,7 +984,7 @@ class SyncHandler: # So we fish out all the member events corresponding to the # timeline here, and then dedupe any redundant ones below. - state_ids = await self.state_storage.get_state_ids_for_event( + state_ids = await self._state_storage_controller.get_state_ids_for_event( batch.events[0].event_id, # we only want members! state_filter=StateFilter.from_types( diff --git a/synapse/notifier.py b/synapse/notifier.py index c2b66eec62..1100434b3f 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -221,7 +221,7 @@ class Notifier: self.room_to_user_streams: Dict[str, Set[_NotifierUserStream]] = {} self.hs = hs - self.storage = hs.get_storage() + self._storage_controllers = hs.get_storage_controllers() self.event_sources = hs.get_event_sources() self.store = hs.get_datastores().main self.pending_new_room_events: List[_PendingRoomEventEntry] = [] @@ -623,7 +623,7 @@ class Notifier: if name == "room": new_events = await filter_events_for_client( - self.storage, + self._storage_controllers, user.to_string(), new_events, is_peeking=is_peeking, diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index d5603596c0..e96fb45e9f 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -65,7 +65,7 @@ class HttpPusher(Pusher): def __init__(self, hs: "HomeServer", pusher_config: PusherConfig): super().__init__(hs, pusher_config) - self.storage = self.hs.get_storage() + self._storage_controllers = self.hs.get_storage_controllers() self.app_display_name = pusher_config.app_display_name self.device_display_name = pusher_config.device_display_name self.pushkey_ts = pusher_config.ts @@ -343,7 +343,9 @@ class HttpPusher(Pusher): } return d - ctx = await push_tools.get_context_for_event(self.storage, event, self.user_id) + ctx = await push_tools.get_context_for_event( + self._storage_controllers, event, self.user_id + ) d = { "notification": { diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 84124af965..63aefd07f5 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -114,10 +114,10 @@ class Mailer: self.send_email_handler = hs.get_send_email_handler() self.store = self.hs.get_datastores().main - self.state_storage = self.hs.get_storage().state + self._state_storage_controller = self.hs.get_storage_controllers().state self.macaroon_gen = self.hs.get_macaroon_generator() self.state_handler = self.hs.get_state_handler() - self.storage = hs.get_storage() + self._storage_controllers = hs.get_storage_controllers() self.app_name = app_name self.email_subjects: EmailSubjectConfig = hs.config.email.email_subjects @@ -456,7 +456,7 @@ class Mailer: } the_events = await filter_events_for_client( - self.storage, user_id, results.events_before + self._storage_controllers, user_id, results.events_before ) the_events.append(notif_event) @@ -494,7 +494,7 @@ class Mailer: ) else: # Attempt to check the historical state for the room. - historical_state = await self.state_storage.get_state_for_event( + historical_state = await self._state_storage_controller.get_state_for_event( event.event_id, StateFilter.from_types((type_state_key,)) ) sender_state_event = historical_state.get(type_state_key) @@ -767,8 +767,10 @@ class Mailer: member_event_ids.append(sender_state_event_id) else: # Attempt to check the historical state for the room. - historical_state = await self.state_storage.get_state_for_event( - event_id, StateFilter.from_types((type_state_key,)) + historical_state = ( + await self._state_storage_controller.get_state_for_event( + event_id, StateFilter.from_types((type_state_key,)) + ) ) sender_state_event = historical_state.get(type_state_key) if sender_state_event: diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index a1bf5b20dd..8397229ccb 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -16,7 +16,7 @@ from typing import Dict from synapse.api.constants import ReceiptTypes from synapse.events import EventBase from synapse.push.presentable_names import calculate_room_name, name_from_member_event -from synapse.storage import Storage +from synapse.storage.controllers import StorageControllers from synapse.storage.databases.main import DataStore @@ -52,7 +52,7 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) - async def get_context_for_event( - storage: Storage, ev: EventBase, user_id: str + storage: StorageControllers, ev: EventBase, user_id: str ) -> Dict[str, str]: ctx = {} diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index 3e7300b4a1..eed29cd597 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -69,7 +69,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): super().__init__(hs) self.store = hs.get_datastores().main - self.storage = hs.get_storage() + self._storage_controllers = hs.get_storage_controllers() self.clock = hs.get_clock() self.federation_event_handler = hs.get_federation_event_handler() @@ -133,7 +133,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): event.internal_metadata.outlier = event_payload["outlier"] context = EventContext.deserialize( - self.storage, event_payload["context"] + self._storage_controllers, event_payload["context"] ) event_and_contexts.append((event, context)) diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index ce78176836..c2b2588ea5 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -70,7 +70,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): self.event_creation_handler = hs.get_event_creation_handler() self.store = hs.get_datastores().main - self.storage = hs.get_storage() + self._storage_controllers = hs.get_storage_controllers() self.clock = hs.get_clock() @staticmethod @@ -127,7 +127,9 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): event.internal_metadata.outlier = content["outlier"] requester = Requester.deserialize(self.store, content["requester"]) - context = EventContext.deserialize(self.storage, content["context"]) + context = EventContext.deserialize( + self._storage_controllers, content["context"] + ) ratelimit = content["ratelimit"] extra_users = [UserID.from_string(u) for u in content["extra_users"]] diff --git a/synapse/server.py b/synapse/server.py index 3fd23aaf52..a66ec228db 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -123,7 +123,8 @@ from synapse.server_notices.worker_server_notices_sender import ( WorkerServerNoticesSender, ) from synapse.state import StateHandler, StateResolutionHandler -from synapse.storage import Databases, Storage +from synapse.storage import Databases +from synapse.storage.controllers import StorageControllers from synapse.streams.events import EventSources from synapse.types import DomainSpecificString, ISynapseReactor from synapse.util import Clock @@ -729,8 +730,8 @@ class HomeServer(metaclass=abc.ABCMeta): return PasswordPolicyHandler(self) @cache_in_self - def get_storage(self) -> Storage: - return Storage(self, self.get_datastores()) + def get_storage_controllers(self) -> StorageControllers: + return StorageControllers(self, self.get_datastores()) @cache_in_self def get_replication_streamer(self) -> ReplicationStreamer: diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 9c9d946f38..bf09f5128a 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -127,10 +127,10 @@ class StateHandler: def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self.store = hs.get_datastores().main - self.state_storage = hs.get_storage().state + self._state_storage_controller = hs.get_storage_controllers().state self.hs = hs self._state_resolution_handler = hs.get_state_resolution_handler() - self._storage = hs.get_storage() + self._storage_controllers = hs.get_storage_controllers() @overload async def get_current_state( @@ -337,12 +337,14 @@ class StateHandler: # if not state_group_before_event: - state_group_before_event = await self.state_storage.store_state_group( - event.event_id, - event.room_id, - prev_group=state_group_before_event_prev_group, - delta_ids=deltas_to_state_group_before_event, - current_state_ids=state_ids_before_event, + state_group_before_event = ( + await self._state_storage_controller.store_state_group( + event.event_id, + event.room_id, + prev_group=state_group_before_event_prev_group, + delta_ids=deltas_to_state_group_before_event, + current_state_ids=state_ids_before_event, + ) ) # Assign the new state group to the cached state entry. @@ -359,7 +361,7 @@ class StateHandler: if not event.is_state(): return EventContext.with_state( - storage=self._storage, + storage=self._storage_controllers, state_group_before_event=state_group_before_event, state_group=state_group_before_event, state_delta_due_to_event={}, @@ -382,16 +384,18 @@ class StateHandler: state_ids_after_event[key] = event.event_id delta_ids = {key: event.event_id} - state_group_after_event = await self.state_storage.store_state_group( - event.event_id, - event.room_id, - prev_group=state_group_before_event, - delta_ids=delta_ids, - current_state_ids=state_ids_after_event, + state_group_after_event = ( + await self._state_storage_controller.store_state_group( + event.event_id, + event.room_id, + prev_group=state_group_before_event, + delta_ids=delta_ids, + current_state_ids=state_ids_after_event, + ) ) return EventContext.with_state( - storage=self._storage, + storage=self._storage_controllers, state_group=state_group_after_event, state_group_before_event=state_group_before_event, state_delta_due_to_event=delta_ids, @@ -416,7 +420,9 @@ class StateHandler: """ logger.debug("resolve_state_groups event_ids %s", event_ids) - state_groups = await self.state_storage.get_state_group_for_events(event_ids) + state_groups = await self._state_storage_controller.get_state_group_for_events( + event_ids + ) state_group_ids = state_groups.values() @@ -424,8 +430,13 @@ class StateHandler: state_group_ids_set = set(state_group_ids) if len(state_group_ids_set) == 1: (state_group_id,) = state_group_ids_set - state = await self.state_storage.get_state_for_groups(state_group_ids_set) - prev_group, delta_ids = await self.state_storage.get_state_group_delta( + state = await self._state_storage_controller.get_state_for_groups( + state_group_ids_set + ) + ( + prev_group, + delta_ids, + ) = await self._state_storage_controller.get_state_group_delta( state_group_id ) return _StateCacheEntry( @@ -439,7 +450,7 @@ class StateHandler: room_version = await self.store.get_room_version_id(room_id) - state_to_resolve = await self.state_storage.get_state_for_groups( + state_to_resolve = await self._state_storage_controller.get_state_for_groups( state_group_ids_set ) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 105e4e1fec..bac21ecf9c 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -18,41 +18,20 @@ The storage layer is split up into multiple parts to allow Synapse to run against different configurations of databases (e.g. single or multiple databases). The `DatabasePool` class represents connections to a single physical database. The `databases` are classes that talk directly to a `DatabasePool` -instance and have associated schemas, background updates, etc. On top of those -there are classes that provide high level interfaces that combine calls to -multiple `databases`. +instance and have associated schemas, background updates, etc. + +On top of the databases are the StorageControllers, located in the +`synapse.storage.controllers` module. These classes provide high level +interfaces that combine calls to multiple `databases`. They are bundled into the +`StorageControllers` singleton for ease of use, and exposed via +`HomeServer.get_storage_controllers()`. There are also schemas that get applied to every database, regardless of the data stores associated with them (e.g. the schema version tables), which are stored in `synapse.storage.schema`. """ -from typing import TYPE_CHECKING from synapse.storage.databases import Databases from synapse.storage.databases.main import DataStore -from synapse.storage.persist_events import EventsPersistenceStorage -from synapse.storage.purge_events import PurgeEventsStorage -from synapse.storage.state import StateGroupStorage - -if TYPE_CHECKING: - from synapse.server import HomeServer - __all__ = ["Databases", "DataStore"] - - -class Storage: - """The high level interfaces for talking to various storage layers.""" - - def __init__(self, hs: "HomeServer", stores: Databases): - # We include the main data store here mainly so that we don't have to - # rewrite all the existing code to split it into high vs low level - # interfaces. - self.main = stores.main - - self.purge_events = PurgeEventsStorage(hs, stores) - self.state = StateGroupStorage(hs, stores) - - self.persistence = None - if stores.persist_events: - self.persistence = EventsPersistenceStorage(hs, stores) diff --git a/synapse/storage/controllers/__init__.py b/synapse/storage/controllers/__init__.py new file mode 100644 index 0000000000..992261d07b --- /dev/null +++ b/synapse/storage/controllers/__init__.py @@ -0,0 +1,46 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from synapse.storage.controllers.persist_events import ( + EventsPersistenceStorageController, +) +from synapse.storage.controllers.purge_events import PurgeEventsStorageController +from synapse.storage.controllers.state import StateGroupStorageController +from synapse.storage.databases import Databases +from synapse.storage.databases.main import DataStore + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +__all__ = ["Databases", "DataStore"] + + +class StorageControllers: + """The high level interfaces for talking to various storage controller layers.""" + + def __init__(self, hs: "HomeServer", stores: Databases): + # We include the main data store here mainly so that we don't have to + # rewrite all the existing code to split it into high vs low level + # interfaces. + self.main = stores.main + + self.purge_events = PurgeEventsStorageController(hs, stores) + self.state = StateGroupStorageController(hs, stores) + + self.persistence = None + if stores.persist_events: + self.persistence = EventsPersistenceStorageController(hs, stores) diff --git a/synapse/storage/persist_events.py b/synapse/storage/controllers/persist_events.py similarity index 99% rename from synapse/storage/persist_events.py rename to synapse/storage/controllers/persist_events.py index a21dea91c8..ef8c135b12 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/controllers/persist_events.py @@ -272,7 +272,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]): pass -class EventsPersistenceStorage: +class EventsPersistenceStorageController: """High level interface for handling persisting newly received events. Takes care of batching up events by room, and calculating the necessary diff --git a/synapse/storage/purge_events.py b/synapse/storage/controllers/purge_events.py similarity index 99% rename from synapse/storage/purge_events.py rename to synapse/storage/controllers/purge_events.py index 30669beb7c..9ca50d6a09 100644 --- a/synapse/storage/purge_events.py +++ b/synapse/storage/controllers/purge_events.py @@ -24,7 +24,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class PurgeEventsStorage: +class PurgeEventsStorageController: """High level interface for purging rooms and event history.""" def __init__(self, hs: "HomeServer", stores: Databases): diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py new file mode 100644 index 0000000000..0f09953086 --- /dev/null +++ b/synapse/storage/controllers/state.py @@ -0,0 +1,351 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import ( + TYPE_CHECKING, + Awaitable, + Collection, + Dict, + Iterable, + List, + Mapping, + Optional, + Tuple, +) + +from synapse.events import EventBase +from synapse.storage.state import StateFilter +from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker +from synapse.types import MutableStateMap, StateMap + +if TYPE_CHECKING: + from synapse.server import HomeServer + from synapse.storage.databases import Databases + +logger = logging.getLogger(__name__) + + +class StateGroupStorageController: + """High level interface to fetching state for event.""" + + def __init__(self, hs: "HomeServer", stores: "Databases"): + self._is_mine_id = hs.is_mine_id + self.stores = stores + self._partial_state_events_tracker = PartialStateEventsTracker(stores.main) + + def notify_event_un_partial_stated(self, event_id: str) -> None: + self._partial_state_events_tracker.notify_un_partial_stated(event_id) + + async def get_state_group_delta( + self, state_group: int + ) -> Tuple[Optional[int], Optional[StateMap[str]]]: + """Given a state group try to return a previous group and a delta between + the old and the new. + + Args: + state_group: The state group used to retrieve state deltas. + + Returns: + A tuple of the previous group and a state map of the event IDs which + make up the delta between the old and new state groups. + """ + + state_group_delta = await self.stores.state.get_state_group_delta(state_group) + return state_group_delta.prev_group, state_group_delta.delta_ids + + async def get_state_groups_ids( + self, _room_id: str, event_ids: Collection[str] + ) -> Dict[int, MutableStateMap[str]]: + """Get the event IDs of all the state for the state groups for the given events + + Args: + _room_id: id of the room for these events + event_ids: ids of the events + + Returns: + dict of state_group_id -> (dict of (type, state_key) -> event id) + + Raises: + RuntimeError if we don't have a state group for one or more of the events + (ie they are outliers or unknown) + """ + if not event_ids: + return {} + + event_to_groups = await self.get_state_group_for_events(event_ids) + + groups = set(event_to_groups.values()) + group_to_state = await self.stores.state._get_state_for_groups(groups) + + return group_to_state + + async def get_state_ids_for_group( + self, state_group: int, state_filter: Optional[StateFilter] = None + ) -> StateMap[str]: + """Get the event IDs of all the state in the given state group + + Args: + state_group: A state group for which we want to get the state IDs. + state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules + + Returns: + Resolves to a map of (type, state_key) -> event_id + """ + group_to_state = await self.get_state_for_groups((state_group,), state_filter) + + return group_to_state[state_group] + + async def get_state_groups( + self, room_id: str, event_ids: Collection[str] + ) -> Dict[int, List[EventBase]]: + """Get the state groups for the given list of event_ids + + Args: + room_id: ID of the room for these events. + event_ids: The event IDs to retrieve state for. + + Returns: + dict of state_group_id -> list of state events. + """ + if not event_ids: + return {} + + group_to_ids = await self.get_state_groups_ids(room_id, event_ids) + + state_event_map = await self.stores.main.get_events( + [ + ev_id + for group_ids in group_to_ids.values() + for ev_id in group_ids.values() + ], + get_prev_content=False, + ) + + return { + group: [ + state_event_map[v] + for v in event_id_map.values() + if v in state_event_map + ] + for group, event_id_map in group_to_ids.items() + } + + def _get_state_groups_from_groups( + self, groups: List[int], state_filter: StateFilter + ) -> Awaitable[Dict[int, StateMap[str]]]: + """Returns the state groups for a given set of groups, filtering on + types of state events. + + Args: + groups: list of state group IDs to query + state_filter: The state filter used to fetch state + from the database. + + Returns: + Dict of state group to state map. + """ + + return self.stores.state._get_state_groups_from_groups(groups, state_filter) + + async def get_state_for_events( + self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None + ) -> Dict[str, StateMap[EventBase]]: + """Given a list of event_ids and type tuples, return a list of state + dicts for each event. + + Args: + event_ids: The events to fetch the state of. + state_filter: The state filter used to fetch state. + + Returns: + A dict of (event_id) -> (type, state_key) -> [state_events] + + Raises: + RuntimeError if we don't have a state group for one or more of the events + (ie they are outliers or unknown) + """ + await_full_state = True + if state_filter and not state_filter.must_await_full_state(self._is_mine_id): + await_full_state = False + + event_to_groups = await self.get_state_group_for_events( + event_ids, await_full_state=await_full_state + ) + + groups = set(event_to_groups.values()) + group_to_state = await self.stores.state._get_state_for_groups( + groups, state_filter or StateFilter.all() + ) + + state_event_map = await self.stores.main.get_events( + [ev_id for sd in group_to_state.values() for ev_id in sd.values()], + get_prev_content=False, + ) + + event_to_state = { + event_id: { + k: state_event_map[v] + for k, v in group_to_state[group].items() + if v in state_event_map + } + for event_id, group in event_to_groups.items() + } + + return {event: event_to_state[event] for event in event_ids} + + async def get_state_ids_for_events( + self, + event_ids: Collection[str], + state_filter: Optional[StateFilter] = None, + ) -> Dict[str, StateMap[str]]: + """ + Get the state dicts corresponding to a list of events, containing the event_ids + of the state events (as opposed to the events themselves) + + Args: + event_ids: events whose state should be returned + state_filter: The state filter used to fetch state from the database. + + Returns: + A dict from event_id -> (type, state_key) -> event_id + + Raises: + RuntimeError if we don't have a state group for one or more of the events + (ie they are outliers or unknown) + """ + await_full_state = True + if state_filter and not state_filter.must_await_full_state(self._is_mine_id): + await_full_state = False + + event_to_groups = await self.get_state_group_for_events( + event_ids, await_full_state=await_full_state + ) + + groups = set(event_to_groups.values()) + group_to_state = await self.stores.state._get_state_for_groups( + groups, state_filter or StateFilter.all() + ) + + event_to_state = { + event_id: group_to_state[group] + for event_id, group in event_to_groups.items() + } + + return {event: event_to_state[event] for event in event_ids} + + async def get_state_for_event( + self, event_id: str, state_filter: Optional[StateFilter] = None + ) -> StateMap[EventBase]: + """ + Get the state dict corresponding to a particular event + + Args: + event_id: event whose state should be returned + state_filter: The state filter used to fetch state from the database. + + Returns: + A dict from (type, state_key) -> state_event + + Raises: + RuntimeError if we don't have a state group for the event (ie it is an + outlier or is unknown) + """ + state_map = await self.get_state_for_events( + [event_id], state_filter or StateFilter.all() + ) + return state_map[event_id] + + async def get_state_ids_for_event( + self, event_id: str, state_filter: Optional[StateFilter] = None + ) -> StateMap[str]: + """ + Get the state dict corresponding to a particular event + + Args: + event_id: event whose state should be returned + state_filter: The state filter used to fetch state from the database. + + Returns: + A dict from (type, state_key) -> state_event_id + + Raises: + RuntimeError if we don't have a state group for the event (ie it is an + outlier or is unknown) + """ + state_map = await self.get_state_ids_for_events( + [event_id], state_filter or StateFilter.all() + ) + return state_map[event_id] + + def get_state_for_groups( + self, groups: Iterable[int], state_filter: Optional[StateFilter] = None + ) -> Awaitable[Dict[int, MutableStateMap[str]]]: + """Gets the state at each of a list of state groups, optionally + filtering by type/state_key + + Args: + groups: list of state groups for which we want to get the state. + state_filter: The state filter used to fetch state. + from the database. + + Returns: + Dict of state group to state map. + """ + return self.stores.state._get_state_for_groups( + groups, state_filter or StateFilter.all() + ) + + async def get_state_group_for_events( + self, + event_ids: Collection[str], + await_full_state: bool = True, + ) -> Mapping[str, int]: + """Returns mapping event_id -> state_group + + Args: + event_ids: events to get state groups for + await_full_state: if true, will block if we do not yet have complete + state at these events. + """ + if await_full_state: + await self._partial_state_events_tracker.await_full_state(event_ids) + + return await self.stores.main._get_state_group_for_events(event_ids) + + async def store_state_group( + self, + event_id: str, + room_id: str, + prev_group: Optional[int], + delta_ids: Optional[StateMap[str]], + current_state_ids: StateMap[str], + ) -> int: + """Store a new set of state, returning a newly assigned state group. + + Args: + event_id: The event ID for which the state was calculated. + room_id: ID of the room for which the state was calculated. + prev_group: A previous state group for the room, optional. + delta_ids: The delta between state at `prev_group` and + `current_state_ids`, if `prev_group` was given. Same format as + `current_state_ids`. + current_state_ids: The state to store. Map of (type, state_key) + to event_id. + + Returns: + The state group ID + """ + return await self.stores.state.store_state_group( + event_id, room_id, prev_group, delta_ids, current_state_ids + ) diff --git a/synapse/storage/state.py b/synapse/storage/state.py index ab630953ac..96aaffb53c 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -15,7 +15,6 @@ import logging from typing import ( TYPE_CHECKING, - Awaitable, Callable, Collection, Dict, @@ -32,15 +31,11 @@ import attr from frozendict import frozendict from synapse.api.constants import EventTypes -from synapse.events import EventBase -from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker from synapse.types import MutableStateMap, StateKey, StateMap if TYPE_CHECKING: from typing import FrozenSet # noqa: used within quoted type hint; flake8 sad - from synapse.server import HomeServer - from synapse.storage.databases import Databases logger = logging.getLogger(__name__) @@ -578,318 +573,3 @@ _ALL_NON_MEMBER_STATE_FILTER = StateFilter( types=frozendict({EventTypes.Member: frozenset()}), include_others=True ) _NONE_STATE_FILTER = StateFilter(types=frozendict(), include_others=False) - - -class StateGroupStorage: - """High level interface to fetching state for event.""" - - def __init__(self, hs: "HomeServer", stores: "Databases"): - self._is_mine_id = hs.is_mine_id - self.stores = stores - self._partial_state_events_tracker = PartialStateEventsTracker(stores.main) - - def notify_event_un_partial_stated(self, event_id: str) -> None: - self._partial_state_events_tracker.notify_un_partial_stated(event_id) - - async def get_state_group_delta( - self, state_group: int - ) -> Tuple[Optional[int], Optional[StateMap[str]]]: - """Given a state group try to return a previous group and a delta between - the old and the new. - - Args: - state_group: The state group used to retrieve state deltas. - - Returns: - A tuple of the previous group and a state map of the event IDs which - make up the delta between the old and new state groups. - """ - - state_group_delta = await self.stores.state.get_state_group_delta(state_group) - return state_group_delta.prev_group, state_group_delta.delta_ids - - async def get_state_groups_ids( - self, _room_id: str, event_ids: Collection[str] - ) -> Dict[int, MutableStateMap[str]]: - """Get the event IDs of all the state for the state groups for the given events - - Args: - _room_id: id of the room for these events - event_ids: ids of the events - - Returns: - dict of state_group_id -> (dict of (type, state_key) -> event id) - - Raises: - RuntimeError if we don't have a state group for one or more of the events - (ie they are outliers or unknown) - """ - if not event_ids: - return {} - - event_to_groups = await self.get_state_group_for_events(event_ids) - - groups = set(event_to_groups.values()) - group_to_state = await self.stores.state._get_state_for_groups(groups) - - return group_to_state - - async def get_state_ids_for_group( - self, state_group: int, state_filter: Optional[StateFilter] = None - ) -> StateMap[str]: - """Get the event IDs of all the state in the given state group - - Args: - state_group: A state group for which we want to get the state IDs. - state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules - - Returns: - Resolves to a map of (type, state_key) -> event_id - """ - group_to_state = await self.get_state_for_groups((state_group,), state_filter) - - return group_to_state[state_group] - - async def get_state_groups( - self, room_id: str, event_ids: Collection[str] - ) -> Dict[int, List[EventBase]]: - """Get the state groups for the given list of event_ids - - Args: - room_id: ID of the room for these events. - event_ids: The event IDs to retrieve state for. - - Returns: - dict of state_group_id -> list of state events. - """ - if not event_ids: - return {} - - group_to_ids = await self.get_state_groups_ids(room_id, event_ids) - - state_event_map = await self.stores.main.get_events( - [ - ev_id - for group_ids in group_to_ids.values() - for ev_id in group_ids.values() - ], - get_prev_content=False, - ) - - return { - group: [ - state_event_map[v] - for v in event_id_map.values() - if v in state_event_map - ] - for group, event_id_map in group_to_ids.items() - } - - def _get_state_groups_from_groups( - self, groups: List[int], state_filter: StateFilter - ) -> Awaitable[Dict[int, StateMap[str]]]: - """Returns the state groups for a given set of groups, filtering on - types of state events. - - Args: - groups: list of state group IDs to query - state_filter: The state filter used to fetch state - from the database. - - Returns: - Dict of state group to state map. - """ - - return self.stores.state._get_state_groups_from_groups(groups, state_filter) - - async def get_state_for_events( - self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None - ) -> Dict[str, StateMap[EventBase]]: - """Given a list of event_ids and type tuples, return a list of state - dicts for each event. - - Args: - event_ids: The events to fetch the state of. - state_filter: The state filter used to fetch state. - - Returns: - A dict of (event_id) -> (type, state_key) -> [state_events] - - Raises: - RuntimeError if we don't have a state group for one or more of the events - (ie they are outliers or unknown) - """ - await_full_state = True - if state_filter and not state_filter.must_await_full_state(self._is_mine_id): - await_full_state = False - - event_to_groups = await self.get_state_group_for_events( - event_ids, await_full_state=await_full_state - ) - - groups = set(event_to_groups.values()) - group_to_state = await self.stores.state._get_state_for_groups( - groups, state_filter or StateFilter.all() - ) - - state_event_map = await self.stores.main.get_events( - [ev_id for sd in group_to_state.values() for ev_id in sd.values()], - get_prev_content=False, - ) - - event_to_state = { - event_id: { - k: state_event_map[v] - for k, v in group_to_state[group].items() - if v in state_event_map - } - for event_id, group in event_to_groups.items() - } - - return {event: event_to_state[event] for event in event_ids} - - async def get_state_ids_for_events( - self, - event_ids: Collection[str], - state_filter: Optional[StateFilter] = None, - ) -> Dict[str, StateMap[str]]: - """ - Get the state dicts corresponding to a list of events, containing the event_ids - of the state events (as opposed to the events themselves) - - Args: - event_ids: events whose state should be returned - state_filter: The state filter used to fetch state from the database. - - Returns: - A dict from event_id -> (type, state_key) -> event_id - - Raises: - RuntimeError if we don't have a state group for one or more of the events - (ie they are outliers or unknown) - """ - await_full_state = True - if state_filter and not state_filter.must_await_full_state(self._is_mine_id): - await_full_state = False - - event_to_groups = await self.get_state_group_for_events( - event_ids, await_full_state=await_full_state - ) - - groups = set(event_to_groups.values()) - group_to_state = await self.stores.state._get_state_for_groups( - groups, state_filter or StateFilter.all() - ) - - event_to_state = { - event_id: group_to_state[group] - for event_id, group in event_to_groups.items() - } - - return {event: event_to_state[event] for event in event_ids} - - async def get_state_for_event( - self, event_id: str, state_filter: Optional[StateFilter] = None - ) -> StateMap[EventBase]: - """ - Get the state dict corresponding to a particular event - - Args: - event_id: event whose state should be returned - state_filter: The state filter used to fetch state from the database. - - Returns: - A dict from (type, state_key) -> state_event - - Raises: - RuntimeError if we don't have a state group for the event (ie it is an - outlier or is unknown) - """ - state_map = await self.get_state_for_events( - [event_id], state_filter or StateFilter.all() - ) - return state_map[event_id] - - async def get_state_ids_for_event( - self, event_id: str, state_filter: Optional[StateFilter] = None - ) -> StateMap[str]: - """ - Get the state dict corresponding to a particular event - - Args: - event_id: event whose state should be returned - state_filter: The state filter used to fetch state from the database. - - Returns: - A dict from (type, state_key) -> state_event_id - - Raises: - RuntimeError if we don't have a state group for the event (ie it is an - outlier or is unknown) - """ - state_map = await self.get_state_ids_for_events( - [event_id], state_filter or StateFilter.all() - ) - return state_map[event_id] - - def get_state_for_groups( - self, groups: Iterable[int], state_filter: Optional[StateFilter] = None - ) -> Awaitable[Dict[int, MutableStateMap[str]]]: - """Gets the state at each of a list of state groups, optionally - filtering by type/state_key - - Args: - groups: list of state groups for which we want to get the state. - state_filter: The state filter used to fetch state. - from the database. - - Returns: - Dict of state group to state map. - """ - return self.stores.state._get_state_for_groups( - groups, state_filter or StateFilter.all() - ) - - async def get_state_group_for_events( - self, - event_ids: Collection[str], - await_full_state: bool = True, - ) -> Mapping[str, int]: - """Returns mapping event_id -> state_group - - Args: - event_ids: events to get state groups for - await_full_state: if true, will block if we do not yet have complete - state at these events. - """ - if await_full_state: - await self._partial_state_events_tracker.await_full_state(event_ids) - - return await self.stores.main._get_state_group_for_events(event_ids) - - async def store_state_group( - self, - event_id: str, - room_id: str, - prev_group: Optional[int], - delta_ids: Optional[StateMap[str]], - current_state_ids: StateMap[str], - ) -> int: - """Store a new set of state, returning a newly assigned state group. - - Args: - event_id: The event ID for which the state was calculated. - room_id: ID of the room for which the state was calculated. - prev_group: A previous state group for the room, optional. - delta_ids: The delta between state at `prev_group` and - `current_state_ids`, if `prev_group` was given. Same format as - `current_state_ids`. - current_state_ids: The state to store. Map of (type, state_key) - to event_id. - - Returns: - The state group ID - """ - return await self.stores.state.store_state_group( - event_id, room_id, prev_group, delta_ids, current_state_ids - ) diff --git a/synapse/visibility.py b/synapse/visibility.py index da4af02796..97548c14e3 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -20,7 +20,7 @@ from typing_extensions import Final from synapse.api.constants import EventTypes, HistoryVisibility, Membership from synapse.events import EventBase from synapse.events.utils import prune_event -from synapse.storage import Storage +from synapse.storage.controllers import StorageControllers from synapse.storage.state import StateFilter from synapse.types import RetentionPolicy, StateMap, get_domain_from_id @@ -47,7 +47,7 @@ _HISTORY_VIS_KEY: Final[Tuple[str, str]] = (EventTypes.RoomHistoryVisibility, "" async def filter_events_for_client( - storage: Storage, + storage: StorageControllers, user_id: str, events: List[EventBase], is_peeking: bool = False, @@ -268,7 +268,7 @@ async def filter_events_for_client( async def filter_events_for_server( - storage: Storage, + storage: StorageControllers, server_name: str, events: List[EventBase], redact: bool = True, @@ -360,7 +360,7 @@ async def filter_events_for_server( async def _event_to_history_vis( - storage: Storage, events: Collection[EventBase] + storage: StorageControllers, events: Collection[EventBase] ) -> Dict[str, str]: """Get the history visibility at each of the given events @@ -407,7 +407,7 @@ async def _event_to_history_vis( async def _event_to_memberships( - storage: Storage, events: Collection[EventBase], server_name: str + storage: StorageControllers, events: Collection[EventBase], server_name: str ) -> Dict[str, StateMap[EventBase]]: """Get the remote membership list at each of the given events diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py index defbc68c18..8ddce83b83 100644 --- a/tests/events/test_snapshot.py +++ b/tests/events/test_snapshot.py @@ -29,7 +29,7 @@ class TestEventContext(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastores().main - self.storage = hs.get_storage() + self._storage_controllers = hs.get_storage_controllers() self.user_id = self.register_user("u1", "pass") self.user_tok = self.login("u1", "pass") @@ -87,7 +87,7 @@ class TestEventContext(unittest.HomeserverTestCase): def _check_serialize_deserialize(self, event, context): serialized = self.get_success(context.serialize(event, self.store)) - d_context = EventContext.deserialize(self.storage, serialized) + d_context = EventContext.deserialize(self._storage_controllers, serialized) self.assertEqual(context.state_group, d_context.state_group) self.assertEqual(context.rejected, d_context.rejected) diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index ec00900621..500c9ccfbc 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -50,7 +50,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase): hs = self.setup_test_homeserver(federation_http_client=None) self.handler = hs.get_federation_handler() self.store = hs.get_datastores().main - self.state_storage = hs.get_storage().state + self.state_storage_controller = hs.get_storage_controllers().state self._event_auth_handler = hs.get_event_auth_handler() return hs @@ -338,7 +338,9 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase): # mapping from (type, state_key) -> state_event_id assert most_recent_prev_event_id is not None prev_state_map = self.get_success( - self.state_storage.get_state_ids_for_event(most_recent_prev_event_id) + self.state_storage_controller.get_state_ids_for_event( + most_recent_prev_event_id + ) ) # List of state event ID's prev_state_ids = list(prev_state_map.values()) diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py index e64b28f28b..1d5b2492c0 100644 --- a/tests/handlers/test_federation_event.py +++ b/tests/handlers/test_federation_event.py @@ -70,7 +70,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): ) -> None: OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}" main_store = self.hs.get_datastores().main - state_storage = self.hs.get_storage().state + state_storage_controller = self.hs.get_storage_controllers().state # create the room user_id = self.register_user("kermit", "test") @@ -146,10 +146,11 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): ) if prev_exists_as_outlier: prev_event.internal_metadata.outlier = True - persistence = self.hs.get_storage().persistence + persistence = self.hs.get_storage_controllers().persistence self.get_success( persistence.persist_event( - prev_event, EventContext.for_outlier(self.hs.get_storage()) + prev_event, + EventContext.for_outlier(self.hs.get_storage_controllers()), ) ) else: @@ -216,7 +217,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): # check that the state at that event is as expected state = self.get_success( - state_storage.get_state_ids_for_event(pulled_event.event_id) + state_storage_controller.get_state_ids_for_event(pulled_event.event_id) ) expected_state = { (e.type, e.state_key): e.event_id for e in state_at_prev_event diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py index f4f7ab4845..44da96c792 100644 --- a/tests/handlers/test_message.py +++ b/tests/handlers/test_message.py @@ -37,7 +37,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.handler = self.hs.get_event_creation_handler() - self.persist_event_storage = self.hs.get_storage().persistence + self._persist_event_storage_controller = ( + self.hs.get_storage_controllers().persistence + ) self.user_id = self.register_user("tester", "foobar") self.access_token = self.login("tester", "foobar") @@ -65,7 +67,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase): ) ) self.get_success( - self.persist_event_storage.persist_event(memberEvent, memberEventContext) + self._persist_event_storage_controller.persist_event( + memberEvent, memberEventContext + ) ) return memberEvent, memberEventContext @@ -129,7 +133,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase): self.assertNotEqual(event1.event_id, event3.event_id) ret_event3, event_pos3, _ = self.get_success( - self.persist_event_storage.persist_event(event3, context) + self._persist_event_storage_controller.persist_event(event3, context) ) # Assert that the returned values match those from the initial event @@ -143,7 +147,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase): self.assertNotEqual(event1.event_id, event3.event_id) events, _ = self.get_success( - self.persist_event_storage.persist_events([(event3, context)]) + self._persist_event_storage_controller.persist_events([(event3, context)]) ) ret_event4 = events[0] @@ -166,7 +170,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase): self.assertNotEqual(event1.event_id, event2.event_id) events, _ = self.get_success( - self.persist_event_storage.persist_events( + self._persist_event_storage_controller.persist_events( [(event1, context1), (event2, context2)] ) ) diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 4d658d29ca..a68c2ffd45 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -954,7 +954,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) self.get_success( - self.hs.get_storage().persistence.persist_event(event, context) + self.hs.get_storage_controllers().persistence.persist_event(event, context) ) def test_local_user_leaving_room_remains_in_user_directory(self) -> None: diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 85be79d19d..c5705256e6 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -32,7 +32,7 @@ class BaseSlavedStoreTestCase(BaseStreamTestCase): self.master_store = hs.get_datastores().main self.slaved_store = self.worker_hs.get_datastores().main - self.storage = hs.get_storage() + self._storage_controllers = hs.get_storage_controllers() def replicate(self): """Tell the master side of replication that something has happened, and then diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 297a9e77f8..6d3d4afe52 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -262,7 +262,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): ) msg, msgctx = self.build_event() self.get_success( - self.storage.persistence.persist_events([(j2, j2ctx), (msg, msgctx)]) + self._storage_controllers.persistence.persist_events( + [(j2, j2ctx), (msg, msgctx)] + ) ) self.replicate() @@ -323,12 +325,14 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): if backfill: self.get_success( - self.storage.persistence.persist_events( + self._storage_controllers.persistence.persist_events( [(event, context)], backfilled=True ) ) else: - self.get_success(self.storage.persistence.persist_event(event, context)) + self.get_success( + self._storage_controllers.persistence.persist_event(event, context) + ) return event diff --git a/tests/replication/slave/storage/test_receipts.py b/tests/replication/slave/storage/test_receipts.py index 5bbbd5fbcb..19f57115a1 100644 --- a/tests/replication/slave/storage/test_receipts.py +++ b/tests/replication/slave/storage/test_receipts.py @@ -31,7 +31,9 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase): def prepare(self, reactor, clock, homeserver): super().prepare(reactor, clock, homeserver) self.room_creator = homeserver.get_room_creation_handler() - self.persist_event_storage = self.hs.get_storage().persistence + self.persist_event_storage_controller = ( + self.hs.get_storage_controllers().persistence + ) # Create a test user self.ourUser = UserID.from_string(OUR_USER_ID) @@ -61,7 +63,9 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase): ) ) self.get_success( - self.persist_event_storage.persist_event(memberEvent, memberEventContext) + self.persist_event_storage_controller.persist_event( + memberEvent, memberEventContext + ) ) # Join the second user to the second room @@ -76,7 +80,9 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase): ) ) self.get_success( - self.persist_event_storage.persist_event(memberEvent, memberEventContext) + self.persist_event_storage_controller.persist_event( + memberEvent, memberEventContext + ) ) def test_return_empty_with_no_data(self): diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 0cdf1dec40..0d44102237 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -2579,7 +2579,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): other_user_tok = self.login("user", "pass") event_builder_factory = self.hs.get_event_builder_factory() event_creation_handler = self.hs.get_event_creation_handler() - storage = self.hs.get_storage() + storage_controllers = self.hs.get_storage_controllers() # Create two rooms, one with a local user only and one with both a local # and remote user. @@ -2604,7 +2604,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): event_creation_handler.create_new_client_event(builder) ) - self.get_success(storage.persistence.persist_event(event, context)) + self.get_success(storage_controllers.persistence.persist_event(event, context)) # Now get rooms url = "/_synapse/admin/v1/users/@joiner:remote_hs/joined_rooms" diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py index 2cd7a9e6c5..ac9c113354 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py @@ -130,7 +130,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): We do this by setting a very long time between purge jobs. """ store = self.hs.get_datastores().main - storage = self.hs.get_storage() + storage_controllers = self.hs.get_storage_controllers() room_id = self.helper.create_room_as(self.user_id, tok=self.token) # Send a first event, which should be filtered out at the end of the test. @@ -155,7 +155,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): ) self.assertEqual(2, len(events), "events retrieved from database") filtered_events = self.get_success( - filter_events_for_client(storage, self.user_id, events) + filter_events_for_client(storage_controllers, self.user_id, events) ) # We should only get one event back. diff --git a/tests/rest/client/test_room_batch.py b/tests/rest/client/test_room_batch.py index 41a1bf6d89..1b7ee08ab2 100644 --- a/tests/rest/client/test_room_batch.py +++ b/tests/rest/client/test_room_batch.py @@ -88,7 +88,7 @@ class RoomBatchTestCase(unittest.HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.clock = clock - self.storage = hs.get_storage() + self._storage_controllers = hs.get_storage_controllers() self.virtual_user_id, _ = self.register_appservice_user( "as_user_potato", self.appservice.token @@ -168,7 +168,9 @@ class RoomBatchTestCase(unittest.HomeserverTestCase): # Fetch the state_groups state_group_map = self.get_success( - self.storage.state.get_state_groups_ids(room_id, historical_event_ids) + self._storage_controllers.state.get_state_groups_ids( + room_id, historical_event_ids + ) ) # We expect all of the historical events to be using the same state_group diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py index c7661e7186..a0ce077a99 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py @@ -393,7 +393,8 @@ class EventChainStoreTestCase(HomeserverTestCase): # We need to persist the events to the events and state_events # tables. persist_events_store._store_event_txn( - txn, [(e, EventContext(self.hs.get_storage())) for e in events] + txn, + [(e, EventContext(self.hs.get_storage_controllers())) for e in events], ) # Actually call the function that calculates the auth chain stuff. diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py index aaa3189b16..a76718e8f9 100644 --- a/tests/storage/test_events.py +++ b/tests/storage/test_events.py @@ -31,7 +31,7 @@ class ExtremPruneTestCase(HomeserverTestCase): def prepare(self, reactor, clock, homeserver): self.state = self.hs.get_state_handler() - self.persistence = self.hs.get_storage().persistence + self._persistence = self.hs.get_storage_controllers().persistence self.store = self.hs.get_datastores().main self.register_user("user", "pass") @@ -71,7 +71,7 @@ class ExtremPruneTestCase(HomeserverTestCase): context = self.get_success( self.state.compute_event_context(event, state_ids_before_event=state) ) - self.get_success(self.persistence.persist_event(event, context)) + self.get_success(self._persistence.persist_event(event, context)) def assert_extremities(self, expected_extremities): """Assert the current extremities for the room""" @@ -148,7 +148,7 @@ class ExtremPruneTestCase(HomeserverTestCase): ) ) - self.get_success(self.persistence.persist_event(remote_event_2, context)) + self.get_success(self._persistence.persist_event(remote_event_2, context)) # Check that we haven't dropped the old extremity. self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id]) @@ -353,7 +353,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase): def prepare(self, reactor, clock, homeserver): self.state = self.hs.get_state_handler() - self.persistence = self.hs.get_storage().persistence + self._persistence = self.hs.get_storage_controllers().persistence self.store = self.hs.get_datastores().main def test_remote_user_rooms_cache_invalidated(self): @@ -390,7 +390,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase): ) context = self.get_success(self.state.compute_event_context(remote_event_1)) - self.get_success(self.persistence.persist_event(remote_event_1, context)) + self.get_success(self._persistence.persist_event(remote_event_1, context)) # Call `get_rooms_for_user` to add the remote user to the cache rooms = self.get_success(self.store.get_rooms_for_user(remote_user)) @@ -437,7 +437,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase): ) context = self.get_success(self.state.compute_event_context(remote_event_1)) - self.get_success(self.persistence.persist_event(remote_event_1, context)) + self.get_success(self._persistence.persist_event(remote_event_1, context)) # Call `get_users_in_room` to add the remote user to the cache users = self.get_success(self.store.get_users_in_room(room_id)) diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index 08cc60237e..92cd0dfc05 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -31,7 +31,7 @@ class PurgeTests(HomeserverTestCase): self.room_id = self.helper.create_room_as(self.user_id) self.store = hs.get_datastores().main - self.storage = self.hs.get_storage() + self._storage_controllers = self.hs.get_storage_controllers() def test_purge_history(self): """ @@ -51,7 +51,9 @@ class PurgeTests(HomeserverTestCase): # Purge everything before this topological token self.get_success( - self.storage.purge_events.purge_history(self.room_id, token_str, True) + self._storage_controllers.purge_events.purge_history( + self.room_id, token_str, True + ) ) # 1-3 should fail and last will succeed, meaning that 1-3 are deleted @@ -79,7 +81,9 @@ class PurgeTests(HomeserverTestCase): # Purge everything before this topological token f = self.get_failure( - self.storage.purge_events.purge_history(self.room_id, event, True), + self._storage_controllers.purge_events.purge_history( + self.room_id, event, True + ), SynapseError, ) self.assertIn("greater than forward", f.value.args[0]) @@ -105,7 +109,9 @@ class PurgeTests(HomeserverTestCase): self.assertIsNotNone(create_event) # Purge everything before this topological token - self.get_success(self.storage.purge_events.purge_room(self.room_id)) + self.get_success( + self._storage_controllers.purge_events.purge_room(self.room_id) + ) # The events aren't found. self.store._invalidate_get_event_cache(create_event.event_id) diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index d8d17ef379..6c4e63b77c 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -31,7 +31,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastores().main - self.storage = hs.get_storage() + self._storage = hs.get_storage_controllers() self.event_builder_factory = hs.get_event_builder_factory() self.event_creation_handler = hs.get_event_creation_handler() @@ -71,7 +71,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.storage.persistence.persist_event(event, context)) + self.get_success(self._storage.persistence.persist_event(event, context)) return event @@ -93,7 +93,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.storage.persistence.persist_event(event, context)) + self.get_success(self._storage.persistence.persist_event(event, context)) return event @@ -114,7 +114,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.storage.persistence.persist_event(event, context)) + self.get_success(self._storage.persistence.persist_event(event, context)) return event @@ -268,7 +268,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): ) ) - self.get_success(self.storage.persistence.persist_event(event_1, context_1)) + self.get_success(self._storage.persistence.persist_event(event_1, context_1)) event_2, context_2 = self.get_success( self.event_creation_handler.create_new_client_event( @@ -287,7 +287,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): ) ) ) - self.get_success(self.storage.persistence.persist_event(event_2, context_2)) + self.get_success(self._storage.persistence.persist_event(event_2, context_2)) # fetch one of the redactions fetched = self.get_success(self.store.get_event(redaction_event_id1)) @@ -411,7 +411,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): ) self.get_success( - self.storage.persistence.persist_event(redaction_event, context) + self._storage.persistence.persist_event(redaction_event, context) ) # Now lets jump to the future where we have censored the redaction event diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index 5b011e18cd..d497a19f63 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -72,7 +72,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase): # Room events need the full datastore, for persist_event() and # get_room_state() self.store = hs.get_datastores().main - self.storage = hs.get_storage() + self._storage = hs.get_storage_controllers() self.event_factory = hs.get_event_factory() self.room = RoomID.from_string("!abcde:test") @@ -88,7 +88,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase): def inject_room_event(self, **kwargs): self.get_success( - self.storage.persistence.persist_event( + self._storage.persistence.persist_event( self.event_factory.create_event(room_id=self.room.to_string(), **kwargs) ) ) diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py index 8dfc1e1db9..e747c6b50e 100644 --- a/tests/storage/test_room_search.py +++ b/tests/storage/test_room_search.py @@ -99,7 +99,9 @@ class EventSearchInsertionTest(HomeserverTestCase): prev_event_ids = self.get_success(store.get_prev_events_for_room(room_id)) prev_event = self.get_success(store.get_event(prev_event_ids[0])) prev_state_map = self.get_success( - self.hs.get_storage().state.get_state_ids_for_event(prev_event_ids[0]) + self.hs.get_storage_controllers().state.get_state_ids_for_event( + prev_event_ids[0] + ) ) event_dict = { diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index f88f1c55fc..8043bdbde2 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -29,7 +29,7 @@ logger = logging.getLogger(__name__) class StateStoreTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastores().main - self.storage = hs.get_storage() + self.storage = hs.get_storage_controllers() self.state_datastore = self.storage.state.stores.state self.event_builder_factory = hs.get_event_builder_factory() self.event_creation_handler = hs.get_event_creation_handler() diff --git a/tests/test_state.py b/tests/test_state.py index 84694d368d..95f81bebae 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -179,12 +179,12 @@ class Graph: class StateTestCase(unittest.TestCase): def setUp(self): self.dummy_store = _DummyStore() - storage = Mock(main=self.dummy_store, state=self.dummy_store) + storage_controllers = Mock(main=self.dummy_store, state=self.dummy_store) hs = Mock( spec_set=[ "config", "get_datastores", - "get_storage", + "get_storage_controllers", "get_auth", "get_state_handler", "get_clock", @@ -199,7 +199,7 @@ class StateTestCase(unittest.TestCase): hs.get_clock.return_value = MockClock() hs.get_auth.return_value = Auth(hs) hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs) - hs.get_storage.return_value = storage + hs.get_storage_controllers.return_value = storage_controllers self.state = StateHandler(hs) self.event_id = 0 diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py index c654e36ee4..8027c7a856 100644 --- a/tests/test_utils/event_injection.py +++ b/tests/test_utils/event_injection.py @@ -70,7 +70,7 @@ async def inject_event( """ event, context = await create_event(hs, room_version, prev_event_ids, **kwargs) - persistence = hs.get_storage().persistence + persistence = hs.get_storage_controllers().persistence assert persistence is not None await persistence.persist_event(event, context) diff --git a/tests/test_visibility.py b/tests/test_visibility.py index 7a9b01ef9d..f338af6c36 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -34,7 +34,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): super(FilterEventsForServerTestCase, self).setUp() self.event_creation_handler = self.hs.get_event_creation_handler() self.event_builder_factory = self.hs.get_event_builder_factory() - self.storage = self.hs.get_storage() + self._storage_controllers = self.hs.get_storage_controllers() self.get_success(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM")) @@ -60,7 +60,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): events_to_filter.append(evt) filtered = self.get_success( - filter_events_for_server(self.storage, "test_server", events_to_filter) + filter_events_for_server( + self._storage_controllers, "test_server", events_to_filter + ) ) # the result should be 5 redacted events, and 5 unredacted events. @@ -80,7 +82,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): outlier = self._inject_outlier() self.assertEqual( self.get_success( - filter_events_for_server(self.storage, "remote_hs", [outlier]) + filter_events_for_server( + self._storage_controllers, "remote_hs", [outlier] + ) ), [outlier], ) @@ -89,7 +93,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): evt = self._inject_message("@unerased:local_hs") filtered = self.get_success( - filter_events_for_server(self.storage, "remote_hs", [outlier, evt]) + filter_events_for_server( + self._storage_controllers, "remote_hs", [outlier, evt] + ) ) self.assertEqual(len(filtered), 2, f"expected 2 results, got: {filtered}") self.assertEqual(filtered[0], outlier) @@ -99,7 +105,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): # ... but other servers should only be able to see the outlier (the other should # be redacted) filtered = self.get_success( - filter_events_for_server(self.storage, "other_server", [outlier, evt]) + filter_events_for_server( + self._storage_controllers, "other_server", [outlier, evt] + ) ) self.assertEqual(filtered[0], outlier) self.assertEqual(filtered[1].event_id, evt.event_id) @@ -132,7 +140,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): # ... and the filtering happens. filtered = self.get_success( - filter_events_for_server(self.storage, "test_server", events_to_filter) + filter_events_for_server( + self._storage_controllers, "test_server", events_to_filter + ) ) for i in range(0, len(events_to_filter)): @@ -168,7 +178,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): event, context = self.get_success( self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.storage.persistence.persist_event(event, context)) + self.get_success( + self._storage_controllers.persistence.persist_event(event, context) + ) return event def _inject_room_member( @@ -194,7 +206,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.storage.persistence.persist_event(event, context)) + self.get_success( + self._storage_controllers.persistence.persist_event(event, context) + ) return event def _inject_message( @@ -216,7 +230,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.storage.persistence.persist_event(event, context)) + self.get_success( + self._storage_controllers.persistence.persist_event(event, context) + ) return event def _inject_outlier(self) -> EventBase: @@ -234,8 +250,8 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): event = self.get_success(builder.build(prev_event_ids=[], auth_event_ids=[])) event.internal_metadata.outlier = True self.get_success( - self.storage.persistence.persist_event( - event, EventContext.for_outlier(self.storage) + self._storage_controllers.persistence.persist_event( + event, EventContext.for_outlier(self._storage_controllers) ) ) return event @@ -293,7 +309,9 @@ class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase): self.assertEqual( self.get_success( filter_events_for_client( - self.hs.get_storage(), "@user:test", [invite_event, reject_event] + self.hs.get_storage_controllers(), + "@user:test", + [invite_event, reject_event], ) ), [invite_event, reject_event], @@ -303,7 +321,9 @@ class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase): self.assertEqual( self.get_success( filter_events_for_client( - self.hs.get_storage(), "@other:test", [invite_event, reject_event] + self.hs.get_storage_controllers(), + "@other:test", + [invite_event, reject_event], ) ), [], diff --git a/tests/utils.py b/tests/utils.py index d4ba3a9b99..3059c453d5 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -264,7 +264,7 @@ class MockClock: async def create_room(hs, room_id: str, creator_id: str): """Creates and persist a creation event for the given room""" - persistence_store = hs.get_storage().persistence + persistence_store = hs.get_storage_controllers().persistence store = hs.get_datastores().main event_builder_factory = hs.get_event_builder_factory() event_creation_handler = hs.get_event_creation_handler() From 5984ada6bb340c736376ba94d766bf76ceeaf514 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 31 May 2022 13:41:49 +0100 Subject: [PATCH 52/74] 1.60.0 --- CHANGES.md | 13 +++++++++++-- changelog.d/12918.bugfix | 1 - debian/changelog | 6 ++++++ pyproject.toml | 2 +- 4 files changed, 18 insertions(+), 4 deletions(-) delete mode 100644 changelog.d/12918.bugfix diff --git a/CHANGES.md b/CHANGES.md index 40e362e920..de6bf95e43 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,5 @@ -Synapse 1.60.0rc2 (2022-05-27) -============================== +Synapse 1.60.0 (2022-05-31) +=========================== This release of Synapse adds a unique index to the `state_group_edges` table, in order to prevent accidentally introducing duplicate information (for example, @@ -14,6 +14,15 @@ should update their modules to use the new signature where possible. See [the upgrade notes](https://github.com/matrix-org/synapse/blob/develop/docs/upgrade.md#upgrading-to-v1600) for more details. +Bugfixes +-------- + +- Fix a bug introduced in Synapse 1.60.0rc1 that would break some imports from `synapse.module_api`. ([\#12918](https://github.com/matrix-org/synapse/issues/12918)) + + +Synapse 1.60.0rc2 (2022-05-27) +============================== + Features -------- diff --git a/changelog.d/12918.bugfix b/changelog.d/12918.bugfix deleted file mode 100644 index 38bdd80700..0000000000 --- a/changelog.d/12918.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a bug introduced in Synapse 1.60.0rc1 that would break some imports from `synapse.module_api`. diff --git a/debian/changelog b/debian/changelog index b6a51d6903..5d332cedef 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,9 @@ +matrix-synapse-py3 (1.60.0) stable; urgency=medium + + * New Synapse release 1.60.0. + + -- Synapse Packaging team Tue, 31 May 2022 13:41:22 +0100 + matrix-synapse-py3 (1.60.0~rc2) stable; urgency=medium * New Synapse release 1.60.0rc2. diff --git a/pyproject.toml b/pyproject.toml index 59cff590b5..75251c863d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,7 +54,7 @@ skip_gitignore = true [tool.poetry] name = "matrix-synapse" -version = "1.60.0rc2" +version = "1.60.0" description = "Homeserver for the Matrix decentralised comms protocol" authors = ["Matrix.org Team and Contributors "] license = "Apache-2.0" From 5e17922ef715e3c68911955b32f42a70d6728831 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 31 May 2022 13:51:49 +0100 Subject: [PATCH 53/74] Stop reading from `event_edges.room_id`. (#12914) event_edges.room_id is implied by the event id, so there is no need to join on the room id. --- changelog.d/12914.misc | 1 + .../databases/main/event_federation.py | 7 +--- .../storage/databases/main/events_worker.py | 41 +++++++++---------- synapse/storage/schema/__init__.py | 5 ++- 4 files changed, 26 insertions(+), 28 deletions(-) create mode 100644 changelog.d/12914.misc diff --git a/changelog.d/12914.misc b/changelog.d/12914.misc new file mode 100644 index 0000000000..07d819932a --- /dev/null +++ b/changelog.d/12914.misc @@ -0,0 +1 @@ +Preparation for database schema simplifications: stop reading from `event_edges.room_id`. diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 562dcbe94d..eec55b6478 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -1318,17 +1318,14 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas query = ( "SELECT prev_event_id FROM event_edges " - "WHERE room_id = ? AND event_id = ? AND is_state = ? " + "WHERE event_id = ? AND NOT is_state " "LIMIT ?" ) while front and len(event_results) < limit: new_front = set() for event_id in front: - txn.execute( - query, (room_id, event_id, False, limit - len(event_results)) - ) - + txn.execute(query, (event_id, limit - len(event_results))) new_results = {t[0] for t in txn} - seen_events new_front |= new_results diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index a97d7e1664..b99b107784 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -1928,23 +1928,6 @@ class EventsWorkerStore(SQLBaseStore): LIMIT 1 """ - # Check to see whether the event in question is already referenced - # by another event. If we don't see any edges, we're next to a - # forward gap. - forward_edge_query = """ - SELECT 1 FROM event_edges - /* Check to make sure the event referencing our event in question is not rejected */ - LEFT JOIN rejections ON event_edges.event_id = rejections.event_id - WHERE - event_edges.room_id = ? - AND event_edges.prev_event_id = ? - /* It's not a valid edge if the event referencing our event in - * question is rejected. - */ - AND rejections.event_id IS NULL - LIMIT 1 - """ - # We consider any forward extremity as the latest in the room and # not a forward gap. # @@ -1954,16 +1937,30 @@ class EventsWorkerStore(SQLBaseStore): # is useless. The new latest messages will just be federated as # usual. txn.execute(forward_extremity_query, (event.room_id, event.event_id)) - forward_extremities = txn.fetchall() - if len(forward_extremities): + if txn.fetchone(): return False + # Check to see whether the event in question is already referenced + # by another event. If we don't see any edges, we're next to a + # forward gap. + forward_edge_query = """ + SELECT 1 FROM event_edges + /* Check to make sure the event referencing our event in question is not rejected */ + LEFT JOIN rejections ON event_edges.event_id = rejections.event_id + WHERE + event_edges.prev_event_id = ? + /* It's not a valid edge if the event referencing our event in + * question is rejected. + */ + AND rejections.event_id IS NULL + LIMIT 1 + """ + # If there are no forward edges to the event in question (another # event hasn't referenced this event in their prev_events), then we # assume there is a forward gap in the history. - txn.execute(forward_edge_query, (event.room_id, event.event_id)) - forward_edges = txn.fetchall() - if not len(forward_edges): + txn.execute(forward_edge_query, (event.event_id,)) + if not txn.fetchone(): return True return False diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index da98f05e03..19466150d4 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -SCHEMA_VERSION = 70 # remember to update the list below when updating +SCHEMA_VERSION = 71 # remember to update the list below when updating """Represents the expectations made by the codebase about the database schema This should be incremented whenever the codebase changes its requirements on the @@ -67,6 +67,9 @@ Changes in SCHEMA_VERSION = 69: Changes in SCHEMA_VERSION = 70: - event_reference_hashes is no longer written to. + +Changes in SCHEMA_VERSION = 71: + - event_edges.room_id is no longer read from. """ From c8684e67924fceed44bcbc4a607502764905ba1e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 31 May 2022 14:01:05 +0100 Subject: [PATCH 54/74] Reduce DB load of /sync when using presence (#12885) While the query was fast, we were calling it *a lot*. --- changelog.d/12885.misc | 1 + synapse/storage/databases/main/presence.py | 75 ++++++++++++++-------- 2 files changed, 49 insertions(+), 27 deletions(-) create mode 100644 changelog.d/12885.misc diff --git a/changelog.d/12885.misc b/changelog.d/12885.misc new file mode 100644 index 0000000000..2524056307 --- /dev/null +++ b/changelog.d/12885.misc @@ -0,0 +1 @@ +Reduce database load of `/sync` when presence is enabled. diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index b47c511450..9769a18a9d 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Tuple, cast +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, cast from synapse.api.presence import PresenceState, UserPresenceState from synapse.replication.tcp.streams import PresenceStream @@ -22,6 +22,7 @@ from synapse.storage.database import ( LoggingDatabaseConnection, LoggingTransaction, ) +from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.engines import PostgresEngine from synapse.storage.types import Connection from synapse.storage.util.id_generators import ( @@ -56,7 +57,7 @@ class PresenceBackgroundUpdateStore(SQLBaseStore): ) -class PresenceStore(PresenceBackgroundUpdateStore): +class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore): def __init__( self, database: DatabasePool, @@ -281,20 +282,30 @@ class PresenceStore(PresenceBackgroundUpdateStore): True if the user should have full presence sent to them, False otherwise. """ - def _should_user_receive_full_presence_with_token_txn( - txn: LoggingTransaction, - ) -> bool: - sql = """ - SELECT 1 FROM users_to_send_full_presence_to - WHERE user_id = ? - AND presence_stream_id >= ? - """ - txn.execute(sql, (user_id, from_token)) - return bool(txn.fetchone()) + token = await self._get_full_presence_stream_token_for_user(user_id) + if token is None: + return False - return await self.db_pool.runInteraction( - "should_user_receive_full_presence_with_token", - _should_user_receive_full_presence_with_token_txn, + return from_token <= token + + @cached() + async def _get_full_presence_stream_token_for_user( + self, user_id: str + ) -> Optional[int]: + """Get the presence token corresponding to the last full presence update + for this user. + + If the user presents a sync token with a presence stream token at least + as old as the result, then we need to send them a full presence update. + + If this user has never needed a full presence update, returns `None`. + """ + return await self.db_pool.simple_select_one_onecol( + table="users_to_send_full_presence_to", + keyvalues={"user_id": user_id}, + retcol="presence_stream_id", + allow_none=True, + desc="_get_full_presence_stream_token_for_user", ) async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]) -> None: @@ -307,18 +318,28 @@ class PresenceStore(PresenceBackgroundUpdateStore): # Add user entries to the table, updating the presence_stream_id column if the user already # exists in the table. presence_stream_id = self._presence_id_gen.get_current_token() - await self.db_pool.simple_upsert_many( - table="users_to_send_full_presence_to", - key_names=("user_id",), - key_values=[(user_id,) for user_id in user_ids], - value_names=("presence_stream_id",), - # We save the current presence stream ID token along with the user ID entry so - # that when a user /sync's, even if they syncing multiple times across separate - # devices at different times, each device will receive full presence once - when - # the presence stream ID in their sync token is less than the one in the table - # for their user ID. - value_values=[(presence_stream_id,) for _ in user_ids], - desc="add_users_to_send_full_presence_to", + + def _add_users_to_send_full_presence_to(txn: LoggingTransaction) -> None: + self.db_pool.simple_upsert_many_txn( + txn, + table="users_to_send_full_presence_to", + key_names=("user_id",), + key_values=[(user_id,) for user_id in user_ids], + value_names=("presence_stream_id",), + # We save the current presence stream ID token along with the user ID entry so + # that when a user /sync's, even if they syncing multiple times across separate + # devices at different times, each device will receive full presence once - when + # the presence stream ID in their sync token is less than the one in the table + # for their user ID. + value_values=[(presence_stream_id,) for _ in user_ids], + ) + for user_id in user_ids: + self._invalidate_cache_and_stream( + txn, self._get_full_presence_stream_token_for_user, (user_id,) + ) + + return await self.db_pool.runInteraction( + "add_users_to_send_full_presence_to", _add_users_to_send_full_presence_to ) async def get_presence_for_all_users( From bf01e51554ad87528d3e8612d0f02c52f8c0e562 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Tue, 31 May 2022 14:02:00 +0100 Subject: [PATCH 55/74] Test Synapse against Complement with workers. (#12810) Co-authored-by: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> --- .ci/scripts/checkout_complement.sh | 25 ++++++++++++ .github/workflows/tests.yml | 62 ++++++++++++++++++++---------- changelog.d/12810.misc | 1 + 3 files changed, 67 insertions(+), 21 deletions(-) create mode 100755 .ci/scripts/checkout_complement.sh create mode 100644 changelog.d/12810.misc diff --git a/.ci/scripts/checkout_complement.sh b/.ci/scripts/checkout_complement.sh new file mode 100755 index 0000000000..379f5d4387 --- /dev/null +++ b/.ci/scripts/checkout_complement.sh @@ -0,0 +1,25 @@ +#!/bin/bash +# +# Fetches a version of complement which best matches the current build. +# +# The tarball is unpacked into `./complement`. + +set -e +mkdir -p complement + +# Pick an appropriate version of complement. Depending on whether this is a PR or release, +# etc. we need to use different fallbacks: +# +# 1. First check if there's a similarly named branch (GITHUB_HEAD_REF +# for pull requests, otherwise GITHUB_REF). +# 2. Attempt to use the base branch, e.g. when merging into release-vX.Y +# (GITHUB_BASE_REF for pull requests). +# 3. Use the default complement branch ("HEAD"). +for BRANCH_NAME in "$GITHUB_HEAD_REF" "$GITHUB_BASE_REF" "${GITHUB_REF#refs/heads/}" "HEAD"; do + # Skip empty branch names and merge commits. + if [[ -z "$BRANCH_NAME" || $BRANCH_NAME =~ ^refs/pull/.* ]]; then + continue + fi + + (wget -O - "https://github.com/matrix-org/complement/archive/$BRANCH_NAME.tar.gz" | tar -xz --strip-components=1 -C complement) && break +done diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index efa35b71df..3693cf06c3 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -306,7 +306,7 @@ jobs: - run: .ci/scripts/test_synapse_port_db.sh complement: - if: ${{ !failure() && !cancelled() }} + if: "${{ !failure() && !cancelled() }}" needs: linting-done runs-on: ubuntu-latest @@ -333,26 +333,7 @@ jobs: # Attempt to check out the same branch of Complement as the PR. If it # doesn't exist, fallback to HEAD. - name: Checkout complement - shell: bash - run: | - mkdir -p complement - # Attempt to use the version of complement which best matches the current - # build. Depending on whether this is a PR or release, etc. we need to - # use different fallbacks. - # - # 1. First check if there's a similarly named branch (GITHUB_HEAD_REF - # for pull requests, otherwise GITHUB_REF). - # 2. Attempt to use the base branch, e.g. when merging into release-vX.Y - # (GITHUB_BASE_REF for pull requests). - # 3. Use the default complement branch ("HEAD"). - for BRANCH_NAME in "$GITHUB_HEAD_REF" "$GITHUB_BASE_REF" "${GITHUB_REF#refs/heads/}" "HEAD"; do - # Skip empty branch names and merge commits. - if [[ -z "$BRANCH_NAME" || $BRANCH_NAME =~ ^refs/pull/.* ]]; then - continue - fi - - (wget -O - "https://github.com/matrix-org/complement/archive/$BRANCH_NAME.tar.gz" | tar -xz --strip-components=1 -C complement) && break - done + run: synapse/.ci/scripts/checkout_complement.sh - run: | set -o pipefail @@ -360,6 +341,45 @@ jobs: shell: bash name: Run Complement Tests + # We only run the workers tests on `develop` for now, because they're too slow to wait for on PRs. + # Sadly, you can't have an `if` condition on the value of a matrix, so this is a temporary, separate job for now. + # GitHub Actions doesn't support YAML anchors, so it's full-on duplication for now. + complement-developonly: + if: "${{ !failure() && !cancelled() && (github.ref == 'refs/heads/develop') }}" + needs: linting-done + runs-on: ubuntu-latest + + steps: + # The path is set via a file given by $GITHUB_PATH. We need both Go 1.17 and GOPATH on the path to run Complement. + # See https://docs.github.com/en/actions/using-workflows/workflow-commands-for-github-actions#adding-a-system-path + - name: "Set Go Version" + run: | + # Add Go 1.17 to the PATH: see https://github.com/actions/virtual-environments/blob/main/images/linux/Ubuntu2004-Readme.md#environment-variables-2 + echo "$GOROOT_1_17_X64/bin" >> $GITHUB_PATH + # Add the Go path to the PATH: We need this so we can call gotestfmt + echo "~/go/bin" >> $GITHUB_PATH + + - name: "Install Complement Dependencies" + run: | + sudo apt-get update && sudo apt-get install -y libolm3 libolm-dev + go get -v github.com/haveyoudebuggedit/gotestfmt/v2/cmd/gotestfmt@latest + + - name: Run actions/checkout@v2 for synapse + uses: actions/checkout@v2 + with: + path: synapse + + # Attempt to check out the same branch of Complement as the PR. If it + # doesn't exist, fallback to HEAD. + - name: Checkout complement + run: .ci/scripts/checkout_complement.sh + + - run: | + set -o pipefail + WORKERS=1 COMPLEMENT_DIR=`pwd`/complement synapse/scripts-dev/complement.sh -json 2>&1 | gotestfmt + shell: bash + name: Run Complement Tests + # a job which marks all the other jobs as complete, thus allowing PRs to be merged. tests-done: if: ${{ always() }} diff --git a/changelog.d/12810.misc b/changelog.d/12810.misc new file mode 100644 index 0000000000..fe5fb81d5e --- /dev/null +++ b/changelog.d/12810.misc @@ -0,0 +1 @@ +Test Synapse against Complement with workers. \ No newline at end of file From b2b5279a3f1b4012de664b424f9e9db13ce3c774 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 31 May 2022 14:25:46 +0100 Subject: [PATCH 56/74] Update changelog --- CHANGES.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index de6bf95e43..2bf8cdea75 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -69,6 +69,7 @@ Bugfixes - Fix a bug introduced in Synapse 1.30.0 where empty rooms could be automatically created if a monthly active users limit is set. ([\#12713](https://github.com/matrix-org/synapse/issues/12713)) - Fix push to dismiss notifications when read on another client. Contributed by @SpiritCroc @ Beeper. ([\#12721](https://github.com/matrix-org/synapse/issues/12721)) - Fix poor database performance when reading the cache invalidation stream for large servers with lots of workers. ([\#12747](https://github.com/matrix-org/synapse/issues/12747)) +- Fix a long-standing bug where the user directory background process would fail to make forward progress if a user included a null codepoint in their display name or avatar. ([\#12762](https://github.com/matrix-org/synapse/issues/12762)) - Delete events from the `federation_inbound_events_staging` table when a room is purged through the admin API. ([\#12770](https://github.com/matrix-org/synapse/issues/12770)) - Give a meaningful error message when a client tries to create a room with an invalid alias localpart. ([\#12779](https://github.com/matrix-org/synapse/issues/12779)) - Fix a bug introduced in 1.43.0 where a file (`providers.json`) was never closed. Contributed by @arkamar. ([\#12794](https://github.com/matrix-org/synapse/issues/12794)) @@ -124,7 +125,6 @@ Internal Changes - Drop the logging level of status messages for the URL preview cache expiry job from INFO to DEBUG. ([\#12720](https://github.com/matrix-org/synapse/issues/12720)) - Downgrade some OIDC errors to warnings in the logs, to reduce the noise of Sentry reports. ([\#12723](https://github.com/matrix-org/synapse/issues/12723)) - Update configs used by Complement to allow more invites/3PID validations during tests. ([\#12731](https://github.com/matrix-org/synapse/issues/12731)) -- Fix a long-standing bug where the user directory background process would fail to make forward progress if a user included a null codepoint in their display name or avatar. ([\#12762](https://github.com/matrix-org/synapse/issues/12762)) - Tweak the mypy plugin so that `@cached` can accept `on_invalidate=None`. ([\#12769](https://github.com/matrix-org/synapse/issues/12769)) - Move methods that call `add_push_rule` to the `PushRuleStore` class. ([\#12772](https://github.com/matrix-org/synapse/issues/12772)) - Make handling of federation Authorization header (more) compliant with RFC7230. ([\#12774](https://github.com/matrix-org/synapse/issues/12774)) @@ -231,7 +231,7 @@ Deprecations and Removals ------------------------- - Remove unstable identifiers from [MSC3069](https://github.com/matrix-org/matrix-doc/pull/3069). ([\#12596](https://github.com/matrix-org/synapse/issues/12596)) -- Remove the unspecified `m.login.jwt` login type and the unstable `uk.half-shot.msc2778.login.application_service` from +- Remove the unspecified `m.login.jwt` login type and the unstable `uk.half-shot.msc2778.login.application_service` from [MSC2778](https://github.com/matrix-org/matrix-doc/pull/2778). ([\#12597](https://github.com/matrix-org/synapse/issues/12597)) - Synapse now requires at least Python 3.7.1 (up from 3.7.0), for compatibility with the latest Twisted trunk. ([\#12613](https://github.com/matrix-org/synapse/issues/12613)) From 2fba1076c56e76410fd901120f0e8df2ef33d1c4 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Tue, 31 May 2022 15:50:29 +0100 Subject: [PATCH 57/74] Faster room joins: Try other destinations when resyncing the state of a partial-state room (#12812) Signed-off-by: Sean Quah --- changelog.d/12812.misc | 1 + synapse/federation/federation_client.py | 5 +- synapse/handlers/federation.py | 86 ++++++++++++++++++++++--- synapse/handlers/federation_event.py | 11 ++++ 4 files changed, 94 insertions(+), 9 deletions(-) create mode 100644 changelog.d/12812.misc diff --git a/changelog.d/12812.misc b/changelog.d/12812.misc new file mode 100644 index 0000000000..53cb936a02 --- /dev/null +++ b/changelog.d/12812.misc @@ -0,0 +1 @@ +Try other homeservers when re-syncing state for rooms with partial state. diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 17eff60909..b60b8983ea 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -405,6 +405,9 @@ class FederationClient(FederationBase): Returns: a tuple of (state event_ids, auth event_ids) + + Raises: + InvalidResponseError: if fields in the response have the wrong type. """ result = await self.transport_layer.get_room_state_ids( destination, room_id, event_id=event_id @@ -416,7 +419,7 @@ class FederationClient(FederationBase): if not isinstance(state_event_ids, list) or not isinstance( auth_event_ids, list ): - raise Exception("invalid response from /state_ids") + raise InvalidResponseError("invalid response from /state_ids") return state_event_ids, auth_event_ids diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 80ee7e7b4e..b4b63a342a 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -20,7 +20,16 @@ import itertools import logging from enum import Enum from http import HTTPStatus -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union +from typing import ( + TYPE_CHECKING, + Collection, + Dict, + Iterable, + List, + Optional, + Tuple, + Union, +) import attr from signedjson.key import decode_verify_key_bytes @@ -34,6 +43,7 @@ from synapse.api.errors import ( CodeMessageException, Codes, FederationDeniedError, + FederationError, HttpResponseException, NotFoundError, RequestSendFailed, @@ -545,7 +555,8 @@ class FederationHandler: run_as_background_process( desc="sync_partial_state_room", func=self._sync_partial_state_room, - destination=origin, + initial_destination=origin, + other_destinations=ret.servers_in_room, room_id=room_id, ) @@ -1454,13 +1465,16 @@ class FederationHandler: async def _sync_partial_state_room( self, - destination: str, + initial_destination: Optional[str], + other_destinations: Collection[str], room_id: str, ) -> None: """Background process to resync the state of a partial-state room Args: - destination: homeserver to pull the state from + initial_destination: the initial homeserver to pull the state from + other_destinations: other homeservers to try to pull the state from, if + `initial_destination` is unavailable room_id: room to be resynced """ @@ -1472,8 +1486,29 @@ class FederationHandler: # really leave, that might mean we have difficulty getting the room state over # federation. # - # TODO(faster_joins): try other destinations if the one we have fails + # TODO(faster_joins): we need some way of prioritising which homeservers in + # `other_destinations` to try first, otherwise we'll spend ages trying dead + # homeservers for large rooms. + if initial_destination is None and len(other_destinations) == 0: + raise ValueError( + f"Cannot resync state of {room_id}: no destinations provided" + ) + + # Make an infinite iterator of destinations to try. Once we find a working + # destination, we'll stick with it until it flakes. + if initial_destination is not None: + # Move `initial_destination` to the front of the list. + destinations = list(other_destinations) + if initial_destination in destinations: + destinations.remove(initial_destination) + destinations = [initial_destination] + destinations + destination_iter = itertools.cycle(destinations) + else: + destination_iter = itertools.cycle(other_destinations) + + # `destination` is the current remote homeserver we're pulling from. + destination = next(destination_iter) logger.info("Syncing state for room %s via %s", room_id, destination) # we work through the queue in order of increasing stream ordering. @@ -1511,6 +1546,41 @@ class FederationHandler: allow_rejected=True, ) for event in events: - await self._federation_event_handler.update_state_for_partial_state_event( - destination, event - ) + for attempt in itertools.count(): + try: + await self._federation_event_handler.update_state_for_partial_state_event( + destination, event + ) + break + except FederationError as e: + if attempt == len(destinations) - 1: + # We have tried every remote server for this event. Give up. + # TODO(faster_joins) giving up isn't the right thing to do + # if there's a temporary network outage. retrying + # indefinitely is also not the right thing to do if we can + # reach all homeservers and they all claim they don't have + # the state we want. + logger.error( + "Failed to get state for %s at %s from %s because %s, " + "giving up!", + room_id, + event, + destination, + e, + ) + raise + + # Try the next remote server. + logger.info( + "Failed to get state for %s at %s from %s because %s", + room_id, + event, + destination, + e, + ) + destination = next(destination_iter) + logger.info( + "Syncing state for room %s via %s instead", + room_id, + destination, + ) diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index b908674529..549b066dd9 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -505,6 +505,9 @@ class FederationEventHandler: Args: destination: server to request full state from event: partial-state event to be de-partial-stated + + Raises: + FederationError if we fail to request state from the remote server. """ logger.info("Updating state for %s", event.event_id) with nested_logging_context(suffix=event.event_id): @@ -815,6 +818,10 @@ class FederationEventHandler: Returns: if we already had all the prev events, `None`. Otherwise, returns the event ids of the state at `event`. + + Raises: + FederationError if we fail to get the state from the remote server after any + missing `prev_event`s. """ room_id = event.room_id event_id = event.event_id @@ -901,6 +908,10 @@ class FederationEventHandler: Returns: The event ids of the state *after* the given event. + + Raises: + InvalidResponseError: if the remote homeserver's response contains fields + of the wrong type. """ ( state_event_ids, From 641908f72f94357049eca1cab632918d252da3e0 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Tue, 31 May 2022 16:15:08 +0100 Subject: [PATCH 58/74] Faster room joins: Resume state re-syncing after a Synapse restart (#12813) Signed-off-by: Sean Quah --- changelog.d/12813.misc | 1 + synapse/handlers/federation.py | 27 ++++++++++++++++++++++++-- synapse/storage/databases/main/room.py | 27 ++++++++++++++++++++++++++ 3 files changed, 53 insertions(+), 2 deletions(-) create mode 100644 changelog.d/12813.misc diff --git a/changelog.d/12813.misc b/changelog.d/12813.misc new file mode 100644 index 0000000000..8be9f3eb44 --- /dev/null +++ b/changelog.d/12813.misc @@ -0,0 +1 @@ +Resume state re-syncing for rooms with partial state after a Synapse restart. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index b4b63a342a..659f279441 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -169,6 +169,14 @@ class FederationHandler: self.third_party_event_rules = hs.get_third_party_event_rules() + # if this is the main process, fire off a background process to resume + # any partial-state-resync operations which were in flight when we + # were shut down. + if not hs.config.worker.worker_app: + run_as_background_process( + "resume_sync_partial_state_room", self._resume_sync_partial_state_room + ) + async def maybe_backfill( self, room_id: str, current_depth: int, limit: int ) -> bool: @@ -470,6 +478,8 @@ class FederationHandler: """ # TODO: We should be able to call this on workers, but the upgrading of # room stuff after join currently doesn't work on workers. + # TODO: Before we relax this condition, we need to allow re-syncing of + # partial room state to happen on workers. assert self.config.worker.worker_app is None logger.debug("Joining %s to %s", joinee, room_id) @@ -550,8 +560,6 @@ class FederationHandler: if ret.partial_state: # Kick off the process of asynchronously fetching the state for this # room. - # - # TODO(faster_joins): pick this up again on restart run_as_background_process( desc="sync_partial_state_room", func=self._sync_partial_state_room, @@ -1463,6 +1471,20 @@ class FederationHandler: # well. return None + async def _resume_sync_partial_state_room(self) -> None: + """Resumes resyncing of all partial-state rooms after a restart.""" + assert not self.config.worker.worker_app + + partial_state_rooms = await self.store.get_partial_state_rooms_and_servers() + for room_id, servers_in_room in partial_state_rooms.items(): + run_as_background_process( + desc="sync_partial_state_room", + func=self._sync_partial_state_room, + initial_destination=None, + other_destinations=servers_in_room, + room_id=room_id, + ) + async def _sync_partial_state_room( self, initial_destination: Optional[str], @@ -1477,6 +1499,7 @@ class FederationHandler: `initial_destination` is unavailable room_id: room to be resynced """ + assert not self.config.worker.worker_app # TODO(faster_joins): do we need to lock to avoid races? What happens if other # worker processes kick off a resync in parallel? Perhaps we should just elect diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 10f2ceb50b..cfd8ce1624 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -23,6 +23,7 @@ from typing import ( Collection, Dict, List, + Mapping, Optional, Tuple, Union, @@ -1081,6 +1082,32 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): get_rooms_for_retention_period_in_range_txn, ) + async def get_partial_state_rooms_and_servers( + self, + ) -> Mapping[str, Collection[str]]: + """Get all rooms containing events with partial state, and the servers known + to be in the room. + + Returns: + A dictionary of rooms with partial state, with room IDs as keys and + lists of servers in rooms as values. + """ + room_servers: Dict[str, List[str]] = {} + + rows = await self.db_pool.simple_select_list( + "partial_state_rooms_servers", + keyvalues=None, + retcols=("room_id", "server_name"), + desc="get_partial_state_rooms", + ) + + for row in rows: + room_id = row["room_id"] + server_name = row["server_name"] + room_servers.setdefault(room_id, []).append(server_name) + + return room_servers + async def clear_partial_state_room(self, room_id: str) -> bool: # this can race with incoming events, so we watch out for FK errors. # TODO(faster_joins): this still doesn't completely fix the race, since the persist process From 2fc787c341ff540e5880932f116498ec0ed7a2c2 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Tue, 31 May 2022 17:35:29 +0100 Subject: [PATCH 59/74] Add config options for media retention (#12732) --- changelog.d/12732.feature | 1 + .../configuration/config_documentation.md | 29 ++- synapse/config/repository.py | 16 ++ synapse/rest/media/v1/media_repository.py | 71 +++++- tests/rest/media/test_media_retention.py | 238 ++++++++++++++++++ 5 files changed, 353 insertions(+), 2 deletions(-) create mode 100644 changelog.d/12732.feature create mode 100644 tests/rest/media/test_media_retention.py diff --git a/changelog.d/12732.feature b/changelog.d/12732.feature new file mode 100644 index 0000000000..3c73363d28 --- /dev/null +++ b/changelog.d/12732.feature @@ -0,0 +1 @@ +Add new `media_retention` options to the homeserver config for routinely cleaning up non-recently accessed media. \ No newline at end of file diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index 88b9e5744d..1c75a23a36 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -1459,7 +1459,7 @@ federation_rr_transactions_per_room_per_second: 40 ``` --- ## Media Store ## -Config options relating to Synapse media store. +Config options related to Synapse's media store. --- Config option: `enable_media_repo` @@ -1563,6 +1563,33 @@ thumbnail_sizes: height: 600 method: scale ``` +--- +Config option: `media_retention` + +Controls whether local media and entries in the remote media cache +(media that is downloaded from other homeservers) should be removed +under certain conditions, typically for the purpose of saving space. + +Purging media files will be the carried out by the media worker +(that is, the worker that has the `enable_media_repo` homeserver config +option set to 'true'). This may be the main process. + +The `media_retention.local_media_lifetime` and +`media_retention.remote_media_lifetime` config options control whether +media will be purged if it has not been accessed in a given amount of +time. Note that media is 'accessed' when loaded in a room in a client, or +otherwise downloaded by a local or remote user. If the media has never +been accessed, the media's creation time is used instead. Both thumbnails +and the original media will be removed. If either of these options are unset, +then media of that type will not be purged. + +Example configuration: +```yaml +media_retention: + local_media_lifetime: 90d + remote_media_lifetime: 14d +``` +--- Config option: `url_preview_enabled` This setting determines whether the preview URL API is enabled. diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 98d8a16621..f9c55143c3 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -223,6 +223,22 @@ class ContentRepositoryConfig(Config): "url_preview_accept_language" ) or ["en"] + media_retention = config.get("media_retention") or {} + + self.media_retention_local_media_lifetime_ms = None + local_media_lifetime = media_retention.get("local_media_lifetime") + if local_media_lifetime is not None: + self.media_retention_local_media_lifetime_ms = self.parse_duration( + local_media_lifetime + ) + + self.media_retention_remote_media_lifetime_ms = None + remote_media_lifetime = media_retention.get("remote_media_lifetime") + if remote_media_lifetime is not None: + self.media_retention_remote_media_lifetime_ms = self.parse_duration( + remote_media_lifetime + ) + def generate_config_section(self, data_dir_path: str, **kwargs: Any) -> str: assert data_dir_path is not None media_store = os.path.join(data_dir_path, "media_store") diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 3e5d6c6294..20af366538 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -65,7 +65,12 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000 +# How often to run the background job to update the "recently accessed" +# attribute of local and remote media. +UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000 # 1 minute +# How often to run the background job to check for local and remote media +# that should be purged according to the configured media retention settings. +MEDIA_RETENTION_CHECK_PERIOD_MS = 60 * 60 * 1000 # 1 hour class MediaRepository: @@ -122,11 +127,36 @@ class MediaRepository: self._start_update_recently_accessed, UPDATE_RECENTLY_ACCESSED_TS ) + # Media retention configuration options + self._media_retention_local_media_lifetime_ms = ( + hs.config.media.media_retention_local_media_lifetime_ms + ) + self._media_retention_remote_media_lifetime_ms = ( + hs.config.media.media_retention_remote_media_lifetime_ms + ) + + # Check whether local or remote media retention is configured + if ( + hs.config.media.media_retention_local_media_lifetime_ms is not None + or hs.config.media.media_retention_remote_media_lifetime_ms is not None + ): + # Run the background job to apply media retention rules routinely, + # with the duration between runs dictated by the homeserver config. + self.clock.looping_call( + self._start_apply_media_retention_rules, + MEDIA_RETENTION_CHECK_PERIOD_MS, + ) + def _start_update_recently_accessed(self) -> Deferred: return run_as_background_process( "update_recently_accessed_media", self._update_recently_accessed ) + def _start_apply_media_retention_rules(self) -> Deferred: + return run_as_background_process( + "apply_media_retention_rules", self._apply_media_retention_rules + ) + async def _update_recently_accessed(self) -> None: remote_media = self.recently_accessed_remotes self.recently_accessed_remotes = set() @@ -835,6 +865,45 @@ class MediaRepository: return {"width": m_width, "height": m_height} + async def _apply_media_retention_rules(self) -> None: + """ + Purge old local and remote media according to the media retention rules + defined in the homeserver config. + """ + # Purge remote media + if self._media_retention_remote_media_lifetime_ms is not None: + # Calculate a threshold timestamp derived from the configured lifetime. Any + # media that has not been accessed since this timestamp will be removed. + remote_media_threshold_timestamp_ms = ( + self.clock.time_msec() - self._media_retention_remote_media_lifetime_ms + ) + + logger.info( + "Purging remote media last accessed before" + f" {remote_media_threshold_timestamp_ms}" + ) + + await self.delete_old_remote_media( + before_ts=remote_media_threshold_timestamp_ms + ) + + # And now do the same for local media + if self._media_retention_local_media_lifetime_ms is not None: + # This works the same as the remote media threshold + local_media_threshold_timestamp_ms = ( + self.clock.time_msec() - self._media_retention_local_media_lifetime_ms + ) + + logger.info( + "Purging local media last accessed before" + f" {local_media_threshold_timestamp_ms}" + ) + + await self.delete_old_local_media( + before_ts=local_media_threshold_timestamp_ms, + keep_profiles=True, + ) + async def delete_old_remote_media(self, before_ts: int) -> Dict[str, int]: old_media = await self.store.get_remote_media_before(before_ts) diff --git a/tests/rest/media/test_media_retention.py b/tests/rest/media/test_media_retention.py new file mode 100644 index 0000000000..b98a5cd586 --- /dev/null +++ b/tests/rest/media/test_media_retention.py @@ -0,0 +1,238 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +from typing import Iterable, Optional, Tuple + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.rest import admin +from synapse.rest.client import login, register, room +from synapse.server import HomeServer +from synapse.types import UserID +from synapse.util import Clock + +from tests import unittest +from tests.unittest import override_config +from tests.utils import MockClock + + +class MediaRetentionTestCase(unittest.HomeserverTestCase): + + ONE_DAY_IN_MS = 24 * 60 * 60 * 1000 + THIRTY_DAYS_IN_MS = 30 * ONE_DAY_IN_MS + + servlets = [ + room.register_servlets, + login.register_servlets, + register.register_servlets, + admin.register_servlets_for_client_rest_resource, + ] + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + # We need to be able to test advancing time in the homeserver, so we + # replace the test homeserver's default clock with a MockClock, which + # supports advancing time. + return self.setup_test_homeserver(clock=MockClock()) + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.remote_server_name = "remote.homeserver" + self.store = hs.get_datastores().main + + # Create a user to upload media with + test_user_id = self.register_user("alice", "password") + + # Inject media (3 images each; recently accessed, old access, never accessed) + # into both the local store and the remote cache + media_repository = hs.get_media_repository() + test_media_content = b"example string" + + def _create_media_and_set_last_accessed( + last_accessed_ms: Optional[int], + ) -> str: + # "Upload" some media to the local media store + mxc_uri = self.get_success( + media_repository.create_content( + media_type="text/plain", + upload_name=None, + content=io.BytesIO(test_media_content), + content_length=len(test_media_content), + auth_user=UserID.from_string(test_user_id), + ) + ) + + media_id = mxc_uri.split("/")[-1] + + # Set the last recently accessed time for this media + if last_accessed_ms is not None: + self.get_success( + self.store.update_cached_last_access_time( + local_media=(media_id,), + remote_media=(), + time_ms=last_accessed_ms, + ) + ) + + return media_id + + def _cache_remote_media_and_set_last_accessed( + media_id: str, last_accessed_ms: Optional[int] + ) -> str: + # Pretend to cache some remote media + self.get_success( + self.store.store_cached_remote_media( + origin=self.remote_server_name, + media_id=media_id, + media_type="text/plain", + media_length=1, + time_now_ms=clock.time_msec(), + upload_name="testfile.txt", + filesystem_id="abcdefg12345", + ) + ) + + # Set the last recently accessed time for this media + if last_accessed_ms is not None: + self.get_success( + hs.get_datastores().main.update_cached_last_access_time( + local_media=(), + remote_media=((self.remote_server_name, media_id),), + time_ms=last_accessed_ms, + ) + ) + + return media_id + + # Start with the local media store + self.local_recently_accessed_media = _create_media_and_set_last_accessed( + self.THIRTY_DAYS_IN_MS + ) + self.local_not_recently_accessed_media = _create_media_and_set_last_accessed( + self.ONE_DAY_IN_MS + ) + self.local_never_accessed_media = _create_media_and_set_last_accessed(None) + + # And now the remote media store + self.remote_recently_accessed_media = _cache_remote_media_and_set_last_accessed( + "a", self.THIRTY_DAYS_IN_MS + ) + self.remote_not_recently_accessed_media = ( + _cache_remote_media_and_set_last_accessed("b", self.ONE_DAY_IN_MS) + ) + # Remote media will always have a "last accessed" attribute, as it would not + # be fetched from the remote homeserver unless instigated by a user. + + @override_config( + { + "media_retention": { + # Enable retention for local media + "local_media_lifetime": "30d" + # Cached remote media should not be purged + } + } + ) + def test_local_media_retention(self) -> None: + """ + Tests that local media that have not been accessed recently is purged, while + cached remote media is unaffected. + """ + # Advance 31 days (in seconds) + self.reactor.advance(31 * 24 * 60 * 60) + + # Check that media has been correctly purged. + # Local media accessed <30 days ago should still exist. + # Remote media should be unaffected. + self._assert_if_mxc_uris_purged( + purged=[ + ( + self.hs.config.server.server_name, + self.local_not_recently_accessed_media, + ), + (self.hs.config.server.server_name, self.local_never_accessed_media), + ], + not_purged=[ + (self.hs.config.server.server_name, self.local_recently_accessed_media), + (self.remote_server_name, self.remote_recently_accessed_media), + (self.remote_server_name, self.remote_not_recently_accessed_media), + ], + ) + + @override_config( + { + "media_retention": { + # Enable retention for cached remote media + "remote_media_lifetime": "30d" + # Local media should not be purged + } + } + ) + def test_remote_media_cache_retention(self) -> None: + """ + Tests that entries from the remote media cache that have not been accessed + recently is purged, while local media is unaffected. + """ + # Advance 31 days (in seconds) + self.reactor.advance(31 * 24 * 60 * 60) + + # Check that media has been correctly purged. + # Local media should be unaffected. + # Remote media accessed <30 days ago should still exist. + self._assert_if_mxc_uris_purged( + purged=[ + (self.remote_server_name, self.remote_not_recently_accessed_media), + ], + not_purged=[ + (self.remote_server_name, self.remote_recently_accessed_media), + (self.hs.config.server.server_name, self.local_recently_accessed_media), + ( + self.hs.config.server.server_name, + self.local_not_recently_accessed_media, + ), + (self.hs.config.server.server_name, self.local_never_accessed_media), + ], + ) + + def _assert_if_mxc_uris_purged( + self, purged: Iterable[Tuple[str, str]], not_purged: Iterable[Tuple[str, str]] + ) -> None: + def _assert_mxc_uri_purge_state( + server_name: str, media_id: str, expect_purged: bool + ) -> None: + """Given an MXC URI, assert whether it has been purged or not.""" + if server_name == self.hs.config.server.server_name: + found_media_dict = self.get_success( + self.store.get_local_media(media_id) + ) + else: + found_media_dict = self.get_success( + self.store.get_cached_remote_media(server_name, media_id) + ) + + mxc_uri = f"mxc://{server_name}/{media_id}" + + if expect_purged: + self.assertIsNone( + found_media_dict, msg=f"{mxc_uri} unexpectedly not purged" + ) + else: + self.assertIsNotNone( + found_media_dict, + msg=f"{mxc_uri} unexpectedly purged", + ) + + # Assert that the given MXC URIs have either been correctly purged or not. + for server_name, media_id in purged: + _assert_mxc_uri_purge_state(server_name, media_id, expect_purged=True) + for server_name, media_id in not_purged: + _assert_mxc_uri_purge_state(server_name, media_id, expect_purged=False) From cf05258f7672dd0dc054723e866c86f5e171b552 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 31 May 2022 13:04:08 -0400 Subject: [PATCH 60/74] Remove groups replication code. (#12900) The replication logic for groups is no longer used, so the message passing infrastructure can be removed. --- changelog.d/12900.removal | 1 + synapse/app/admin_cmd.py | 2 - synapse/app/generic_worker.py | 2 - synapse/replication/slave/storage/groups.py | 58 --------------------- synapse/replication/tcp/client.py | 5 -- synapse/replication/tcp/streams/__init__.py | 3 -- synapse/replication/tcp/streams/_base.py | 20 ------- 7 files changed, 1 insertion(+), 90 deletions(-) create mode 100644 changelog.d/12900.removal delete mode 100644 synapse/replication/slave/storage/groups.py diff --git a/changelog.d/12900.removal b/changelog.d/12900.removal new file mode 100644 index 0000000000..41f6fae5da --- /dev/null +++ b/changelog.d/12900.removal @@ -0,0 +1 @@ +Remove support for the non-standard groups/communities feature from Synapse. diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index 2a4c2e59cd..6fedf681f8 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -37,7 +37,6 @@ from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore from synapse.replication.slave.storage.devices import SlavedDeviceStore from synapse.replication.slave.storage.events import SlavedEventStore from synapse.replication.slave.storage.filtering import SlavedFilteringStore -from synapse.replication.slave.storage.groups import SlavedGroupServerStore from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore @@ -55,7 +54,6 @@ class AdminCmdSlavedStore( SlavedApplicationServiceStore, SlavedRegistrationStore, SlavedFilteringStore, - SlavedGroupServerStore, SlavedDeviceInboxStore, SlavedDeviceStore, SlavedPushRuleStore, diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 0a6dd618f6..89f8998f0e 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -58,7 +58,6 @@ from synapse.replication.slave.storage.devices import SlavedDeviceStore from synapse.replication.slave.storage.directory import DirectoryStore from synapse.replication.slave.storage.events import SlavedEventStore from synapse.replication.slave.storage.filtering import SlavedFilteringStore -from synapse.replication.slave.storage.groups import SlavedGroupServerStore from synapse.replication.slave.storage.keys import SlavedKeyStore from synapse.replication.slave.storage.profile import SlavedProfileStore from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore @@ -233,7 +232,6 @@ class GenericWorkerSlavedStore( SlavedDeviceStore, SlavedReceiptsStore, SlavedPushRuleStore, - SlavedGroupServerStore, SlavedAccountDataStore, SlavedPusherStore, CensorEventsStore, diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py deleted file mode 100644 index d6f37d7479..0000000000 --- a/synapse/replication/slave/storage/groups.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import TYPE_CHECKING, Any, Iterable - -from synapse.replication.slave.storage._base import BaseSlavedStore -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker -from synapse.replication.tcp.streams import GroupServerStream -from synapse.storage.database import DatabasePool, LoggingDatabaseConnection -from synapse.storage.databases.main.group_server import GroupServerWorkerStore -from synapse.util.caches.stream_change_cache import StreamChangeCache - -if TYPE_CHECKING: - from synapse.server import HomeServer - - -class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore): - def __init__( - self, - database: DatabasePool, - db_conn: LoggingDatabaseConnection, - hs: "HomeServer", - ): - super().__init__(database, db_conn, hs) - - self.hs = hs - - self._group_updates_id_gen = SlavedIdTracker( - db_conn, "local_group_updates", "stream_id" - ) - self._group_updates_stream_cache = StreamChangeCache( - "_group_updates_stream_cache", - self._group_updates_id_gen.get_current_token(), - ) - - def get_group_stream_token(self) -> int: - return self._group_updates_id_gen.get_current_token() - - def process_replication_rows( - self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] - ) -> None: - if stream_name == GroupServerStream.NAME: - self._group_updates_id_gen.advance(instance_name, token) - for row in rows: - self._group_updates_stream_cache.entity_has_changed(row.user_id, token) - - return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index a52e25c1af..2f59245058 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -30,7 +30,6 @@ from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol from synapse.replication.tcp.streams import ( AccountDataStream, DeviceListsStream, - GroupServerStream, PushersStream, PushRulesStream, ReceiptsStream, @@ -185,10 +184,6 @@ class ReplicationDataHandler: self.notifier.on_new_event( StreamKeyType.DEVICE_LIST, token, rooms=all_room_ids ) - elif stream_name == GroupServerStream.NAME: - self.notifier.on_new_event( - "groups_key", token, users=[row.user_id for row in rows] - ) elif stream_name == PushersStream.NAME: for row in rows: if row.deleted: diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py index f41eabd85e..b1cd55bf6f 100644 --- a/synapse/replication/tcp/streams/__init__.py +++ b/synapse/replication/tcp/streams/__init__.py @@ -29,7 +29,6 @@ from synapse.replication.tcp.streams._base import ( BackfillStream, CachesStream, DeviceListsStream, - GroupServerStream, PresenceFederationStream, PresenceStream, PushersStream, @@ -61,7 +60,6 @@ STREAMS_MAP = { FederationStream, TagAccountDataStream, AccountDataStream, - GroupServerStream, UserSignatureStream, ) } @@ -81,6 +79,5 @@ __all__ = [ "ToDeviceStream", "TagAccountDataStream", "AccountDataStream", - "GroupServerStream", "UserSignatureStream", ] diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 495f2f0285..398bebeaa6 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -585,26 +585,6 @@ class AccountDataStream(Stream): return updates, to_token, limited -class GroupServerStream(Stream): - @attr.s(slots=True, frozen=True, auto_attribs=True) - class GroupsStreamRow: - group_id: str - user_id: str - type: str - content: JsonDict - - NAME = "groups" - ROW_TYPE = GroupsStreamRow - - def __init__(self, hs: "HomeServer"): - store = hs.get_datastores().main - super().__init__( - hs.get_instance_name(), - current_token_without_instance(store.get_group_stream_token), - store.get_all_groups_changes, - ) - - class UserSignatureStream(Stream): """A user has signed their own device with their user-signing key""" From f0aec0abefceae36eff2ab08848e7a576535ee4e Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 31 May 2022 23:32:56 +0100 Subject: [PATCH 61/74] Improve logging when signature checks fail (#12925) * Raise a dedicated `InvalidEventSignatureError` from `_check_sigs_on_pdu` * Downgrade logging about redactions to DEBUG this can be very spammy during a room join, and it's not very useful. * Raise `InvalidEventSignatureError` from `_check_sigs_and_hash` ... and, more importantly, move the logging out to the callers. * changelog --- changelog.d/12925.misc | 1 + synapse/federation/federation_base.py | 89 ++++++++++++------------- synapse/federation/federation_client.py | 45 +++++++++---- synapse/federation/federation_server.py | 25 +++++-- 4 files changed, 95 insertions(+), 65 deletions(-) create mode 100644 changelog.d/12925.misc diff --git a/changelog.d/12925.misc b/changelog.d/12925.misc new file mode 100644 index 0000000000..71ca956dc5 --- /dev/null +++ b/changelog.d/12925.misc @@ -0,0 +1 @@ +Improve the logging when signature checks on events fail. diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index 7bc54b9988..a6232e048b 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -32,6 +32,18 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +class InvalidEventSignatureError(RuntimeError): + """Raised when the signature on an event is invalid. + + The stringification of this exception is just the error message without reference + to the event id. The event id is available as a property. + """ + + def __init__(self, message: str, event_id: str): + super().__init__(message) + self.event_id = event_id + + class FederationBase: def __init__(self, hs: "HomeServer"): self.hs = hs @@ -59,20 +71,13 @@ class FederationBase: Returns: * the original event if the checks pass * a redacted version of the event (if the signature - matched but the hash did not) + matched but the hash did not). In this case a warning will be logged. Raises: - SynapseError if the signature check failed. + InvalidEventSignatureError if the signature check failed. Nothing + will be logged in this case. """ - try: - await _check_sigs_on_pdu(self.keyring, room_version, pdu) - except SynapseError as e: - logger.warning( - "Signature check failed for %s: %s", - pdu.event_id, - e, - ) - raise + await _check_sigs_on_pdu(self.keyring, room_version, pdu) if not check_event_content_hash(pdu): # let's try to distinguish between failures because the event was @@ -87,7 +92,7 @@ class FederationBase: if set(redacted_event.keys()) == set(pdu.keys()) and set( redacted_event.content.keys() ) == set(pdu.content.keys()): - logger.info( + logger.debug( "Event %s seems to have been redacted; using our redacted copy", pdu.event_id, ) @@ -116,12 +121,13 @@ async def _check_sigs_on_pdu( ) -> None: """Check that the given events are correctly signed - Raise a SynapseError if the event wasn't correctly signed. - Args: keyring: keyring object to do the checks room_version: the room version of the PDUs pdus: the events to be checked + + Raises: + InvalidEventSignatureError if the event wasn't correctly signed. """ # we want to check that the event is signed by: @@ -147,44 +153,38 @@ async def _check_sigs_on_pdu( # First we check that the sender event is signed by the sender's domain # (except if its a 3pid invite, in which case it may be sent by any server) + sender_domain = get_domain_from_id(pdu.sender) if not _is_invite_via_3pid(pdu): try: await keyring.verify_event_for_server( - get_domain_from_id(pdu.sender), + sender_domain, pdu, pdu.origin_server_ts if room_version.enforce_key_validity else 0, ) except Exception as e: - errmsg = "event id %s: unable to verify signature for sender %s: %s" % ( + raise InvalidEventSignatureError( + f"unable to verify signature for sender domain {sender_domain}: {e}", pdu.event_id, - get_domain_from_id(pdu.sender), - e, - ) - raise SynapseError(403, errmsg, Codes.FORBIDDEN) + ) from None # now let's look for events where the sender's domain is different to the # event id's domain (normally only the case for joins/leaves), and add additional # checks. Only do this if the room version has a concept of event ID domain # (ie, the room version uses old-style non-hash event IDs). - if room_version.event_format == EventFormatVersions.V1 and get_domain_from_id( - pdu.event_id - ) != get_domain_from_id(pdu.sender): - try: - await keyring.verify_event_for_server( - get_domain_from_id(pdu.event_id), - pdu, - pdu.origin_server_ts if room_version.enforce_key_validity else 0, - ) - except Exception as e: - errmsg = ( - "event id %s: unable to verify signature for event id domain %s: %s" - % ( - pdu.event_id, - get_domain_from_id(pdu.event_id), - e, + if room_version.event_format == EventFormatVersions.V1: + event_domain = get_domain_from_id(pdu.event_id) + if event_domain != sender_domain: + try: + await keyring.verify_event_for_server( + event_domain, + pdu, + pdu.origin_server_ts if room_version.enforce_key_validity else 0, ) - ) - raise SynapseError(403, errmsg, Codes.FORBIDDEN) + except Exception as e: + raise InvalidEventSignatureError( + f"unable to verify signature for event domain {event_domain}: {e}", + pdu.event_id, + ) from None # If this is a join event for a restricted room it may have been authorised # via a different server from the sending server. Check those signatures. @@ -204,15 +204,10 @@ async def _check_sigs_on_pdu( pdu.origin_server_ts if room_version.enforce_key_validity else 0, ) except Exception as e: - errmsg = ( - "event id %s: unable to verify signature for authorising server %s: %s" - % ( - pdu.event_id, - authorising_server, - e, - ) - ) - raise SynapseError(403, errmsg, Codes.FORBIDDEN) + raise InvalidEventSignatureError( + f"unable to verify signature for authorising serve {authorising_server}: {e}", + pdu.event_id, + ) from None def _is_invite_via_3pid(event: EventBase) -> bool: diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index b60b8983ea..ad475a913b 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -54,7 +54,11 @@ from synapse.api.room_versions import ( RoomVersions, ) from synapse.events import EventBase, builder -from synapse.federation.federation_base import FederationBase, event_from_pdu_json +from synapse.federation.federation_base import ( + FederationBase, + InvalidEventSignatureError, + event_from_pdu_json, +) from synapse.federation.transport.client import SendJoinResponse from synapse.http.types import QueryParams from synapse.types import JsonDict, UserID, get_domain_from_id @@ -319,7 +323,13 @@ class FederationClient(FederationBase): pdu = pdu_list[0] # Check signatures are correct. - signed_pdu = await self._check_sigs_and_hash(room_version, pdu) + try: + signed_pdu = await self._check_sigs_and_hash(room_version, pdu) + except InvalidEventSignatureError as e: + errmsg = f"event id {pdu.event_id}: {e}" + logger.warning("%s", errmsg) + raise SynapseError(403, errmsg, Codes.FORBIDDEN) + return signed_pdu return None @@ -555,20 +565,24 @@ class FederationClient(FederationBase): Returns: The PDU (possibly redacted) if it has valid signatures and hashes. + None if no valid copy could be found. """ - res = None try: - res = await self._check_sigs_and_hash(room_version, pdu) - except SynapseError: - pass - - if not res: - # Check local db. - res = await self.store.get_event( - pdu.event_id, allow_rejected=True, allow_none=True + return await self._check_sigs_and_hash(room_version, pdu) + except InvalidEventSignatureError as e: + logger.warning( + "Signature on retrieved event %s was invalid (%s). " + "Checking local store/orgin server", + pdu.event_id, + e, ) + # Check local db. + res = await self.store.get_event( + pdu.event_id, allow_rejected=True, allow_none=True + ) + pdu_origin = get_domain_from_id(pdu.sender) if not res and pdu_origin != origin: try: @@ -1043,9 +1057,14 @@ class FederationClient(FederationBase): pdu = event_from_pdu_json(pdu_dict, room_version) # Check signatures are correct. - pdu = await self._check_sigs_and_hash(room_version, pdu) + try: + pdu = await self._check_sigs_and_hash(room_version, pdu) + except InvalidEventSignatureError as e: + errmsg = f"event id {pdu.event_id}: {e}" + logger.warning("%s", errmsg) + raise SynapseError(403, errmsg, Codes.FORBIDDEN) - # FIXME: We should handle signature failures more gracefully. + # FIXME: We should handle signature failures more gracefully. return pdu diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 3ecede22d9..12591dc8db 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -48,7 +48,11 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.crypto.event_signing import compute_event_signature from synapse.events import EventBase from synapse.events.snapshot import EventContext -from synapse.federation.federation_base import FederationBase, event_from_pdu_json +from synapse.federation.federation_base import ( + FederationBase, + InvalidEventSignatureError, + event_from_pdu_json, +) from synapse.federation.persistence import TransactionActions from synapse.federation.units import Edu, Transaction from synapse.http.servlet import assert_params_in_dict @@ -631,7 +635,12 @@ class FederationServer(FederationBase): pdu = event_from_pdu_json(content, room_version) origin_host, _ = parse_server_name(origin) await self.check_server_matches_acl(origin_host, pdu.room_id) - pdu = await self._check_sigs_and_hash(room_version, pdu) + try: + pdu = await self._check_sigs_and_hash(room_version, pdu) + except InvalidEventSignatureError as e: + errmsg = f"event id {pdu.event_id}: {e}" + logger.warning("%s", errmsg) + raise SynapseError(403, errmsg, Codes.FORBIDDEN) ret_pdu = await self.handler.on_invite_request(origin, pdu, room_version) time_now = self._clock.time_msec() return {"event": ret_pdu.get_pdu_json(time_now)} @@ -864,7 +873,12 @@ class FederationServer(FederationBase): ) ) - event = await self._check_sigs_and_hash(room_version, event) + try: + event = await self._check_sigs_and_hash(room_version, event) + except InvalidEventSignatureError as e: + errmsg = f"event id {event.event_id}: {e}" + logger.warning("%s", errmsg) + raise SynapseError(403, errmsg, Codes.FORBIDDEN) return await self._federation_event_handler.on_send_membership_event( origin, event @@ -1016,8 +1030,9 @@ class FederationServer(FederationBase): # Check signature. try: pdu = await self._check_sigs_and_hash(room_version, pdu) - except SynapseError as e: - raise FederationError("ERROR", e.code, e.msg, affected=pdu.event_id) + except InvalidEventSignatureError as e: + logger.warning("event id %s: %s", pdu.event_id, e) + raise FederationError("ERROR", 403, str(e), affected=pdu.event_id) if await self._spam_checker.should_drop_federated_event(pdu): logger.warning( From 2e8763ec96d2a8b03111e3d5e73924c4a23d8239 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 31 May 2022 20:28:17 -0400 Subject: [PATCH 62/74] Remove most groups datastore code. (#12895) The remaining piece is a background update that is needed for backwards compatibility. --- changelog.d/12895.removal | 1 + synapse/_scripts/synapse_port_db.py | 4 +- .../storage/databases/main/group_server.py | 1398 +---------------- 3 files changed, 7 insertions(+), 1396 deletions(-) create mode 100644 changelog.d/12895.removal diff --git a/changelog.d/12895.removal b/changelog.d/12895.removal new file mode 100644 index 0000000000..41f6fae5da --- /dev/null +++ b/changelog.d/12895.removal @@ -0,0 +1 @@ +Remove support for the non-standard groups/communities feature from Synapse. diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index 12ff79f6e2..d7dfa92bd1 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -62,7 +62,7 @@ from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyBackground from synapse.storage.databases.main.events_bg_updates import ( EventsBackgroundUpdatesStore, ) -from synapse.storage.databases.main.group_server import GroupServerWorkerStore +from synapse.storage.databases.main.group_server import GroupServerStore from synapse.storage.databases.main.media_repository import ( MediaRepositoryBackgroundUpdateStore, ) @@ -211,7 +211,7 @@ class Store( PushRuleStore, PusherWorkerStore, PresenceBackgroundUpdateStore, - GroupServerWorkerStore, + GroupServerStore, ): def execute(self, f: Callable[..., R], *args: Any, **kwargs: Any) -> Awaitable[R]: return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs) diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index 04efad9e9a..da21a50144 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -13,36 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast +from typing import TYPE_CHECKING -from typing_extensions import TypedDict - -from synapse.api.errors import SynapseError -from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import ( - DatabasePool, - LoggingDatabaseConnection, - LoggingTransaction, -) -from synapse.types import JsonDict -from synapse.util import json_encoder +from synapse.storage._base import SQLBaseStore +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection if TYPE_CHECKING: from synapse.server import HomeServer -# The category ID for the "default" category. We don't store as null in the -# database to avoid the fun of null != null -_DEFAULT_CATEGORY_ID = "" -_DEFAULT_ROLE_ID = "" - -# A room in a group. -class _RoomInGroup(TypedDict): - room_id: str - is_public: bool - - -class GroupServerWorkerStore(SQLBaseStore): +class GroupServerStore(SQLBaseStore): def __init__( self, database: DatabasePool, @@ -57,1373 +37,3 @@ class GroupServerWorkerStore(SQLBaseStore): unique=True, ) super().__init__(database, db_conn, hs) - - async def get_group(self, group_id: str) -> Optional[Dict[str, Any]]: - return await self.db_pool.simple_select_one( - table="groups", - keyvalues={"group_id": group_id}, - retcols=( - "name", - "short_description", - "long_description", - "avatar_url", - "is_public", - "join_policy", - ), - allow_none=True, - desc="get_group", - ) - - async def get_users_in_group( - self, group_id: str, include_private: bool = False - ) -> List[Dict[str, Any]]: - # TODO: Pagination - - keyvalues: JsonDict = {"group_id": group_id} - if not include_private: - keyvalues["is_public"] = True - - return await self.db_pool.simple_select_list( - table="group_users", - keyvalues=keyvalues, - retcols=("user_id", "is_public", "is_admin"), - desc="get_users_in_group", - ) - - async def get_invited_users_in_group(self, group_id: str) -> List[str]: - # TODO: Pagination - - return await self.db_pool.simple_select_onecol( - table="group_invites", - keyvalues={"group_id": group_id}, - retcol="user_id", - desc="get_invited_users_in_group", - ) - - async def get_rooms_in_group( - self, group_id: str, include_private: bool = False - ) -> List[_RoomInGroup]: - """Retrieve the rooms that belong to a given group. Does not return rooms that - lack members. - - Args: - group_id: The ID of the group to query for rooms - include_private: Whether to return private rooms in results - - Returns: - A list of dictionaries, each in the form of: - - { - "room_id": "!a_room_id:example.com", # The ID of the room - "is_public": False # Whether this is a public room or not - } - """ - - # TODO: Pagination - - def _get_rooms_in_group_txn(txn: LoggingTransaction) -> List[_RoomInGroup]: - sql = """ - SELECT room_id, is_public FROM group_rooms - WHERE group_id = ? - AND room_id IN ( - SELECT group_rooms.room_id FROM group_rooms - LEFT JOIN room_stats_current ON - group_rooms.room_id = room_stats_current.room_id - AND joined_members > 0 - AND local_users_in_room > 0 - LEFT JOIN rooms ON - group_rooms.room_id = rooms.room_id - AND (room_version <> '') = ? - ) - """ - args = [group_id, False] - - if not include_private: - sql += " AND is_public = ?" - args += [True] - - txn.execute(sql, args) - - return [ - {"room_id": room_id, "is_public": is_public} - for room_id, is_public in txn - ] - - return await self.db_pool.runInteraction( - "get_rooms_in_group", _get_rooms_in_group_txn - ) - - async def get_rooms_for_summary_by_category( - self, - group_id: str, - include_private: bool = False, - ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: - """Get the rooms and categories that should be included in a summary request - - Args: - group_id: The ID of the group to query the summary for - include_private: Whether to return private rooms in results - - Returns: - A tuple containing: - - * A list of dictionaries with the keys: - * "room_id": str, the room ID - * "is_public": bool, whether the room is public - * "category_id": str|None, the category ID if set, else None - * "order": int, the sort order of rooms - - * A dictionary with the key: - * category_id (str): a dictionary with the keys: - * "is_public": bool, whether the category is public - * "profile": str, the category profile - * "order": int, the sort order of rooms in this category - """ - - def _get_rooms_for_summary_txn( - txn: LoggingTransaction, - ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: - keyvalues: JsonDict = {"group_id": group_id} - if not include_private: - keyvalues["is_public"] = True - - sql = """ - SELECT room_id, is_public, category_id, room_order - FROM group_summary_rooms - WHERE group_id = ? - AND room_id IN ( - SELECT group_rooms.room_id FROM group_rooms - LEFT JOIN room_stats_current ON - group_rooms.room_id = room_stats_current.room_id - AND joined_members > 0 - AND local_users_in_room > 0 - LEFT JOIN rooms ON - group_rooms.room_id = rooms.room_id - AND (room_version <> '') = ? - ) - """ - - if not include_private: - sql += " AND is_public = ?" - txn.execute(sql, (group_id, False, True)) - else: - txn.execute(sql, (group_id, False)) - - rooms = [ - { - "room_id": row[0], - "is_public": row[1], - "category_id": row[2] if row[2] != _DEFAULT_CATEGORY_ID else None, - "order": row[3], - } - for row in txn - ] - - sql = """ - SELECT category_id, is_public, profile, cat_order - FROM group_summary_room_categories - INNER JOIN group_room_categories USING (group_id, category_id) - WHERE group_id = ? - """ - - if not include_private: - sql += " AND is_public = ?" - txn.execute(sql, (group_id, True)) - else: - txn.execute(sql, (group_id,)) - - categories = { - row[0]: { - "is_public": row[1], - "profile": db_to_json(row[2]), - "order": row[3], - } - for row in txn - } - - return rooms, categories - - return await self.db_pool.runInteraction( - "get_rooms_for_summary", _get_rooms_for_summary_txn - ) - - async def get_group_categories(self, group_id: str) -> JsonDict: - rows = await self.db_pool.simple_select_list( - table="group_room_categories", - keyvalues={"group_id": group_id}, - retcols=("category_id", "is_public", "profile"), - desc="get_group_categories", - ) - - return { - row["category_id"]: { - "is_public": row["is_public"], - "profile": db_to_json(row["profile"]), - } - for row in rows - } - - async def get_group_category(self, group_id: str, category_id: str) -> JsonDict: - category = await self.db_pool.simple_select_one( - table="group_room_categories", - keyvalues={"group_id": group_id, "category_id": category_id}, - retcols=("is_public", "profile"), - desc="get_group_category", - ) - - category["profile"] = db_to_json(category["profile"]) - - return category - - async def get_group_roles(self, group_id: str) -> JsonDict: - rows = await self.db_pool.simple_select_list( - table="group_roles", - keyvalues={"group_id": group_id}, - retcols=("role_id", "is_public", "profile"), - desc="get_group_roles", - ) - - return { - row["role_id"]: { - "is_public": row["is_public"], - "profile": db_to_json(row["profile"]), - } - for row in rows - } - - async def get_group_role(self, group_id: str, role_id: str) -> JsonDict: - role = await self.db_pool.simple_select_one( - table="group_roles", - keyvalues={"group_id": group_id, "role_id": role_id}, - retcols=("is_public", "profile"), - desc="get_group_role", - ) - - role["profile"] = db_to_json(role["profile"]) - - return role - - async def get_local_groups_for_room(self, room_id: str) -> List[str]: - """Get all of the local group that contain a given room - Args: - room_id: The ID of a room - Returns: - A list of group ids containing this room - """ - return await self.db_pool.simple_select_onecol( - table="group_rooms", - keyvalues={"room_id": room_id}, - retcol="group_id", - desc="get_local_groups_for_room", - ) - - async def get_users_for_summary_by_role( - self, group_id: str, include_private: bool = False - ) -> Tuple[List[JsonDict], JsonDict]: - """Get the users and roles that should be included in a summary request - - Returns: - ([users], [roles]) - """ - - def _get_users_for_summary_txn( - txn: LoggingTransaction, - ) -> Tuple[List[JsonDict], JsonDict]: - keyvalues: JsonDict = {"group_id": group_id} - if not include_private: - keyvalues["is_public"] = True - - sql = """ - SELECT user_id, is_public, role_id, user_order - FROM group_summary_users - WHERE group_id = ? - """ - - if not include_private: - sql += " AND is_public = ?" - txn.execute(sql, (group_id, True)) - else: - txn.execute(sql, (group_id,)) - - users = [ - { - "user_id": row[0], - "is_public": row[1], - "role_id": row[2] if row[2] != _DEFAULT_ROLE_ID else None, - "order": row[3], - } - for row in txn - ] - - sql = """ - SELECT role_id, is_public, profile, role_order - FROM group_summary_roles - INNER JOIN group_roles USING (group_id, role_id) - WHERE group_id = ? - """ - - if not include_private: - sql += " AND is_public = ?" - txn.execute(sql, (group_id, True)) - else: - txn.execute(sql, (group_id,)) - - roles = { - row[0]: { - "is_public": row[1], - "profile": db_to_json(row[2]), - "order": row[3], - } - for row in txn - } - - return users, roles - - return await self.db_pool.runInteraction( - "get_users_for_summary_by_role", _get_users_for_summary_txn - ) - - async def is_user_in_group(self, user_id: str, group_id: str) -> bool: - result = await self.db_pool.simple_select_one_onecol( - table="group_users", - keyvalues={"group_id": group_id, "user_id": user_id}, - retcol="user_id", - allow_none=True, - desc="is_user_in_group", - ) - return bool(result) - - async def is_user_admin_in_group( - self, group_id: str, user_id: str - ) -> Optional[bool]: - return await self.db_pool.simple_select_one_onecol( - table="group_users", - keyvalues={"group_id": group_id, "user_id": user_id}, - retcol="is_admin", - allow_none=True, - desc="is_user_admin_in_group", - ) - - async def is_user_invited_to_local_group( - self, group_id: str, user_id: str - ) -> Optional[bool]: - """Has the group server invited a user?""" - return await self.db_pool.simple_select_one_onecol( - table="group_invites", - keyvalues={"group_id": group_id, "user_id": user_id}, - retcol="user_id", - desc="is_user_invited_to_local_group", - allow_none=True, - ) - - async def get_users_membership_info_in_group( - self, group_id: str, user_id: str - ) -> JsonDict: - """Get a dict describing the membership of a user in a group. - - Example if joined: - - { - "membership": "join", - "is_public": True, - "is_privileged": False, - } - - Returns: - An empty dict if the user is not join/invite/etc - """ - - def _get_users_membership_in_group_txn(txn: LoggingTransaction) -> JsonDict: - row = self.db_pool.simple_select_one_txn( - txn, - table="group_users", - keyvalues={"group_id": group_id, "user_id": user_id}, - retcols=("is_admin", "is_public"), - allow_none=True, - ) - - if row: - return { - "membership": "join", - "is_public": row["is_public"], - "is_privileged": row["is_admin"], - } - - row = self.db_pool.simple_select_one_onecol_txn( - txn, - table="group_invites", - keyvalues={"group_id": group_id, "user_id": user_id}, - retcol="user_id", - allow_none=True, - ) - - if row: - return {"membership": "invite"} - - return {} - - return await self.db_pool.runInteraction( - "get_users_membership_info_in_group", _get_users_membership_in_group_txn - ) - - async def get_publicised_groups_for_user(self, user_id: str) -> List[str]: - """Get all groups a user is publicising""" - return await self.db_pool.simple_select_onecol( - table="local_group_membership", - keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True}, - retcol="group_id", - desc="get_publicised_groups_for_user", - ) - - async def get_attestations_need_renewals( - self, valid_until_ms: int - ) -> List[Dict[str, Any]]: - """Get all attestations that need to be renewed until givent time""" - - def _get_attestations_need_renewals_txn( - txn: LoggingTransaction, - ) -> List[Dict[str, Any]]: - sql = """ - SELECT group_id, user_id FROM group_attestations_renewals - WHERE valid_until_ms <= ? - """ - txn.execute(sql, (valid_until_ms,)) - return self.db_pool.cursor_to_dict(txn) - - return await self.db_pool.runInteraction( - "get_attestations_need_renewals", _get_attestations_need_renewals_txn - ) - - async def get_remote_attestation( - self, group_id: str, user_id: str - ) -> Optional[JsonDict]: - """Get the attestation that proves the remote agrees that the user is - in the group. - """ - row = await self.db_pool.simple_select_one( - table="group_attestations_remote", - keyvalues={"group_id": group_id, "user_id": user_id}, - retcols=("valid_until_ms", "attestation_json"), - desc="get_remote_attestation", - allow_none=True, - ) - - now = int(self._clock.time_msec()) - if row and now < row["valid_until_ms"]: - return db_to_json(row["attestation_json"]) - - return None - - async def get_joined_groups(self, user_id: str) -> List[str]: - return await self.db_pool.simple_select_onecol( - table="local_group_membership", - keyvalues={"user_id": user_id, "membership": "join"}, - retcol="group_id", - desc="get_joined_groups", - ) - - async def get_all_groups_for_user( - self, user_id: str, now_token: int - ) -> List[JsonDict]: - def _get_all_groups_for_user_txn(txn: LoggingTransaction) -> List[JsonDict]: - sql = """ - SELECT group_id, type, membership, u.content - FROM local_group_updates AS u - INNER JOIN local_group_membership USING (group_id, user_id) - WHERE user_id = ? AND membership != 'leave' - AND stream_id <= ? - """ - txn.execute(sql, (user_id, now_token)) - return [ - { - "group_id": row[0], - "type": row[1], - "membership": row[2], - "content": db_to_json(row[3]), - } - for row in txn - ] - - return await self.db_pool.runInteraction( - "get_all_groups_for_user", _get_all_groups_for_user_txn - ) - - async def get_groups_changes_for_user( - self, user_id: str, from_token: int, to_token: int - ) -> List[JsonDict]: - has_changed = self._group_updates_stream_cache.has_entity_changed( # type: ignore[attr-defined] - user_id, from_token - ) - if not has_changed: - return [] - - def _get_groups_changes_for_user_txn(txn: LoggingTransaction) -> List[JsonDict]: - sql = """ - SELECT group_id, membership, type, u.content - FROM local_group_updates AS u - INNER JOIN local_group_membership USING (group_id, user_id) - WHERE user_id = ? AND ? < stream_id AND stream_id <= ? - """ - txn.execute(sql, (user_id, from_token, to_token)) - return [ - { - "group_id": group_id, - "membership": membership, - "type": gtype, - "content": db_to_json(content_json), - } - for group_id, membership, gtype, content_json in txn - ] - - return await self.db_pool.runInteraction( - "get_groups_changes_for_user", _get_groups_changes_for_user_txn - ) - - async def get_all_groups_changes( - self, instance_name: str, last_id: int, current_id: int, limit: int - ) -> Tuple[List[Tuple[int, tuple]], int, bool]: - """Get updates for groups replication stream. - - Args: - instance_name: The writer we want to fetch updates from. Unused - here since there is only ever one writer. - last_id: The token to fetch updates from. Exclusive. - current_id: The token to fetch updates up to. Inclusive. - limit: The requested limit for the number of rows to return. The - function may return more or fewer rows. - - Returns: - A tuple consisting of: the updates, a token to use to fetch - subsequent updates, and whether we returned fewer rows than exists - between the requested tokens due to the limit. - - The token returned can be used in a subsequent call to this - function to get further updatees. - - The updates are a list of 2-tuples of stream ID and the row data - """ - - last_id = int(last_id) - has_changed = self._group_updates_stream_cache.has_any_entity_changed(last_id) # type: ignore[attr-defined] - - if not has_changed: - return [], current_id, False - - def _get_all_groups_changes_txn( - txn: LoggingTransaction, - ) -> Tuple[List[Tuple[int, tuple]], int, bool]: - sql = """ - SELECT stream_id, group_id, user_id, type, content - FROM local_group_updates - WHERE ? < stream_id AND stream_id <= ? - LIMIT ? - """ - txn.execute(sql, (last_id, current_id, limit)) - updates = cast( - List[Tuple[int, tuple]], - [ - (stream_id, (group_id, user_id, gtype, db_to_json(content_json))) - for stream_id, group_id, user_id, gtype, content_json in txn - ], - ) - - limited = False - upto_token = current_id - if len(updates) >= limit: - upto_token = updates[-1][0] - limited = True - - return updates, upto_token, limited - - return await self.db_pool.runInteraction( - "get_all_groups_changes", _get_all_groups_changes_txn - ) - - -class GroupServerStore(GroupServerWorkerStore): - async def set_group_join_policy(self, group_id: str, join_policy: str) -> None: - """Set the join policy of a group. - - join_policy can be one of: - * "invite" - * "open" - """ - await self.db_pool.simple_update_one( - table="groups", - keyvalues={"group_id": group_id}, - updatevalues={"join_policy": join_policy}, - desc="set_group_join_policy", - ) - - async def add_room_to_summary( - self, - group_id: str, - room_id: str, - category_id: Optional[str], - order: Optional[int], - is_public: Optional[bool], - ) -> None: - """Add (or update) room's entry in summary. - - Args: - group_id - room_id - category_id: If not None then adds the category to the end of - the summary if its not already there. - order: If not None inserts the room at that position, e.g. an order - of 1 will put the room first. Otherwise, the room gets added to - the end. - is_public - """ - await self.db_pool.runInteraction( - "add_room_to_summary", - self._add_room_to_summary_txn, - group_id, - room_id, - category_id, - order, - is_public, - ) - - def _add_room_to_summary_txn( - self, - txn: LoggingTransaction, - group_id: str, - room_id: str, - category_id: Optional[str], - order: Optional[int], - is_public: Optional[bool], - ) -> None: - """Add (or update) room's entry in summary. - - Args: - txn - group_id - room_id - category_id: If not None then adds the category to the end of - the summary if its not already there. - order: If not None inserts the room at that position, e.g. an order - of 1 will put the room first. Otherwise, the room gets added to - the end. - is_public - """ - room_in_group = self.db_pool.simple_select_one_onecol_txn( - txn, - table="group_rooms", - keyvalues={"group_id": group_id, "room_id": room_id}, - retcol="room_id", - allow_none=True, - ) - if not room_in_group: - raise SynapseError(400, "room not in group") - - if category_id is None: - category_id = _DEFAULT_CATEGORY_ID - else: - cat_exists = self.db_pool.simple_select_one_onecol_txn( - txn, - table="group_room_categories", - keyvalues={"group_id": group_id, "category_id": category_id}, - retcol="group_id", - allow_none=True, - ) - if not cat_exists: - raise SynapseError(400, "Category doesn't exist") - - # TODO: Check category is part of summary already - cat_exists = self.db_pool.simple_select_one_onecol_txn( - txn, - table="group_summary_room_categories", - keyvalues={"group_id": group_id, "category_id": category_id}, - retcol="group_id", - allow_none=True, - ) - if not cat_exists: - # If not, add it with an order larger than all others - txn.execute( - """ - INSERT INTO group_summary_room_categories - (group_id, category_id, cat_order) - SELECT ?, ?, COALESCE(MAX(cat_order), 0) + 1 - FROM group_summary_room_categories - WHERE group_id = ? AND category_id = ? - """, - (group_id, category_id, group_id, category_id), - ) - - existing = self.db_pool.simple_select_one_txn( - txn, - table="group_summary_rooms", - keyvalues={ - "group_id": group_id, - "room_id": room_id, - "category_id": category_id, - }, - retcols=("room_order", "is_public"), - allow_none=True, - ) - - if order is not None: - # Shuffle other room orders that come after the given order - sql = """ - UPDATE group_summary_rooms SET room_order = room_order + 1 - WHERE group_id = ? AND category_id = ? AND room_order >= ? - """ - txn.execute(sql, (group_id, category_id, order)) - elif not existing: - sql = """ - SELECT COALESCE(MAX(room_order), 0) + 1 FROM group_summary_rooms - WHERE group_id = ? AND category_id = ? - """ - txn.execute(sql, (group_id, category_id)) - (order,) = cast(Tuple[int], txn.fetchone()) - - if existing: - to_update = {} - if order is not None: - to_update["room_order"] = order - if is_public is not None: - to_update["is_public"] = is_public - self.db_pool.simple_update_txn( - txn, - table="group_summary_rooms", - keyvalues={ - "group_id": group_id, - "category_id": category_id, - "room_id": room_id, - }, - updatevalues=to_update, - ) - else: - if is_public is None: - is_public = True - - self.db_pool.simple_insert_txn( - txn, - table="group_summary_rooms", - values={ - "group_id": group_id, - "category_id": category_id, - "room_id": room_id, - "room_order": order, - "is_public": is_public, - }, - ) - - async def remove_room_from_summary( - self, group_id: str, room_id: str, category_id: Optional[str] - ) -> int: - if category_id is None: - category_id = _DEFAULT_CATEGORY_ID - - return await self.db_pool.simple_delete( - table="group_summary_rooms", - keyvalues={ - "group_id": group_id, - "category_id": category_id, - "room_id": room_id, - }, - desc="remove_room_from_summary", - ) - - async def upsert_group_category( - self, - group_id: str, - category_id: str, - profile: Optional[JsonDict], - is_public: Optional[bool], - ) -> None: - """Add/update room category for group""" - insertion_values: JsonDict = {} - update_values: JsonDict = {"category_id": category_id} # This cannot be empty - - if profile is None: - insertion_values["profile"] = "{}" - else: - update_values["profile"] = json_encoder.encode(profile) - - if is_public is None: - insertion_values["is_public"] = True - else: - update_values["is_public"] = is_public - - await self.db_pool.simple_upsert( - table="group_room_categories", - keyvalues={"group_id": group_id, "category_id": category_id}, - values=update_values, - insertion_values=insertion_values, - desc="upsert_group_category", - ) - - async def remove_group_category(self, group_id: str, category_id: str) -> int: - return await self.db_pool.simple_delete( - table="group_room_categories", - keyvalues={"group_id": group_id, "category_id": category_id}, - desc="remove_group_category", - ) - - async def upsert_group_role( - self, - group_id: str, - role_id: str, - profile: Optional[JsonDict], - is_public: Optional[bool], - ) -> None: - """Add/remove user role""" - insertion_values: JsonDict = {} - update_values: JsonDict = {"role_id": role_id} # This cannot be empty - - if profile is None: - insertion_values["profile"] = "{}" - else: - update_values["profile"] = json_encoder.encode(profile) - - if is_public is None: - insertion_values["is_public"] = True - else: - update_values["is_public"] = is_public - - await self.db_pool.simple_upsert( - table="group_roles", - keyvalues={"group_id": group_id, "role_id": role_id}, - values=update_values, - insertion_values=insertion_values, - desc="upsert_group_role", - ) - - async def remove_group_role(self, group_id: str, role_id: str) -> int: - return await self.db_pool.simple_delete( - table="group_roles", - keyvalues={"group_id": group_id, "role_id": role_id}, - desc="remove_group_role", - ) - - async def add_user_to_summary( - self, - group_id: str, - user_id: str, - role_id: Optional[str], - order: Optional[int], - is_public: Optional[bool], - ) -> None: - """Add (or update) user's entry in summary. - - Args: - group_id - user_id - role_id: If not None then adds the role to the end of the summary if - its not already there. - order: If not None inserts the user at that position, e.g. an order - of 1 will put the user first. Otherwise, the user gets added to - the end. - is_public - """ - await self.db_pool.runInteraction( - "add_user_to_summary", - self._add_user_to_summary_txn, - group_id, - user_id, - role_id, - order, - is_public, - ) - - def _add_user_to_summary_txn( - self, - txn: LoggingTransaction, - group_id: str, - user_id: str, - role_id: Optional[str], - order: Optional[int], - is_public: Optional[bool], - ) -> None: - """Add (or update) user's entry in summary. - - Args: - txn - group_id - user_id - role_id: If not None then adds the role to the end of the summary if - its not already there. - order: If not None inserts the user at that position, e.g. an order - of 1 will put the user first. Otherwise, the user gets added to - the end. - is_public - """ - user_in_group = self.db_pool.simple_select_one_onecol_txn( - txn, - table="group_users", - keyvalues={"group_id": group_id, "user_id": user_id}, - retcol="user_id", - allow_none=True, - ) - if not user_in_group: - raise SynapseError(400, "user not in group") - - if role_id is None: - role_id = _DEFAULT_ROLE_ID - else: - role_exists = self.db_pool.simple_select_one_onecol_txn( - txn, - table="group_roles", - keyvalues={"group_id": group_id, "role_id": role_id}, - retcol="group_id", - allow_none=True, - ) - if not role_exists: - raise SynapseError(400, "Role doesn't exist") - - # TODO: Check role is part of the summary already - role_exists = self.db_pool.simple_select_one_onecol_txn( - txn, - table="group_summary_roles", - keyvalues={"group_id": group_id, "role_id": role_id}, - retcol="group_id", - allow_none=True, - ) - if not role_exists: - # If not, add it with an order larger than all others - txn.execute( - """ - INSERT INTO group_summary_roles - (group_id, role_id, role_order) - SELECT ?, ?, COALESCE(MAX(role_order), 0) + 1 - FROM group_summary_roles - WHERE group_id = ? AND role_id = ? - """, - (group_id, role_id, group_id, role_id), - ) - - existing = self.db_pool.simple_select_one_txn( - txn, - table="group_summary_users", - keyvalues={"group_id": group_id, "user_id": user_id, "role_id": role_id}, - retcols=("user_order", "is_public"), - allow_none=True, - ) - - if order is not None: - # Shuffle other users orders that come after the given order - sql = """ - UPDATE group_summary_users SET user_order = user_order + 1 - WHERE group_id = ? AND role_id = ? AND user_order >= ? - """ - txn.execute(sql, (group_id, role_id, order)) - elif not existing: - sql = """ - SELECT COALESCE(MAX(user_order), 0) + 1 FROM group_summary_users - WHERE group_id = ? AND role_id = ? - """ - txn.execute(sql, (group_id, role_id)) - (order,) = cast(Tuple[int], txn.fetchone()) - - if existing: - to_update = {} - if order is not None: - to_update["user_order"] = order - if is_public is not None: - to_update["is_public"] = is_public - self.db_pool.simple_update_txn( - txn, - table="group_summary_users", - keyvalues={ - "group_id": group_id, - "role_id": role_id, - "user_id": user_id, - }, - updatevalues=to_update, - ) - else: - if is_public is None: - is_public = True - - self.db_pool.simple_insert_txn( - txn, - table="group_summary_users", - values={ - "group_id": group_id, - "role_id": role_id, - "user_id": user_id, - "user_order": order, - "is_public": is_public, - }, - ) - - async def remove_user_from_summary( - self, group_id: str, user_id: str, role_id: Optional[str] - ) -> int: - if role_id is None: - role_id = _DEFAULT_ROLE_ID - - return await self.db_pool.simple_delete( - table="group_summary_users", - keyvalues={"group_id": group_id, "role_id": role_id, "user_id": user_id}, - desc="remove_user_from_summary", - ) - - async def add_group_invite(self, group_id: str, user_id: str) -> None: - """Record that the group server has invited a user""" - await self.db_pool.simple_insert( - table="group_invites", - values={"group_id": group_id, "user_id": user_id}, - desc="add_group_invite", - ) - - async def add_user_to_group( - self, - group_id: str, - user_id: str, - is_admin: bool = False, - is_public: bool = True, - local_attestation: Optional[dict] = None, - remote_attestation: Optional[dict] = None, - ) -> None: - """Add a user to the group server. - - Args: - group_id - user_id - is_admin - is_public - local_attestation: The attestation the GS created to give to the remote - server. Optional if the user and group are on the same server - remote_attestation: The attestation given to GS by remote server. - Optional if the user and group are on the same server - """ - - def _add_user_to_group_txn(txn: LoggingTransaction) -> None: - self.db_pool.simple_insert_txn( - txn, - table="group_users", - values={ - "group_id": group_id, - "user_id": user_id, - "is_admin": is_admin, - "is_public": is_public, - }, - ) - - self.db_pool.simple_delete_txn( - txn, - table="group_invites", - keyvalues={"group_id": group_id, "user_id": user_id}, - ) - - if local_attestation: - self.db_pool.simple_insert_txn( - txn, - table="group_attestations_renewals", - values={ - "group_id": group_id, - "user_id": user_id, - "valid_until_ms": local_attestation["valid_until_ms"], - }, - ) - if remote_attestation: - self.db_pool.simple_insert_txn( - txn, - table="group_attestations_remote", - values={ - "group_id": group_id, - "user_id": user_id, - "valid_until_ms": remote_attestation["valid_until_ms"], - "attestation_json": json_encoder.encode(remote_attestation), - }, - ) - - await self.db_pool.runInteraction("add_user_to_group", _add_user_to_group_txn) - - async def remove_user_from_group(self, group_id: str, user_id: str) -> None: - def _remove_user_from_group_txn(txn: LoggingTransaction) -> None: - self.db_pool.simple_delete_txn( - txn, - table="group_users", - keyvalues={"group_id": group_id, "user_id": user_id}, - ) - self.db_pool.simple_delete_txn( - txn, - table="group_invites", - keyvalues={"group_id": group_id, "user_id": user_id}, - ) - self.db_pool.simple_delete_txn( - txn, - table="group_attestations_renewals", - keyvalues={"group_id": group_id, "user_id": user_id}, - ) - self.db_pool.simple_delete_txn( - txn, - table="group_attestations_remote", - keyvalues={"group_id": group_id, "user_id": user_id}, - ) - self.db_pool.simple_delete_txn( - txn, - table="group_summary_users", - keyvalues={"group_id": group_id, "user_id": user_id}, - ) - - await self.db_pool.runInteraction( - "remove_user_from_group", _remove_user_from_group_txn - ) - - async def add_room_to_group( - self, group_id: str, room_id: str, is_public: bool - ) -> None: - await self.db_pool.simple_insert( - table="group_rooms", - values={"group_id": group_id, "room_id": room_id, "is_public": is_public}, - desc="add_room_to_group", - ) - - async def update_room_in_group_visibility( - self, group_id: str, room_id: str, is_public: bool - ) -> int: - return await self.db_pool.simple_update( - table="group_rooms", - keyvalues={"group_id": group_id, "room_id": room_id}, - updatevalues={"is_public": is_public}, - desc="update_room_in_group_visibility", - ) - - async def remove_room_from_group(self, group_id: str, room_id: str) -> None: - def _remove_room_from_group_txn(txn: LoggingTransaction) -> None: - self.db_pool.simple_delete_txn( - txn, - table="group_rooms", - keyvalues={"group_id": group_id, "room_id": room_id}, - ) - - self.db_pool.simple_delete_txn( - txn, - table="group_summary_rooms", - keyvalues={"group_id": group_id, "room_id": room_id}, - ) - - await self.db_pool.runInteraction( - "remove_room_from_group", _remove_room_from_group_txn - ) - - async def update_group_publicity( - self, group_id: str, user_id: str, publicise: bool - ) -> None: - """Update whether the user is publicising their membership of the group""" - await self.db_pool.simple_update_one( - table="local_group_membership", - keyvalues={"group_id": group_id, "user_id": user_id}, - updatevalues={"is_publicised": publicise}, - desc="update_group_publicity", - ) - - async def register_user_group_membership( - self, - group_id: str, - user_id: str, - membership: str, - is_admin: bool = False, - content: Optional[JsonDict] = None, - local_attestation: Optional[dict] = None, - remote_attestation: Optional[dict] = None, - is_publicised: bool = False, - ) -> int: - """Registers that a local user is a member of a (local or remote) group. - - Args: - group_id: The group the member is being added to. - user_id: THe user ID to add to the group. - membership: The type of group membership. - is_admin: Whether the user should be added as a group admin. - content: Content of the membership, e.g. includes the inviter - if the user has been invited. - local_attestation: If remote group then store the fact that we - have given out an attestation, else None. - remote_attestation: If remote group then store the remote - attestation from the group, else None. - is_publicised: Whether this should be publicised. - """ - - content = content or {} - - def _register_user_group_membership_txn( - txn: LoggingTransaction, next_id: int - ) -> int: - # TODO: Upsert? - self.db_pool.simple_delete_txn( - txn, - table="local_group_membership", - keyvalues={"group_id": group_id, "user_id": user_id}, - ) - self.db_pool.simple_insert_txn( - txn, - table="local_group_membership", - values={ - "group_id": group_id, - "user_id": user_id, - "is_admin": is_admin, - "membership": membership, - "is_publicised": is_publicised, - "content": json_encoder.encode(content), - }, - ) - - self.db_pool.simple_insert_txn( - txn, - table="local_group_updates", - values={ - "stream_id": next_id, - "group_id": group_id, - "user_id": user_id, - "type": "membership", - "content": json_encoder.encode( - {"membership": membership, "content": content} - ), - }, - ) - self._group_updates_stream_cache.entity_has_changed(user_id, next_id) # type: ignore[attr-defined] - - # TODO: Insert profile to ensure it comes down stream if its a join. - - if membership == "join": - if local_attestation: - self.db_pool.simple_insert_txn( - txn, - table="group_attestations_renewals", - values={ - "group_id": group_id, - "user_id": user_id, - "valid_until_ms": local_attestation["valid_until_ms"], - }, - ) - if remote_attestation: - self.db_pool.simple_insert_txn( - txn, - table="group_attestations_remote", - values={ - "group_id": group_id, - "user_id": user_id, - "valid_until_ms": remote_attestation["valid_until_ms"], - "attestation_json": json_encoder.encode(remote_attestation), - }, - ) - else: - self.db_pool.simple_delete_txn( - txn, - table="group_attestations_renewals", - keyvalues={"group_id": group_id, "user_id": user_id}, - ) - self.db_pool.simple_delete_txn( - txn, - table="group_attestations_remote", - keyvalues={"group_id": group_id, "user_id": user_id}, - ) - - return next_id - - async with self._group_updates_id_gen.get_next() as next_id: # type: ignore[attr-defined] - res = await self.db_pool.runInteraction( - "register_user_group_membership", - _register_user_group_membership_txn, - next_id, - ) - return res - - async def create_group( - self, - group_id: str, - user_id: str, - name: str, - avatar_url: str, - short_description: str, - long_description: str, - ) -> None: - await self.db_pool.simple_insert( - table="groups", - values={ - "group_id": group_id, - "name": name, - "avatar_url": avatar_url, - "short_description": short_description, - "long_description": long_description, - "is_public": True, - }, - desc="create_group", - ) - - async def update_group_profile(self, group_id: str, profile: JsonDict) -> None: - await self.db_pool.simple_update_one( - table="groups", - keyvalues={"group_id": group_id}, - updatevalues=profile, - desc="update_group_profile", - ) - - async def update_attestation_renewal( - self, group_id: str, user_id: str, attestation: dict - ) -> None: - """Update an attestation that we have renewed""" - await self.db_pool.simple_update_one( - table="group_attestations_renewals", - keyvalues={"group_id": group_id, "user_id": user_id}, - updatevalues={"valid_until_ms": attestation["valid_until_ms"]}, - desc="update_attestation_renewal", - ) - - async def update_remote_attestion( - self, group_id: str, user_id: str, attestation: dict - ) -> None: - """Update an attestation that a remote has renewed""" - await self.db_pool.simple_update_one( - table="group_attestations_remote", - keyvalues={"group_id": group_id, "user_id": user_id}, - updatevalues={ - "valid_until_ms": attestation["valid_until_ms"], - "attestation_json": json_encoder.encode(attestation), - }, - desc="update_remote_attestion", - ) - - async def remove_attestation_renewal(self, group_id: str, user_id: str) -> int: - """Remove an attestation that we thought we should renew, but actually - shouldn't. Ideally this would never get called as we would never - incorrectly try and do attestations for local users on local groups. - - Args: - group_id - user_id - """ - return await self.db_pool.simple_delete( - table="group_attestations_renewals", - keyvalues={"group_id": group_id, "user_id": user_id}, - desc="remove_attestation_renewal", - ) - - def get_group_stream_token(self) -> int: - return self._group_updates_id_gen.get_current_token() # type: ignore[attr-defined] - - async def delete_group(self, group_id: str) -> None: - """Deletes a group fully from the database. - - Args: - group_id: The group ID to delete. - """ - - def _delete_group_txn(txn: LoggingTransaction) -> None: - tables = [ - "groups", - "group_users", - "group_invites", - "group_rooms", - "group_summary_rooms", - "group_summary_room_categories", - "group_room_categories", - "group_summary_users", - "group_summary_roles", - "group_roles", - "group_attestations_renewals", - "group_attestations_remote", - ] - - for table in tables: - self.db_pool.simple_delete_txn( - txn, table=table, keyvalues={"group_id": group_id} - ) - - await self.db_pool.runInteraction("delete_group", _delete_group_txn) From 5949ab86f8db0ef3dac2063e42210030f17786fb Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 1 Jun 2022 11:57:49 +0100 Subject: [PATCH 63/74] Fix potential thumbnail memory leaks. (#12932) --- changelog.d/12932.bugfix | 1 + synapse/rest/media/v1/media_repository.py | 269 ++++++++++++---------- synapse/rest/media/v1/thumbnailer.py | 71 +++++- 3 files changed, 204 insertions(+), 137 deletions(-) create mode 100644 changelog.d/12932.bugfix diff --git a/changelog.d/12932.bugfix b/changelog.d/12932.bugfix new file mode 100644 index 0000000000..506f92b427 --- /dev/null +++ b/changelog.d/12932.bugfix @@ -0,0 +1 @@ +Fix potential memory leak when generating thumbnails. diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 20af366538..a551458a9f 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -587,15 +587,16 @@ class MediaRepository: ) return None - t_byte_source = await defer_to_thread( - self.hs.get_reactor(), - self._generate_thumbnail, - thumbnailer, - t_width, - t_height, - t_method, - t_type, - ) + with thumbnailer: + t_byte_source = await defer_to_thread( + self.hs.get_reactor(), + self._generate_thumbnail, + thumbnailer, + t_width, + t_height, + t_method, + t_type, + ) if t_byte_source: try: @@ -657,15 +658,16 @@ class MediaRepository: ) return None - t_byte_source = await defer_to_thread( - self.hs.get_reactor(), - self._generate_thumbnail, - thumbnailer, - t_width, - t_height, - t_method, - t_type, - ) + with thumbnailer: + t_byte_source = await defer_to_thread( + self.hs.get_reactor(), + self._generate_thumbnail, + thumbnailer, + t_width, + t_height, + t_method, + t_type, + ) if t_byte_source: try: @@ -749,119 +751,134 @@ class MediaRepository: ) return None - m_width = thumbnailer.width - m_height = thumbnailer.height + with thumbnailer: + m_width = thumbnailer.width + m_height = thumbnailer.height - if m_width * m_height >= self.max_image_pixels: - logger.info( - "Image too large to thumbnail %r x %r > %r", - m_width, - m_height, - self.max_image_pixels, - ) - return None - - if thumbnailer.transpose_method is not None: - m_width, m_height = await defer_to_thread( - self.hs.get_reactor(), thumbnailer.transpose - ) - - # We deduplicate the thumbnail sizes by ignoring the cropped versions if - # they have the same dimensions of a scaled one. - thumbnails: Dict[Tuple[int, int, str], str] = {} - for requirement in requirements: - if requirement.method == "crop": - thumbnails.setdefault( - (requirement.width, requirement.height, requirement.media_type), - requirement.method, + if m_width * m_height >= self.max_image_pixels: + logger.info( + "Image too large to thumbnail %r x %r > %r", + m_width, + m_height, + self.max_image_pixels, ) - elif requirement.method == "scale": - t_width, t_height = thumbnailer.aspect( - requirement.width, requirement.height + return None + + if thumbnailer.transpose_method is not None: + m_width, m_height = await defer_to_thread( + self.hs.get_reactor(), thumbnailer.transpose ) - t_width = min(m_width, t_width) - t_height = min(m_height, t_height) - thumbnails[ - (t_width, t_height, requirement.media_type) - ] = requirement.method - # Now we generate the thumbnails for each dimension, store it - for (t_width, t_height, t_type), t_method in thumbnails.items(): - # Generate the thumbnail - if t_method == "crop": - t_byte_source = await defer_to_thread( - self.hs.get_reactor(), thumbnailer.crop, t_width, t_height, t_type - ) - elif t_method == "scale": - t_byte_source = await defer_to_thread( - self.hs.get_reactor(), thumbnailer.scale, t_width, t_height, t_type - ) - else: - logger.error("Unrecognized method: %r", t_method) - continue - - if not t_byte_source: - continue - - file_info = FileInfo( - server_name=server_name, - file_id=file_id, - url_cache=url_cache, - thumbnail=ThumbnailInfo( - width=t_width, - height=t_height, - method=t_method, - type=t_type, - ), - ) - - with self.media_storage.store_into_file(file_info) as (f, fname, finish): - try: - await self.media_storage.write_to_file(t_byte_source, f) - await finish() - finally: - t_byte_source.close() - - t_len = os.path.getsize(fname) - - # Write to database - if server_name: - # Multiple remote media download requests can race (when - # using multiple media repos), so this may throw a violation - # constraint exception. If it does we'll delete the newly - # generated thumbnail from disk (as we're in the ctx - # manager). - # - # However: we've already called `finish()` so we may have - # also written to the storage providers. This is preferable - # to the alternative where we call `finish()` *after* this, - # where we could end up having an entry in the DB but fail - # to write the files to the storage providers. - try: - await self.store.store_remote_media_thumbnail( - server_name, - media_id, - file_id, - t_width, - t_height, - t_type, - t_method, - t_len, - ) - except Exception as e: - thumbnail_exists = await self.store.get_remote_media_thumbnail( - server_name, - media_id, - t_width, - t_height, - t_type, - ) - if not thumbnail_exists: - raise e - else: - await self.store.store_local_thumbnail( - media_id, t_width, t_height, t_type, t_method, t_len + # We deduplicate the thumbnail sizes by ignoring the cropped versions if + # they have the same dimensions of a scaled one. + thumbnails: Dict[Tuple[int, int, str], str] = {} + for requirement in requirements: + if requirement.method == "crop": + thumbnails.setdefault( + (requirement.width, requirement.height, requirement.media_type), + requirement.method, ) + elif requirement.method == "scale": + t_width, t_height = thumbnailer.aspect( + requirement.width, requirement.height + ) + t_width = min(m_width, t_width) + t_height = min(m_height, t_height) + thumbnails[ + (t_width, t_height, requirement.media_type) + ] = requirement.method + + # Now we generate the thumbnails for each dimension, store it + for (t_width, t_height, t_type), t_method in thumbnails.items(): + # Generate the thumbnail + if t_method == "crop": + t_byte_source = await defer_to_thread( + self.hs.get_reactor(), + thumbnailer.crop, + t_width, + t_height, + t_type, + ) + elif t_method == "scale": + t_byte_source = await defer_to_thread( + self.hs.get_reactor(), + thumbnailer.scale, + t_width, + t_height, + t_type, + ) + else: + logger.error("Unrecognized method: %r", t_method) + continue + + if not t_byte_source: + continue + + file_info = FileInfo( + server_name=server_name, + file_id=file_id, + url_cache=url_cache, + thumbnail=ThumbnailInfo( + width=t_width, + height=t_height, + method=t_method, + type=t_type, + ), + ) + + with self.media_storage.store_into_file(file_info) as ( + f, + fname, + finish, + ): + try: + await self.media_storage.write_to_file(t_byte_source, f) + await finish() + finally: + t_byte_source.close() + + t_len = os.path.getsize(fname) + + # Write to database + if server_name: + # Multiple remote media download requests can race (when + # using multiple media repos), so this may throw a violation + # constraint exception. If it does we'll delete the newly + # generated thumbnail from disk (as we're in the ctx + # manager). + # + # However: we've already called `finish()` so we may have + # also written to the storage providers. This is preferable + # to the alternative where we call `finish()` *after* this, + # where we could end up having an entry in the DB but fail + # to write the files to the storage providers. + try: + await self.store.store_remote_media_thumbnail( + server_name, + media_id, + file_id, + t_width, + t_height, + t_type, + t_method, + t_len, + ) + except Exception as e: + thumbnail_exists = ( + await self.store.get_remote_media_thumbnail( + server_name, + media_id, + t_width, + t_height, + t_type, + ) + ) + if not thumbnail_exists: + raise e + else: + await self.store.store_local_thumbnail( + media_id, t_width, t_height, t_type, t_method, t_len + ) return {"width": m_width, "height": m_height} diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py index 390491eb83..9b93b9b4f6 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/rest/media/v1/thumbnailer.py @@ -14,7 +14,8 @@ # limitations under the License. import logging from io import BytesIO -from typing import Tuple +from types import TracebackType +from typing import Optional, Tuple, Type from PIL import Image @@ -45,6 +46,9 @@ class Thumbnailer: Image.MAX_IMAGE_PIXELS = max_image_pixels def __init__(self, input_path: str): + # Have we closed the image? + self._closed = False + try: self.image = Image.open(input_path) except OSError as e: @@ -89,7 +93,8 @@ class Thumbnailer: # Safety: `transpose` takes an int rather than e.g. an IntEnum. # self.transpose_method is set above to be a value in # EXIF_TRANSPOSE_MAPPINGS, and that only contains correct values. - self.image = self.image.transpose(self.transpose_method) # type: ignore[arg-type] + with self.image: + self.image = self.image.transpose(self.transpose_method) # type: ignore[arg-type] self.width, self.height = self.image.size self.transpose_method = None # We don't need EXIF any more @@ -122,9 +127,11 @@ class Thumbnailer: # If the image has transparency, use RGBA instead. if self.image.mode in ["1", "L", "P"]: if self.image.info.get("transparency", None) is not None: - self.image = self.image.convert("RGBA") + with self.image: + self.image = self.image.convert("RGBA") else: - self.image = self.image.convert("RGB") + with self.image: + self.image = self.image.convert("RGB") return self.image.resize((width, height), Image.ANTIALIAS) def scale(self, width: int, height: int, output_type: str) -> BytesIO: @@ -133,8 +140,8 @@ class Thumbnailer: Returns: BytesIO: the bytes of the encoded image ready to be written to disk """ - scaled = self._resize(width, height) - return self._encode_image(scaled, output_type) + with self._resize(width, height) as scaled: + return self._encode_image(scaled, output_type) def crop(self, width: int, height: int, output_type: str) -> BytesIO: """Rescales and crops the image to the given dimensions preserving @@ -151,18 +158,21 @@ class Thumbnailer: BytesIO: the bytes of the encoded image ready to be written to disk """ if width * self.height > height * self.width: + scaled_width = width scaled_height = (width * self.height) // self.width - scaled_image = self._resize(width, scaled_height) crop_top = (scaled_height - height) // 2 crop_bottom = height + crop_top - cropped = scaled_image.crop((0, crop_top, width, crop_bottom)) + crop = (0, crop_top, width, crop_bottom) else: scaled_width = (height * self.width) // self.height - scaled_image = self._resize(scaled_width, height) + scaled_height = height crop_left = (scaled_width - width) // 2 crop_right = width + crop_left - cropped = scaled_image.crop((crop_left, 0, crop_right, height)) - return self._encode_image(cropped, output_type) + crop = (crop_left, 0, crop_right, height) + + with self._resize(scaled_width, scaled_height) as scaled_image: + with scaled_image.crop(crop) as cropped: + return self._encode_image(cropped, output_type) def _encode_image(self, output_image: Image.Image, output_type: str) -> BytesIO: output_bytes_io = BytesIO() @@ -171,3 +181,42 @@ class Thumbnailer: output_image = output_image.convert("RGB") output_image.save(output_bytes_io, fmt, quality=80) return output_bytes_io + + def close(self) -> None: + """Closes the underlying image file. + + Once closed no other functions can be called. + + Can be called multiple times. + """ + + if self._closed: + return + + self._closed = True + + # Since we run this on the finalizer then we need to handle `__init__` + # raising an exception before it can define `self.image`. + image = getattr(self, "image", None) + if image is None: + return + + image.close() + + def __enter__(self) -> "Thumbnailer": + """Make `Thumbnailer` a context manager that calls `close` on + `__exit__`. + """ + return self + + def __exit__( + self, + type: Optional[Type[BaseException]], + value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: + self.close() + + def __del__(self) -> None: + # Make sure we actually do close the image, rather than leak data. + self.close() From 79dadf7216836170af2ac5ef130bfc012b86821c Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 1 Jun 2022 12:29:51 +0100 Subject: [PATCH 64/74] Fix 404 on `/sync` when the last event is a redaction of an unknown/purged event (#12905) Currently, we try to pull the event corresponding to a sync token from the database. However, when we fetch redaction events, we check the target of that redaction (because we aren't allowed to send redactions to clients without validating them). So, if the sync token points to a redaction of an event that we don't have, we have a problem. It turns out we don't really need that event, and can just work with its ID and metadata, which sidesteps the whole problem. --- changelog.d/12905.bugfix | 1 + synapse/handlers/message.py | 114 +++++++++++++++-------- synapse/handlers/sync.py | 27 ++++-- synapse/storage/databases/main/state.py | 12 ++- synapse/storage/databases/main/stream.py | 12 +-- synapse/visibility.py | 28 ++++-- 6 files changed, 129 insertions(+), 65 deletions(-) create mode 100644 changelog.d/12905.bugfix diff --git a/changelog.d/12905.bugfix b/changelog.d/12905.bugfix new file mode 100644 index 0000000000..67e95d0398 --- /dev/null +++ b/changelog.d/12905.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.58.0 where `/sync` would fail if the most recent event in a room was a redaction of an event that has since been purged. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index cf7c2d1979..ac911a2ddc 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -28,6 +28,7 @@ from synapse.api.constants import ( EventContentFields, EventTypes, GuestAccess, + HistoryVisibility, Membership, RelationTypes, UserTypes, @@ -66,7 +67,7 @@ from synapse.util import json_decoder, json_encoder, log_failure, unwrapFirstErr from synapse.util.async_helpers import Linearizer, gather_results from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.metrics import measure_func -from synapse.visibility import filter_events_for_client +from synapse.visibility import get_effective_room_visibility_from_state if TYPE_CHECKING: from synapse.events.third_party_rules import ThirdPartyEventRules @@ -182,51 +183,31 @@ class MessageHandler: state_filter = state_filter or StateFilter.all() if at_token: - last_event = await self.store.get_last_event_in_room_before_stream_ordering( - room_id, - end_token=at_token.room_key, + last_event_id = ( + await self.store.get_last_event_in_room_before_stream_ordering( + room_id, + end_token=at_token.room_key, + ) ) - if not last_event: + if not last_event_id: raise NotFoundError("Can't find event for token %s" % (at_token,)) - # check whether the user is in the room at that time to determine - # whether they should be treated as peeking. - state_map = await self._state_storage_controller.get_state_for_event( - last_event.event_id, - StateFilter.from_types([(EventTypes.Member, user_id)]), - ) - - joined = False - membership_event = state_map.get((EventTypes.Member, user_id)) - if membership_event: - joined = membership_event.membership == Membership.JOIN - - is_peeking = not joined - - visible_events = await filter_events_for_client( - self._storage_controllers, - user_id, - [last_event], - filter_send_to_client=False, - is_peeking=is_peeking, - ) - - if visible_events: - room_state_events = ( - await self._state_storage_controller.get_state_for_events( - [last_event.event_id], state_filter=state_filter - ) - ) - room_state: Mapping[Any, EventBase] = room_state_events[ - last_event.event_id - ] - else: + if not await self._user_can_see_state_at_event( + user_id, room_id, last_event_id + ): raise AuthError( 403, "User %s not allowed to view events in room %s at token %s" % (user_id, room_id, at_token), ) + + room_state_events = ( + await self._state_storage_controller.get_state_for_events( + [last_event_id], state_filter=state_filter + ) + ) + room_state: Mapping[Any, EventBase] = room_state_events[last_event_id] else: ( membership, @@ -256,6 +237,65 @@ class MessageHandler: events = self._event_serializer.serialize_events(room_state.values(), now) return events + async def _user_can_see_state_at_event( + self, user_id: str, room_id: str, event_id: str + ) -> bool: + # check whether the user was in the room, and the history visibility, + # at that time. + state_map = await self._state_storage_controller.get_state_for_event( + event_id, + StateFilter.from_types( + [ + (EventTypes.Member, user_id), + (EventTypes.RoomHistoryVisibility, ""), + ] + ), + ) + + membership = None + membership_event = state_map.get((EventTypes.Member, user_id)) + if membership_event: + membership = membership_event.membership + + # if the user was a member of the room at the time of the event, + # they can see it. + if membership == Membership.JOIN: + return True + + # otherwise, it depends on the history visibility. + visibility = get_effective_room_visibility_from_state(state_map) + + if visibility == HistoryVisibility.JOINED: + # we weren't a member at the time of the event, so we can't see this event. + return False + + # otherwise *invited* is good enough + if membership == Membership.INVITE: + return True + + if visibility == HistoryVisibility.INVITED: + # we weren't invited, so we can't see this event. + return False + + if visibility == HistoryVisibility.WORLD_READABLE: + return True + + # So it's SHARED, and the user was not a member at the time. The user cannot + # see history, unless they have *subsequently* joined the room. + # + # XXX: if the user has subsequently joined and then left again, + # ideally we would share history up to the point they left. But + # we don't know when they left. We just treat it as though they + # never joined, and restrict access. + + ( + current_membership, + _, + ) = await self.store.get_local_current_membership_for_user_in_room( + user_id, event_id + ) + return current_membership == Membership.JOIN + async def get_joined_members(self, requester: Requester, room_id: str) -> dict: """Get all the joined members in the room and their profile information. diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index b5859dcb28..a1d41358d9 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -621,21 +621,32 @@ class SyncHandler: ) async def get_state_after_event( - self, event: EventBase, state_filter: Optional[StateFilter] = None + self, event_id: str, state_filter: Optional[StateFilter] = None ) -> StateMap[str]: """ Get the room state after the given event Args: - event: event of interest + event_id: event of interest state_filter: The state filter used to fetch state from the database. """ state_ids = await self._state_storage_controller.get_state_ids_for_event( - event.event_id, state_filter=state_filter or StateFilter.all() + event_id, state_filter=state_filter or StateFilter.all() ) - if event.is_state(): + + # using get_metadata_for_events here (instead of get_event) sidesteps an issue + # with redactions: if `event_id` is a redaction event, and we don't have the + # original (possibly because it got purged), get_event will refuse to return + # the redaction event, which isn't terribly helpful here. + # + # (To be fair, in that case we could assume it's *not* a state event, and + # therefore we don't need to worry about it. But still, it seems cleaner just + # to pull the metadata.) + m = (await self.store.get_metadata_for_events([event_id]))[event_id] + if m.state_key is not None and m.rejection_reason is None: state_ids = dict(state_ids) - state_ids[(event.type, event.state_key)] = event.event_id + state_ids[(m.event_type, m.state_key)] = event_id + return state_ids async def get_state_at( @@ -654,14 +665,14 @@ class SyncHandler: # FIXME: This gets the state at the latest event before the stream ordering, # which might not be the same as the "current state" of the room at the time # of the stream token if there were multiple forward extremities at the time. - last_event = await self.store.get_last_event_in_room_before_stream_ordering( + last_event_id = await self.store.get_last_event_in_room_before_stream_ordering( room_id, end_token=stream_position.room_key, ) - if last_event: + if last_event_id: state = await self.get_state_after_event( - last_event, state_filter=state_filter or StateFilter.all() + last_event_id, state_filter=state_filter or StateFilter.all() ) else: diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index a07ad85582..3f2be3854b 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -54,6 +54,7 @@ class EventMetadata: room_id: str event_type: str state_key: Optional[str] + rejection_reason: Optional[str] def _retrieve_and_check_room_version(room_id: str, room_version_id: str) -> RoomVersion: @@ -167,17 +168,22 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): ) sql = f""" - SELECT e.event_id, e.room_id, e.type, se.state_key FROM events AS e + SELECT e.event_id, e.room_id, e.type, se.state_key, r.reason + FROM events AS e LEFT JOIN state_events se USING (event_id) + LEFT JOIN rejections r USING (event_id) WHERE {clause} """ txn.execute(sql, args) return { event_id: EventMetadata( - room_id=room_id, event_type=event_type, state_key=state_key + room_id=room_id, + event_type=event_type, + state_key=state_key, + rejection_reason=rejection_reason, ) - for event_id, room_id, event_type, state_key in txn + for event_id, room_id, event_type, state_key, rejection_reason in txn } result_map: Dict[str, EventMetadata] = {} diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 0e3a23a140..8e88784d3c 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -765,15 +765,16 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): self, room_id: str, end_token: RoomStreamToken, - ) -> Optional[EventBase]: - """Returns the last event in a room at or before a stream ordering + ) -> Optional[str]: + """Returns the ID of the last event in a room at or before a stream ordering Args: room_id end_token: The token used to stream from Returns: - The most recent event. + The ID of the most recent event, or None if there are no events in the room + before this stream ordering. """ last_row = await self.get_room_event_before_stream_ordering( @@ -781,10 +782,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): stream_ordering=end_token.stream, ) if last_row: - _, _, event_id = last_row - event = await self.get_event(event_id, get_prev_content=True) - return event - + return last_row[2] return None async def get_current_room_stream_token_for_room_id( diff --git a/synapse/visibility.py b/synapse/visibility.py index 97548c14e3..8aaa8c709f 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -162,16 +162,7 @@ async def filter_events_for_client( state = event_id_to_state[event.event_id] # get the room_visibility at the time of the event. - visibility_event = state.get(_HISTORY_VIS_KEY, None) - if visibility_event: - visibility = visibility_event.content.get( - "history_visibility", HistoryVisibility.SHARED - ) - else: - visibility = HistoryVisibility.SHARED - - if visibility not in VISIBILITY_PRIORITY: - visibility = HistoryVisibility.SHARED + visibility = get_effective_room_visibility_from_state(state) # Always allow history visibility events on boundaries. This is done # by setting the effective visibility to the least restrictive @@ -267,6 +258,23 @@ async def filter_events_for_client( return [ev for ev in filtered_events if ev] +def get_effective_room_visibility_from_state(state: StateMap[EventBase]) -> str: + """Get the actual history vis, from a state map including the history_visibility event + + Handles missing and invalid history visibility events. + """ + visibility_event = state.get(_HISTORY_VIS_KEY, None) + if not visibility_event: + return HistoryVisibility.SHARED + + visibility = visibility_event.content.get( + "history_visibility", HistoryVisibility.SHARED + ) + if visibility not in VISIBILITY_PRIORITY: + visibility = HistoryVisibility.SHARED + return visibility + + async def filter_events_for_server( storage: StorageControllers, server_name: str, From 88193f2125ad2e1dc1c83d6876757cc5eb3c467d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jacek=20Ku=C5=9Bnierz?= Date: Wed, 1 Jun 2022 13:32:35 +0200 Subject: [PATCH 65/74] Remove direct refeferences to PyNaCl (use signedjson instead). (#12902) --- changelog.d/12902.misc | 1 + contrib/cmdclient/console.py | 9 ++++----- poetry.lock | 2 +- pyproject.toml | 1 - tests/crypto/test_event_signing.py | 17 +++++------------ tests/crypto/test_keyring.py | 2 +- 6 files changed, 12 insertions(+), 20 deletions(-) create mode 100644 changelog.d/12902.misc diff --git a/changelog.d/12902.misc b/changelog.d/12902.misc new file mode 100644 index 0000000000..3ee8f92552 --- /dev/null +++ b/changelog.d/12902.misc @@ -0,0 +1 @@ +Remove PyNaCl occurrences directly used in Synapse code. \ No newline at end of file diff --git a/contrib/cmdclient/console.py b/contrib/cmdclient/console.py index 856dd437db..895b2a7af1 100755 --- a/contrib/cmdclient/console.py +++ b/contrib/cmdclient/console.py @@ -16,6 +16,7 @@ """ Starts a synapse client console. """ import argparse +import binascii import cmd import getpass import json @@ -26,9 +27,8 @@ import urllib from http import TwistedHttpClient from typing import Optional -import nacl.encoding -import nacl.signing import urlparse +from signedjson.key import NACL_ED25519, decode_verify_key_bytes from signedjson.sign import SignatureVerifyException, verify_signed_json from twisted.internet import defer, reactor, threads @@ -41,7 +41,6 @@ TRUSTED_ID_SERVERS = ["localhost:8001"] class SynapseCmd(cmd.Cmd): - """Basic synapse command-line processor. This processes commands from the user and calls the relevant HTTP methods. @@ -420,8 +419,8 @@ class SynapseCmd(cmd.Cmd): pubKey = None pubKeyObj = yield self.http_client.do_request("GET", url) if "public_key" in pubKeyObj: - pubKey = nacl.signing.VerifyKey( - pubKeyObj["public_key"], encoder=nacl.encoding.HexEncoder + pubKey = decode_verify_key_bytes( + NACL_ED25519, binascii.unhexlify(pubKeyObj["public_key"]) ) else: print("No public key found in pubkey response!") diff --git a/poetry.lock b/poetry.lock index 6b4686545b..7c561e3182 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1563,7 +1563,7 @@ url_preview = ["lxml"] [metadata] lock-version = "1.1" python-versions = "^3.7.1" -content-hash = "d39d5ac5d51c014581186b7691999b861058b569084c525523baf70b77f292b1" +content-hash = "539e5326f401472d1ffc8325d53d72e544cd70156b3f43f32f1285c4c131f831" [metadata.files] attrs = [ diff --git a/pyproject.toml b/pyproject.toml index 75251c863d..ec6e81f254 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -113,7 +113,6 @@ unpaddedbase64 = ">=2.1.0" canonicaljson = ">=1.4.0" # we use the type definitions added in signedjson 1.1. signedjson = ">=1.1.0" -PyNaCl = ">=1.2.1" # validating SSL certs for IP addresses requires service_identity 18.1. service-identity = ">=18.1.0" # Twisted 18.9 introduces some logger improvements that the structured diff --git a/tests/crypto/test_event_signing.py b/tests/crypto/test_event_signing.py index 06e0545a4f..8fa710c9dc 100644 --- a/tests/crypto/test_event_signing.py +++ b/tests/crypto/test_event_signing.py @@ -12,10 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. - -import nacl.signing -import signedjson.types -from unpaddedbase64 import decode_base64 +from signedjson.key import decode_signing_key_base64 +from signedjson.types import SigningKey from synapse.api.room_versions import RoomVersions from synapse.crypto.event_signing import add_hashes_and_signatures @@ -25,7 +23,7 @@ from tests import unittest # Perform these tests using given secret key so we get entirely deterministic # signatures output that we can test against. -SIGNING_KEY_SEED = decode_base64("YJDBA9Xnr2sVqXD9Vj7XVUnmFZcZrlw8Md7kMW+3XA1") +SIGNING_KEY_SEED = "YJDBA9Xnr2sVqXD9Vj7XVUnmFZcZrlw8Md7kMW+3XA1" KEY_ALG = "ed25519" KEY_VER = "1" @@ -36,14 +34,9 @@ HOSTNAME = "domain" class EventSigningTestCase(unittest.TestCase): def setUp(self): - # NB: `signedjson` expects `nacl.signing.SigningKey` instances which have been - # monkeypatched to include new `alg` and `version` attributes. This is captured - # by the `signedjson.types.SigningKey` protocol. - self.signing_key: signedjson.types.SigningKey = nacl.signing.SigningKey( # type: ignore[assignment] - SIGNING_KEY_SEED + self.signing_key: SigningKey = decode_signing_key_base64( + KEY_ALG, KEY_VER, SIGNING_KEY_SEED ) - self.signing_key.alg = KEY_ALG - self.signing_key.version = KEY_VER def test_sign_minimal(self): event_dict = { diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index d00ef24ca8..820a1a54e2 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -19,8 +19,8 @@ import attr import canonicaljson import signedjson.key import signedjson.sign -from nacl.signing import SigningKey from signedjson.key import encode_verify_key_base64, get_verify_key +from signedjson.types import SigningKey from twisted.internet import defer from twisted.internet.defer import Deferred, ensureDeferred From 7bc08f320147a1d80371eb13258328c88073fad0 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 1 Jun 2022 09:41:25 -0400 Subject: [PATCH 66/74] Remove remaining bits of groups code. (#12936) * Update worker docs to remove group endpoints. * Removes an unused parameter to `ApplicationService`. * Break dependency between media repo and groups. * Avoid copying `m.room.related_groups` state events during room upgrades. --- changelog.d/12936.removal | 1 + docs/workers.md | 6 ------ synapse/api/constants.py | 1 - synapse/appservice/__init__.py | 2 -- synapse/config/appservice.py | 1 - synapse/handlers/room.py | 1 - synapse/storage/databases/main/media_repository.py | 4 ---- tests/api/test_auth.py | 2 -- tests/api/test_ratelimiting.py | 2 -- tests/appservice/test_api.py | 1 - tests/appservice/test_appservice.py | 1 - tests/handlers/test_appservice.py | 3 --- tests/handlers/test_user_directory.py | 1 - tests/rest/client/test_account.py | 1 - tests/rest/client/test_login.py | 2 -- tests/rest/client/test_register.py | 2 -- tests/rest/client/test_room_batch.py | 1 - tests/storage/test_user_directory.py | 1 - tests/test_mau.py | 3 --- 19 files changed, 1 insertion(+), 35 deletions(-) create mode 100644 changelog.d/12936.removal diff --git a/changelog.d/12936.removal b/changelog.d/12936.removal new file mode 100644 index 0000000000..41f6fae5da --- /dev/null +++ b/changelog.d/12936.removal @@ -0,0 +1 @@ +Remove support for the non-standard groups/communities feature from Synapse. diff --git a/docs/workers.md b/docs/workers.md index 78973a498c..6969c424d8 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -191,7 +191,6 @@ information. ^/_matrix/federation/v1/event_auth/ ^/_matrix/federation/v1/exchange_third_party_invite/ ^/_matrix/federation/v1/user/devices/ - ^/_matrix/federation/v1/get_groups_publicised$ ^/_matrix/key/v2/query ^/_matrix/federation/v1/hierarchy/ @@ -213,9 +212,6 @@ information. ^/_matrix/client/(r0|v3|unstable)/devices$ ^/_matrix/client/versions$ ^/_matrix/client/(api/v1|r0|v3|unstable)/voip/turnServer$ - ^/_matrix/client/(r0|v3|unstable)/joined_groups$ - ^/_matrix/client/(r0|v3|unstable)/publicised_groups$ - ^/_matrix/client/(r0|v3|unstable)/publicised_groups/ ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/event/ ^/_matrix/client/(api/v1|r0|v3|unstable)/joined_rooms$ ^/_matrix/client/(api/v1|r0|v3|unstable)/search$ @@ -255,9 +251,7 @@ information. Additionally, the following REST endpoints can be handled for GET requests: - ^/_matrix/federation/v1/groups/ ^/_matrix/client/(api/v1|r0|v3|unstable)/pushrules/ - ^/_matrix/client/(r0|v3|unstable)/groups/ Pagination requests can also be handled, but all requests for a given room must be routed to the same instance. Additionally, care must be taken to diff --git a/synapse/api/constants.py b/synapse/api/constants.py index f03fdd6dae..e1d31cabed 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -95,7 +95,6 @@ class EventTypes: Aliases: Final = "m.room.aliases" Redaction: Final = "m.room.redaction" ThirdPartyInvite: Final = "m.room.third_party_invite" - RelatedGroups: Final = "m.room.related_groups" RoomHistoryVisibility: Final = "m.room.history_visibility" CanonicalAlias: Final = "m.room.canonical_alias" diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index ed92c2e910..0dfa00df44 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -70,7 +70,6 @@ class ApplicationService: def __init__( self, token: str, - hostname: str, id: str, sender: str, url: Optional[str] = None, @@ -88,7 +87,6 @@ class ApplicationService: ) # url must not end with a slash self.hs_token = hs_token self.sender = sender - self.server_name = hostname self.namespaces = self._check_namespaces(namespaces) self.id = id self.ip_range_whitelist = ip_range_whitelist diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index 24498e7944..16f93273b3 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -179,7 +179,6 @@ def _load_appservice( return ApplicationService( token=as_info["as_token"], - hostname=hostname, url=as_info["url"], namespaces=as_info["namespaces"], hs_token=as_info["hs_token"], diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 5c91d33f58..e1341dd9bb 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -468,7 +468,6 @@ class RoomCreationHandler: (EventTypes.RoomAvatar, ""), (EventTypes.RoomEncryption, ""), (EventTypes.ServerACL, ""), - (EventTypes.RelatedGroups, ""), (EventTypes.PowerLevels, ""), ] diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 40ac377ca9..deffdc19ce 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -276,10 +276,6 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): (SELECT 1 FROM profiles WHERE profiles.avatar_url = '{media_prefix}' || lmr.media_id) - AND NOT EXISTS - (SELECT 1 - FROM groups - WHERE groups.avatar_url = '{media_prefix}' || lmr.media_id) AND NOT EXISTS (SELECT 1 FROM room_memberships diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index d547df8a64..bc75ddd3e9 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -404,7 +404,6 @@ class AuthTestCase(unittest.HomeserverTestCase): appservice = ApplicationService( "abcd", - self.hs.config.server.server_name, id="1234", namespaces={ "users": [{"regex": "@_appservice.*:sender", "exclusive": True}] @@ -433,7 +432,6 @@ class AuthTestCase(unittest.HomeserverTestCase): appservice = ApplicationService( "abcd", - self.hs.config.server.server_name, id="1234", namespaces={ "users": [{"regex": "@_appservice.*:sender", "exclusive": True}] diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index 483d5463ad..f661a9ff8e 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -31,7 +31,6 @@ class TestRatelimiter(unittest.HomeserverTestCase): def test_allowed_appservice_ratelimited_via_can_requester_do_action(self): appservice = ApplicationService( None, - "example.com", id="foo", rate_limited=True, sender="@as:example.com", @@ -62,7 +61,6 @@ class TestRatelimiter(unittest.HomeserverTestCase): def test_allowed_appservice_via_can_requester_do_action(self): appservice = ApplicationService( None, - "example.com", id="foo", rate_limited=False, sender="@as:example.com", diff --git a/tests/appservice/test_api.py b/tests/appservice/test_api.py index 3e0db4dd98..532b676365 100644 --- a/tests/appservice/test_api.py +++ b/tests/appservice/test_api.py @@ -37,7 +37,6 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase): url=URL, token="unused", hs_token=TOKEN, - hostname="myserver", ) def test_query_3pe_authenticates_token(self): diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py index 7135362f76..3018d3fc6f 100644 --- a/tests/appservice/test_appservice.py +++ b/tests/appservice/test_appservice.py @@ -33,7 +33,6 @@ class ApplicationServiceTestCase(unittest.TestCase): sender="@as:test", url="some_url", token="some_token", - hostname="matrix.org", # only used by get_groups_for_user ) self.event = Mock( event_id="$abc:xyz", diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 0e100c404d..d96d5aa138 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -697,7 +697,6 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): # Create an application service appservice = ApplicationService( token=random_string(10), - hostname="example.com", id=random_string(10), sender="@as:example.com", rate_limited=False, @@ -776,7 +775,6 @@ class ApplicationServicesHandlerDeviceListsTestCase(unittest.HomeserverTestCase) # Create an appservice that is interested in "local_user" appservice = ApplicationService( token=random_string(10), - hostname="example.com", id=random_string(10), sender="@as:example.com", rate_limited=False, @@ -843,7 +841,6 @@ class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase): self._service_token = "VERYSECRET" self._service = ApplicationService( self._service_token, - "as1.invalid", "as1", "@as.sender:test", namespaces={ diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index a68c2ffd45..9e39cd97e5 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -60,7 +60,6 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.appservice = ApplicationService( token="i_am_an_app_service", - hostname="test", id="1234", namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, # Note: this user does not match the regex above, so that tests diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index e0a11da97b..a43a137273 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -548,7 +548,6 @@ class WhoamiTestCase(unittest.HomeserverTestCase): appservice = ApplicationService( as_token, - self.hs.config.server.server_name, id="1234", namespaces={"users": [{"regex": user_id, "exclusive": True}]}, sender=user_id, diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index 4920468f7a..f4ea1209d9 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -1112,7 +1112,6 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): self.service = ApplicationService( id="unique_identifier", token="some_token", - hostname="example.com", sender="@asbot:example.com", namespaces={ ApplicationService.NS_USERS: [ @@ -1125,7 +1124,6 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): self.another_service = ApplicationService( id="another__identifier", token="another_token", - hostname="example.com", sender="@as2bot:example.com", namespaces={ ApplicationService.NS_USERS: [ diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index 9aebf1735a..afb08b2736 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -56,7 +56,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): appservice = ApplicationService( as_token, - self.hs.config.server.server_name, id="1234", namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, sender="@as:test", @@ -80,7 +79,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): appservice = ApplicationService( as_token, - self.hs.config.server.server_name, id="1234", namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, sender="@as:test", diff --git a/tests/rest/client/test_room_batch.py b/tests/rest/client/test_room_batch.py index 1b7ee08ab2..9d5cb60d16 100644 --- a/tests/rest/client/test_room_batch.py +++ b/tests/rest/client/test_room_batch.py @@ -71,7 +71,6 @@ class RoomBatchTestCase(unittest.HomeserverTestCase): self.appservice = ApplicationService( token="i_am_an_app_service", - hostname="test", id="1234", namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, # Note: this user does not have to match the regex above diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index 7f1964eb6a..5b60cf5285 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -134,7 +134,6 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.appservice = ApplicationService( token="i_am_an_app_service", - hostname="test", id="1234", namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, sender="@as:test", diff --git a/tests/test_mau.py b/tests/test_mau.py index 5bbc361aa2..f14fcb7db9 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py @@ -105,7 +105,6 @@ class TestMauLimit(unittest.HomeserverTestCase): self.store.services_cache.append( ApplicationService( token=as_token, - hostname=self.hs.hostname, id="SomeASID", sender="@as_sender:test", namespaces={"users": [{"regex": "@as_*", "exclusive": True}]}, @@ -251,7 +250,6 @@ class TestMauLimit(unittest.HomeserverTestCase): self.store.services_cache.append( ApplicationService( token=as_token_1, - hostname=self.hs.hostname, id="SomeASID", sender="@as_sender_1:test", namespaces={"users": [{"regex": "@as_1.*", "exclusive": True}]}, @@ -262,7 +260,6 @@ class TestMauLimit(unittest.HomeserverTestCase): self.store.services_cache.append( ApplicationService( token=as_token_2, - hostname=self.hs.hostname, id="AnotherASID", sender="@as_sender_2:test", namespaces={"users": [{"regex": "@as_2.*", "exclusive": True}]}, From 782cb7420a88fe29241dcecdfee91e25940b2ac7 Mon Sep 17 00:00:00 2001 From: Michael Telatynski <7t3chguy@gmail.com> Date: Wed, 1 Jun 2022 15:57:09 +0100 Subject: [PATCH 67/74] Fix complement tests using the wrong path (#12933) --- .github/workflows/tests.yml | 2 +- changelog.d/12933.misc | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/12933.misc diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 3693cf06c3..83ab727378 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -372,7 +372,7 @@ jobs: # Attempt to check out the same branch of Complement as the PR. If it # doesn't exist, fallback to HEAD. - name: Checkout complement - run: .ci/scripts/checkout_complement.sh + run: synapse/.ci/scripts/checkout_complement.sh - run: | set -o pipefail diff --git a/changelog.d/12933.misc b/changelog.d/12933.misc new file mode 100644 index 0000000000..e29bf02407 --- /dev/null +++ b/changelog.d/12933.misc @@ -0,0 +1 @@ +Test Synapse against Complement with workers. From 888a29f4127723a8d048ce47cff37ee8a7a6f1b9 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 1 Jun 2022 16:02:53 +0100 Subject: [PATCH 68/74] Wait for lazy join to complete when getting current state (#12872) --- changelog.d/12872.misc | 1 + synapse/events/third_party_rules.py | 3 +- synapse/federation/federation_server.py | 4 +- synapse/handlers/device.py | 2 +- synapse/handlers/directory.py | 7 +- synapse/handlers/federation.py | 7 +- synapse/handlers/message.py | 2 +- synapse/handlers/presence.py | 6 +- synapse/handlers/register.py | 3 +- synapse/handlers/room.py | 13 +- synapse/handlers/room_list.py | 3 +- synapse/handlers/room_member.py | 5 +- synapse/handlers/room_summary.py | 11 +- synapse/handlers/stats.py | 6 +- synapse/handlers/sync.py | 13 +- synapse/handlers/user_directory.py | 6 +- synapse/module_api/__init__.py | 19 ++- synapse/push/mailer.py | 4 +- synapse/rest/admin/rooms.py | 3 +- synapse/storage/_base.py | 2 +- synapse/storage/controllers/__init__.py | 4 +- synapse/storage/controllers/persist_events.py | 4 +- synapse/storage/controllers/state.py | 112 +++++++++++++++++- synapse/storage/databases/main/room.py | 18 +++ synapse/storage/databases/main/state.py | 38 ++---- .../storage/databases/main/state_deltas.py | 4 +- .../storage/databases/main/user_directory.py | 4 +- .../util/partial_state_events_tracker.py | 60 ++++++++++ tests/handlers/test_federation.py | 6 +- tests/handlers/test_federation_event.py | 4 +- tests/handlers/test_typing.py | 2 +- tests/rest/client/test_upgrade_room.py | 8 +- .../util/test_partial_state_events_tracker.py | 59 ++++++++- 33 files changed, 361 insertions(+), 82 deletions(-) create mode 100644 changelog.d/12872.misc diff --git a/changelog.d/12872.misc b/changelog.d/12872.misc new file mode 100644 index 0000000000..f60a756f21 --- /dev/null +++ b/changelog.d/12872.misc @@ -0,0 +1 @@ +Faster room joins: when querying the current state of the room, wait for state to be populated. diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py index 9f4ff9799c..35f3f3690f 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py @@ -152,6 +152,7 @@ class ThirdPartyEventRules: self.third_party_rules = None self.store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self._check_event_allowed_callbacks: List[CHECK_EVENT_ALLOWED_CALLBACK] = [] self._on_create_room_callbacks: List[ON_CREATE_ROOM_CALLBACK] = [] @@ -463,7 +464,7 @@ class ThirdPartyEventRules: Returns: A dict mapping (event type, state key) to state event. """ - state_ids = await self.store.get_filtered_current_state_ids(room_id) + state_ids = await self._storage_controllers.state.get_current_state_ids(room_id) room_state_events = await self.store.get_events(state_ids.values()) state_events = {} diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 12591dc8db..f4af121c4d 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -118,6 +118,8 @@ class FederationServer(FederationBase): self.state = hs.get_state_handler() self._event_auth_handler = hs.get_event_auth_handler() + self._state_storage_controller = hs.get_storage_controllers().state + self.device_handler = hs.get_device_handler() # Ensure the following handlers are loaded since they register callbacks @@ -1221,7 +1223,7 @@ class FederationServer(FederationBase): Raises: AuthError if the server does not match the ACL """ - state_ids = await self.store.get_current_state_ids(room_id) + state_ids = await self._state_storage_controller.get_current_state_ids(room_id) acl_event_id = state_ids.get((EventTypes.ServerACL, "")) if not acl_event_id: diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 72faf2ee38..a0cbeedc30 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -166,7 +166,7 @@ class DeviceWorkerHandler: possibly_changed = set(changed) possibly_left = set() for room_id in rooms_changed: - current_state_ids = await self.store.get_current_state_ids(room_id) + current_state_ids = await self._state_storage.get_current_state_ids(room_id) # The user may have left the room # TODO: Check if they actually did or if we were just invited. diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 4aa33df884..44e84698c4 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -45,6 +45,7 @@ class DirectoryHandler: self.appservice_handler = hs.get_application_service_handler() self.event_creation_handler = hs.get_event_creation_handler() self.store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self.config = hs.config self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search self.require_membership = hs.config.server.require_membership_for_aliases @@ -463,7 +464,11 @@ class DirectoryHandler: making_public = visibility == "public" if making_public: room_aliases = await self.store.get_aliases_for_room(room_id) - canonical_alias = await self.store.get_canonical_alias_for_room(room_id) + canonical_alias = ( + await self._storage_controllers.state.get_canonical_alias_for_room( + room_id + ) + ) if canonical_alias: room_aliases.append(canonical_alias) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 659f279441..b212ee2172 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -750,7 +750,9 @@ class FederationHandler: # Note that this requires the /send_join request to come back to the # same server. if room_version.msc3083_join_rules: - state_ids = await self.store.get_current_state_ids(room_id) + state_ids = await self._state_storage_controller.get_current_state_ids( + room_id + ) if await self._event_auth_handler.has_restricted_join_rules( state_ids, room_version ): @@ -1552,6 +1554,9 @@ class FederationHandler: success = await self.store.clear_partial_state_room(room_id) if success: logger.info("State resync complete for %s", room_id) + self._storage_controllers.state.notify_room_un_partial_stated( + room_id + ) # TODO(faster_joins) update room stats and user directory? return diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index ac911a2ddc..081625f0bd 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -217,7 +217,7 @@ class MessageHandler: ) if membership == Membership.JOIN: - state_ids = await self.store.get_filtered_current_state_ids( + state_ids = await self._state_storage_controller.get_current_state_ids( room_id, state_filter=state_filter ) room_state = await self.store.get_events(state_ids.values()) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index bf112b9e1e..895ea63ed3 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -134,6 +134,7 @@ class BasePresenceHandler(abc.ABC): def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self.store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self.presence_router = hs.get_presence_router() self.state = hs.get_state_handler() self.is_mine_id = hs.is_mine_id @@ -1348,7 +1349,10 @@ class PresenceHandler(BasePresenceHandler): self._event_pos, room_max_stream_ordering, ) - max_pos, deltas = await self.store.get_current_state_deltas( + ( + max_pos, + deltas, + ) = await self._storage_controllers.state.get_current_state_deltas( self._event_pos, room_max_stream_ordering ) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 05bb1e0225..338204287f 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -87,6 +87,7 @@ class LoginDict(TypedDict): class RegistrationHandler: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self.clock = hs.get_clock() self.hs = hs self.auth = hs.get_auth() @@ -528,7 +529,7 @@ class RegistrationHandler: if requires_invite: # If the server is in the room, check if the room is public. - state = await self.store.get_filtered_current_state_ids( + state = await self._storage_controllers.state.get_current_state_ids( room_id, StateFilter.from_types([(EventTypes.JoinRules, "")]) ) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index e1341dd9bb..e2b0e519d4 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -107,6 +107,7 @@ class EventContext: class RoomCreationHandler: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self.auth = hs.get_auth() self.clock = hs.get_clock() self.hs = hs @@ -480,8 +481,10 @@ class RoomCreationHandler: if room_type == RoomTypes.SPACE: types_to_copy.append((EventTypes.SpaceChild, None)) - old_room_state_ids = await self.store.get_filtered_current_state_ids( - old_room_id, StateFilter.from_types(types_to_copy) + old_room_state_ids = ( + await self._storage_controllers.state.get_current_state_ids( + old_room_id, StateFilter.from_types(types_to_copy) + ) ) # map from event_id to BaseEvent old_room_state_events = await self.store.get_events(old_room_state_ids.values()) @@ -558,8 +561,10 @@ class RoomCreationHandler: ) # Transfer membership events - old_room_member_state_ids = await self.store.get_filtered_current_state_ids( - old_room_id, StateFilter.from_types([(EventTypes.Member, None)]) + old_room_member_state_ids = ( + await self._storage_controllers.state.get_current_state_ids( + old_room_id, StateFilter.from_types([(EventTypes.Member, None)]) + ) ) # map from event_id to BaseEvent diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index f3577b5d5a..183d4ae3c4 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -50,6 +50,7 @@ EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None) class RoomListHandler: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self.hs = hs self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search self.response_cache: ResponseCache[ @@ -274,7 +275,7 @@ class RoomListHandler: if aliases: result["aliases"] = aliases - current_state_ids = await self.store.get_current_state_ids( + current_state_ids = await self._storage_controllers.state.get_current_state_ids( room_id, on_invalidate=cache_context.invalidate ) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 00662dc961..70c674ff8e 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -68,6 +68,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self.auth = hs.get_auth() self.state_handler = hs.get_state_handler() self.config = hs.config @@ -994,7 +995,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # If the host is in the room, but not one of the authorised hosts # for restricted join rules, a remote join must be used. room_version = await self.store.get_room_version(room_id) - current_state_ids = await self.store.get_current_state_ids(room_id) + current_state_ids = await self._storage_controllers.state.get_current_state_ids( + room_id + ) # If restricted join rules are not being used, a local join can always # be used. diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py index 75aee6a111..13098f56ed 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py @@ -90,6 +90,7 @@ class RoomSummaryHandler: def __init__(self, hs: "HomeServer"): self._event_auth_handler = hs.get_event_auth_handler() self._store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self._event_serializer = hs.get_event_client_serializer() self._server_name = hs.hostname self._federation_client = hs.get_federation_client() @@ -537,7 +538,7 @@ class RoomSummaryHandler: Returns: True if the room is accessible to the requesting user or server. """ - state_ids = await self._store.get_current_state_ids(room_id) + state_ids = await self._storage_controllers.state.get_current_state_ids(room_id) # If there's no state for the room, it isn't known. if not state_ids: @@ -702,7 +703,9 @@ class RoomSummaryHandler: # there should always be an entry assert stats is not None, "unable to retrieve stats for %s" % (room_id,) - current_state_ids = await self._store.get_current_state_ids(room_id) + current_state_ids = await self._storage_controllers.state.get_current_state_ids( + room_id + ) create_event = await self._store.get_event( current_state_ids[(EventTypes.Create, "")] ) @@ -760,7 +763,9 @@ class RoomSummaryHandler: """ # look for child rooms/spaces. - current_state_ids = await self._store.get_current_state_ids(room_id) + current_state_ids = await self._storage_controllers.state.get_current_state_ids( + room_id + ) events = await self._store.get_events_as_list( [ diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index 436cd971ce..f45e06eb0e 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -40,6 +40,7 @@ class StatsHandler: def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self.state = hs.get_state_handler() self.server_name = hs.hostname self.clock = hs.get_clock() @@ -105,7 +106,10 @@ class StatsHandler: logger.debug( "Processing room stats %s->%s", self.pos, room_max_stream_ordering ) - max_pos, deltas = await self.store.get_current_state_deltas( + ( + max_pos, + deltas, + ) = await self._storage_controllers.state.get_current_state_deltas( self.pos, room_max_stream_ordering ) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index a1d41358d9..b4ead79f97 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -506,8 +506,10 @@ class SyncHandler: # ensure that we always include current state in the timeline current_state_ids: FrozenSet[str] = frozenset() if any(e.is_state() for e in recents): - current_state_ids_map = await self.store.get_current_state_ids( - room_id + current_state_ids_map = ( + await self._state_storage_controller.get_current_state_ids( + room_id + ) ) current_state_ids = frozenset(current_state_ids_map.values()) @@ -574,8 +576,11 @@ class SyncHandler: # ensure that we always include current state in the timeline current_state_ids = frozenset() if any(e.is_state() for e in loaded_recents): - current_state_ids_map = await self.store.get_current_state_ids( - room_id + # FIXME(faster_joins): We use the partial state here as + # we don't want to block `/sync` on finishing a lazy join. + # Is this the correct way of doing it? + current_state_ids_map = ( + await self.store.get_partial_current_state_ids(room_id) ) current_state_ids = frozenset(current_state_ids_map.values()) diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index 74f7fdfe6c..8c3c52e1ca 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -56,6 +56,7 @@ class UserDirectoryHandler(StateDeltasHandler): super().__init__(hs) self.store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self.server_name = hs.hostname self.clock = hs.get_clock() self.notifier = hs.get_notifier() @@ -174,7 +175,10 @@ class UserDirectoryHandler(StateDeltasHandler): logger.debug( "Processing user stats %s->%s", self.pos, room_max_stream_ordering ) - max_pos, deltas = await self.store.get_current_state_deltas( + ( + max_pos, + deltas, + ) = await self._storage_controllers.state.get_current_state_deltas( self.pos, room_max_stream_ordering ) diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index b7451fc870..a8ad575fcd 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -194,6 +194,7 @@ class ModuleApi: self._store: Union[ DataStore, "GenericWorkerSlavedStore" ] = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self._auth = hs.get_auth() self._auth_handler = auth_handler self._server_name = hs.hostname @@ -911,7 +912,7 @@ class ModuleApi: The filtered state events in the room. """ state_ids = yield defer.ensureDeferred( - self._store.get_filtered_current_state_ids( + self._storage_controllers.state.get_current_state_ids( room_id=room_id, state_filter=StateFilter.from_types(types) ) ) @@ -1289,20 +1290,16 @@ class ModuleApi: # regardless of their state key ] """ + state_filter = None if event_filter: # If a filter was provided, turn it into a StateFilter and retrieve a filtered # view of the state. state_filter = StateFilter.from_types(event_filter) - state_ids = await self._store.get_filtered_current_state_ids( - room_id, - state_filter, - ) - else: - # If no filter was provided, get the whole state. We could also reuse the call - # to get_filtered_current_state_ids above, with `state_filter = StateFilter.all()`, - # but get_filtered_current_state_ids isn't cached and `get_current_state_ids` - # is, so using the latter when we can is better for perf. - state_ids = await self._store.get_current_state_ids(room_id) + + state_ids = await self._storage_controllers.state.get_current_state_ids( + room_id, + state_filter, + ) state_events = await self._store.get_events(state_ids.values()) diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 63aefd07f5..015c19b2d9 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -255,7 +255,9 @@ class Mailer: user_display_name = user_id async def _fetch_room_state(room_id: str) -> None: - room_state = await self.store.get_current_state_ids(room_id) + room_state = await self._state_storage_controller.get_current_state_ids( + room_id + ) state_by_room[room_id] = room_state # Run at most 3 of these at once: sync does 10 at a time but email diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 356d6f74d7..1cacd1a4f0 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -418,6 +418,7 @@ class RoomStateRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() self.store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self.clock = hs.get_clock() self._event_serializer = hs.get_event_client_serializer() @@ -430,7 +431,7 @@ class RoomStateRestServlet(RestServlet): if not ret: raise NotFoundError("Room not found") - event_ids = await self.store.get_current_state_ids(room_id) + event_ids = await self._storage_controllers.state.get_current_state_ids(room_id) events = await self.store.get_events(event_ids.values()) now = self.clock.time_msec() room_state = self._event_serializer.serialize_events(events.values(), now) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 8df80664a2..57bd74700e 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -77,7 +77,7 @@ class SQLBaseStore(metaclass=ABCMeta): # Purge other caches based on room state. self._attempt_to_invalidate_cache("get_room_summary", (room_id,)) - self._attempt_to_invalidate_cache("get_current_state_ids", (room_id,)) + self._attempt_to_invalidate_cache("get_partial_current_state_ids", (room_id,)) def _attempt_to_invalidate_cache( self, cache_name: str, key: Optional[Collection[Any]] diff --git a/synapse/storage/controllers/__init__.py b/synapse/storage/controllers/__init__.py index 992261d07b..55649719f6 100644 --- a/synapse/storage/controllers/__init__.py +++ b/synapse/storage/controllers/__init__.py @@ -18,7 +18,7 @@ from synapse.storage.controllers.persist_events import ( EventsPersistenceStorageController, ) from synapse.storage.controllers.purge_events import PurgeEventsStorageController -from synapse.storage.controllers.state import StateGroupStorageController +from synapse.storage.controllers.state import StateStorageController from synapse.storage.databases import Databases from synapse.storage.databases.main import DataStore @@ -39,7 +39,7 @@ class StorageControllers: self.main = stores.main self.purge_events = PurgeEventsStorageController(hs, stores) - self.state = StateGroupStorageController(hs, stores) + self.state = StateStorageController(hs, stores) self.persistence = None if stores.persist_events: diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py index ef8c135b12..4caaa81808 100644 --- a/synapse/storage/controllers/persist_events.py +++ b/synapse/storage/controllers/persist_events.py @@ -994,7 +994,7 @@ class EventsPersistenceStorageController: Assumes that we are only persisting events for one room at a time. """ - existing_state = await self.main_store.get_current_state_ids(room_id) + existing_state = await self.main_store.get_partial_current_state_ids(room_id) to_delete = [key for key in existing_state if key not in current_state] @@ -1083,7 +1083,7 @@ class EventsPersistenceStorageController: # The server will leave the room, so we go and find out which remote # users will still be joined when we leave. if current_state is None: - current_state = await self.main_store.get_current_state_ids(room_id) + current_state = await self.main_store.get_partial_current_state_ids(room_id) current_state = dict(current_state) for key in delta.to_delete: current_state.pop(key, None) diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index 0f09953086..9952b00493 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -14,7 +14,9 @@ import logging from typing import ( TYPE_CHECKING, + Any, Awaitable, + Callable, Collection, Dict, Iterable, @@ -24,9 +26,13 @@ from typing import ( Tuple, ) +from synapse.api.constants import EventTypes from synapse.events import EventBase from synapse.storage.state import StateFilter -from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker +from synapse.storage.util.partial_state_events_tracker import ( + PartialCurrentStateTracker, + PartialStateEventsTracker, +) from synapse.types import MutableStateMap, StateMap if TYPE_CHECKING: @@ -36,17 +42,27 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class StateGroupStorageController: - """High level interface to fetching state for event.""" +class StateStorageController: + """High level interface to fetching state for an event, or the current state + in a room. + """ def __init__(self, hs: "HomeServer", stores: "Databases"): self._is_mine_id = hs.is_mine_id self.stores = stores self._partial_state_events_tracker = PartialStateEventsTracker(stores.main) + self._partial_state_room_tracker = PartialCurrentStateTracker(stores.main) def notify_event_un_partial_stated(self, event_id: str) -> None: self._partial_state_events_tracker.notify_un_partial_stated(event_id) + def notify_room_un_partial_stated(self, room_id: str) -> None: + """Notify that the room no longer has any partial state. + + Must be called after `DataStore.clear_partial_state_room` + """ + self._partial_state_room_tracker.notify_un_partial_stated(room_id) + async def get_state_group_delta( self, state_group: int ) -> Tuple[Optional[int], Optional[StateMap[str]]]: @@ -349,3 +365,93 @@ class StateGroupStorageController: return await self.stores.state.store_state_group( event_id, room_id, prev_group, delta_ids, current_state_ids ) + + async def get_current_state_ids( + self, + room_id: str, + state_filter: Optional[StateFilter] = None, + on_invalidate: Optional[Callable[[], None]] = None, + ) -> StateMap[str]: + """Get the current state event ids for a room based on the + current_state_events table. + + If a state filter is given (that is not `StateFilter.all()`) the query + result is *not* cached. + + Args: + room_id: The room to get the state IDs of. state_filter: The state + filter used to fetch state from the + database. + on_invalidate: Callback for when the `get_current_state_ids` cache + for the room gets invalidated. + + Returns: + The current state of the room. + """ + if not state_filter or state_filter.must_await_full_state(self._is_mine_id): + await self._partial_state_room_tracker.await_full_state(room_id) + + if state_filter and not state_filter.is_full(): + return await self.stores.main.get_partial_filtered_current_state_ids( + room_id, state_filter + ) + else: + return await self.stores.main.get_partial_current_state_ids( + room_id, on_invalidate=on_invalidate + ) + + async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]: + """Get canonical alias for room, if any + + Args: + room_id: The room ID + + Returns: + The canonical alias, if any + """ + + state = await self.get_current_state_ids( + room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")]) + ) + + event_id = state.get((EventTypes.CanonicalAlias, "")) + if not event_id: + return None + + event = await self.stores.main.get_event(event_id, allow_none=True) + if not event: + return None + + return event.content.get("canonical_alias") + + async def get_current_state_deltas( + self, prev_stream_id: int, max_stream_id: int + ) -> Tuple[int, List[Dict[str, Any]]]: + """Fetch a list of room state changes since the given stream id + + Each entry in the result contains the following fields: + - stream_id (int) + - room_id (str) + - type (str): event type + - state_key (str): + - event_id (str|None): new event_id for this state key. None if the + state has been deleted. + - prev_event_id (str|None): previous event_id for this state key. None + if it's new state. + + Args: + prev_stream_id: point to get changes since (exclusive) + max_stream_id: the point that we know has been correctly persisted + - ie, an upper limit to return changes from. + + Returns: + A tuple consisting of: + - the stream id which these results go up to + - list of current_state_delta_stream rows. If it is empty, we are + up to date. + """ + # FIXME(faster_joins): what do we do here? + + return await self.stores.main.get_partial_current_state_deltas( + prev_stream_id, max_stream_id + ) diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index cfd8ce1624..68d4fc2e64 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -1139,6 +1139,24 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): keyvalues={"room_id": room_id}, ) + async def is_partial_state_room(self, room_id: str) -> bool: + """Checks if this room has partial state. + + Returns true if this is a "partial-state" room, which means that the state + at events in the room, and `current_state_events`, may not yet be + complete. + """ + + entry = await self.db_pool.simple_select_one_onecol( + table="partial_state_rooms", + keyvalues={"room_id": room_id}, + retcol="room_id", + allow_none=True, + desc="is_partial_state_room", + ) + + return entry is not None + class _BackgroundUpdates: REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory" diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index 3f2be3854b..bdd00273cd 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -242,7 +242,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): Raises: NotFoundError if the room is unknown """ - state_ids = await self.get_current_state_ids(room_id) + state_ids = await self.get_partial_current_state_ids(room_id) if not state_ids: raise NotFoundError(f"Current state for room {room_id} is empty") @@ -258,10 +258,12 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return create_event @cached(max_entries=100000, iterable=True) - async def get_current_state_ids(self, room_id: str) -> StateMap[str]: + async def get_partial_current_state_ids(self, room_id: str) -> StateMap[str]: """Get the current state event ids for a room based on the current_state_events table. + This may be the partial state if we're lazy joining the room. + Args: room_id: The room to get the state IDs of. @@ -280,17 +282,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return {(intern_string(r[0]), intern_string(r[1])): r[2] for r in txn} return await self.db_pool.runInteraction( - "get_current_state_ids", _get_current_state_ids_txn + "get_partial_current_state_ids", _get_current_state_ids_txn ) # FIXME: how should this be cached? - async def get_filtered_current_state_ids( + async def get_partial_filtered_current_state_ids( self, room_id: str, state_filter: Optional[StateFilter] = None ) -> StateMap[str]: """Get the current state event of a given type for a room based on the current_state_events table. This may not be as up-to-date as the result of doing a fresh state resolution as per state_handler.get_current_state + This may be the partial state if we're lazy joining the room. + Args: room_id state_filter: The state filter used to fetch state @@ -306,7 +310,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): if not where_clause: # We delegate to the cached version - return await self.get_current_state_ids(room_id) + return await self.get_partial_current_state_ids(room_id) def _get_filtered_current_state_ids_txn( txn: LoggingTransaction, @@ -334,30 +338,6 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): "get_filtered_current_state_ids", _get_filtered_current_state_ids_txn ) - async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]: - """Get canonical alias for room, if any - - Args: - room_id: The room ID - - Returns: - The canonical alias, if any - """ - - state = await self.get_filtered_current_state_ids( - room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")]) - ) - - event_id = state.get((EventTypes.CanonicalAlias, "")) - if not event_id: - return None - - event = await self.get_event(event_id, allow_none=True) - if not event: - return None - - return event.content.get("canonical_alias") - @cached(max_entries=50000) async def _get_state_group_for_event(self, event_id: str) -> Optional[int]: return await self.db_pool.simple_select_one_onecol( diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py index 188afec332..445213e12a 100644 --- a/synapse/storage/databases/main/state_deltas.py +++ b/synapse/storage/databases/main/state_deltas.py @@ -27,7 +27,7 @@ class StateDeltasStore(SQLBaseStore): # attribute. TODO: can we get static analysis to enforce this? _curr_state_delta_stream_cache: StreamChangeCache - async def get_current_state_deltas( + async def get_partial_current_state_deltas( self, prev_stream_id: int, max_stream_id: int ) -> Tuple[int, List[Dict[str, Any]]]: """Fetch a list of room state changes since the given stream id @@ -42,6 +42,8 @@ class StateDeltasStore(SQLBaseStore): - prev_event_id (str|None): previous event_id for this state key. None if it's new state. + This may be the partial state if we're lazy joining the room. + Args: prev_stream_id: point to get changes since (exclusive) max_stream_id: the point that we know has been correctly persisted diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 2282242e9d..ddb25b5cea 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -441,7 +441,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): (EventTypes.RoomHistoryVisibility, ""), ) - current_state_ids = await self.get_filtered_current_state_ids( # type: ignore[attr-defined] + # Getting the partial state is fine, as we're not looking at membership + # events. + current_state_ids = await self.get_partial_filtered_current_state_ids( # type: ignore[attr-defined] room_id, StateFilter.from_types(types_to_filter) ) diff --git a/synapse/storage/util/partial_state_events_tracker.py b/synapse/storage/util/partial_state_events_tracker.py index a61a951ef0..211437cfaa 100644 --- a/synapse/storage/util/partial_state_events_tracker.py +++ b/synapse/storage/util/partial_state_events_tracker.py @@ -21,6 +21,7 @@ from twisted.internet.defer import Deferred from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.storage.databases.main.room import RoomWorkerStore from synapse.util import unwrapFirstError logger = logging.getLogger(__name__) @@ -118,3 +119,62 @@ class PartialStateEventsTracker: observer_set.discard(observer) if not observer_set: del self._observers[event_id] + + +class PartialCurrentStateTracker: + """Keeps track of which rooms have partial state, after partial-state joins""" + + def __init__(self, store: RoomWorkerStore): + self._store = store + + # a map from room id to a set of Deferreds which are waiting for that room to be + # un-partial-stated. + self._observers: Dict[str, Set[Deferred[None]]] = defaultdict(set) + + def notify_un_partial_stated(self, room_id: str) -> None: + """Notify that we now have full current state for a given room + + Unblocks any callers to await_full_state() for that room. + + Args: + room_id: the room that now has full current state. + """ + observers = self._observers.pop(room_id, None) + if not observers: + return + logger.info( + "Notifying %i things waiting for un-partial-stating of room %s", + len(observers), + room_id, + ) + with PreserveLoggingContext(): + for o in observers: + o.callback(None) + + async def await_full_state(self, room_id: str) -> None: + # We add the deferred immediately so that the DB call to check for + # partial state doesn't race when we unpartial the room. + d: Deferred[None] = Deferred() + self._observers.setdefault(room_id, set()).add(d) + + try: + # Check if the room has partial current state or not. + has_partial_state = await self._store.is_partial_state_room(room_id) + if not has_partial_state: + return + + logger.info( + "Awaiting un-partial-stating of room %s", + room_id, + ) + + await make_deferred_yieldable(d) + + logger.info("Room has un-partial-stated") + finally: + # Remove the added observer, and remove the room entry if its empty. + ds = self._observers.get(room_id) + if ds is not None: + ds.discard(d) + if not ds: + self._observers.pop(room_id, None) diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index 500c9ccfbc..e0eda545b9 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -237,7 +237,9 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase): ) current_state = self.get_success( self.store.get_events_as_list( - (self.get_success(self.store.get_current_state_ids(room_id))).values() + ( + self.get_success(self.store.get_partial_current_state_ids(room_id)) + ).values() ) ) @@ -512,7 +514,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase): self.get_success(d) # sanity-check: the room should show that the new user is a member - r = self.get_success(self.store.get_current_state_ids(room_id)) + r = self.get_success(self.store.get_partial_current_state_ids(room_id)) self.assertEqual(r[(EventTypes.Member, other_user)], join_event.event_id) return join_event diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py index 1d5b2492c0..1a36c25c41 100644 --- a/tests/handlers/test_federation_event.py +++ b/tests/handlers/test_federation_event.py @@ -91,7 +91,9 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): event_injection.inject_member_event(self.hs, room_id, OTHER_USER, "join") ) - initial_state_map = self.get_success(main_store.get_current_state_ids(room_id)) + initial_state_map = self.get_success( + main_store.get_partial_current_state_ids(room_id) + ) auth_event_ids = [ initial_state_map[("m.room.create", "")], diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 057256cecd..14a0ee4922 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -146,7 +146,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ) ) - self.datastore.get_current_state_deltas = Mock(return_value=(0, None)) + self.datastore.get_partial_current_state_deltas = Mock(return_value=(0, None)) self.datastore.get_to_device_stream_token = lambda: 0 self.datastore.get_new_device_msgs_for_remote = ( diff --git a/tests/rest/client/test_upgrade_room.py b/tests/rest/client/test_upgrade_room.py index a21cbe9fa8..98c1039d33 100644 --- a/tests/rest/client/test_upgrade_room.py +++ b/tests/rest/client/test_upgrade_room.py @@ -249,7 +249,9 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): new_space_id = channel.json_body["replacement_room"] - state_ids = self.get_success(self.store.get_current_state_ids(new_space_id)) + state_ids = self.get_success( + self.store.get_partial_current_state_ids(new_space_id) + ) # Ensure the new room is still a space. create_event = self.get_success( @@ -284,7 +286,9 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): new_room_id = channel.json_body["replacement_room"] - state_ids = self.get_success(self.store.get_current_state_ids(new_room_id)) + state_ids = self.get_success( + self.store.get_partial_current_state_ids(new_room_id) + ) # Ensure the new room is the same type as the old room. create_event = self.get_success( diff --git a/tests/storage/util/test_partial_state_events_tracker.py b/tests/storage/util/test_partial_state_events_tracker.py index 303e190b6c..cae14151c0 100644 --- a/tests/storage/util/test_partial_state_events_tracker.py +++ b/tests/storage/util/test_partial_state_events_tracker.py @@ -17,8 +17,12 @@ from unittest import mock from twisted.internet.defer import CancelledError, ensureDeferred -from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker +from synapse.storage.util.partial_state_events_tracker import ( + PartialCurrentStateTracker, + PartialStateEventsTracker, +) +from tests.test_utils import make_awaitable from tests.unittest import TestCase @@ -115,3 +119,56 @@ class PartialStateEventsTrackerTestCase(TestCase): self.tracker.notify_un_partial_stated("event1") self.successResultOf(d2) + + +class PartialCurrentStateTrackerTestCase(TestCase): + def setUp(self) -> None: + self.mock_store = mock.Mock(spec_set=["is_partial_state_room"]) + + self.tracker = PartialCurrentStateTracker(self.mock_store) + + def test_does_not_block_for_full_state_rooms(self): + self.mock_store.is_partial_state_room.return_value = make_awaitable(False) + + self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id"))) + + def test_blocks_for_partial_room_state(self): + self.mock_store.is_partial_state_room.return_value = make_awaitable(True) + + d = ensureDeferred(self.tracker.await_full_state("room_id")) + + # there should be no result yet + self.assertNoResult(d) + + # notifying that the room has been de-partial-stated should unblock + self.tracker.notify_un_partial_stated("room_id") + self.successResultOf(d) + + def test_un_partial_state_race(self): + # We should correctly handle race between awaiting the state and us + # un-partialling the state + async def is_partial_state_room(events): + self.tracker.notify_un_partial_stated("room_id") + return True + + self.mock_store.is_partial_state_room.side_effect = is_partial_state_room + + self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id"))) + + def test_cancellation(self): + self.mock_store.is_partial_state_room.return_value = make_awaitable(True) + + d1 = ensureDeferred(self.tracker.await_full_state("room_id")) + self.assertNoResult(d1) + + d2 = ensureDeferred(self.tracker.await_full_state("room_id")) + self.assertNoResult(d2) + + d1.cancel() + self.assertFailure(d1, CancelledError) + + # d2 should still be waiting! + self.assertNoResult(d2) + + self.tracker.notify_un_partial_stated("room_id") + self.successResultOf(d2) From 01df5bacac3aa0e8356fed889ea0b69c4c044535 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 3 Jun 2022 12:09:12 -0400 Subject: [PATCH 69/74] Improve URL previews for some pages (#12951) * Skip `og` and `meta` tags where the value is empty. * Fallback to the favicon if there are no other images. * Ignore tags meant for navigation. --- changelog.d/12951.feature | 1 + synapse/rest/media/v1/preview_html.py | 52 ++++++++++++++++-------- tests/rest/media/v1/test_html_preview.py | 37 ++++++++++++++++- 3 files changed, 72 insertions(+), 18 deletions(-) create mode 100644 changelog.d/12951.feature diff --git a/changelog.d/12951.feature b/changelog.d/12951.feature new file mode 100644 index 0000000000..f885be9fe4 --- /dev/null +++ b/changelog.d/12951.feature @@ -0,0 +1 @@ +Improve URL previews for pages with empty elements. diff --git a/synapse/rest/media/v1/preview_html.py b/synapse/rest/media/v1/preview_html.py index 13ec7ab533..ed8f21a483 100644 --- a/synapse/rest/media/v1/preview_html.py +++ b/synapse/rest/media/v1/preview_html.py @@ -30,6 +30,9 @@ _xml_encoding_match = re.compile( ) _content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I) +# Certain elements aren't meant for display. +ARIA_ROLES_TO_IGNORE = {"directory", "menu", "menubar", "toolbar"} + def _normalise_encoding(encoding: str) -> Optional[str]: """Use the Python codec's name as the normalised entry.""" @@ -174,13 +177,15 @@ def parse_html_to_open_graph(tree: "etree.Element") -> Dict[str, Optional[str]]: # "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3", og: Dict[str, Optional[str]] = {} - for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"): - if "content" in tag.attrib: - # if we've got more than 50 tags, someone is taking the piss - if len(og) >= 50: - logger.warning("Skipping OG for page with too many 'og:' tags") - return {} - og[tag.attrib["property"]] = tag.attrib["content"] + for tag in tree.xpath( + "//*/meta[starts-with(@property, 'og:')][@content][not(@content='')]" + ): + # if we've got more than 50 tags, someone is taking the piss + if len(og) >= 50: + logger.warning("Skipping OG for page with too many 'og:' tags") + return {} + + og[tag.attrib["property"]] = tag.attrib["content"] # TODO: grab article: meta tags too, e.g.: @@ -192,21 +197,23 @@ def parse_html_to_open_graph(tree: "etree.Element") -> Dict[str, Optional[str]]: # "article:modified_time" content="2016-04-01T18:31:53+00:00" /> if "og:title" not in og: - # do some basic spidering of the HTML - title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]") - if title and title[0].text is not None: - og["og:title"] = title[0].text.strip() + # Attempt to find a title from the title tag, or the biggest header on the page. + title = tree.xpath("((//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1])/text()") + if title: + og["og:title"] = title[0].strip() else: og["og:title"] = None if "og:image" not in og: - # TODO: extract a favicon failing all else meta_image = tree.xpath( - "//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content" + "//*/meta[translate(@itemprop, 'IMAGE', 'image')='image'][not(@content='')]/@content[1]" ) + # If a meta image is found, use it. if meta_image: og["og:image"] = meta_image[0] else: + # Try to find images which are larger than 10px by 10px. + # # TODO: consider inlined CSS styles as well as width & height attribs images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]") images = sorted( @@ -215,17 +222,24 @@ def parse_html_to_open_graph(tree: "etree.Element") -> Dict[str, Optional[str]]: -1 * float(i.attrib["width"]) * float(i.attrib["height"]) ), ) + # If no images were found, try to find *any* images. if not images: - images = tree.xpath("//img[@src]") + images = tree.xpath("//img[@src][1]") if images: og["og:image"] = images[0].attrib["src"] + # Finally, fallback to the favicon if nothing else. + else: + favicons = tree.xpath("//link[@href][contains(@rel, 'icon')]/@href[1]") + if favicons: + og["og:image"] = favicons[0] + if "og:description" not in og: + # Check the first meta description tag for content. meta_description = tree.xpath( - "//*/meta" - "[translate(@name, 'DESCRIPTION', 'description')='description']" - "/@content" + "//*/meta[translate(@name, 'DESCRIPTION', 'description')='description'][not(@content='')]/@content[1]" ) + # If a meta description is found with content, use it. if meta_description: og["og:description"] = meta_description[0] else: @@ -306,6 +320,10 @@ def _iterate_over_text( if isinstance(el, str): yield el elif el.tag not in tags_to_ignore: + # If the element isn't meant for display, ignore it. + if el.get("role") in ARIA_ROLES_TO_IGNORE: + continue + # el.text is the text before the first child, so we can immediately # return it if the text exists. if el.text: diff --git a/tests/rest/media/v1/test_html_preview.py b/tests/rest/media/v1/test_html_preview.py index 62e308814d..ea9e5889bf 100644 --- a/tests/rest/media/v1/test_html_preview.py +++ b/tests/rest/media/v1/test_html_preview.py @@ -145,7 +145,7 @@ class SummarizeTestCase(unittest.TestCase): ) -class CalcOgTestCase(unittest.TestCase): +class OpenGraphFromHtmlTestCase(unittest.TestCase): if not lxml: skip = "url preview feature requires lxml" @@ -235,6 +235,21 @@ class CalcOgTestCase(unittest.TestCase): self.assertEqual(og, {"og:title": None, "og:description": "Some text."}) + # Another variant is a title with no content. + html = b""" + + + +

Title

+ + + """ + + tree = decode_body(html, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) + + self.assertEqual(og, {"og:title": "Title", "og:description": "Title"}) + def test_h1_as_title(self) -> None: html = b""" @@ -250,6 +265,26 @@ class CalcOgTestCase(unittest.TestCase): self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."}) + def test_empty_description(self) -> None: + """Description tags with empty content should be ignored.""" + html = b""" + + + + + + + +

Title

+ + + """ + + tree = decode_body(html, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) + + self.assertEqual(og, {"og:title": "Title", "og:description": "Finally!"}) + def test_missing_title_and_broken_h1(self) -> None: html = b""" From 6b46c3eb3d526d903e1e4833b2e8ae9b73de8502 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 3 Jun 2022 12:13:35 -0400 Subject: [PATCH 70/74] Remove groups code from synapse_port_db. (#12899) --- changelog.d/12899.removal | 1 + synapse/_scripts/synapse_port_db.py | 23 ++++++++++++------- .../storage/databases/main/group_server.py | 9 ++------ 3 files changed, 18 insertions(+), 15 deletions(-) create mode 100644 changelog.d/12899.removal diff --git a/changelog.d/12899.removal b/changelog.d/12899.removal new file mode 100644 index 0000000000..41f6fae5da --- /dev/null +++ b/changelog.d/12899.removal @@ -0,0 +1 @@ +Remove support for the non-standard groups/communities feature from Synapse. diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index d7dfa92bd1..4939573f30 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -102,14 +102,6 @@ BOOLEAN_COLUMNS = { "devices": ["hidden"], "device_lists_outbound_pokes": ["sent"], "users_who_share_rooms": ["share_private"], - "groups": ["is_public"], - "group_rooms": ["is_public"], - "group_users": ["is_public", "is_admin"], - "group_summary_rooms": ["is_public"], - "group_room_categories": ["is_public"], - "group_summary_users": ["is_public"], - "group_roles": ["is_public"], - "local_group_membership": ["is_publicised", "is_admin"], "e2e_room_keys": ["is_verified"], "account_validity": ["email_sent"], "redactions": ["have_censored"], @@ -175,6 +167,21 @@ IGNORED_TABLES = { "ui_auth_sessions", "ui_auth_sessions_credentials", "ui_auth_sessions_ips", + # Groups/communities is no longer supported. + "group_attestations_remote", + "group_attestations_renewals", + "group_invites", + "group_roles", + "group_room_categories", + "group_rooms", + "group_summary_roles", + "group_summary_room_categories", + "group_summary_rooms", + "group_summary_users", + "group_users", + "groups", + "local_group_membership", + "local_group_updates", } diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index da21a50144..c15a7136b6 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -29,11 +29,6 @@ class GroupServerStore(SQLBaseStore): db_conn: LoggingDatabaseConnection, hs: "HomeServer", ): - database.updates.register_background_index_update( - update_name="local_group_updates_index", - index_name="local_group_updates_stream_id_index", - table="local_group_updates", - columns=("stream_id",), - unique=True, - ) + # Register a legacy groups background update as a no-op. + database.updates.register_noop_background_update("local_group_updates_index") super().__init__(database, db_conn, hs) From e3163e2e11cf8bffa4cb3e58ac0b86a83eca314c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 6 Jun 2022 11:24:12 +0300 Subject: [PATCH 71/74] Reduce the amount of state we pull from the DB (#12811) --- changelog.d/12811.misc | 1 + synapse/api/auth.py | 45 ++++++----- synapse/federation/federation_base.py | 1 + synapse/federation/federation_server.py | 12 +-- synapse/handlers/directory.py | 2 +- synapse/handlers/federation.py | 2 +- synapse/handlers/federation_event.py | 18 +++-- synapse/handlers/initial_sync.py | 6 +- synapse/handlers/message.py | 4 +- synapse/handlers/room.py | 5 +- synapse/handlers/room_member.py | 16 +++- synapse/handlers/search.py | 2 +- synapse/notifier.py | 2 +- synapse/rest/admin/rooms.py | 34 ++++++--- synapse/rest/client/room.py | 7 +- .../resource_limits_server_notices.py | 7 +- synapse/state/__init__.py | 75 +------------------ synapse/storage/controllers/state.py | 27 +++++++ tests/federation/test_federation_server.py | 6 +- tests/handlers/test_directory.py | 3 +- tests/storage/test_events.py | 17 +++-- tests/storage/test_purge.py | 5 +- tests/storage/test_room.py | 12 ++- 23 files changed, 162 insertions(+), 147 deletions(-) create mode 100644 changelog.d/12811.misc diff --git a/changelog.d/12811.misc b/changelog.d/12811.misc new file mode 100644 index 0000000000..d57e1aca6b --- /dev/null +++ b/changelog.d/12811.misc @@ -0,0 +1 @@ +Reduce the amount of state we pull from the DB. diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 931750668e..5a410f805a 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -29,12 +29,11 @@ from synapse.api.errors import ( MissingClientTokenError, ) from synapse.appservice import ApplicationService -from synapse.events import EventBase from synapse.http import get_request_user_agent from synapse.http.site import SynapseRequest from synapse.logging.opentracing import active_span, force_tracing, start_active_span from synapse.storage.databases.main.registration import TokenLookupResult -from synapse.types import Requester, StateMap, UserID, create_requester +from synapse.types import Requester, UserID, create_requester from synapse.util.caches.lrucache import LruCache from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry @@ -61,8 +60,8 @@ class Auth: self.hs = hs self.clock = hs.get_clock() self.store = hs.get_datastores().main - self.state = hs.get_state_handler() self._account_validity_handler = hs.get_account_validity_handler() + self._storage_controllers = hs.get_storage_controllers() self.token_cache: LruCache[str, Tuple[str, bool]] = LruCache( 10000, "token_cache" @@ -79,9 +78,8 @@ class Auth: self, room_id: str, user_id: str, - current_state: Optional[StateMap[EventBase]] = None, allow_departed_users: bool = False, - ) -> EventBase: + ) -> Tuple[str, Optional[str]]: """Check if the user is in the room, or was at some point. Args: room_id: The room to check. @@ -99,29 +97,28 @@ class Auth: Raises: AuthError if the user is/was not in the room. Returns: - Membership event for the user if the user was in the - room. This will be the join event if they are currently joined to - the room. This will be the leave event if they have left the room. + The current membership of the user in the room and the + membership event ID of the user. """ - if current_state: - member = current_state.get((EventTypes.Member, user_id), None) - else: - member = await self.state.get_current_state( - room_id=room_id, event_type=EventTypes.Member, state_key=user_id - ) - if member: - membership = member.membership + ( + membership, + member_event_id, + ) = await self.store.get_local_current_membership_for_user_in_room( + user_id=user_id, + room_id=room_id, + ) + if membership: if membership == Membership.JOIN: - return member + return membership, member_event_id # XXX this looks totally bogus. Why do we not allow users who have been banned, # or those who were members previously and have been re-invited? if allow_departed_users and membership == Membership.LEAVE: forgot = await self.store.did_forget(user_id, room_id) if not forgot: - return member + return membership, member_event_id raise AuthError(403, "User %s not in room %s" % (user_id, room_id)) @@ -602,8 +599,11 @@ class Auth: # We currently require the user is a "moderator" in the room. We do this # by checking if they would (theoretically) be able to change the # m.room.canonical_alias events - power_level_event = await self.state.get_current_state( - room_id, EventTypes.PowerLevels, "" + + power_level_event = ( + await self._storage_controllers.state.get_current_state_event( + room_id, EventTypes.PowerLevels, "" + ) ) auth_events = {} @@ -693,12 +693,11 @@ class Auth: # * The user is a non-guest user, and was ever in the room # * The user is a guest user, and has joined the room # else it will throw. - member_event = await self.check_user_in_room( + return await self.check_user_in_room( room_id, user_id, allow_departed_users=allow_departed_users ) - return member_event.membership, member_event.event_id except AuthError: - visibility = await self.state.get_current_state( + visibility = await self._storage_controllers.state.get_current_state_event( room_id, EventTypes.RoomHistoryVisibility, "" ) if ( diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index a6232e048b..2522bf78fc 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -53,6 +53,7 @@ class FederationBase: self.spam_checker = hs.get_spam_checker() self.store = hs.get_datastores().main self._clock = hs.get_clock() + self._storage_controllers = hs.get_storage_controllers() async def _check_sigs_and_hash( self, room_version: RoomVersion, pdu: EventBase diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index f4af121c4d..3e1518f1f6 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -1223,14 +1223,10 @@ class FederationServer(FederationBase): Raises: AuthError if the server does not match the ACL """ - state_ids = await self._state_storage_controller.get_current_state_ids(room_id) - acl_event_id = state_ids.get((EventTypes.ServerACL, "")) - - if not acl_event_id: - return - - acl_event = await self.store.get_event(acl_event_id) - if server_matches_acl_event(server_name, acl_event): + acl_event = await self._storage_controllers.state.get_current_state_event( + room_id, EventTypes.ServerACL, "" + ) + if not acl_event or server_matches_acl_event(server_name, acl_event): return raise AuthError(code=403, msg="Server is banned from room") diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 44e84698c4..1459a046de 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -320,7 +320,7 @@ class DirectoryHandler: Raises: ShadowBanError if the requester has been shadow-banned. """ - alias_event = await self.state.get_current_state( + alias_event = await self._storage_controllers.state.get_current_state_event( room_id, EventTypes.CanonicalAlias, "" ) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index b212ee2172..6a143440d3 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -371,7 +371,7 @@ class FederationHandler: # First we try hosts that are already in the room # TODO: HEURISTIC ALERT. - curr_state = await self.state_handler.get_current_state(room_id) + curr_state = await self._storage_controllers.state.get_current_state(room_id) curr_domains = get_domains_from_state(curr_state) diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 549b066dd9..87a0608359 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -1584,9 +1584,11 @@ class FederationEventHandler: if guest_access == GuestAccess.CAN_JOIN: return - current_state_map = await self._state_handler.get_current_state(event.room_id) - current_state = list(current_state_map.values()) - await self._get_room_member_handler().kick_guest_users(current_state) + current_state = await self._storage_controllers.state.get_current_state( + event.room_id + ) + current_state_list = list(current_state.values()) + await self._get_room_member_handler().kick_guest_users(current_state_list) async def _check_for_soft_fail( self, @@ -1614,6 +1616,9 @@ class FederationEventHandler: room_version = await self._store.get_room_version_id(event.room_id) room_version_obj = KNOWN_ROOM_VERSIONS[room_version] + # The event types we want to pull from the "current" state. + auth_types = auth_types_for_event(room_version_obj, event) + # Calculate the "current state". if state_ids is not None: # If we're explicitly given the state then we won't have all the @@ -1643,8 +1648,10 @@ class FederationEventHandler: ) ) else: - current_state_ids = await self._state_handler.get_current_state_ids( - event.room_id, latest_event_ids=extrem_ids + current_state_ids = ( + await self._state_storage_controller.get_current_state_ids( + event.room_id, StateFilter.from_types(auth_types) + ) ) logger.debug( @@ -1654,7 +1661,6 @@ class FederationEventHandler: ) # Now check if event pass auth against said current state - auth_types = auth_types_for_event(room_version_obj, event) current_state_ids_list = [ e for k, e in current_state_ids.items() if k in auth_types ] diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index d2b489e816..85b472f250 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -190,7 +190,7 @@ class InitialSyncHandler: if event.membership == Membership.JOIN: room_end_token = now_token.room_key deferred_room_state = run_in_background( - self.state_handler.get_current_state, event.room_id + self._state_storage_controller.get_current_state, event.room_id ) elif event.membership == Membership.LEAVE: room_end_token = RoomStreamToken( @@ -407,7 +407,9 @@ class InitialSyncHandler: membership: str, is_peeking: bool, ) -> JsonDict: - current_state = await self.state.get_current_state(room_id=room_id) + current_state = await self._storage_controllers.state.get_current_state( + room_id=room_id + ) # TODO: These concurrently time_now = self.clock.time_msec() diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 081625f0bd..f455158a2c 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -125,7 +125,9 @@ class MessageHandler: ) if membership == Membership.JOIN: - data = await self.state.get_current_state(room_id, event_type, state_key) + data = await self._storage_controllers.state.get_current_state_event( + room_id, event_type, state_key + ) elif membership == Membership.LEAVE: key = (event_type, state_key) # If the membership is not JOIN, then the event ID should exist. diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index e2b0e519d4..520663f172 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1333,6 +1333,7 @@ class TimestampLookupHandler: self.store = hs.get_datastores().main self.state_handler = hs.get_state_handler() self.federation_client = hs.get_federation_client() + self._storage_controllers = hs.get_storage_controllers() async def get_event_for_timestamp( self, @@ -1406,7 +1407,9 @@ class TimestampLookupHandler: ) # Find other homeservers from the given state in the room - curr_state = await self.state_handler.get_current_state(room_id) + curr_state = await self._storage_controllers.state.get_current_state( + room_id + ) curr_domains = get_domains_from_state(curr_state) likely_domains = [ domain for domain, depth in curr_domains if domain != self.server_name diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 70c674ff8e..d1199a0644 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -1401,7 +1401,19 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): txn_id: Optional[str], id_access_token: Optional[str] = None, ) -> int: - room_state = await self.state_handler.get_current_state(room_id) + room_state = await self._storage_controllers.state.get_current_state( + room_id, + StateFilter.from_types( + [ + (EventTypes.Member, user.to_string()), + (EventTypes.CanonicalAlias, ""), + (EventTypes.Name, ""), + (EventTypes.Create, ""), + (EventTypes.JoinRules, ""), + (EventTypes.RoomAvatar, ""), + ] + ), + ) inviter_display_name = "" inviter_avatar_url = "" @@ -1797,7 +1809,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): async def forget(self, user: UserID, room_id: str) -> None: user_id = user.to_string() - member = await self.state_handler.get_current_state( + member = await self._storage_controllers.state.get_current_state_event( room_id=room_id, event_type=EventTypes.Member, state_key=user_id ) membership = member.membership if member else None diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 659f99f7e2..bcab98c6d5 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -348,7 +348,7 @@ class SearchHandler: state_results = {} if include_state: for room_id in {e.room_id for e in search_result.allowed_events}: - state = await self.state_handler.get_current_state(room_id) + state = await self._storage_controllers.state.get_current_state(room_id) state_results[room_id] = list(state.values()) aggregations = await self._relations_handler.get_bundled_aggregations( diff --git a/synapse/notifier.py b/synapse/notifier.py index 1100434b3f..54b0ec4b97 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -681,7 +681,7 @@ class Notifier: return joined_room_ids, True async def _is_world_readable(self, room_id: str) -> bool: - state = await self.state_handler.get_current_state( + state = await self._storage_controllers.state.get_current_state_event( room_id, EventTypes.RoomHistoryVisibility, "" ) if state and "history_visibility" in state.content: diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 1cacd1a4f0..9d953d58de 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -34,6 +34,7 @@ from synapse.rest.admin._base import ( assert_user_is_admin, ) from synapse.storage.databases.main.room import RoomSortOrder +from synapse.storage.state import StateFilter from synapse.types import JsonDict, RoomID, UserID, create_requester from synapse.util import json_decoder @@ -448,7 +449,8 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet): super().__init__(hs) self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() - self.state_handler = hs.get_state_handler() + self.store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self.is_mine = hs.is_mine async def on_POST( @@ -490,8 +492,11 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet): ) # send invite if room has "JoinRules.INVITE" - room_state = await self.state_handler.get_current_state(room_id) - join_rules_event = room_state.get((EventTypes.JoinRules, "")) + join_rules_event = ( + await self._storage_controllers.state.get_current_state_event( + room_id, EventTypes.JoinRules, "" + ) + ) if join_rules_event: if not (join_rules_event.content.get("join_rule") == JoinRules.PUBLIC): # update_membership with an action of "invite" can raise a @@ -536,6 +541,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): super().__init__(hs) self.auth = hs.get_auth() self.store = hs.get_datastores().main + self._state_storage_controller = hs.get_storage_controllers().state self.event_creation_handler = hs.get_event_creation_handler() self.state_handler = hs.get_state_handler() self.is_mine_id = hs.is_mine_id @@ -553,12 +559,22 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): user_to_add = content.get("user_id", requester.user.to_string()) # Figure out which local users currently have power in the room, if any. - room_state = await self.state_handler.get_current_state(room_id) - if not room_state: + filtered_room_state = await self._state_storage_controller.get_current_state( + room_id, + StateFilter.from_types( + [ + (EventTypes.Create, ""), + (EventTypes.PowerLevels, ""), + (EventTypes.JoinRules, ""), + (EventTypes.Member, user_to_add), + ] + ), + ) + if not filtered_room_state: raise SynapseError(HTTPStatus.BAD_REQUEST, "Server not in room") - create_event = room_state[(EventTypes.Create, "")] - power_levels = room_state.get((EventTypes.PowerLevels, "")) + create_event = filtered_room_state[(EventTypes.Create, "")] + power_levels = filtered_room_state.get((EventTypes.PowerLevels, "")) if power_levels is not None: # We pick the local user with the highest power. @@ -634,7 +650,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): # Now we check if the user we're granting admin rights to is already in # the room. If not and it's not a public room we invite them. - member_event = room_state.get((EventTypes.Member, user_to_add)) + member_event = filtered_room_state.get((EventTypes.Member, user_to_add)) is_joined = False if member_event: is_joined = member_event.content["membership"] in ( @@ -645,7 +661,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): if is_joined: return HTTPStatus.OK, {} - join_rules = room_state.get((EventTypes.JoinRules, "")) + join_rules = filtered_room_state.get((EventTypes.JoinRules, "")) is_public = False if join_rules: is_public = join_rules.content.get("join_rule") == JoinRules.PUBLIC diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 7a5ce8ad0e..a26e976492 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -650,6 +650,7 @@ class RoomEventServlet(RestServlet): self.clock = hs.get_clock() self._store = hs.get_datastores().main self._state = hs.get_state_handler() + self._storage_controllers = hs.get_storage_controllers() self.event_handler = hs.get_event_handler() self._event_serializer = hs.get_event_client_serializer() self._relations_handler = hs.get_relations_handler() @@ -673,8 +674,10 @@ class RoomEventServlet(RestServlet): if include_unredacted_content and not await self.auth.is_server_admin( requester.user ): - power_level_event = await self._state.get_current_state( - room_id, EventTypes.PowerLevels, "" + power_level_event = ( + await self._storage_controllers.state.get_current_state_event( + room_id, EventTypes.PowerLevels, "" + ) ) auth_events = {} diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py index b5f3a0c74e..6863020778 100644 --- a/synapse/server_notices/resource_limits_server_notices.py +++ b/synapse/server_notices/resource_limits_server_notices.py @@ -36,6 +36,7 @@ class ResourceLimitsServerNotices: def __init__(self, hs: "HomeServer"): self._server_notices_manager = hs.get_server_notices_manager() self._store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self._auth = hs.get_auth() self._config = hs.config self._resouce_limited = False @@ -178,8 +179,10 @@ class ResourceLimitsServerNotices: currently_blocked = False pinned_state_event = None try: - pinned_state_event = await self._state.get_current_state( - room_id, event_type=EventTypes.Pinned + pinned_state_event = ( + await self._storage_controllers.state.get_current_state_event( + room_id, event_type=EventTypes.Pinned, state_key="" + ) ) except AuthError: # The user has yet to join the server notices room diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index bf09f5128a..ab68e2b6a4 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -32,13 +32,11 @@ from typing import ( Set, Tuple, Union, - overload, ) import attr from frozendict import frozendict from prometheus_client import Counter, Histogram -from typing_extensions import Literal from synapse.api.constants import EventTypes from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions @@ -132,85 +130,20 @@ class StateHandler: self._state_resolution_handler = hs.get_state_resolution_handler() self._storage_controllers = hs.get_storage_controllers() - @overload - async def get_current_state( - self, - room_id: str, - event_type: Literal[None] = None, - state_key: str = "", - latest_event_ids: Optional[List[str]] = None, - ) -> StateMap[EventBase]: - ... - - @overload - async def get_current_state( - self, - room_id: str, - event_type: str, - state_key: str = "", - latest_event_ids: Optional[List[str]] = None, - ) -> Optional[EventBase]: - ... - - async def get_current_state( - self, - room_id: str, - event_type: Optional[str] = None, - state_key: str = "", - latest_event_ids: Optional[List[str]] = None, - ) -> Union[Optional[EventBase], StateMap[EventBase]]: - """Retrieves the current state for the room. This is done by - calling `get_latest_events_in_room` to get the leading edges of the - event graph and then resolving any of the state conflicts. - - This is equivalent to getting the state of an event that were to send - next before receiving any new events. - - Returns: - If `event_type` is specified, then the method returns only the one - event (or None) with that `event_type` and `state_key`. - - Otherwise, a map from (type, state_key) to event. - """ - if not latest_event_ids: - latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id) - assert latest_event_ids is not None - - logger.debug("calling resolve_state_groups from get_current_state") - ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids) - state = ret.state - - if event_type: - event_id = state.get((event_type, state_key)) - event = None - if event_id: - event = await self.store.get_event(event_id, allow_none=True) - return event - - state_map = await self.store.get_events( - list(state.values()), get_prev_content=False - ) - return { - key: state_map[e_id] for key, e_id in state.items() if e_id in state_map - } - async def get_current_state_ids( - self, room_id: str, latest_event_ids: Optional[Collection[str]] = None + self, + room_id: str, + latest_event_ids: Collection[str], ) -> StateMap[str]: """Get the current state, or the state at a set of events, for a room Args: room_id: - latest_event_ids: if given, the forward extremities to resolve. If - None, we look them up from the database (via a cache). + latest_event_ids: The forward extremities to resolve. Returns: the state dict, mapping from (event_type, state_key) -> event_id """ - if not latest_event_ids: - latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id) - assert latest_event_ids is not None - logger.debug("calling resolve_state_groups from get_current_state_ids") ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids) return ret.state diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index 9952b00493..63a78ebc87 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -455,3 +455,30 @@ class StateStorageController: return await self.stores.main.get_partial_current_state_deltas( prev_stream_id, max_stream_id ) + + async def get_current_state( + self, room_id: str, state_filter: Optional[StateFilter] = None + ) -> StateMap[EventBase]: + """Same as `get_current_state_ids` but also fetches the events""" + state_map_ids = await self.get_current_state_ids(room_id, state_filter) + + event_map = await self.stores.main.get_events(list(state_map_ids.values())) + + state_map = {} + for key, event_id in state_map_ids.items(): + event = event_map.get(event_id) + if event: + state_map[key] = event + + return state_map + + async def get_current_state_event( + self, room_id: str, event_type: str, state_key: str + ) -> Optional[EventBase]: + """Get the current state event for the given type/state_key.""" + + key = (event_type, state_key) + state_map = await self.get_current_state( + room_id, StateFilter.from_types((key,)) + ) + return state_map.get(key) diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index b19365b81a..413b3c9426 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -134,6 +134,8 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): super().prepare(reactor, clock, hs) + self._storage_controllers = hs.get_storage_controllers() + # create the room creator_user_id = self.register_user("kermit", "test") tok = self.login("kermit", "test") @@ -207,7 +209,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase): # the room should show that the new user is a member r = self.get_success( - self.hs.get_state_handler().get_current_state(self._room_id) + self._storage_controllers.state.get_current_state(self._room_id) ) self.assertEqual(r[("m.room.member", joining_user)].membership, "join") @@ -258,7 +260,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase): # the room should show that the new user is a member r = self.get_success( - self.hs.get_state_handler().get_current_state(self._room_id) + self._storage_controllers.state.get_current_state(self._room_id) ) self.assertEqual(r[("m.room.member", joining_user)].membership, "join") diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 11ad44223d..53d49ca896 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -298,6 +298,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): self.store = hs.get_datastores().main self.handler = hs.get_directory_handler() self.state_handler = hs.get_state_handler() + self._storage_controllers = hs.get_storage_controllers() # Create user self.admin_user = self.register_user("admin", "pass", admin=True) @@ -335,7 +336,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): def _get_canonical_alias(self): """Get the canonical alias state of the room.""" return self.get_success( - self.state_handler.get_current_state( + self._storage_controllers.state.get_current_state_event( self.room_id, EventTypes.CanonicalAlias, "" ) ) diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py index a76718e8f9..2ff88e64a5 100644 --- a/tests/storage/test_events.py +++ b/tests/storage/test_events.py @@ -32,6 +32,7 @@ class ExtremPruneTestCase(HomeserverTestCase): def prepare(self, reactor, clock, homeserver): self.state = self.hs.get_state_handler() self._persistence = self.hs.get_storage_controllers().persistence + self._state_storage_controller = self.hs.get_storage_controllers().state self.store = self.hs.get_datastores().main self.register_user("user", "pass") @@ -104,7 +105,7 @@ class ExtremPruneTestCase(HomeserverTestCase): ) state_before_gap = self.get_success( - self.state.get_current_state_ids(self.room_id) + self._state_storage_controller.get_current_state_ids(self.room_id) ) self.persist_event(remote_event_2, state=state_before_gap) @@ -137,7 +138,9 @@ class ExtremPruneTestCase(HomeserverTestCase): # setting. The state resolution across the old and new event will then # include it, and so the resolved state won't match the new state. state_before_gap = dict( - self.get_success(self.state.get_current_state_ids(self.room_id)) + self.get_success( + self._state_storage_controller.get_current_state_ids(self.room_id) + ) ) state_before_gap.pop(("m.room.history_visibility", "")) @@ -181,7 +184,7 @@ class ExtremPruneTestCase(HomeserverTestCase): ) state_before_gap = self.get_success( - self.state.get_current_state_ids(self.room_id) + self._state_storage_controller.get_current_state_ids(self.room_id) ) self.persist_event(remote_event_2, state=state_before_gap) @@ -213,7 +216,7 @@ class ExtremPruneTestCase(HomeserverTestCase): ) state_before_gap = self.get_success( - self.state.get_current_state_ids(self.room_id) + self._state_storage_controller.get_current_state_ids(self.room_id) ) self.persist_event(remote_event_2, state=state_before_gap) @@ -255,7 +258,7 @@ class ExtremPruneTestCase(HomeserverTestCase): ) state_before_gap = self.get_success( - self.state.get_current_state_ids(self.room_id) + self._state_storage_controller.get_current_state_ids(self.room_id) ) self.persist_event(remote_event_2, state=state_before_gap) @@ -299,7 +302,7 @@ class ExtremPruneTestCase(HomeserverTestCase): ) state_before_gap = self.get_success( - self.state.get_current_state_ids(self.room_id) + self._state_storage_controller.get_current_state_ids(self.room_id) ) self.persist_event(remote_event_2, state=state_before_gap) @@ -335,7 +338,7 @@ class ExtremPruneTestCase(HomeserverTestCase): ) state_before_gap = self.get_success( - self.state.get_current_state_ids(self.room_id) + self._state_storage_controller.get_current_state_ids(self.room_id) ) self.persist_event(remote_event_2, state=state_before_gap) diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index 92cd0dfc05..8dfaa0559b 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -102,9 +102,10 @@ class PurgeTests(HomeserverTestCase): first = self.helper.send(self.room_id, body="test1") # Get the current room state. - state_handler = self.hs.get_state_handler() create_event = self.get_success( - state_handler.get_current_state(self.room_id, "m.room.create", "") + self._storage_controllers.state.get_current_state_event( + self.room_id, "m.room.create", "" + ) ) self.assertIsNotNone(create_event) diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index d497a19f63..3c79dabc9f 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -72,7 +72,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase): # Room events need the full datastore, for persist_event() and # get_room_state() self.store = hs.get_datastores().main - self._storage = hs.get_storage_controllers() + self._storage_controllers = hs.get_storage_controllers() self.event_factory = hs.get_event_factory() self.room = RoomID.from_string("!abcde:test") @@ -88,7 +88,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase): def inject_room_event(self, **kwargs): self.get_success( - self._storage.persistence.persist_event( + self._storage_controllers.persistence.persist_event( self.event_factory.create_event(room_id=self.room.to_string(), **kwargs) ) ) @@ -101,7 +101,9 @@ class RoomEventsStoreTestCase(HomeserverTestCase): ) state = self.get_success( - self.store.get_current_state(room_id=self.room.to_string()) + self._storage_controllers.state.get_current_state( + room_id=self.room.to_string() + ) ) self.assertEqual(1, len(state)) @@ -118,7 +120,9 @@ class RoomEventsStoreTestCase(HomeserverTestCase): ) state = self.get_success( - self.store.get_current_state(room_id=self.room.to_string()) + self._storage_controllers.state.get_current_state( + room_id=self.room.to_string() + ) ) self.assertEqual(1, len(state)) From fcd8703508ce5bfe481fc2f1510b05731477ce32 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Christian=20Gr=C3=BCnhage?= Date: Mon, 6 Jun 2022 13:10:13 +0200 Subject: [PATCH 72/74] Allow updating passwords using the admin api without logging out devices (#12952) --- changelog.d/12952.feature | 1 + docs/admin_api/user_admin_api.md | 4 +++- synapse/rest/admin/users.py | 8 +++++++- 3 files changed, 11 insertions(+), 2 deletions(-) create mode 100644 changelog.d/12952.feature diff --git a/changelog.d/12952.feature b/changelog.d/12952.feature new file mode 100644 index 0000000000..7329bcc3d4 --- /dev/null +++ b/changelog.d/12952.feature @@ -0,0 +1 @@ +Allow updating a user's password using the admin API without logging out their devices. Contributed by @jcgruenhage. diff --git a/docs/admin_api/user_admin_api.md b/docs/admin_api/user_admin_api.md index c8794299e7..62f89e8cba 100644 --- a/docs/admin_api/user_admin_api.md +++ b/docs/admin_api/user_admin_api.md @@ -115,7 +115,9 @@ URL parameters: Body parameters: - `password` - string, optional. If provided, the user's password is updated and all - devices are logged out. + devices are logged out, unless `logout_devices` is set to `false`. +- `logout_devices` - bool, optional, defaults to `true`. If set to false, devices aren't + logged out even when `password` is provided. - `displayname` - string, optional, defaults to the value of `user_id`. - `threepids` - array, optional, allows setting the third-party IDs (email, msisdn) - `medium` - string. Kind of third-party ID, either `email` or `msisdn`. diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 8e29ada8a0..f0614a2897 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -226,6 +226,13 @@ class UserRestServletV2(RestServlet): if not isinstance(password, str) or len(password) > 512: raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid password") + logout_devices = body.get("logout_devices", True) + if not isinstance(logout_devices, bool): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "'logout_devices' parameter is not of type boolean", + ) + deactivate = body.get("deactivated", False) if not isinstance(deactivate, bool): raise SynapseError( @@ -305,7 +312,6 @@ class UserRestServletV2(RestServlet): await self.store.set_server_admin(target_user, set_admin_to) if password is not None: - logout_devices = True new_password_hash = await self.auth_handler.hash(password) await self.set_password_handler.set_password( From 1acc897c317f2ed66c28a0cc27b6c584b8afdd6a Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 6 Jun 2022 07:18:04 -0400 Subject: [PATCH 73/74] Implement MSC3816, consider the root event for thread participation. (#12766) As opposed to only considering a user to have "participated" if they replied to the thread. --- changelog.d/12766.bugfix | 1 + synapse/handlers/relations.py | 58 +++++++++++++------- tests/rest/client/test_relations.py | 85 ++++++++++++++++++++--------- 3 files changed, 97 insertions(+), 47 deletions(-) create mode 100644 changelog.d/12766.bugfix diff --git a/changelog.d/12766.bugfix b/changelog.d/12766.bugfix new file mode 100644 index 0000000000..912c3deb70 --- /dev/null +++ b/changelog.d/12766.bugfix @@ -0,0 +1 @@ +Implement [MSC3816](https://github.com/matrix-org/matrix-spec-proposals/pull/3816): sending the root event in a thread should count as "participated" in it. diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 9a1cc11bb3..0b63cd2186 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -12,16 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import ( - TYPE_CHECKING, - Collection, - Dict, - FrozenSet, - Iterable, - List, - Optional, - Tuple, -) +from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tuple import attr @@ -256,13 +247,19 @@ class RelationsHandler: return filtered_results - async def get_threads_for_events( - self, event_ids: Collection[str], user_id: str, ignored_users: FrozenSet[str] + async def _get_threads_for_events( + self, + events_by_id: Dict[str, EventBase], + relations_by_id: Dict[str, str], + user_id: str, + ignored_users: FrozenSet[str], ) -> Dict[str, _ThreadAggregation]: """Get the bundled aggregations for threads for the requested events. Args: - event_ids: Events to get aggregations for threads. + events_by_id: A map of event_id to events to get aggregations for threads. + relations_by_id: A map of event_id to the relation type, if one exists + for that event. user_id: The user requesting the bundled aggregations. ignored_users: The users ignored by the requesting user. @@ -273,16 +270,34 @@ class RelationsHandler: """ user = UserID.from_string(user_id) + # It is not valid to start a thread on an event which itself relates to another event. + event_ids = [eid for eid in events_by_id.keys() if eid not in relations_by_id] + # Fetch thread summaries. summaries = await self._main_store.get_thread_summaries(event_ids) - # Only fetch participated for a limited selection based on what had - # summaries. + # Limit fetching whether the requester has participated in a thread to + # events which are thread roots. thread_event_ids = [ event_id for event_id, summary in summaries.items() if summary ] - participated = await self._main_store.get_threads_participated( - thread_event_ids, user_id + + # Pre-seed thread participation with whether the requester sent the event. + participated = { + event_id: events_by_id[event_id].sender == user_id + for event_id in thread_event_ids + } + # For events the requester did not send, check the database for whether + # the requester sent a threaded reply. + participated.update( + await self._main_store.get_threads_participated( + [ + event_id + for event_id in thread_event_ids + if not participated[event_id] + ], + user_id, + ) ) # Then subtract off the results for any ignored users. @@ -343,7 +358,8 @@ class RelationsHandler: count=thread_count, # If there's a thread summary it must also exist in the # participated dictionary. - current_user_participated=participated[event_id], + current_user_participated=events_by_id[event_id].sender == user_id + or participated[event_id], ) return results @@ -401,9 +417,9 @@ class RelationsHandler: # events to be fetched. Thus, we check those first! # Fetch thread summaries (but only for the directly requested events). - threads = await self.get_threads_for_events( - # It is not valid to start a thread on an event which itself relates to another event. - [eid for eid in events_by_id.keys() if eid not in relations_by_id], + threads = await self._get_threads_for_events( + events_by_id, + relations_by_id, user_id, ignored_users, ) diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index bc9cc51b92..62e4db23ef 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -896,6 +896,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): relation_type: str, assertion_callable: Callable[[JsonDict], None], expected_db_txn_for_event: int, + access_token: Optional[str] = None, ) -> None: """ Makes requests to various endpoints which should include bundled aggregations @@ -907,7 +908,9 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): for relation-specific assertions. expected_db_txn_for_event: The number of database transactions which are expected for a call to /event/. + access_token: The access token to user, defaults to self.user_token. """ + access_token = access_token or self.user_token def assert_bundle(event_json: JsonDict) -> None: """Assert the expected values of the bundled aggregations.""" @@ -921,7 +924,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): channel = self.make_request( "GET", f"/rooms/{self.room}/event/{self.parent_id}", - access_token=self.user_token, + access_token=access_token, ) self.assertEqual(200, channel.code, channel.json_body) assert_bundle(channel.json_body) @@ -932,7 +935,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): channel = self.make_request( "GET", f"/rooms/{self.room}/messages?dir=b", - access_token=self.user_token, + access_token=access_token, ) self.assertEqual(200, channel.code, channel.json_body) assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"])) @@ -941,7 +944,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): channel = self.make_request( "GET", f"/rooms/{self.room}/context/{self.parent_id}", - access_token=self.user_token, + access_token=access_token, ) self.assertEqual(200, channel.code, channel.json_body) assert_bundle(channel.json_body["event"]) @@ -949,7 +952,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): # Request sync. filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 4}}}') channel = self.make_request( - "GET", f"/sync?filter={filter}", access_token=self.user_token + "GET", f"/sync?filter={filter}", access_token=access_token ) self.assertEqual(200, channel.code, channel.json_body) room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] @@ -962,7 +965,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): "/search", # Search term matches the parent message. content={"search_categories": {"room_events": {"search_term": "Hi"}}}, - access_token=self.user_token, + access_token=access_token, ) self.assertEqual(200, channel.code, channel.json_body) chunk = [ @@ -1037,30 +1040,60 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): """ Test that threads get correctly bundled. """ - self._send_relation(RelationTypes.THREAD, "m.room.test") - channel = self._send_relation(RelationTypes.THREAD, "m.room.test") + # The root message is from "user", send replies as "user2". + self._send_relation( + RelationTypes.THREAD, "m.room.test", access_token=self.user2_token + ) + channel = self._send_relation( + RelationTypes.THREAD, "m.room.test", access_token=self.user2_token + ) thread_2 = channel.json_body["event_id"] - def assert_thread(bundled_aggregations: JsonDict) -> None: - self.assertEqual(2, bundled_aggregations.get("count")) - self.assertTrue(bundled_aggregations.get("current_user_participated")) - # The latest thread event has some fields that don't matter. - self.assert_dict( - { - "content": { - "m.relates_to": { - "event_id": self.parent_id, - "rel_type": RelationTypes.THREAD, - } + # This needs two assertion functions which are identical except for whether + # the current_user_participated flag is True, create a factory for the + # two versions. + def _gen_assert(participated: bool) -> Callable[[JsonDict], None]: + def assert_thread(bundled_aggregations: JsonDict) -> None: + self.assertEqual(2, bundled_aggregations.get("count")) + self.assertEqual( + participated, bundled_aggregations.get("current_user_participated") + ) + # The latest thread event has some fields that don't matter. + self.assert_dict( + { + "content": { + "m.relates_to": { + "event_id": self.parent_id, + "rel_type": RelationTypes.THREAD, + } + }, + "event_id": thread_2, + "sender": self.user2_id, + "type": "m.room.test", }, - "event_id": thread_2, - "sender": self.user_id, - "type": "m.room.test", - }, - bundled_aggregations.get("latest_event"), - ) + bundled_aggregations.get("latest_event"), + ) - self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 9) + return assert_thread + + # The "user" sent the root event and is making queries for the bundled + # aggregations: they have participated. + self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 8) + # The "user2" sent replies in the thread and is making queries for the + # bundled aggregations: they have participated. + # + # Note that this re-uses some cached values, so the total number of + # queries is much smaller. + self._test_bundled_aggregations( + RelationTypes.THREAD, _gen_assert(True), 2, access_token=self.user2_token + ) + + # A user with no interactions with the thread: they have not participated. + user3_id, user3_token = self._create_user("charlie") + self.helper.join(self.room, user=user3_id, tok=user3_token) + self._test_bundled_aggregations( + RelationTypes.THREAD, _gen_assert(False), 2, access_token=user3_token + ) def test_thread_with_bundled_aggregations_for_latest(self) -> None: """ @@ -1106,7 +1139,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): bundled_aggregations["latest_event"].get("unsigned"), ) - self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 9) + self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 8) def test_nested_thread(self) -> None: """ From 148fe58a247d61ffb76c566ba397285480d93f74 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 6 Jun 2022 07:46:04 -0400 Subject: [PATCH 74/74] Do not break URL previews if an image is unreachable. (#12950) Avoid breaking a URL preview completely if the chosen image 404s or is unreachable for some other reason (e.g. DNS). --- changelog.d/12950.bugfix | 1 + synapse/rest/media/v1/preview_url_resource.py | 23 ++++++++---- tests/rest/media/v1/test_url_preview.py | 35 +++++++++++++++++++ 3 files changed, 53 insertions(+), 6 deletions(-) create mode 100644 changelog.d/12950.bugfix diff --git a/changelog.d/12950.bugfix b/changelog.d/12950.bugfix new file mode 100644 index 0000000000..e835d9aa72 --- /dev/null +++ b/changelog.d/12950.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where a URL preview would break if the image failed to download. diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 2b2db63bf7..54a849eac9 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -586,12 +586,16 @@ class PreviewUrlResource(DirectServeJsonResource): og: The Open Graph dictionary. This is modified with image information. """ # If there's no image or it is blank, there's nothing to do. - if "og:image" not in og or not og["og:image"]: + if "og:image" not in og: + return + + # Remove the raw image URL, this will be replaced with an MXC URL, if successful. + image_url = og.pop("og:image") + if not image_url: return # The image URL from the HTML might be relative to the previewed page, # convert it to an URL which can be requested directly. - image_url = og["og:image"] url_parts = urlparse(image_url) if url_parts.scheme != "data": image_url = urljoin(media_info.uri, image_url) @@ -599,7 +603,16 @@ class PreviewUrlResource(DirectServeJsonResource): # FIXME: it might be cleaner to use the same flow as the main /preview_url # request itself and benefit from the same caching etc. But for now we # just rely on the caching on the master request to speed things up. - image_info = await self._handle_url(image_url, user, allow_data_urls=True) + try: + image_info = await self._handle_url(image_url, user, allow_data_urls=True) + except Exception as e: + # Pre-caching the image failed, don't block the entire URL preview. + logger.warning( + "Pre-caching image failed during URL preview: %s errored with %s", + image_url, + e, + ) + return if _is_media(image_info.media_type): # TODO: make sure we don't choke on white-on-transparent images @@ -611,13 +624,11 @@ class PreviewUrlResource(DirectServeJsonResource): og["og:image:width"] = dims["width"] og["og:image:height"] = dims["height"] else: - logger.warning("Couldn't get dims for %s", og["og:image"]) + logger.warning("Couldn't get dims for %s", image_url) og["og:image"] = f"mxc://{self.server_name}/{image_info.filesystem_id}" og["og:image:type"] = image_info.media_type og["matrix:image:size"] = image_info.media_length - else: - del og["og:image"] async def _handle_oembed_response( self, url: str, media_info: MediaInfo, expiration_ms: int diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index 3b24d0ace6..2c321f8d04 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -656,6 +656,41 @@ class URLPreviewTests(unittest.HomeserverTestCase): server.data, ) + def test_nonexistent_image(self) -> None: + """If the preview image doesn't exist, ensure some data is returned.""" + self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] + + end_content = ( + b"""""" + ) + + channel = self.make_request( + "GET", + "preview_url?url=http://matrix.org", + shorthand=False, + await_result=False, + ) + self.pump() + + client = self.reactor.tcpClients[0][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + ( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" + b'Content-Type: text/html; charset="utf8"\r\n\r\n' + ) + % (len(end_content),) + + end_content + ) + + self.pump() + self.assertEqual(channel.code, 200) + + # The image should not be in the result. + self.assertNotIn("og:image", channel.json_body) + def test_data_url(self) -> None: """ Requesting to preview a data URL is not supported.