Make StreamIdGen `get_next` and `get_next_mult` async (#8161)
This is mainly so that `StreamIdGenerator` and `MultiWriterIdGenerator` will have the same interface, allowing them to be used interchangeably.pull/8167/head
parent
74bf8d4d06
commit
2231dffee6
|
@ -0,0 +1 @@
|
|||
Refactor `StreamIdGenerator` and `MultiWriterIdGenerator` to have the same interface.
|
|
@ -336,7 +336,7 @@ class AccountDataStore(AccountDataWorkerStore):
|
|||
"""
|
||||
content_json = json_encoder.encode(content)
|
||||
|
||||
with self._account_data_id_gen.get_next() as next_id:
|
||||
with await self._account_data_id_gen.get_next() as next_id:
|
||||
# no need to lock here as room_account_data has a unique constraint
|
||||
# on (user_id, room_id, account_data_type) so simple_upsert will
|
||||
# retry if there is a conflict.
|
||||
|
@ -384,7 +384,7 @@ class AccountDataStore(AccountDataWorkerStore):
|
|||
"""
|
||||
content_json = json_encoder.encode(content)
|
||||
|
||||
with self._account_data_id_gen.get_next() as next_id:
|
||||
with await self._account_data_id_gen.get_next() as next_id:
|
||||
# no need to lock here as account_data has a unique constraint on
|
||||
# (user_id, account_data_type) so simple_upsert will retry if
|
||||
# there is a conflict.
|
||||
|
|
|
@ -362,7 +362,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
|
|||
rows.append((destination, stream_id, now_ms, edu_json))
|
||||
txn.executemany(sql, rows)
|
||||
|
||||
with self._device_inbox_id_gen.get_next() as stream_id:
|
||||
with await self._device_inbox_id_gen.get_next() as stream_id:
|
||||
now_ms = self.clock.time_msec()
|
||||
await self.db_pool.runInteraction(
|
||||
"add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
|
||||
|
@ -411,7 +411,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
|
|||
txn, stream_id, local_messages_by_user_then_device
|
||||
)
|
||||
|
||||
with self._device_inbox_id_gen.get_next() as stream_id:
|
||||
with await self._device_inbox_id_gen.get_next() as stream_id:
|
||||
now_ms = self.clock.time_msec()
|
||||
await self.db_pool.runInteraction(
|
||||
"add_messages_from_remote_to_device_inbox",
|
||||
|
|
|
@ -380,7 +380,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
THe new stream ID.
|
||||
"""
|
||||
|
||||
with self._device_list_id_gen.get_next() as stream_id:
|
||||
with await self._device_list_id_gen.get_next() as stream_id:
|
||||
await self.db_pool.runInteraction(
|
||||
"add_user_sig_change_to_streams",
|
||||
self._add_user_signature_change_txn,
|
||||
|
@ -1146,7 +1146,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
if not device_ids:
|
||||
return
|
||||
|
||||
with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids:
|
||||
with await self._device_list_id_gen.get_next_mult(
|
||||
len(device_ids)
|
||||
) as stream_ids:
|
||||
await self.db_pool.runInteraction(
|
||||
"add_device_change_to_stream",
|
||||
self._add_device_change_to_stream_txn,
|
||||
|
@ -1159,7 +1161,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
return stream_ids[-1]
|
||||
|
||||
context = get_active_span_text_map()
|
||||
with self._device_list_id_gen.get_next_mult(
|
||||
with await self._device_list_id_gen.get_next_mult(
|
||||
len(hosts) * len(device_ids)
|
||||
) as stream_ids:
|
||||
await self.db_pool.runInteraction(
|
||||
|
|
|
@ -648,7 +648,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
|||
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
|
||||
)
|
||||
|
||||
def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key):
|
||||
def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key, stream_id):
|
||||
"""Set a user's cross-signing key.
|
||||
|
||||
Args:
|
||||
|
@ -658,6 +658,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
|||
for a master key, 'self_signing' for a self-signing key, or
|
||||
'user_signing' for a user-signing key
|
||||
key (dict): the key data
|
||||
stream_id (int)
|
||||
"""
|
||||
# the 'key' dict will look something like:
|
||||
# {
|
||||
|
@ -695,23 +696,22 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
|||
)
|
||||
|
||||
# and finally, store the key itself
|
||||
with self._cross_signing_id_gen.get_next() as stream_id:
|
||||
self.db_pool.simple_insert_txn(
|
||||
txn,
|
||||
"e2e_cross_signing_keys",
|
||||
values={
|
||||
"user_id": user_id,
|
||||
"keytype": key_type,
|
||||
"keydata": json_encoder.encode(key),
|
||||
"stream_id": stream_id,
|
||||
},
|
||||
)
|
||||
self.db_pool.simple_insert_txn(
|
||||
txn,
|
||||
"e2e_cross_signing_keys",
|
||||
values={
|
||||
"user_id": user_id,
|
||||
"keytype": key_type,
|
||||
"keydata": json_encoder.encode(key),
|
||||
"stream_id": stream_id,
|
||||
},
|
||||
)
|
||||
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self._get_bare_e2e_cross_signing_keys, (user_id,)
|
||||
)
|
||||
|
||||
def set_e2e_cross_signing_key(self, user_id, key_type, key):
|
||||
async def set_e2e_cross_signing_key(self, user_id, key_type, key):
|
||||
"""Set a user's cross-signing key.
|
||||
|
||||
Args:
|
||||
|
@ -719,13 +719,16 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
|||
key_type (str): the type of cross-signing key to set
|
||||
key (dict): the key data
|
||||
"""
|
||||
return self.db_pool.runInteraction(
|
||||
"add_e2e_cross_signing_key",
|
||||
self._set_e2e_cross_signing_key_txn,
|
||||
user_id,
|
||||
key_type,
|
||||
key,
|
||||
)
|
||||
|
||||
with await self._cross_signing_id_gen.get_next() as stream_id:
|
||||
return await self.db_pool.runInteraction(
|
||||
"add_e2e_cross_signing_key",
|
||||
self._set_e2e_cross_signing_key_txn,
|
||||
user_id,
|
||||
key_type,
|
||||
key,
|
||||
stream_id,
|
||||
)
|
||||
|
||||
def store_e2e_cross_signing_signatures(self, user_id, signatures):
|
||||
"""Stores cross-signing signatures.
|
||||
|
|
|
@ -153,11 +153,11 @@ class PersistEventsStore:
|
|||
# Note: Multiple instances of this function cannot be in flight at
|
||||
# the same time for the same room.
|
||||
if backfilled:
|
||||
stream_ordering_manager = self._backfill_id_gen.get_next_mult(
|
||||
stream_ordering_manager = await self._backfill_id_gen.get_next_mult(
|
||||
len(events_and_contexts)
|
||||
)
|
||||
else:
|
||||
stream_ordering_manager = self._stream_id_gen.get_next_mult(
|
||||
stream_ordering_manager = await self._stream_id_gen.get_next_mult(
|
||||
len(events_and_contexts)
|
||||
)
|
||||
|
||||
|
|
|
@ -1182,7 +1182,7 @@ class GroupServerStore(GroupServerWorkerStore):
|
|||
|
||||
return next_id
|
||||
|
||||
with self._group_updates_id_gen.get_next() as next_id:
|
||||
with await self._group_updates_id_gen.get_next() as next_id:
|
||||
res = await self.db_pool.runInteraction(
|
||||
"register_user_group_membership",
|
||||
_register_user_group_membership_txn,
|
||||
|
|
|
@ -23,7 +23,7 @@ from synapse.util.iterutils import batch_iter
|
|||
|
||||
class PresenceStore(SQLBaseStore):
|
||||
async def update_presence(self, presence_states):
|
||||
stream_ordering_manager = self._presence_id_gen.get_next_mult(
|
||||
stream_ordering_manager = await self._presence_id_gen.get_next_mult(
|
||||
len(presence_states)
|
||||
)
|
||||
|
||||
|
|
|
@ -338,7 +338,7 @@ class PushRuleStore(PushRulesWorkerStore):
|
|||
) -> None:
|
||||
conditions_json = json_encoder.encode(conditions)
|
||||
actions_json = json_encoder.encode(actions)
|
||||
with self._push_rules_stream_id_gen.get_next() as stream_id:
|
||||
with await self._push_rules_stream_id_gen.get_next() as stream_id:
|
||||
event_stream_ordering = self._stream_id_gen.get_current_token()
|
||||
|
||||
if before or after:
|
||||
|
@ -560,7 +560,7 @@ class PushRuleStore(PushRulesWorkerStore):
|
|||
txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE"
|
||||
)
|
||||
|
||||
with self._push_rules_stream_id_gen.get_next() as stream_id:
|
||||
with await self._push_rules_stream_id_gen.get_next() as stream_id:
|
||||
event_stream_ordering = self._stream_id_gen.get_current_token()
|
||||
|
||||
await self.db_pool.runInteraction(
|
||||
|
@ -571,7 +571,7 @@ class PushRuleStore(PushRulesWorkerStore):
|
|||
)
|
||||
|
||||
async def set_push_rule_enabled(self, user_id, rule_id, enabled) -> None:
|
||||
with self._push_rules_stream_id_gen.get_next() as stream_id:
|
||||
with await self._push_rules_stream_id_gen.get_next() as stream_id:
|
||||
event_stream_ordering = self._stream_id_gen.get_current_token()
|
||||
|
||||
await self.db_pool.runInteraction(
|
||||
|
@ -646,7 +646,7 @@ class PushRuleStore(PushRulesWorkerStore):
|
|||
data={"actions": actions_json},
|
||||
)
|
||||
|
||||
with self._push_rules_stream_id_gen.get_next() as stream_id:
|
||||
with await self._push_rules_stream_id_gen.get_next() as stream_id:
|
||||
event_stream_ordering = self._stream_id_gen.get_current_token()
|
||||
|
||||
await self.db_pool.runInteraction(
|
||||
|
|
|
@ -281,7 +281,7 @@ class PusherStore(PusherWorkerStore):
|
|||
last_stream_ordering,
|
||||
profile_tag="",
|
||||
) -> None:
|
||||
with self._pushers_id_gen.get_next() as stream_id:
|
||||
with await self._pushers_id_gen.get_next() as stream_id:
|
||||
# no need to lock because `pushers` has a unique key on
|
||||
# (app_id, pushkey, user_name) so simple_upsert will retry
|
||||
await self.db_pool.simple_upsert(
|
||||
|
@ -344,7 +344,7 @@ class PusherStore(PusherWorkerStore):
|
|||
},
|
||||
)
|
||||
|
||||
with self._pushers_id_gen.get_next() as stream_id:
|
||||
with await self._pushers_id_gen.get_next() as stream_id:
|
||||
await self.db_pool.runInteraction(
|
||||
"delete_pusher", delete_pusher_txn, stream_id
|
||||
)
|
||||
|
|
|
@ -520,8 +520,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
|
|||
"insert_receipt_conv", graph_to_linear
|
||||
)
|
||||
|
||||
stream_id_manager = self._receipts_id_gen.get_next()
|
||||
with stream_id_manager as stream_id:
|
||||
with await self._receipts_id_gen.get_next() as stream_id:
|
||||
event_ts = await self.db_pool.runInteraction(
|
||||
"insert_linearized_receipt",
|
||||
self.insert_linearized_receipt_txn,
|
||||
|
|
|
@ -1129,7 +1129,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
|||
},
|
||||
)
|
||||
|
||||
with self._public_room_id_gen.get_next() as next_id:
|
||||
with await self._public_room_id_gen.get_next() as next_id:
|
||||
await self.db_pool.runInteraction(
|
||||
"store_room_txn", store_room_txn, next_id
|
||||
)
|
||||
|
@ -1196,7 +1196,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
|||
},
|
||||
)
|
||||
|
||||
with self._public_room_id_gen.get_next() as next_id:
|
||||
with await self._public_room_id_gen.get_next() as next_id:
|
||||
await self.db_pool.runInteraction(
|
||||
"set_room_is_public", set_room_is_public_txn, next_id
|
||||
)
|
||||
|
@ -1276,7 +1276,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
|||
},
|
||||
)
|
||||
|
||||
with self._public_room_id_gen.get_next() as next_id:
|
||||
with await self._public_room_id_gen.get_next() as next_id:
|
||||
await self.db_pool.runInteraction(
|
||||
"set_room_is_public_appservice",
|
||||
set_room_is_public_appservice_txn,
|
||||
|
|
|
@ -210,7 +210,7 @@ class TagsStore(TagsWorkerStore):
|
|||
)
|
||||
self._update_revision_txn(txn, user_id, room_id, next_id)
|
||||
|
||||
with self._account_data_id_gen.get_next() as next_id:
|
||||
with await self._account_data_id_gen.get_next() as next_id:
|
||||
await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id)
|
||||
|
||||
self.get_tags_for_user.invalidate((user_id,))
|
||||
|
@ -232,7 +232,7 @@ class TagsStore(TagsWorkerStore):
|
|||
txn.execute(sql, (user_id, room_id, tag))
|
||||
self._update_revision_txn(txn, user_id, room_id, next_id)
|
||||
|
||||
with self._account_data_id_gen.get_next() as next_id:
|
||||
with await self._account_data_id_gen.get_next() as next_id:
|
||||
await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id)
|
||||
|
||||
self.get_tags_for_user.invalidate((user_id,))
|
||||
|
|
|
@ -80,7 +80,7 @@ class StreamIdGenerator(object):
|
|||
upwards, -1 to grow downwards.
|
||||
|
||||
Usage:
|
||||
with stream_id_gen.get_next() as stream_id:
|
||||
with await stream_id_gen.get_next() as stream_id:
|
||||
# ... persist event ...
|
||||
"""
|
||||
|
||||
|
@ -95,10 +95,10 @@ class StreamIdGenerator(object):
|
|||
)
|
||||
self._unfinished_ids = deque() # type: Deque[int]
|
||||
|
||||
def get_next(self):
|
||||
async def get_next(self):
|
||||
"""
|
||||
Usage:
|
||||
with stream_id_gen.get_next() as stream_id:
|
||||
with await stream_id_gen.get_next() as stream_id:
|
||||
# ... persist event ...
|
||||
"""
|
||||
with self._lock:
|
||||
|
@ -117,10 +117,10 @@ class StreamIdGenerator(object):
|
|||
|
||||
return manager()
|
||||
|
||||
def get_next_mult(self, n):
|
||||
async def get_next_mult(self, n):
|
||||
"""
|
||||
Usage:
|
||||
with stream_id_gen.get_next(n) as stream_ids:
|
||||
with await stream_id_gen.get_next(n) as stream_ids:
|
||||
# ... persist events ...
|
||||
"""
|
||||
with self._lock:
|
||||
|
|
Loading…
Reference in New Issue