Skip to content

Commit 79b682d

Browse files
authored
Merge pull request #9 from danimal141/feature/support-athena-workgroup
Feature/support athena workgroup
2 parents 8bd8e4e + 119bf1c commit 79b682d

File tree

2 files changed

+15
-0
lines changed

2 files changed

+15
-0
lines changed

conn.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ type conn struct {
1616
athena athenaiface.AthenaAPI
1717
db string
1818
OutputLocation string
19+
workgroup string
1920

2021
pollFrequency time.Duration
2122
}
@@ -65,6 +66,7 @@ func (c *conn) startQuery(query string) (string, error) {
6566
ResultConfiguration: &athena.ResultConfiguration{
6667
OutputLocation: aws.String(c.OutputLocation),
6768
},
69+
WorkGroup: aws.String(c.workgroup),
6870
})
6971
if err != nil {
7072
return "", err

driver.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@ func init() {
5858
// - `region` (optional)
5959
// Override AWS region. Useful if it is not set with environment variable.
6060
//
61+
// - `workgroup` (optional)
62+
// Athena's workgroup. This defaults to "primary".
63+
//
6164
// Credentials must be accessible via the SDK's Default Credential Provider Chain.
6265
// For more advanced AWS credentials/session/config management, please supply
6366
// a custom AWS session directly via `athena.Open()`.
@@ -80,6 +83,7 @@ func (d *Driver) Open(connStr string) (driver.Conn, error) {
8083
db: cfg.Database,
8184
OutputLocation: cfg.OutputLocation,
8285
pollFrequency: cfg.PollFrequency,
86+
workgroup: cfg.WorkGroup,
8387
}, nil
8488
}
8589

@@ -99,6 +103,10 @@ func Open(cfg Config) (*sql.DB, error) {
99103
return nil, errors.New("session is required")
100104
}
101105

106+
if cfg.WorkGroup == "" {
107+
cfg.WorkGroup = "primary"
108+
}
109+
102110
// This hack was copied from jackc/pgx. Sorry :(
103111
// https://github.com/jackc/pgx/blob/70a284f4f33a9cc28fd1223f6b83fb00deecfe33/stdlib/sql.go#L130-L136
104112
openFromSessionMutex.Lock()
@@ -115,6 +123,7 @@ type Config struct {
115123
Session *session.Session
116124
Database string
117125
OutputLocation string
126+
WorkGroup string
118127

119128
PollFrequency time.Duration
120129
}
@@ -138,6 +147,10 @@ func configFromConnectionString(connStr string) (*Config, error) {
138147

139148
cfg.Database = args.Get("db")
140149
cfg.OutputLocation = args.Get("output_location")
150+
cfg.WorkGroup = args.Get("workgroup")
151+
if cfg.WorkGroup == "" {
152+
cfg.WorkGroup = "primary"
153+
}
141154

142155
frequencyStr := args.Get("poll_frequency")
143156
if frequencyStr != "" {

0 commit comments

Comments
 (0)