Skip to content

Commit

Permalink
engine: cache prepared statements
Browse files Browse the repository at this point in the history
This PR adds caching for statements in the interpreter, removing the need to redundantly
parse, analyze, and generate SQL for statements that don't change. The cache is invalidated
any time any table is changed; this is because a change in the DB schema may change what
is necessary to make a query deterministic.

This PR also fixes a bug that I coincidentally found while clicking through the query planner
looking for something else. It was not properly applying default ordering to window functions.
  • Loading branch information
brennanjl authored Feb 21, 2025
1 parent 0f0bd06 commit 38f5803
Show file tree
Hide file tree
Showing 5 changed files with 283 additions and 77 deletions.
252 changes: 201 additions & 51 deletions node/engine/interpreter/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package interpreter
import (
"fmt"
"strings"
"sync"

"github.com/kwilteam/kwil-db/common"
"github.com/kwilteam/kwil-db/core/types"
Expand Down Expand Up @@ -152,29 +153,161 @@ func (e *executionContext) query(sql string, fn func(*row) error) error {
e.queryActive = true
defer func() { e.queryActive = false }()

generatedSQL, analyzed, args, err := e.prepareQuery(sql)
if err != nil {
return err
}

// get the scan values as well:
var scanValues []value
for _, field := range analyzed.Plan.Relation().Fields {
scalar, err := field.Scalar()
if err != nil {
return err
}

zVal, err := newZeroValue(scalar)
if err != nil {
return err
}

scanValues = append(scanValues, zVal)
}

cols := make([]string, len(analyzed.Plan.Relation().Fields))
for i, field := range analyzed.Plan.Relation().Fields {
cols[i] = field.Name
}

return query(e.engineCtx.TxContext.Ctx, e.db, generatedSQL, scanValues, func() error {
if len(scanValues) != len(cols) {
// should never happen, but just in case
return fmt.Errorf("node bug: scan values and columns are not the same length")
}

return fn(&row{
columns: cols,
Values: scanValues,
})
}, args)
}

// getValues gets values of the names
func (e *executionContext) getValues(names []string) ([]value, error) {
values := make([]value, len(names))
for i, name := range names {
val, err := e.getVariable(name)
if err != nil {
return nil, err
}
values[i] = val
}
return values, nil
}

// prepareQuery prepares a query for execution.
// It will check the cache for a prepared statement, and if it does not exist,
// it will parse the SQL, create a logical plan, and cache the statement.
func (e *executionContext) prepareQuery(sql string) (pgSql string, plan *logical.AnalyzedPlan, args []value, err error) {
cached, ok := statementCache.get(e.scope.namespace, sql)
if ok {
// if it is mutating state it must be deterministic
if e.canMutateState {
values, err := e.getValues(cached.deterministicParams)
if err != nil {
return "", nil, nil, err
}

return cached.deterministicSQL, cached.deterministicPlan, values, nil
}
values, err := e.getValues(cached.nonDeterministicParams)
if err != nil {
return "", nil, nil, err
}
return cached.nonDeterministicSQL, cached.nonDeterministicPlan, values, nil
}

deterministicAST, err := getAST(sql)
if err != nil {
return "", nil, nil, err
}
nondeterministicAST, err := getAST(sql)
if err != nil {
return "", nil, nil, err
}

deterministicPlan, err := makePlan(e, deterministicAST)
if err != nil {
return "", nil, nil, fmt.Errorf("%w: %w", engine.ErrQueryPlanner, err)
}

nonDeterministicPlan, err := makePlan(e, nondeterministicAST)
if err != nil {
return "", nil, nil, fmt.Errorf("%w: %w", engine.ErrQueryPlanner, err)
}

deterministicSQL, deterministicParams, err := pggenerate.GenerateSQL(deterministicAST, e.scope.namespace, e.getVariableType)
if err != nil {
return "", nil, nil, fmt.Errorf("%w: %w", engine.ErrPGGen, err)
}

nonDeterministicSQL, nonDeterministicParams, err := pggenerate.GenerateSQL(nondeterministicAST, e.scope.namespace, e.getVariableType)
if err != nil {
return "", nil, nil, fmt.Errorf("%w: %w", engine.ErrPGGen, err)
}

statementCache.set(e.scope.namespace, sql, &preparedStatement{
deterministicPlan: deterministicPlan,
deterministicSQL: deterministicSQL,
deterministicParams: deterministicParams,
nonDeterministicPlan: nonDeterministicPlan,
nonDeterministicSQL: nonDeterministicSQL,
nonDeterministicParams: nonDeterministicParams,
})

if e.canMutateState {
values, err := e.getValues(deterministicParams)
if err != nil {
return "", nil, nil, err
}

return deterministicSQL, deterministicPlan, values, nil
}
values, err := e.getValues(nonDeterministicParams)
if err != nil {
return "", nil, nil, err
}
return nonDeterministicSQL, nonDeterministicPlan, values, nil
}

// getAST gets the AST of a SQL statement.
func getAST(sql string) (*parse.SQLStatement, error) {
res, err := parse.Parse(sql)
if err != nil {
return fmt.Errorf("%w: invalid query '%s': %w", engine.ErrParse, sql, err)
return nil, fmt.Errorf("%w: invalid query '%s': %w", engine.ErrParse, sql, err)
}

if len(res) != 1 {
// this is an node bug b/c `query` is only called with a single statement
// from the interpreter
return fmt.Errorf("node bug: expected exactly 1 statement, got %d", len(res))
return nil, fmt.Errorf("node bug: expected exactly 1 statement, got %d", len(res))
}

sqlStmt, ok := res[0].(*parse.SQLStatement)
if !ok {
return fmt.Errorf("node bug: expected *parse.SQLStatement, got %T", res[0])
return nil, fmt.Errorf("node bug: expected *parse.SQLStatement, got %T", res[0])
}

// create a logical plan. This will make the query deterministic (if necessary),
// as well as tell us what the return types will be.
analyzed, err := logical.CreateLogicalPlan(
sqlStmt,
return sqlStmt, nil
}

// makePlan creates a logical plan from a SQL statement.
func makePlan(e *executionContext, ast *parse.SQLStatement) (*logical.AnalyzedPlan, error) {
return logical.CreateLogicalPlan(
ast,
e.getTable,
e.getVariableType,
func(objName string) (obj map[string]*types.DataType, err2 error) {
func(objName string) (obj map[string]*types.DataType, err error) {
val, err := e.getVariable(objName)
if err != nil {
return nil, err
Expand Down Expand Up @@ -207,58 +340,73 @@ func (e *executionContext) query(sql string, fn func(*row) error) error {
e.canMutateState,
e.scope.namespace,
)
if err != nil {
return fmt.Errorf("%w: %w", engine.ErrQueryPlanner, err)
}
}

generatedSQL, params, err := pggenerate.GenerateSQL(sqlStmt, e.scope.namespace, e.getVariableType)
if err != nil {
return fmt.Errorf("%w: %w", engine.ErrPGGen, err)
}
// preparedStatement is a SQL statement that has been parsed and planned
// against a schema (a set of tables with some actions).
// It separates into two forms: deterministic and non-deterministic.
// This is necessary because we use the AST to generate Postgres SQL
// queries, so we actually modify the AST to make it deterministic.
type preparedStatement struct {
deterministicPlan *logical.AnalyzedPlan
deterministicSQL string
// the params for deterministic and non-deterministic
// queries _should_ be the same, but I am keeping them separate
// because it might change based on the implementation of the planner
deterministicParams []string
nonDeterministicPlan *logical.AnalyzedPlan
nonDeterministicSQL string
nonDeterministicParams []string
}

// get the params we will pass
var args []value
for _, param := range params {
val, err := e.getVariable(param)
if err != nil {
return err
}
// statementCache caches parsed statements.
// It is reloaded when schema changes are made to the namespace
type preparedStatements struct {
mu sync.RWMutex
// statements maps a namespace to a map of statements to two parsed forms.
statements map[string]map[string]*preparedStatement
}

// get gets a prepared statement from the cache.
func (p *preparedStatements) get(namespace, query string) (*preparedStatement, bool) {
p.mu.RLock()
defer p.mu.RUnlock()

args = append(args, val)
ns, ok := p.statements[namespace]
if !ok {
return nil, false
}

// get the scan values as well:
var scanValues []value
for _, field := range analyzed.Plan.Relation().Fields {
scalar, err := field.Scalar()
if err != nil {
return err
}
stmt, ok := ns[query]
if !ok {
return nil, false
}

zVal, err := newZeroValue(scalar)
if err != nil {
return err
}
return stmt, true
}

scanValues = append(scanValues, zVal)
}
// set sets a prepared statement in the cache.
func (p *preparedStatements) set(namespace, query string, stmt *preparedStatement) {
p.mu.Lock()
defer p.mu.Unlock()

cols := make([]string, len(analyzed.Plan.Relation().Fields))
for i, field := range analyzed.Plan.Relation().Fields {
cols[i] = field.Name
if _, ok := p.statements[namespace]; !ok {
p.statements[namespace] = make(map[string]*preparedStatement)
}

return query(e.engineCtx.TxContext.Ctx, e.db, generatedSQL, scanValues, func() error {
if len(scanValues) != len(cols) {
// should never happen, but just in case
return fmt.Errorf("node bug: scan values and columns are not the same length")
}
p.statements[namespace][query] = stmt
}

return fn(&row{
columns: cols,
Values: scanValues,
})
}, args)
// clear clears the cache namespace.
func (p *preparedStatements) clear() {
p.mu.Lock()
defer p.mu.Unlock()

p.statements = make(map[string]map[string]*preparedStatement)
}

var statementCache = &preparedStatements{
statements: make(map[string]map[string]*preparedStatement),
}

// executable is the interface and function to call a built-in Postgres function,
Expand Down Expand Up @@ -406,8 +554,8 @@ func (e *executionContext) getVariable(name string) (value, error) {
}
}

// reloadTables reloads the cached tables from the database for the current namespace.
func (e *executionContext) reloadTables() error {
// reloadNamespaceCache reloads the cached tables from the database for the current namespace.
func (e *executionContext) reloadNamespaceCache() error {
tables, err := listTablesInNamespace(e.engineCtx.TxContext.Ctx, e.db, e.scope.namespace)
if err != nil {
return err
Expand All @@ -420,6 +568,8 @@ func (e *executionContext) reloadTables() error {
ns.tables[table.Name] = table
}

statementCache.clear()

return nil
}

Expand Down
45 changes: 45 additions & 0 deletions node/engine/interpreter/interpreter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3046,3 +3046,48 @@ func newTestInterp(t *testing.T, tx sql.DB, seeds []string, includeTestTables bo

return interp
}

// This tests that dropping a namespace invalidates the statement cache.
// If the cache is not invalidated, the test will fail because it will try to insert
// a string into an integer column.
func Test_NamespaceDropsCache(t *testing.T) {
db := newTestDB(t, nil, nil)

ctx := context.Background()
tx, err := db.BeginTx(ctx)
require.NoError(t, err)
defer tx.Rollback(ctx) // always rollback

interp := newTestInterp(t, tx, nil, true)

err = interp.ExecuteWithoutEngineCtx(ctx, tx, `CREATE NAMESPACE test_ns;`, nil, nil)
require.NoError(t, err)

err = interp.ExecuteWithoutEngineCtx(ctx, tx, `{test_ns}CREATE TABLE test_table (id INT PRIMARY KEY);`, nil, nil)
require.NoError(t, err)

err = interp.ExecuteWithoutEngineCtx(ctx, tx, `{test_ns}CREATE ACTION smthn($a int) public { INSERT INTO test_table (id) VALUES ($a); }`, nil, nil)
require.NoError(t, err)

_, err = interp.CallWithoutEngineCtx(ctx, tx, "test_ns", "smthn", []any{1}, nil)
require.NoError(t, err)

// drop the namespace
err = interp.ExecuteWithoutEngineCtx(ctx, tx, `DROP NAMESPACE test_ns;`, nil, nil)
require.NoError(t, err)

// create the namespace again
err = interp.ExecuteWithoutEngineCtx(ctx, tx, `CREATE NAMESPACE test_ns;`, nil, nil)
require.NoError(t, err)

// create the table again
err = interp.ExecuteWithoutEngineCtx(ctx, tx, `{test_ns}CREATE TABLE test_table (id TEXT PRIMARY KEY);`, nil, nil)
require.NoError(t, err)

// create the action again
err = interp.ExecuteWithoutEngineCtx(ctx, tx, `{test_ns}CREATE ACTION smthn($a text) public { INSERT INTO test_table (id) VALUES ($a); }`, nil, nil)
require.NoError(t, err)

_, err = interp.CallWithoutEngineCtx(ctx, tx, "test_ns", "smthn", []any{"hello"}, nil)
require.NoError(t, err)
}
Loading

0 comments on commit 38f5803

Please sign in to comment.