Merge pull request #5209 from matrix-org/erikj/reactions_base
Land basic reaction and edit support.pull/5211/head
						commit
						57ba3451b6
					
				|  | @ -0,0 +1 @@ | |||
| Add experimental support for relations (aka reactions and edits). | ||||
|  | @ -119,3 +119,11 @@ class UserTypes(object): | |||
|     """ | ||||
|     SUPPORT = "support" | ||||
|     ALL_USER_TYPES = (SUPPORT,) | ||||
| 
 | ||||
| 
 | ||||
| class RelationTypes(object): | ||||
|     """The types of relations known to this server. | ||||
|     """ | ||||
|     ANNOTATION = "m.annotation" | ||||
|     REPLACES = "m.replaces" | ||||
|     REFERENCES = "m.references" | ||||
|  |  | |||
|  | @ -101,6 +101,11 @@ class ServerConfig(Config): | |||
|             "block_non_admin_invites", False, | ||||
|         ) | ||||
| 
 | ||||
|         # Whether to enable experimental MSC1849 (aka relations) support | ||||
|         self.experimental_msc1849_support_enabled = config.get( | ||||
|             "experimental_msc1849_support_enabled", False, | ||||
|         ) | ||||
| 
 | ||||
|         # Options to control access by tracking MAU | ||||
|         self.limit_usage_by_mau = config.get("limit_usage_by_mau", False) | ||||
|         self.max_mau_value = 0 | ||||
|  |  | |||
|  | @ -21,7 +21,7 @@ from frozendict import frozendict | |||
| 
 | ||||
| from twisted.internet import defer | ||||
| 
 | ||||
| from synapse.api.constants import EventTypes | ||||
| from synapse.api.constants import EventTypes, RelationTypes | ||||
| from synapse.util.async_helpers import yieldable_gather_results | ||||
| 
 | ||||
| from . import EventBase | ||||
|  | @ -324,8 +324,12 @@ class EventClientSerializer(object): | |||
|     """ | ||||
| 
 | ||||
|     def __init__(self, hs): | ||||
|         pass | ||||
|         self.store = hs.get_datastore() | ||||
|         self.experimental_msc1849_support_enabled = ( | ||||
|             hs.config.experimental_msc1849_support_enabled | ||||
|         ) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def serialize_event(self, event, time_now, **kwargs): | ||||
|         """Serializes a single event. | ||||
| 
 | ||||
|  | @ -337,8 +341,52 @@ class EventClientSerializer(object): | |||
|         Returns: | ||||
|             Deferred[dict]: The serialized event | ||||
|         """ | ||||
|         event = serialize_event(event, time_now, **kwargs) | ||||
|         return defer.succeed(event) | ||||
|         # To handle the case of presence events and the like | ||||
|         if not isinstance(event, EventBase): | ||||
|             defer.returnValue(event) | ||||
| 
 | ||||
|         event_id = event.event_id | ||||
|         serialized_event = serialize_event(event, time_now, **kwargs) | ||||
| 
 | ||||
|         # If MSC1849 is enabled then we need to look if thre are any relations | ||||
|         # we need to bundle in with the event | ||||
|         if self.experimental_msc1849_support_enabled: | ||||
|             annotations = yield self.store.get_aggregation_groups_for_event( | ||||
|                 event_id, | ||||
|             ) | ||||
|             references = yield self.store.get_relations_for_event( | ||||
|                 event_id, RelationTypes.REFERENCES, direction="f", | ||||
|             ) | ||||
| 
 | ||||
|             if annotations.chunk: | ||||
|                 r = serialized_event["unsigned"].setdefault("m.relations", {}) | ||||
|                 r[RelationTypes.ANNOTATION] = annotations.to_dict() | ||||
| 
 | ||||
|             if references.chunk: | ||||
|                 r = serialized_event["unsigned"].setdefault("m.relations", {}) | ||||
|                 r[RelationTypes.REFERENCES] = references.to_dict() | ||||
| 
 | ||||
|             edit = None | ||||
|             if event.type == EventTypes.Message: | ||||
|                 edit = yield self.store.get_applicable_edit(event_id) | ||||
| 
 | ||||
|             if edit: | ||||
|                 # If there is an edit replace the content, preserving existing | ||||
|                 # relations. | ||||
| 
 | ||||
|                 relations = event.content.get("m.relates_to") | ||||
|                 serialized_event["content"] = edit.content.get("m.new_content", {}) | ||||
|                 if relations: | ||||
|                     serialized_event["content"]["m.relates_to"] = relations | ||||
|                 else: | ||||
|                     serialized_event["content"].pop("m.relates_to", None) | ||||
| 
 | ||||
|                 r = serialized_event["unsigned"].setdefault("m.relations", {}) | ||||
|                 r[RelationTypes.REPLACES] = { | ||||
|                     "event_id": edit.event_id, | ||||
|                 } | ||||
| 
 | ||||
|         defer.returnValue(serialized_event) | ||||
| 
 | ||||
|     def serialize_events(self, events, time_now, **kwargs): | ||||
|         """Serializes multiple events. | ||||
|  |  | |||
|  | @ -23,6 +23,7 @@ from synapse.replication.tcp.streams.events import ( | |||
| from synapse.storage.event_federation import EventFederationWorkerStore | ||||
| from synapse.storage.event_push_actions import EventPushActionsWorkerStore | ||||
| from synapse.storage.events_worker import EventsWorkerStore | ||||
| from synapse.storage.relations import RelationsWorkerStore | ||||
| from synapse.storage.roommember import RoomMemberWorkerStore | ||||
| from synapse.storage.signatures import SignatureWorkerStore | ||||
| from synapse.storage.state import StateGroupWorkerStore | ||||
|  | @ -52,6 +53,7 @@ class SlavedEventStore(EventFederationWorkerStore, | |||
|                        EventsWorkerStore, | ||||
|                        SignatureWorkerStore, | ||||
|                        UserErasureWorkerStore, | ||||
|                        RelationsWorkerStore, | ||||
|                        BaseSlavedStore): | ||||
| 
 | ||||
|     def __init__(self, db_conn, hs): | ||||
|  | @ -89,7 +91,7 @@ class SlavedEventStore(EventFederationWorkerStore, | |||
|             for row in rows: | ||||
|                 self.invalidate_caches_for_event( | ||||
|                     -token, row.event_id, row.room_id, row.type, row.state_key, | ||||
|                     row.redacts, | ||||
|                     row.redacts, row.relates_to, | ||||
|                     backfilled=True, | ||||
|                 ) | ||||
|         return super(SlavedEventStore, self).process_replication_rows( | ||||
|  | @ -102,7 +104,7 @@ class SlavedEventStore(EventFederationWorkerStore, | |||
|         if row.type == EventsStreamEventRow.TypeId: | ||||
|             self.invalidate_caches_for_event( | ||||
|                 token, data.event_id, data.room_id, data.type, data.state_key, | ||||
|                 data.redacts, | ||||
|                 data.redacts, data.relates_to, | ||||
|                 backfilled=False, | ||||
|             ) | ||||
|         elif row.type == EventsStreamCurrentStateRow.TypeId: | ||||
|  | @ -114,7 +116,8 @@ class SlavedEventStore(EventFederationWorkerStore, | |||
|             raise Exception("Unknown events stream row type %s" % (row.type, )) | ||||
| 
 | ||||
|     def invalidate_caches_for_event(self, stream_ordering, event_id, room_id, | ||||
|                                     etype, state_key, redacts, backfilled): | ||||
|                                     etype, state_key, redacts, relates_to, | ||||
|                                     backfilled): | ||||
|         self._invalidate_get_event_cache(event_id) | ||||
| 
 | ||||
|         self.get_latest_event_ids_in_room.invalidate((room_id,)) | ||||
|  | @ -136,3 +139,8 @@ class SlavedEventStore(EventFederationWorkerStore, | |||
|                 state_key, stream_ordering | ||||
|             ) | ||||
|             self.get_invited_rooms_for_user.invalidate((state_key,)) | ||||
| 
 | ||||
|         if relates_to: | ||||
|             self.get_relations_for_event.invalidate_many((relates_to,)) | ||||
|             self.get_aggregation_groups_for_event.invalidate_many((relates_to,)) | ||||
|             self.get_applicable_edit.invalidate((relates_to,)) | ||||
|  |  | |||
|  | @ -32,6 +32,7 @@ BackfillStreamRow = namedtuple("BackfillStreamRow", ( | |||
|     "type",  # str | ||||
|     "state_key",  # str, optional | ||||
|     "redacts",  # str, optional | ||||
|     "relates_to",  # str, optional | ||||
| )) | ||||
| PresenceStreamRow = namedtuple("PresenceStreamRow", ( | ||||
|     "user_id",  # str | ||||
|  |  | |||
|  | @ -80,11 +80,12 @@ class BaseEventsStreamRow(object): | |||
| class EventsStreamEventRow(BaseEventsStreamRow): | ||||
|     TypeId = "ev" | ||||
| 
 | ||||
|     event_id = attr.ib()   # str | ||||
|     room_id = attr.ib()    # str | ||||
|     type = attr.ib()       # str | ||||
|     state_key = attr.ib()  # str, optional | ||||
|     redacts = attr.ib()    # str, optional | ||||
|     event_id = attr.ib()    # str | ||||
|     room_id = attr.ib()     # str | ||||
|     type = attr.ib()        # str | ||||
|     state_key = attr.ib()   # str, optional | ||||
|     redacts = attr.ib()     # str, optional | ||||
|     relates_to = attr.ib()  # str, optional | ||||
| 
 | ||||
| 
 | ||||
| @attr.s(slots=True, frozen=True) | ||||
|  |  | |||
|  | @ -44,6 +44,7 @@ from synapse.rest.client.v2_alpha import ( | |||
|     read_marker, | ||||
|     receipts, | ||||
|     register, | ||||
|     relations, | ||||
|     report_event, | ||||
|     room_keys, | ||||
|     room_upgrade_rest_servlet, | ||||
|  | @ -115,6 +116,7 @@ class ClientRestResource(JsonResource): | |||
|         room_upgrade_rest_servlet.register_servlets(hs, client_resource) | ||||
|         capabilities.register_servlets(hs, client_resource) | ||||
|         account_validity.register_servlets(hs, client_resource) | ||||
|         relations.register_servlets(hs, client_resource) | ||||
| 
 | ||||
|         # moving to /_synapse/admin | ||||
|         synapse.rest.admin.register_servlets_for_client_rest_resource( | ||||
|  |  | |||
|  | @ -0,0 +1,338 @@ | |||
| # -*- coding: utf-8 -*- | ||||
| # Copyright 2019 New Vector Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| """This class implements the proposed relation APIs from MSC 1849. | ||||
| 
 | ||||
| Since the MSC has not been approved all APIs here are unstable and may change at | ||||
| any time to reflect changes in the MSC. | ||||
| """ | ||||
| 
 | ||||
| import logging | ||||
| 
 | ||||
| from twisted.internet import defer | ||||
| 
 | ||||
| from synapse.api.constants import EventTypes, RelationTypes | ||||
| from synapse.api.errors import SynapseError | ||||
| from synapse.http.servlet import ( | ||||
|     RestServlet, | ||||
|     parse_integer, | ||||
|     parse_json_object_from_request, | ||||
|     parse_string, | ||||
| ) | ||||
| from synapse.rest.client.transactions import HttpTransactionCache | ||||
| from synapse.storage.relations import AggregationPaginationToken, RelationPaginationToken | ||||
| 
 | ||||
| from ._base import client_v2_patterns | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| 
 | ||||
| class RelationSendServlet(RestServlet): | ||||
|     """Helper API for sending events that have relation data. | ||||
| 
 | ||||
|     Example API shape to send a 👍 reaction to a room: | ||||
| 
 | ||||
|         POST /rooms/!foo/send_relation/$bar/m.annotation/m.reaction?key=%F0%9F%91%8D | ||||
|         {} | ||||
| 
 | ||||
|         { | ||||
|             "event_id": "$foobar" | ||||
|         } | ||||
|     """ | ||||
| 
 | ||||
|     PATTERN = ( | ||||
|         "/rooms/(?P<room_id>[^/]*)/send_relation" | ||||
|         "/(?P<parent_id>[^/]*)/(?P<relation_type>[^/]*)/(?P<event_type>[^/]*)" | ||||
|     ) | ||||
| 
 | ||||
|     def __init__(self, hs): | ||||
|         super(RelationSendServlet, self).__init__() | ||||
|         self.auth = hs.get_auth() | ||||
|         self.event_creation_handler = hs.get_event_creation_handler() | ||||
|         self.txns = HttpTransactionCache(hs) | ||||
| 
 | ||||
|     def register(self, http_server): | ||||
|         http_server.register_paths( | ||||
|             "POST", | ||||
|             client_v2_patterns(self.PATTERN + "$", releases=()), | ||||
|             self.on_PUT_or_POST, | ||||
|         ) | ||||
|         http_server.register_paths( | ||||
|             "PUT", | ||||
|             client_v2_patterns(self.PATTERN + "/(?P<txn_id>[^/]*)$", releases=()), | ||||
|             self.on_PUT, | ||||
|         ) | ||||
| 
 | ||||
|     def on_PUT(self, request, *args, **kwargs): | ||||
|         return self.txns.fetch_or_execute_request( | ||||
|             request, self.on_PUT_or_POST, request, *args, **kwargs | ||||
|         ) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def on_PUT_or_POST( | ||||
|         self, request, room_id, parent_id, relation_type, event_type, txn_id=None | ||||
|     ): | ||||
|         requester = yield self.auth.get_user_by_req(request, allow_guest=True) | ||||
| 
 | ||||
|         if event_type == EventTypes.Member: | ||||
|             # Add relations to a membership is meaningless, so we just deny it | ||||
|             # at the CS API rather than trying to handle it correctly. | ||||
|             raise SynapseError(400, "Cannot send member events with relations") | ||||
| 
 | ||||
|         content = parse_json_object_from_request(request) | ||||
| 
 | ||||
|         aggregation_key = parse_string(request, "key", encoding="utf-8") | ||||
| 
 | ||||
|         content["m.relates_to"] = { | ||||
|             "event_id": parent_id, | ||||
|             "key": aggregation_key, | ||||
|             "rel_type": relation_type, | ||||
|         } | ||||
| 
 | ||||
|         event_dict = { | ||||
|             "type": event_type, | ||||
|             "content": content, | ||||
|             "room_id": room_id, | ||||
|             "sender": requester.user.to_string(), | ||||
|         } | ||||
| 
 | ||||
|         event = yield self.event_creation_handler.create_and_send_nonmember_event( | ||||
|             requester, event_dict=event_dict, txn_id=txn_id | ||||
|         ) | ||||
| 
 | ||||
|         defer.returnValue((200, {"event_id": event.event_id})) | ||||
| 
 | ||||
| 
 | ||||
| class RelationPaginationServlet(RestServlet): | ||||
|     """API to paginate relations on an event by topological ordering, optionally | ||||
|     filtered by relation type and event type. | ||||
|     """ | ||||
| 
 | ||||
|     PATTERNS = client_v2_patterns( | ||||
|         "/rooms/(?P<room_id>[^/]*)/relations/(?P<parent_id>[^/]*)" | ||||
|         "(/(?P<relation_type>[^/]*)(/(?P<event_type>[^/]*))?)?$", | ||||
|         releases=(), | ||||
|     ) | ||||
| 
 | ||||
|     def __init__(self, hs): | ||||
|         super(RelationPaginationServlet, self).__init__() | ||||
|         self.auth = hs.get_auth() | ||||
|         self.store = hs.get_datastore() | ||||
|         self.clock = hs.get_clock() | ||||
|         self._event_serializer = hs.get_event_client_serializer() | ||||
|         self.event_handler = hs.get_event_handler() | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def on_GET(self, request, room_id, parent_id, relation_type=None, event_type=None): | ||||
|         requester = yield self.auth.get_user_by_req(request, allow_guest=True) | ||||
| 
 | ||||
|         yield self.auth.check_in_room_or_world_readable( | ||||
|             room_id, requester.user.to_string() | ||||
|         ) | ||||
| 
 | ||||
|         # This checks that a) the event exists and b) the user is allowed to | ||||
|         # view it. | ||||
|         yield self.event_handler.get_event(requester.user, room_id, parent_id) | ||||
| 
 | ||||
|         limit = parse_integer(request, "limit", default=5) | ||||
|         from_token = parse_string(request, "from") | ||||
|         to_token = parse_string(request, "to") | ||||
| 
 | ||||
|         if from_token: | ||||
|             from_token = RelationPaginationToken.from_string(from_token) | ||||
| 
 | ||||
|         if to_token: | ||||
|             to_token = RelationPaginationToken.from_string(to_token) | ||||
| 
 | ||||
|         result = yield self.store.get_relations_for_event( | ||||
|             event_id=parent_id, | ||||
|             relation_type=relation_type, | ||||
|             event_type=event_type, | ||||
|             limit=limit, | ||||
|             from_token=from_token, | ||||
|             to_token=to_token, | ||||
|         ) | ||||
| 
 | ||||
|         events = yield self.store.get_events_as_list( | ||||
|             [c["event_id"] for c in result.chunk] | ||||
|         ) | ||||
| 
 | ||||
|         now = self.clock.time_msec() | ||||
|         events = yield self._event_serializer.serialize_events(events, now) | ||||
| 
 | ||||
|         return_value = result.to_dict() | ||||
|         return_value["chunk"] = events | ||||
| 
 | ||||
|         defer.returnValue((200, return_value)) | ||||
| 
 | ||||
| 
 | ||||
| class RelationAggregationPaginationServlet(RestServlet): | ||||
|     """API to paginate aggregation groups of relations, e.g. paginate the | ||||
|     types and counts of the reactions on the events. | ||||
| 
 | ||||
|     Example request and response: | ||||
| 
 | ||||
|         GET /rooms/{room_id}/aggregations/{parent_id} | ||||
| 
 | ||||
|         { | ||||
|             chunk: [ | ||||
|                 { | ||||
|                     "type": "m.reaction", | ||||
|                     "key": "👍", | ||||
|                     "count": 3 | ||||
|                 } | ||||
|             ] | ||||
|         } | ||||
|     """ | ||||
| 
 | ||||
|     PATTERNS = client_v2_patterns( | ||||
|         "/rooms/(?P<room_id>[^/]*)/aggregations/(?P<parent_id>[^/]*)" | ||||
|         "(/(?P<relation_type>[^/]*)(/(?P<event_type>[^/]*))?)?$", | ||||
|         releases=(), | ||||
|     ) | ||||
| 
 | ||||
|     def __init__(self, hs): | ||||
|         super(RelationAggregationPaginationServlet, self).__init__() | ||||
|         self.auth = hs.get_auth() | ||||
|         self.store = hs.get_datastore() | ||||
|         self.event_handler = hs.get_event_handler() | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def on_GET(self, request, room_id, parent_id, relation_type=None, event_type=None): | ||||
|         requester = yield self.auth.get_user_by_req(request, allow_guest=True) | ||||
| 
 | ||||
|         yield self.auth.check_in_room_or_world_readable( | ||||
|             room_id, requester.user.to_string() | ||||
|         ) | ||||
| 
 | ||||
|         # This checks that a) the event exists and b) the user is allowed to | ||||
|         # view it. | ||||
|         yield self.event_handler.get_event(requester.user, room_id, parent_id) | ||||
| 
 | ||||
|         if relation_type not in (RelationTypes.ANNOTATION, None): | ||||
|             raise SynapseError(400, "Relation type must be 'annotation'") | ||||
| 
 | ||||
|         limit = parse_integer(request, "limit", default=5) | ||||
|         from_token = parse_string(request, "from") | ||||
|         to_token = parse_string(request, "to") | ||||
| 
 | ||||
|         if from_token: | ||||
|             from_token = AggregationPaginationToken.from_string(from_token) | ||||
| 
 | ||||
|         if to_token: | ||||
|             to_token = AggregationPaginationToken.from_string(to_token) | ||||
| 
 | ||||
|         res = yield self.store.get_aggregation_groups_for_event( | ||||
|             event_id=parent_id, | ||||
|             event_type=event_type, | ||||
|             limit=limit, | ||||
|             from_token=from_token, | ||||
|             to_token=to_token, | ||||
|         ) | ||||
| 
 | ||||
|         defer.returnValue((200, res.to_dict())) | ||||
| 
 | ||||
| 
 | ||||
| class RelationAggregationGroupPaginationServlet(RestServlet): | ||||
|     """API to paginate within an aggregation group of relations, e.g. paginate | ||||
|     all the 👍 reactions on an event. | ||||
| 
 | ||||
|     Example request and response: | ||||
| 
 | ||||
|         GET /rooms/{room_id}/aggregations/{parent_id}/m.annotation/m.reaction/👍 | ||||
| 
 | ||||
|         { | ||||
|             chunk: [ | ||||
|                 { | ||||
|                     "type": "m.reaction", | ||||
|                     "content": { | ||||
|                         "m.relates_to": { | ||||
|                             "rel_type": "m.annotation", | ||||
|                             "key": "👍" | ||||
|                         } | ||||
|                     } | ||||
|                 }, | ||||
|                 ... | ||||
|             ] | ||||
|         } | ||||
|     """ | ||||
| 
 | ||||
|     PATTERNS = client_v2_patterns( | ||||
|         "/rooms/(?P<room_id>[^/]*)/aggregations/(?P<parent_id>[^/]*)" | ||||
|         "/(?P<relation_type>[^/]*)/(?P<event_type>[^/]*)/(?P<key>[^/]*)$", | ||||
|         releases=(), | ||||
|     ) | ||||
| 
 | ||||
|     def __init__(self, hs): | ||||
|         super(RelationAggregationGroupPaginationServlet, self).__init__() | ||||
|         self.auth = hs.get_auth() | ||||
|         self.store = hs.get_datastore() | ||||
|         self.clock = hs.get_clock() | ||||
|         self._event_serializer = hs.get_event_client_serializer() | ||||
|         self.event_handler = hs.get_event_handler() | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def on_GET(self, request, room_id, parent_id, relation_type, event_type, key): | ||||
|         requester = yield self.auth.get_user_by_req(request, allow_guest=True) | ||||
| 
 | ||||
|         yield self.auth.check_in_room_or_world_readable( | ||||
|             room_id, requester.user.to_string() | ||||
|         ) | ||||
| 
 | ||||
|         # This checks that a) the event exists and b) the user is allowed to | ||||
|         # view it. | ||||
|         yield self.event_handler.get_event(requester.user, room_id, parent_id) | ||||
| 
 | ||||
|         if relation_type != RelationTypes.ANNOTATION: | ||||
|             raise SynapseError(400, "Relation type must be 'annotation'") | ||||
| 
 | ||||
|         limit = parse_integer(request, "limit", default=5) | ||||
|         from_token = parse_string(request, "from") | ||||
|         to_token = parse_string(request, "to") | ||||
| 
 | ||||
|         if from_token: | ||||
|             from_token = RelationPaginationToken.from_string(from_token) | ||||
| 
 | ||||
|         if to_token: | ||||
|             to_token = RelationPaginationToken.from_string(to_token) | ||||
| 
 | ||||
|         result = yield self.store.get_relations_for_event( | ||||
|             event_id=parent_id, | ||||
|             relation_type=relation_type, | ||||
|             event_type=event_type, | ||||
|             aggregation_key=key, | ||||
|             limit=limit, | ||||
|             from_token=from_token, | ||||
|             to_token=to_token, | ||||
|         ) | ||||
| 
 | ||||
|         events = yield self.store.get_events_as_list( | ||||
|             [c["event_id"] for c in result.chunk] | ||||
|         ) | ||||
| 
 | ||||
|         now = self.clock.time_msec() | ||||
|         events = yield self._event_serializer.serialize_events(events, now) | ||||
| 
 | ||||
|         return_value = result.to_dict() | ||||
|         return_value["chunk"] = events | ||||
| 
 | ||||
|         defer.returnValue((200, return_value)) | ||||
| 
 | ||||
| 
 | ||||
| def register_servlets(hs, http_server): | ||||
|     RelationSendServlet(hs).register(http_server) | ||||
|     RelationPaginationServlet(hs).register(http_server) | ||||
|     RelationAggregationPaginationServlet(hs).register(http_server) | ||||
|     RelationAggregationGroupPaginationServlet(hs).register(http_server) | ||||
|  | @ -49,6 +49,7 @@ from .pusher import PusherStore | |||
| from .receipts import ReceiptsStore | ||||
| from .registration import RegistrationStore | ||||
| from .rejections import RejectionsStore | ||||
| from .relations import RelationsStore | ||||
| from .room import RoomStore | ||||
| from .roommember import RoomMemberStore | ||||
| from .search import SearchStore | ||||
|  | @ -99,6 +100,7 @@ class DataStore( | |||
|     GroupServerStore, | ||||
|     UserErasureStore, | ||||
|     MonthlyActiveUsersStore, | ||||
|     RelationsStore, | ||||
| ): | ||||
|     def __init__(self, db_conn, hs): | ||||
|         self.hs = hs | ||||
|  |  | |||
|  | @ -1325,6 +1325,9 @@ class EventsStore( | |||
|                     txn, event.room_id, event.redacts | ||||
|                 ) | ||||
| 
 | ||||
|                 # Remove from relations table. | ||||
|                 self._handle_redaction(txn, event.redacts) | ||||
| 
 | ||||
|         # Update the event_forward_extremities, event_backward_extremities and | ||||
|         # event_edges tables. | ||||
|         self._handle_mult_prev_events( | ||||
|  | @ -1351,6 +1354,8 @@ class EventsStore( | |||
|                 # Insert into the event_search table. | ||||
|                 self._store_guest_access_txn(txn, event) | ||||
| 
 | ||||
|             self._handle_event_relations(txn, event) | ||||
| 
 | ||||
|         # Insert into the room_memberships table. | ||||
|         self._store_room_members_txn( | ||||
|             txn, | ||||
|  | @ -1655,10 +1660,11 @@ class EventsStore( | |||
|         def get_all_new_forward_event_rows(txn): | ||||
|             sql = ( | ||||
|                 "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," | ||||
|                 " state_key, redacts" | ||||
|                 " state_key, redacts, relates_to_id" | ||||
|                 " FROM events AS e" | ||||
|                 " LEFT JOIN redactions USING (event_id)" | ||||
|                 " LEFT JOIN state_events USING (event_id)" | ||||
|                 " LEFT JOIN event_relations USING (event_id)" | ||||
|                 " WHERE ? < stream_ordering AND stream_ordering <= ?" | ||||
|                 " ORDER BY stream_ordering ASC" | ||||
|                 " LIMIT ?" | ||||
|  | @ -1673,11 +1679,12 @@ class EventsStore( | |||
| 
 | ||||
|             sql = ( | ||||
|                 "SELECT event_stream_ordering, e.event_id, e.room_id, e.type," | ||||
|                 " state_key, redacts" | ||||
|                 " state_key, redacts, relates_to_id" | ||||
|                 " FROM events AS e" | ||||
|                 " INNER JOIN ex_outlier_stream USING (event_id)" | ||||
|                 " LEFT JOIN redactions USING (event_id)" | ||||
|                 " LEFT JOIN state_events USING (event_id)" | ||||
|                 " LEFT JOIN event_relations USING (event_id)" | ||||
|                 " WHERE ? < event_stream_ordering" | ||||
|                 " AND event_stream_ordering <= ?" | ||||
|                 " ORDER BY event_stream_ordering DESC" | ||||
|  | @ -1698,10 +1705,11 @@ class EventsStore( | |||
|         def get_all_new_backfill_event_rows(txn): | ||||
|             sql = ( | ||||
|                 "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," | ||||
|                 " state_key, redacts" | ||||
|                 " state_key, redacts, relates_to_id" | ||||
|                 " FROM events AS e" | ||||
|                 " LEFT JOIN redactions USING (event_id)" | ||||
|                 " LEFT JOIN state_events USING (event_id)" | ||||
|                 " LEFT JOIN event_relations USING (event_id)" | ||||
|                 " WHERE ? > stream_ordering AND stream_ordering >= ?" | ||||
|                 " ORDER BY stream_ordering ASC" | ||||
|                 " LIMIT ?" | ||||
|  | @ -1716,11 +1724,12 @@ class EventsStore( | |||
| 
 | ||||
|             sql = ( | ||||
|                 "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type," | ||||
|                 " state_key, redacts" | ||||
|                 " state_key, redacts, relates_to_id" | ||||
|                 " FROM events AS e" | ||||
|                 " INNER JOIN ex_outlier_stream USING (event_id)" | ||||
|                 " LEFT JOIN redactions USING (event_id)" | ||||
|                 " LEFT JOIN state_events USING (event_id)" | ||||
|                 " LEFT JOIN event_relations USING (event_id)" | ||||
|                 " WHERE ? > event_stream_ordering" | ||||
|                 " AND event_stream_ordering >= ?" | ||||
|                 " ORDER BY event_stream_ordering DESC" | ||||
|  |  | |||
|  | @ -0,0 +1,434 @@ | |||
| # -*- coding: utf-8 -*- | ||||
| # Copyright 2019 New Vector Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| import logging | ||||
| 
 | ||||
| import attr | ||||
| 
 | ||||
| from twisted.internet import defer | ||||
| 
 | ||||
| from synapse.api.constants import RelationTypes | ||||
| from synapse.api.errors import SynapseError | ||||
| from synapse.storage._base import SQLBaseStore | ||||
| from synapse.storage.stream import generate_pagination_where_clause | ||||
| from synapse.util.caches.descriptors import cached, cachedInlineCallbacks | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| 
 | ||||
| @attr.s | ||||
| class PaginationChunk(object): | ||||
|     """Returned by relation pagination APIs. | ||||
| 
 | ||||
|     Attributes: | ||||
|         chunk (list): The rows returned by pagination | ||||
|         next_batch (Any|None): Token to fetch next set of results with, if | ||||
|             None then there are no more results. | ||||
|         prev_batch (Any|None): Token to fetch previous set of results with, if | ||||
|             None then there are no previous results. | ||||
|     """ | ||||
| 
 | ||||
|     chunk = attr.ib() | ||||
|     next_batch = attr.ib(default=None) | ||||
|     prev_batch = attr.ib(default=None) | ||||
| 
 | ||||
|     def to_dict(self): | ||||
|         d = {"chunk": self.chunk} | ||||
| 
 | ||||
|         if self.next_batch: | ||||
|             d["next_batch"] = self.next_batch.to_string() | ||||
| 
 | ||||
|         if self.prev_batch: | ||||
|             d["prev_batch"] = self.prev_batch.to_string() | ||||
| 
 | ||||
|         return d | ||||
| 
 | ||||
| 
 | ||||
| @attr.s(frozen=True, slots=True) | ||||
| class RelationPaginationToken(object): | ||||
|     """Pagination token for relation pagination API. | ||||
| 
 | ||||
|     As the results are order by topological ordering, we can use the | ||||
|     `topological_ordering` and `stream_ordering` fields of the events at the | ||||
|     boundaries of the chunk as pagination tokens. | ||||
| 
 | ||||
|     Attributes: | ||||
|         topological (int): The topological ordering of the boundary event | ||||
|         stream (int): The stream ordering of the boundary event. | ||||
|     """ | ||||
| 
 | ||||
|     topological = attr.ib() | ||||
|     stream = attr.ib() | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def from_string(string): | ||||
|         try: | ||||
|             t, s = string.split("-") | ||||
|             return RelationPaginationToken(int(t), int(s)) | ||||
|         except ValueError: | ||||
|             raise SynapseError(400, "Invalid token") | ||||
| 
 | ||||
|     def to_string(self): | ||||
|         return "%d-%d" % (self.topological, self.stream) | ||||
| 
 | ||||
|     def as_tuple(self): | ||||
|         return attr.astuple(self) | ||||
| 
 | ||||
| 
 | ||||
| @attr.s(frozen=True, slots=True) | ||||
| class AggregationPaginationToken(object): | ||||
|     """Pagination token for relation aggregation pagination API. | ||||
| 
 | ||||
|     As the results are order by count and then MAX(stream_ordering) of the | ||||
|     aggregation groups, we can just use them as our pagination token. | ||||
| 
 | ||||
|     Attributes: | ||||
|         count (int): The count of relations in the boundar group. | ||||
|         stream (int): The MAX stream ordering in the boundary group. | ||||
|     """ | ||||
| 
 | ||||
|     count = attr.ib() | ||||
|     stream = attr.ib() | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def from_string(string): | ||||
|         try: | ||||
|             c, s = string.split("-") | ||||
|             return AggregationPaginationToken(int(c), int(s)) | ||||
|         except ValueError: | ||||
|             raise SynapseError(400, "Invalid token") | ||||
| 
 | ||||
|     def to_string(self): | ||||
|         return "%d-%d" % (self.count, self.stream) | ||||
| 
 | ||||
|     def as_tuple(self): | ||||
|         return attr.astuple(self) | ||||
| 
 | ||||
| 
 | ||||
| class RelationsWorkerStore(SQLBaseStore): | ||||
|     @cached(tree=True) | ||||
|     def get_relations_for_event( | ||||
|         self, | ||||
|         event_id, | ||||
|         relation_type=None, | ||||
|         event_type=None, | ||||
|         aggregation_key=None, | ||||
|         limit=5, | ||||
|         direction="b", | ||||
|         from_token=None, | ||||
|         to_token=None, | ||||
|     ): | ||||
|         """Get a list of relations for an event, ordered by topological ordering. | ||||
| 
 | ||||
|         Args: | ||||
|             event_id (str): Fetch events that relate to this event ID. | ||||
|             relation_type (str|None): Only fetch events with this relation | ||||
|                 type, if given. | ||||
|             event_type (str|None): Only fetch events with this event type, if | ||||
|                 given. | ||||
|             aggregation_key (str|None): Only fetch events with this aggregation | ||||
|                 key, if given. | ||||
|             limit (int): Only fetch the most recent `limit` events. | ||||
|             direction (str): Whether to fetch the most recent first (`"b"`) or | ||||
|                 the oldest first (`"f"`). | ||||
|             from_token (RelationPaginationToken|None): Fetch rows from the given | ||||
|                 token, or from the start if None. | ||||
|             to_token (RelationPaginationToken|None): Fetch rows up to the given | ||||
|                 token, or up to the end if None. | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred[PaginationChunk]: List of event IDs that match relations | ||||
|             requested. The rows are of the form `{"event_id": "..."}`. | ||||
|         """ | ||||
| 
 | ||||
|         where_clause = ["relates_to_id = ?"] | ||||
|         where_args = [event_id] | ||||
| 
 | ||||
|         if relation_type is not None: | ||||
|             where_clause.append("relation_type = ?") | ||||
|             where_args.append(relation_type) | ||||
| 
 | ||||
|         if event_type is not None: | ||||
|             where_clause.append("type = ?") | ||||
|             where_args.append(event_type) | ||||
| 
 | ||||
|         if aggregation_key: | ||||
|             where_clause.append("aggregation_key = ?") | ||||
|             where_args.append(aggregation_key) | ||||
| 
 | ||||
|         pagination_clause = generate_pagination_where_clause( | ||||
|             direction=direction, | ||||
|             column_names=("topological_ordering", "stream_ordering"), | ||||
|             from_token=attr.astuple(from_token) if from_token else None, | ||||
|             to_token=attr.astuple(to_token) if to_token else None, | ||||
|             engine=self.database_engine, | ||||
|         ) | ||||
| 
 | ||||
|         if pagination_clause: | ||||
|             where_clause.append(pagination_clause) | ||||
| 
 | ||||
|         if direction == "b": | ||||
|             order = "DESC" | ||||
|         else: | ||||
|             order = "ASC" | ||||
| 
 | ||||
|         sql = """ | ||||
|             SELECT event_id, topological_ordering, stream_ordering | ||||
|             FROM event_relations | ||||
|             INNER JOIN events USING (event_id) | ||||
|             WHERE %s | ||||
|             ORDER BY topological_ordering %s, stream_ordering %s | ||||
|             LIMIT ? | ||||
|         """ % ( | ||||
|             " AND ".join(where_clause), | ||||
|             order, | ||||
|             order, | ||||
|         ) | ||||
| 
 | ||||
|         def _get_recent_references_for_event_txn(txn): | ||||
|             txn.execute(sql, where_args + [limit + 1]) | ||||
| 
 | ||||
|             last_topo_id = None | ||||
|             last_stream_id = None | ||||
|             events = [] | ||||
|             for row in txn: | ||||
|                 events.append({"event_id": row[0]}) | ||||
|                 last_topo_id = row[1] | ||||
|                 last_stream_id = row[2] | ||||
| 
 | ||||
|             next_batch = None | ||||
|             if len(events) > limit and last_topo_id and last_stream_id: | ||||
|                 next_batch = RelationPaginationToken(last_topo_id, last_stream_id) | ||||
| 
 | ||||
|             return PaginationChunk( | ||||
|                 chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token | ||||
|             ) | ||||
| 
 | ||||
|         return self.runInteraction( | ||||
|             "get_recent_references_for_event", _get_recent_references_for_event_txn | ||||
|         ) | ||||
| 
 | ||||
|     @cached(tree=True) | ||||
|     def get_aggregation_groups_for_event( | ||||
|         self, | ||||
|         event_id, | ||||
|         event_type=None, | ||||
|         limit=5, | ||||
|         direction="b", | ||||
|         from_token=None, | ||||
|         to_token=None, | ||||
|     ): | ||||
|         """Get a list of annotations on the event, grouped by event type and | ||||
|         aggregation key, sorted by count. | ||||
| 
 | ||||
|         This is used e.g. to get the what and how many reactions have happend | ||||
|         on an event. | ||||
| 
 | ||||
|         Args: | ||||
|             event_id (str): Fetch events that relate to this event ID. | ||||
|             event_type (str|None): Only fetch events with this event type, if | ||||
|                 given. | ||||
|             limit (int): Only fetch the `limit` groups. | ||||
|             direction (str): Whether to fetch the highest count first (`"b"`) or | ||||
|                 the lowest count first (`"f"`). | ||||
|             from_token (AggregationPaginationToken|None): Fetch rows from the | ||||
|                 given token, or from the start if None. | ||||
|             to_token (AggregationPaginationToken|None): Fetch rows up to the | ||||
|                 given token, or up to the end if None. | ||||
| 
 | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred[PaginationChunk]: List of groups of annotations that | ||||
|             match. Each row is a dict with `type`, `key` and `count` fields. | ||||
|         """ | ||||
| 
 | ||||
|         where_clause = ["relates_to_id = ?", "relation_type = ?"] | ||||
|         where_args = [event_id, RelationTypes.ANNOTATION] | ||||
| 
 | ||||
|         if event_type: | ||||
|             where_clause.append("type = ?") | ||||
|             where_args.append(event_type) | ||||
| 
 | ||||
|         having_clause = generate_pagination_where_clause( | ||||
|             direction=direction, | ||||
|             column_names=("COUNT(*)", "MAX(stream_ordering)"), | ||||
|             from_token=attr.astuple(from_token) if from_token else None, | ||||
|             to_token=attr.astuple(to_token) if to_token else None, | ||||
|             engine=self.database_engine, | ||||
|         ) | ||||
| 
 | ||||
|         if direction == "b": | ||||
|             order = "DESC" | ||||
|         else: | ||||
|             order = "ASC" | ||||
| 
 | ||||
|         if having_clause: | ||||
|             having_clause = "HAVING " + having_clause | ||||
|         else: | ||||
|             having_clause = "" | ||||
| 
 | ||||
|         sql = """ | ||||
|             SELECT type, aggregation_key, COUNT(*), MAX(stream_ordering) | ||||
|             FROM event_relations | ||||
|             INNER JOIN events USING (event_id) | ||||
|             WHERE {where_clause} | ||||
|             GROUP BY relation_type, type, aggregation_key | ||||
|             {having_clause} | ||||
|             ORDER BY COUNT(*) {order}, MAX(stream_ordering) {order} | ||||
|             LIMIT ? | ||||
|         """.format( | ||||
|             where_clause=" AND ".join(where_clause), | ||||
|             order=order, | ||||
|             having_clause=having_clause, | ||||
|         ) | ||||
| 
 | ||||
|         def _get_aggregation_groups_for_event_txn(txn): | ||||
|             txn.execute(sql, where_args + [limit + 1]) | ||||
| 
 | ||||
|             next_batch = None | ||||
|             events = [] | ||||
|             for row in txn: | ||||
|                 events.append({"type": row[0], "key": row[1], "count": row[2]}) | ||||
|                 next_batch = AggregationPaginationToken(row[2], row[3]) | ||||
| 
 | ||||
|             if len(events) <= limit: | ||||
|                 next_batch = None | ||||
| 
 | ||||
|             return PaginationChunk( | ||||
|                 chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token | ||||
|             ) | ||||
| 
 | ||||
|         return self.runInteraction( | ||||
|             "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn | ||||
|         ) | ||||
| 
 | ||||
|     @cachedInlineCallbacks() | ||||
|     def get_applicable_edit(self, event_id): | ||||
|         """Get the most recent edit (if any) that has happened for the given | ||||
|         event. | ||||
| 
 | ||||
|         Correctly handles checking whether edits were allowed to happen. | ||||
| 
 | ||||
|         Args: | ||||
|             event_id (str): The original event ID | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred[EventBase|None]: Returns the most recent edit, if any. | ||||
|         """ | ||||
| 
 | ||||
|         # We only allow edits for `m.room.message` events that have the same sender | ||||
|         # and event type. We can't assert these things during regular event auth so | ||||
|         # we have to do the checks post hoc. | ||||
| 
 | ||||
|         # Fetches latest edit that has the same type and sender as the | ||||
|         # original, and is an `m.room.message`. | ||||
|         sql = """ | ||||
|             SELECT edit.event_id FROM events AS edit | ||||
|             INNER JOIN event_relations USING (event_id) | ||||
|             INNER JOIN events AS original ON | ||||
|                 original.event_id = relates_to_id | ||||
|                 AND edit.type = original.type | ||||
|                 AND edit.sender = original.sender | ||||
|             WHERE | ||||
|                 relates_to_id = ? | ||||
|                 AND relation_type = ? | ||||
|                 AND edit.type = 'm.room.message' | ||||
|             ORDER by edit.origin_server_ts DESC, edit.event_id DESC | ||||
|             LIMIT 1 | ||||
|         """ | ||||
| 
 | ||||
|         def _get_applicable_edit_txn(txn): | ||||
|             txn.execute( | ||||
|                 sql, (event_id, RelationTypes.REPLACES,) | ||||
|             ) | ||||
|             row = txn.fetchone() | ||||
|             if row: | ||||
|                 return row[0] | ||||
| 
 | ||||
|         edit_id = yield self.runInteraction( | ||||
|             "get_applicable_edit", _get_applicable_edit_txn | ||||
|         ) | ||||
| 
 | ||||
|         if not edit_id: | ||||
|             return | ||||
| 
 | ||||
|         edit_event = yield self.get_event(edit_id, allow_none=True) | ||||
|         defer.returnValue(edit_event) | ||||
| 
 | ||||
| 
 | ||||
| class RelationsStore(RelationsWorkerStore): | ||||
|     def _handle_event_relations(self, txn, event): | ||||
|         """Handles inserting relation data during peristence of events | ||||
| 
 | ||||
|         Args: | ||||
|             txn | ||||
|             event (EventBase) | ||||
|         """ | ||||
|         relation = event.content.get("m.relates_to") | ||||
|         if not relation: | ||||
|             # No relations | ||||
|             return | ||||
| 
 | ||||
|         rel_type = relation.get("rel_type") | ||||
|         if rel_type not in ( | ||||
|             RelationTypes.ANNOTATION, | ||||
|             RelationTypes.REFERENCES, | ||||
|             RelationTypes.REPLACES, | ||||
|         ): | ||||
|             # Unknown relation type | ||||
|             return | ||||
| 
 | ||||
|         parent_id = relation.get("event_id") | ||||
|         if not parent_id: | ||||
|             # Invalid relation | ||||
|             return | ||||
| 
 | ||||
|         aggregation_key = relation.get("key") | ||||
| 
 | ||||
|         self._simple_insert_txn( | ||||
|             txn, | ||||
|             table="event_relations", | ||||
|             values={ | ||||
|                 "event_id": event.event_id, | ||||
|                 "relates_to_id": parent_id, | ||||
|                 "relation_type": rel_type, | ||||
|                 "aggregation_key": aggregation_key, | ||||
|             }, | ||||
|         ) | ||||
| 
 | ||||
|         txn.call_after(self.get_relations_for_event.invalidate_many, (parent_id,)) | ||||
|         txn.call_after( | ||||
|             self.get_aggregation_groups_for_event.invalidate_many, (parent_id,) | ||||
|         ) | ||||
| 
 | ||||
|         if rel_type == RelationTypes.REPLACES: | ||||
|             txn.call_after(self.get_applicable_edit.invalidate, (parent_id,)) | ||||
| 
 | ||||
|     def _handle_redaction(self, txn, redacted_event_id): | ||||
|         """Handles receiving a redaction and checking whether we need to remove | ||||
|         any redacted relations from the database. | ||||
| 
 | ||||
|         Args: | ||||
|             txn | ||||
|             redacted_event_id (str): The event that was redacted. | ||||
|         """ | ||||
| 
 | ||||
|         self._simple_delete_txn( | ||||
|             txn, | ||||
|             table="event_relations", | ||||
|             keyvalues={ | ||||
|                 "event_id": redacted_event_id, | ||||
|             } | ||||
|         ) | ||||
|  | @ -0,0 +1,27 @@ | |||
| /* Copyright 2019 New Vector Ltd | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  *    http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  * See the License for the specific language governing permissions and | ||||
|  * limitations under the License. | ||||
|  */ | ||||
| 
 | ||||
| -- Tracks related events, like reactions, replies, edits, etc. Note that things | ||||
| -- in this table are not necessarily "valid", e.g. it may contain edits from | ||||
| -- people who don't have power to edit other peoples events. | ||||
| CREATE TABLE IF NOT EXISTS event_relations ( | ||||
|     event_id TEXT NOT NULL, | ||||
|     relates_to_id TEXT NOT NULL, | ||||
|     relation_type TEXT NOT NULL, | ||||
|     aggregation_key TEXT | ||||
| ); | ||||
| 
 | ||||
| CREATE UNIQUE INDEX event_relations_id ON event_relations(event_id); | ||||
| CREATE INDEX event_relations_relates ON event_relations(relates_to_id, relation_type, aggregation_key); | ||||
|  | @ -0,0 +1,539 @@ | |||
| # -*- coding: utf-8 -*- | ||||
| # Copyright 2019 New Vector Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| import itertools | ||||
| import json | ||||
| 
 | ||||
| import six | ||||
| 
 | ||||
| from synapse.api.constants import EventTypes, RelationTypes | ||||
| from synapse.rest import admin | ||||
| from synapse.rest.client.v1 import login, room | ||||
| from synapse.rest.client.v2_alpha import register, relations | ||||
| 
 | ||||
| from tests import unittest | ||||
| 
 | ||||
| 
 | ||||
| class RelationsTestCase(unittest.HomeserverTestCase): | ||||
|     servlets = [ | ||||
|         relations.register_servlets, | ||||
|         room.register_servlets, | ||||
|         login.register_servlets, | ||||
|         register.register_servlets, | ||||
|         admin.register_servlets_for_client_rest_resource, | ||||
|     ] | ||||
|     hijack_auth = False | ||||
| 
 | ||||
|     def make_homeserver(self, reactor, clock): | ||||
|         # We need to enable msc1849 support for aggregations | ||||
|         config = self.default_config() | ||||
|         config["experimental_msc1849_support_enabled"] = True | ||||
|         return self.setup_test_homeserver(config=config) | ||||
| 
 | ||||
|     def prepare(self, reactor, clock, hs): | ||||
|         self.user_id, self.user_token = self._create_user("alice") | ||||
|         self.user2_id, self.user2_token = self._create_user("bob") | ||||
| 
 | ||||
|         self.room = self.helper.create_room_as(self.user_id, tok=self.user_token) | ||||
|         self.helper.join(self.room, user=self.user2_id, tok=self.user2_token) | ||||
|         res = self.helper.send(self.room, body="Hi!", tok=self.user_token) | ||||
|         self.parent_id = res["event_id"] | ||||
| 
 | ||||
|     def test_send_relation(self): | ||||
|         """Tests that sending a relation using the new /send_relation works | ||||
|         creates the right shape of event. | ||||
|         """ | ||||
| 
 | ||||
|         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key=u"👍") | ||||
|         self.assertEquals(200, channel.code, channel.json_body) | ||||
| 
 | ||||
|         event_id = channel.json_body["event_id"] | ||||
| 
 | ||||
|         request, channel = self.make_request( | ||||
|             "GET", | ||||
|             "/rooms/%s/event/%s" % (self.room, event_id), | ||||
|             access_token=self.user_token, | ||||
|         ) | ||||
|         self.render(request) | ||||
|         self.assertEquals(200, channel.code, channel.json_body) | ||||
| 
 | ||||
|         self.assert_dict( | ||||
|             { | ||||
|                 "type": "m.reaction", | ||||
|                 "sender": self.user_id, | ||||
|                 "content": { | ||||
|                     "m.relates_to": { | ||||
|                         "event_id": self.parent_id, | ||||
|                         "key": u"👍", | ||||
|                         "rel_type": RelationTypes.ANNOTATION, | ||||
|                     } | ||||
|                 }, | ||||
|             }, | ||||
|             channel.json_body, | ||||
|         ) | ||||
| 
 | ||||
|     def test_deny_membership(self): | ||||
|         """Test that we deny relations on membership events | ||||
|         """ | ||||
|         channel = self._send_relation(RelationTypes.ANNOTATION, EventTypes.Member) | ||||
|         self.assertEquals(400, channel.code, channel.json_body) | ||||
| 
 | ||||
|     def test_basic_paginate_relations(self): | ||||
|         """Tests that calling pagination API corectly the latest relations. | ||||
|         """ | ||||
|         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction") | ||||
|         self.assertEquals(200, channel.code, channel.json_body) | ||||
| 
 | ||||
|         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction") | ||||
|         self.assertEquals(200, channel.code, channel.json_body) | ||||
|         annotation_id = channel.json_body["event_id"] | ||||
| 
 | ||||
|         request, channel = self.make_request( | ||||
|             "GET", | ||||
|             "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1" | ||||
|             % (self.room, self.parent_id), | ||||
|             access_token=self.user_token, | ||||
|         ) | ||||
|         self.render(request) | ||||
|         self.assertEquals(200, channel.code, channel.json_body) | ||||
| 
 | ||||
|         # We expect to get back a single pagination result, which is the full | ||||
|         # relation event we sent above. | ||||
|         self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body) | ||||
|         self.assert_dict( | ||||
|             {"event_id": annotation_id, "sender": self.user_id, "type": "m.reaction"}, | ||||
|             channel.json_body["chunk"][0], | ||||
|         ) | ||||
| 
 | ||||
|         # Make sure next_batch has something in it that looks like it could be a | ||||
|         # valid token. | ||||
|         self.assertIsInstance( | ||||
|             channel.json_body.get("next_batch"), six.string_types, channel.json_body | ||||
|         ) | ||||
| 
 | ||||
|     def test_repeated_paginate_relations(self): | ||||
|         """Test that if we paginate using a limit and tokens then we get the | ||||
|         expected events. | ||||
|         """ | ||||
| 
 | ||||
|         expected_event_ids = [] | ||||
|         for _ in range(10): | ||||
|             channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction") | ||||
|             self.assertEquals(200, channel.code, channel.json_body) | ||||
|             expected_event_ids.append(channel.json_body["event_id"]) | ||||
| 
 | ||||
|         prev_token = None | ||||
|         found_event_ids = [] | ||||
|         for _ in range(20): | ||||
|             from_token = "" | ||||
|             if prev_token: | ||||
|                 from_token = "&from=" + prev_token | ||||
| 
 | ||||
|             request, channel = self.make_request( | ||||
|                 "GET", | ||||
|                 "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1%s" | ||||
|                 % (self.room, self.parent_id, from_token), | ||||
|                 access_token=self.user_token, | ||||
|             ) | ||||
|             self.render(request) | ||||
|             self.assertEquals(200, channel.code, channel.json_body) | ||||
| 
 | ||||
|             found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) | ||||
|             next_batch = channel.json_body.get("next_batch") | ||||
| 
 | ||||
|             self.assertNotEquals(prev_token, next_batch) | ||||
|             prev_token = next_batch | ||||
| 
 | ||||
|             if not prev_token: | ||||
|                 break | ||||
| 
 | ||||
|         # We paginated backwards, so reverse | ||||
|         found_event_ids.reverse() | ||||
|         self.assertEquals(found_event_ids, expected_event_ids) | ||||
| 
 | ||||
|     def test_aggregation_pagination_groups(self): | ||||
|         """Test that we can paginate annotation groups correctly. | ||||
|         """ | ||||
| 
 | ||||
|         # We need to create ten separate users to send each reaction. | ||||
|         access_tokens = [self.user_token, self.user2_token] | ||||
|         idx = 0 | ||||
|         while len(access_tokens) < 10: | ||||
|             user_id, token = self._create_user("test" + str(idx)) | ||||
|             idx += 1 | ||||
| 
 | ||||
|             self.helper.join(self.room, user=user_id, tok=token) | ||||
|             access_tokens.append(token) | ||||
| 
 | ||||
|         idx = 0 | ||||
|         sent_groups = {u"👍": 10, u"a": 7, u"b": 5, u"c": 3, u"d": 2, u"e": 1} | ||||
|         for key in itertools.chain.from_iterable( | ||||
|             itertools.repeat(key, num) for key, num in sent_groups.items() | ||||
|         ): | ||||
|             channel = self._send_relation( | ||||
|                 RelationTypes.ANNOTATION, | ||||
|                 "m.reaction", | ||||
|                 key=key, | ||||
|                 access_token=access_tokens[idx], | ||||
|             ) | ||||
|             self.assertEquals(200, channel.code, channel.json_body) | ||||
| 
 | ||||
|             idx += 1 | ||||
|             idx %= len(access_tokens) | ||||
| 
 | ||||
|         prev_token = None | ||||
|         found_groups = {} | ||||
|         for _ in range(20): | ||||
|             from_token = "" | ||||
|             if prev_token: | ||||
|                 from_token = "&from=" + prev_token | ||||
| 
 | ||||
|             request, channel = self.make_request( | ||||
|                 "GET", | ||||
|                 "/_matrix/client/unstable/rooms/%s/aggregations/%s?limit=1%s" | ||||
|                 % (self.room, self.parent_id, from_token), | ||||
|                 access_token=self.user_token, | ||||
|             ) | ||||
|             self.render(request) | ||||
|             self.assertEquals(200, channel.code, channel.json_body) | ||||
| 
 | ||||
|             self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) | ||||
| 
 | ||||
|             for groups in channel.json_body["chunk"]: | ||||
|                 # We only expect reactions | ||||
|                 self.assertEqual(groups["type"], "m.reaction", channel.json_body) | ||||
| 
 | ||||
|                 # We should only see each key once | ||||
|                 self.assertNotIn(groups["key"], found_groups, channel.json_body) | ||||
| 
 | ||||
|                 found_groups[groups["key"]] = groups["count"] | ||||
| 
 | ||||
|             next_batch = channel.json_body.get("next_batch") | ||||
| 
 | ||||
|             self.assertNotEquals(prev_token, next_batch) | ||||
|             prev_token = next_batch | ||||
| 
 | ||||
|             if not prev_token: | ||||
|                 break | ||||
| 
 | ||||
|         self.assertEquals(sent_groups, found_groups) | ||||
| 
 | ||||
|     def test_aggregation_pagination_within_group(self): | ||||
|         """Test that we can paginate within an annotation group. | ||||
|         """ | ||||
| 
 | ||||
|         expected_event_ids = [] | ||||
|         for _ in range(10): | ||||
|             channel = self._send_relation( | ||||
|                 RelationTypes.ANNOTATION, "m.reaction", key=u"👍" | ||||
|             ) | ||||
|             self.assertEquals(200, channel.code, channel.json_body) | ||||
|             expected_event_ids.append(channel.json_body["event_id"]) | ||||
| 
 | ||||
|         # Also send a different type of reaction so that we test we don't see it | ||||
|         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") | ||||
|         self.assertEquals(200, channel.code, channel.json_body) | ||||
| 
 | ||||
|         prev_token = None | ||||
|         found_event_ids = [] | ||||
|         encoded_key = six.moves.urllib.parse.quote_plus(u"👍".encode("utf-8")) | ||||
|         for _ in range(20): | ||||
|             from_token = "" | ||||
|             if prev_token: | ||||
|                 from_token = "&from=" + prev_token | ||||
| 
 | ||||
|             request, channel = self.make_request( | ||||
|                 "GET", | ||||
|                 "/_matrix/client/unstable/rooms/%s" | ||||
|                 "/aggregations/%s/%s/m.reaction/%s?limit=1%s" | ||||
|                 % ( | ||||
|                     self.room, | ||||
|                     self.parent_id, | ||||
|                     RelationTypes.ANNOTATION, | ||||
|                     encoded_key, | ||||
|                     from_token, | ||||
|                 ), | ||||
|                 access_token=self.user_token, | ||||
|             ) | ||||
|             self.render(request) | ||||
|             self.assertEquals(200, channel.code, channel.json_body) | ||||
| 
 | ||||
|             self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) | ||||
| 
 | ||||
|             found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) | ||||
| 
 | ||||
|             next_batch = channel.json_body.get("next_batch") | ||||
| 
 | ||||
|             self.assertNotEquals(prev_token, next_batch) | ||||
|             prev_token = next_batch | ||||
| 
 | ||||
|             if not prev_token: | ||||
|                 break | ||||
| 
 | ||||
|         # We paginated backwards, so reverse | ||||
|         found_event_ids.reverse() | ||||
|         self.assertEquals(found_event_ids, expected_event_ids) | ||||
| 
 | ||||
|     def test_aggregation(self): | ||||
|         """Test that annotations get correctly aggregated. | ||||
|         """ | ||||
| 
 | ||||
|         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") | ||||
|         self.assertEquals(200, channel.code, channel.json_body) | ||||
| 
 | ||||
|         channel = self._send_relation( | ||||
|             RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token | ||||
|         ) | ||||
|         self.assertEquals(200, channel.code, channel.json_body) | ||||
| 
 | ||||
|         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") | ||||
|         self.assertEquals(200, channel.code, channel.json_body) | ||||
| 
 | ||||
|         request, channel = self.make_request( | ||||
|             "GET", | ||||
|             "/_matrix/client/unstable/rooms/%s/aggregations/%s" | ||||
|             % (self.room, self.parent_id), | ||||
|             access_token=self.user_token, | ||||
|         ) | ||||
|         self.render(request) | ||||
|         self.assertEquals(200, channel.code, channel.json_body) | ||||
| 
 | ||||
|         self.assertEquals( | ||||
|             channel.json_body, | ||||
|             { | ||||
|                 "chunk": [ | ||||
|                     {"type": "m.reaction", "key": "a", "count": 2}, | ||||
|                     {"type": "m.reaction", "key": "b", "count": 1}, | ||||
|                 ] | ||||
|             }, | ||||
|         ) | ||||
| 
 | ||||
|     def test_aggregation_redactions(self): | ||||
|         """Test that annotations get correctly aggregated after a redaction. | ||||
|         """ | ||||
| 
 | ||||
|         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") | ||||
|         self.assertEquals(200, channel.code, channel.json_body) | ||||
|         to_redact_event_id = channel.json_body["event_id"] | ||||
| 
 | ||||
|         channel = self._send_relation( | ||||
|             RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token | ||||
|         ) | ||||
|         self.assertEquals(200, channel.code, channel.json_body) | ||||
| 
 | ||||
|         # Now lets redact one of the 'a' reactions | ||||
|         request, channel = self.make_request( | ||||
|             "POST", | ||||
|             "/_matrix/client/r0/rooms/%s/redact/%s" % (self.room, to_redact_event_id), | ||||
|             access_token=self.user_token, | ||||
|             content={}, | ||||
|         ) | ||||
|         self.render(request) | ||||
|         self.assertEquals(200, channel.code, channel.json_body) | ||||
| 
 | ||||
|         request, channel = self.make_request( | ||||
|             "GET", | ||||
|             "/_matrix/client/unstable/rooms/%s/aggregations/%s" | ||||
|             % (self.room, self.parent_id), | ||||
|             access_token=self.user_token, | ||||
|         ) | ||||
|         self.render(request) | ||||
|         self.assertEquals(200, channel.code, channel.json_body) | ||||
| 
 | ||||
|         self.assertEquals( | ||||
|             channel.json_body, | ||||
|             {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]}, | ||||
|         ) | ||||
| 
 | ||||
|     def test_aggregation_must_be_annotation(self): | ||||
|         """Test that aggregations must be annotations. | ||||
|         """ | ||||
| 
 | ||||
|         request, channel = self.make_request( | ||||
|             "GET", | ||||
|             "/_matrix/client/unstable/rooms/%s/aggregations/%s/m.replace?limit=1" | ||||
|             % (self.room, self.parent_id), | ||||
|             access_token=self.user_token, | ||||
|         ) | ||||
|         self.render(request) | ||||
|         self.assertEquals(400, channel.code, channel.json_body) | ||||
| 
 | ||||
|     def test_aggregation_get_event(self): | ||||
|         """Test that annotations and references get correctly bundled when | ||||
|         getting the parent event. | ||||
|         """ | ||||
| 
 | ||||
|         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") | ||||
|         self.assertEquals(200, channel.code, channel.json_body) | ||||
| 
 | ||||
|         channel = self._send_relation( | ||||
|             RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token | ||||
|         ) | ||||
|         self.assertEquals(200, channel.code, channel.json_body) | ||||
| 
 | ||||
|         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") | ||||
|         self.assertEquals(200, channel.code, channel.json_body) | ||||
| 
 | ||||
|         channel = self._send_relation(RelationTypes.REFERENCES, "m.room.test") | ||||
|         self.assertEquals(200, channel.code, channel.json_body) | ||||
|         reply_1 = channel.json_body["event_id"] | ||||
| 
 | ||||
|         channel = self._send_relation(RelationTypes.REFERENCES, "m.room.test") | ||||
|         self.assertEquals(200, channel.code, channel.json_body) | ||||
|         reply_2 = channel.json_body["event_id"] | ||||
| 
 | ||||
|         request, channel = self.make_request( | ||||
|             "GET", | ||||
|             "/rooms/%s/event/%s" % (self.room, self.parent_id), | ||||
|             access_token=self.user_token, | ||||
|         ) | ||||
|         self.render(request) | ||||
|         self.assertEquals(200, channel.code, channel.json_body) | ||||
| 
 | ||||
|         self.assertEquals( | ||||
|             channel.json_body["unsigned"].get("m.relations"), | ||||
|             { | ||||
|                 RelationTypes.ANNOTATION: { | ||||
|                     "chunk": [ | ||||
|                         {"type": "m.reaction", "key": "a", "count": 2}, | ||||
|                         {"type": "m.reaction", "key": "b", "count": 1}, | ||||
|                     ] | ||||
|                 }, | ||||
|                 RelationTypes.REFERENCES: { | ||||
|                     "chunk": [{"event_id": reply_1}, {"event_id": reply_2}] | ||||
|                 }, | ||||
|             }, | ||||
|         ) | ||||
| 
 | ||||
|     def test_edit(self): | ||||
|         """Test that a simple edit works. | ||||
|         """ | ||||
| 
 | ||||
|         new_body = {"msgtype": "m.text", "body": "I've been edited!"} | ||||
|         channel = self._send_relation( | ||||
|             RelationTypes.REPLACES, | ||||
|             "m.room.message", | ||||
|             content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, | ||||
|         ) | ||||
|         self.assertEquals(200, channel.code, channel.json_body) | ||||
| 
 | ||||
|         edit_event_id = channel.json_body["event_id"] | ||||
| 
 | ||||
|         request, channel = self.make_request( | ||||
|             "GET", | ||||
|             "/rooms/%s/event/%s" % (self.room, self.parent_id), | ||||
|             access_token=self.user_token, | ||||
|         ) | ||||
|         self.render(request) | ||||
|         self.assertEquals(200, channel.code, channel.json_body) | ||||
| 
 | ||||
|         self.assertEquals(channel.json_body["content"], new_body) | ||||
| 
 | ||||
|         self.assertEquals( | ||||
|             channel.json_body["unsigned"].get("m.relations"), | ||||
|             {RelationTypes.REPLACES: {"event_id": edit_event_id}}, | ||||
|         ) | ||||
| 
 | ||||
|     def test_multi_edit(self): | ||||
|         """Test that multiple edits, including attempts by people who | ||||
|         shouldn't be allowed, are correctly handled. | ||||
|         """ | ||||
| 
 | ||||
|         channel = self._send_relation( | ||||
|             RelationTypes.REPLACES, | ||||
|             "m.room.message", | ||||
|             content={ | ||||
|                 "msgtype": "m.text", | ||||
|                 "body": "Wibble", | ||||
|                 "m.new_content": {"msgtype": "m.text", "body": "First edit"}, | ||||
|             }, | ||||
|         ) | ||||
|         self.assertEquals(200, channel.code, channel.json_body) | ||||
| 
 | ||||
|         new_body = {"msgtype": "m.text", "body": "I've been edited!"} | ||||
|         channel = self._send_relation( | ||||
|             RelationTypes.REPLACES, | ||||
|             "m.room.message", | ||||
|             content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, | ||||
|         ) | ||||
|         self.assertEquals(200, channel.code, channel.json_body) | ||||
| 
 | ||||
|         edit_event_id = channel.json_body["event_id"] | ||||
| 
 | ||||
|         channel = self._send_relation( | ||||
|             RelationTypes.REPLACES, | ||||
|             "m.room.message.WRONG_TYPE", | ||||
|             content={ | ||||
|                 "msgtype": "m.text", | ||||
|                 "body": "Wibble", | ||||
|                 "m.new_content": {"msgtype": "m.text", "body": "Edit, but wrong type"}, | ||||
|             }, | ||||
|         ) | ||||
|         self.assertEquals(200, channel.code, channel.json_body) | ||||
| 
 | ||||
|         request, channel = self.make_request( | ||||
|             "GET", | ||||
|             "/rooms/%s/event/%s" % (self.room, self.parent_id), | ||||
|             access_token=self.user_token, | ||||
|         ) | ||||
|         self.render(request) | ||||
|         self.assertEquals(200, channel.code, channel.json_body) | ||||
| 
 | ||||
|         self.assertEquals(channel.json_body["content"], new_body) | ||||
| 
 | ||||
|         self.assertEquals( | ||||
|             channel.json_body["unsigned"].get("m.relations"), | ||||
|             {RelationTypes.REPLACES: {"event_id": edit_event_id}}, | ||||
|         ) | ||||
| 
 | ||||
|     def _send_relation( | ||||
|         self, relation_type, event_type, key=None, content={}, access_token=None | ||||
|     ): | ||||
|         """Helper function to send a relation pointing at `self.parent_id` | ||||
| 
 | ||||
|         Args: | ||||
|             relation_type (str): One of `RelationTypes` | ||||
|             event_type (str): The type of the event to create | ||||
|             key (str|None): The aggregation key used for m.annotation relation | ||||
|                 type. | ||||
|             content(dict|None): The content of the created event. | ||||
|             access_token (str|None): The access token used to send the relation, | ||||
|                 defaults to `self.user_token` | ||||
| 
 | ||||
|         Returns: | ||||
|             FakeChannel | ||||
|         """ | ||||
|         if not access_token: | ||||
|             access_token = self.user_token | ||||
| 
 | ||||
|         query = "" | ||||
|         if key: | ||||
|             query = "?key=" + six.moves.urllib.parse.quote_plus(key.encode("utf-8")) | ||||
| 
 | ||||
|         request, channel = self.make_request( | ||||
|             "POST", | ||||
|             "/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s" | ||||
|             % (self.room, self.parent_id, relation_type, event_type, query), | ||||
|             json.dumps(content).encode("utf-8"), | ||||
|             access_token=access_token, | ||||
|         ) | ||||
|         self.render(request) | ||||
|         return channel | ||||
| 
 | ||||
|     def _create_user(self, localpart): | ||||
|         user_id = self.register_user(localpart, "abc123") | ||||
|         access_token = self.login(localpart, "abc123") | ||||
| 
 | ||||
|         return user_id, access_token | ||||
		Loading…
	
		Reference in New Issue
	
	 Erik Johnston
						Erik Johnston