Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
brennanjl committed Feb 19, 2025
1 parent ba57c18 commit a01729f
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 28 deletions.
29 changes: 15 additions & 14 deletions node/engine/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,21 @@ const (
var (
// Errors that suggest a bug in a user's executing code. These are type errors,
// issues in arithmetic, array indexing, etc.
ErrType = errors.New("type error")
ErrReturnShape = errors.New("unexpected action/function return shape")
ErrUnknownVariable = errors.New("unknown variable")
ErrInvalidVariable = errors.New("invalid variable name")
ErrLoop = errors.New("loop error")
ErrArithmetic = errors.New("arithmetic error")
ErrComparison = errors.New("comparison error")
ErrCast = errors.New("type cast error")
ErrUnary = errors.New("unary operation error")
ErrIndexOutOfBounds = errors.New("index out of bounds")
ErrArrayDimensionality = errors.New("array dimensionality error")
ErrInvalidNull = errors.New("invalid null value")
ErrArrayTooSmall = errors.New("array too small")
ErrExtensionInvocation = errors.New("extension invocation error")
ErrType = errors.New("type error")
ErrReturnShape = errors.New("unexpected action/function return shape")
ErrUnknownVariable = errors.New("unknown variable")
ErrInvalidVariable = errors.New("invalid variable name")
ErrLoop = errors.New("loop error")
ErrArithmetic = errors.New("arithmetic error")
ErrComparison = errors.New("comparison error")
ErrCast = errors.New("type cast error")
ErrUnary = errors.New("unary operation error")
ErrIndexOutOfBounds = errors.New("index out of bounds")
ErrArrayDimensionality = errors.New("array dimensionality error")
ErrInvalidNull = errors.New("invalid null value")
ErrArrayTooSmall = errors.New("array too small")
ErrExtensionImplementation = errors.New("extension implementation error")
ErrActionInvocation = errors.New("action invocation error")

// Errors that signal the existence or non-existence of an object.
ErrUnknownAction = errors.New("unknown action")
Expand Down
14 changes: 7 additions & 7 deletions node/engine/interpreter/extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func initializeExtension(ctx context.Context, svc *common.Service, db sql.DB, i
}

if len(args) != len(method.Parameters) {
return fmt.Errorf(`%w: extension method "%s" expected %d arguments, but got %d`, engine.ErrExtensionInvocation, lowerName, len(method.Parameters), len(args))
return fmt.Errorf(`%w: extension method "%s" expected %d arguments, but got %d`, engine.ErrExtensionImplementation, lowerName, len(method.Parameters), len(args))
}

argVals := make([]any, len(args))
Expand All @@ -59,13 +59,13 @@ func initializeExtension(ctx context.Context, svc *common.Service, db sql.DB, i

// ensure the argument types match
if !method.Parameters[i].Type.Equals(arg.Type()) {
return fmt.Errorf(`%w: extension method "%s" expected argument %d to be of type %s, but got %s`, engine.ErrExtensionInvocation, lowerName, i, method.Parameters[i].Type, arg.Type())
return fmt.Errorf(`%w: extension method "%s" expected argument %d to be of type %s, but got %s`, engine.ErrExtensionImplementation, lowerName, i, method.Parameters[i].Type, arg.Type())
}

// the above will be ok if the argument is nil
// we therefore check for nullability here
if !method.Parameters[i].Nullable && arg.Null() {
return fmt.Errorf(`%w: extension method "%s" expected argument %d to be non-null, but got null`, engine.ErrExtensionInvocation, lowerName, i)
return fmt.Errorf(`%w: extension method "%s" expected argument %d to be non-null, but got null`, engine.ErrExtensionImplementation, lowerName, i)
}
}

Expand All @@ -74,14 +74,14 @@ func initializeExtension(ctx context.Context, svc *common.Service, db sql.DB, i
return method.Handler(exec2.engineCtx, exec2.app(), argVals, func(a []any) error {
// if no return is specified for this method, then the callback should never be called
if method.Returns == nil {
return fmt.Errorf(`%w: method "%s"."%s" returned no value, but expected one`, engine.ErrExtensionInvocation, alias, lowerName)
return fmt.Errorf(`%w: method "%s"."%s" returned no value, but expected one`, engine.ErrExtensionImplementation, alias, lowerName)
}

colNames := make([]string, len(a))
returnVals := make([]value, len(a))

if len(method.Returns.Fields) != len(a) {
return fmt.Errorf("%w: method %s returned %d values, but expected %d", engine.ErrExtensionInvocation, lowerName, len(a), len(method.Returns.Fields))
return fmt.Errorf("%w: method %s returned %d values, but expected %d", engine.ErrExtensionImplementation, lowerName, len(a), len(method.Returns.Fields))
}

for i, v := range a {
Expand All @@ -90,11 +90,11 @@ func initializeExtension(ctx context.Context, svc *common.Service, db sql.DB, i
return err
}
if !ok {
return fmt.Errorf(`%w: method "%s"."%s" returned a value of type %s, but expected %s. column: "%s"`, engine.ErrExtensionInvocation, alias, lowerName, newVal.Type(), method.Returns.Fields[i].Type, method.Returns.Fields[i].Name)
return fmt.Errorf(`%w: method "%s"."%s" returned a value of type %s, but expected %s. column: "%s"`, engine.ErrExtensionImplementation, alias, lowerName, newVal.Type(), method.Returns.Fields[i].Type, method.Returns.Fields[i].Name)
}

if !method.Returns.Fields[i].Nullable && newVal.Null() {
return fmt.Errorf(`%w: method "%s"."%s" returned a null value for a non-nullable column. column: "%s"`, engine.ErrExtensionInvocation, alias, lowerName, method.Returns.Fields[i].Name)
return fmt.Errorf(`%w: method "%s"."%s" returned a null value for a non-nullable column. column: "%s"`, engine.ErrExtensionImplementation, alias, lowerName, method.Returns.Fields[i].Name)
}

returnVals[i] = newVal
Expand Down
2 changes: 1 addition & 1 deletion node/engine/interpreter/interpreter.go
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,7 @@ func (i *baseInterpreter) call(ctx *common.EngineContext, db sql.DB, namespace,
if exec.ExpectedArgs != nil {
expect := *exec.ExpectedArgs
if len(expect) != len(args) {
return nil, fmt.Errorf(`action "%s" expected %d arguments, but got %d`, action, len(expect), len(args))
return nil, fmt.Errorf(`%w: action "%s" expected %d arguments, but got %d`, engine.ErrActionInvocation, action, len(expect), len(args))
}

for i, arg := range args {
Expand Down
12 changes: 6 additions & 6 deletions node/engine/interpreter/interpreter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2792,11 +2792,11 @@ func Test_ExtensionTypeChecks(t *testing.T) {
require.NoError(t, err)

_, err = interp.Call(newEngineCtx(defaultCaller), tx, "types_ext", "accept_not_null", []any{nil}, nil)
require.ErrorIs(t, err, engine.ErrExtensionInvocation)
require.ErrorIs(t, err, engine.ErrExtensionImplementation)

// call with an int
_, err = interp.Call(newEngineCtx(defaultCaller), tx, "types_ext", "accept_not_null", []any{1}, nil)
require.ErrorIs(t, err, engine.ErrExtensionInvocation)
require.ErrorIs(t, err, engine.ErrType)

// 2. takes 1 param (can be null), returns nothing
_, err = interp.Call(newEngineCtx(defaultCaller), tx, "types_ext", "accept_null", []any{"hello"}, nil)
Expand All @@ -2810,7 +2810,7 @@ func Test_ExtensionTypeChecks(t *testing.T) {
require.NoError(t, err)

_, err = interp.Call(newEngineCtx(defaultCaller), tx, "types_ext", "return_not_null", []any{nil}, nil)
require.ErrorIs(t, err, engine.ErrExtensionInvocation)
require.ErrorIs(t, err, engine.ErrExtensionImplementation)

// 4. takes 1 param, returns the same type (return can be null)
_, err = interp.Call(newEngineCtx(defaultCaller), tx, "types_ext", "return_null", []any{"hello"}, exact("hello"))
Expand All @@ -2822,15 +2822,15 @@ func Test_ExtensionTypeChecks(t *testing.T) {
// Other tests:
// returns wrong type
_, err = interp.Call(newEngineCtx(defaultCaller), tx, "types_ext", "returns_wrong_type", nil, nil)
require.ErrorIs(t, err, engine.ErrExtensionInvocation)
require.ErrorIs(t, err, engine.ErrExtensionImplementation)

// returns wrong count
_, err = interp.Call(newEngineCtx(defaultCaller), tx, "types_ext", "returns_wrong_count", nil, nil)
require.ErrorIs(t, err, engine.ErrExtensionInvocation)
require.ErrorIs(t, err, engine.ErrExtensionImplementation)

// wrong count for parameters
_, err = interp.Call(newEngineCtx(defaultCaller), tx, "types_ext", "accept_not_null", []any{"hello", "world"}, nil)
require.ErrorIs(t, err, engine.ErrExtensionInvocation)
require.ErrorIs(t, err, engine.ErrActionInvocation)

// empty array works ok
_, err = interp.Call(newEngineCtx(defaultCaller), tx, "types_ext", "returns empty decimal array", nil, exact([]*types.Decimal{}))
Expand Down

0 comments on commit a01729f

Please sign in to comment.