@@ -19,6 +19,7 @@ import (
19
19
20
20
"github.com/cockroachdb/errors"
21
21
"github.com/dolthub/go-mysql-server/sql"
22
+ "github.com/dolthub/go-mysql-server/sql/expression"
22
23
"github.com/dolthub/go-mysql-server/sql/plan"
23
24
24
25
"github.com/dolthub/doltgresql/server/functions/framework"
@@ -39,16 +40,16 @@ type AnyExpr struct {
39
40
// subqueryAnyExpr represents the resolved comparison functions for a plan.Subquery.
40
41
type subqueryAnyExpr struct {
41
42
rightSub * plan.Subquery
42
- staticLiteral * Literal
43
- arrayLiterals []* Literal
43
+ staticLiteral * expression. Literal
44
+ arrayLiterals []* expression. Literal
44
45
compFuncs []framework.Function
45
46
}
46
47
47
48
// expressionAnyExpr represents the resolved comparison function for a sql.Expression.
48
49
type expressionAnyExpr struct {
49
50
rightExpr sql.Expression
50
- staticLiteral * Literal
51
- arrayLiteral * Literal
51
+ staticLiteral * expression. Literal
52
+ arrayLiteral * expression. Literal
52
53
compFunc framework.Function
53
54
}
54
55
@@ -130,7 +131,7 @@ func (a *subqueryAnyExpr) eval(ctx *sql.Context, subOperator string, row sql.Row
130
131
}
131
132
132
133
for i := len (a .arrayLiterals ); i < len (rightValues ); i ++ {
133
- arrayLiteral := & Literal { typ : a .arrayLiterals [0 ].typ }
134
+ arrayLiteral := expression . NewLiteral ( nil , a .arrayLiterals [0 ].Type ())
134
135
a .arrayLiterals = append (a .arrayLiterals , arrayLiteral )
135
136
compFunc := framework .GetBinaryFunction (op ).Compile ("internal_any_comparison" , a .staticLiteral , a .arrayLiterals [i ])
136
137
a .compFuncs = append (a .compFuncs , compFunc )
@@ -142,9 +143,10 @@ func (a *subqueryAnyExpr) eval(ctx *sql.Context, subOperator string, row sql.Row
142
143
}
143
144
144
145
// Next we'll assign our evaluated values to the expressions that the comparison functions reference
145
- a .staticLiteral .value = left
146
+ // Note that the compiled function has a reference to the staticLiteral and arrayLiterals, so we must alter them in place
147
+ a .staticLiteral .Val = left
146
148
for i , rightValue := range rightValues {
147
- a .arrayLiterals [i ].value = rightValue
149
+ a .arrayLiterals [i ].Val = rightValue
148
150
}
149
151
// Now we can loop over all comparison functions, as they'll reference their respective values
150
152
for _ , compFunc := range a .compFuncs {
@@ -192,9 +194,10 @@ func (a *expressionAnyExpr) eval(ctx *sql.Context, row sql.Row, left interface{}
192
194
}
193
195
194
196
// Next we'll assign our evaluated values to the expressions that the comparison function reference
195
- a .staticLiteral .value = left
197
+ // Note that the compiled function has a reference to the staticLiteral and arrayLiteral, so we must alter them in place
198
+ a .staticLiteral .Val = left
196
199
for _ , rightValue := range rightValues {
197
- a .arrayLiteral .value = rightValue
200
+ a .arrayLiteral .Val = rightValue
198
201
result , err := a .compFunc .Eval (ctx , row )
199
202
if err != nil {
200
203
return nil , err
@@ -293,12 +296,12 @@ func anySubqueryWithChildren(anyExpr *AnyExpr, sub *plan.Subquery) (sql.Expressi
293
296
294
297
if leftType , ok := anyExpr .leftExpr .Type ().(* pgtypes.DoltgresType ); ok {
295
298
// Resolve comparison functions once and reuse the functions in Eval.
296
- staticLiteral := & Literal { typ : leftType }
297
- arrayLiterals := make ([]* Literal , len (subTypes ))
299
+ staticLiteral := expression . NewLiteral ( nil , leftType )
300
+ arrayLiterals := make ([]* expression. Literal , len (subTypes ))
298
301
// Each expression may be a different type (which is valid), so we need a comparison function for each expression.
299
302
compFuncs := make ([]framework.Function , len (subTypes ))
300
303
for i , rightType := range subTypes {
301
- arrayLiterals [i ] = & Literal { typ : rightType }
304
+ arrayLiterals [i ] = expression . NewLiteral ( nil , rightType )
302
305
compFuncs [i ] = framework .GetBinaryFunction (op ).Compile ("internal_any_comparison" , staticLiteral , arrayLiterals [i ])
303
306
if compFuncs [i ] == nil {
304
307
return nil , errors .Errorf ("operator does not exist: %s = %s" , leftType .String (), rightType .String ())
@@ -334,8 +337,8 @@ func anyExpressionWithChildren(anyExpr *AnyExpr) (sql.Expression, error) {
334
337
335
338
if leftType , ok := anyExpr .leftExpr .Type ().(* pgtypes.DoltgresType ); ok {
336
339
// Resolve comparison function once and reuse the function in Eval.
337
- staticLiteral := & Literal { typ : leftType }
338
- arrayLiteral := & Literal { typ : rightType }
340
+ staticLiteral := expression . NewLiteral ( nil , leftType )
341
+ arrayLiteral := expression . NewLiteral ( nil , rightType )
339
342
compFunc := framework .GetBinaryFunction (op ).Compile ("internal_any_comparison" , staticLiteral , arrayLiteral )
340
343
if compFunc == nil {
341
344
return nil , errors .Errorf ("operator does not exist: %s = %s" , leftType .String (), rightType .String ())
0 commit comments