Convert additional databases to async/await (#8199)
							parent
							
								
									5bf8e5f55b
								
							
						
					
					
						commit
						54f8d73c00
					
				|  | @ -0,0 +1 @@ | |||
| Convert various parts of the codebase to async/await. | ||||
|  | @ -18,7 +18,7 @@ | |||
| import calendar | ||||
| import logging | ||||
| import time | ||||
| from typing import Any, Dict, List, Optional | ||||
| from typing import Any, Dict, List, Optional, Tuple | ||||
| 
 | ||||
| from synapse.api.constants import PresenceState | ||||
| from synapse.config.homeserver import HomeServerConfig | ||||
|  | @ -294,16 +294,16 @@ class DataStore( | |||
| 
 | ||||
|         return [UserPresenceState(**row) for row in rows] | ||||
| 
 | ||||
|     def count_daily_users(self): | ||||
|     async def count_daily_users(self) -> int: | ||||
|         """ | ||||
|         Counts the number of users who used this homeserver in the last 24 hours. | ||||
|         """ | ||||
|         yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24) | ||||
|         return self.db_pool.runInteraction( | ||||
|         return await self.db_pool.runInteraction( | ||||
|             "count_daily_users", self._count_users, yesterday | ||||
|         ) | ||||
| 
 | ||||
|     def count_monthly_users(self): | ||||
|     async def count_monthly_users(self) -> int: | ||||
|         """ | ||||
|         Counts the number of users who used this homeserver in the last 30 days. | ||||
|         Note this method is intended for phonehome metrics only and is different | ||||
|  | @ -311,7 +311,7 @@ class DataStore( | |||
|         amongst other things, includes a 3 day grace period before a user counts. | ||||
|         """ | ||||
|         thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) | ||||
|         return self.db_pool.runInteraction( | ||||
|         return await self.db_pool.runInteraction( | ||||
|             "count_monthly_users", self._count_users, thirty_days_ago | ||||
|         ) | ||||
| 
 | ||||
|  | @ -330,15 +330,15 @@ class DataStore( | |||
|         (count,) = txn.fetchone() | ||||
|         return count | ||||
| 
 | ||||
|     def count_r30_users(self): | ||||
|     async def count_r30_users(self) -> Dict[str, int]: | ||||
|         """ | ||||
|         Counts the number of 30 day retained users, defined as:- | ||||
|          * Users who have created their accounts more than 30 days ago | ||||
|          * Where last seen at most 30 days ago | ||||
|          * Where account creation and last_seen are > 30 days apart | ||||
| 
 | ||||
|          Returns counts globaly for a given user as well as breaking | ||||
|          by platform | ||||
|         Returns: | ||||
|              A mapping of counts globally as well as broken out by platform. | ||||
|         """ | ||||
| 
 | ||||
|         def _count_r30_users(txn): | ||||
|  | @ -411,7 +411,7 @@ class DataStore( | |||
| 
 | ||||
|             return results | ||||
| 
 | ||||
|         return self.db_pool.runInteraction("count_r30_users", _count_r30_users) | ||||
|         return await self.db_pool.runInteraction("count_r30_users", _count_r30_users) | ||||
| 
 | ||||
|     def _get_start_of_day(self): | ||||
|         """ | ||||
|  | @ -421,7 +421,7 @@ class DataStore( | |||
|         today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0)) | ||||
|         return today_start * 1000 | ||||
| 
 | ||||
|     def generate_user_daily_visits(self): | ||||
|     async def generate_user_daily_visits(self) -> None: | ||||
|         """ | ||||
|         Generates daily visit data for use in cohort/ retention analysis | ||||
|         """ | ||||
|  | @ -476,7 +476,7 @@ class DataStore( | |||
|             # frequently | ||||
|             self._last_user_visit_update = now | ||||
| 
 | ||||
|         return self.db_pool.runInteraction( | ||||
|         await self.db_pool.runInteraction( | ||||
|             "generate_user_daily_visits", _generate_user_daily_visits | ||||
|         ) | ||||
| 
 | ||||
|  | @ -500,22 +500,28 @@ class DataStore( | |||
|             desc="get_users", | ||||
|         ) | ||||
| 
 | ||||
|     def get_users_paginate( | ||||
|         self, start, limit, user_id=None, name=None, guests=True, deactivated=False | ||||
|     ): | ||||
|     async def get_users_paginate( | ||||
|         self, | ||||
|         start: int, | ||||
|         limit: int, | ||||
|         user_id: Optional[str] = None, | ||||
|         name: Optional[str] = None, | ||||
|         guests: bool = True, | ||||
|         deactivated: bool = False, | ||||
|     ) -> Tuple[List[Dict[str, Any]], int]: | ||||
|         """Function to retrieve a paginated list of users from | ||||
|         users list. This will return a json list of users and the | ||||
|         total number of users matching the filter criteria. | ||||
| 
 | ||||
|         Args: | ||||
|             start (int): start number to begin the query from | ||||
|             limit (int): number of rows to retrieve | ||||
|             user_id (string): search for user_id. ignored if name is not None | ||||
|             name (string): search for local part of user_id or display name | ||||
|             guests (bool): whether to in include guest users | ||||
|             deactivated (bool): whether to include deactivated users | ||||
|             start: start number to begin the query from | ||||
|             limit: number of rows to retrieve | ||||
|             user_id: search for user_id. ignored if name is not None | ||||
|             name: search for local part of user_id or display name | ||||
|             guests: whether to in include guest users | ||||
|             deactivated: whether to include deactivated users | ||||
|         Returns: | ||||
|             defer.Deferred: resolves to list[dict[str, Any]], int | ||||
|             A tuple of a list of mappings from user to information and a count of total users. | ||||
|         """ | ||||
| 
 | ||||
|         def get_users_paginate_txn(txn): | ||||
|  | @ -558,7 +564,7 @@ class DataStore( | |||
|             users = self.db_pool.cursor_to_dict(txn) | ||||
|             return users, count | ||||
| 
 | ||||
|         return self.db_pool.runInteraction( | ||||
|         return await self.db_pool.runInteraction( | ||||
|             "get_users_paginate_txn", get_users_paginate_txn | ||||
|         ) | ||||
| 
 | ||||
|  |  | |||
|  | @ -313,9 +313,9 @@ class DeviceWorkerStore(SQLBaseStore): | |||
| 
 | ||||
|         return results | ||||
| 
 | ||||
|     def _get_last_device_update_for_remote_user( | ||||
|     async def _get_last_device_update_for_remote_user( | ||||
|         self, destination: str, user_id: str, from_stream_id: int | ||||
|     ): | ||||
|     ) -> int: | ||||
|         def f(txn): | ||||
|             prev_sent_id_sql = """ | ||||
|                 SELECT coalesce(max(stream_id), 0) as stream_id | ||||
|  | @ -326,12 +326,16 @@ class DeviceWorkerStore(SQLBaseStore): | |||
|             rows = txn.fetchall() | ||||
|             return rows[0][0] | ||||
| 
 | ||||
|         return self.db_pool.runInteraction("get_last_device_update_for_remote_user", f) | ||||
|         return await self.db_pool.runInteraction( | ||||
|             "get_last_device_update_for_remote_user", f | ||||
|         ) | ||||
| 
 | ||||
|     def mark_as_sent_devices_by_remote(self, destination: str, stream_id: int): | ||||
|     async def mark_as_sent_devices_by_remote( | ||||
|         self, destination: str, stream_id: int | ||||
|     ) -> None: | ||||
|         """Mark that updates have successfully been sent to the destination. | ||||
|         """ | ||||
|         return self.db_pool.runInteraction( | ||||
|         await self.db_pool.runInteraction( | ||||
|             "mark_as_sent_devices_by_remote", | ||||
|             self._mark_as_sent_devices_by_remote_txn, | ||||
|             destination, | ||||
|  | @ -684,7 +688,7 @@ class DeviceWorkerStore(SQLBaseStore): | |||
|             desc="make_remote_user_device_cache_as_stale", | ||||
|         ) | ||||
| 
 | ||||
|     def mark_remote_user_device_list_as_unsubscribed(self, user_id: str): | ||||
|     async def mark_remote_user_device_list_as_unsubscribed(self, user_id: str) -> None: | ||||
|         """Mark that we no longer track device lists for remote user. | ||||
|         """ | ||||
| 
 | ||||
|  | @ -698,7 +702,7 @@ class DeviceWorkerStore(SQLBaseStore): | |||
|                 txn, self.get_device_list_last_stream_id_for_remote, (user_id,) | ||||
|             ) | ||||
| 
 | ||||
|         return self.db_pool.runInteraction( | ||||
|         await self.db_pool.runInteraction( | ||||
|             "mark_remote_user_device_list_as_unsubscribed", | ||||
|             _mark_remote_user_device_list_as_unsubscribed_txn, | ||||
|         ) | ||||
|  | @ -959,9 +963,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): | |||
|             desc="update_device", | ||||
|         ) | ||||
| 
 | ||||
|     def update_remote_device_list_cache_entry( | ||||
|     async def update_remote_device_list_cache_entry( | ||||
|         self, user_id: str, device_id: str, content: JsonDict, stream_id: int | ||||
|     ): | ||||
|     ) -> None: | ||||
|         """Updates a single device in the cache of a remote user's devicelist. | ||||
| 
 | ||||
|         Note: assumes that we are the only thread that can be updating this user's | ||||
|  | @ -972,11 +976,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): | |||
|             device_id: ID of decivice being updated | ||||
|             content: new data on this device | ||||
|             stream_id: the version of the device list | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred[None] | ||||
|         """ | ||||
|         return self.db_pool.runInteraction( | ||||
|         await self.db_pool.runInteraction( | ||||
|             "update_remote_device_list_cache_entry", | ||||
|             self._update_remote_device_list_cache_entry_txn, | ||||
|             user_id, | ||||
|  | @ -1028,9 +1029,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): | |||
|             lock=False, | ||||
|         ) | ||||
| 
 | ||||
|     def update_remote_device_list_cache( | ||||
|     async def update_remote_device_list_cache( | ||||
|         self, user_id: str, devices: List[dict], stream_id: int | ||||
|     ): | ||||
|     ) -> None: | ||||
|         """Replace the entire cache of the remote user's devices. | ||||
| 
 | ||||
|         Note: assumes that we are the only thread that can be updating this user's | ||||
|  | @ -1040,11 +1041,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): | |||
|             user_id: User to update device list for | ||||
|             devices: list of device objects supplied over federation | ||||
|             stream_id: the version of the device list | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred[None] | ||||
|         """ | ||||
|         return self.db_pool.runInteraction( | ||||
|         await self.db_pool.runInteraction( | ||||
|             "update_remote_device_list_cache", | ||||
|             self._update_remote_device_list_cache_txn, | ||||
|             user_id, | ||||
|  | @ -1054,7 +1052,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): | |||
| 
 | ||||
|     def _update_remote_device_list_cache_txn( | ||||
|         self, txn: LoggingTransaction, user_id: str, devices: List[dict], stream_id: int | ||||
|     ): | ||||
|     ) -> None: | ||||
|         self.db_pool.simple_delete_txn( | ||||
|             txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id} | ||||
|         ) | ||||
|  |  | |||
|  | @ -823,20 +823,24 @@ class EventsWorkerStore(SQLBaseStore): | |||
| 
 | ||||
|         return event_dict | ||||
| 
 | ||||
|     def _maybe_redact_event_row(self, original_ev, redactions, event_map): | ||||
|     def _maybe_redact_event_row( | ||||
|         self, | ||||
|         original_ev: EventBase, | ||||
|         redactions: Iterable[str], | ||||
|         event_map: Dict[str, EventBase], | ||||
|     ) -> Optional[EventBase]: | ||||
|         """Given an event object and a list of possible redacting event ids, | ||||
|         determine whether to honour any of those redactions and if so return a redacted | ||||
|         event. | ||||
| 
 | ||||
|         Args: | ||||
|              original_ev (EventBase): | ||||
|              redactions (iterable[str]): list of event ids of potential redaction events | ||||
|              event_map (dict[str, EventBase]): other events which have been fetched, in | ||||
|                  which we can look up the redaaction events. Map from event id to event. | ||||
|              original_ev: The original event. | ||||
|              redactions: list of event ids of potential redaction events | ||||
|              event_map: other events which have been fetched, in which we can | ||||
|                 look up the redaaction events. Map from event id to event. | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred[EventBase|None]: if the event should be redacted, a pruned | ||||
|                 event object. Otherwise, None. | ||||
|             If the event should be redacted, a pruned event object. Otherwise, None. | ||||
|         """ | ||||
|         if original_ev.type == "m.room.create": | ||||
|             # we choose to ignore redactions of m.room.create events. | ||||
|  | @ -946,17 +950,17 @@ class EventsWorkerStore(SQLBaseStore): | |||
|         row = txn.fetchone() | ||||
|         return row[0] if row else 0 | ||||
| 
 | ||||
|     def get_current_state_event_counts(self, room_id): | ||||
|     async def get_current_state_event_counts(self, room_id: str) -> int: | ||||
|         """ | ||||
|         Gets the current number of state events in a room. | ||||
| 
 | ||||
|         Args: | ||||
|             room_id (str) | ||||
|             room_id: The room ID to query. | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred[int] | ||||
|             The current number of state events. | ||||
|         """ | ||||
|         return self.db_pool.runInteraction( | ||||
|         return await self.db_pool.runInteraction( | ||||
|             "get_current_state_event_counts", | ||||
|             self._get_current_state_event_counts_txn, | ||||
|             room_id, | ||||
|  | @ -991,7 +995,9 @@ class EventsWorkerStore(SQLBaseStore): | |||
|         """The current maximum token that events have reached""" | ||||
|         return self._stream_id_gen.get_current_token() | ||||
| 
 | ||||
|     def get_all_new_forward_event_rows(self, last_id, current_id, limit): | ||||
|     async def get_all_new_forward_event_rows( | ||||
|         self, last_id: int, current_id: int, limit: int | ||||
|     ) -> List[Tuple]: | ||||
|         """Returns new events, for the Events replication stream | ||||
| 
 | ||||
|         Args: | ||||
|  | @ -999,7 +1005,7 @@ class EventsWorkerStore(SQLBaseStore): | |||
|             current_id: the maximum stream_id to return up to | ||||
|             limit: the maximum number of rows to return | ||||
| 
 | ||||
|         Returns: Deferred[List[Tuple]] | ||||
|         Returns: | ||||
|             a list of events stream rows. Each tuple consists of a stream id as | ||||
|             the first element, followed by fields suitable for casting into an | ||||
|             EventsStreamRow. | ||||
|  | @ -1020,18 +1026,20 @@ class EventsWorkerStore(SQLBaseStore): | |||
|             txn.execute(sql, (last_id, current_id, limit)) | ||||
|             return txn.fetchall() | ||||
| 
 | ||||
|         return self.db_pool.runInteraction( | ||||
|         return await self.db_pool.runInteraction( | ||||
|             "get_all_new_forward_event_rows", get_all_new_forward_event_rows | ||||
|         ) | ||||
| 
 | ||||
|     def get_ex_outlier_stream_rows(self, last_id, current_id): | ||||
|     async def get_ex_outlier_stream_rows( | ||||
|         self, last_id: int, current_id: int | ||||
|     ) -> List[Tuple]: | ||||
|         """Returns de-outliered events, for the Events replication stream | ||||
| 
 | ||||
|         Args: | ||||
|             last_id: the last stream_id from the previous batch. | ||||
|             current_id: the maximum stream_id to return up to | ||||
| 
 | ||||
|         Returns: Deferred[List[Tuple]] | ||||
|         Returns: | ||||
|             a list of events stream rows. Each tuple consists of a stream id as | ||||
|             the first element, followed by fields suitable for casting into an | ||||
|             EventsStreamRow. | ||||
|  | @ -1054,7 +1062,7 @@ class EventsWorkerStore(SQLBaseStore): | |||
|             txn.execute(sql, (last_id, current_id)) | ||||
|             return txn.fetchall() | ||||
| 
 | ||||
|         return self.db_pool.runInteraction( | ||||
|         return await self.db_pool.runInteraction( | ||||
|             "get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn | ||||
|         ) | ||||
| 
 | ||||
|  | @ -1226,11 +1234,11 @@ class EventsWorkerStore(SQLBaseStore): | |||
| 
 | ||||
|         return (int(res["topological_ordering"]), int(res["stream_ordering"])) | ||||
| 
 | ||||
|     def get_next_event_to_expire(self): | ||||
|     async def get_next_event_to_expire(self) -> Optional[Tuple[str, int]]: | ||||
|         """Retrieve the entry with the lowest expiry timestamp in the event_expiry | ||||
|         table, or None if there's no more event to expire. | ||||
| 
 | ||||
|         Returns: Deferred[Optional[Tuple[str, int]]] | ||||
|         Returns: | ||||
|             A tuple containing the event ID as its first element and an expiry timestamp | ||||
|             as its second one, if there's at least one row in the event_expiry table. | ||||
|             None otherwise. | ||||
|  | @ -1246,6 +1254,6 @@ class EventsWorkerStore(SQLBaseStore): | |||
| 
 | ||||
|             return txn.fetchone() | ||||
| 
 | ||||
|         return self.db_pool.runInteraction( | ||||
|         return await self.db_pool.runInteraction( | ||||
|             desc="get_next_event_to_expire", func=get_next_event_to_expire_txn | ||||
|         ) | ||||
|  |  | |||
|  | @ -14,7 +14,7 @@ | |||
| # limitations under the License. | ||||
| 
 | ||||
| import logging | ||||
| from typing import Any, Tuple | ||||
| from typing import Any, List, Set, Tuple | ||||
| 
 | ||||
| from synapse.api.errors import SynapseError | ||||
| from synapse.storage._base import SQLBaseStore | ||||
|  | @ -25,25 +25,24 @@ logger = logging.getLogger(__name__) | |||
| 
 | ||||
| 
 | ||||
| class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): | ||||
|     def purge_history(self, room_id, token, delete_local_events): | ||||
|     async def purge_history( | ||||
|         self, room_id: str, token: str, delete_local_events: bool | ||||
|     ) -> Set[int]: | ||||
|         """Deletes room history before a certain point | ||||
| 
 | ||||
|         Args: | ||||
|             room_id (str): | ||||
| 
 | ||||
|             token (str): A topological token to delete events before | ||||
| 
 | ||||
|             delete_local_events (bool): | ||||
|             room_id: | ||||
|             token: A topological token to delete events before | ||||
|             delete_local_events: | ||||
|                 if True, we will delete local events as well as remote ones | ||||
|                 (instead of just marking them as outliers and deleting their | ||||
|                 state groups). | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred[set[int]]: The set of state groups that are referenced by | ||||
|             deleted events. | ||||
|             The set of state groups that are referenced by deleted events. | ||||
|         """ | ||||
| 
 | ||||
|         return self.db_pool.runInteraction( | ||||
|         return await self.db_pool.runInteraction( | ||||
|             "purge_history", | ||||
|             self._purge_history_txn, | ||||
|             room_id, | ||||
|  | @ -283,17 +282,18 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): | |||
| 
 | ||||
|         return referenced_state_groups | ||||
| 
 | ||||
|     def purge_room(self, room_id): | ||||
|     async def purge_room(self, room_id: str) -> List[int]: | ||||
|         """Deletes all record of a room | ||||
| 
 | ||||
|         Args: | ||||
|             room_id (str) | ||||
|             room_id | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred[List[int]]: The list of state groups to delete. | ||||
|             The list of state groups to delete. | ||||
|         """ | ||||
| 
 | ||||
|         return self.db_pool.runInteraction("purge_room", self._purge_room_txn, room_id) | ||||
|         return await self.db_pool.runInteraction( | ||||
|             "purge_room", self._purge_room_txn, room_id | ||||
|         ) | ||||
| 
 | ||||
|     def _purge_room_txn(self, txn, room_id): | ||||
|         # First we fetch all the state groups that should be deleted, before | ||||
|  |  | |||
|  | @ -276,12 +276,14 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||
|         } | ||||
|         return results | ||||
| 
 | ||||
|     def get_users_sent_receipts_between(self, last_id: int, current_id: int): | ||||
|     async def get_users_sent_receipts_between( | ||||
|         self, last_id: int, current_id: int | ||||
|     ) -> List[str]: | ||||
|         """Get all users who sent receipts between `last_id` exclusive and | ||||
|         `current_id` inclusive. | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred[List[str]] | ||||
|             The list of users. | ||||
|         """ | ||||
| 
 | ||||
|         if last_id == current_id: | ||||
|  | @ -296,7 +298,7 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||
| 
 | ||||
|             return [r[0] for r in txn] | ||||
| 
 | ||||
|         return self.db_pool.runInteraction( | ||||
|         return await self.db_pool.runInteraction( | ||||
|             "get_users_sent_receipts_between", _get_users_sent_receipts_between_txn | ||||
|         ) | ||||
| 
 | ||||
|  | @ -553,8 +555,10 @@ class ReceiptsStore(ReceiptsWorkerStore): | |||
| 
 | ||||
|         return stream_id, max_persisted_id | ||||
| 
 | ||||
|     def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, data): | ||||
|         return self.db_pool.runInteraction( | ||||
|     async def insert_graph_receipt( | ||||
|         self, room_id, receipt_type, user_id, event_ids, data | ||||
|     ): | ||||
|         return await self.db_pool.runInteraction( | ||||
|             "insert_graph_receipt", | ||||
|             self.insert_graph_receipt_txn, | ||||
|             room_id, | ||||
|  |  | |||
|  | @ -34,38 +34,33 @@ logger = logging.getLogger(__name__) | |||
| 
 | ||||
| class RelationsWorkerStore(SQLBaseStore): | ||||
|     @cached(tree=True) | ||||
|     def get_relations_for_event( | ||||
|     async def get_relations_for_event( | ||||
|         self, | ||||
|         event_id, | ||||
|         relation_type=None, | ||||
|         event_type=None, | ||||
|         aggregation_key=None, | ||||
|         limit=5, | ||||
|         direction="b", | ||||
|         from_token=None, | ||||
|         to_token=None, | ||||
|     ): | ||||
|         event_id: str, | ||||
|         relation_type: Optional[str] = None, | ||||
|         event_type: Optional[str] = None, | ||||
|         aggregation_key: Optional[str] = None, | ||||
|         limit: int = 5, | ||||
|         direction: str = "b", | ||||
|         from_token: Optional[RelationPaginationToken] = None, | ||||
|         to_token: Optional[RelationPaginationToken] = None, | ||||
|     ) -> PaginationChunk: | ||||
|         """Get a list of relations for an event, ordered by topological ordering. | ||||
| 
 | ||||
|         Args: | ||||
|             event_id (str): Fetch events that relate to this event ID. | ||||
|             relation_type (str|None): Only fetch events with this relation | ||||
|                 type, if given. | ||||
|             event_type (str|None): Only fetch events with this event type, if | ||||
|                 given. | ||||
|             aggregation_key (str|None): Only fetch events with this aggregation | ||||
|                 key, if given. | ||||
|             limit (int): Only fetch the most recent `limit` events. | ||||
|             direction (str): Whether to fetch the most recent first (`"b"`) or | ||||
|                 the oldest first (`"f"`). | ||||
|             from_token (RelationPaginationToken|None): Fetch rows from the given | ||||
|                 token, or from the start if None. | ||||
|             to_token (RelationPaginationToken|None): Fetch rows up to the given | ||||
|                 token, or up to the end if None. | ||||
|             event_id: Fetch events that relate to this event ID. | ||||
|             relation_type: Only fetch events with this relation type, if given. | ||||
|             event_type: Only fetch events with this event type, if given. | ||||
|             aggregation_key: Only fetch events with this aggregation key, if given. | ||||
|             limit: Only fetch the most recent `limit` events. | ||||
|             direction: Whether to fetch the most recent first (`"b"`) or the | ||||
|                 oldest first (`"f"`). | ||||
|             from_token: Fetch rows from the given token, or from the start if None. | ||||
|             to_token: Fetch rows up to the given token, or up to the end if None. | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred[PaginationChunk]: List of event IDs that match relations | ||||
|             requested. The rows are of the form `{"event_id": "..."}`. | ||||
|             List of event IDs that match relations requested. The rows are of | ||||
|             the form `{"event_id": "..."}`. | ||||
|         """ | ||||
| 
 | ||||
|         where_clause = ["relates_to_id = ?"] | ||||
|  | @ -131,20 +126,20 @@ class RelationsWorkerStore(SQLBaseStore): | |||
|                 chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token | ||||
|             ) | ||||
| 
 | ||||
|         return self.db_pool.runInteraction( | ||||
|         return await self.db_pool.runInteraction( | ||||
|             "get_recent_references_for_event", _get_recent_references_for_event_txn | ||||
|         ) | ||||
| 
 | ||||
|     @cached(tree=True) | ||||
|     def get_aggregation_groups_for_event( | ||||
|     async def get_aggregation_groups_for_event( | ||||
|         self, | ||||
|         event_id, | ||||
|         event_type=None, | ||||
|         limit=5, | ||||
|         direction="b", | ||||
|         from_token=None, | ||||
|         to_token=None, | ||||
|     ): | ||||
|         event_id: str, | ||||
|         event_type: Optional[str] = None, | ||||
|         limit: int = 5, | ||||
|         direction: str = "b", | ||||
|         from_token: Optional[AggregationPaginationToken] = None, | ||||
|         to_token: Optional[AggregationPaginationToken] = None, | ||||
|     ) -> PaginationChunk: | ||||
|         """Get a list of annotations on the event, grouped by event type and | ||||
|         aggregation key, sorted by count. | ||||
| 
 | ||||
|  | @ -152,21 +147,17 @@ class RelationsWorkerStore(SQLBaseStore): | |||
|         on an event. | ||||
| 
 | ||||
|         Args: | ||||
|             event_id (str): Fetch events that relate to this event ID. | ||||
|             event_type (str|None): Only fetch events with this event type, if | ||||
|                 given. | ||||
|             limit (int): Only fetch the `limit` groups. | ||||
|             direction (str): Whether to fetch the highest count first (`"b"`) or | ||||
|             event_id: Fetch events that relate to this event ID. | ||||
|             event_type: Only fetch events with this event type, if given. | ||||
|             limit: Only fetch the `limit` groups. | ||||
|             direction: Whether to fetch the highest count first (`"b"`) or | ||||
|                 the lowest count first (`"f"`). | ||||
|             from_token (AggregationPaginationToken|None): Fetch rows from the | ||||
|                 given token, or from the start if None. | ||||
|             to_token (AggregationPaginationToken|None): Fetch rows up to the | ||||
|                 given token, or up to the end if None. | ||||
| 
 | ||||
|             from_token: Fetch rows from the given token, or from the start if None. | ||||
|             to_token: Fetch rows up to the given token, or up to the end if None. | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred[PaginationChunk]: List of groups of annotations that | ||||
|             match. Each row is a dict with `type`, `key` and `count` fields. | ||||
|             List of groups of annotations that match. Each row is a dict with | ||||
|             `type`, `key` and `count` fields. | ||||
|         """ | ||||
| 
 | ||||
|         where_clause = ["relates_to_id = ?", "relation_type = ?"] | ||||
|  | @ -225,7 +216,7 @@ class RelationsWorkerStore(SQLBaseStore): | |||
|                 chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token | ||||
|             ) | ||||
| 
 | ||||
|         return self.db_pool.runInteraction( | ||||
|         return await self.db_pool.runInteraction( | ||||
|             "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn | ||||
|         ) | ||||
| 
 | ||||
|  | @ -279,18 +270,20 @@ class RelationsWorkerStore(SQLBaseStore): | |||
| 
 | ||||
|         return await self.get_event(edit_id, allow_none=True) | ||||
| 
 | ||||
|     def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender): | ||||
|     async def has_user_annotated_event( | ||||
|         self, parent_id: str, event_type: str, aggregation_key: str, sender: str | ||||
|     ) -> bool: | ||||
|         """Check if a user has already annotated an event with the same key | ||||
|         (e.g. already liked an event). | ||||
| 
 | ||||
|         Args: | ||||
|             parent_id (str): The event being annotated | ||||
|             event_type (str): The event type of the annotation | ||||
|             aggregation_key (str): The aggregation key of the annotation | ||||
|             sender (str): The sender of the annotation | ||||
|             parent_id: The event being annotated | ||||
|             event_type: The event type of the annotation | ||||
|             aggregation_key: The aggregation key of the annotation | ||||
|             sender: The sender of the annotation | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred[bool] | ||||
|             True if the event is already annotated. | ||||
|         """ | ||||
| 
 | ||||
|         sql = """ | ||||
|  | @ -319,7 +312,7 @@ class RelationsWorkerStore(SQLBaseStore): | |||
| 
 | ||||
|             return bool(txn.fetchone()) | ||||
| 
 | ||||
|         return self.db_pool.runInteraction( | ||||
|         return await self.db_pool.runInteraction( | ||||
|             "get_if_user_has_annotated_event", _get_if_user_has_annotated_event | ||||
|         ) | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 Patrick Cloke
						Patrick Cloke