Skip to content

Commit 6f73721

Browse files
committed
[in progress] Bulk load for MySQL
1 parent c3e8cde commit 6f73721

File tree

3 files changed

+130
-4
lines changed

3 files changed

+130
-4
lines changed

drivers/drivers_test.go

+21-3
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,7 @@ func TestCopy(t *testing.T) {
435435

436436
testCases := []struct {
437437
dbName string
438+
testCase string
438439
setupQueries []setupQuery
439440
src string
440441
dest string
@@ -449,7 +450,8 @@ func TestCopy(t *testing.T) {
449450
dest: "staff_copy",
450451
},
451452
{
452-
dbName: "pgsql",
453+
dbName: "pgsql",
454+
testCase: "schemaInDest",
453455
setupQueries: []setupQuery{
454456
{query: "DROP TABLE staff_copy"},
455457
{query: "CREATE TABLE staff_copy AS SELECT * FROM staff WHERE 0=1", check: true},
@@ -467,7 +469,8 @@ func TestCopy(t *testing.T) {
467469
dest: "staff_copy",
468470
},
469471
{
470-
dbName: "pgx",
472+
dbName: "pgx",
473+
testCase: "schemaInDest",
471474
setupQueries: []setupQuery{
472475
{query: "DROP TABLE staff_copy"},
473476
{query: "CREATE TABLE staff_copy AS SELECT * FROM staff WHERE 0=1", check: true},
@@ -484,6 +487,17 @@ func TestCopy(t *testing.T) {
484487
src: "select staff_id, first_name, last_name, address_id, picture, email, store_id, active, username, password, last_update from staff",
485488
dest: "staff_copy(staff_id, first_name, last_name, address_id, picture, email, store_id, active, username, password, last_update)",
486489
},
490+
{
491+
dbName: "mysql",
492+
testCase: "bulkCopy",
493+
setupQueries: []setupQuery{
494+
{query: "SET GLOBAL local_infile = ON"},
495+
{query: "DROP TABLE staff_copy"},
496+
{query: "CREATE TABLE staff_copy AS SELECT * FROM staff WHERE 0=1", check: true},
497+
},
498+
src: "select staff_id, first_name, last_name, address_id, email, store_id, active, username, password, last_update from staff",
499+
dest: "staff_copy(staff_id, first_name, last_name, address_id, email, store_id, active, username, password, last_update)",
500+
},
487501
{
488502
dbName: "sqlserver",
489503
setupQueries: []setupQuery{
@@ -508,7 +522,11 @@ func TestCopy(t *testing.T) {
508522
continue
509523
}
510524

511-
t.Run(test.dbName, func(t *testing.T) {
525+
testName := test.dbName
526+
if test.testCase != "" {
527+
testName += "-" + test.testCase
528+
}
529+
t.Run(testName, func(t *testing.T) {
512530

513531
// TODO test copy from a different DB, maybe csvq?
514532
// TODO test copy from same DB

drivers/mysql/copy.go

+108
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
package mysql
2+
3+
import (
4+
"context"
5+
"database/sql"
6+
"encoding/csv"
7+
"fmt"
8+
"io"
9+
"os"
10+
"reflect"
11+
"strings"
12+
13+
"github.com/go-sql-driver/mysql"
14+
"github.com/xo/usql/drivers"
15+
)
16+
17+
func copyRows(ctx context.Context, db *sql.DB, rows *sql.Rows, table string) (int64, error) {
18+
localInfileSupported := false
19+
row := db.QueryRowContext(ctx, "SELECT @@GLOBAL.local_infile")
20+
err := row.Scan(&localInfileSupported)
21+
if err == nil && localInfileSupported && !hasBlobColumn(rows) {
22+
return bulkCopy(ctx, db, rows, table)
23+
} else {
24+
return drivers.CopyWithInsert(func(int) string { return "?" })(ctx, db, rows, table)
25+
}
26+
}
27+
28+
func bulkCopy(ctx context.Context, db *sql.DB, rows *sql.Rows, table string) (int64, error) {
29+
mysql.RegisterReaderHandler("data", func() io.Reader {
30+
return toCsvReader(rows)
31+
})
32+
defer mysql.DeregisterReaderHandler("data")
33+
tx, err := db.BeginTx(ctx, nil)
34+
if err != nil {
35+
return 0, err
36+
}
37+
var cnt int64
38+
res, err := tx.ExecContext(ctx, fmt.Sprintf("LOAD DATA LOCAL INFILE 'Reader::data' INTO TABLE %s",
39+
strings.Replace(table, "(", " FIELDS TERMINATED BY ',' (", 1)))
40+
if err != nil {
41+
tx.Rollback()
42+
} else {
43+
err = tx.Commit()
44+
if err == nil {
45+
cnt, err = res.RowsAffected()
46+
}
47+
}
48+
return cnt, err
49+
}
50+
51+
func hasBlobColumn(rows *sql.Rows) bool {
52+
columnTypes, err := rows.ColumnTypes()
53+
if err != nil {
54+
return false
55+
}
56+
for _, ct := range columnTypes {
57+
if ct.DatabaseTypeName() == "BLOB" {
58+
return true
59+
}
60+
}
61+
return false
62+
}
63+
64+
func toCsvReader(rows *sql.Rows) io.Reader {
65+
r, w := io.Pipe()
66+
go writeAsCsv(rows, w)
67+
return r
68+
}
69+
70+
// writeAsCsv writes the rows in a CSV format compatible with LOAD DATA INFILE
71+
func writeAsCsv(rows *sql.Rows, w *io.PipeWriter) {
72+
defer w.Close() // noop if already closed
73+
columnTypes, err := rows.ColumnTypes()
74+
if err != nil {
75+
w.CloseWithError(err)
76+
return
77+
}
78+
values := make([]interface{}, len(columnTypes))
79+
valueRefs := make([]reflect.Value, len(columnTypes))
80+
for i := 0; i < len(columnTypes); i++ {
81+
valueRefs[i] = reflect.New(columnTypes[i].ScanType())
82+
values[i] = valueRefs[i].Interface()
83+
}
84+
record := make([]string, len(values))
85+
csvWriter := csv.NewWriter(io.MultiWriter(w, os.Stdout))
86+
for rows.Next() {
87+
if err = rows.Err(); err != nil {
88+
break
89+
}
90+
err = rows.Scan(values...)
91+
if err != nil {
92+
break
93+
}
94+
for i, valueRef := range valueRefs {
95+
// NB: Does not work for BLOBs. Use regular copy if there are BLOB columns
96+
record[i] = fmt.Sprintf("%v", valueRef.Elem().Interface())
97+
}
98+
err = csvWriter.Write(record)
99+
if err != nil {
100+
break
101+
}
102+
}
103+
if err == nil {
104+
csvWriter.Flush()
105+
err = csvWriter.Error()
106+
}
107+
w.CloseWithError(err) // same as w.Close(), if err is nil
108+
}

drivers/mysql/mysql.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ func init() {
4545
NewMetadataWriter: func(db drivers.DB, w io.Writer, opts ...metadata.ReaderOption) metadata.Writer {
4646
return metadata.NewDefaultWriter(mymeta.NewReader(db, opts...))(db, w)
4747
},
48-
Copy: drivers.CopyWithInsert(func(int) string { return "?" }),
48+
Copy: copyRows,
4949
NewCompleter: mymeta.NewCompleter,
5050
}, "memsql", "vitess", "tidb")
5151
}

0 commit comments

Comments
 (0)