Skip to content

Commit 2534234

Browse files
authored
Merge pull request #22 from upfluence/am/register-driver-wrapper
sqlutil: Add public API to register driver wrapper
2 parents 2f3862f + 2799d9d commit 2534234

File tree

5 files changed

+69
-13
lines changed

5 files changed

+69
-13
lines changed

sqlutil/open.go

+11-4
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,18 @@ var (
1919

2020
ErrNoDBProvided = errors.New("sql/sqlutil: No DB provided")
2121

22-
driverWrapperMu = &sync.Mutex{}
23-
driverWrapper = map[string]driverWrapperFunc{"postgres": postgres.NewDB}
22+
driverWrappersMu = &sync.Mutex{}
23+
driverWrappers = map[string]DriverWrapperFunc{"postgres": postgres.NewDB}
2424
)
2525

26-
type driverWrapperFunc func(sql.DB, sqlparser.SQLParser) sql.DB
26+
func RegisterDriverWrapper(d string, fn DriverWrapperFunc) {
27+
driverWrappersMu.Lock()
28+
defer driverWrappersMu.Unlock()
29+
30+
driverWrappers[d] = fn
31+
}
32+
33+
type DriverWrapperFunc func(sql.DB, sqlparser.SQLParser) sql.DB
2734

2835
type DBOption func(*dbInput)
2936

@@ -84,7 +91,7 @@ func (i *dbInput) buildDB(p sqlparser.SQLParser) (sql.DB, error) {
8491

8592
db := simple.FromStdDB(plainDB, i.driver)
8693

87-
if wfn, ok := driverWrapper[i.driver]; ok {
94+
if wfn, ok := driverWrappers[i.driver]; ok {
8895
db = wfn(db, p)
8996
}
9097

sqlutil/open_cgo.go

+1-3
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,7 @@ import (
99
)
1010

1111
func init() {
12-
driverWrapperMu.Lock()
13-
driverWrapper["sqlite3"] = newSQLite3DB
14-
driverWrapperMu.Unlock()
12+
RegisterDriverWrapper("sqlite3", newSQLite3DB)
1513
}
1614

1715
func newSQLite3DB(db sql.DB, _ sqlparser.SQLParser) sql.DB {

sqlutil/open_test.go

+18
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
package sqlutil
22

33
import (
4+
"database/sql"
45
"testing"
56

7+
"github.com/lib/pq"
8+
69
"github.com/upfluence/sql/backend/postgres"
710
)
811

@@ -17,3 +20,18 @@ func TestOpenPostgresDB(t *testing.T) {
1720
t.Errorf("invalid wrapping of the DB")
1821
}
1922
}
23+
24+
func TestRegisterDriverWrapper(t *testing.T) {
25+
sql.Register("bizbuz", &pq.Driver{})
26+
RegisterDriverWrapper("bizbuz", postgres.NewDB)
27+
28+
db, err := Open(WithMaster("bizbuz", "foobar"))
29+
30+
if err != nil {
31+
t.Errorf("Open() = (_, %+v) wanted nil", err)
32+
}
33+
34+
if !postgres.IsPostgresDB(db) {
35+
t.Errorf("invalid wrapping of the DB")
36+
}
37+
}

x/migration/driver.go

+20-6
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,41 @@
11
package migration
22

3+
import "sync"
4+
35
var (
46
defaultDriver = &driver{name: "default"}
57

6-
driverMap = map[string]Driver{
7-
"postgres": &driver{
8-
name: "postgres",
9-
extensions: []string{"postgres", "psql"},
10-
},
8+
PostgresDriver Driver = &driver{
9+
name: "postgres",
10+
extensions: []string{"postgres", "psql"},
11+
}
12+
13+
driversMu = &sync.Mutex{}
14+
drivers = map[string]Driver{
15+
"postgres": PostgresDriver,
1116
"sqlite3": &driver{
1217
name: "sqlite3",
1318
extensions: []string{"sqlite3", "sqlite"},
1419
},
1520
}
1621
)
1722

23+
func RegisterDriver(n string, d Driver) {
24+
driversMu.Lock()
25+
defer driversMu.Unlock()
26+
27+
drivers[n] = d
28+
}
29+
1830
type Driver interface {
1931
Name() string
2032
Extensions() []string
2133
}
2234

2335
func fetchDriver(dname string) Driver {
24-
d, ok := driverMap[dname]
36+
driversMu.Lock()
37+
d, ok := drivers[dname]
38+
driversMu.Unlock()
2539

2640
if ok {
2741
return d

x/migration/driver_test.go

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
package migration
2+
3+
import "testing"
4+
5+
func TestRegisterDriver(t *testing.T) {
6+
d := fetchDriver("foo")
7+
8+
if n := d.Name(); n != "default" {
9+
t.Errorf("driver.Name() = %q [ want: default ]", n)
10+
}
11+
12+
RegisterDriver("foo", PostgresDriver)
13+
14+
d = fetchDriver("foo")
15+
16+
if n := d.Name(); n != "postgres" {
17+
t.Errorf("driver.Name() = %q [ want: postgres ]", n)
18+
}
19+
}

0 commit comments

Comments
 (0)