@@ -4,10 +4,15 @@ import (
4
4
"context"
5
5
"database/sql/driver"
6
6
"errors"
7
+ "fmt"
7
8
"regexp"
9
+ "strings"
8
10
"time"
9
11
12
+ uuid "github.com/satori/go.uuid"
13
+
10
14
"github.com/aws/aws-sdk-go/aws"
15
+ "github.com/aws/aws-sdk-go/aws/session"
11
16
"github.com/aws/aws-sdk-go/service/athena"
12
17
"github.com/aws/aws-sdk-go/service/athena/athenaiface"
13
18
)
@@ -19,6 +24,11 @@ type conn struct {
19
24
workgroup string
20
25
21
26
pollFrequency time.Duration
27
+
28
+ resultMode ResultMode
29
+ session * session.Session
30
+ timeout uint
31
+ catalog string
22
32
}
23
33
24
34
func (c * conn ) QueryContext (ctx context.Context , query string , args []driver.NamedValue ) (driver.Rows , error ) {
@@ -40,6 +50,38 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
40
50
}
41
51
42
52
func (c * conn ) runQuery (ctx context.Context , query string ) (driver.Rows , error ) {
53
+ // result mode
54
+ isSelect := isSelectQuery (query )
55
+ resultMode := c .resultMode
56
+ if rmode , ok := getResultMode (ctx ); ok {
57
+ resultMode = rmode
58
+ }
59
+ if ! isSelect {
60
+ resultMode = ResultModeAPI
61
+ }
62
+
63
+ // timeout
64
+ timeout := c .timeout
65
+ if to , ok := getTimeout (ctx ); ok {
66
+ timeout = to
67
+ }
68
+
69
+ // catalog
70
+ catalog := c .catalog
71
+ if cat , ok := getCatalog (ctx ); ok {
72
+ catalog = cat
73
+ }
74
+
75
+ // mode ctas
76
+ var ctasTable string
77
+ var afterDownload func () error
78
+ if isSelect && resultMode == ResultModeGzipDL {
79
+ // Create AS Select
80
+ ctasTable = fmt .Sprintf ("tmp_ctas_%v" , strings .Replace (uuid .NewV4 ().String (), "-" , "" , - 1 ))
81
+ query = fmt .Sprintf ("CREATE TABLE %s WITH (format='TEXTFILE') AS %s" , ctasTable , query )
82
+ afterDownload = c .dropCTASTable (ctx , ctasTable )
83
+ }
84
+
43
85
queryID , err := c .startQuery (query )
44
86
if err != nil {
45
87
return nil , err
@@ -50,12 +92,33 @@ func (c *conn) runQuery(ctx context.Context, query string) (driver.Rows, error)
50
92
}
51
93
52
94
return newRows (rowsConfig {
53
- Athena : c .athena ,
54
- QueryID : queryID ,
55
- SkipHeader : ! isDDLQuery (query ),
95
+ Athena : c .athena ,
96
+ QueryID : queryID ,
97
+ SkipHeader : ! isDDLQuery (query ),
98
+ ResultMode : resultMode ,
99
+ Session : c .session ,
100
+ OutputLocation : c .OutputLocation ,
101
+ Timeout : timeout ,
102
+ AfterDownload : afterDownload ,
103
+ CTASTable : ctasTable ,
104
+ DB : c .db ,
105
+ Catalog : catalog ,
56
106
})
57
107
}
58
108
109
+ func (c * conn ) dropCTASTable (ctx context.Context , table string ) func () error {
110
+ return func () error {
111
+ query := fmt .Sprintf ("DROP TABLE %s" , table )
112
+
113
+ queryID , err := c .startQuery (query )
114
+ if err != nil {
115
+ return err
116
+ }
117
+
118
+ return c .waitOnQuery (ctx , queryID )
119
+ }
120
+ }
121
+
59
122
// startQuery starts an Athena query and returns its ID.
60
123
func (c * conn ) startQuery (query string ) (string , error ) {
61
124
resp , err := c .athena .StartQueryExecution (& athena.StartQueryExecutionInput {
@@ -146,3 +209,11 @@ var ddlQueryRegex = regexp.MustCompile(`(?i)^(ALTER|CREATE|DESCRIBE|DROP|MSCK|SH
146
209
func isDDLQuery (query string ) bool {
147
210
return ddlQueryRegex .Match ([]byte (query ))
148
211
}
212
+
213
+ func isSelectQuery (query string ) bool {
214
+ return regexp .MustCompile (`(?i)^SELECT` ).Match ([]byte (query ))
215
+ }
216
+
217
+ func isCTASQuery (query string ) bool {
218
+ return regexp .MustCompile (`(?i)^CREATE.+AS\s+SELECT` ).Match ([]byte (query ))
219
+ }
0 commit comments