Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added tests for decimal #790

Merged
merged 5 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions core/types/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -1237,6 +1237,10 @@ func (c *DataType) Equals(other *DataType) bool {
}

func (c *DataType) IsNumeric() bool {
if c.IsArray {
return false
}

return c.Name == intStr || c.Name == DecimalStr || c.Name == uint256Str || c.Name == unknownStr
}

Expand Down
48 changes: 48 additions & 0 deletions internal/engine/execution/procedure.go
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,54 @@ func (p *preparedProcedure) shapeReturn(result *sql.ResultSet) error {

for i, col := range p.returns.Fields {
result.Columns[i] = col.Name

// if the column is a decimal or a decimal array, we need to convert the values to
// the specified scale and precision
if col.Type.Name == types.DecimalStr {
// if it is an array, we need to convert each value in the array
if col.Type.IsArray {
for _, row := range result.Rows {
if row[i] == nil {
continue
}

arr, ok := row[i].([]any)
if !ok {
return fmt.Errorf("shapeReturn: expected decimal array, got %T", row[i])
}

for _, v := range arr {
if v == nil {
continue
}
dec, ok := v.(*decimal.Decimal)
if !ok {
return fmt.Errorf("shapeReturn: expected decimal, got %T", dec)
}
err := dec.SetPrecisionAndScale(col.Type.Metadata[0], col.Type.Metadata[1])
if err != nil {
return err
}
}
}
} else {
for _, row := range result.Rows {
if row[i] == nil {
continue
}

dec, ok := row[i].(*decimal.Decimal)
if !ok {
return fmt.Errorf("shapeReturn: expected decimal, got %T", row[i])
}

err := dec.SetPrecisionAndScale(col.Type.Metadata[0], col.Type.Metadata[1])
if err != nil {
return err
}
}
}
}
}

return nil
Expand Down
1 change: 0 additions & 1 deletion internal/engine/execution/queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,6 @@ func deleteSchema(ctx context.Context, tx sql.TxMaker, dbid string) error {
}

// setContextualVars sets the contextual variables for the given postgres session.
// TODO: use this function for actions too.
func setContextualVars(ctx context.Context, db sql.DB, data *common.ExecutionData) error {
// for contextual parameters, we use postgres's current_setting()
// feature for setting session variables. For example, @caller
Expand Down
39 changes: 39 additions & 0 deletions internal/engine/integration/procedure_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@ import (
"github.com/kwilteam/kwil-db/common/sql"
"github.com/kwilteam/kwil-db/core/crypto"
"github.com/kwilteam/kwil-db/core/types"
"github.com/kwilteam/kwil-db/core/types/decimal"
"github.com/kwilteam/kwil-db/core/utils/order"
"github.com/kwilteam/kwil-db/internal/engine/execution"
"github.com/kwilteam/kwil-db/parse"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -256,6 +258,25 @@ func Test_Procedures(t *testing.T) {
}`,
outputs: [][]any{{int64(1), int64(3)}},
},
{
name: "sum",
procedure: `procedure sum() public view returns (sum decimal(1000,0)) {
for $row in select sum(user_num) as s from users {
return $row.s;
}
}`,
outputs: [][]any{{mustDecimal("6", 1000, 0)}},
},
{
name: "decimal array",
procedure: `procedure decimal_array() public view returns (decimals decimal(2,1)[]) {
$a := 2.5;
$b := 3.5;
$c := $a/$b;
return [$a, $b, $c];
}`,
outputs: [][]any{{[]any{mustDecimal("2.5", 2, 1), mustDecimal("3.5", 2, 1), mustDecimal("0.7", 2, 1)}}},
},
}

for _, test := range tests {
Expand Down Expand Up @@ -301,13 +322,31 @@ func Test_Procedures(t *testing.T) {
for i, output := range test.outputs {
require.Len(t, res.Rows[i], len(output))
for j, val := range output {
if dec, ok := val.(*decimal.Decimal); ok {
received := res.Rows[i][j].(*decimal.Decimal)

assert.Equal(t, dec.String(), received.String())
assert.Equal(t, dec.Precision(), received.Precision())
assert.Equal(t, dec.Scale(), received.Scale())
continue
}

require.Equal(t, val, res.Rows[i][j])
}
}
})
}
}

func mustDecimal(val string, precision, scale uint16) *decimal.Decimal {
d, err := decimal.NewExplicit(val, precision, scale)
if err != nil {
panic(err)
}

return d
}

func Test_ForeignProcedures(t *testing.T) {
type testcase struct {
name string
Expand Down
34 changes: 31 additions & 3 deletions parse/functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -563,15 +563,32 @@ var (
},
"sum": {
ValidateArgs: func(args []*types.DataType) (*types.DataType, error) {
// per https://www.postgresql.org/docs/current/datatype-numeric.html#DATATYPE-NUMERIC-TABLE
// the result of sum will be made a decimal(1000, 0)
if len(args) != 1 {
return nil, wrapErrArgumentNumber(1, len(args))
}

if !args[0].EqualsStrict(types.IntType) {
return nil, wrapErrArgumentType(types.IntType, args[0])
if !args[0].IsNumeric() {
return nil, fmt.Errorf("expected argument to be numeric, got %s", args[0].String())
}

return types.IntType, nil
var retType *types.DataType
switch {
case args[0].EqualsStrict(types.IntType):
retType = decimal1000.Copy()
case args[0].Name == types.DecimalStr:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess you're not using EqualsStrict because the metadata (precision and scale) isn't relevant?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly. Regardless of what the precision and scale is, if it is any sort of decimal we want to set it to precision 1000

retType = args[0].Copy()
retType.Metadata[0] = 1000 // max precision
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any need to modify scale? Error if not already 0? I'm just guessing about the assumptions here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so. In postgres, sum will always maintain the input's scale. If we were multiplying or dividing, I would be a bit more concerned about scale, but for sum it really seems like we should just give as much precision as possible and keep the scale the same

case args[0].EqualsStrict(types.Uint256Type):
retType = decimal1000.Copy()
case args[0].EqualsStrict(types.UnknownType):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you give an example what's a unknown numeric type?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unknown is a special case that can be used for any type. I refactored this part though, as noted by Jon's bug above.

retType = types.UnknownType.Copy()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EqualsStrict has this:

	// if unknown, return true. unknown is a special case used
	// internally when type checking is disabled.
	if c.Name == unknownStr || other.Name == unknownStr {
		return true
	}

So this can be true with any type?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh jeez, that's a good catch. Will fix this

default:
panic(fmt.Sprintf("unexpected numeric type: %s", retType.String()))
}

return retType, nil
},
IsAggregate: true,
PGFormat: func(inputs []string, distinct bool, star bool) (string, error) {
Expand Down Expand Up @@ -653,6 +670,17 @@ func defaultFormat(name string) FormatFunc {
}
}

// decimal1000 is a decimal type with a precision of 1000.
var decimal1000 *types.DataType

func init() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel we should always put init at the very top, but I couldn't find a guide for this. nbd

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I thought about it, but honestly I think it's more clear next to the decimal since it is what we are initializing.

var err error
decimal1000, err = types.NewDecimalType(1000, 0)
if err != nil {
panic(fmt.Sprintf("failed to create decimal type: 1000, 0: %v", err))
}
}

func errDistinct(funcName string) error {
return fmt.Errorf(`%w: cannot use DISTINCT with function "%s"`, ErrFunctionSignature, funcName)
}
Expand Down
90 changes: 90 additions & 0 deletions parse/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1566,6 +1566,88 @@ func Test_Procedure(t *testing.T) {
},
},
},
{
name: "sum types - failure",
proc: `
$sum := 0;
for $row in select sum(id) as id from users {
$sum := $sum + $row.id;
}
`,
// this should error, since sum returns numeric
err: parse.ErrType,
},
{
name: "sum types - success",
proc: `
$sum decimal(1000,0);
for $row in select sum(id) as id from users {
$sum := $sum + $row.id;
}
`,
want: &parse.ProcedureParseResult{
Variables: map[string]*types.DataType{
"$sum": mustNewDecimal(1000, 0),
},
CompoundVariables: map[string]struct{}{
"$row": {},
},
AST: []parse.ProcedureStmt{
&parse.ProcedureStmtDeclaration{
Variable: exprVar("$sum"),
Type: mustNewDecimal(1000, 0),
},
&parse.ProcedureStmtForLoop{
Receiver: exprVar("$row"),
LoopTerm: &parse.LoopTermSQL{
Statement: &parse.SQLStatement{
SQL: &parse.SelectStatement{
SelectCores: []*parse.SelectCore{
{
Columns: []parse.ResultColumn{
&parse.ResultColumnExpression{
Expression: &parse.ExpressionFunctionCall{
Name: "sum",
Args: []parse.Expression{
exprColumn("", "id"),
},
},
Alias: "id",
},
},
From: &parse.RelationTable{
Table: "users",
},
},
},
// If there is an aggregate clause with no group by, then no ordering is applied.
},
},
},
Body: []parse.ProcedureStmt{
&parse.ProcedureStmtAssign{
Variable: exprVar("$sum"),
Value: &parse.ExpressionArithmetic{
Left: exprVar("$sum"),
Operator: parse.ArithmeticOperatorAdd,
Right: &parse.ExpressionFieldAccess{Record: exprVar("$row"), Field: "id"},
},
},
},
},
},
},
},
{
// this is a regression test for a previous bug
name: "adding arrays",
proc: `
$arr1 := [1,2,3];
$arr2 := [4,5,6];
$arr3 := $arr1 + $arr2;
`,
err: parse.ErrType,
},
}

for _, tt := range tests {
Expand Down Expand Up @@ -1644,6 +1726,14 @@ func exprVar(n string) *parse.ExpressionVariable {
}
}

func mustNewDecimal(precision, scale uint16) *types.DataType {
dt, err := types.NewDecimalType(precision, scale)
if err != nil {
panic(err)
}
return dt
}

// exprLit makes an ExpressionLiteral.
// it can only make strings and ints
func exprLit(v any) *parse.ExpressionLiteral {
Expand Down
Loading