Skip to content

Commit 4ff2231

Browse files
committed
serializer factory + rm unpublicly content
1 parent bf16bab commit 4ff2231

File tree

2 files changed

+67
-23
lines changed

2 files changed

+67
-23
lines changed

thrift/sql/serializing.go

+16-11
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,8 @@ import (
1010
"github.com/upfluence/pkg/thrift/thriftutil"
1111
)
1212

13-
var (
14-
binaryThriftSerializer = serializer.NewTSerializer(
15-
thriftutil.BinaryProtocolFactory,
16-
)
17-
18-
binaryThriftDeserializer = serializer.NewTDeserializer(
19-
thriftutil.BinaryProtocolFactory,
20-
)
13+
var defaultSerializerFactory = serializer.NewTSerializerFactory(
14+
thriftutil.BinaryProtocolFactory,
2115
)
2216

2317
type TStructPtr[T any] interface {
@@ -26,7 +20,16 @@ type TStructPtr[T any] interface {
2620
}
2721

2822
type NullThrift[T any, PT TStructPtr[T]] struct {
29-
Data PT
23+
Data PT
24+
SerializerFactory *serializer.TSerializerFactory
25+
}
26+
27+
func (t NullThrift[T, PT]) serializerFactory() *serializer.TSerializerFactory {
28+
if t.SerializerFactory == nil {
29+
return defaultSerializerFactory
30+
}
31+
32+
return t.SerializerFactory
3033
}
3134

3235
func (t *NullThrift[T, PT]) Scan(value any) error {
@@ -46,15 +49,17 @@ func (t *NullThrift[T, PT]) Scan(value any) error {
4649
t.Data = new(T)
4750
}
4851

49-
return errors.WithStack(binaryThriftDeserializer.Read(t.Data, data))
52+
return errors.WithStack(
53+
t.serializerFactory().GetDeserializer().Read(t.Data, data),
54+
)
5055
}
5156

5257
func (t NullThrift[T, PT]) Value() (driver.Value, error) {
5358
if t.Data == nil {
5459
return nil, nil // nolint:nilnil
5560
}
5661

57-
data, err := binaryThriftSerializer.Write(t.Data)
62+
data, err := t.serializerFactory().GetSerializer().Write(t.Data)
5863

5964
return data, errors.WithStack(err)
6065
}

thrift/sql/serializing_test.go

+51-12
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,76 @@
11
package sql
22

33
import (
4+
"fmt"
45
"testing"
56

67
"github.com/stretchr/testify/assert"
78
"github.com/stretchr/testify/require"
8-
"github.com/upfluence/base/provider/credential"
9+
"github.com/upfluence/errors"
10+
"github.com/upfluence/thrift/lib/go/thrift"
11+
12+
"github.com/upfluence/pkg/thrift/serializer"
13+
"github.com/upfluence/pkg/thrift/thriftutil"
914
)
1015

11-
func TestNullableThrift(t *testing.T) {
16+
type fakeTStruct struct {
17+
value int64
18+
}
19+
20+
func (t fakeTStruct) Write(p thrift.TProtocol) error {
21+
return errors.WithStack(p.WriteI64(t.value))
22+
}
23+
24+
func (t *fakeTStruct) Read(p thrift.TProtocol) error {
25+
val, err := p.ReadI64()
26+
27+
if err != nil {
28+
return errors.WithStack(err)
29+
}
30+
31+
t.value = val
32+
33+
return nil
34+
}
35+
36+
func (t *fakeTStruct) String() string {
37+
return fmt.Sprint(t.value)
38+
}
39+
40+
func TestNullableThrift_Scan(t *testing.T) {
1241
for _, tt := range []struct {
13-
name string
14-
wantValue *credential.CredentialReference
42+
name string
43+
fakeValue *fakeTStruct
44+
serializerFactory *serializer.TSerializerFactory
1545
}{
1646
{
1747
name: "nil value",
18-
wantValue: nil,
48+
fakeValue: nil,
1949
},
2050
{
2151
name: "with value",
22-
wantValue: &credential.CredentialReference{
23-
Type: credential.CredentialType_StripeConnectedAccount,
24-
Id: 42,
52+
fakeValue: &fakeTStruct{
53+
value: 42,
54+
},
55+
},
56+
{
57+
name: "with custom serializer factory",
58+
fakeValue: &fakeTStruct{
59+
value: 42,
2560
},
61+
serializerFactory: serializer.NewTSerializerFactory(
62+
thriftutil.JSONProtocolFactory,
63+
),
2664
},
2765
} {
2866
t.Run(tt.name, func(t *testing.T) {
2967
var (
3068
s = NullThrift[
31-
credential.CredentialReference,
32-
*credential.CredentialReference,
69+
fakeTStruct,
70+
*fakeTStruct,
3371
]{
34-
Data: tt.wantValue,
72+
Data: tt.fakeValue,
73+
SerializerFactory: tt.serializerFactory,
3574
}
3675

3776
data, err = s.Value()
@@ -42,7 +81,7 @@ func TestNullableThrift(t *testing.T) {
4281
s.Data = nil
4382

4483
require.NoError(t, s.Scan(data))
45-
assert.Equal(t, tt.wantValue, s.Data)
84+
assert.Equal(t, tt.fakeValue, s.Data)
4685
})
4786
}
4887
}

0 commit comments

Comments
 (0)