Skip to content

Commit

Permalink
added tests for decimal (#790)
Browse files Browse the repository at this point in the history
* added tests for decimal

* added regression test for arrays

* fixed build tag

* fixed nil error

* fixed bug with unknown type in sum
  • Loading branch information
brennanjl authored Jun 4, 2024
1 parent 53e8d38 commit 20c53ed
Show file tree
Hide file tree
Showing 6 changed files with 216 additions and 4 deletions.
4 changes: 4 additions & 0 deletions core/types/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -1237,6 +1237,10 @@ func (c *DataType) Equals(other *DataType) bool {
}

func (c *DataType) IsNumeric() bool {
if c.IsArray {
return false
}

return c.Name == intStr || c.Name == DecimalStr || c.Name == uint256Str || c.Name == unknownStr
}

Expand Down
48 changes: 48 additions & 0 deletions internal/engine/execution/procedure.go
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,54 @@ func (p *preparedProcedure) shapeReturn(result *sql.ResultSet) error {

for i, col := range p.returns.Fields {
result.Columns[i] = col.Name

// if the column is a decimal or a decimal array, we need to convert the values to
// the specified scale and precision
if col.Type.Name == types.DecimalStr {
// if it is an array, we need to convert each value in the array
if col.Type.IsArray {
for _, row := range result.Rows {
if row[i] == nil {
continue
}

arr, ok := row[i].([]any)
if !ok {
return fmt.Errorf("shapeReturn: expected decimal array, got %T", row[i])
}

for _, v := range arr {
if v == nil {
continue
}
dec, ok := v.(*decimal.Decimal)
if !ok {
return fmt.Errorf("shapeReturn: expected decimal, got %T", dec)
}
err := dec.SetPrecisionAndScale(col.Type.Metadata[0], col.Type.Metadata[1])
if err != nil {
return err
}
}
}
} else {
for _, row := range result.Rows {
if row[i] == nil {
continue
}

dec, ok := row[i].(*decimal.Decimal)
if !ok {
return fmt.Errorf("shapeReturn: expected decimal, got %T", row[i])
}

err := dec.SetPrecisionAndScale(col.Type.Metadata[0], col.Type.Metadata[1])
if err != nil {
return err
}
}
}
}
}

return nil
Expand Down
1 change: 0 additions & 1 deletion internal/engine/execution/queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,6 @@ func deleteSchema(ctx context.Context, tx sql.TxMaker, dbid string) error {
}

// setContextualVars sets the contextual variables for the given postgres session.
// TODO: use this function for actions too.
func setContextualVars(ctx context.Context, db sql.DB, data *common.ExecutionData) error {
// for contextual parameters, we use postgres's current_setting()
// feature for setting session variables. For example, @caller
Expand Down
39 changes: 39 additions & 0 deletions internal/engine/integration/procedure_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@ import (
"github.com/kwilteam/kwil-db/common/sql"
"github.com/kwilteam/kwil-db/core/crypto"
"github.com/kwilteam/kwil-db/core/types"
"github.com/kwilteam/kwil-db/core/types/decimal"
"github.com/kwilteam/kwil-db/core/utils/order"
"github.com/kwilteam/kwil-db/internal/engine/execution"
"github.com/kwilteam/kwil-db/parse"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -256,6 +258,25 @@ func Test_Procedures(t *testing.T) {
}`,
outputs: [][]any{{int64(1), int64(3)}},
},
{
name: "sum",
procedure: `procedure sum() public view returns (sum decimal(1000,0)) {
for $row in select sum(user_num) as s from users {
return $row.s;
}
}`,
outputs: [][]any{{mustDecimal("6", 1000, 0)}},
},
{
name: "decimal array",
procedure: `procedure decimal_array() public view returns (decimals decimal(2,1)[]) {
$a := 2.5;
$b := 3.5;
$c := $a/$b;
return [$a, $b, $c];
}`,
outputs: [][]any{{[]any{mustDecimal("2.5", 2, 1), mustDecimal("3.5", 2, 1), mustDecimal("0.7", 2, 1)}}},
},
}

for _, test := range tests {
Expand Down Expand Up @@ -301,13 +322,31 @@ func Test_Procedures(t *testing.T) {
for i, output := range test.outputs {
require.Len(t, res.Rows[i], len(output))
for j, val := range output {
if dec, ok := val.(*decimal.Decimal); ok {
received := res.Rows[i][j].(*decimal.Decimal)

assert.Equal(t, dec.String(), received.String())
assert.Equal(t, dec.Precision(), received.Precision())
assert.Equal(t, dec.Scale(), received.Scale())
continue
}

require.Equal(t, val, res.Rows[i][j])
}
}
})
}
}

func mustDecimal(val string, precision, scale uint16) *decimal.Decimal {
d, err := decimal.NewExplicit(val, precision, scale)
if err != nil {
panic(err)
}

return d
}

func Test_ForeignProcedures(t *testing.T) {
type testcase struct {
name string
Expand Down
38 changes: 35 additions & 3 deletions parse/functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -563,15 +563,36 @@ var (
},
"sum": {
ValidateArgs: func(args []*types.DataType) (*types.DataType, error) {
// per https://www.postgresql.org/docs/current/datatype-numeric.html#DATATYPE-NUMERIC-TABLE
// the result of sum will be made a decimal(1000, 0)
if len(args) != 1 {
return nil, wrapErrArgumentNumber(1, len(args))
}

if !args[0].EqualsStrict(types.IntType) {
return nil, wrapErrArgumentType(types.IntType, args[0])
if !args[0].IsNumeric() {
return nil, fmt.Errorf("expected argument to be numeric, got %s", args[0].String())
}

return types.IntType, nil
// we check if it is an unknown type before the switch,
// as unknown will be true for all EqualsStrict checks
if args[0] == types.UnknownType {
return types.UnknownType, nil
}

var retType *types.DataType
switch {
case args[0].EqualsStrict(types.IntType):
retType = decimal1000.Copy()
case args[0].Name == types.DecimalStr:
retType = args[0].Copy()
retType.Metadata[0] = 1000 // max precision
case args[0].EqualsStrict(types.Uint256Type):
retType = decimal1000.Copy()
default:
panic(fmt.Sprintf("unexpected numeric type: %s", retType.String()))
}

return retType, nil
},
IsAggregate: true,
PGFormat: func(inputs []string, distinct bool, star bool) (string, error) {
Expand Down Expand Up @@ -653,6 +674,17 @@ func defaultFormat(name string) FormatFunc {
}
}

// decimal1000 is a decimal type with a precision of 1000.
var decimal1000 *types.DataType

func init() {
var err error
decimal1000, err = types.NewDecimalType(1000, 0)
if err != nil {
panic(fmt.Sprintf("failed to create decimal type: 1000, 0: %v", err))
}
}

func errDistinct(funcName string) error {
return fmt.Errorf(`%w: cannot use DISTINCT with function "%s"`, ErrFunctionSignature, funcName)
}
Expand Down
90 changes: 90 additions & 0 deletions parse/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1566,6 +1566,88 @@ func Test_Procedure(t *testing.T) {
},
},
},
{
name: "sum types - failure",
proc: `
$sum := 0;
for $row in select sum(id) as id from users {
$sum := $sum + $row.id;
}
`,
// this should error, since sum returns numeric
err: parse.ErrType,
},
{
name: "sum types - success",
proc: `
$sum decimal(1000,0);
for $row in select sum(id) as id from users {
$sum := $sum + $row.id;
}
`,
want: &parse.ProcedureParseResult{
Variables: map[string]*types.DataType{
"$sum": mustNewDecimal(1000, 0),
},
CompoundVariables: map[string]struct{}{
"$row": {},
},
AST: []parse.ProcedureStmt{
&parse.ProcedureStmtDeclaration{
Variable: exprVar("$sum"),
Type: mustNewDecimal(1000, 0),
},
&parse.ProcedureStmtForLoop{
Receiver: exprVar("$row"),
LoopTerm: &parse.LoopTermSQL{
Statement: &parse.SQLStatement{
SQL: &parse.SelectStatement{
SelectCores: []*parse.SelectCore{
{
Columns: []parse.ResultColumn{
&parse.ResultColumnExpression{
Expression: &parse.ExpressionFunctionCall{
Name: "sum",
Args: []parse.Expression{
exprColumn("", "id"),
},
},
Alias: "id",
},
},
From: &parse.RelationTable{
Table: "users",
},
},
},
// If there is an aggregate clause with no group by, then no ordering is applied.
},
},
},
Body: []parse.ProcedureStmt{
&parse.ProcedureStmtAssign{
Variable: exprVar("$sum"),
Value: &parse.ExpressionArithmetic{
Left: exprVar("$sum"),
Operator: parse.ArithmeticOperatorAdd,
Right: &parse.ExpressionFieldAccess{Record: exprVar("$row"), Field: "id"},
},
},
},
},
},
},
},
{
// this is a regression test for a previous bug
name: "adding arrays",
proc: `
$arr1 := [1,2,3];
$arr2 := [4,5,6];
$arr3 := $arr1 + $arr2;
`,
err: parse.ErrType,
},
}

for _, tt := range tests {
Expand Down Expand Up @@ -1644,6 +1726,14 @@ func exprVar(n string) *parse.ExpressionVariable {
}
}

func mustNewDecimal(precision, scale uint16) *types.DataType {
dt, err := types.NewDecimalType(precision, scale)
if err != nil {
panic(err)
}
return dt
}

// exprLit makes an ExpressionLiteral.
// it can only make strings and ints
func exprLit(v any) *parse.ExpressionLiteral {
Expand Down

0 comments on commit 20c53ed

Please sign in to comment.