Skip to content

Commit d63a8ea

Browse files
authored
Merge pull request #2 from lightninglabs/fork-post-migrate-exec
Allow running Golang based post migration steps
2 parents 9023d66 + d6580fa commit d63a8ea

File tree

2 files changed

+210
-9
lines changed

2 files changed

+210
-9
lines changed

migrate.go

Lines changed: 99 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,71 @@ func (e ErrDirty) Error() string {
5555
return fmt.Sprintf("Dirty database version %v. Fix and force version.", e.Version)
5656
}
5757

58+
// PostStepCallback is a callback function type that can be used to execute a
59+
// Golang based migration step after a SQL based migration step has been
60+
// executed. The callback function receives the migration and the database
61+
// driver as arguments.
62+
type PostStepCallback func(migr *Migration, driver database.Driver) error
63+
64+
// options is a set of optional options that can be set when a Migrate instance
65+
// is created.
66+
type options struct {
67+
// postStepCallbacks is a map of PostStepCallback functions that can be
68+
// used to execute a Golang based migration step after a SQL based
69+
// migration step has been executed. The key is the migration version
70+
// and the value is the callback function that should be run _after_ the
71+
// step was executed (but within the same database transaction).
72+
postStepCallbacks map[uint]PostStepCallback
73+
}
74+
75+
// defaultOptions returns a new options struct with default values.
76+
func defaultOptions() options {
77+
return options{
78+
postStepCallbacks: make(map[uint]PostStepCallback),
79+
}
80+
}
81+
82+
// Option is a function that can be used to set options on a Migrate instance.
83+
type Option func(*options)
84+
85+
// WithPostStepCallbacks is an option that can be used to set a map of
86+
// PostStepCallback functions that can be used to execute a Golang based
87+
// migration step after a SQL based migration step has been executed. The key is
88+
// the migration version and the value is the callback function that should be
89+
// run _after_ the step was executed (but before the version is marked as
90+
// cleanly executed). An error returned from the callback will cause the
91+
// migration to fail and the step to be marked as dirty.
92+
func WithPostStepCallbacks(
93+
postStepCallbacks map[uint]PostStepCallback) Option {
94+
95+
return func(o *options) {
96+
o.postStepCallbacks = postStepCallbacks
97+
}
98+
}
99+
100+
// WithPostStepCallback is an option that can be used to set a PostStepCallback
101+
// function that can be used to execute a Golang based migration step after the
102+
// SQL based migration step with the given version number has been executed. The
103+
// callback is the function that should be run _after_ the step was executed
104+
// (but before the version is marked as cleanly executed). An error returned
105+
// from the callback will cause the migration to fail and the step to be marked
106+
// as dirty.
107+
func WithPostStepCallback(version uint, callback PostStepCallback) Option {
108+
return func(o *options) {
109+
o.postStepCallbacks[version] = callback
110+
}
111+
}
112+
58113
type Migrate struct {
59114
sourceName string
60115
sourceDrv source.Driver
61116
databaseName string
62117
databaseDrv database.Driver
63118

119+
// opts is a set of options that can be used to modify the behavior
120+
// of the Migrate instance.
121+
opts options
122+
64123
// Log accepts a Logger interface
65124
Log Logger
66125

@@ -84,8 +143,8 @@ type Migrate struct {
84143

85144
// New returns a new Migrate instance from a source URL and a database URL.
86145
// The URL scheme is defined by each driver.
87-
func New(sourceURL, databaseURL string) (*Migrate, error) {
88-
m := newCommon()
146+
func New(sourceURL, databaseURL string, opts ...Option) (*Migrate, error) {
147+
m := newMigrateWithOptions(opts)
89148

90149
sourceName, err := iurl.SchemeFromURL(sourceURL)
91150
if err != nil {
@@ -118,8 +177,10 @@ func New(sourceURL, databaseURL string) (*Migrate, error) {
118177
// and an existing database instance. The source URL scheme is defined by each driver.
119178
// Use any string that can serve as an identifier during logging as databaseName.
120179
// You are responsible for closing the underlying database client if necessary.
121-
func NewWithDatabaseInstance(sourceURL string, databaseName string, databaseInstance database.Driver) (*Migrate, error) {
122-
m := newCommon()
180+
func NewWithDatabaseInstance(sourceURL string, databaseName string,
181+
databaseInstance database.Driver, opts ...Option) (*Migrate, error) {
182+
183+
m := newMigrateWithOptions(opts)
123184

124185
sourceName, err := iurl.SchemeFromURL(sourceURL)
125186
if err != nil {
@@ -144,8 +205,10 @@ func NewWithDatabaseInstance(sourceURL string, databaseName string, databaseInst
144205
// and a database URL. The database URL scheme is defined by each driver.
145206
// Use any string that can serve as an identifier during logging as sourceName.
146207
// You are responsible for closing the underlying source client if necessary.
147-
func NewWithSourceInstance(sourceName string, sourceInstance source.Driver, databaseURL string) (*Migrate, error) {
148-
m := newCommon()
208+
func NewWithSourceInstance(sourceName string, sourceInstance source.Driver,
209+
databaseURL string, opts ...Option) (*Migrate, error) {
210+
211+
m := newMigrateWithOptions(opts)
149212

150213
databaseName, err := iurl.SchemeFromURL(databaseURL)
151214
if err != nil {
@@ -170,8 +233,11 @@ func NewWithSourceInstance(sourceName string, sourceInstance source.Driver, data
170233
// database instance. Use any string that can serve as an identifier during logging
171234
// as sourceName and databaseName. You are responsible for closing down
172235
// the underlying source and database client if necessary.
173-
func NewWithInstance(sourceName string, sourceInstance source.Driver, databaseName string, databaseInstance database.Driver) (*Migrate, error) {
174-
m := newCommon()
236+
func NewWithInstance(sourceName string, sourceInstance source.Driver,
237+
databaseName string, databaseInstance database.Driver,
238+
opts ...Option) (*Migrate, error) {
239+
240+
m := newMigrateWithOptions(opts)
175241

176242
m.sourceName = sourceName
177243
m.databaseName = databaseName
@@ -182,8 +248,13 @@ func NewWithInstance(sourceName string, sourceInstance source.Driver, databaseNa
182248
return m, nil
183249
}
184250

185-
func newCommon() *Migrate {
251+
func newMigrateWithOptions(optFunctions []Option) *Migrate {
252+
opts := defaultOptions()
253+
for _, opt := range optFunctions {
254+
opt(&opts)
255+
}
186256
return &Migrate{
257+
opts: opts,
187258
GracefulStop: make(chan bool, 1),
188259
PrefetchMigrations: DefaultPrefetchMigrations,
189260
LockTimeout: DefaultLockTimeout,
@@ -746,6 +817,25 @@ func (m *Migrate) runMigrations(ret <-chan interface{}) error {
746817
if err := m.databaseDrv.Run(migr.BufferedBody); err != nil {
747818
return err
748819
}
820+
821+
// If there is a post execution function for
822+
// this migration, run it now.
823+
cb, ok := m.opts.postStepCallbacks[migr.Version]
824+
if ok {
825+
m.logVerbosePrintf("Running post step "+
826+
"callback for %v\n", migr.LogString())
827+
828+
err := cb(migr, m.databaseDrv)
829+
if err != nil {
830+
return fmt.Errorf("failed to "+
831+
"execute post "+
832+
"step callback: %w",
833+
err)
834+
}
835+
836+
m.logVerbosePrintf("Post step callback "+
837+
"finished for %v\n", migr.LogString())
838+
}
749839
}
750840

751841
// set clean state

migrate_test.go

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"strings"
1111
"testing"
1212

13+
"github.com/golang-migrate/migrate/v4/database"
1314
dStub "github.com/golang-migrate/migrate/v4/database/stub"
1415
"github.com/golang-migrate/migrate/v4/source"
1516
sStub "github.com/golang-migrate/migrate/v4/source/stub"
@@ -878,6 +879,116 @@ func TestUpAndDown(t *testing.T) {
878879
equalDbSeq(t, 1, expectedSequence, dbDrv)
879880
}
880881

882+
func TestPostStepCallback(t *testing.T) {
883+
m, _ := New("stub://", "stub://", WithPostStepCallbacks(
884+
map[uint]PostStepCallback{
885+
1: func(m *Migration, driver database.Driver) error {
886+
return driver.Run(
887+
strings.NewReader("CALLBACK 1"),
888+
)
889+
},
890+
7: func(m *Migration, driver database.Driver) error {
891+
return driver.Run(
892+
strings.NewReader("CALLBACK 7"),
893+
)
894+
},
895+
},
896+
))
897+
m.sourceDrv.(*sStub.Stub).Migrations = sourceStubMigrations
898+
dbDrv := m.databaseDrv.(*dStub.Stub)
899+
900+
// go Up first
901+
if err := m.Up(); err != nil {
902+
t.Fatal(err)
903+
}
904+
expectedSequence := migrationSequence{
905+
mr("CREATE 1"),
906+
mr("CALLBACK 1"),
907+
mr("CREATE 3"),
908+
mr("CREATE 4"),
909+
mr("CREATE 7"),
910+
mr("CALLBACK 7"),
911+
}
912+
equalDbSeq(t, 0, expectedSequence, dbDrv)
913+
914+
if !bytes.Equal(dbDrv.LastRunMigration, []byte("CALLBACK 7")) {
915+
t.Fatalf("expected database last migration to be callback 7, "+
916+
"got %s", dbDrv.LastRunMigration)
917+
}
918+
919+
// go Down
920+
if err := m.Down(); err != nil {
921+
t.Fatal(err)
922+
}
923+
expectedSequence = migrationSequence{
924+
mr("CREATE 1"),
925+
mr("CALLBACK 1"),
926+
mr("CREATE 3"),
927+
mr("CREATE 4"),
928+
mr("CREATE 7"),
929+
mr("CALLBACK 7"),
930+
mr("DROP 7"),
931+
mr("CALLBACK 7"),
932+
mr("DROP 5"),
933+
mr("DROP 4"),
934+
mr("DROP 1"),
935+
mr("CALLBACK 1"),
936+
}
937+
equalDbSeq(t, 1, expectedSequence, dbDrv)
938+
939+
if !bytes.Equal(dbDrv.LastRunMigration, []byte("CALLBACK 1")) {
940+
t.Fatalf("expected database last migration to be callback 1, "+
941+
"got %s", dbDrv.LastRunMigration)
942+
}
943+
944+
// go 1 Up and then all the way Up
945+
if err := m.Steps(1); err != nil {
946+
t.Fatal(err)
947+
}
948+
expectedSequence = migrationSequence{
949+
mr("CREATE 1"),
950+
mr("CALLBACK 1"),
951+
mr("CREATE 3"),
952+
mr("CREATE 4"),
953+
mr("CREATE 7"),
954+
mr("CALLBACK 7"),
955+
mr("DROP 7"),
956+
mr("CALLBACK 7"),
957+
mr("DROP 5"),
958+
mr("DROP 4"),
959+
mr("DROP 1"),
960+
mr("CALLBACK 1"),
961+
mr("CREATE 1"),
962+
mr("CALLBACK 1"),
963+
}
964+
equalDbSeq(t, 2, expectedSequence, dbDrv)
965+
966+
if err := m.Up(); err != nil {
967+
t.Fatal(err)
968+
}
969+
expectedSequence = migrationSequence{
970+
mr("CREATE 1"),
971+
mr("CALLBACK 1"),
972+
mr("CREATE 3"),
973+
mr("CREATE 4"),
974+
mr("CREATE 7"),
975+
mr("CALLBACK 7"),
976+
mr("DROP 7"),
977+
mr("CALLBACK 7"),
978+
mr("DROP 5"),
979+
mr("DROP 4"),
980+
mr("DROP 1"),
981+
mr("CALLBACK 1"),
982+
mr("CREATE 1"),
983+
mr("CALLBACK 1"),
984+
mr("CREATE 3"),
985+
mr("CREATE 4"),
986+
mr("CREATE 7"),
987+
mr("CALLBACK 7"),
988+
}
989+
equalDbSeq(t, 3, expectedSequence, dbDrv)
990+
}
991+
881992
func TestUpDirty(t *testing.T) {
882993
m, _ := New("stub://", "stub://")
883994
dbDrv := m.databaseDrv.(*dStub.Stub)

0 commit comments

Comments
 (0)