Skip to content

Commit

Permalink
Parse: tree update for visitor (#602)
Browse files Browse the repository at this point in the history
* parse: ast tree visitor, and walker

* parse: tree.Relation insteadof TableOrSubquery+JoinClause

* parse: change statement structure

* parse: move sql-grammar-go in kwil-db

* update kuneiform version

* made select node names consistent

* renamed selects

---------

Co-authored-by: Brennan Lamey <66885902+brennanjl@users.noreply.github.com>
  • Loading branch information
Yaiba and brennanjl authored Mar 25, 2024
1 parent 4f6acde commit 4a4632e
Show file tree
Hide file tree
Showing 86 changed files with 18,741 additions and 3,177 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
[submodule "proto"]
path = proto
url = https://github.com/kwilteam/proto
[submodule "parse/sql/antlr-grammar"]
path = parse/sql/antlr-grammar
url = https://github.com/kwilteam/sql-grammar
13 changes: 13 additions & 0 deletions Taskfile-antlr.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
version: '3'

tasks:
sql:
desc: Generate sql grammar go code.
cmds:
- rm -rf parse/sql/grammar/*
- rm -rf parse/sql/antlr-grammar/{gen,.antlr}/*
- cd parse/sql/antlr-grammar/ && ./generate.sh Go grammar ../grammar
sources:
- parse/sql/antlr-grammar/*.g4
generates:
- parse/sql/grammar/*.{go,interp,tokens}
1 change: 1 addition & 0 deletions Taskfile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ version: "3"

includes:
pb: ./Taskfile-pb.yml
antlr: ./Taskfile-antlr.yml

tasks:
default:
Expand Down
5 changes: 2 additions & 3 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ require (
github.com/jackc/pglogrepl v0.0.0-20231111135425-1627ab1b5780
github.com/jackc/pgx/v5 v5.5.2
github.com/jpillora/backoff v1.0.0
github.com/kwilteam/kuneiform v0.6.0
github.com/kwilteam/kuneiform v0.6.1-0.20240325162722-e0e0101eca92
github.com/kwilteam/kwil-db/core v0.1.0
github.com/kwilteam/kwil-db/parse v0.1.1
github.com/kwilteam/kwil-db/parse v0.1.2
github.com/kwilteam/kwil-extensions v0.0.0-20230727040522-1cfd930226b7
github.com/manifoldco/promptui v0.9.0
github.com/mitchellh/mapstructure v1.5.0
Expand Down Expand Up @@ -93,7 +93,6 @@ require (
github.com/jmhodges/levigo v1.0.0 // indirect
github.com/klauspost/compress v1.17.0 // indirect
github.com/kwilteam/action-grammar-go v0.1.1 // indirect
github.com/kwilteam/sql-grammar-go v0.1.0 // indirect
github.com/lib/pq v1.10.7 // indirect
github.com/libp2p/go-buffer-pool v0.1.0 // indirect
github.com/magiconair/properties v1.8.7 // indirect
Expand Down
6 changes: 2 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -274,12 +274,10 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kwilteam/action-grammar-go v0.1.1 h1:0NeWrIN0B+pQMyiTwW/kWtqLWl7P4ExmfHwaXaJ14zI=
github.com/kwilteam/action-grammar-go v0.1.1/go.mod h1:hHGHtnrJpASW9P+F7pdr/EP2M1Hxy1N9Wx/TmjVdV6I=
github.com/kwilteam/kuneiform v0.6.0 h1:Y8VWrJN1cl9idqX+LBSQd+c3m/JjDDRInBSKq3i27NY=
github.com/kwilteam/kuneiform v0.6.0/go.mod h1:b3Ce6falEDBQ0xgLpa/hjFjUQoD8aFEg96yewS/3wzg=
github.com/kwilteam/kuneiform v0.6.1-0.20240325162722-e0e0101eca92 h1:RoFJdrFt0zI6Y1t23PUTMYgykhCPlEOzxQ4nEd+ymFI=
github.com/kwilteam/kuneiform v0.6.1-0.20240325162722-e0e0101eca92/go.mod h1:+9V+E5I5sEL643ZHaeiF6bVCJqp56lGti+wXZdJ/YYQ=
github.com/kwilteam/kwil-extensions v0.0.0-20230727040522-1cfd930226b7 h1:YiPBu0pOeYOtOVfwKQqdWB07SUef9LvngF4bVFD+x34=
github.com/kwilteam/kwil-extensions v0.0.0-20230727040522-1cfd930226b7/go.mod h1:+BrFrV+3qcdYIfptqjwatE5gT19azuRHJzw77wMPY8c=
github.com/kwilteam/sql-grammar-go v0.1.0 h1:rSS7DER9PWVDmFwNyoInG5oXrn+E9UrZkjref84L4Qk=
github.com/kwilteam/sql-grammar-go v0.1.0/go.mod h1:A9AXaH5Vl/uPsY88fWqvU9O7z7P4YfvndaGyc8s//2s=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/leanovate/gopter v0.2.9 h1:fQjYxZaynp97ozCzfOyOuAGOU4aU/z37zf/tOujFk7c=
Expand Down
16 changes: 8 additions & 8 deletions internal/engine/execution/procedure.go
Original file line number Diff line number Diff line change
Expand Up @@ -453,16 +453,16 @@ func makeExecutables(exprs []tree.Expression) ([]evaluatable, error) {
}

// clean expression, since it is submitted by the user
err := expr.Accept(clean.NewStatementCleaner())
err := expr.Walk(clean.NewStatementCleaner())
if err != nil {
return nil, err
}

// The schema walker is not necessary for inline expressions, since
// we do not support table references in inline expressions.
accept := sqlanalyzer.NewAcceptRecoverer(expr)
paramVisitor := parameters.NewParametersVisitor()
err = accept.Accept(paramVisitor)
walker := sqlanalyzer.NewWalkerRecoverer(expr)
paramVisitor := parameters.NewParametersWalker()
err = walker.Walk(paramVisitor)
if err != nil {
return nil, fmt.Errorf("error replacing parameters: %w", err)
}
Expand All @@ -471,9 +471,9 @@ func makeExecutables(exprs []tree.Expression) ([]evaluatable, error) {
// statements This query needs to be run in "simple" execution mode
// rather than "extended" execution mode, which asks the database for
// OID (placeholder types) that it can't know since there's no FOR table.
selectTree := &tree.Select{
SelectStmt: &tree.SelectStmt{
SelectCores: []*tree.SelectCore{
selectTree := &tree.SelectStmt{
Stmt: &tree.SelectCore{
SelectCores: []*tree.SimpleSelect{
{
SelectType: tree.SelectTypeAll,
Columns: []tree.ResultColumn{
Expand All @@ -486,7 +486,7 @@ func makeExecutables(exprs []tree.Expression) ([]evaluatable, error) {
},
}

stmt, err := selectTree.ToSQL()
stmt, err := tree.SafeToSQL(selectTree)
if err != nil {
return nil, err
}
Expand Down
32 changes: 16 additions & 16 deletions internal/engine/sqlanalyzer/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,24 @@ import (
"github.com/kwilteam/kwil-db/parse/sql/tree"
)

// AcceptRecoverer is a wrapper around a statement that implements the accepter interface
// it catches panics and returns them as errors
type AcceptRecoverer struct {
tree.Accepter
// WalkerRecoverer is a wrapper around a statement that implements the AstWalker
// interface, it catches panics and returns them as errors
type WalkerRecoverer struct {
inner tree.AstWalker
}

func NewAcceptRecoverer(a tree.Accepter) *AcceptRecoverer {
return &AcceptRecoverer{a}
func NewWalkerRecoverer(a tree.AstWalker) *WalkerRecoverer {
return &WalkerRecoverer{a}
}

func (a *AcceptRecoverer) Accept(walker tree.Walker) (err error) {
func (a *WalkerRecoverer) Walk(walker tree.AstListener) (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic while walking statement: %v", r)
}
}()

return a.Accepter.Accept(walker)
return a.inner.Walk(walker)
}

// ApplyRules analyzes the given statement and returns the transformed statement.
Expand All @@ -48,38 +48,38 @@ func ApplyRules(stmt string, flags VerifyFlag, tables []*common.Table, pgSchemaN
return nil, fmt.Errorf("error parsing statement: %w", err)
}

accept := &AcceptRecoverer{parsed}
walker := &WalkerRecoverer{parsed}

clnr := clean.NewStatementCleaner()
err = accept.Accept(clnr)
err = walker.Walk(clnr)
if err != nil {
return nil, fmt.Errorf("error cleaning statement: %w", err)
}

schemaWalker := schema.NewSchemaWalker(pgSchemaName)
err = accept.Accept(schemaWalker)
err = walker.Walk(schemaWalker)
if err != nil {
return nil, fmt.Errorf("error applying schema rules: %w", err)
}

if flags&NoCartesianProduct != 0 {
err := accept.Accept(join.NewJoinWalker())
err := walker.Walk(join.NewJoinWalker())
if err != nil {
return nil, fmt.Errorf("error applying join rules: %w", err)
}
}

if flags&GuaranteedOrder != 0 {
err := accept.Accept(order.NewOrderWalker(cleanedTables))
err := walker.Walk(order.NewOrderWalker(cleanedTables))
if err != nil {
return nil, fmt.Errorf("error enforcing guaranteed order: %w", err)
}
}

orderedParams := make([]string, 0)
if flags&ReplaceNamedParameters != 0 {
paramVisitor := parameters.NewParametersVisitor()
err := accept.Accept(paramVisitor)
paramVisitor := parameters.NewParametersWalker()
err := walker.Walk(paramVisitor)
if err != nil {
return nil, fmt.Errorf("error replacing named parameters: %w", err)
}
Expand All @@ -91,7 +91,7 @@ func ApplyRules(stmt string, flags VerifyFlag, tables []*common.Table, pgSchemaN
return nil, fmt.Errorf("error determining mutativity: %w", err)
}

generated, err := parsed.ToSQL()
generated, err := tree.SafeToSQL(parsed)
if err != nil {
return nil, fmt.Errorf("error generating SQL: %w", err)
}
Expand Down
20 changes: 10 additions & 10 deletions internal/engine/sqlanalyzer/attributes/select_core.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ type RelationAttribute struct {
// tbl1.col, col, col AS alias, col*5 AS alias, etc.
// If a statement has "SELECT * FROM tbl",
// then the result column expressions will be tbl.col_1, tbl.col_2, etc.
func GetSelectCoreRelationAttributes(selectCore *tree.SelectCore, tables []*common.Table) ([]*RelationAttribute, error) {
func GetSelectCoreRelationAttributes(selectCore *tree.SimpleSelect, tables []*common.Table) ([]*RelationAttribute, error) {
walker := newSelectCoreWalker(tables)
err := selectCore.Accept(walker)
err := selectCore.Walk(walker)
if err != nil {
return nil, fmt.Errorf("error analyzing select core: %w", err)
}
Expand All @@ -55,7 +55,7 @@ func GetSelectCoreRelationAttributes(selectCore *tree.SelectCore, tables []*comm

func newSelectCoreWalker(tables []*common.Table) *selectCoreAnalyzer {
return &selectCoreAnalyzer{
Walker: tree.NewBaseWalker(),
AstListener: tree.NewBaseListener(),
context: newSelectCoreContext(nil),
schemaTables: tables,
detectedAttributes: []*RelationAttribute{},
Expand All @@ -64,7 +64,7 @@ func newSelectCoreWalker(tables []*common.Table) *selectCoreAnalyzer {

// selectCoreAnalyzer will walk the tree and identify the returned attributes for the select core
type selectCoreAnalyzer struct {
tree.Walker
tree.AstListener
context *selectCoreContext
schemaTables []*common.Table

Expand Down Expand Up @@ -197,14 +197,14 @@ func newSelectCoreContext(parent *selectCoreContext) *selectCoreContext {
}

// EnterSelectCore creates a new scope.
func (s *selectCoreAnalyzer) EnterSelectCore(node *tree.SelectCore) error {
func (s *selectCoreAnalyzer) EnterSelectCore(node *tree.SimpleSelect) error {
s.newScope()

return nil
}

// ExitSelectCore pops the current scope.
func (s *selectCoreAnalyzer) ExitSelectCore(node *tree.SelectCore) error {
func (s *selectCoreAnalyzer) ExitSelectCore(node *tree.SimpleSelect) error {
var err error
s.detectedAttributes, err = s.context.relations()
if err != nil {
Expand All @@ -215,8 +215,8 @@ func (s *selectCoreAnalyzer) ExitSelectCore(node *tree.SelectCore) error {
return nil
}

// EnterTableOrSubqueryTable adds the table to the list of used tables.
func (s *selectCoreAnalyzer) EnterTableOrSubqueryTable(node *tree.TableOrSubqueryTable) error {
// EnterRelationTable adds the table to the list of used tables.
func (s *selectCoreAnalyzer) EnterRelationTable(node *tree.RelationTable) error {
tbl, err := findTable(s.schemaTables, node.Name)
if err != nil {
return err
Expand Down Expand Up @@ -278,8 +278,8 @@ func findColumn(columns []*common.Column, name string) (*common.Column, error) {
}

// addTableIfNotPresent adds the table name to the column if it is not already present.
func addTableIfNotPresent(tableName string, expr tree.Accepter) error {
return expr.Accept(&tree.ImplementedWalker{
func addTableIfNotPresent(tableName string, expr tree.AstWalker) error {
return expr.Walk(&tree.ImplementedListener{
FuncEnterExpressionColumn: func(col *tree.ExpressionColumn) error {
if col.Table == "" {
col.Table = tableName
Expand Down
8 changes: 4 additions & 4 deletions internal/engine/sqlanalyzer/attributes/select_core_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,18 +178,18 @@ func TestGetSelectCoreRelationAttributes(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ast, err := sqlparser.Parse(tt.stmt)
stmt, err := sqlparser.Parse(tt.stmt)
if err != nil {
t.Errorf("GetSelectCoreRelationAttributes() error = %v", err)
return
}
selectStmt, okj := ast.(*tree.Select)
selectStmt, okj := stmt.(*tree.SelectStmt)
if !okj {
t.Errorf("test case %s is not a select statement", tt.name)
return
}

got, err := attributes.GetSelectCoreRelationAttributes(selectStmt.SelectStmt.SelectCores[0], tt.tables)
got, err := attributes.GetSelectCoreRelationAttributes(selectStmt.Stmt.SelectCores[0], tt.tables)
if (err != nil) != tt.wantErr {
t.Errorf("GetSelectCoreRelationAttributes() error = %v, wantErr %v", err, tt.wantErr)
return
Expand All @@ -211,7 +211,7 @@ func TestGetSelectCoreRelationAttributes(t *testing.T) {

assert.ElementsMatch(t, tt.resultTableCols, genTable.Columns, "GetSelectCoreRelationAttributes() got = %v, want %v", got, tt.want)

sql, err := selectStmt.ToSQL()
sql, err := tree.SafeToSQL(selectStmt)
assert.NoErrorf(t, err, "error converting query to SQL: %s", err)

err = postgres.CheckSyntaxReplaceDollar(sql)
Expand Down
10 changes: 5 additions & 5 deletions internal/engine/sqlanalyzer/attributes/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ import (
// If it is invalid, it will return an error.
func predictReturnType(expr tree.Expression, tables []*common.Table) (common.DataType, error) {
w := &returnTypeWalker{
Walker: tree.NewBaseWalker(),
tables: tables,
AstListener: tree.NewBaseListener(),
tables: tables,
}

err := expr.Accept(w)
err := expr.Walk(w)
if err != nil {
return common.TEXT, fmt.Errorf("error predicting return type: %w", err)
}
Expand All @@ -38,13 +38,13 @@ func errReturnExpr(expr tree.Expression) error {
}

type returnTypeWalker struct {
tree.Walker
tree.AstListener
detected bool
detectedType common.DataType
tables []*common.Table
}

var _ tree.Walker = &returnTypeWalker{}
var _ tree.AstListener = &returnTypeWalker{}

func (r *returnTypeWalker) EnterExpressionArithmetic(p0 *tree.ExpressionArithmetic) error {
r.set(common.INT)
Expand Down
2 changes: 1 addition & 1 deletion internal/engine/sqlanalyzer/clean/clean.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ The walker in this package implements all the tree.Walker methods, even if it
doesn't do anything. This is to ensure that if we need to add more cleaning / validation
rules, we know that we've covered all the nodes.
For example, EnterDelete does nothing, but if we later set a limit on the amount of
For example, EnterDeleteStmt does nothing, but if we later set a limit on the amount of
CTEs allowed, then we would add it there.
*/
package clean
Expand Down
Loading

0 comments on commit 4a4632e

Please sign in to comment.