diff --git a/CHANGELOG.md b/CHANGELOG.md index 8eff1f5..aad6392 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,11 +2,18 @@ All notable changes to this project will be documented in this file. -The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project -adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +The format is based on [Keep a +Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to +[Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +## [0.1.2] - 2024-08-23 + +- Fixes a bug where the `SaferAddIndexConcurrently` class would try to perform + a migration regardless of whether the router would allow it (through + `router.allow_migrate`). + ## [0.1.1] - 2024-08-14 - Non-functional changes for the documentation to be properly linked in PyPI. diff --git a/docs/usage/operations.rst b/docs/usage/operations.rst index b7bcbee..e874af2 100644 --- a/docs/usage/operations.rst +++ b/docs/usage/operations.rst @@ -6,7 +6,7 @@ Provides custom migration operations that help developers perform idempotent and Class Definitions ----------------- -.. py:class:: SaferAddIndexConcurrently(model_name: str, index: models.Index) +.. py:class:: SaferAddIndexConcurrently(model_name: str, index: models.Index, hints: Any = None) Performs CREATE INDEX CONCURRENTLY IF NOT EXISTS without a lock_timeout value to guarantee the index creation won't be affected by any pre-set @@ -16,6 +16,8 @@ Class Definitions :type model_name: str :param index: Any type of index supported by Django. :type index: models.Index + :param hints: Hints to be passed to the router as per https://docs.djangoproject.com/en/5.1/topics/db/multi-db/#hints + :type hints: Any **Why use this SaferAddIndexConcurrently operation?** ----------------------------------------------------- diff --git a/pyproject.toml b/pyproject.toml index 48518af..66b4f95 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ where = ["src"] [project] name = "django_pg_migration_tools" -version = "0.1.1" +version = "0.1.2" description = "Tools for making Django migrations safer and more scalable." license.file = "LICENSE" readme = "README.md" diff --git a/src/django_pg_migration_tools/operations.py b/src/django_pg_migration_tools/operations.py index 4163620..cf6c34c 100644 --- a/src/django_pg_migration_tools/operations.py +++ b/src/django_pg_migration_tools/operations.py @@ -4,7 +4,7 @@ from typing import Any from django.contrib.postgres import operations as psql_operations -from django.db import migrations, models +from django.db import migrations, models, router from django.db.backends.base import schema as base_schema from django.db.migrations.operations import base as migrations_base @@ -49,10 +49,11 @@ class due to limitations of Django's AddIndexConcurrently operation. DROP_INDEX_QUERY = 'DROP INDEX CONCURRENTLY IF EXISTS "{}";' - def __init__(self, model_name: str, index: models.Index) -> None: + def __init__(self, model_name: str, index: models.Index, hints: Any = None) -> None: self.model_name = model_name self.index = index self.original_lock_timeout = "" + self.hints = {} if hints is None else hints def describe(self) -> str: return ( @@ -70,6 +71,12 @@ def database_forwards( to_state: migrations.state.ProjectState, ) -> None: self._ensure_not_in_transaction(schema_editor) + + if not router.allow_migrate( + schema_editor.connection.alias, app_label, **self.hints + ): + return + self._ensure_no_lock_timeout_set(schema_editor) self._ensure_not_an_invalid_index(schema_editor) model = from_state.apps.get_model(app_label, self.model_name) @@ -90,6 +97,12 @@ def database_backwards( to_state: migrations.state.ProjectState, ) -> None: self._ensure_not_in_transaction(schema_editor) + + if not router.allow_migrate( + schema_editor.connection.alias, app_label, **self.hints + ): + return + self._ensure_no_lock_timeout_set(schema_editor) model = from_state.apps.get_model(app_label, self.model_name) index_sql = str(self.index.remove_sql(model, schema_editor, concurrently=True)) diff --git a/tests/django_pg_migration_tools/test_operations.py b/tests/django_pg_migration_tools/test_operations.py index dd2d8e4..7188663 100644 --- a/tests/django_pg_migration_tools/test_operations.py +++ b/tests/django_pg_migration_tools/test_operations.py @@ -1,4 +1,5 @@ from textwrap import dedent +from typing import Any import pytest from django.db import ( @@ -10,7 +11,7 @@ ProjectState, ) from django.db.models import Index -from django.test import utils +from django.test import override_settings, utils from django_pg_migration_tools import operations from tests.example_app.models import IntModel @@ -34,6 +35,16 @@ ); """ +_CHECK_INVALID_INDEX_EXISTS_QUERY = """ +SELECT relname +FROM pg_class, pg_index +WHERE ( + pg_index.indisvalid = false + AND pg_index.indexrelid = pg_class.oid + AND relname = 'int_field_idx' +); +""" + _CREATE_INDEX_QUERY = """ CREATE INDEX "int_field_idx" ON "example_app_intmodel" ("int_field"); @@ -55,6 +66,16 @@ """ +class AllowDefaultOnly: + """ + A router that only allows a migration to happen if the instance is the + "default" instance. + """ + + def allow_migrate(self, db: str, app_label: str, **hints: Any) -> bool: + return bool(hints["instance"] == "default") + + class TestSaferAddIndexConcurrently: app_label = "example_app" @@ -74,6 +95,7 @@ def test_requires_atomic_false(self): # Disable the overall test transaction because a concurrent index cannot # be triggered/tested inside of a transaction. @pytest.mark.django_db(transaction=True) + @override_settings(DATABASE_ROUTERS=[AllowDefaultOnly()]) def test_add(self): with connection.cursor() as cursor: # We first create the index and set it to invalid, to make sure it @@ -100,7 +122,9 @@ def test_add(self): # Set the operation that will drop the invalid index and re-create it # (without lock timeouts). index = Index(fields=["int_field"], name="int_field_idx") - operation = operations.SaferAddIndexConcurrently("IntModel", index) + operation = operations.SaferAddIndexConcurrently( + "IntModel", index, hints={"instance": "default"} + ) assert operation.describe() == ( "Concurrently creates index int_field_idx on field(s) " @@ -182,3 +206,59 @@ def test_add(self): with connection.cursor() as cursor: cursor.execute(_CHECK_INDEX_EXISTS_QUERY) assert not cursor.fetchone() + + # Disable the overall test transaction because a concurrent index cannot + # be triggered/tested inside of a transaction. + @pytest.mark.django_db(transaction=True) + @override_settings(DATABASE_ROUTERS=[AllowDefaultOnly()]) + def test_when_not_allowed_to_migrate(self): + with connection.cursor() as cursor: + # We first create the index and set it to invalid, to make sure it + # will not be removed automatically because the operation is not + # allowed to run. + cursor.execute(_CREATE_INDEX_QUERY) + cursor.execute(_SET_INDEX_INVALID) + + # Prove that the invalid index exists before the operation runs: + with connection.cursor() as cursor: + cursor.execute( + operations.SaferAddIndexConcurrently.CHECK_INVALID_INDEX_QUERY, + {"index_name": "int_field_idx"}, + ) + assert cursor.fetchone() + + project_state = ProjectState() + project_state.add_model(ModelState.from_model(IntModel)) + new_state = project_state.clone() + + index = Index(fields=["int_field"], name="int_field_idx") + operation = operations.SaferAddIndexConcurrently( + # Our migration should only be allowed to run if the instance + # equals "default" - which isn't the case here. + "IntModel", + index, + hints={"instance": "replica"}, + ) + + operation.state_forwards(self.app_label, new_state) + assert len(new_state.models[self.app_label, "intmodel"].options["indexes"]) == 1 + assert ( + new_state.models[self.app_label, "intmodel"].options["indexes"][0].name + == "int_field_idx" + ) + # Proceed to try and add the index: + with connection.schema_editor(atomic=False, collect_sql=False) as editor: + with utils.CaptureQueriesContext(connection) as queries: + operation.database_forwards( + self.app_label, editor, project_state, new_state + ) + + # No queries have run, because the migration wasn't allowed to run by + # the router. + assert len(queries) == 0 + + # Make sure the invalid index was NOT been replaced by a valid index. + # (because the router didn't allow this migration to run). + with connection.cursor() as cursor: + cursor.execute(_CHECK_INVALID_INDEX_EXISTS_QUERY) + assert cursor.fetchone()