Add foreign key constraint to `event_forward_extremities`. (#15751)
							parent
							
								
									c303eca8cc
								
							
						
					
					
						commit
						95a96b21eb
					
				|  | @ -0,0 +1 @@ | |||
| Add foreign key constraint to `event_forward_extremities`. | ||||
|  | @ -61,6 +61,7 @@ from synapse.storage.databases.main.deviceinbox import DeviceInboxBackgroundUpda | |||
| from synapse.storage.databases.main.devices import DeviceBackgroundUpdateStore | ||||
| from synapse.storage.databases.main.e2e_room_keys import EndToEndRoomKeyBackgroundStore | ||||
| from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyBackgroundStore | ||||
| from synapse.storage.databases.main.event_federation import EventFederationWorkerStore | ||||
| from synapse.storage.databases.main.event_push_actions import EventPushActionsStore | ||||
| from synapse.storage.databases.main.events_bg_updates import ( | ||||
|     EventsBackgroundUpdatesStore, | ||||
|  | @ -239,6 +240,7 @@ class Store( | |||
|     PresenceBackgroundUpdateStore, | ||||
|     ReceiptsBackgroundUpdateStore, | ||||
|     RelationsWorkerStore, | ||||
|     EventFederationWorkerStore, | ||||
| ): | ||||
|     def execute(self, f: Callable[..., R], *args: Any, **kwargs: Any) -> Awaitable[R]: | ||||
|         return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs) | ||||
|  |  | |||
|  | @ -11,8 +11,9 @@ | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| import abc | ||||
| import logging | ||||
| from enum import IntEnum | ||||
| from enum import Enum, IntEnum | ||||
| from types import TracebackType | ||||
| from typing import ( | ||||
|     TYPE_CHECKING, | ||||
|  | @ -24,12 +25,16 @@ from typing import ( | |||
|     Iterable, | ||||
|     List, | ||||
|     Optional, | ||||
|     Sequence, | ||||
|     Tuple, | ||||
|     Type, | ||||
| ) | ||||
| 
 | ||||
| import attr | ||||
| from pydantic import BaseModel | ||||
| 
 | ||||
| from synapse.metrics.background_process_metrics import run_as_background_process | ||||
| from synapse.storage.engines import PostgresEngine | ||||
| from synapse.storage.types import Connection, Cursor | ||||
| from synapse.types import JsonDict | ||||
| from synapse.util import Clock, json_encoder | ||||
|  | @ -48,6 +53,78 @@ DEFAULT_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]] | |||
| MIN_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]] | ||||
| 
 | ||||
| 
 | ||||
| class Constraint(metaclass=abc.ABCMeta): | ||||
|     """Base class representing different constraints. | ||||
| 
 | ||||
|     Used by `register_background_validate_constraint_and_delete_rows`. | ||||
|     """ | ||||
| 
 | ||||
|     @abc.abstractmethod | ||||
|     def make_check_clause(self, table: str) -> str: | ||||
|         """Returns an SQL expression that checks the row passes the constraint.""" | ||||
|         pass | ||||
| 
 | ||||
|     @abc.abstractmethod | ||||
|     def make_constraint_clause_postgres(self) -> str: | ||||
|         """Returns an SQL clause for creating the constraint. | ||||
| 
 | ||||
|         Only used on Postgres DBs | ||||
|         """ | ||||
|         pass | ||||
| 
 | ||||
| 
 | ||||
| @attr.s(auto_attribs=True) | ||||
| class ForeignKeyConstraint(Constraint): | ||||
|     """A foreign key constraint. | ||||
| 
 | ||||
|     Attributes: | ||||
|         referenced_table: The "parent" table name. | ||||
|         columns: The list of mappings of columns from table to referenced table | ||||
|     """ | ||||
| 
 | ||||
|     referenced_table: str | ||||
|     columns: Sequence[Tuple[str, str]] | ||||
| 
 | ||||
|     def make_check_clause(self, table: str) -> str: | ||||
|         join_clause = " AND ".join( | ||||
|             f"{col1} = {table}.{col2}" for col1, col2 in self.columns | ||||
|         ) | ||||
|         return f"EXISTS (SELECT 1 FROM {self.referenced_table} WHERE {join_clause})" | ||||
| 
 | ||||
|     def make_constraint_clause_postgres(self) -> str: | ||||
|         column1_list = ", ".join(col1 for col1, col2 in self.columns) | ||||
|         column2_list = ", ".join(col2 for col1, col2 in self.columns) | ||||
|         return f"FOREIGN KEY ({column1_list}) REFERENCES {self.referenced_table} ({column2_list})" | ||||
| 
 | ||||
| 
 | ||||
| @attr.s(auto_attribs=True) | ||||
| class NotNullConstraint(Constraint): | ||||
|     """A NOT NULL column constraint""" | ||||
| 
 | ||||
|     column: str | ||||
| 
 | ||||
|     def make_check_clause(self, table: str) -> str: | ||||
|         return f"{self.column} IS NOT NULL" | ||||
| 
 | ||||
|     def make_constraint_clause_postgres(self) -> str: | ||||
|         return f"CHECK ({self.column} IS NOT NULL)" | ||||
| 
 | ||||
| 
 | ||||
| class ValidateConstraintProgress(BaseModel): | ||||
|     """The format of the progress JSON for validate constraint background | ||||
|     updates. | ||||
| 
 | ||||
|     Used by `register_background_validate_constraint_and_delete_rows`. | ||||
|     """ | ||||
| 
 | ||||
|     class State(str, Enum): | ||||
|         check = "check" | ||||
|         validate = "validate" | ||||
| 
 | ||||
|     state: State = State.validate | ||||
|     lower_bound: Sequence[Any] = () | ||||
| 
 | ||||
| 
 | ||||
| @attr.s(slots=True, frozen=True, auto_attribs=True) | ||||
| class _BackgroundUpdateHandler: | ||||
|     """A handler for a given background update. | ||||
|  | @ -740,6 +817,179 @@ class BackgroundUpdater: | |||
|         logger.info("Adding index %s to %s", index_name, table) | ||||
|         await self.db_pool.runWithConnection(runner) | ||||
| 
 | ||||
|     def register_background_validate_constraint_and_delete_rows( | ||||
|         self, | ||||
|         update_name: str, | ||||
|         table: str, | ||||
|         constraint_name: str, | ||||
|         constraint: Constraint, | ||||
|         unique_columns: Sequence[str], | ||||
|     ) -> None: | ||||
|         """Helper for store classes to do a background validate constraint, and | ||||
|         delete rows that do not pass the constraint check. | ||||
| 
 | ||||
|         Note: This deletes rows that don't match the constraint. This may not be | ||||
|         appropriate in all situations, and so the suitability of using this | ||||
|         method should be considered on a case-by-case basis. | ||||
| 
 | ||||
|         This only applies on PostgreSQL. | ||||
| 
 | ||||
|         For SQLite the table gets recreated as part of the schema delta and the | ||||
|         data is copied over synchronously (or whatever the correct way to | ||||
|         describe it as). | ||||
| 
 | ||||
|         Args: | ||||
|             update_name: The name of the background update. | ||||
|             table: The table with the invalid constraint. | ||||
|             constraint_name: The name of the constraint | ||||
|             constraint: A `Constraint` object matching the type of constraint. | ||||
|             unique_columns: A sequence of columns that form a unique constraint | ||||
|               on the table. Used to iterate over the table. | ||||
|         """ | ||||
| 
 | ||||
|         assert isinstance( | ||||
|             self.db_pool.engine, engines.PostgresEngine | ||||
|         ), "validate constraint background update registered for non-Postres database" | ||||
| 
 | ||||
|         async def updater(progress: JsonDict, batch_size: int) -> int: | ||||
|             return await self.validate_constraint_and_delete_in_background( | ||||
|                 update_name=update_name, | ||||
|                 table=table, | ||||
|                 constraint_name=constraint_name, | ||||
|                 constraint=constraint, | ||||
|                 unique_columns=unique_columns, | ||||
|                 progress=progress, | ||||
|                 batch_size=batch_size, | ||||
|             ) | ||||
| 
 | ||||
|         self._background_update_handlers[update_name] = _BackgroundUpdateHandler( | ||||
|             updater, oneshot=True | ||||
|         ) | ||||
| 
 | ||||
|     async def validate_constraint_and_delete_in_background( | ||||
|         self, | ||||
|         update_name: str, | ||||
|         table: str, | ||||
|         constraint_name: str, | ||||
|         constraint: Constraint, | ||||
|         unique_columns: Sequence[str], | ||||
|         progress: JsonDict, | ||||
|         batch_size: int, | ||||
|     ) -> int: | ||||
|         """Validates a table constraint that has been marked as `NOT VALID`, | ||||
|         deleting rows that don't pass the constraint check. | ||||
| 
 | ||||
|         This will delete rows that do not meet the validation check. | ||||
| 
 | ||||
|         update_name: str, | ||||
|         table: str, | ||||
|         constraint_name: str, | ||||
|         constraint: Constraint, | ||||
|         unique_columns: Sequence[str], | ||||
|         """ | ||||
| 
 | ||||
|         # We validate the constraint by: | ||||
|         #   1. Trying to validate the constraint as is. If this succeeds then | ||||
|         #      we're done. | ||||
|         #   2. Otherwise, we manually scan the table to remove rows that don't | ||||
|         #      match the constraint. | ||||
|         #   3. We try re-validating the constraint. | ||||
| 
 | ||||
|         parsed_progress = ValidateConstraintProgress.parse_obj(progress) | ||||
| 
 | ||||
|         if parsed_progress.state == ValidateConstraintProgress.State.check: | ||||
|             return_columns = ", ".join(unique_columns) | ||||
|             order_columns = ", ".join(unique_columns) | ||||
| 
 | ||||
|             where_clause = "" | ||||
|             args: List[Any] = [] | ||||
|             if parsed_progress.lower_bound: | ||||
|                 where_clause = f"""WHERE ({order_columns}) > ({", ".join("?" for _ in unique_columns)})""" | ||||
|                 args.extend(parsed_progress.lower_bound) | ||||
| 
 | ||||
|             args.append(batch_size) | ||||
| 
 | ||||
|             sql = f""" | ||||
|                 SELECT | ||||
|                     {return_columns}, | ||||
|                     {constraint.make_check_clause(table)} AS check | ||||
|                 FROM {table} | ||||
|                 {where_clause} | ||||
|                 ORDER BY {order_columns} | ||||
|                 LIMIT ? | ||||
|             """ | ||||
| 
 | ||||
|             def validate_constraint_in_background_check( | ||||
|                 txn: "LoggingTransaction", | ||||
|             ) -> None: | ||||
|                 txn.execute(sql, args) | ||||
|                 rows = txn.fetchall() | ||||
| 
 | ||||
|                 new_progress = parsed_progress.copy() | ||||
| 
 | ||||
|                 if not rows: | ||||
|                     new_progress.state = ValidateConstraintProgress.State.validate | ||||
|                     self._background_update_progress_txn( | ||||
|                         txn, update_name, new_progress.dict() | ||||
|                     ) | ||||
|                     return | ||||
| 
 | ||||
|                 new_progress.lower_bound = rows[-1][:-1] | ||||
| 
 | ||||
|                 to_delete = [row[:-1] for row in rows if not row[-1]] | ||||
| 
 | ||||
|                 if to_delete: | ||||
|                     logger.warning( | ||||
|                         "Deleting %d rows that do not pass new constraint", | ||||
|                         len(to_delete), | ||||
|                     ) | ||||
| 
 | ||||
|                     self.db_pool.simple_delete_many_batch_txn( | ||||
|                         txn, table=table, keys=unique_columns, values=to_delete | ||||
|                     ) | ||||
| 
 | ||||
|                 self._background_update_progress_txn( | ||||
|                     txn, update_name, new_progress.dict() | ||||
|                 ) | ||||
| 
 | ||||
|             await self.db_pool.runInteraction( | ||||
|                 "validate_constraint_in_background_check", | ||||
|                 validate_constraint_in_background_check, | ||||
|             ) | ||||
| 
 | ||||
|             return batch_size | ||||
| 
 | ||||
|         elif parsed_progress.state == ValidateConstraintProgress.State.validate: | ||||
|             sql = f"ALTER TABLE {table} VALIDATE CONSTRAINT {constraint_name}" | ||||
| 
 | ||||
|             def validate_constraint_in_background_validate( | ||||
|                 txn: "LoggingTransaction", | ||||
|             ) -> None: | ||||
|                 txn.execute(sql) | ||||
| 
 | ||||
|             try: | ||||
|                 await self.db_pool.runInteraction( | ||||
|                     "validate_constraint_in_background_validate", | ||||
|                     validate_constraint_in_background_validate, | ||||
|                 ) | ||||
| 
 | ||||
|                 await self._end_background_update(update_name) | ||||
|             except self.db_pool.engine.module.IntegrityError as e: | ||||
|                 # If we get an integrity error here, then we go back and recheck the table. | ||||
|                 logger.warning("Integrity error when validating constraint: %s", e) | ||||
|                 await self._background_update_progress( | ||||
|                     update_name, | ||||
|                     ValidateConstraintProgress( | ||||
|                         state=ValidateConstraintProgress.State.check | ||||
|                     ).dict(), | ||||
|                 ) | ||||
| 
 | ||||
|             return batch_size | ||||
|         else: | ||||
|             raise Exception( | ||||
|                 f"Unrecognized state '{parsed_progress.state}' when trying to validate_constraint_and_delete_in_background" | ||||
|             ) | ||||
| 
 | ||||
|     async def _end_background_update(self, update_name: str) -> None: | ||||
|         """Removes a completed background update task from the queue. | ||||
| 
 | ||||
|  | @ -795,3 +1045,86 @@ class BackgroundUpdater: | |||
|             keyvalues={"update_name": update_name}, | ||||
|             updatevalues={"progress_json": progress_json}, | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| def run_validate_constraint_and_delete_rows_schema_delta( | ||||
|     txn: "LoggingTransaction", | ||||
|     ordering: int, | ||||
|     update_name: str, | ||||
|     table: str, | ||||
|     constraint_name: str, | ||||
|     constraint: Constraint, | ||||
|     sqlite_table_name: str, | ||||
|     sqlite_table_schema: str, | ||||
| ) -> None: | ||||
|     """Runs a schema delta to add a constraint to the table. This should be run | ||||
|     in a schema delta file. | ||||
| 
 | ||||
|     For PostgreSQL the constraint is added and validated in the background. | ||||
| 
 | ||||
|     For SQLite the table is recreated and data copied across immediately. This | ||||
|     is done by the caller passing in a script to create the new table. Note that | ||||
|     table indexes and triggers are copied over automatically. | ||||
| 
 | ||||
|     There must be a corresponding call to | ||||
|     `register_background_validate_constraint_and_delete_rows` to register the | ||||
|     background update in one of the data store classes. | ||||
| 
 | ||||
|     Attributes: | ||||
|         txn ordering, update_name: For adding a row to background_updates table. | ||||
|         table: The table to add constraint to. constraint_name: The name of the | ||||
|         new constraint constraint: A `Constraint` object describing the | ||||
|         constraint sqlite_table_name: For SQLite the name of the empty copy of | ||||
|         table sqlite_table_schema: A SQL script for creating the above table. | ||||
|     """ | ||||
| 
 | ||||
|     if isinstance(txn.database_engine, PostgresEngine): | ||||
|         # For postgres we can just add the constraint and mark it as NOT VALID, | ||||
|         # and then insert a background update to go and check the validity in | ||||
|         # the background. | ||||
|         txn.execute( | ||||
|             f""" | ||||
|             ALTER TABLE {table} | ||||
|             ADD CONSTRAINT {constraint_name} {constraint.make_constraint_clause_postgres()} | ||||
|             NOT VALID | ||||
|             """ | ||||
|         ) | ||||
| 
 | ||||
|         txn.execute( | ||||
|             "INSERT INTO background_updates (ordering, update_name, progress_json) VALUES (?, ?, '{}')", | ||||
|             (ordering, update_name), | ||||
|         ) | ||||
|     else: | ||||
|         # For SQLite, we: | ||||
|         #   1. fetch all indexes/triggers/etc related to the table | ||||
|         #   2. create an empty copy of the table | ||||
|         #   3. copy across the rows (that satisfy the check) | ||||
|         #   4. replace the old table with the new able. | ||||
|         #   5. add back all the indexes/triggers/etc | ||||
| 
 | ||||
|         # Fetch the indexes/triggers/etc. Note that `sql` column being null is | ||||
|         # due to indexes being auto created based on the class definition (e.g. | ||||
|         # PRIMARY KEY), and so don't need to be recreated. | ||||
|         txn.execute( | ||||
|             """ | ||||
|             SELECT sql FROM sqlite_master | ||||
|             WHERE tbl_name = ? AND type != 'table' AND sql IS NOT NULL | ||||
|             """, | ||||
|             (table,), | ||||
|         ) | ||||
|         extras = [row[0] for row in txn] | ||||
| 
 | ||||
|         txn.execute(sqlite_table_schema) | ||||
| 
 | ||||
|         sql = f""" | ||||
|             INSERT INTO {sqlite_table_name} SELECT * FROM {table} | ||||
|             WHERE {constraint.make_check_clause(table)} | ||||
|         """ | ||||
| 
 | ||||
|         txn.execute(sql) | ||||
| 
 | ||||
|         txn.execute(f"DROP TABLE {table}") | ||||
|         txn.execute(f"ALTER TABLE {sqlite_table_name} RENAME TO {table}") | ||||
| 
 | ||||
|         for extra in extras: | ||||
|             txn.execute(extra) | ||||
|  |  | |||
|  | @ -2313,6 +2313,43 @@ class DatabasePool: | |||
| 
 | ||||
|         return txn.rowcount | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def simple_delete_many_batch_txn( | ||||
|         txn: LoggingTransaction, | ||||
|         table: str, | ||||
|         keys: Collection[str], | ||||
|         values: Iterable[Iterable[Any]], | ||||
|     ) -> None: | ||||
|         """Executes a DELETE query on the named table. | ||||
| 
 | ||||
|         The input is given as a list of rows, where each row is a list of values. | ||||
|         (Actually any iterable is fine.) | ||||
| 
 | ||||
|         Args: | ||||
|             txn: The transaction to use. | ||||
|             table: string giving the table name | ||||
|             keys: list of column names | ||||
|             values: for each row, a list of values in the same order as `keys` | ||||
|         """ | ||||
| 
 | ||||
|         if isinstance(txn.database_engine, PostgresEngine): | ||||
|             # We use `execute_values` as it can be a lot faster than `execute_batch`, | ||||
|             # but it's only available on postgres. | ||||
|             sql = "DELETE FROM %s WHERE (%s) IN (VALUES ?)" % ( | ||||
|                 table, | ||||
|                 ", ".join(k for k in keys), | ||||
|             ) | ||||
| 
 | ||||
|             txn.execute_values(sql, values, fetch=False) | ||||
|         else: | ||||
|             sql = "DELETE FROM %s WHERE (%s) = (%s)" % ( | ||||
|                 table, | ||||
|                 ", ".join(k for k in keys), | ||||
|                 ", ".join("?" for _ in keys), | ||||
|             ) | ||||
| 
 | ||||
|             txn.execute_batch(sql, values) | ||||
| 
 | ||||
|     def get_cache_dict( | ||||
|         self, | ||||
|         db_conn: LoggingDatabaseConnection, | ||||
|  |  | |||
|  | @ -38,6 +38,7 @@ from synapse.events import EventBase, make_event_from_dict | |||
| from synapse.logging.opentracing import tag_args, trace | ||||
| from synapse.metrics.background_process_metrics import wrap_as_background_process | ||||
| from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause | ||||
| from synapse.storage.background_updates import ForeignKeyConstraint | ||||
| from synapse.storage.database import ( | ||||
|     DatabasePool, | ||||
|     LoggingDatabaseConnection, | ||||
|  | @ -140,6 +141,15 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas | |||
| 
 | ||||
|         self._clock.looping_call(self._get_stats_for_federation_staging, 30 * 1000) | ||||
| 
 | ||||
|         if isinstance(self.database_engine, PostgresEngine): | ||||
|             self.db_pool.updates.register_background_validate_constraint_and_delete_rows( | ||||
|                 update_name="event_forward_extremities_event_id_foreign_key_constraint_update", | ||||
|                 table="event_forward_extremities", | ||||
|                 constraint_name="event_forward_extremities_event_id", | ||||
|                 constraint=ForeignKeyConstraint("events", [("event_id", "event_id")]), | ||||
|                 unique_columns=("event_id", "room_id"), | ||||
|             ) | ||||
| 
 | ||||
|     async def get_auth_chain( | ||||
|         self, room_id: str, event_ids: Collection[str], include_given: bool = False | ||||
|     ) -> List[EventBase]: | ||||
|  |  | |||
|  | @ -415,12 +415,6 @@ class PersistEventsStore: | |||
|                 backfilled=False, | ||||
|             ) | ||||
| 
 | ||||
|         self._update_forward_extremities_txn( | ||||
|             txn, | ||||
|             new_forward_extremities=new_forward_extremities, | ||||
|             max_stream_order=max_stream_order, | ||||
|         ) | ||||
| 
 | ||||
|         # Ensure that we don't have the same event twice. | ||||
|         events_and_contexts = self._filter_events_and_contexts_for_duplicates( | ||||
|             events_and_contexts | ||||
|  | @ -439,6 +433,12 @@ class PersistEventsStore: | |||
| 
 | ||||
|         self._store_event_txn(txn, events_and_contexts=events_and_contexts) | ||||
| 
 | ||||
|         self._update_forward_extremities_txn( | ||||
|             txn, | ||||
|             new_forward_extremities=new_forward_extremities, | ||||
|             max_stream_order=max_stream_order, | ||||
|         ) | ||||
| 
 | ||||
|         self._persist_transaction_ids_txn(txn, events_and_contexts) | ||||
| 
 | ||||
|         # Insert into event_to_state_groups. | ||||
|  |  | |||
|  | @ -0,0 +1,51 @@ | |||
| # Copyright 2023 The Matrix.org Foundation C.I.C. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| 
 | ||||
| """ | ||||
| This migration adds foreign key constraint to `event_forward_extremities` table. | ||||
| """ | ||||
| from synapse.storage.background_updates import ( | ||||
|     ForeignKeyConstraint, | ||||
|     run_validate_constraint_and_delete_rows_schema_delta, | ||||
| ) | ||||
| from synapse.storage.database import LoggingTransaction | ||||
| from synapse.storage.engines import BaseDatabaseEngine | ||||
| 
 | ||||
| FORWARD_EXTREMITIES_TABLE_SCHEMA = """ | ||||
|     CREATE TABLE event_forward_extremities2( | ||||
|         event_id TEXT NOT NULL, | ||||
|         room_id TEXT NOT NULL, | ||||
|         UNIQUE (event_id, room_id), | ||||
|         CONSTRAINT event_forward_extremities_event_id FOREIGN KEY (event_id) REFERENCES events (event_id) | ||||
|     ) | ||||
| """ | ||||
| 
 | ||||
| 
 | ||||
| def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) -> None: | ||||
|     run_validate_constraint_and_delete_rows_schema_delta( | ||||
|         cur, | ||||
|         ordering=7803, | ||||
|         update_name="event_forward_extremities_event_id_foreign_key_constraint_update", | ||||
|         table="event_forward_extremities", | ||||
|         constraint_name="event_forward_extremities_event_id", | ||||
|         constraint=ForeignKeyConstraint("events", [("event_id", "event_id")]), | ||||
|         sqlite_table_name="event_forward_extremities2", | ||||
|         sqlite_table_schema=FORWARD_EXTREMITIES_TABLE_SCHEMA, | ||||
|     ) | ||||
| 
 | ||||
|     # We can't add a similar constraint to `event_backward_extremities` as the | ||||
|     # events in there don't exist in the `events` table and `event_edges` | ||||
|     # doesn't have a unique constraint on `prev_event_id` (so we can't make a | ||||
|     # foreign key point to it). | ||||
|  | @ -20,7 +20,14 @@ from twisted.internet.defer import Deferred, ensureDeferred | |||
| from twisted.test.proto_helpers import MemoryReactor | ||||
| 
 | ||||
| from synapse.server import HomeServer | ||||
| from synapse.storage.background_updates import BackgroundUpdater | ||||
| from synapse.storage.background_updates import ( | ||||
|     BackgroundUpdater, | ||||
|     ForeignKeyConstraint, | ||||
|     NotNullConstraint, | ||||
|     run_validate_constraint_and_delete_rows_schema_delta, | ||||
| ) | ||||
| from synapse.storage.database import LoggingTransaction | ||||
| from synapse.storage.engines import PostgresEngine, Sqlite3Engine | ||||
| from synapse.types import JsonDict | ||||
| from synapse.util import Clock | ||||
| 
 | ||||
|  | @ -404,3 +411,221 @@ class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase): | |||
|         self.pump() | ||||
|         self._update_ctx_manager.__aexit__.assert_called() | ||||
|         self.get_success(do_update_d) | ||||
| 
 | ||||
| 
 | ||||
| class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase): | ||||
|     """Tests the validate contraint and delete background handlers.""" | ||||
| 
 | ||||
|     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | ||||
|         self.updates: BackgroundUpdater = self.hs.get_datastores().main.db_pool.updates | ||||
|         # the base test class should have run the real bg updates for us | ||||
|         self.assertTrue( | ||||
|             self.get_success(self.updates.has_completed_background_updates()) | ||||
|         ) | ||||
| 
 | ||||
|         self.store = self.hs.get_datastores().main | ||||
| 
 | ||||
|     def test_not_null_constraint(self) -> None: | ||||
|         # Create the initial tables, where we have some invalid data. | ||||
|         """Tests adding a not null constraint.""" | ||||
|         table_sql = """ | ||||
|             CREATE TABLE test_constraint( | ||||
|                 a INT PRIMARY KEY, | ||||
|                 b INT | ||||
|             ); | ||||
|         """ | ||||
|         self.get_success( | ||||
|             self.store.db_pool.execute( | ||||
|                 "test_not_null_constraint", lambda _: None, table_sql | ||||
|             ) | ||||
|         ) | ||||
| 
 | ||||
|         # We add an index so that we can check that its correctly recreated when | ||||
|         # using SQLite. | ||||
|         index_sql = "CREATE INDEX test_index ON test_constraint(a)" | ||||
|         self.get_success( | ||||
|             self.store.db_pool.execute( | ||||
|                 "test_not_null_constraint", lambda _: None, index_sql | ||||
|             ) | ||||
|         ) | ||||
| 
 | ||||
|         self.get_success( | ||||
|             self.store.db_pool.simple_insert("test_constraint", {"a": 1, "b": 1}) | ||||
|         ) | ||||
|         self.get_success( | ||||
|             self.store.db_pool.simple_insert("test_constraint", {"a": 2, "b": None}) | ||||
|         ) | ||||
|         self.get_success( | ||||
|             self.store.db_pool.simple_insert("test_constraint", {"a": 3, "b": 3}) | ||||
|         ) | ||||
| 
 | ||||
|         # Now lets do the migration | ||||
| 
 | ||||
|         table2_sqlite = """ | ||||
|             CREATE TABLE test_constraint2( | ||||
|                 a INT PRIMARY KEY, | ||||
|                 b INT, | ||||
|                 CONSTRAINT test_constraint_name CHECK (b is NOT NULL) | ||||
|             ); | ||||
|         """ | ||||
| 
 | ||||
|         def delta(txn: LoggingTransaction) -> None: | ||||
|             run_validate_constraint_and_delete_rows_schema_delta( | ||||
|                 txn, | ||||
|                 ordering=1000, | ||||
|                 update_name="test_bg_update", | ||||
|                 table="test_constraint", | ||||
|                 constraint_name="test_constraint_name", | ||||
|                 constraint=NotNullConstraint("b"), | ||||
|                 sqlite_table_name="test_constraint2", | ||||
|                 sqlite_table_schema=table2_sqlite, | ||||
|             ) | ||||
| 
 | ||||
|         self.get_success( | ||||
|             self.store.db_pool.runInteraction( | ||||
|                 "test_not_null_constraint", | ||||
|                 delta, | ||||
|             ) | ||||
|         ) | ||||
| 
 | ||||
|         if isinstance(self.store.database_engine, PostgresEngine): | ||||
|             # Postgres uses a background update | ||||
|             self.updates.register_background_validate_constraint_and_delete_rows( | ||||
|                 "test_bg_update", | ||||
|                 table="test_constraint", | ||||
|                 constraint_name="test_constraint_name", | ||||
|                 constraint=NotNullConstraint("b"), | ||||
|                 unique_columns=["a"], | ||||
|             ) | ||||
| 
 | ||||
|             # Tell the DataStore that it hasn't finished all updates yet | ||||
|             self.store.db_pool.updates._all_done = False | ||||
| 
 | ||||
|             # Now let's actually drive the updates to completion | ||||
|             self.wait_for_background_updates() | ||||
| 
 | ||||
|         # Check the correct values are in the new table. | ||||
|         rows = self.get_success( | ||||
|             self.store.db_pool.simple_select_list( | ||||
|                 table="test_constraint", | ||||
|                 keyvalues={}, | ||||
|                 retcols=("a", "b"), | ||||
|             ) | ||||
|         ) | ||||
| 
 | ||||
|         self.assertCountEqual(rows, [{"a": 1, "b": 1}, {"a": 3, "b": 3}]) | ||||
| 
 | ||||
|         # And check that invalid rows get correctly rejected. | ||||
|         self.get_failure( | ||||
|             self.store.db_pool.simple_insert("test_constraint", {"a": 2, "b": None}), | ||||
|             exc=self.store.database_engine.module.IntegrityError, | ||||
|         ) | ||||
| 
 | ||||
|         # Check the index is still there for SQLite. | ||||
|         if isinstance(self.store.database_engine, Sqlite3Engine): | ||||
|             # Ensure the index exists in the schema. | ||||
|             self.get_success( | ||||
|                 self.store.db_pool.simple_select_one_onecol( | ||||
|                     table="sqlite_master", | ||||
|                     keyvalues={"tbl_name": "test_constraint"}, | ||||
|                     retcol="name", | ||||
|                 ) | ||||
|             ) | ||||
| 
 | ||||
|     def test_foreign_constraint(self) -> None: | ||||
|         """Tests adding a not foreign key constraint.""" | ||||
| 
 | ||||
|         # Create the initial tables, where we have some invalid data. | ||||
|         base_sql = """ | ||||
|             CREATE TABLE base_table( | ||||
|                 b INT PRIMARY KEY | ||||
|             ); | ||||
|         """ | ||||
| 
 | ||||
|         table_sql = """ | ||||
|             CREATE TABLE test_constraint( | ||||
|                 a INT PRIMARY KEY, | ||||
|                 b INT NOT NULL | ||||
|             ); | ||||
|         """ | ||||
|         self.get_success( | ||||
|             self.store.db_pool.execute( | ||||
|                 "test_foreign_key_constraint", lambda _: None, base_sql | ||||
|             ) | ||||
|         ) | ||||
|         self.get_success( | ||||
|             self.store.db_pool.execute( | ||||
|                 "test_foreign_key_constraint", lambda _: None, table_sql | ||||
|             ) | ||||
|         ) | ||||
| 
 | ||||
|         self.get_success(self.store.db_pool.simple_insert("base_table", {"b": 1})) | ||||
|         self.get_success( | ||||
|             self.store.db_pool.simple_insert("test_constraint", {"a": 1, "b": 1}) | ||||
|         ) | ||||
|         self.get_success( | ||||
|             self.store.db_pool.simple_insert("test_constraint", {"a": 2, "b": 2}) | ||||
|         ) | ||||
|         self.get_success(self.store.db_pool.simple_insert("base_table", {"b": 3})) | ||||
|         self.get_success( | ||||
|             self.store.db_pool.simple_insert("test_constraint", {"a": 3, "b": 3}) | ||||
|         ) | ||||
| 
 | ||||
|         table2_sqlite = """ | ||||
|             CREATE TABLE test_constraint2( | ||||
|                 a INT PRIMARY KEY, | ||||
|                 b INT NOT NULL, | ||||
|                 CONSTRAINT test_constraint_name FOREIGN KEY (b) REFERENCES base_table (b) | ||||
|             ); | ||||
|         """ | ||||
| 
 | ||||
|         def delta(txn: LoggingTransaction) -> None: | ||||
|             run_validate_constraint_and_delete_rows_schema_delta( | ||||
|                 txn, | ||||
|                 ordering=1000, | ||||
|                 update_name="test_bg_update", | ||||
|                 table="test_constraint", | ||||
|                 constraint_name="test_constraint_name", | ||||
|                 constraint=ForeignKeyConstraint("base_table", [("b", "b")]), | ||||
|                 sqlite_table_name="test_constraint2", | ||||
|                 sqlite_table_schema=table2_sqlite, | ||||
|             ) | ||||
| 
 | ||||
|         self.get_success( | ||||
|             self.store.db_pool.runInteraction( | ||||
|                 "test_foreign_key_constraint", | ||||
|                 delta, | ||||
|             ) | ||||
|         ) | ||||
| 
 | ||||
|         if isinstance(self.store.database_engine, PostgresEngine): | ||||
|             # Postgres uses a background update | ||||
|             self.updates.register_background_validate_constraint_and_delete_rows( | ||||
|                 "test_bg_update", | ||||
|                 table="test_constraint", | ||||
|                 constraint_name="test_constraint_name", | ||||
|                 constraint=ForeignKeyConstraint("base_table", [("b", "b")]), | ||||
|                 unique_columns=["a"], | ||||
|             ) | ||||
| 
 | ||||
|             # Tell the DataStore that it hasn't finished all updates yet | ||||
|             self.store.db_pool.updates._all_done = False | ||||
| 
 | ||||
|             # Now let's actually drive the updates to completion | ||||
|             self.wait_for_background_updates() | ||||
| 
 | ||||
|         # Check the correct values are in the new table. | ||||
|         rows = self.get_success( | ||||
|             self.store.db_pool.simple_select_list( | ||||
|                 table="test_constraint", | ||||
|                 keyvalues={}, | ||||
|                 retcols=("a", "b"), | ||||
|             ) | ||||
|         ) | ||||
|         self.assertCountEqual(rows, [{"a": 1, "b": 1}, {"a": 3, "b": 3}]) | ||||
| 
 | ||||
|         # And check that invalid rows get correctly rejected. | ||||
|         self.get_failure( | ||||
|             self.store.db_pool.simple_insert("test_constraint", {"a": 2, "b": 2}), | ||||
|             exc=self.store.database_engine.module.IntegrityError, | ||||
|         ) | ||||
|  |  | |||
|  | @ -20,6 +20,7 @@ from parameterized import parameterized | |||
| 
 | ||||
| from twisted.test.proto_helpers import MemoryReactor | ||||
| 
 | ||||
| from synapse.api.constants import EventTypes | ||||
| from synapse.api.room_versions import ( | ||||
|     KNOWN_ROOM_VERSIONS, | ||||
|     EventFormatVersions, | ||||
|  | @ -98,8 +99,32 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): | |||
|         room2 = "#room2" | ||||
|         room3 = "#room3" | ||||
| 
 | ||||
|         def insert_event(txn: Cursor, i: int, room_id: str) -> None: | ||||
|         def insert_event(txn: LoggingTransaction, i: int, room_id: str) -> None: | ||||
|             event_id = "$event_%i:local" % i | ||||
| 
 | ||||
|             # We need to insert into events table to get around the foreign key constraint. | ||||
|             self.store.db_pool.simple_insert_txn( | ||||
|                 txn, | ||||
|                 table="events", | ||||
|                 values={ | ||||
|                     "instance_name": "master", | ||||
|                     "stream_ordering": self.store._stream_id_gen.get_next_txn(txn), | ||||
|                     "topological_ordering": 1, | ||||
|                     "depth": 1, | ||||
|                     "event_id": event_id, | ||||
|                     "room_id": room_id, | ||||
|                     "type": EventTypes.Message, | ||||
|                     "processed": True, | ||||
|                     "outlier": False, | ||||
|                     "origin_server_ts": 0, | ||||
|                     "received_ts": 0, | ||||
|                     "sender": "@user:local", | ||||
|                     "contains_url": False, | ||||
|                     "state_key": None, | ||||
|                     "rejection_reason": None, | ||||
|                 }, | ||||
|             ) | ||||
| 
 | ||||
|             txn.execute( | ||||
|                 ( | ||||
|                     "INSERT INTO event_forward_extremities (room_id, event_id) " | ||||
|  | @ -113,10 +138,14 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): | |||
|                 self.store.db_pool.runInteraction("insert", insert_event, i, room1) | ||||
|             ) | ||||
|             self.get_success( | ||||
|                 self.store.db_pool.runInteraction("insert", insert_event, i, room2) | ||||
|                 self.store.db_pool.runInteraction( | ||||
|                     "insert", insert_event, i + 100, room2 | ||||
|                 ) | ||||
|             ) | ||||
|             self.get_success( | ||||
|                 self.store.db_pool.runInteraction("insert", insert_event, i, room3) | ||||
|                 self.store.db_pool.runInteraction( | ||||
|                     "insert", insert_event, i + 200, room3 | ||||
|                 ) | ||||
|             ) | ||||
| 
 | ||||
|         # Test simple case | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 Erik Johnston
						Erik Johnston