Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parse: tree update for visitor #602

Merged
merged 7 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
[submodule "proto"]
path = proto
url = https://github.com/kwilteam/proto
[submodule "parse/sql/antlr-grammar"]
path = parse/sql/antlr-grammar
url = https://github.com/kwilteam/sql-grammar
13 changes: 13 additions & 0 deletions Taskfile-antlr.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
version: '3'

tasks:
sql:
desc: Generate sql grammar go code.
cmds:
- rm -rf parse/sql/grammar/*
- rm -rf parse/sql/antlr-grammar/{gen,.antlr}/*
- cd parse/sql/antlr-grammar/ && ./generate.sh Go grammar ../grammar
sources:
- parse/sql/antlr-grammar/*.g4
generates:
- parse/sql/grammar/*.{go,interp,tokens}
1 change: 1 addition & 0 deletions Taskfile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ version: "3"

includes:
pb: ./Taskfile-pb.yml
antlr: ./Taskfile-antlr.yml

tasks:
default:
Expand Down
5 changes: 2 additions & 3 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ require (
github.com/jackc/pglogrepl v0.0.0-20231111135425-1627ab1b5780
github.com/jackc/pgx/v5 v5.5.2
github.com/jpillora/backoff v1.0.0
github.com/kwilteam/kuneiform v0.6.0
github.com/kwilteam/kuneiform v0.6.1-0.20240325162722-e0e0101eca92
github.com/kwilteam/kwil-db/core v0.1.0
github.com/kwilteam/kwil-db/parse v0.1.1
github.com/kwilteam/kwil-db/parse v0.1.2
github.com/kwilteam/kwil-extensions v0.0.0-20230727040522-1cfd930226b7
github.com/manifoldco/promptui v0.9.0
github.com/mitchellh/mapstructure v1.5.0
Expand Down Expand Up @@ -93,7 +93,6 @@ require (
github.com/jmhodges/levigo v1.0.0 // indirect
github.com/klauspost/compress v1.17.0 // indirect
github.com/kwilteam/action-grammar-go v0.1.1 // indirect
github.com/kwilteam/sql-grammar-go v0.1.0 // indirect
github.com/lib/pq v1.10.7 // indirect
github.com/libp2p/go-buffer-pool v0.1.0 // indirect
github.com/magiconair/properties v1.8.7 // indirect
Expand Down
6 changes: 2 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -274,12 +274,10 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kwilteam/action-grammar-go v0.1.1 h1:0NeWrIN0B+pQMyiTwW/kWtqLWl7P4ExmfHwaXaJ14zI=
github.com/kwilteam/action-grammar-go v0.1.1/go.mod h1:hHGHtnrJpASW9P+F7pdr/EP2M1Hxy1N9Wx/TmjVdV6I=
github.com/kwilteam/kuneiform v0.6.0 h1:Y8VWrJN1cl9idqX+LBSQd+c3m/JjDDRInBSKq3i27NY=
github.com/kwilteam/kuneiform v0.6.0/go.mod h1:b3Ce6falEDBQ0xgLpa/hjFjUQoD8aFEg96yewS/3wzg=
github.com/kwilteam/kuneiform v0.6.1-0.20240325162722-e0e0101eca92 h1:RoFJdrFt0zI6Y1t23PUTMYgykhCPlEOzxQ4nEd+ymFI=
github.com/kwilteam/kuneiform v0.6.1-0.20240325162722-e0e0101eca92/go.mod h1:+9V+E5I5sEL643ZHaeiF6bVCJqp56lGti+wXZdJ/YYQ=
github.com/kwilteam/kwil-extensions v0.0.0-20230727040522-1cfd930226b7 h1:YiPBu0pOeYOtOVfwKQqdWB07SUef9LvngF4bVFD+x34=
github.com/kwilteam/kwil-extensions v0.0.0-20230727040522-1cfd930226b7/go.mod h1:+BrFrV+3qcdYIfptqjwatE5gT19azuRHJzw77wMPY8c=
github.com/kwilteam/sql-grammar-go v0.1.0 h1:rSS7DER9PWVDmFwNyoInG5oXrn+E9UrZkjref84L4Qk=
github.com/kwilteam/sql-grammar-go v0.1.0/go.mod h1:A9AXaH5Vl/uPsY88fWqvU9O7z7P4YfvndaGyc8s//2s=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/leanovate/gopter v0.2.9 h1:fQjYxZaynp97ozCzfOyOuAGOU4aU/z37zf/tOujFk7c=
Expand Down
16 changes: 8 additions & 8 deletions internal/engine/execution/procedure.go
Original file line number Diff line number Diff line change
Expand Up @@ -453,16 +453,16 @@ func makeExecutables(exprs []tree.Expression) ([]evaluatable, error) {
}

// clean expression, since it is submitted by the user
err := expr.Accept(clean.NewStatementCleaner())
err := expr.Walk(clean.NewStatementCleaner())
if err != nil {
return nil, err
}

// The schema walker is not necessary for inline expressions, since
// we do not support table references in inline expressions.
accept := sqlanalyzer.NewAcceptRecoverer(expr)
paramVisitor := parameters.NewParametersVisitor()
err = accept.Accept(paramVisitor)
walker := sqlanalyzer.NewWalkerRecoverer(expr)
paramVisitor := parameters.NewParametersWalker()
err = walker.Walk(paramVisitor)
if err != nil {
return nil, fmt.Errorf("error replacing parameters: %w", err)
}
Expand All @@ -471,9 +471,9 @@ func makeExecutables(exprs []tree.Expression) ([]evaluatable, error) {
// statements This query needs to be run in "simple" execution mode
// rather than "extended" execution mode, which asks the database for
// OID (placeholder types) that it can't know since there's no FOR table.
selectTree := &tree.Select{
SelectStmt: &tree.SelectStmt{
SelectCores: []*tree.SelectCore{
selectTree := &tree.SelectStmt{
Stmt: &tree.SelectCore{
SelectCores: []*tree.SimpleSelect{
{
SelectType: tree.SelectTypeAll,
Columns: []tree.ResultColumn{
Expand All @@ -486,7 +486,7 @@ func makeExecutables(exprs []tree.Expression) ([]evaluatable, error) {
},
}

stmt, err := selectTree.ToSQL()
stmt, err := tree.SafeToSQL(selectTree)
if err != nil {
return nil, err
}
Expand Down
32 changes: 16 additions & 16 deletions internal/engine/sqlanalyzer/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,24 @@ import (
"github.com/kwilteam/kwil-db/parse/sql/tree"
)

// AcceptRecoverer is a wrapper around a statement that implements the accepter interface
// it catches panics and returns them as errors
type AcceptRecoverer struct {
tree.Accepter
// WalkerRecoverer is a wrapper around a statement that implements the AstWalker
// interface, it catches panics and returns them as errors
type WalkerRecoverer struct {
inner tree.AstWalker
}

func NewAcceptRecoverer(a tree.Accepter) *AcceptRecoverer {
return &AcceptRecoverer{a}
func NewWalkerRecoverer(a tree.AstWalker) *WalkerRecoverer {
return &WalkerRecoverer{a}
}

func (a *AcceptRecoverer) Accept(walker tree.Walker) (err error) {
func (a *WalkerRecoverer) Walk(walker tree.AstListener) (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic while walking statement: %v", r)
}
}()

return a.Accepter.Accept(walker)
return a.inner.Walk(walker)
}

// ApplyRules analyzes the given statement and returns the transformed statement.
Expand All @@ -48,38 +48,38 @@ func ApplyRules(stmt string, flags VerifyFlag, tables []*common.Table, pgSchemaN
return nil, fmt.Errorf("error parsing statement: %w", err)
}

accept := &AcceptRecoverer{parsed}
walker := &WalkerRecoverer{parsed}

clnr := clean.NewStatementCleaner()
err = accept.Accept(clnr)
err = walker.Walk(clnr)
if err != nil {
return nil, fmt.Errorf("error cleaning statement: %w", err)
}

schemaWalker := schema.NewSchemaWalker(pgSchemaName)
err = accept.Accept(schemaWalker)
err = walker.Walk(schemaWalker)
if err != nil {
return nil, fmt.Errorf("error applying schema rules: %w", err)
}

if flags&NoCartesianProduct != 0 {
err := accept.Accept(join.NewJoinWalker())
err := walker.Walk(join.NewJoinWalker())
if err != nil {
return nil, fmt.Errorf("error applying join rules: %w", err)
}
}

if flags&GuaranteedOrder != 0 {
err := accept.Accept(order.NewOrderWalker(cleanedTables))
err := walker.Walk(order.NewOrderWalker(cleanedTables))
if err != nil {
return nil, fmt.Errorf("error enforcing guaranteed order: %w", err)
}
}

orderedParams := make([]string, 0)
if flags&ReplaceNamedParameters != 0 {
paramVisitor := parameters.NewParametersVisitor()
err := accept.Accept(paramVisitor)
paramVisitor := parameters.NewParametersWalker()
err := walker.Walk(paramVisitor)
if err != nil {
return nil, fmt.Errorf("error replacing named parameters: %w", err)
}
Expand All @@ -91,7 +91,7 @@ func ApplyRules(stmt string, flags VerifyFlag, tables []*common.Table, pgSchemaN
return nil, fmt.Errorf("error determining mutativity: %w", err)
}

generated, err := parsed.ToSQL()
generated, err := tree.SafeToSQL(parsed)
if err != nil {
return nil, fmt.Errorf("error generating SQL: %w", err)
}
Expand Down
20 changes: 10 additions & 10 deletions internal/engine/sqlanalyzer/attributes/select_core.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ type RelationAttribute struct {
// tbl1.col, col, col AS alias, col*5 AS alias, etc.
// If a statement has "SELECT * FROM tbl",
// then the result column expressions will be tbl.col_1, tbl.col_2, etc.
func GetSelectCoreRelationAttributes(selectCore *tree.SelectCore, tables []*common.Table) ([]*RelationAttribute, error) {
func GetSelectCoreRelationAttributes(selectCore *tree.SimpleSelect, tables []*common.Table) ([]*RelationAttribute, error) {
walker := newSelectCoreWalker(tables)
err := selectCore.Accept(walker)
err := selectCore.Walk(walker)
if err != nil {
return nil, fmt.Errorf("error analyzing select core: %w", err)
}
Expand All @@ -55,7 +55,7 @@ func GetSelectCoreRelationAttributes(selectCore *tree.SelectCore, tables []*comm

func newSelectCoreWalker(tables []*common.Table) *selectCoreAnalyzer {
return &selectCoreAnalyzer{
Walker: tree.NewBaseWalker(),
AstListener: tree.NewBaseListener(),
context: newSelectCoreContext(nil),
schemaTables: tables,
detectedAttributes: []*RelationAttribute{},
Expand All @@ -64,7 +64,7 @@ func newSelectCoreWalker(tables []*common.Table) *selectCoreAnalyzer {

// selectCoreAnalyzer will walk the tree and identify the returned attributes for the select core
type selectCoreAnalyzer struct {
tree.Walker
tree.AstListener
context *selectCoreContext
schemaTables []*common.Table

Expand Down Expand Up @@ -197,14 +197,14 @@ func newSelectCoreContext(parent *selectCoreContext) *selectCoreContext {
}

// EnterSelectCore creates a new scope.
func (s *selectCoreAnalyzer) EnterSelectCore(node *tree.SelectCore) error {
func (s *selectCoreAnalyzer) EnterSelectCore(node *tree.SimpleSelect) error {
s.newScope()

return nil
}

// ExitSelectCore pops the current scope.
func (s *selectCoreAnalyzer) ExitSelectCore(node *tree.SelectCore) error {
func (s *selectCoreAnalyzer) ExitSelectCore(node *tree.SimpleSelect) error {
var err error
s.detectedAttributes, err = s.context.relations()
if err != nil {
Expand All @@ -215,8 +215,8 @@ func (s *selectCoreAnalyzer) ExitSelectCore(node *tree.SelectCore) error {
return nil
}

// EnterTableOrSubqueryTable adds the table to the list of used tables.
func (s *selectCoreAnalyzer) EnterTableOrSubqueryTable(node *tree.TableOrSubqueryTable) error {
// EnterRelationTable adds the table to the list of used tables.
func (s *selectCoreAnalyzer) EnterRelationTable(node *tree.RelationTable) error {
tbl, err := findTable(s.schemaTables, node.Name)
if err != nil {
return err
Expand Down Expand Up @@ -278,8 +278,8 @@ func findColumn(columns []*common.Column, name string) (*common.Column, error) {
}

// addTableIfNotPresent adds the table name to the column if it is not already present.
func addTableIfNotPresent(tableName string, expr tree.Accepter) error {
return expr.Accept(&tree.ImplementedWalker{
func addTableIfNotPresent(tableName string, expr tree.AstWalker) error {
return expr.Walk(&tree.ImplementedListener{
FuncEnterExpressionColumn: func(col *tree.ExpressionColumn) error {
if col.Table == "" {
col.Table = tableName
Expand Down
8 changes: 4 additions & 4 deletions internal/engine/sqlanalyzer/attributes/select_core_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,18 +178,18 @@ func TestGetSelectCoreRelationAttributes(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ast, err := sqlparser.Parse(tt.stmt)
stmt, err := sqlparser.Parse(tt.stmt)
if err != nil {
t.Errorf("GetSelectCoreRelationAttributes() error = %v", err)
return
}
selectStmt, okj := ast.(*tree.Select)
selectStmt, okj := stmt.(*tree.SelectStmt)
if !okj {
t.Errorf("test case %s is not a select statement", tt.name)
return
}

got, err := attributes.GetSelectCoreRelationAttributes(selectStmt.SelectStmt.SelectCores[0], tt.tables)
got, err := attributes.GetSelectCoreRelationAttributes(selectStmt.Stmt.SelectCores[0], tt.tables)
if (err != nil) != tt.wantErr {
t.Errorf("GetSelectCoreRelationAttributes() error = %v, wantErr %v", err, tt.wantErr)
return
Expand All @@ -211,7 +211,7 @@ func TestGetSelectCoreRelationAttributes(t *testing.T) {

assert.ElementsMatch(t, tt.resultTableCols, genTable.Columns, "GetSelectCoreRelationAttributes() got = %v, want %v", got, tt.want)

sql, err := selectStmt.ToSQL()
sql, err := tree.SafeToSQL(selectStmt)
assert.NoErrorf(t, err, "error converting query to SQL: %s", err)

err = postgres.CheckSyntaxReplaceDollar(sql)
Expand Down
10 changes: 5 additions & 5 deletions internal/engine/sqlanalyzer/attributes/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ import (
// If it is invalid, it will return an error.
func predictReturnType(expr tree.Expression, tables []*common.Table) (common.DataType, error) {
w := &returnTypeWalker{
Walker: tree.NewBaseWalker(),
tables: tables,
AstListener: tree.NewBaseListener(),
tables: tables,
}

err := expr.Accept(w)
err := expr.Walk(w)
if err != nil {
return common.TEXT, fmt.Errorf("error predicting return type: %w", err)
}
Expand All @@ -38,13 +38,13 @@ func errReturnExpr(expr tree.Expression) error {
}

type returnTypeWalker struct {
tree.Walker
tree.AstListener
detected bool
detectedType common.DataType
tables []*common.Table
}

var _ tree.Walker = &returnTypeWalker{}
var _ tree.AstListener = &returnTypeWalker{}

func (r *returnTypeWalker) EnterExpressionArithmetic(p0 *tree.ExpressionArithmetic) error {
r.set(common.INT)
Expand Down
2 changes: 1 addition & 1 deletion internal/engine/sqlanalyzer/clean/clean.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ The walker in this package implements all the tree.Walker methods, even if it
doesn't do anything. This is to ensure that if we need to add more cleaning / validation
rules, we know that we've covered all the nodes.

For example, EnterDelete does nothing, but if we later set a limit on the amount of
For example, EnterDeleteStmt does nothing, but if we later set a limit on the amount of
CTEs allowed, then we would add it there.
*/
package clean
Expand Down
Loading
Loading