diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 16ba545740..c9ce6f05a3 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -939,7 +939,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): token: str, device_id: Optional[str], valid_until_ms: Optional[int], - ) -> None: + ) -> int: """Adds an access token for the given user. Args: @@ -949,6 +949,8 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): valid_until_ms: when the token is valid until. None for no expiry. Raises: StoreError if there was a problem adding this. + Returns: + The token ID """ next_id = self._access_tokens_id_gen.get_next() @@ -964,6 +966,8 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): desc="add_access_token_to_user", ) + return next_id + def _set_device_for_access_token_txn(self, txn, token: str, device_id: str) -> str: old_device_id = self.db_pool.simple_select_one_onecol_txn( txn, "access_tokens", {"token": token}, "device_id" diff --git a/synapse/storage/databases/main/schema/delta/58/19txn_id.sql b/synapse/storage/databases/main/schema/delta/58/19txn_id.sql index 31e81314b4..82c00dd908 100644 --- a/synapse/storage/databases/main/schema/delta/58/19txn_id.sql +++ b/synapse/storage/databases/main/schema/delta/58/19txn_id.sql @@ -17,16 +17,23 @@ -- A map of recent events persisted with transaction IDs. Used to deduplicate -- send event requests with the same transaction ID. -- --- Note, transaction IDs are scoped to the user ID/access token that was used to --- make the request. -CREATE TABLE event_txn_id ( +-- Note: transaction IDs are scoped to the room ID/user ID/access token that was +-- used to make the request. +-- +-- Note: The foreign key constraints are ON DELETE CASCADE, as if we delete the +-- events or access token we don't want to try and de-duplicate the event. +CREATE TABLE IF NOT EXISTS event_txn_id ( event_id TEXT NOT NULL, user_id TEXT NOT NULL, token_id BIGINT NOT NULL, txn_id TEXT NOT NULL, - inserted_ts BIGINT NOT NULL + inserted_ts BIGINT NOT NULL, + FOREIGN KEY (event_id) + REFERENCES events (event_id) ON DELETE CASCADE, + FOREIGN KEY (token_id) + REFERENCES access_tokens (id) ON DELETE CASCADE ); -CREATE UNIQUE INDEX event_txn_id_event_id ON event_txn_id(event_id); -CREATE UNIQUE INDEX event_txn_id_txn_id ON event_txn_id(user_id, token_id, txn_id); -CREATE INDEX event_txn_id_ts ON event_txn_id(inserted_ts); +CREATE UNIQUE INDEX IF NOT EXISTS event_txn_id_event_id ON event_txn_id(event_id); +CREATE UNIQUE INDEX IF NOT EXISTS event_txn_id_txn_id ON event_txn_id(user_id, token_id, txn_id); +CREATE INDEX IF NOT EXISTS event_txn_id_ts ON event_txn_id(inserted_ts); diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py index dbad7dc158..d32de57515 100644 --- a/tests/handlers/test_message.py +++ b/tests/handlers/test_message.py @@ -44,12 +44,15 @@ class EventCreationTestCase(unittest.HomeserverTestCase): access_token = self.login("tester", "foobar") room_id = self.helper.create_room_as(user_id, tok=access_token) - # We make the IDs up here, which is fine. - token_id = 4957834 - txn_id = "something_suitably_random" + info = self.get_success( + self.hs.get_datastore().get_user_by_access_token(access_token,) + ) + token_id = info["token_id"] requester = create_requester(user_id, access_token_id=token_id) + txn_id = "something_suitably_random" + def create_duplicate_event(): return self.get_success( handler.create_event( diff --git a/tests/unittest.py b/tests/unittest.py index 5c87f6097e..6c1661c92c 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -254,17 +254,24 @@ class HomeserverTestCase(TestCase): if hasattr(self, "user_id"): if self.hijack_auth: + # We need a valid token ID to satisfy foreign key constraints. + token_id = self.get_success( + self.hs.get_datastore().add_access_token_to_user( + self.helper.auth_user_id, "some_fake_token", None, None, + ) + ) + async def get_user_by_access_token(token=None, allow_guest=False): return { "user": UserID.from_string(self.helper.auth_user_id), - "token_id": 1, + "token_id": token_id, "is_guest": False, } async def get_user_by_req(request, allow_guest=False, rights="access"): return create_requester( UserID.from_string(self.helper.auth_user_id), - 1, + token_id, False, False, None,