Revert "Make all `process_replication_rows` methods async (#13304)" (#13312)

This reverts commit 5d4028f217.
pull/13314/head
Erik Johnston 2022-07-18 14:28:14 +01:00 committed by GitHub
parent cf5fa5063d
commit f721f1baba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 25 additions and 40 deletions

View File

@ -1 +0,0 @@
Make all replication row processing methods asynchronous. Contributed by Nick @ Beeper (@fizzadar).

View File

@ -158,7 +158,7 @@ class FollowerTypingHandler:
except Exception: except Exception:
logger.exception("Error pushing typing notif to remotes") logger.exception("Error pushing typing notif to remotes")
async def process_replication_rows( def process_replication_rows(
self, token: int, rows: List[TypingStream.TypingStreamRow] self, token: int, rows: List[TypingStream.TypingStreamRow]
) -> None: ) -> None:
"""Should be called whenever we receive updates for typing stream.""" """Should be called whenever we receive updates for typing stream."""
@ -444,7 +444,7 @@ class TypingWriterHandler(FollowerTypingHandler):
return rows, current_id, limited return rows, current_id, limited
async def process_replication_rows( def process_replication_rows(
self, token: int, rows: List[TypingStream.TypingStreamRow] self, token: int, rows: List[TypingStream.TypingStreamRow]
) -> None: ) -> None:
# The writing process should never get updates from replication. # The writing process should never get updates from replication.

View File

@ -49,7 +49,7 @@ class SlavedDeviceStore(DeviceWorkerStore, BaseSlavedStore):
def get_device_stream_token(self) -> int: def get_device_stream_token(self) -> int:
return self._device_list_id_gen.get_current_token() return self._device_list_id_gen.get_current_token()
async def process_replication_rows( def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
) -> None: ) -> None:
if stream_name == DeviceListsStream.NAME: if stream_name == DeviceListsStream.NAME:
@ -59,9 +59,7 @@ class SlavedDeviceStore(DeviceWorkerStore, BaseSlavedStore):
self._device_list_id_gen.advance(instance_name, token) self._device_list_id_gen.advance(instance_name, token)
for row in rows: for row in rows:
self._user_signature_stream_cache.entity_has_changed(row.user_id, token) self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
return await super().process_replication_rows( return super().process_replication_rows(stream_name, instance_name, token, rows)
stream_name, instance_name, token, rows
)
def _invalidate_caches_for_devices( def _invalidate_caches_for_devices(
self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow] self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow]

View File

@ -24,7 +24,7 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
def get_max_push_rules_stream_id(self) -> int: def get_max_push_rules_stream_id(self) -> int:
return self._push_rules_stream_id_gen.get_current_token() return self._push_rules_stream_id_gen.get_current_token()
async def process_replication_rows( def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
) -> None: ) -> None:
if stream_name == PushRulesStream.NAME: if stream_name == PushRulesStream.NAME:
@ -33,6 +33,4 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
self.get_push_rules_for_user.invalidate((row.user_id,)) self.get_push_rules_for_user.invalidate((row.user_id,))
self.get_push_rules_enabled_for_user.invalidate((row.user_id,)) self.get_push_rules_enabled_for_user.invalidate((row.user_id,))
self.push_rules_stream_cache.entity_has_changed(row.user_id, token) self.push_rules_stream_cache.entity_has_changed(row.user_id, token)
return await super().process_replication_rows( return super().process_replication_rows(stream_name, instance_name, token, rows)
stream_name, instance_name, token, rows
)

View File

@ -40,11 +40,9 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
def get_pushers_stream_token(self) -> int: def get_pushers_stream_token(self) -> int:
return self._pushers_id_gen.get_current_token() return self._pushers_id_gen.get_current_token()
async def process_replication_rows( def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
) -> None: ) -> None:
if stream_name == PushersStream.NAME: if stream_name == PushersStream.NAME:
self._pushers_id_gen.advance(instance_name, token) self._pushers_id_gen.advance(instance_name, token)
return await super().process_replication_rows( return super().process_replication_rows(stream_name, instance_name, token, rows)
stream_name, instance_name, token, rows
)

View File

@ -144,15 +144,13 @@ class ReplicationDataHandler:
token: stream token for this batch of rows token: stream token for this batch of rows
rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
""" """
await self.store.process_replication_rows( self.store.process_replication_rows(stream_name, instance_name, token, rows)
stream_name, instance_name, token, rows
)
if self.send_handler: if self.send_handler:
await self.send_handler.process_replication_rows(stream_name, token, rows) await self.send_handler.process_replication_rows(stream_name, token, rows)
if stream_name == TypingStream.NAME: if stream_name == TypingStream.NAME:
await self._typing_handler.process_replication_rows(token, rows) self._typing_handler.process_replication_rows(token, rows)
self.notifier.on_new_event( self.notifier.on_new_event(
StreamKeyType.TYPING, token, rooms=[row.room_id for row in rows] StreamKeyType.TYPING, token, rooms=[row.room_id for row in rows]
) )

View File

@ -47,7 +47,7 @@ class SQLBaseStore(metaclass=ABCMeta):
self.database_engine = database.engine self.database_engine = database.engine
self.db_pool = database self.db_pool = database
async def process_replication_rows( def process_replication_rows(
self, self,
stream_name: str, stream_name: str,
instance_name: str, instance_name: str,

View File

@ -414,7 +414,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
) )
) )
async def process_replication_rows( def process_replication_rows(
self, self,
stream_name: str, stream_name: str,
instance_name: str, instance_name: str,
@ -437,7 +437,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
) )
self._account_data_stream_cache.entity_has_changed(row.user_id, token) self._account_data_stream_cache.entity_has_changed(row.user_id, token)
await super().process_replication_rows(stream_name, instance_name, token, rows) super().process_replication_rows(stream_name, instance_name, token, rows)
async def add_account_data_to_room( async def add_account_data_to_room(
self, user_id: str, room_id: str, account_data_type: str, content: JsonDict self, user_id: str, room_id: str, account_data_type: str, content: JsonDict

View File

@ -119,7 +119,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
"get_all_updated_caches", get_all_updated_caches_txn "get_all_updated_caches", get_all_updated_caches_txn
) )
async def process_replication_rows( def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
) -> None: ) -> None:
if stream_name == EventsStream.NAME: if stream_name == EventsStream.NAME:
@ -154,7 +154,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
else: else:
self._attempt_to_invalidate_cache(row.cache_func, row.keys) self._attempt_to_invalidate_cache(row.cache_func, row.keys)
await super().process_replication_rows(stream_name, instance_name, token, rows) super().process_replication_rows(stream_name, instance_name, token, rows)
def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None: def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None:
data = row.data data = row.data

View File

@ -128,7 +128,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
prefilled_cache=device_outbox_prefill, prefilled_cache=device_outbox_prefill,
) )
async def process_replication_rows( def process_replication_rows(
self, self,
stream_name: str, stream_name: str,
instance_name: str, instance_name: str,
@ -148,9 +148,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
self._device_federation_outbox_stream_cache.entity_has_changed( self._device_federation_outbox_stream_cache.entity_has_changed(
row.entity, token row.entity, token
) )
return await super().process_replication_rows( return super().process_replication_rows(stream_name, instance_name, token, rows)
stream_name, instance_name, token, rows
)
def get_to_device_stream_token(self) -> int: def get_to_device_stream_token(self) -> int:
return self._device_inbox_id_gen.get_current_token() return self._device_inbox_id_gen.get_current_token()

View File

@ -280,7 +280,7 @@ class EventsWorkerStore(SQLBaseStore):
id_column="chain_id", id_column="chain_id",
) )
async def process_replication_rows( def process_replication_rows(
self, self,
stream_name: str, stream_name: str,
instance_name: str, instance_name: str,
@ -292,7 +292,7 @@ class EventsWorkerStore(SQLBaseStore):
elif stream_name == BackfillStream.NAME: elif stream_name == BackfillStream.NAME:
self._backfill_id_gen.advance(instance_name, -token) self._backfill_id_gen.advance(instance_name, -token)
await super().process_replication_rows(stream_name, instance_name, token, rows) super().process_replication_rows(stream_name, instance_name, token, rows)
async def have_censored_event(self, event_id: str) -> bool: async def have_censored_event(self, event_id: str) -> bool:
"""Check if an event has been censored, i.e. if the content of the event has been erased """Check if an event has been censored, i.e. if the content of the event has been erased

View File

@ -431,7 +431,7 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
self._presence_on_startup = [] self._presence_on_startup = []
return active_on_startup return active_on_startup
async def process_replication_rows( def process_replication_rows(
self, self,
stream_name: str, stream_name: str,
instance_name: str, instance_name: str,
@ -443,6 +443,4 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
for row in rows: for row in rows:
self.presence_stream_cache.entity_has_changed(row.user_id, token) self.presence_stream_cache.entity_has_changed(row.user_id, token)
self._get_presence_for_user.invalidate((row.user_id,)) self._get_presence_for_user.invalidate((row.user_id,))
return await super().process_replication_rows( return super().process_replication_rows(stream_name, instance_name, token, rows)
stream_name, instance_name, token, rows
)

View File

@ -589,7 +589,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
"get_unread_event_push_actions_by_room_for_user", (room_id,) "get_unread_event_push_actions_by_room_for_user", (room_id,)
) )
async def process_replication_rows( def process_replication_rows(
self, self,
stream_name: str, stream_name: str,
instance_name: str, instance_name: str,
@ -604,9 +604,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
) )
self._receipts_stream_cache.entity_has_changed(row.room_id, token) self._receipts_stream_cache.entity_has_changed(row.room_id, token)
return await super().process_replication_rows( return super().process_replication_rows(stream_name, instance_name, token, rows)
stream_name, instance_name, token, rows
)
def _insert_linearized_receipt_txn( def _insert_linearized_receipt_txn(
self, self,

View File

@ -292,7 +292,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
# than the id that the client has. # than the id that the client has.
pass pass
async def process_replication_rows( def process_replication_rows(
self, self,
stream_name: str, stream_name: str,
instance_name: str, instance_name: str,
@ -305,7 +305,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
self.get_tags_for_user.invalidate((row.user_id,)) self.get_tags_for_user.invalidate((row.user_id,))
self._account_data_stream_cache.entity_has_changed(row.user_id, token) self._account_data_stream_cache.entity_has_changed(row.user_id, token)
await super().process_replication_rows(stream_name, instance_name, token, rows) super().process_replication_rows(stream_name, instance_name, token, rows)
class TagsStore(TagsWorkerStore): class TagsStore(TagsWorkerStore):