Clean-up some receipts code (#12888)

* Properly marks private methods as private.
* Adds missing docstrings.
* Rework inline methods.
pull/12897/head
Patrick Cloke 2022-05-27 07:44:10 -04:00 committed by GitHub
parent c52abc1cfd
commit 724e11d620
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 48 additions and 42 deletions

1
changelog.d/12888.misc Normal file
View File

@ -0,0 +1 @@
Refactor receipt linearization code.

View File

@ -597,7 +597,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
return super().process_replication_rows(stream_name, instance_name, token, rows)
def insert_linearized_receipt_txn(
def _insert_linearized_receipt_txn(
self,
txn: LoggingTransaction,
room_id: str,
@ -686,6 +686,44 @@ class ReceiptsWorkerStore(SQLBaseStore):
return rx_ts
def _graph_to_linear(
self, txn: LoggingTransaction, room_id: str, event_ids: List[str]
) -> str:
"""
Generate a linearized event from a list of events (i.e. a list of forward
extremities in the room).
This should allow for calculation of the correct read receipt even if
servers have different event ordering.
Args:
txn: The transaction
room_id: The room ID the events are in.
event_ids: The list of event IDs to linearize.
Returns:
The linearized event ID.
"""
# TODO: Make this better.
clause, args = make_in_list_sql_clause(
self.database_engine, "event_id", event_ids
)
sql = """
SELECT event_id WHERE room_id = ? AND stream_ordering IN (
SELECT max(stream_ordering) WHERE %s
)
""" % (
clause,
)
txn.execute(sql, [room_id] + list(args))
rows = txn.fetchall()
if rows:
return rows[0][0]
else:
raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
async def insert_receipt(
self,
room_id: str,
@ -712,35 +750,14 @@ class ReceiptsWorkerStore(SQLBaseStore):
linearized_event_id = event_ids[0]
else:
# we need to points in graph -> linearized form.
# TODO: Make this better.
def graph_to_linear(txn: LoggingTransaction) -> str:
clause, args = make_in_list_sql_clause(
self.database_engine, "event_id", event_ids
)
sql = """
SELECT event_id WHERE room_id = ? AND stream_ordering IN (
SELECT max(stream_ordering) WHERE %s
)
""" % (
clause,
)
txn.execute(sql, [room_id] + list(args))
rows = txn.fetchall()
if rows:
return rows[0][0]
else:
raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
linearized_event_id = await self.db_pool.runInteraction(
"insert_receipt_conv", graph_to_linear
"insert_receipt_conv", self._graph_to_linear, room_id, event_ids
)
async with self._receipts_id_gen.get_next() as stream_id: # type: ignore[attr-defined]
event_ts = await self.db_pool.runInteraction(
"insert_linearized_receipt",
self.insert_linearized_receipt_txn,
self._insert_linearized_receipt_txn,
room_id,
receipt_type,
user_id,
@ -761,25 +778,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
now - event_ts,
)
await self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data)
max_persisted_id = self._receipts_id_gen.get_current_token()
return stream_id, max_persisted_id
async def insert_graph_receipt(
self,
room_id: str,
receipt_type: str,
user_id: str,
event_ids: List[str],
data: JsonDict,
) -> None:
assert self._can_write_to_receipts
await self.db_pool.runInteraction(
"insert_graph_receipt",
self.insert_graph_receipt_txn,
self._insert_graph_receipt_txn,
room_id,
receipt_type,
user_id,
@ -787,7 +788,11 @@ class ReceiptsWorkerStore(SQLBaseStore):
data,
)
def insert_graph_receipt_txn(
max_persisted_id = self._receipts_id_gen.get_current_token()
return stream_id, max_persisted_id
def _insert_graph_receipt_txn(
self,
txn: LoggingTransaction,
room_id: str,