Skip to content

Commit

Permalink
Merge pull request #16 from kraken-tech/prevent-index-creation-if-rou…
Browse files Browse the repository at this point in the history
…ter-says-no

Do not migrate if the router check does not pass
  • Loading branch information
marcelofern authored Aug 22, 2024
2 parents e45c43c + 9e6f2c6 commit 6a215d3
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 8 deletions.
11 changes: 9 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,18 @@

All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project
adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
The format is based on [Keep a
Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to
[Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]

## [0.1.2] - 2024-08-23

- Fixes a bug where the `SaferAddIndexConcurrently` class would try to perform
a migration regardless of whether the router would allow it (through
`router.allow_migrate`).

## [0.1.1] - 2024-08-14

- Non-functional changes for the documentation to be properly linked in PyPI.
Expand Down
4 changes: 3 additions & 1 deletion docs/usage/operations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Provides custom migration operations that help developers perform idempotent and
Class Definitions
-----------------

.. py:class:: SaferAddIndexConcurrently(model_name: str, index: models.Index)
.. py:class:: SaferAddIndexConcurrently(model_name: str, index: models.Index, hints: Any = None)
Performs CREATE INDEX CONCURRENTLY IF NOT EXISTS without a lock_timeout
value to guarantee the index creation won't be affected by any pre-set
Expand All @@ -16,6 +16,8 @@ Class Definitions
:type model_name: str
:param index: Any type of index supported by Django.
:type index: models.Index
:param hints: Hints to be passed to the router as per https://docs.djangoproject.com/en/5.1/topics/db/multi-db/#hints
:type hints: Any

**Why use this SaferAddIndexConcurrently operation?**
-----------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ where = ["src"]

[project]
name = "django_pg_migration_tools"
version = "0.1.1"
version = "0.1.2"
description = "Tools for making Django migrations safer and more scalable."
license.file = "LICENSE"
readme = "README.md"
Expand Down
17 changes: 15 additions & 2 deletions src/django_pg_migration_tools/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any

from django.contrib.postgres import operations as psql_operations
from django.db import migrations, models
from django.db import migrations, models, router
from django.db.backends.base import schema as base_schema
from django.db.migrations.operations import base as migrations_base

Expand Down Expand Up @@ -49,10 +49,11 @@ class due to limitations of Django's AddIndexConcurrently operation.

DROP_INDEX_QUERY = 'DROP INDEX CONCURRENTLY IF EXISTS "{}";'

def __init__(self, model_name: str, index: models.Index) -> None:
def __init__(self, model_name: str, index: models.Index, hints: Any = None) -> None:
self.model_name = model_name
self.index = index
self.original_lock_timeout = ""
self.hints = {} if hints is None else hints

def describe(self) -> str:
return (
Expand All @@ -70,6 +71,12 @@ def database_forwards(
to_state: migrations.state.ProjectState,
) -> None:
self._ensure_not_in_transaction(schema_editor)

if not router.allow_migrate(
schema_editor.connection.alias, app_label, **self.hints
):
return

self._ensure_no_lock_timeout_set(schema_editor)
self._ensure_not_an_invalid_index(schema_editor)
model = from_state.apps.get_model(app_label, self.model_name)
Expand All @@ -90,6 +97,12 @@ def database_backwards(
to_state: migrations.state.ProjectState,
) -> None:
self._ensure_not_in_transaction(schema_editor)

if not router.allow_migrate(
schema_editor.connection.alias, app_label, **self.hints
):
return

self._ensure_no_lock_timeout_set(schema_editor)
model = from_state.apps.get_model(app_label, self.model_name)
index_sql = str(self.index.remove_sql(model, schema_editor, concurrently=True))
Expand Down
84 changes: 82 additions & 2 deletions tests/django_pg_migration_tools/test_operations.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from textwrap import dedent
from typing import Any

import pytest
from django.db import (
Expand All @@ -10,7 +11,7 @@
ProjectState,
)
from django.db.models import Index
from django.test import utils
from django.test import override_settings, utils

from django_pg_migration_tools import operations
from tests.example_app.models import IntModel
Expand All @@ -34,6 +35,16 @@
);
"""

_CHECK_INVALID_INDEX_EXISTS_QUERY = """
SELECT relname
FROM pg_class, pg_index
WHERE (
pg_index.indisvalid = false
AND pg_index.indexrelid = pg_class.oid
AND relname = 'int_field_idx'
);
"""

_CREATE_INDEX_QUERY = """
CREATE INDEX "int_field_idx"
ON "example_app_intmodel" ("int_field");
Expand All @@ -55,6 +66,16 @@
"""


class AllowDefaultOnly:
"""
A router that only allows a migration to happen if the instance is the
"default" instance.
"""

def allow_migrate(self, db: str, app_label: str, **hints: Any) -> bool:
return bool(hints["instance"] == "default")


class TestSaferAddIndexConcurrently:
app_label = "example_app"

Expand All @@ -74,6 +95,7 @@ def test_requires_atomic_false(self):
# Disable the overall test transaction because a concurrent index cannot
# be triggered/tested inside of a transaction.
@pytest.mark.django_db(transaction=True)
@override_settings(DATABASE_ROUTERS=[AllowDefaultOnly()])
def test_add(self):
with connection.cursor() as cursor:
# We first create the index and set it to invalid, to make sure it
Expand All @@ -100,7 +122,9 @@ def test_add(self):
# Set the operation that will drop the invalid index and re-create it
# (without lock timeouts).
index = Index(fields=["int_field"], name="int_field_idx")
operation = operations.SaferAddIndexConcurrently("IntModel", index)
operation = operations.SaferAddIndexConcurrently(
"IntModel", index, hints={"instance": "default"}
)

assert operation.describe() == (
"Concurrently creates index int_field_idx on field(s) "
Expand Down Expand Up @@ -182,3 +206,59 @@ def test_add(self):
with connection.cursor() as cursor:
cursor.execute(_CHECK_INDEX_EXISTS_QUERY)
assert not cursor.fetchone()

# Disable the overall test transaction because a concurrent index cannot
# be triggered/tested inside of a transaction.
@pytest.mark.django_db(transaction=True)
@override_settings(DATABASE_ROUTERS=[AllowDefaultOnly()])
def test_when_not_allowed_to_migrate(self):
with connection.cursor() as cursor:
# We first create the index and set it to invalid, to make sure it
# will not be removed automatically because the operation is not
# allowed to run.
cursor.execute(_CREATE_INDEX_QUERY)
cursor.execute(_SET_INDEX_INVALID)

# Prove that the invalid index exists before the operation runs:
with connection.cursor() as cursor:
cursor.execute(
operations.SaferAddIndexConcurrently.CHECK_INVALID_INDEX_QUERY,
{"index_name": "int_field_idx"},
)
assert cursor.fetchone()

project_state = ProjectState()
project_state.add_model(ModelState.from_model(IntModel))
new_state = project_state.clone()

index = Index(fields=["int_field"], name="int_field_idx")
operation = operations.SaferAddIndexConcurrently(
# Our migration should only be allowed to run if the instance
# equals "default" - which isn't the case here.
"IntModel",
index,
hints={"instance": "replica"},
)

operation.state_forwards(self.app_label, new_state)
assert len(new_state.models[self.app_label, "intmodel"].options["indexes"]) == 1
assert (
new_state.models[self.app_label, "intmodel"].options["indexes"][0].name
== "int_field_idx"
)
# Proceed to try and add the index:
with connection.schema_editor(atomic=False, collect_sql=False) as editor:
with utils.CaptureQueriesContext(connection) as queries:
operation.database_forwards(
self.app_label, editor, project_state, new_state
)

# No queries have run, because the migration wasn't allowed to run by
# the router.
assert len(queries) == 0

# Make sure the invalid index was NOT been replaced by a valid index.
# (because the router didn't allow this migration to run).
with connection.cursor() as cursor:
cursor.execute(_CHECK_INVALID_INDEX_EXISTS_QUERY)
assert cursor.fetchone()

0 comments on commit 6a215d3

Please sign in to comment.