Skip to content

Commit 0388381

Browse files
middleware/prefix: add middleware to prefix all keys
1 parent ca97e5c commit 0388381

File tree

7 files changed

+391
-99
lines changed

7 files changed

+391
-99
lines changed

backend/db.go

+4-97
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,11 @@ package backend
33
import (
44
"context"
55
"errors"
6-
"fmt"
7-
"reflect"
86

97
goredis "github.com/redis/go-redis/v9"
108

119
"github.com/upfluence/redis"
10+
internal "github.com/upfluence/redis/internal/scanner"
1211
)
1312

1413
type db struct {
@@ -37,10 +36,11 @@ func (d *db) Do(ctx context.Context, cmd string, vs ...interface{}) redis.Scanne
3736

3837
if vv, ok := v.(redis.Valuer); ok {
3938
var err error
39+
4040
rv, err = vv.Value()
4141

4242
if err != nil {
43-
return errScanner{err: err}
43+
return &internal.ErrScanner{Err: err}
4444
}
4545
}
4646

@@ -69,98 +69,5 @@ func (s *scanner) Scan(vs ...interface{}) error {
6969
return err
7070
}
7171

72-
switch ssrc := src.(type) {
73-
case []any:
74-
if len(vs) == 1 {
75-
if sc, ok := vs[0].(redis.ValueScanner); ok {
76-
return sc.Scan(src)
77-
}
78-
79-
rv := reflect.ValueOf(vs[0])
80-
81-
if rv.Kind() != reflect.Pointer {
82-
return errNilPtr
83-
}
84-
85-
rve := rv.Elem()
86-
87-
if rve.Kind() == reflect.Slice {
88-
for _, src := range ssrc {
89-
rv := reflect.New(rve.Type().Elem())
90-
91-
if err := convertAssign(rv.Interface(), src); err != nil {
92-
return err
93-
}
94-
95-
rve = reflect.Append(rve, rv.Elem())
96-
}
97-
98-
rv.Elem().Set(rve)
99-
100-
return nil
101-
}
102-
103-
}
104-
105-
if len(vs) != len(ssrc) {
106-
return fmt.Errorf("unsupported Scan, storing driver.Value type %T into multiple values", src)
107-
}
108-
109-
for i, dst := range vs {
110-
if err := convertAssign(dst, ssrc[i]); err != nil {
111-
return err
112-
}
113-
}
114-
115-
return nil
116-
case map[any]any:
117-
if len(vs) > 1 {
118-
return fmt.Errorf("unsupported Scan, storing driver.Value type %T into multiple values", src)
119-
}
120-
121-
if sc, ok := vs[0].(redis.ValueScanner); ok {
122-
return sc.Scan(src)
123-
}
124-
125-
rv := reflect.ValueOf(vs[0])
126-
127-
if rv.Kind() != reflect.Pointer {
128-
return errNilPtr
129-
}
130-
131-
rv = rv.Elem()
132-
133-
if rv.Kind() != reflect.Map {
134-
return fmt.Errorf("unsupported Scan, storing driver.Value type %T into %T", src, vs[0])
135-
}
136-
137-
for k, v := range ssrc {
138-
rk := reflect.New(rv.Type().Key())
139-
re := reflect.New(rv.Type().Elem())
140-
141-
if err := convertAssign(rk.Interface(), k); err != nil {
142-
return err
143-
}
144-
145-
if err := convertAssign(re.Interface(), v); err != nil {
146-
return err
147-
}
148-
149-
rv.SetMapIndex(rk.Elem(), re.Elem())
150-
}
151-
152-
return nil
153-
}
154-
155-
if len(vs) > 1 {
156-
return fmt.Errorf("unsupported Scan, storing driver.Value type %T into multiple values", src)
157-
}
158-
159-
return convertAssign(vs[0], src)
72+
return internal.Assign(src, vs)
16073
}
161-
162-
type errScanner struct {
163-
err error
164-
}
165-
166-
func (es errScanner) Scan(_ ...interface{}) error { return es.err }

backend/db_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ func TestIntegration(t *testing.T) {
5151
assert.Equal(t, "1", foo)
5252

5353
sfoo = nil
54-
err = db.Do(ctx, "ZMPOP", 1, "foob", "MIN").Scan(&sfoo)
54+
err = db.Do(ctx, "ZRANDMEMBER", "foob").Scan(&sfoo)
5555

5656
require.ErrorIs(t, err, redis.Empty)
5757
assert.Len(t, sfoo, 0)

backend/convert.go internal/scanner/convert.go

+104-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package backend
1+
package scanner
22

33
import (
44
"bytes"
@@ -236,3 +236,106 @@ func asBytes(rv reflect.Value) ([]byte, bool) {
236236

237237
return buf, false
238238
}
239+
240+
func Assign(src any, dsts []any) error {
241+
switch ssrc := src.(type) {
242+
case []any:
243+
if len(dsts) == 1 {
244+
if sc, ok := dsts[0].(redis.ValueScanner); ok {
245+
return sc.Scan(src)
246+
}
247+
248+
rv := reflect.ValueOf(dsts[0])
249+
250+
if rv.Kind() != reflect.Pointer {
251+
return errNilPtr
252+
}
253+
254+
rve := rv.Elem()
255+
256+
if rve.Kind() == reflect.Slice {
257+
for _, src := range ssrc {
258+
rv := reflect.New(rve.Type().Elem())
259+
260+
if err := convertAssign(rv.Interface(), src); err != nil {
261+
return err
262+
}
263+
264+
rve = reflect.Append(rve, rv.Elem())
265+
}
266+
267+
rv.Elem().Set(rve)
268+
269+
return nil
270+
}
271+
272+
}
273+
274+
if len(dsts) != len(ssrc) {
275+
return fmt.Errorf("unsupported Scan, storing driver.Value type %T into multiple values", src)
276+
}
277+
278+
for i, dst := range dsts {
279+
if err := convertAssign(dst, ssrc[i]); err != nil {
280+
return err
281+
}
282+
}
283+
284+
return nil
285+
case map[any]any:
286+
if len(dsts) > 1 {
287+
return fmt.Errorf("unsupported Scan, storing driver.Value type %T into multiple values", src)
288+
}
289+
290+
if sc, ok := dsts[0].(redis.ValueScanner); ok {
291+
return sc.Scan(src)
292+
}
293+
294+
rv := reflect.ValueOf(dsts[0])
295+
296+
if rv.Kind() != reflect.Pointer {
297+
return errNilPtr
298+
}
299+
300+
rv = rv.Elem()
301+
302+
if rv.Kind() != reflect.Map {
303+
return fmt.Errorf("unsupported Scan, storing driver.Value type %T into %T", src, dsts[0])
304+
}
305+
306+
for k, v := range ssrc {
307+
rk := reflect.New(rv.Type().Key())
308+
re := reflect.New(rv.Type().Elem())
309+
310+
if err := convertAssign(rk.Interface(), k); err != nil {
311+
return err
312+
}
313+
314+
if err := convertAssign(re.Interface(), v); err != nil {
315+
return err
316+
}
317+
318+
rv.SetMapIndex(rk.Elem(), re.Elem())
319+
}
320+
321+
return nil
322+
}
323+
324+
if len(dsts) > 1 {
325+
return fmt.Errorf("unsupported Scan, storing driver.Value type %T into multiple values", src)
326+
}
327+
328+
return convertAssign(dsts[0], src)
329+
}
330+
331+
type StaticScanner struct {
332+
Val any
333+
}
334+
335+
func (ss *StaticScanner) Scan(vs ...interface{}) error {
336+
if len(vs) == 0 {
337+
return nil
338+
}
339+
340+
return Assign(ss.Val, vs)
341+
}

internal/scanner/err_scanner.go

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
package scanner
2+
3+
type ErrScanner struct {
4+
Err error
5+
}
6+
7+
func (es *ErrScanner) Scan(_ ...interface{}) error { return es.Err }

middleware/prefix/db.go

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
package prefix
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"strings"
7+
8+
"github.com/upfluence/redis"
9+
"github.com/upfluence/redis/internal/scanner"
10+
)
11+
12+
type Option func(*factory)
13+
14+
func WithExecutor(cmd string, exc Executor) Option {
15+
return func(mf *factory) { mf.vs[cmd] = exc }
16+
}
17+
18+
type factory struct {
19+
prefix string
20+
21+
vs map[string]Executor
22+
}
23+
24+
func NewFactory(prefix string, opts ...Option) redis.MiddlewareFactory {
25+
if prefix != "" && !strings.HasSuffix(prefix, ":") {
26+
prefix = prefix + ":"
27+
}
28+
29+
f := factory{prefix: prefix, vs: allExecutors}
30+
31+
for _, opt := range opts {
32+
opt(&f)
33+
}
34+
35+
return &f
36+
}
37+
38+
func (f *factory) Wrap(next redis.DB) redis.DB {
39+
if f.prefix == "" {
40+
return next
41+
}
42+
43+
return &db{next: next, prefix: f.prefix, vs: f.vs}
44+
}
45+
46+
type db struct {
47+
next redis.DB
48+
prefix string
49+
vs map[string]Executor
50+
}
51+
52+
func (db *db) Close() error { return db.next.Close() }
53+
54+
func (db *db) Do(ctx context.Context, cmd string, vs ...interface{}) redis.Scanner {
55+
e, ok := db.vs[cmd]
56+
57+
if !ok {
58+
return &scanner.ErrScanner{
59+
Err: fmt.Errorf("prefix wrapping for cmd %q not implemented", cmd),
60+
}
61+
}
62+
63+
return e.Execute(ctx, db.next, db.prefix, cmd, vs)
64+
}

0 commit comments

Comments
 (0)