Store arbitrary relations from events. (#11391)

Instead of only known relation types. This also reworks the background
update for thread relations to crawl events and search for any relation
type, not just threaded relations.
pull/11415/head
Patrick Cloke 2021-11-22 12:01:47 -05:00 committed by GitHub
parent d9e9771d6b
commit 3d893b8cf2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 193 additions and 45 deletions

View File

@ -0,0 +1 @@
Store and allow querying of arbitrary event relations.

View File

@ -1,6 +1,6 @@
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018-2019 New Vector Ltd
# Copyright 2019 The Matrix.org Foundation C.I.C.
# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -1696,34 +1696,33 @@ class PersistEventsStore:
},
)
def _handle_event_relations(self, txn, event):
"""Handles inserting relation data during peristence of events
def _handle_event_relations(
self, txn: LoggingTransaction, event: EventBase
) -> None:
"""Handles inserting relation data during persistence of events
Args:
txn
event (EventBase)
txn: The current database transaction.
event: The event which might have relations.
"""
relation = event.content.get("m.relates_to")
if not relation:
# No relations
return
# Relations must have a type and parent event ID.
rel_type = relation.get("rel_type")
if rel_type not in (
RelationTypes.ANNOTATION,
RelationTypes.REFERENCE,
RelationTypes.REPLACE,
RelationTypes.THREAD,
):
# Unknown relation type
if not isinstance(rel_type, str):
return
parent_id = relation.get("event_id")
if not parent_id:
# Invalid relation
if not isinstance(parent_id, str):
return
aggregation_key = relation.get("key")
# Annotations have a key field.
aggregation_key = None
if rel_type == RelationTypes.ANNOTATION:
aggregation_key = relation.get("key")
self.db_pool.simple_insert_txn(
txn,

View File

@ -1,4 +1,4 @@
# Copyright 2019 The Matrix.org Foundation C.I.C.
# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -171,8 +171,14 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
self._purged_chain_cover_index,
)
# The event_thread_relation background update was replaced with the
# event_arbitrary_relations one, which handles any relation to avoid
# needed to potentially crawl the entire events table in the future.
self.db_pool.updates.register_noop_background_update("event_thread_relation")
self.db_pool.updates.register_background_update_handler(
"event_thread_relation", self._event_thread_relation
"event_arbitrary_relations",
self._event_arbitrary_relations,
)
################################################################################
@ -1099,23 +1105,27 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return result
async def _event_thread_relation(self, progress: JsonDict, batch_size: int) -> int:
"""Background update handler which will store thread relations for existing events."""
async def _event_arbitrary_relations(
self, progress: JsonDict, batch_size: int
) -> int:
"""Background update handler which will store previously unknown relations for existing events."""
last_event_id = progress.get("last_event_id", "")
def _event_thread_relation_txn(txn: LoggingTransaction) -> int:
def _event_arbitrary_relations_txn(txn: LoggingTransaction) -> int:
# Fetch events and then filter based on whether the event has a
# relation or not.
txn.execute(
"""
SELECT event_id, json FROM event_json
LEFT JOIN event_relations USING (event_id)
WHERE event_id > ? AND event_relations.event_id IS NULL
WHERE event_id > ?
ORDER BY event_id LIMIT ?
""",
(last_event_id, batch_size),
)
results = list(txn)
missing_thread_relations = []
# (event_id, parent_id, rel_type) for each relation
relations_to_insert: List[Tuple[str, str, str]] = []
for (event_id, event_json_raw) in results:
try:
event_json = db_to_json(event_json_raw)
@ -1127,48 +1137,70 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
)
continue
# If there's no relation (or it is not a thread), skip!
# If there's no relation, skip!
relates_to = event_json["content"].get("m.relates_to")
if not relates_to or not isinstance(relates_to, dict):
continue
if relates_to.get("rel_type") != RelationTypes.THREAD:
# If the relation type or parent event ID is not a string, skip it.
#
# Do not consider relation types that have existed for a long time,
# since they will already be listed in the `event_relations` table.
rel_type = relates_to.get("rel_type")
if not isinstance(rel_type, str) or rel_type in (
RelationTypes.ANNOTATION,
RelationTypes.REFERENCE,
RelationTypes.REPLACE,
):
continue
# Get the parent ID.
parent_id = relates_to.get("event_id")
if not isinstance(parent_id, str):
continue
missing_thread_relations.append((event_id, parent_id))
relations_to_insert.append((event_id, parent_id, rel_type))
# Insert the missing data.
self.db_pool.simple_insert_many_txn(
txn=txn,
table="event_relations",
values=[
{
"event_id": event_id,
"relates_to_Id": parent_id,
"relation_type": RelationTypes.THREAD,
}
for event_id, parent_id in missing_thread_relations
],
)
# Insert the missing data, note that we upsert here in case the event
# has already been processed.
if relations_to_insert:
self.db_pool.simple_upsert_many_txn(
txn=txn,
table="event_relations",
key_names=("event_id",),
key_values=[(r[0],) for r in relations_to_insert],
value_names=("relates_to_id", "relation_type"),
value_values=[r[1:] for r in relations_to_insert],
)
# Iterate the parent IDs and invalidate caches.
for parent_id in {r[1] for r in relations_to_insert}:
cache_tuple = (parent_id,)
self._invalidate_cache_and_stream(
txn, self.get_relations_for_event, cache_tuple
)
self._invalidate_cache_and_stream(
txn, self.get_aggregation_groups_for_event, cache_tuple
)
self._invalidate_cache_and_stream(
txn, self.get_thread_summary, cache_tuple
)
if results:
latest_event_id = results[-1][0]
self.db_pool.updates._background_update_progress_txn(
txn, "event_thread_relation", {"last_event_id": latest_event_id}
txn, "event_arbitrary_relations", {"last_event_id": latest_event_id}
)
return len(results)
num_rows = await self.db_pool.runInteraction(
desc="event_thread_relation", func=_event_thread_relation_txn
desc="event_arbitrary_relations", func=_event_arbitrary_relations_txn
)
if not num_rows:
await self.db_pool.updates._end_background_update("event_thread_relation")
await self.db_pool.updates._end_background_update(
"event_arbitrary_relations"
)
return num_rows

View File

@ -15,4 +15,4 @@
-- Check old events for thread relations.
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
(6502, 'event_thread_relation', '{}');
(6507, 'event_arbitrary_relations', '{}');

View File

@ -1,4 +1,5 @@
# Copyright 2019 New Vector Ltd
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -46,6 +47,8 @@ class RelationsTestCase(unittest.HomeserverTestCase):
return config
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.user_id, self.user_token = self._create_user("alice")
self.user2_id, self.user2_token = self._create_user("bob")
@ -765,6 +768,52 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertIn("chunk", channel.json_body)
self.assertEquals(channel.json_body["chunk"], [])
def test_unknown_relations(self):
"""Unknown relations should be accepted."""
channel = self._send_relation("m.relation.test", "m.room.test")
self.assertEquals(200, channel.code, channel.json_body)
event_id = channel.json_body["event_id"]
channel = self.make_request(
"GET",
"/_matrix/client/unstable/rooms/%s/relations/%s?limit=1"
% (self.room, self.parent_id),
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
# We expect to get back a single pagination result, which is the full
# relation event we sent above.
self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body)
self.assert_dict(
{"event_id": event_id, "sender": self.user_id, "type": "m.room.test"},
channel.json_body["chunk"][0],
)
# We also expect to get the original event (the id of which is self.parent_id)
self.assertEquals(
channel.json_body["original_event"]["event_id"], self.parent_id
)
# When bundling the unknown relation is not included.
channel = self.make_request(
"GET",
"/rooms/%s/event/%s" % (self.room, self.parent_id),
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
self.assertNotIn("m.relations", channel.json_body["unsigned"])
# But unknown relations can be directly queried.
channel = self.make_request(
"GET",
"/_matrix/client/unstable/rooms/%s/aggregations/%s?limit=1"
% (self.room, self.parent_id),
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
self.assertEquals(channel.json_body["chunk"], [])
def _send_relation(
self,
relation_type: str,
@ -811,3 +860,65 @@ class RelationsTestCase(unittest.HomeserverTestCase):
access_token = self.login(localpart, "abc123")
return user_id, access_token
def test_background_update(self):
"""Test the event_arbitrary_relations background update."""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
self.assertEquals(200, channel.code, channel.json_body)
annotation_event_id_good = channel.json_body["event_id"]
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="A")
self.assertEquals(200, channel.code, channel.json_body)
annotation_event_id_bad = channel.json_body["event_id"]
channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
self.assertEquals(200, channel.code, channel.json_body)
thread_event_id = channel.json_body["event_id"]
# Clean-up the table as if the inserts did not happen during event creation.
self.get_success(
self.store.db_pool.simple_delete_many(
table="event_relations",
column="event_id",
iterable=(annotation_event_id_bad, thread_event_id),
keyvalues={},
desc="RelationsTestCase.test_background_update",
)
)
# Only the "good" annotation should be found.
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=10",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
self.assertEquals(
[ev["event_id"] for ev in channel.json_body["chunk"]],
[annotation_event_id_good],
)
# Insert and run the background update.
self.get_success(
self.store.db_pool.simple_insert(
"background_updates",
{"update_name": "event_arbitrary_relations", "progress_json": "{}"},
)
)
# Ugh, have to reset this flag
self.store.db_pool.updates._all_done = False
self.wait_for_background_updates()
# The "good" annotation and the thread should be found, but not the "bad"
# annotation.
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=10",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
self.assertCountEqual(
[ev["event_id"] for ev in channel.json_body["chunk"]],
[annotation_event_id_good, thread_event_id],
)

View File

@ -331,7 +331,12 @@ class HomeserverTestCase(TestCase):
time.sleep(0.01)
def wait_for_background_updates(self) -> None:
"""Block until all background database updates have completed."""
"""
Block until all background database updates have completed.
Note that callers must ensure that's a store property created on the
testcase.
"""
while not self.get_success(
self.store.db_pool.updates.has_completed_background_updates()
):