diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index d55733a4cd..3299607910 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -1017,29 +1017,41 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): """Persist that a user's devices have been updated, and which hosts (if any) should be poked. """ - with self._device_list_id_gen.get_next() as stream_id: + if not device_ids: + return + + with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids: yield self.db.runInteraction( - "add_device_change_to_streams", - self._add_device_change_txn, + "add_device_change_to_stream", + self._add_device_change_to_stream_txn, + user_id, + device_ids, + stream_ids, + ) + + if not hosts: + return stream_ids[-1] + + context = get_active_span_text_map() + with self._device_list_id_gen.get_next_mult( + len(hosts) * len(device_ids) + ) as stream_ids: + yield self.db.runInteraction( + "add_device_outbound_poke_to_stream", + self._add_device_outbound_poke_to_stream_txn, user_id, device_ids, hosts, - stream_id, + stream_ids, + context, ) - return stream_id - def _add_device_change_txn(self, txn, user_id, device_ids, hosts, stream_id): - now = self._clock.time_msec() + return stream_ids[-1] + def _add_device_change_to_stream_txn(self, txn, user_id, device_ids, stream_ids): txn.call_after( - self._device_list_stream_cache.entity_has_changed, user_id, stream_id + self._device_list_stream_cache.entity_has_changed, user_id, stream_ids[-1], ) - for host in hosts: - txn.call_after( - self._device_list_federation_stream_cache.entity_has_changed, - host, - stream_id, - ) # Delete older entries in the table, as we really only care about # when the latest change happened. @@ -1048,7 +1060,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): DELETE FROM device_lists_stream WHERE user_id = ? AND device_id = ? AND stream_id < ? """, - [(user_id, device_id, stream_id) for device_id in device_ids], + [(user_id, device_id, stream_ids[0]) for device_id in device_ids], ) self.db.simple_insert_many_txn( @@ -1056,11 +1068,22 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): table="device_lists_stream", values=[ {"stream_id": stream_id, "user_id": user_id, "device_id": device_id} - for device_id in device_ids + for stream_id, device_id in zip(stream_ids, device_ids) ], ) - context = get_active_span_text_map() + def _add_device_outbound_poke_to_stream_txn( + self, txn, user_id, device_ids, hosts, stream_ids, context, + ): + for host in hosts: + txn.call_after( + self._device_list_federation_stream_cache.entity_has_changed, + host, + stream_ids[-1], + ) + + now = self._clock.time_msec() + next_stream_id = iter(stream_ids) self.db.simple_insert_many_txn( txn, @@ -1068,7 +1091,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): values=[ { "destination": destination, - "stream_id": stream_id, + "stream_id": next(next_stream_id), "user_id": user_id, "device_id": device_id, "sent": False,