Skip to content

Commit

Permalink
Do not migrate if the router check does not pass
Browse files Browse the repository at this point in the history
Prior to this change, we weren't checking if the database router allowed
the migration to proceed.

We thus add the router.allow_migrate() check, which is the same that the
RunSQL operation runs internally. Ref:

- https://github.com/django/django/blob/7adb6dd98d50a238f3eca8c15b16b5aec12575fd/django/db/migrations/operations/special.py#L105

A possible error that would raise before this fix happens when the
Django applicatoin has multiple databases and the migration should be
routed to only apply on a particular database.
  • Loading branch information
marcelofern committed Aug 22, 2024
1 parent e45c43c commit 9180c4e
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 26 deletions.
11 changes: 9 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,18 @@

All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project
adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
The format is based on [Keep a
Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to
[Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]

## [0.1.2] - 2024-08-23

- Fixes a bug where the `SaferAddIndexConcurrently` class would try to perform
a migration regardless of whether the router wouldn't allow it (through
`router.allow_migrate`).

## [0.1.1] - 2024-08-14

- Non-functional changes for the documentation to be properly linked in PyPI.
Expand Down
4 changes: 3 additions & 1 deletion docs/usage/operations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Provides custom migration operations that help developers perform idempotent and
Class Definitions
-----------------

.. py:class:: SaferAddIndexConcurrently(model_name: str, index: models.Index)
.. py:class:: SaferAddIndexConcurrently(model_name: str, index: models.Index, hints: Any = None)
Performs CREATE INDEX CONCURRENTLY IF NOT EXISTS without a lock_timeout
value to guarantee the index creation won't be affected by any pre-set
Expand All @@ -16,6 +16,8 @@ Class Definitions
:type model_name: str
:param index: Any type of index supported by Django.
:type index: models.Index
:param hints: Hints to be passed to the router as per https://docs.djangoproject.com/en/5.1/topics/db/multi-db/#hints
:type hints: Any

**Why use this SaferAddIndexConcurrently operation?**
-----------------------------------------------------
Expand Down
53 changes: 32 additions & 21 deletions src/django_pg_migration_tools/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any

from django.contrib.postgres import operations as psql_operations
from django.db import migrations, models
from django.db import migrations, models, router
from django.db.backends.base import schema as base_schema
from django.db.migrations.operations import base as migrations_base

Expand Down Expand Up @@ -49,10 +49,11 @@ class due to limitations of Django's AddIndexConcurrently operation.

DROP_INDEX_QUERY = 'DROP INDEX CONCURRENTLY IF EXISTS "{}";'

def __init__(self, model_name: str, index: models.Index) -> None:
def __init__(self, model_name: str, index: models.Index, hints: Any = None) -> None:
self.model_name = model_name
self.index = index
self.original_lock_timeout = ""
self.hints = hints or {}

def describe(self) -> str:
return (
Expand All @@ -70,17 +71,22 @@ def database_forwards(
to_state: migrations.state.ProjectState,
) -> None:
self._ensure_not_in_transaction(schema_editor)
self._ensure_no_lock_timeout_set(schema_editor)
self._ensure_not_an_invalid_index(schema_editor)
model = from_state.apps.get_model(app_label, self.model_name)
index_sql = str(self.index.create_sql(model, schema_editor, concurrently=True))
# Inject the IF NOT EXISTS because Django doesn't provide a handy
# if_not_exists: bool parameter for us to use.
index_sql = index_sql.replace(
"CREATE INDEX CONCURRENTLY", "CREATE INDEX CONCURRENTLY IF NOT EXISTS"
)
schema_editor.execute(index_sql)
self._ensure_original_lock_timeout_is_reset(schema_editor)
if router.allow_migrate(
schema_editor.connection.alias, app_label, **self.hints
):
self._ensure_no_lock_timeout_set(schema_editor)
self._ensure_not_an_invalid_index(schema_editor)
model = from_state.apps.get_model(app_label, self.model_name)
index_sql = str(
self.index.create_sql(model, schema_editor, concurrently=True)
)
# Inject the IF NOT EXISTS because Django doesn't provide a handy
# if_not_exists: bool parameter for us to use.
index_sql = index_sql.replace(
"CREATE INDEX CONCURRENTLY", "CREATE INDEX CONCURRENTLY IF NOT EXISTS"
)
schema_editor.execute(index_sql)
self._ensure_original_lock_timeout_is_reset(schema_editor)

def database_backwards(
self,
Expand All @@ -90,14 +96,19 @@ def database_backwards(
to_state: migrations.state.ProjectState,
) -> None:
self._ensure_not_in_transaction(schema_editor)
self._ensure_no_lock_timeout_set(schema_editor)
model = from_state.apps.get_model(app_label, self.model_name)
index_sql = str(self.index.remove_sql(model, schema_editor, concurrently=True))
# Differently from the CREATE INDEX operation, Django already provides
# us with IF EXISTS when dropping an index... We don't have to do that
# .replace() call here.
schema_editor.execute(index_sql)
self._ensure_original_lock_timeout_is_reset(schema_editor)
if router.allow_migrate(
schema_editor.connection.alias, app_label, **self.hints
):
self._ensure_no_lock_timeout_set(schema_editor)
model = from_state.apps.get_model(app_label, self.model_name)
index_sql = str(
self.index.remove_sql(model, schema_editor, concurrently=True)
)
# Differently from the CREATE INDEX operation, Django already provides
# us with IF EXISTS when dropping an index... We don't have to do that
# .replace() call here.
schema_editor.execute(index_sql)
self._ensure_original_lock_timeout_is_reset(schema_editor)

def _ensure_no_lock_timeout_set(
self,
Expand Down
86 changes: 84 additions & 2 deletions tests/django_pg_migration_tools/test_operations.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from textwrap import dedent
from typing import Any

import pytest
from django.db import (
Expand All @@ -10,7 +11,7 @@
ProjectState,
)
from django.db.models import Index
from django.test import utils
from django.test import override_settings, utils

from django_pg_migration_tools import operations
from tests.example_app.models import IntModel
Expand All @@ -34,6 +35,16 @@
);
"""

_CHECK_INVALID_INDEX_EXISTS_QUERY = """
SELECT relname
FROM pg_class, pg_index
WHERE (
pg_index.indisvalid = false
AND pg_index.indexrelid = pg_class.oid
AND relname = 'int_field_idx'
);
"""

_CREATE_INDEX_QUERY = """
CREATE INDEX "int_field_idx"
ON "example_app_intmodel" ("int_field");
Expand All @@ -55,6 +66,16 @@
"""


class AllowDefaultOnly:
"""
A router that only allows a migration to happen if the instance is the
"default" instance.
"""

def allow_migrate(self, db: str, app_label: str, **hints: Any) -> bool:
return bool(hints["instance"] == "default")


class TestSaferAddIndexConcurrently:
app_label = "example_app"

Expand All @@ -74,6 +95,7 @@ def test_requires_atomic_false(self):
# Disable the overall test transaction because a concurrent index cannot
# be triggered/tested inside of a transaction.
@pytest.mark.django_db(transaction=True)
@override_settings(DATABASE_ROUTERS=[AllowDefaultOnly()])
def test_add(self):
with connection.cursor() as cursor:
# We first create the index and set it to invalid, to make sure it
Expand All @@ -100,7 +122,9 @@ def test_add(self):
# Set the operation that will drop the invalid index and re-create it
# (without lock timeouts).
index = Index(fields=["int_field"], name="int_field_idx")
operation = operations.SaferAddIndexConcurrently("IntModel", index)
operation = operations.SaferAddIndexConcurrently(
"IntModel", index, hints={"instance": "default"}
)

assert operation.describe() == (
"Concurrently creates index int_field_idx on field(s) "
Expand Down Expand Up @@ -182,3 +206,61 @@ def test_add(self):
with connection.cursor() as cursor:
cursor.execute(_CHECK_INDEX_EXISTS_QUERY)
assert not cursor.fetchone()

# Disable the overall test transaction because a concurrent index cannot
# be triggered/tested inside of a transaction.
@pytest.mark.django_db(transaction=True)
@override_settings(DATABASE_ROUTERS=[AllowDefaultOnly()])
def test_when_not_allowed_to_migrate(self):
with connection.cursor() as cursor:
# We first create the index and set it to invalid, to make sure it
# will not be removed automatically because the operation is not
# allowed to run.
cursor.execute(_CREATE_INDEX_QUERY)
cursor.execute(_SET_INDEX_INVALID)

# Prove that the invalid index exists before the operation runs:
with connection.cursor() as cursor:
cursor.execute(
operations.SaferAddIndexConcurrently.CHECK_INVALID_INDEX_QUERY,
{"index_name": "int_field_idx"},
)
assert cursor.fetchone()

project_state = ProjectState()
project_state.add_model(ModelState.from_model(IntModel))
new_state = project_state.clone()

# Set the operation that will drop the invalid index and re-create it
# (without lock timeouts).
index = Index(fields=["int_field"], name="int_field_idx")
operation = operations.SaferAddIndexConcurrently(
# Our migration should only be allowed to run if the instance
# equals "default" - which isn't the case here.
"IntModel",
index,
hints={"instance": "replica"},
)

operation.state_forwards(self.app_label, new_state)
assert len(new_state.models[self.app_label, "intmodel"].options["indexes"]) == 1
assert (
new_state.models[self.app_label, "intmodel"].options["indexes"][0].name
== "int_field_idx"
)
# Proceed to add the index:
with connection.schema_editor(atomic=False, collect_sql=False) as editor:
with utils.CaptureQueriesContext(connection) as queries:
operation.database_forwards(
self.app_label, editor, project_state, new_state
)

# No queries have run, because the migration wasn't allowed to run by
# the router.
assert len(queries) == 0

# Make sure the invalid index was NOT been replaced by a valid index.
# (because the router didn't allow this migration to run).
with connection.cursor() as cursor:
cursor.execute(_CHECK_INVALID_INDEX_EXISTS_QUERY)
assert cursor.fetchone()

0 comments on commit 9180c4e

Please sign in to comment.