-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtx.go
172 lines (143 loc) · 4.6 KB
/
tx.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
/*
Copyright © 2024 Acronis International GmbH.
Released under MIT license.
*/
package goquutil
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/acronis/go-appkit/httpserver/middleware"
golibslog "github.com/acronis/go-appkit/log"
"github.com/doug-martin/goqu/v9"
)
// PreQueryFuncT is type for pre query hook function
type PreQueryFuncT func(ctx context.Context, query string, args ...interface{}) string
// PostQueryFuncT is type for post query hook function
type PostQueryFuncT func(ctx context.Context, startedAt time.Time, err error, query string, args ...interface{})
// PreQueryHook will be executed before actual query execution
var PreQueryHook PreQueryFuncT
// PostQueryHook will be executed after actual query execution
var PostQueryHook PostQueryFuncT
// ContextProvider is an interface that defines a method for obtaining a context.Context.
// Implementing types should return the context.Context representing
// the execution context of the operation or task.
type ContextProvider interface {
Context() context.Context
}
type cancellableTxQuerier struct {
ctx context.Context
tx *goqu.TxDatabase
}
func newCancellableTxQuerier(ctx context.Context, tx *goqu.TxDatabase) Querier {
return &cancellableTxQuerier{ctx: ctx, tx: tx}
}
func (q *cancellableTxQuerier) Exec(query string, args ...interface{}) (sql.Result, error) {
if PreQueryHook != nil {
query = PreQueryHook(q.ctx, query, args...)
}
start := time.Now().UTC()
res, err := q.tx.ExecContext(q.ctx, query, args...)
if PostQueryHook != nil {
PostQueryHook(q.ctx, start, err, query, args...)
}
return res, err
}
func (q *cancellableTxQuerier) Query(query string, args ...interface{}) (*sql.Rows, error) {
if PreQueryHook != nil {
query = PreQueryHook(q.ctx, query, args...)
}
start := time.Now().UTC()
res, err := q.tx.QueryContext(q.ctx, query, args...)
if PostQueryHook != nil {
PostQueryHook(q.ctx, start, err, query, args...)
}
return res, err
}
func (q *cancellableTxQuerier) QueryRow(query string, args ...interface{}) *sql.Row {
if PreQueryHook != nil {
query = PreQueryHook(q.ctx, query, args...)
}
start := time.Now().UTC()
res := q.tx.QueryRowContext(q.ctx, query, args...)
if PostQueryHook != nil {
PostQueryHook(q.ctx, start, nil, query, args...)
}
return res
}
func (q *cancellableTxQuerier) Context() context.Context {
return q.ctx
}
// DB is a wrapper for goqu.Database
type DB struct {
db *goqu.Database
ctx context.Context
txOpts *sql.TxOptions
logger golibslog.FieldLogger
loggingCtx string
loggingTimeThresholdBeginTx time.Duration
}
// NewDB returns tx wrapper for goqu.Database
func NewDB(ctx context.Context, db *goqu.Database) *DB {
return &DB{db: db, ctx: ctx}
}
// DoInTx opens db tx and runs worker func within its context
func (d *DB) DoInTx(worker func(q Querier) error) error {
start := time.Now()
tx, err := d.db.BeginTx(d.ctx, d.txOpts)
if err != nil {
return err
}
if d.logger != nil {
elapsed := time.Since(start).Milliseconds()
var level = golibslog.LevelDebug
if elapsed > d.loggingTimeThresholdBeginTx.Milliseconds() {
level = golibslog.LevelInfo
}
d.logger.AtLevel(level, func(logFunc golibslog.LogFunc) {
logFunc(
fmt.Sprintf("opened DB transaction (%s) in %dms", d.loggingCtx, elapsed),
golibslog.Int64("duration_ms", elapsed),
)
})
if d.ctx != nil {
loggingParams := middleware.GetLoggingParamsFromContext(d.ctx)
if loggingParams != nil {
loggingParams.AddTimeSlotInt("open_db_transaction_ms", elapsed)
}
}
}
err = tx.Wrap(func() error {
q := newCancellableTxQuerier(d.ctx, tx)
workerErr := worker(q)
start = time.Now()
return workerErr
})
if d.logger != nil {
elapsed := time.Since(start).Milliseconds()
d.logger.Debug(
fmt.Sprintf("closed DB transaction (%s) in %dms", d.loggingCtx, elapsed),
golibslog.Int64("duration_ms", elapsed),
)
if d.ctx != nil {
loggingParams := middleware.GetLoggingParamsFromContext(d.ctx)
if loggingParams != nil {
loggingParams.AddTimeSlotInt("closed_db_transaction_ms", elapsed)
}
}
}
return err
}
// WithTxOpts allows passing additional options for opened tx
func (d *DB) WithTxOpts(txOpts *sql.TxOptions) *DB {
d.txOpts = txOpts
return d
}
// WithLogging enables logging of time consumed on openning/getting DB connection from pool
func (d *DB) WithLogging(logger golibslog.FieldLogger, loggingCtx string, loggingTimeThresholdBeginTx time.Duration) *DB {
d.logger = logger
d.loggingCtx = loggingCtx
d.loggingTimeThresholdBeginTx = loggingTimeThresholdBeginTx
return d
}