From faf879e8fe7c0420101420f0efeadc5c0a46077d Mon Sep 17 00:00:00 2001 From: Brennan Lamey <66885902+brennanjl@users.noreply.github.com> Date: Mon, 15 Jul 2024 12:28:03 -0500 Subject: [PATCH] fix: in-line type assertions not comparable * engine,parse: fix in-line action error This change fixes two bugs: - Actions using in-line statements with type casts for ambiguous types would not generate SQL properly. This would only occur when performing an expression (such as a comparison). This is the main purpose of this PR. - Another very minor bug that caused a stack overflow when using the parse debugger tool was fixed. It was caused by having circular pointer references while removing error positions. * fix failing tests --- internal/engine/execution/global.go | 7 +++++- internal/engine/execution/procedure.go | 18 ++++++++++++++- internal/engine/generate/plpgsql.go | 10 +++++++-- internal/engine/integration/sql_test.go | 17 +++++++++++++++ parse/parse.go | 29 +++++++++++++++++++------ 5 files changed, 70 insertions(+), 11 deletions(-) diff --git a/internal/engine/execution/global.go b/internal/engine/execution/global.go index fb9216f9a..899c7bb27 100644 --- a/internal/engine/execution/global.go +++ b/internal/engine/execution/global.go @@ -376,7 +376,12 @@ func (g *GlobalContext) Execute(ctx context.Context, tx sql.DB, dbid, query stri args := orderAndCleanValueMap(values, params) args = append([]any{pg.QueryModeExec}, args...) - return tx.Execute(ctx, sqlStmt, args...) + result, err := tx.Execute(ctx, sqlStmt, args...) + if err != nil { + return nil, decorateExecuteErr(err, query) + } + + return result, nil } type dbQueryFn func(ctx context.Context, stmt string, args ...any) (*sql.ResultSet, error) diff --git a/internal/engine/execution/procedure.go b/internal/engine/execution/procedure.go index 0c77d5f4e..c2f0a1361 100644 --- a/internal/engine/execution/procedure.go +++ b/internal/engine/execution/procedure.go @@ -8,6 +8,7 @@ import ( "reflect" "strings" + "github.com/jackc/pgx/v5/pgconn" "github.com/kwilteam/kwil-db/common" sql "github.com/kwilteam/kwil-db/common/sql" "github.com/kwilteam/kwil-db/core/types" @@ -40,6 +41,7 @@ var ( ErrPrivateProcedure = errors.New("procedure is private") ErrMutativeProcedure = errors.New("procedure is mutative") ErrMaxStackDepth = errors.New("max call stack depth reached") + ErrCannotInferType = errors.New("cannot infer type") ) // instruction is an instruction that can be executed. @@ -313,6 +315,20 @@ type dmlStmt struct { OrderedParameters []string } +// decorateExecuteErr parses an execute error from postgres and tries to give a more helpful error message. +// this allows us to give a more helpful error message when users hit this, +// since the Postgres error message is not helpful, and this is a common error. +func decorateExecuteErr(err error, stmt string) error { + // this catches a common error case for in-line expressions, where the type cannot be inferred + var pgErr *pgconn.PgError + if errors.As(err, &pgErr) && pgErr.Code == "42P08" || pgErr.Code == "42P18" { + return fmt.Errorf(`%w: could not dynamically determine the data type in statement "%s". try type casting using ::, e.g. $id::text`, + ErrCannotInferType, stmt) + } + + return err +} + var _ instructionFunc = (&dmlStmt{}).execute func (e *dmlStmt) execute(scope *precompiles.ProcedureContext, _ *GlobalContext, db sql.DB) error { @@ -321,7 +337,7 @@ func (e *dmlStmt) execute(scope *precompiles.ProcedureContext, _ *GlobalContext, // args := append([]any{pg.QueryModeExec}, params...) results, err := db.Execute(scope.Ctx, e.SQLStatement, append([]any{pg.QueryModeExec}, params...)...) if err != nil { - return err + return decorateExecuteErr(err, e.SQLStatement) } // we need to check for any pg numeric types returned, and convert them to int64 diff --git a/internal/engine/generate/plpgsql.go b/internal/engine/generate/plpgsql.go index ce4e23055..48507fbf7 100644 --- a/internal/engine/generate/plpgsql.go +++ b/internal/engine/generate/plpgsql.go @@ -2,6 +2,7 @@ package generate import ( "fmt" + "strconv" "strings" "github.com/kwilteam/kwil-db/core/types" @@ -126,7 +127,7 @@ func (s *sqlGenerator) VisitExpressionVariable(p0 *parse.ExpressionVariable) any // if it already exists, we write it as that index. for i, v := range s.orderedParams { if v == str { - return "$" + fmt.Sprint(i+1) + return "$" + strconv.Itoa(i+1) } } @@ -134,7 +135,12 @@ func (s *sqlGenerator) VisitExpressionVariable(p0 *parse.ExpressionVariable) any // Postgres uses $1, $2, etc. for numbered parameters. s.orderedParams = append(s.orderedParams, str) - return "$" + fmt.Sprint(len(s.orderedParams)) + + res := strings.Builder{} + res.WriteString("$") + res.WriteString(strconv.Itoa(len(s.orderedParams))) + typeCast(p0, &res) + return res.String() } str := strings.Builder{} diff --git a/internal/engine/integration/sql_test.go b/internal/engine/integration/sql_test.go index 0cb8bf8fe..86706835d 100644 --- a/internal/engine/integration/sql_test.go +++ b/internal/engine/integration/sql_test.go @@ -6,6 +6,7 @@ import ( "context" "testing" + "github.com/kwilteam/kwil-db/internal/engine/execution" "github.com/stretchr/testify/require" ) @@ -202,6 +203,22 @@ func Test_SQL(t *testing.T) { {"4a67d6ea-7ac8-453c-964e-5a144f9e3004"}, }, }, + { + name: "inferred type - failure", + sql: "select $id is null", + values: map[string]any{ + "$id": "4a67d6ea-7ac8-453c-964e-5a144f9e3004", + }, + err: execution.ErrCannotInferType, + }, + { + name: "inferred type - success", + sql: "select $id::text is null", + values: map[string]any{ + "$id": "4a67d6ea-7ac8-453c-964e-5a144f9e3004", + }, + want: [][]any{{false}}, + }, } for _, tt := range tests { diff --git a/parse/parse.go b/parse/parse.go index 0f3ff9ae1..b92c4d4f6 100644 --- a/parse/parse.go +++ b/parse/parse.go @@ -454,16 +454,18 @@ func setupParser(inputStream string, errLisName string) (errLis *errorListener, // It is used in both parsing tools, as well as in tests. // WARNING: This function should NEVER be used in consensus, since it is non-deterministic. func RecursivelyVisitPositions(v any, fn func(GetPositioner)) { + + visited := make(map[uintptr]struct{}) visitRecursive(reflect.ValueOf(v), reflect.TypeOf((*GetPositioner)(nil)).Elem(), func(v reflect.Value) { if v.CanInterface() { a := v.Interface().(GetPositioner) fn(a) } - }) + }, visited) } // visitRecursive is a recursive function that visits all types that implement the target interface. -func visitRecursive(v reflect.Value, target reflect.Type, fn func(reflect.Value)) { +func visitRecursive(v reflect.Value, target reflect.Type, fn func(reflect.Value), visited map[uintptr]struct{}) { if v.Type().Implements(target) { // check if the value is nil if !v.IsNil() { @@ -472,23 +474,36 @@ func visitRecursive(v reflect.Value, target reflect.Type, fn func(reflect.Value) } switch v.Kind() { - case reflect.Ptr, reflect.Interface: + case reflect.Interface: + if v.IsNil() { + return + } + + visitRecursive(v.Elem(), target, fn, visited) + case reflect.Ptr: if v.IsNil() { return } - visitRecursive(v.Elem(), target, fn) + // check if we have visited this pointer before + ptr := v.Pointer() + if _, ok := visited[ptr]; ok { + return + } + visited[ptr] = struct{}{} + + visitRecursive(v.Elem(), target, fn, visited) case reflect.Struct: for i := 0; i < v.NumField(); i++ { - visitRecursive(v.Field(i), target, fn) + visitRecursive(v.Field(i), target, fn, visited) } case reflect.Slice, reflect.Array: for i := 0; i < v.Len(); i++ { - visitRecursive(v.Index(i), target, fn) + visitRecursive(v.Index(i), target, fn, visited) } case reflect.Map: for _, key := range v.MapKeys() { - visitRecursive(v.MapIndex(key), target, fn) + visitRecursive(v.MapIndex(key), target, fn, visited) } } }