diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py index 57202a5bda..988577a334 100644 --- a/synapse/storage/deviceinbox.py +++ b/synapse/storage/deviceinbox.py @@ -27,28 +27,89 @@ logger = logging.getLogger(__name__) class DeviceInboxStore(SQLBaseStore): @defer.inlineCallbacks - def add_messages_to_device_inbox(self, messages_by_user_then_device): - """ + def add_messages_to_device_inbox(self, local_messages_by_user_then_device, + remote_messages_by_destination): + """Used to send messages from this server. + Args: - messages_by_user_and_device(dict): + sender_user_id(str): The ID of the user sending these messages. + local_messages_by_user_and_device(dict): Dictionary of user_id to device_id to message. + remote_messages_by_destination(dict): + Dictionary of destination server_name to the EDU JSON to send. Returns: A deferred stream_id that resolves when the messages have been inserted. """ + def add_messages_to_device_federation_outbox(txn, now_ms, stream_id): + sql = ( + "INSERT INTO device_federation_outbox" + " (destination, stream_id, queued_ts, messages_json)" + " VALUES (?,?,?,?)" + ) + rows = [] + for destination, edu in remote_messages_by_destination.items(): + edu_json = ujson.dumps(edu) + rows.append((destination, stream_id, now_ms, edu_json)) + + txn.executemany(sql, rows) + + def add_messages_txn(txn, now_ms, stream_id): + self._add_messages_to_local_device_inbox_txn( + txn, stream_id, local_messages_by_user_then_device + ) + add_messages_to_device_federation_outbox(now_ms, stream_id) + with self._device_inbox_id_gen.get_next() as stream_id: + now_ms = self.clock.time_now_ms() yield self.runInteraction( "add_messages_to_device_inbox", - self._add_messages_to_device_inbox_txn, + add_messages_txn, + now_ms, stream_id, - messages_by_user_then_device, ) defer.returnValue(self._device_inbox_id_gen.get_current_token()) - def _add_messages_to_device_inbox_txn(self, txn, stream_id, - messages_by_user_then_device): + @defer.inlineCallbacks + def add_messages_from_remote_to_device_inbox( + self, origin, message_id, local_messages_by_user_then_device + ): + def add_messages_txn(txn, now_ms, stream_id): + already_inserted = self._simple_select_one_txn( + txn, table="device_federation_inbox", + keyvalues={"origin": origin, "message_id": message_id}, + retcols=("message_id",), + allow_none=True, + ) + if already_inserted is not None: + return + + self._simple_insert_txn( + txn, table="device_federation_inbox", + values={ + "origin": origin, + "message_id": message_id, + "received_ts": now_ms, + }, + ) + + self._add_messages_to_local_device_inbox_txn( + txn, stream_id, local_messages_by_user_then_device + ) + + with self._device_inbox_id_gen.get_next() as stream_id: + now_ms = self.clock.time_now_ms() + yield self.runInteraction( + "add_messages_from_remote_to_device_inbox", + add_messages_txn, + now_ms, + stream_id, + ) + + def _add_messages_to_local_device_inbox_txn(self, txn, stream_id, + messages_by_user_then_device): local_users_and_devices = set() for user_id, messages_by_device in messages_by_user_then_device.items(): devices = messages_by_device.keys() @@ -177,3 +238,67 @@ class DeviceInboxStore(SQLBaseStore): def get_to_device_stream_token(self): return self._device_inbox_id_gen.get_current_token() + + @defer.inlineCallbacks + def get_new_device_messages_for_remote_destination( + self, destination, last_stream_id, current_stream_id, limit=100 + ): + """ + Args: + destination(str): The name of the remote server. + last_stream_id(int): The last position of the device message stream + that the server sent up to. + current_stream_id(int): The current position of the device + message stream. + Returns: + Deferred ([dict], int): List of messages for the device and where + in the stream the messages got to. + """ + def get_new_messages_for_remote_destination_txn(txn): + sql = ( + "SELECT stream_id, messages_json FROM device_federation_outbox" + " WHERE destination = ?" + " AND ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC" + " LIMIT ?" + ) + txn.execute(sql, ( + destination, last_stream_id, current_stream_id, limit + )) + messages = [] + for row in txn.fetchall(): + stream_pos = row[0] + messages.append(ujson.loads(row[1])) + if len(messages) < limit: + stream_pos = current_stream_id + return (messages, stream_pos) + + return self.runInteraction( + "get_new_device_messages_for_remote_destination", + get_new_messages_for_remote_destination_txn, + ) + + @defer.inlineCallbacks + def delete_device_messages_for_remote_destination(self, destination, + up_to_stream_id): + """Used to delete messages when the remote destination acknowledges + their receipt. + + Args: + destination(str): The destination server_name + up_to_stream_id(int): Where to delete messages up to. + Returns: + A deferred that resolves when the messages have been deleted. + """ + def delete_messages_for_remote_destination_txn(txn): + sql = ( + "DELETE FROM device_federation_outbox" + " WHERE destination = ? AND" + " AND stream_id <= ?" + ) + txn.execute(sql, (destination, up_to_stream_id)) + + return self.runInteraction( + "delete_device_messages_for_remote_destination", + delete_messages_for_remote_destination_txn + )