1
1
package sql
2
2
3
3
import (
4
+ "fmt"
4
5
"testing"
5
6
6
7
"github.com/stretchr/testify/assert"
7
8
"github.com/stretchr/testify/require"
8
- "github.com/upfluence/base/provider/credential"
9
+ "github.com/upfluence/errors"
10
+ "github.com/upfluence/thrift/lib/go/thrift"
11
+
12
+ "github.com/upfluence/pkg/thrift/serializer"
13
+ "github.com/upfluence/pkg/thrift/thriftutil"
9
14
)
10
15
11
- func TestNullableThrift (t * testing.T ) {
16
+ type fakeTStruct struct {
17
+ value int64
18
+ }
19
+
20
+ func (t fakeTStruct ) Write (p thrift.TProtocol ) error {
21
+ return errors .WithStack (p .WriteI64 (t .value ))
22
+ }
23
+
24
+ func (t * fakeTStruct ) Read (p thrift.TProtocol ) error {
25
+ val , err := p .ReadI64 ()
26
+
27
+ if err != nil {
28
+ return errors .WithStack (err )
29
+ }
30
+
31
+ t .value = val
32
+
33
+ return nil
34
+ }
35
+
36
+ func (t * fakeTStruct ) String () string {
37
+ return fmt .Sprint (t .value )
38
+ }
39
+
40
+ func TestNullableThrift_Scan (t * testing.T ) {
12
41
for _ , tt := range []struct {
13
- name string
14
- wantValue * credential.CredentialReference
42
+ name string
43
+ fakeValue * fakeTStruct
44
+ serializerFactory * serializer.TSerializerFactory
15
45
}{
16
46
{
17
47
name : "nil value" ,
18
- wantValue : nil ,
48
+ fakeValue : nil ,
19
49
},
20
50
{
21
51
name : "with value" ,
22
- wantValue : & credential.CredentialReference {
23
- Type : credential .CredentialType_StripeConnectedAccount ,
24
- Id : 42 ,
52
+ fakeValue : & fakeTStruct {
53
+ value : 42 ,
54
+ },
55
+ },
56
+ {
57
+ name : "with custom serializer factory" ,
58
+ fakeValue : & fakeTStruct {
59
+ value : 42 ,
25
60
},
61
+ serializerFactory : serializer .NewTSerializerFactory (
62
+ thriftutil .JSONProtocolFactory ,
63
+ ),
26
64
},
27
65
} {
28
66
t .Run (tt .name , func (t * testing.T ) {
29
67
var (
30
68
s = NullThrift [
31
- credential. CredentialReference ,
32
- * credential. CredentialReference ,
69
+ fakeTStruct ,
70
+ * fakeTStruct ,
33
71
]{
34
- Data : tt .wantValue ,
72
+ Data : tt .fakeValue ,
73
+ SerializerFactory : tt .serializerFactory ,
35
74
}
36
75
37
76
data , err = s .Value ()
@@ -42,7 +81,7 @@ func TestNullableThrift(t *testing.T) {
42
81
s .Data = nil
43
82
44
83
require .NoError (t , s .Scan (data ))
45
- assert .Equal (t , tt .wantValue , s .Data )
84
+ assert .Equal (t , tt .fakeValue , s .Data )
46
85
})
47
86
}
48
87
}
0 commit comments