Skip to content

Commit b5bff4c

Browse files
authored
Merge pull request #25 from muroon/aws-sdk-go-v2
Upgrade to aws-sdk-go-v2
2 parents d30758b + 8ca7ddd commit b5bff4c

15 files changed

+438
-262
lines changed

README.md

+13
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,19 @@ Note
6060
- Detailed explanation is described [here](doc/result_mode.md).
6161
- [Usages of Result Mode](doc/result_mode.md#usages).
6262

63+
## Performance
64+
65+
Response time for fetching 1,000,000 records
66+
67+
|package|response time|
68+
|--|--|
69+
|segmentio/go-athena|2m 33.4132205s|
70+
|speee/go-athena API mode|2m 26.475804292s|
71+
|speee/go-athena DL mode|20.719727417s|
72+
|speee/go-athena GZIP mode|17.661648209s|
73+
74+
Detailed explanation is described [here](doc/result_mode.md#response-time-for-each-mode).
75+
6376
## Prepared Statements
6477

6578
You can use [Athena Prepared Statements](https://docs.aws.amazon.com/athena/latest/ug/querying-with-prepared-statements.html).

conn.go

+80-38
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,67 @@ import (
1111

1212
uuid "github.com/satori/go.uuid"
1313

14-
"github.com/aws/aws-sdk-go/aws"
15-
"github.com/aws/aws-sdk-go/aws/session"
16-
"github.com/aws/aws-sdk-go/service/athena"
17-
"github.com/aws/aws-sdk-go/service/athena/athenaiface"
14+
"github.com/aws/aws-sdk-go-v2/aws"
15+
"github.com/aws/aws-sdk-go-v2/service/athena"
16+
"github.com/aws/aws-sdk-go-v2/service/athena/types"
1817
)
1918

19+
// Query type patterns
20+
var (
21+
ddlQueryPattern = regexp.MustCompile(`(?i)^(ALTER|CREATE|DESCRIBE|DROP|MSCK|SHOW)`)
22+
selectQueryPattern = regexp.MustCompile(`(?i)^SELECT`)
23+
ctasQueryPattern = regexp.MustCompile(`(?i)^CREATE.+AS\s+SELECT`)
24+
)
25+
26+
// queryType represents the type of SQL query
27+
type queryType int
28+
29+
const (
30+
queryTypeUnknown queryType = iota
31+
queryTypeDDL
32+
queryTypeSelect
33+
queryTypeCTAS
34+
)
35+
36+
// getQueryType determines the type of the query
37+
func getQueryType(query string) queryType {
38+
switch {
39+
case ddlQueryPattern.MatchString(query):
40+
return queryTypeDDL
41+
case ctasQueryPattern.MatchString(query):
42+
return queryTypeCTAS
43+
case selectQueryPattern.MatchString(query):
44+
return queryTypeSelect
45+
default:
46+
return queryTypeUnknown
47+
}
48+
}
49+
50+
// isDDLQuery determines if the query is a DDL statement
51+
func isDDLQuery(query string) bool {
52+
return getQueryType(query) == queryTypeDDL
53+
}
54+
55+
// isSelectQuery determines if the query is a SELECT statement
56+
func isSelectQuery(query string) bool {
57+
return getQueryType(query) == queryTypeSelect
58+
}
59+
60+
// isCTASQuery determines if the query is a CREATE TABLE AS SELECT statement
61+
func isCTASQuery(query string) bool {
62+
return getQueryType(query) == queryTypeCTAS
63+
}
64+
2065
type conn struct {
21-
athena athenaiface.AthenaAPI
66+
athena *athena.Client
2267
db string
2368
OutputLocation string
2469
workgroup string
2570

2671
pollFrequency time.Duration
2772

2873
resultMode ResultMode
29-
session *session.Session
74+
config aws.Config
3075
timeout uint
3176
catalog string
3277
}
@@ -54,6 +99,9 @@ func (c *conn) runQuery(ctx context.Context, query string) (driver.Rows, error)
5499
isSelect := isSelectQuery(query)
55100
resultMode := c.resultMode
56101
if rmode, ok := getResultMode(ctx); ok {
102+
if !isValidResultMode(rmode) {
103+
return nil, ErrInvalidResultMode
104+
}
57105
resultMode = rmode
58106
}
59107
if !isSelect {
@@ -91,7 +139,7 @@ func (c *conn) runQuery(ctx context.Context, query string) (driver.Rows, error)
91139
afterDownload = c.dropCTASTable(ctx, ctasTable)
92140
}
93141

94-
queryID, err := c.startQuery(query)
142+
queryID, err := c.startQuery(ctx, query)
95143
if err != nil {
96144
return nil, err
97145
}
@@ -105,7 +153,7 @@ func (c *conn) runQuery(ctx context.Context, query string) (driver.Rows, error)
105153
QueryID: queryID,
106154
SkipHeader: !isDDLQuery(query),
107155
ResultMode: resultMode,
108-
Session: c.session,
156+
Config: c.config,
109157
OutputLocation: c.OutputLocation,
110158
Timeout: timeout,
111159
AfterDownload: afterDownload,
@@ -119,7 +167,7 @@ func (c *conn) dropCTASTable(ctx context.Context, table string) func() error {
119167
return func() error {
120168
query := fmt.Sprintf("DROP TABLE %s", table)
121169

122-
queryID, err := c.startQuery(query)
170+
queryID, err := c.startQuery(ctx, query)
123171
if err != nil {
124172
return err
125173
}
@@ -129,13 +177,13 @@ func (c *conn) dropCTASTable(ctx context.Context, table string) func() error {
129177
}
130178

131179
// startQuery starts an Athena query and returns its ID.
132-
func (c *conn) startQuery(query string) (string, error) {
133-
resp, err := c.athena.StartQueryExecution(&athena.StartQueryExecutionInput{
180+
func (c *conn) startQuery(ctx context.Context, query string) (string, error) {
181+
resp, err := c.athena.StartQueryExecution(ctx, &athena.StartQueryExecutionInput{
134182
QueryString: aws.String(query),
135-
QueryExecutionContext: &athena.QueryExecutionContext{
183+
QueryExecutionContext: &types.QueryExecutionContext{
136184
Database: aws.String(c.db),
137185
},
138-
ResultConfiguration: &athena.ResultConfiguration{
186+
ResultConfiguration: &types.ResultConfiguration{
139187
OutputLocation: aws.String(c.OutputLocation),
140188
},
141189
WorkGroup: aws.String(c.workgroup),
@@ -150,28 +198,28 @@ func (c *conn) startQuery(query string) (string, error) {
150198
// waitOnQuery blocks until a query finishes, returning an error if it failed.
151199
func (c *conn) waitOnQuery(ctx context.Context, queryID string) error {
152200
for {
153-
statusResp, err := c.athena.GetQueryExecutionWithContext(ctx, &athena.GetQueryExecutionInput{
201+
statusResp, err := c.athena.GetQueryExecution(ctx, &athena.GetQueryExecutionInput{
154202
QueryExecutionId: aws.String(queryID),
155203
})
156204
if err != nil {
157205
return err
158206
}
159207

160-
switch *statusResp.QueryExecution.Status.State {
161-
case athena.QueryExecutionStateCancelled:
208+
switch statusResp.QueryExecution.Status.State {
209+
case types.QueryExecutionStateCancelled:
162210
return context.Canceled
163-
case athena.QueryExecutionStateFailed:
211+
case types.QueryExecutionStateFailed:
164212
reason := *statusResp.QueryExecution.Status.StateChangeReason
165213
return errors.New(reason)
166-
case athena.QueryExecutionStateSucceeded:
214+
case types.QueryExecutionStateSucceeded:
167215
return nil
168-
case athena.QueryExecutionStateQueued:
169-
case athena.QueryExecutionStateRunning:
216+
case types.QueryExecutionStateQueued:
217+
case types.QueryExecutionStateRunning:
170218
}
171219

172220
select {
173221
case <-ctx.Done():
174-
c.athena.StopQueryExecution(&athena.StopQueryExecutionInput{
222+
c.athena.StopQueryExecution(ctx, &athena.StopQueryExecutionInput{
175223
QueryExecutionId: aws.String(queryID),
176224
})
177225

@@ -229,7 +277,7 @@ func (c *conn) prepareContext(ctx context.Context, query string) (driver.Stmt, e
229277
prepareKey := fmt.Sprintf("tmp_prepare_%v", strings.Replace(uuid.NewV4().String(), "-", "", -1))
230278
newQuery := fmt.Sprintf("PREPARE %s FROM %s", prepareKey, query)
231279

232-
queryID, err := c.startQuery(newQuery)
280+
queryID, err := c.startQuery(ctx, newQuery)
233281
if err != nil {
234282
return nil, err
235283
}
@@ -273,22 +321,16 @@ func (c *conn) Exec(query string, args []driver.Value) (driver.Result, error) {
273321
var _ driver.Queryer = (*conn)(nil)
274322
var _ driver.Execer = (*conn)(nil)
275323

276-
// supported DDL statements by Athena
277-
// https://docs.aws.amazon.com/athena/latest/ug/language-reference.html
278-
var ddlQueryRegex = regexp.MustCompile(`(?i)^(ALTER|CREATE|DESCRIBE|DROP|MSCK|SHOW)`)
279-
280-
func isDDLQuery(query string) bool {
281-
return ddlQueryRegex.Match([]byte(query))
282-
}
283-
284-
func isSelectQuery(query string) bool {
285-
return regexp.MustCompile(`(?i)^SELECT`).Match([]byte(query))
286-
}
287-
288-
func isCTASQuery(query string) bool {
289-
return regexp.MustCompile(`(?i)^CREATE.+AS\s+SELECT`).Match([]byte(query))
290-
}
291-
292324
func isCreatingCTASTable(isSelect bool, resultMode ResultMode) bool {
293325
return isSelect && resultMode == ResultModeGzipDL
294326
}
327+
328+
// isValidResultMode checks if the given result mode is valid
329+
func isValidResultMode(mode ResultMode) bool {
330+
switch mode {
331+
case ResultModeAPI, ResultModeDL, ResultModeGzipDL:
332+
return true
333+
default:
334+
return false
335+
}
336+
}

context.go

+48-43
Original file line numberDiff line numberDiff line change
@@ -2,71 +2,76 @@ package athena
22

33
import "context"
44

5-
const contextPrefix string = "go-athena"
6-
7-
/*
8-
* Result Mode
9-
*/
10-
11-
const resultModeContextKey string = "result_mode_key"
5+
// contextKey is a type for context keys to ensure type safety
6+
type contextKey string
7+
8+
const contextPrefix = "go-athena"
9+
10+
// Context keys
11+
const (
12+
resultModeKey contextKey = contextKey(contextPrefix + "result_mode_key")
13+
timeoutKey contextKey = contextKey(contextPrefix + "timeout_key")
14+
catalogKey contextKey = contextKey(contextPrefix + "catalog_key")
15+
)
16+
17+
// ResultModeContextKey is deprecated, use resultModeKey instead
18+
var ResultModeContextKey = string(resultModeKey)
19+
20+
// TimeoutContextKey is deprecated, use timeoutKey instead
21+
var TimeoutContextKey = string(timeoutKey)
22+
23+
// CatalogContextKey is deprecated, use catalogKey instead
24+
var CatalogContextKey = string(catalogKey)
25+
26+
// contextValue safely retrieves a typed value from context
27+
func contextValue[T any](ctx context.Context, key contextKey) (T, bool) {
28+
v := ctx.Value(key)
29+
if v == nil {
30+
var zero T
31+
return zero, false
32+
}
33+
val, ok := v.(T)
34+
return val, ok
35+
}
1236

13-
// ResultModeContextKey context key of setting result mode
14-
var ResultModeContextKey string = contextPrefix + resultModeContextKey
37+
// SetResultMode sets the ResultMode in context
38+
func SetResultMode(ctx context.Context, mode ResultMode) context.Context {
39+
return context.WithValue(ctx, resultModeKey, mode)
40+
}
1541

16-
// SetAPIMode set APIMode to ResultMode from context
42+
// SetAPIMode sets APIMode to ResultMode in context
1743
func SetAPIMode(ctx context.Context) context.Context {
18-
return context.WithValue(ctx, ResultModeContextKey, ResultModeAPI)
44+
return SetResultMode(ctx, ResultModeAPI)
1945
}
2046

21-
// SetDLMode set DownloadMode to ResultMode from context
47+
// SetDLMode sets DownloadMode to ResultMode in context
2248
func SetDLMode(ctx context.Context) context.Context {
23-
return context.WithValue(ctx, ResultModeContextKey, ResultModeDL)
49+
return SetResultMode(ctx, ResultModeDL)
2450
}
2551

26-
// SetGzipDLMode set CTASMode to ResultMode from context
52+
// SetGzipDLMode sets GzipDLMode to ResultMode in context
2753
func SetGzipDLMode(ctx context.Context) context.Context {
28-
return context.WithValue(ctx, ResultModeContextKey, ResultModeGzipDL)
54+
return SetResultMode(ctx, ResultModeGzipDL)
2955
}
3056

3157
func getResultMode(ctx context.Context) (ResultMode, bool) {
32-
val, ok := ctx.Value(ResultModeContextKey).(ResultMode)
33-
return val, ok
58+
return contextValue[ResultMode](ctx, resultModeKey)
3459
}
3560

36-
/*
37-
* timeout
38-
*/
39-
40-
const timeoutContextKey string = "timeout_key"
41-
42-
// TimeoutContextKey context key of setting timeout
43-
var TimeoutContextKey string = contextPrefix + timeoutContextKey
44-
45-
// SetTimeout set timeout from context
61+
// SetTimeout sets timeout in context
4662
func SetTimeout(ctx context.Context, timeout uint) context.Context {
47-
return context.WithValue(ctx, TimeoutContextKey, timeout)
63+
return context.WithValue(ctx, timeoutKey, timeout)
4864
}
4965

5066
func getTimeout(ctx context.Context) (uint, bool) {
51-
val, ok := ctx.Value(TimeoutContextKey).(uint)
52-
return val, ok
67+
return contextValue[uint](ctx, timeoutKey)
5368
}
5469

55-
/*
56-
* catalog
57-
*/
58-
59-
const catalogContextKey string = "catalog_key"
60-
61-
// CatalogContextKey context key of setting catalog
62-
var CatalogContextKey string = contextPrefix + catalogContextKey
63-
64-
// SetCatalog set catalog from context
70+
// SetCatalog sets catalog in context
6571
func SetCatalog(ctx context.Context, catalog string) context.Context {
66-
return context.WithValue(ctx, CatalogContextKey, catalog)
72+
return context.WithValue(ctx, catalogKey, catalog)
6773
}
6874

6975
func getCatalog(ctx context.Context) (string, bool) {
70-
val, ok := ctx.Value(CatalogContextKey).(string)
71-
return val, ok
76+
return contextValue[string](ctx, catalogKey)
7277
}

0 commit comments

Comments
 (0)