Skip to content

Allow running Golang based post migration steps #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 99 additions & 9 deletions migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,71 @@ func (e ErrDirty) Error() string {
return fmt.Sprintf("Dirty database version %v. Fix and force version.", e.Version)
}

// PostStepCallback is a callback function type that can be used to execute a
// Golang based migration step after a SQL based migration step has been
// executed. The callback function receives the migration and the database
// driver as arguments.
type PostStepCallback func(migr *Migration, driver database.Driver) error

// options is a set of optional options that can be set when a Migrate instance
// is created.
type options struct {
// postStepCallbacks is a map of PostStepCallback functions that can be
// used to execute a Golang based migration step after a SQL based
// migration step has been executed. The key is the migration version
// and the value is the callback function that should be run _after_ the
// step was executed (but within the same database transaction).
postStepCallbacks map[uint]PostStepCallback
}

// defaultOptions returns a new options struct with default values.
func defaultOptions() options {
return options{
postStepCallbacks: make(map[uint]PostStepCallback),
}
}

// Option is a function that can be used to set options on a Migrate instance.
type Option func(*options)

// WithPostStepCallbacks is an option that can be used to set a map of
// PostStepCallback functions that can be used to execute a Golang based
// migration step after a SQL based migration step has been executed. The key is
// the migration version and the value is the callback function that should be
// run _after_ the step was executed (but before the version is marked as
// cleanly executed). An error returned from the callback will cause the
// migration to fail and the step to be marked as dirty.
func WithPostStepCallbacks(
postStepCallbacks map[uint]PostStepCallback) Option {

return func(o *options) {
o.postStepCallbacks = postStepCallbacks
}
}

// WithPostStepCallback is an option that can be used to set a PostStepCallback
// function that can be used to execute a Golang based migration step after the
// SQL based migration step with the given version number has been executed. The
// callback is the function that should be run _after_ the step was executed
// (but before the version is marked as cleanly executed). An error returned
// from the callback will cause the migration to fail and the step to be marked
// as dirty.
func WithPostStepCallback(version uint, callback PostStepCallback) Option {
return func(o *options) {
o.postStepCallbacks[version] = callback
}
}

type Migrate struct {
sourceName string
sourceDrv source.Driver
databaseName string
databaseDrv database.Driver

// opts is a set of options that can be used to modify the behavior
// of the Migrate instance.
opts options

// Log accepts a Logger interface
Log Logger

Expand All @@ -84,8 +143,8 @@ type Migrate struct {

// New returns a new Migrate instance from a source URL and a database URL.
// The URL scheme is defined by each driver.
func New(sourceURL, databaseURL string) (*Migrate, error) {
m := newCommon()
func New(sourceURL, databaseURL string, opts ...Option) (*Migrate, error) {
m := newMigrateWithOptions(opts)

sourceName, err := iurl.SchemeFromURL(sourceURL)
if err != nil {
Expand Down Expand Up @@ -118,8 +177,10 @@ func New(sourceURL, databaseURL string) (*Migrate, error) {
// and an existing database instance. The source URL scheme is defined by each driver.
// Use any string that can serve as an identifier during logging as databaseName.
// You are responsible for closing the underlying database client if necessary.
func NewWithDatabaseInstance(sourceURL string, databaseName string, databaseInstance database.Driver) (*Migrate, error) {
m := newCommon()
func NewWithDatabaseInstance(sourceURL string, databaseName string,
databaseInstance database.Driver, opts ...Option) (*Migrate, error) {

m := newMigrateWithOptions(opts)

sourceName, err := iurl.SchemeFromURL(sourceURL)
if err != nil {
Expand All @@ -144,8 +205,10 @@ func NewWithDatabaseInstance(sourceURL string, databaseName string, databaseInst
// and a database URL. The database URL scheme is defined by each driver.
// Use any string that can serve as an identifier during logging as sourceName.
// You are responsible for closing the underlying source client if necessary.
func NewWithSourceInstance(sourceName string, sourceInstance source.Driver, databaseURL string) (*Migrate, error) {
m := newCommon()
func NewWithSourceInstance(sourceName string, sourceInstance source.Driver,
databaseURL string, opts ...Option) (*Migrate, error) {

m := newMigrateWithOptions(opts)

databaseName, err := iurl.SchemeFromURL(databaseURL)
if err != nil {
Expand All @@ -170,8 +233,11 @@ func NewWithSourceInstance(sourceName string, sourceInstance source.Driver, data
// database instance. Use any string that can serve as an identifier during logging
// as sourceName and databaseName. You are responsible for closing down
// the underlying source and database client if necessary.
func NewWithInstance(sourceName string, sourceInstance source.Driver, databaseName string, databaseInstance database.Driver) (*Migrate, error) {
m := newCommon()
func NewWithInstance(sourceName string, sourceInstance source.Driver,
databaseName string, databaseInstance database.Driver,
opts ...Option) (*Migrate, error) {

m := newMigrateWithOptions(opts)

m.sourceName = sourceName
m.databaseName = databaseName
Expand All @@ -182,8 +248,13 @@ func NewWithInstance(sourceName string, sourceInstance source.Driver, databaseNa
return m, nil
}

func newCommon() *Migrate {
func newMigrateWithOptions(optFunctions []Option) *Migrate {
opts := defaultOptions()
for _, opt := range optFunctions {
opt(&opts)
}
return &Migrate{
opts: opts,
GracefulStop: make(chan bool, 1),
PrefetchMigrations: DefaultPrefetchMigrations,
LockTimeout: DefaultLockTimeout,
Expand Down Expand Up @@ -746,6 +817,25 @@ func (m *Migrate) runMigrations(ret <-chan interface{}) error {
if err := m.databaseDrv.Run(migr.BufferedBody); err != nil {
return err
}

// If there is a post execution function for
// this migration, run it now.
cb, ok := m.opts.postStepCallbacks[migr.Version]
if ok {
m.logVerbosePrintf("Running post step "+
"callback for %v\n", migr.LogString())

err := cb(migr, m.databaseDrv)
if err != nil {
return fmt.Errorf("failed to "+
"execute post "+
"step callback: %w",
err)
}

m.logVerbosePrintf("Post step callback "+
"finished for %v\n", migr.LogString())
}
}

// set clean state
Expand Down
111 changes: 111 additions & 0 deletions migrate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"strings"
"testing"

"github.com/golang-migrate/migrate/v4/database"
dStub "github.com/golang-migrate/migrate/v4/database/stub"
"github.com/golang-migrate/migrate/v4/source"
sStub "github.com/golang-migrate/migrate/v4/source/stub"
Expand Down Expand Up @@ -878,6 +879,116 @@ func TestUpAndDown(t *testing.T) {
equalDbSeq(t, 1, expectedSequence, dbDrv)
}

func TestPostStepCallback(t *testing.T) {
m, _ := New("stub://", "stub://", WithPostStepCallbacks(
map[uint]PostStepCallback{
1: func(m *Migration, driver database.Driver) error {
return driver.Run(
strings.NewReader("CALLBACK 1"),
)
},
7: func(m *Migration, driver database.Driver) error {
return driver.Run(
strings.NewReader("CALLBACK 7"),
)
},
},
))
m.sourceDrv.(*sStub.Stub).Migrations = sourceStubMigrations
dbDrv := m.databaseDrv.(*dStub.Stub)

// go Up first
if err := m.Up(); err != nil {
t.Fatal(err)
}
expectedSequence := migrationSequence{
mr("CREATE 1"),
mr("CALLBACK 1"),
mr("CREATE 3"),
mr("CREATE 4"),
mr("CREATE 7"),
mr("CALLBACK 7"),
}
equalDbSeq(t, 0, expectedSequence, dbDrv)

if !bytes.Equal(dbDrv.LastRunMigration, []byte("CALLBACK 7")) {
t.Fatalf("expected database last migration to be callback 7, "+
"got %s", dbDrv.LastRunMigration)
}

// go Down
if err := m.Down(); err != nil {
t.Fatal(err)
}
expectedSequence = migrationSequence{
mr("CREATE 1"),
mr("CALLBACK 1"),
mr("CREATE 3"),
mr("CREATE 4"),
mr("CREATE 7"),
mr("CALLBACK 7"),
mr("DROP 7"),
mr("CALLBACK 7"),
mr("DROP 5"),
mr("DROP 4"),
mr("DROP 1"),
mr("CALLBACK 1"),
}
equalDbSeq(t, 1, expectedSequence, dbDrv)

if !bytes.Equal(dbDrv.LastRunMigration, []byte("CALLBACK 1")) {
t.Fatalf("expected database last migration to be callback 1, "+
"got %s", dbDrv.LastRunMigration)
}

// go 1 Up and then all the way Up
if err := m.Steps(1); err != nil {
t.Fatal(err)
}
expectedSequence = migrationSequence{
mr("CREATE 1"),
mr("CALLBACK 1"),
mr("CREATE 3"),
mr("CREATE 4"),
mr("CREATE 7"),
mr("CALLBACK 7"),
mr("DROP 7"),
mr("CALLBACK 7"),
mr("DROP 5"),
mr("DROP 4"),
mr("DROP 1"),
mr("CALLBACK 1"),
mr("CREATE 1"),
mr("CALLBACK 1"),
}
equalDbSeq(t, 2, expectedSequence, dbDrv)

if err := m.Up(); err != nil {
t.Fatal(err)
}
expectedSequence = migrationSequence{
mr("CREATE 1"),
mr("CALLBACK 1"),
mr("CREATE 3"),
mr("CREATE 4"),
mr("CREATE 7"),
mr("CALLBACK 7"),
mr("DROP 7"),
mr("CALLBACK 7"),
mr("DROP 5"),
mr("DROP 4"),
mr("DROP 1"),
mr("CALLBACK 1"),
mr("CREATE 1"),
mr("CALLBACK 1"),
mr("CREATE 3"),
mr("CREATE 4"),
mr("CREATE 7"),
mr("CALLBACK 7"),
}
equalDbSeq(t, 3, expectedSequence, dbDrv)
}

func TestUpDirty(t *testing.T) {
m, _ := New("stub://", "stub://")
dbDrv := m.databaseDrv.(*dStub.Stub)
Expand Down