From 5b219eb81e5a9a949d3b55b7dff2c69e1e8e77da Mon Sep 17 00:00:00 2001 From: Oliver Gugger Date: Fri, 28 Mar 2025 12:49:10 -0500 Subject: [PATCH 1/2] Add functional options to constructors This is a preparatory commit that adds the ability to add new options to any of the constructor functions, without breaking backward compatibility. An actual option is going to be added in the next commit, this just introduces the mechanism. --- migrate.go | 47 ++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 38 insertions(+), 9 deletions(-) diff --git a/migrate.go b/migrate.go index 266cc04eb..784709906 100644 --- a/migrate.go +++ b/migrate.go @@ -55,12 +55,29 @@ func (e ErrDirty) Error() string { return fmt.Sprintf("Dirty database version %v. Fix and force version.", e.Version) } +// options is a set of optional options that can be set when a Migrate instance +// is created. +type options struct { +} + +// defaultOptions returns a new options struct with default values. +func defaultOptions() options { + return options{} +} + +// Option is a function that can be used to set options on a Migrate instance. +type Option func(*options) + type Migrate struct { sourceName string sourceDrv source.Driver databaseName string databaseDrv database.Driver + // opts is a set of options that can be used to modify the behavior + // of the Migrate instance. + opts options + // Log accepts a Logger interface Log Logger @@ -84,8 +101,8 @@ type Migrate struct { // New returns a new Migrate instance from a source URL and a database URL. // The URL scheme is defined by each driver. -func New(sourceURL, databaseURL string) (*Migrate, error) { - m := newCommon() +func New(sourceURL, databaseURL string, opts ...Option) (*Migrate, error) { + m := newMigrateWithOptions(opts) sourceName, err := iurl.SchemeFromURL(sourceURL) if err != nil { @@ -118,8 +135,10 @@ func New(sourceURL, databaseURL string) (*Migrate, error) { // and an existing database instance. The source URL scheme is defined by each driver. // Use any string that can serve as an identifier during logging as databaseName. // You are responsible for closing the underlying database client if necessary. -func NewWithDatabaseInstance(sourceURL string, databaseName string, databaseInstance database.Driver) (*Migrate, error) { - m := newCommon() +func NewWithDatabaseInstance(sourceURL string, databaseName string, + databaseInstance database.Driver, opts ...Option) (*Migrate, error) { + + m := newMigrateWithOptions(opts) sourceName, err := iurl.SchemeFromURL(sourceURL) if err != nil { @@ -144,8 +163,10 @@ func NewWithDatabaseInstance(sourceURL string, databaseName string, databaseInst // and a database URL. The database URL scheme is defined by each driver. // Use any string that can serve as an identifier during logging as sourceName. // You are responsible for closing the underlying source client if necessary. -func NewWithSourceInstance(sourceName string, sourceInstance source.Driver, databaseURL string) (*Migrate, error) { - m := newCommon() +func NewWithSourceInstance(sourceName string, sourceInstance source.Driver, + databaseURL string, opts ...Option) (*Migrate, error) { + + m := newMigrateWithOptions(opts) databaseName, err := iurl.SchemeFromURL(databaseURL) if err != nil { @@ -170,8 +191,11 @@ func NewWithSourceInstance(sourceName string, sourceInstance source.Driver, data // database instance. Use any string that can serve as an identifier during logging // as sourceName and databaseName. You are responsible for closing down // the underlying source and database client if necessary. -func NewWithInstance(sourceName string, sourceInstance source.Driver, databaseName string, databaseInstance database.Driver) (*Migrate, error) { - m := newCommon() +func NewWithInstance(sourceName string, sourceInstance source.Driver, + databaseName string, databaseInstance database.Driver, + opts ...Option) (*Migrate, error) { + + m := newMigrateWithOptions(opts) m.sourceName = sourceName m.databaseName = databaseName @@ -182,8 +206,13 @@ func NewWithInstance(sourceName string, sourceInstance source.Driver, databaseNa return m, nil } -func newCommon() *Migrate { +func newMigrateWithOptions(optFunctions []Option) *Migrate { + opts := defaultOptions() + for _, opt := range optFunctions { + opt(&opts) + } return &Migrate{ + opts: opts, GracefulStop: make(chan bool, 1), PrefetchMigrations: DefaultPrefetchMigrations, LockTimeout: DefaultLockTimeout, From d6580fad33462405bea07839036ae0b7b1e9ecef Mon Sep 17 00:00:00 2001 From: Oliver Gugger Date: Fri, 28 Mar 2025 12:49:14 -0500 Subject: [PATCH 2/2] Implement post migration step callbacks --- migrate.go | 63 ++++++++++++++++++++++++++- migrate_test.go | 111 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 173 insertions(+), 1 deletion(-) diff --git a/migrate.go b/migrate.go index 784709906..83f28cd35 100644 --- a/migrate.go +++ b/migrate.go @@ -55,19 +55,61 @@ func (e ErrDirty) Error() string { return fmt.Sprintf("Dirty database version %v. Fix and force version.", e.Version) } +// PostStepCallback is a callback function type that can be used to execute a +// Golang based migration step after a SQL based migration step has been +// executed. The callback function receives the migration and the database +// driver as arguments. +type PostStepCallback func(migr *Migration, driver database.Driver) error + // options is a set of optional options that can be set when a Migrate instance // is created. type options struct { + // postStepCallbacks is a map of PostStepCallback functions that can be + // used to execute a Golang based migration step after a SQL based + // migration step has been executed. The key is the migration version + // and the value is the callback function that should be run _after_ the + // step was executed (but within the same database transaction). + postStepCallbacks map[uint]PostStepCallback } // defaultOptions returns a new options struct with default values. func defaultOptions() options { - return options{} + return options{ + postStepCallbacks: make(map[uint]PostStepCallback), + } } // Option is a function that can be used to set options on a Migrate instance. type Option func(*options) +// WithPostStepCallbacks is an option that can be used to set a map of +// PostStepCallback functions that can be used to execute a Golang based +// migration step after a SQL based migration step has been executed. The key is +// the migration version and the value is the callback function that should be +// run _after_ the step was executed (but before the version is marked as +// cleanly executed). An error returned from the callback will cause the +// migration to fail and the step to be marked as dirty. +func WithPostStepCallbacks( + postStepCallbacks map[uint]PostStepCallback) Option { + + return func(o *options) { + o.postStepCallbacks = postStepCallbacks + } +} + +// WithPostStepCallback is an option that can be used to set a PostStepCallback +// function that can be used to execute a Golang based migration step after the +// SQL based migration step with the given version number has been executed. The +// callback is the function that should be run _after_ the step was executed +// (but before the version is marked as cleanly executed). An error returned +// from the callback will cause the migration to fail and the step to be marked +// as dirty. +func WithPostStepCallback(version uint, callback PostStepCallback) Option { + return func(o *options) { + o.postStepCallbacks[version] = callback + } +} + type Migrate struct { sourceName string sourceDrv source.Driver @@ -775,6 +817,25 @@ func (m *Migrate) runMigrations(ret <-chan interface{}) error { if err := m.databaseDrv.Run(migr.BufferedBody); err != nil { return err } + + // If there is a post execution function for + // this migration, run it now. + cb, ok := m.opts.postStepCallbacks[migr.Version] + if ok { + m.logVerbosePrintf("Running post step "+ + "callback for %v\n", migr.LogString()) + + err := cb(migr, m.databaseDrv) + if err != nil { + return fmt.Errorf("failed to "+ + "execute post "+ + "step callback: %w", + err) + } + + m.logVerbosePrintf("Post step callback "+ + "finished for %v\n", migr.LogString()) + } } // set clean state diff --git a/migrate_test.go b/migrate_test.go index f2728179e..a19c4f214 100644 --- a/migrate_test.go +++ b/migrate_test.go @@ -10,6 +10,7 @@ import ( "strings" "testing" + "github.com/golang-migrate/migrate/v4/database" dStub "github.com/golang-migrate/migrate/v4/database/stub" "github.com/golang-migrate/migrate/v4/source" sStub "github.com/golang-migrate/migrate/v4/source/stub" @@ -878,6 +879,116 @@ func TestUpAndDown(t *testing.T) { equalDbSeq(t, 1, expectedSequence, dbDrv) } +func TestPostStepCallback(t *testing.T) { + m, _ := New("stub://", "stub://", WithPostStepCallbacks( + map[uint]PostStepCallback{ + 1: func(m *Migration, driver database.Driver) error { + return driver.Run( + strings.NewReader("CALLBACK 1"), + ) + }, + 7: func(m *Migration, driver database.Driver) error { + return driver.Run( + strings.NewReader("CALLBACK 7"), + ) + }, + }, + )) + m.sourceDrv.(*sStub.Stub).Migrations = sourceStubMigrations + dbDrv := m.databaseDrv.(*dStub.Stub) + + // go Up first + if err := m.Up(); err != nil { + t.Fatal(err) + } + expectedSequence := migrationSequence{ + mr("CREATE 1"), + mr("CALLBACK 1"), + mr("CREATE 3"), + mr("CREATE 4"), + mr("CREATE 7"), + mr("CALLBACK 7"), + } + equalDbSeq(t, 0, expectedSequence, dbDrv) + + if !bytes.Equal(dbDrv.LastRunMigration, []byte("CALLBACK 7")) { + t.Fatalf("expected database last migration to be callback 7, "+ + "got %s", dbDrv.LastRunMigration) + } + + // go Down + if err := m.Down(); err != nil { + t.Fatal(err) + } + expectedSequence = migrationSequence{ + mr("CREATE 1"), + mr("CALLBACK 1"), + mr("CREATE 3"), + mr("CREATE 4"), + mr("CREATE 7"), + mr("CALLBACK 7"), + mr("DROP 7"), + mr("CALLBACK 7"), + mr("DROP 5"), + mr("DROP 4"), + mr("DROP 1"), + mr("CALLBACK 1"), + } + equalDbSeq(t, 1, expectedSequence, dbDrv) + + if !bytes.Equal(dbDrv.LastRunMigration, []byte("CALLBACK 1")) { + t.Fatalf("expected database last migration to be callback 1, "+ + "got %s", dbDrv.LastRunMigration) + } + + // go 1 Up and then all the way Up + if err := m.Steps(1); err != nil { + t.Fatal(err) + } + expectedSequence = migrationSequence{ + mr("CREATE 1"), + mr("CALLBACK 1"), + mr("CREATE 3"), + mr("CREATE 4"), + mr("CREATE 7"), + mr("CALLBACK 7"), + mr("DROP 7"), + mr("CALLBACK 7"), + mr("DROP 5"), + mr("DROP 4"), + mr("DROP 1"), + mr("CALLBACK 1"), + mr("CREATE 1"), + mr("CALLBACK 1"), + } + equalDbSeq(t, 2, expectedSequence, dbDrv) + + if err := m.Up(); err != nil { + t.Fatal(err) + } + expectedSequence = migrationSequence{ + mr("CREATE 1"), + mr("CALLBACK 1"), + mr("CREATE 3"), + mr("CREATE 4"), + mr("CREATE 7"), + mr("CALLBACK 7"), + mr("DROP 7"), + mr("CALLBACK 7"), + mr("DROP 5"), + mr("DROP 4"), + mr("DROP 1"), + mr("CALLBACK 1"), + mr("CREATE 1"), + mr("CALLBACK 1"), + mr("CREATE 3"), + mr("CREATE 4"), + mr("CREATE 7"), + mr("CALLBACK 7"), + } + equalDbSeq(t, 3, expectedSequence, dbDrv) +} + func TestUpDirty(t *testing.T) { m, _ := New("stub://", "stub://") dbDrv := m.databaseDrv.(*dStub.Stub)