Store arbitrary relations from events. (#11391)

Instead of only known relation types. This also reworks the background
update for thread relations to crawl events and search for any relation
type, not just threaded relations.
pull/11415/head
Patrick Cloke 2021-11-22 12:01:47 -05:00 committed by GitHub
parent d9e9771d6b
commit 3d893b8cf2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 193 additions and 45 deletions

View File

@ -0,0 +1 @@
Store and allow querying of arbitrary event relations.

View File

@ -1,6 +1,6 @@
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018-2019 New Vector Ltd # Copyright 2018-2019 New Vector Ltd
# Copyright 2019 The Matrix.org Foundation C.I.C. # Copyright 2019-2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -1696,34 +1696,33 @@ class PersistEventsStore:
}, },
) )
def _handle_event_relations(self, txn, event): def _handle_event_relations(
"""Handles inserting relation data during peristence of events self, txn: LoggingTransaction, event: EventBase
) -> None:
"""Handles inserting relation data during persistence of events
Args: Args:
txn txn: The current database transaction.
event (EventBase) event: The event which might have relations.
""" """
relation = event.content.get("m.relates_to") relation = event.content.get("m.relates_to")
if not relation: if not relation:
# No relations # No relations
return return
# Relations must have a type and parent event ID.
rel_type = relation.get("rel_type") rel_type = relation.get("rel_type")
if rel_type not in ( if not isinstance(rel_type, str):
RelationTypes.ANNOTATION,
RelationTypes.REFERENCE,
RelationTypes.REPLACE,
RelationTypes.THREAD,
):
# Unknown relation type
return return
parent_id = relation.get("event_id") parent_id = relation.get("event_id")
if not parent_id: if not isinstance(parent_id, str):
# Invalid relation
return return
aggregation_key = relation.get("key") # Annotations have a key field.
aggregation_key = None
if rel_type == RelationTypes.ANNOTATION:
aggregation_key = relation.get("key")
self.db_pool.simple_insert_txn( self.db_pool.simple_insert_txn(
txn, txn,

View File

@ -1,4 +1,4 @@
# Copyright 2019 The Matrix.org Foundation C.I.C. # Copyright 2019-2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -171,8 +171,14 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
self._purged_chain_cover_index, self._purged_chain_cover_index,
) )
# The event_thread_relation background update was replaced with the
# event_arbitrary_relations one, which handles any relation to avoid
# needed to potentially crawl the entire events table in the future.
self.db_pool.updates.register_noop_background_update("event_thread_relation")
self.db_pool.updates.register_background_update_handler( self.db_pool.updates.register_background_update_handler(
"event_thread_relation", self._event_thread_relation "event_arbitrary_relations",
self._event_arbitrary_relations,
) )
################################################################################ ################################################################################
@ -1099,23 +1105,27 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return result return result
async def _event_thread_relation(self, progress: JsonDict, batch_size: int) -> int: async def _event_arbitrary_relations(
"""Background update handler which will store thread relations for existing events.""" self, progress: JsonDict, batch_size: int
) -> int:
"""Background update handler which will store previously unknown relations for existing events."""
last_event_id = progress.get("last_event_id", "") last_event_id = progress.get("last_event_id", "")
def _event_thread_relation_txn(txn: LoggingTransaction) -> int: def _event_arbitrary_relations_txn(txn: LoggingTransaction) -> int:
# Fetch events and then filter based on whether the event has a
# relation or not.
txn.execute( txn.execute(
""" """
SELECT event_id, json FROM event_json SELECT event_id, json FROM event_json
LEFT JOIN event_relations USING (event_id) WHERE event_id > ?
WHERE event_id > ? AND event_relations.event_id IS NULL
ORDER BY event_id LIMIT ? ORDER BY event_id LIMIT ?
""", """,
(last_event_id, batch_size), (last_event_id, batch_size),
) )
results = list(txn) results = list(txn)
missing_thread_relations = [] # (event_id, parent_id, rel_type) for each relation
relations_to_insert: List[Tuple[str, str, str]] = []
for (event_id, event_json_raw) in results: for (event_id, event_json_raw) in results:
try: try:
event_json = db_to_json(event_json_raw) event_json = db_to_json(event_json_raw)
@ -1127,48 +1137,70 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
) )
continue continue
# If there's no relation (or it is not a thread), skip! # If there's no relation, skip!
relates_to = event_json["content"].get("m.relates_to") relates_to = event_json["content"].get("m.relates_to")
if not relates_to or not isinstance(relates_to, dict): if not relates_to or not isinstance(relates_to, dict):
continue continue
if relates_to.get("rel_type") != RelationTypes.THREAD:
# If the relation type or parent event ID is not a string, skip it.
#
# Do not consider relation types that have existed for a long time,
# since they will already be listed in the `event_relations` table.
rel_type = relates_to.get("rel_type")
if not isinstance(rel_type, str) or rel_type in (
RelationTypes.ANNOTATION,
RelationTypes.REFERENCE,
RelationTypes.REPLACE,
):
continue continue
# Get the parent ID.
parent_id = relates_to.get("event_id") parent_id = relates_to.get("event_id")
if not isinstance(parent_id, str): if not isinstance(parent_id, str):
continue continue
missing_thread_relations.append((event_id, parent_id)) relations_to_insert.append((event_id, parent_id, rel_type))
# Insert the missing data. # Insert the missing data, note that we upsert here in case the event
self.db_pool.simple_insert_many_txn( # has already been processed.
txn=txn, if relations_to_insert:
table="event_relations", self.db_pool.simple_upsert_many_txn(
values=[ txn=txn,
{ table="event_relations",
"event_id": event_id, key_names=("event_id",),
"relates_to_Id": parent_id, key_values=[(r[0],) for r in relations_to_insert],
"relation_type": RelationTypes.THREAD, value_names=("relates_to_id", "relation_type"),
} value_values=[r[1:] for r in relations_to_insert],
for event_id, parent_id in missing_thread_relations )
],
) # Iterate the parent IDs and invalidate caches.
for parent_id in {r[1] for r in relations_to_insert}:
cache_tuple = (parent_id,)
self._invalidate_cache_and_stream(
txn, self.get_relations_for_event, cache_tuple
)
self._invalidate_cache_and_stream(
txn, self.get_aggregation_groups_for_event, cache_tuple
)
self._invalidate_cache_and_stream(
txn, self.get_thread_summary, cache_tuple
)
if results: if results:
latest_event_id = results[-1][0] latest_event_id = results[-1][0]
self.db_pool.updates._background_update_progress_txn( self.db_pool.updates._background_update_progress_txn(
txn, "event_thread_relation", {"last_event_id": latest_event_id} txn, "event_arbitrary_relations", {"last_event_id": latest_event_id}
) )
return len(results) return len(results)
num_rows = await self.db_pool.runInteraction( num_rows = await self.db_pool.runInteraction(
desc="event_thread_relation", func=_event_thread_relation_txn desc="event_arbitrary_relations", func=_event_arbitrary_relations_txn
) )
if not num_rows: if not num_rows:
await self.db_pool.updates._end_background_update("event_thread_relation") await self.db_pool.updates._end_background_update(
"event_arbitrary_relations"
)
return num_rows return num_rows

View File

@ -15,4 +15,4 @@
-- Check old events for thread relations. -- Check old events for thread relations.
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
(6502, 'event_thread_relation', '{}'); (6507, 'event_arbitrary_relations', '{}');

View File

@ -1,4 +1,5 @@
# Copyright 2019 New Vector Ltd # Copyright 2019 New Vector Ltd
# Copyright 2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -46,6 +47,8 @@ class RelationsTestCase(unittest.HomeserverTestCase):
return config return config
def prepare(self, reactor, clock, hs): def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.user_id, self.user_token = self._create_user("alice") self.user_id, self.user_token = self._create_user("alice")
self.user2_id, self.user2_token = self._create_user("bob") self.user2_id, self.user2_token = self._create_user("bob")
@ -765,6 +768,52 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertIn("chunk", channel.json_body) self.assertIn("chunk", channel.json_body)
self.assertEquals(channel.json_body["chunk"], []) self.assertEquals(channel.json_body["chunk"], [])
def test_unknown_relations(self):
"""Unknown relations should be accepted."""
channel = self._send_relation("m.relation.test", "m.room.test")
self.assertEquals(200, channel.code, channel.json_body)
event_id = channel.json_body["event_id"]
channel = self.make_request(
"GET",
"/_matrix/client/unstable/rooms/%s/relations/%s?limit=1"
% (self.room, self.parent_id),
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
# We expect to get back a single pagination result, which is the full
# relation event we sent above.
self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body)
self.assert_dict(
{"event_id": event_id, "sender": self.user_id, "type": "m.room.test"},
channel.json_body["chunk"][0],
)
# We also expect to get the original event (the id of which is self.parent_id)
self.assertEquals(
channel.json_body["original_event"]["event_id"], self.parent_id
)
# When bundling the unknown relation is not included.
channel = self.make_request(
"GET",
"/rooms/%s/event/%s" % (self.room, self.parent_id),
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
self.assertNotIn("m.relations", channel.json_body["unsigned"])
# But unknown relations can be directly queried.
channel = self.make_request(
"GET",
"/_matrix/client/unstable/rooms/%s/aggregations/%s?limit=1"
% (self.room, self.parent_id),
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
self.assertEquals(channel.json_body["chunk"], [])
def _send_relation( def _send_relation(
self, self,
relation_type: str, relation_type: str,
@ -811,3 +860,65 @@ class RelationsTestCase(unittest.HomeserverTestCase):
access_token = self.login(localpart, "abc123") access_token = self.login(localpart, "abc123")
return user_id, access_token return user_id, access_token
def test_background_update(self):
"""Test the event_arbitrary_relations background update."""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
self.assertEquals(200, channel.code, channel.json_body)
annotation_event_id_good = channel.json_body["event_id"]
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="A")
self.assertEquals(200, channel.code, channel.json_body)
annotation_event_id_bad = channel.json_body["event_id"]
channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
self.assertEquals(200, channel.code, channel.json_body)
thread_event_id = channel.json_body["event_id"]
# Clean-up the table as if the inserts did not happen during event creation.
self.get_success(
self.store.db_pool.simple_delete_many(
table="event_relations",
column="event_id",
iterable=(annotation_event_id_bad, thread_event_id),
keyvalues={},
desc="RelationsTestCase.test_background_update",
)
)
# Only the "good" annotation should be found.
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=10",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
self.assertEquals(
[ev["event_id"] for ev in channel.json_body["chunk"]],
[annotation_event_id_good],
)
# Insert and run the background update.
self.get_success(
self.store.db_pool.simple_insert(
"background_updates",
{"update_name": "event_arbitrary_relations", "progress_json": "{}"},
)
)
# Ugh, have to reset this flag
self.store.db_pool.updates._all_done = False
self.wait_for_background_updates()
# The "good" annotation and the thread should be found, but not the "bad"
# annotation.
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=10",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
self.assertCountEqual(
[ev["event_id"] for ev in channel.json_body["chunk"]],
[annotation_event_id_good, thread_event_id],
)

View File

@ -331,7 +331,12 @@ class HomeserverTestCase(TestCase):
time.sleep(0.01) time.sleep(0.01)
def wait_for_background_updates(self) -> None: def wait_for_background_updates(self) -> None:
"""Block until all background database updates have completed.""" """
Block until all background database updates have completed.
Note that callers must ensure that's a store property created on the
testcase.
"""
while not self.get_success( while not self.get_success(
self.store.db_pool.updates.has_completed_background_updates() self.store.db_pool.updates.has_completed_background_updates()
): ):