Skip to content

Commit 2e5d24c

Browse files
committed
[in progress] Bulk load for MySQL
1 parent c3e8cde commit 2e5d24c

File tree

3 files changed

+136
-4
lines changed

3 files changed

+136
-4
lines changed

drivers/drivers_test.go

+21-3
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,7 @@ func TestCopy(t *testing.T) {
435435

436436
testCases := []struct {
437437
dbName string
438+
testCase string
438439
setupQueries []setupQuery
439440
src string
440441
dest string
@@ -449,7 +450,8 @@ func TestCopy(t *testing.T) {
449450
dest: "staff_copy",
450451
},
451452
{
452-
dbName: "pgsql",
453+
dbName: "pgsql",
454+
testCase: "schemaInDest",
453455
setupQueries: []setupQuery{
454456
{query: "DROP TABLE staff_copy"},
455457
{query: "CREATE TABLE staff_copy AS SELECT * FROM staff WHERE 0=1", check: true},
@@ -467,7 +469,8 @@ func TestCopy(t *testing.T) {
467469
dest: "staff_copy",
468470
},
469471
{
470-
dbName: "pgx",
472+
dbName: "pgx",
473+
testCase: "schemaInDest",
471474
setupQueries: []setupQuery{
472475
{query: "DROP TABLE staff_copy"},
473476
{query: "CREATE TABLE staff_copy AS SELECT * FROM staff WHERE 0=1", check: true},
@@ -484,6 +487,17 @@ func TestCopy(t *testing.T) {
484487
src: "select staff_id, first_name, last_name, address_id, picture, email, store_id, active, username, password, last_update from staff",
485488
dest: "staff_copy(staff_id, first_name, last_name, address_id, picture, email, store_id, active, username, password, last_update)",
486489
},
490+
{
491+
dbName: "mysql",
492+
testCase: "bulkCopy",
493+
setupQueries: []setupQuery{
494+
{query: "SET GLOBAL local_infile = ON"},
495+
{query: "DROP TABLE staff_copy"},
496+
{query: "CREATE TABLE staff_copy AS SELECT * FROM staff WHERE 0=1", check: true},
497+
},
498+
src: "select staff_id, first_name, last_name, address_id, email, store_id, active, username, password, last_update from staff",
499+
dest: "staff_copy(staff_id, first_name, last_name, address_id, email, store_id, active, username, password, last_update)",
500+
},
487501
{
488502
dbName: "sqlserver",
489503
setupQueries: []setupQuery{
@@ -508,7 +522,11 @@ func TestCopy(t *testing.T) {
508522
continue
509523
}
510524

511-
t.Run(test.dbName, func(t *testing.T) {
525+
testName := test.dbName
526+
if test.testCase != "" {
527+
testName += "-" + test.testCase
528+
}
529+
t.Run(testName, func(t *testing.T) {
512530

513531
// TODO test copy from a different DB, maybe csvq?
514532
// TODO test copy from same DB

drivers/mysql/copy.go

+114
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
package mysql
2+
3+
import (
4+
"context"
5+
"database/sql"
6+
"encoding/csv"
7+
"fmt"
8+
"io"
9+
"reflect"
10+
"strings"
11+
12+
"github.com/go-sql-driver/mysql"
13+
"github.com/xo/usql/drivers"
14+
)
15+
16+
func copyRows(ctx context.Context, db *sql.DB, rows *sql.Rows, table string) (int64, error) {
17+
localInfileSupported := false
18+
row := db.QueryRowContext(ctx, "SELECT @@GLOBAL.local_infile")
19+
err := row.Scan(&localInfileSupported)
20+
if err == nil && localInfileSupported && !hasBlobColumn(rows) {
21+
return bulkCopy(ctx, db, rows, table)
22+
} else {
23+
return drivers.CopyWithInsert(func(int) string { return "?" })(ctx, db, rows, table)
24+
}
25+
}
26+
27+
func bulkCopy(ctx context.Context, db *sql.DB, rows *sql.Rows, table string) (int64, error) {
28+
mysql.RegisterReaderHandler("data", func() io.Reader {
29+
return toCsvReader(rows)
30+
})
31+
defer mysql.DeregisterReaderHandler("data")
32+
tx, err := db.BeginTx(ctx, nil)
33+
if err != nil {
34+
return 0, err
35+
}
36+
var cnt int64
37+
csvSpec := " FIELDS TERMINATED BY ',' "
38+
stmt := fmt.Sprintf("LOAD DATA LOCAL INFILE 'Reader::data' INTO TABLE %s",
39+
// if there is a column list, csvSpec goes between the table name and the list
40+
strings.Replace(table, "(", csvSpec+" (", 1))
41+
// if there wasn't a column list in the table spec, csvSpec goes at the end
42+
if !strings.Contains(table, "(") {
43+
stmt += csvSpec
44+
}
45+
res, err := tx.ExecContext(ctx, stmt)
46+
if err != nil {
47+
tx.Rollback()
48+
} else {
49+
err = tx.Commit()
50+
if err == nil {
51+
cnt, err = res.RowsAffected()
52+
}
53+
}
54+
return cnt, err
55+
}
56+
57+
func hasBlobColumn(rows *sql.Rows) bool {
58+
columnTypes, err := rows.ColumnTypes()
59+
if err != nil {
60+
return false
61+
}
62+
for _, ct := range columnTypes {
63+
if ct.DatabaseTypeName() == "BLOB" {
64+
return true
65+
}
66+
}
67+
return false
68+
}
69+
70+
func toCsvReader(rows *sql.Rows) io.Reader {
71+
r, w := io.Pipe()
72+
go writeAsCsv(rows, w)
73+
return r
74+
}
75+
76+
// writeAsCsv writes the rows in a CSV format compatible with LOAD DATA INFILE
77+
func writeAsCsv(rows *sql.Rows, w *io.PipeWriter) {
78+
defer w.Close() // noop if already closed
79+
columnTypes, err := rows.ColumnTypes()
80+
if err != nil {
81+
w.CloseWithError(err)
82+
return
83+
}
84+
values := make([]interface{}, len(columnTypes))
85+
valueRefs := make([]reflect.Value, len(columnTypes))
86+
for i := 0; i < len(columnTypes); i++ {
87+
valueRefs[i] = reflect.New(columnTypes[i].ScanType())
88+
values[i] = valueRefs[i].Interface()
89+
}
90+
record := make([]string, len(values))
91+
csvWriter := csv.NewWriter(w)
92+
for rows.Next() {
93+
if err = rows.Err(); err != nil {
94+
break
95+
}
96+
err = rows.Scan(values...)
97+
if err != nil {
98+
break
99+
}
100+
for i, valueRef := range valueRefs {
101+
// NB: Does not work for BLOBs. Use regular copy if there are BLOB columns
102+
record[i] = fmt.Sprintf("%v", valueRef.Elem().Interface())
103+
}
104+
err = csvWriter.Write(record)
105+
if err != nil {
106+
break
107+
}
108+
}
109+
if err == nil {
110+
csvWriter.Flush()
111+
err = csvWriter.Error()
112+
}
113+
w.CloseWithError(err) // same as w.Close(), if err is nil
114+
}

drivers/mysql/mysql.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ func init() {
4545
NewMetadataWriter: func(db drivers.DB, w io.Writer, opts ...metadata.ReaderOption) metadata.Writer {
4646
return metadata.NewDefaultWriter(mymeta.NewReader(db, opts...))(db, w)
4747
},
48-
Copy: drivers.CopyWithInsert(func(int) string { return "?" }),
48+
Copy: copyRows,
4949
NewCompleter: mymeta.NewCompleter,
5050
}, "memsql", "vitess", "tidb")
5151
}

0 commit comments

Comments
 (0)