@@ -11,22 +11,67 @@ import (
11
11
12
12
uuid "github.com/satori/go.uuid"
13
13
14
- "github.com/aws/aws-sdk-go/aws"
15
- "github.com/aws/aws-sdk-go/aws/session"
16
- "github.com/aws/aws-sdk-go/service/athena"
17
- "github.com/aws/aws-sdk-go/service/athena/athenaiface"
14
+ "github.com/aws/aws-sdk-go-v2/aws"
15
+ "github.com/aws/aws-sdk-go-v2/service/athena"
16
+ "github.com/aws/aws-sdk-go-v2/service/athena/types"
18
17
)
19
18
19
+ // Query type patterns
20
+ var (
21
+ ddlQueryPattern = regexp .MustCompile (`(?i)^(ALTER|CREATE|DESCRIBE|DROP|MSCK|SHOW)` )
22
+ selectQueryPattern = regexp .MustCompile (`(?i)^SELECT` )
23
+ ctasQueryPattern = regexp .MustCompile (`(?i)^CREATE.+AS\s+SELECT` )
24
+ )
25
+
26
+ // queryType represents the type of SQL query
27
+ type queryType int
28
+
29
+ const (
30
+ queryTypeUnknown queryType = iota
31
+ queryTypeDDL
32
+ queryTypeSelect
33
+ queryTypeCTAS
34
+ )
35
+
36
+ // getQueryType determines the type of the query
37
+ func getQueryType (query string ) queryType {
38
+ switch {
39
+ case ddlQueryPattern .MatchString (query ):
40
+ return queryTypeDDL
41
+ case ctasQueryPattern .MatchString (query ):
42
+ return queryTypeCTAS
43
+ case selectQueryPattern .MatchString (query ):
44
+ return queryTypeSelect
45
+ default :
46
+ return queryTypeUnknown
47
+ }
48
+ }
49
+
50
+ // isDDLQuery determines if the query is a DDL statement
51
+ func isDDLQuery (query string ) bool {
52
+ return getQueryType (query ) == queryTypeDDL
53
+ }
54
+
55
+ // isSelectQuery determines if the query is a SELECT statement
56
+ func isSelectQuery (query string ) bool {
57
+ return getQueryType (query ) == queryTypeSelect
58
+ }
59
+
60
+ // isCTASQuery determines if the query is a CREATE TABLE AS SELECT statement
61
+ func isCTASQuery (query string ) bool {
62
+ return getQueryType (query ) == queryTypeCTAS
63
+ }
64
+
20
65
type conn struct {
21
- athena athenaiface. AthenaAPI
66
+ athena * athena. Client
22
67
db string
23
68
OutputLocation string
24
69
workgroup string
25
70
26
71
pollFrequency time.Duration
27
72
28
73
resultMode ResultMode
29
- session * session. Session
74
+ config aws. Config
30
75
timeout uint
31
76
catalog string
32
77
}
@@ -54,6 +99,9 @@ func (c *conn) runQuery(ctx context.Context, query string) (driver.Rows, error)
54
99
isSelect := isSelectQuery (query )
55
100
resultMode := c .resultMode
56
101
if rmode , ok := getResultMode (ctx ); ok {
102
+ if ! isValidResultMode (rmode ) {
103
+ return nil , ErrInvalidResultMode
104
+ }
57
105
resultMode = rmode
58
106
}
59
107
if ! isSelect {
@@ -91,7 +139,7 @@ func (c *conn) runQuery(ctx context.Context, query string) (driver.Rows, error)
91
139
afterDownload = c .dropCTASTable (ctx , ctasTable )
92
140
}
93
141
94
- queryID , err := c .startQuery (query )
142
+ queryID , err := c .startQuery (ctx , query )
95
143
if err != nil {
96
144
return nil , err
97
145
}
@@ -105,7 +153,7 @@ func (c *conn) runQuery(ctx context.Context, query string) (driver.Rows, error)
105
153
QueryID : queryID ,
106
154
SkipHeader : ! isDDLQuery (query ),
107
155
ResultMode : resultMode ,
108
- Session : c . session ,
156
+ Config : c . config ,
109
157
OutputLocation : c .OutputLocation ,
110
158
Timeout : timeout ,
111
159
AfterDownload : afterDownload ,
@@ -119,7 +167,7 @@ func (c *conn) dropCTASTable(ctx context.Context, table string) func() error {
119
167
return func () error {
120
168
query := fmt .Sprintf ("DROP TABLE %s" , table )
121
169
122
- queryID , err := c .startQuery (query )
170
+ queryID , err := c .startQuery (ctx , query )
123
171
if err != nil {
124
172
return err
125
173
}
@@ -129,13 +177,13 @@ func (c *conn) dropCTASTable(ctx context.Context, table string) func() error {
129
177
}
130
178
131
179
// startQuery starts an Athena query and returns its ID.
132
- func (c * conn ) startQuery (query string ) (string , error ) {
133
- resp , err := c .athena .StartQueryExecution (& athena.StartQueryExecutionInput {
180
+ func (c * conn ) startQuery (ctx context. Context , query string ) (string , error ) {
181
+ resp , err := c .athena .StartQueryExecution (ctx , & athena.StartQueryExecutionInput {
134
182
QueryString : aws .String (query ),
135
- QueryExecutionContext : & athena .QueryExecutionContext {
183
+ QueryExecutionContext : & types .QueryExecutionContext {
136
184
Database : aws .String (c .db ),
137
185
},
138
- ResultConfiguration : & athena .ResultConfiguration {
186
+ ResultConfiguration : & types .ResultConfiguration {
139
187
OutputLocation : aws .String (c .OutputLocation ),
140
188
},
141
189
WorkGroup : aws .String (c .workgroup ),
@@ -150,28 +198,28 @@ func (c *conn) startQuery(query string) (string, error) {
150
198
// waitOnQuery blocks until a query finishes, returning an error if it failed.
151
199
func (c * conn ) waitOnQuery (ctx context.Context , queryID string ) error {
152
200
for {
153
- statusResp , err := c .athena .GetQueryExecutionWithContext (ctx , & athena.GetQueryExecutionInput {
201
+ statusResp , err := c .athena .GetQueryExecution (ctx , & athena.GetQueryExecutionInput {
154
202
QueryExecutionId : aws .String (queryID ),
155
203
})
156
204
if err != nil {
157
205
return err
158
206
}
159
207
160
- switch * statusResp .QueryExecution .Status .State {
161
- case athena .QueryExecutionStateCancelled :
208
+ switch statusResp .QueryExecution .Status .State {
209
+ case types .QueryExecutionStateCancelled :
162
210
return context .Canceled
163
- case athena .QueryExecutionStateFailed :
211
+ case types .QueryExecutionStateFailed :
164
212
reason := * statusResp .QueryExecution .Status .StateChangeReason
165
213
return errors .New (reason )
166
- case athena .QueryExecutionStateSucceeded :
214
+ case types .QueryExecutionStateSucceeded :
167
215
return nil
168
- case athena .QueryExecutionStateQueued :
169
- case athena .QueryExecutionStateRunning :
216
+ case types .QueryExecutionStateQueued :
217
+ case types .QueryExecutionStateRunning :
170
218
}
171
219
172
220
select {
173
221
case <- ctx .Done ():
174
- c .athena .StopQueryExecution (& athena.StopQueryExecutionInput {
222
+ c .athena .StopQueryExecution (ctx , & athena.StopQueryExecutionInput {
175
223
QueryExecutionId : aws .String (queryID ),
176
224
})
177
225
@@ -229,7 +277,7 @@ func (c *conn) prepareContext(ctx context.Context, query string) (driver.Stmt, e
229
277
prepareKey := fmt .Sprintf ("tmp_prepare_%v" , strings .Replace (uuid .NewV4 ().String (), "-" , "" , - 1 ))
230
278
newQuery := fmt .Sprintf ("PREPARE %s FROM %s" , prepareKey , query )
231
279
232
- queryID , err := c .startQuery (newQuery )
280
+ queryID , err := c .startQuery (ctx , newQuery )
233
281
if err != nil {
234
282
return nil , err
235
283
}
@@ -273,22 +321,16 @@ func (c *conn) Exec(query string, args []driver.Value) (driver.Result, error) {
273
321
var _ driver.Queryer = (* conn )(nil )
274
322
var _ driver.Execer = (* conn )(nil )
275
323
276
- // supported DDL statements by Athena
277
- // https://docs.aws.amazon.com/athena/latest/ug/language-reference.html
278
- var ddlQueryRegex = regexp .MustCompile (`(?i)^(ALTER|CREATE|DESCRIBE|DROP|MSCK|SHOW)` )
279
-
280
- func isDDLQuery (query string ) bool {
281
- return ddlQueryRegex .Match ([]byte (query ))
282
- }
283
-
284
- func isSelectQuery (query string ) bool {
285
- return regexp .MustCompile (`(?i)^SELECT` ).Match ([]byte (query ))
286
- }
287
-
288
- func isCTASQuery (query string ) bool {
289
- return regexp .MustCompile (`(?i)^CREATE.+AS\s+SELECT` ).Match ([]byte (query ))
290
- }
291
-
292
324
func isCreatingCTASTable (isSelect bool , resultMode ResultMode ) bool {
293
325
return isSelect && resultMode == ResultModeGzipDL
294
326
}
327
+
328
+ // isValidResultMode checks if the given result mode is valid
329
+ func isValidResultMode (mode ResultMode ) bool {
330
+ switch mode {
331
+ case ResultModeAPI , ResultModeDL , ResultModeGzipDL :
332
+ return true
333
+ default :
334
+ return false
335
+ }
336
+ }
0 commit comments