diff --git a/migrate.go b/migrate.go index 44efe14e3..35e4a3442 100644 --- a/migrate.go +++ b/migrate.go @@ -55,12 +55,71 @@ func (e ErrDirty) Error() string { return fmt.Sprintf("Dirty database version %v. Fix and force version.", e.Version) } +// PostStepCallback is a callback function type that can be used to execute a +// Golang based migration step after a SQL based migration step has been +// executed. The callback function receives the migration and the database +// driver as arguments. +type PostStepCallback func(migr *Migration, driver database.Driver) error + +// options is a set of optional options that can be set when a Migrate instance +// is created. +type options struct { + // postStepCallbacks is a map of PostStepCallback functions that can be + // used to execute a Golang based migration step after a SQL based + // migration step has been executed. The key is the migration version + // and the value is the callback function that should be run _after_ the + // step was executed (but within the same database transaction). + postStepCallbacks map[uint]PostStepCallback +} + +// defaultOptions returns a new options struct with default values. +func defaultOptions() options { + return options{ + postStepCallbacks: make(map[uint]PostStepCallback), + } +} + +// Option is a function that can be used to set options on a Migrate instance. +type Option func(*options) + +// WithPostStepCallbacks is an option that can be used to set a map of +// PostStepCallback functions that can be used to execute a Golang based +// migration step after a SQL based migration step has been executed. The key is +// the migration version and the value is the callback function that should be +// run _after_ the step was executed (but before the version is marked as +// cleanly executed). An error returned from the callback will cause the +// migration to fail and the step to be marked as dirty. +func WithPostStepCallbacks( + postStepCallbacks map[uint]PostStepCallback) Option { + + return func(o *options) { + o.postStepCallbacks = postStepCallbacks + } +} + +// WithPostStepCallback is an option that can be used to set a PostStepCallback +// function that can be used to execute a Golang based migration step after the +// SQL based migration step with the given version number has been executed. The +// callback is the function that should be run _after_ the step was executed +// (but before the version is marked as cleanly executed). An error returned +// from the callback will cause the migration to fail and the step to be marked +// as dirty. +func WithPostStepCallback(version uint, callback PostStepCallback) Option { + return func(o *options) { + o.postStepCallbacks[version] = callback + } +} + type Migrate struct { sourceName string sourceDrv source.Driver databaseName string databaseDrv database.Driver + // opts is a set of options that can be used to modify the behavior + // of the Migrate instance. + opts options + // Log accepts a Logger interface Log Logger @@ -84,8 +143,8 @@ type Migrate struct { // New returns a new Migrate instance from a source URL and a database URL. // The URL scheme is defined by each driver. -func New(sourceURL, databaseURL string) (*Migrate, error) { - m := newCommon() +func New(sourceURL, databaseURL string, opts ...Option) (*Migrate, error) { + m := newMigrateWithOptions(opts) sourceName, err := iurl.SchemeFromURL(sourceURL) if err != nil { @@ -118,8 +177,10 @@ func New(sourceURL, databaseURL string) (*Migrate, error) { // and an existing database instance. The source URL scheme is defined by each driver. // Use any string that can serve as an identifier during logging as databaseName. // You are responsible for closing the underlying database client if necessary. -func NewWithDatabaseInstance(sourceURL string, databaseName string, databaseInstance database.Driver) (*Migrate, error) { - m := newCommon() +func NewWithDatabaseInstance(sourceURL string, databaseName string, + databaseInstance database.Driver, opts ...Option) (*Migrate, error) { + + m := newMigrateWithOptions(opts) sourceName, err := iurl.SchemeFromURL(sourceURL) if err != nil { @@ -144,8 +205,10 @@ func NewWithDatabaseInstance(sourceURL string, databaseName string, databaseInst // and a database URL. The database URL scheme is defined by each driver. // Use any string that can serve as an identifier during logging as sourceName. // You are responsible for closing the underlying source client if necessary. -func NewWithSourceInstance(sourceName string, sourceInstance source.Driver, databaseURL string) (*Migrate, error) { - m := newCommon() +func NewWithSourceInstance(sourceName string, sourceInstance source.Driver, + databaseURL string, opts ...Option) (*Migrate, error) { + + m := newMigrateWithOptions(opts) databaseName, err := iurl.SchemeFromURL(databaseURL) if err != nil { @@ -170,8 +233,11 @@ func NewWithSourceInstance(sourceName string, sourceInstance source.Driver, data // database instance. Use any string that can serve as an identifier during logging // as sourceName and databaseName. You are responsible for closing down // the underlying source and database client if necessary. -func NewWithInstance(sourceName string, sourceInstance source.Driver, databaseName string, databaseInstance database.Driver) (*Migrate, error) { - m := newCommon() +func NewWithInstance(sourceName string, sourceInstance source.Driver, + databaseName string, databaseInstance database.Driver, + opts ...Option) (*Migrate, error) { + + m := newMigrateWithOptions(opts) m.sourceName = sourceName m.databaseName = databaseName @@ -182,8 +248,13 @@ func NewWithInstance(sourceName string, sourceInstance source.Driver, databaseNa return m, nil } -func newCommon() *Migrate { +func newMigrateWithOptions(optFunctions []Option) *Migrate { + opts := defaultOptions() + for _, opt := range optFunctions { + opt(&opts) + } return &Migrate{ + opts: opts, GracefulStop: make(chan bool, 1), PrefetchMigrations: DefaultPrefetchMigrations, LockTimeout: DefaultLockTimeout, @@ -746,6 +817,25 @@ func (m *Migrate) runMigrations(ret <-chan interface{}) error { if err := m.databaseDrv.Run(migr.BufferedBody); err != nil { return err } + + // If there is a post execution function for + // this migration, run it now. + cb, ok := m.opts.postStepCallbacks[migr.Version] + if ok { + m.logVerbosePrintf("Running post step "+ + "callback for %v\n", migr.LogString()) + + err := cb(migr, m.databaseDrv) + if err != nil { + return fmt.Errorf("failed to "+ + "execute post "+ + "step callback: %w", + err) + } + + m.logVerbosePrintf("Post step callback "+ + "finished for %v\n", migr.LogString()) + } } // set clean state diff --git a/migrate_test.go b/migrate_test.go index f2728179e..a19c4f214 100644 --- a/migrate_test.go +++ b/migrate_test.go @@ -10,6 +10,7 @@ import ( "strings" "testing" + "github.com/golang-migrate/migrate/v4/database" dStub "github.com/golang-migrate/migrate/v4/database/stub" "github.com/golang-migrate/migrate/v4/source" sStub "github.com/golang-migrate/migrate/v4/source/stub" @@ -878,6 +879,116 @@ func TestUpAndDown(t *testing.T) { equalDbSeq(t, 1, expectedSequence, dbDrv) } +func TestPostStepCallback(t *testing.T) { + m, _ := New("stub://", "stub://", WithPostStepCallbacks( + map[uint]PostStepCallback{ + 1: func(m *Migration, driver database.Driver) error { + return driver.Run( + strings.NewReader("CALLBACK 1"), + ) + }, + 7: func(m *Migration, driver database.Driver) error { + return driver.Run( + strings.NewReader("CALLBACK 7"), + ) + }, + }, + )) + m.sourceDrv.(*sStub.Stub).Migrations = sourceStubMigrations + dbDrv := m.databaseDrv.(*dStub.Stub) + + // go Up first + if err := m.Up(); err != nil { + t.Fatal(err) + } + expectedSequence := migrationSequence{ + mr("CREATE 1"), + mr("CALLBACK 1"), + mr("CREATE 3"), + mr("CREATE 4"), + mr("CREATE 7"), + mr("CALLBACK 7"), + } + equalDbSeq(t, 0, expectedSequence, dbDrv) + + if !bytes.Equal(dbDrv.LastRunMigration, []byte("CALLBACK 7")) { + t.Fatalf("expected database last migration to be callback 7, "+ + "got %s", dbDrv.LastRunMigration) + } + + // go Down + if err := m.Down(); err != nil { + t.Fatal(err) + } + expectedSequence = migrationSequence{ + mr("CREATE 1"), + mr("CALLBACK 1"), + mr("CREATE 3"), + mr("CREATE 4"), + mr("CREATE 7"), + mr("CALLBACK 7"), + mr("DROP 7"), + mr("CALLBACK 7"), + mr("DROP 5"), + mr("DROP 4"), + mr("DROP 1"), + mr("CALLBACK 1"), + } + equalDbSeq(t, 1, expectedSequence, dbDrv) + + if !bytes.Equal(dbDrv.LastRunMigration, []byte("CALLBACK 1")) { + t.Fatalf("expected database last migration to be callback 1, "+ + "got %s", dbDrv.LastRunMigration) + } + + // go 1 Up and then all the way Up + if err := m.Steps(1); err != nil { + t.Fatal(err) + } + expectedSequence = migrationSequence{ + mr("CREATE 1"), + mr("CALLBACK 1"), + mr("CREATE 3"), + mr("CREATE 4"), + mr("CREATE 7"), + mr("CALLBACK 7"), + mr("DROP 7"), + mr("CALLBACK 7"), + mr("DROP 5"), + mr("DROP 4"), + mr("DROP 1"), + mr("CALLBACK 1"), + mr("CREATE 1"), + mr("CALLBACK 1"), + } + equalDbSeq(t, 2, expectedSequence, dbDrv) + + if err := m.Up(); err != nil { + t.Fatal(err) + } + expectedSequence = migrationSequence{ + mr("CREATE 1"), + mr("CALLBACK 1"), + mr("CREATE 3"), + mr("CREATE 4"), + mr("CREATE 7"), + mr("CALLBACK 7"), + mr("DROP 7"), + mr("CALLBACK 7"), + mr("DROP 5"), + mr("DROP 4"), + mr("DROP 1"), + mr("CALLBACK 1"), + mr("CREATE 1"), + mr("CALLBACK 1"), + mr("CREATE 3"), + mr("CREATE 4"), + mr("CREATE 7"), + mr("CALLBACK 7"), + } + equalDbSeq(t, 3, expectedSequence, dbDrv) +} + func TestUpDirty(t *testing.T) { m, _ := New("stub://", "stub://") dbDrv := m.databaseDrv.(*dStub.Stub)