From 0dd49779e3aa7bbaf337ca9d1c8155edf4eebb40 Mon Sep 17 00:00:00 2001 From: Yaiba <4yaiba@gmail.com> Date: Thu, 7 Mar 2024 17:24:46 -0600 Subject: [PATCH 1/2] SafeToSql util fn --- schema/validator.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/schema/validator.go b/schema/validator.go index 5df56c8..0aa9b35 100644 --- a/schema/validator.go +++ b/schema/validator.go @@ -244,22 +244,22 @@ func (c *ContextValidator) visitActions(actions []Action) error { } for _, statement := range a.Statements { - stmt, err := actparser.Parse(statement) + actStmt, err := actparser.Parse(statement) if err != nil { return fmt.Errorf("%w: %s", err, statement) } - switch s := stmt.(type) { + switch s := actStmt.(type) { case *actparser.DMLStmt: - astTree, err := sqlparser.ParseSql(statement, 1, nil, false) + sqlStmt, err := sqlparser.ParseSql(statement, 1, nil, false, false) if err != nil { return fmt.Errorf("%w: %s", err, statement) } - if _, err := astTree.ToSQL(); err != nil { + if _, err := tree.SafeToSQL(sqlStmt); err != nil { return fmt.Errorf("%w: %s", err, statement) } // TODO: validate reference in SQL statement - //switch t := astTree.(type) { + //switch t := sqlStmt.(type) { //case *tree.Select: case *actparser.ActionCallStmt: if _, ok := c.actionCtx[s.Method]; !ok { From e0e0101eca92819e2e503ccaa02d4633c7ca1945 Mon Sep 17 00:00:00 2001 From: Yaiba <4yaiba@gmail.com> Date: Mon, 25 Mar 2024 11:27:22 -0500 Subject: [PATCH 2/2] update gomod --- go.mod | 3 +-- go.sum | 6 ++---- schema/validator.go | 2 +- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/go.mod b/go.mod index 3cb5766..e62b79e 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.21 require ( github.com/antlr4-go/antlr/v4 v4.13.0 - github.com/kwilteam/kwil-db/parse v0.1.2 + github.com/kwilteam/kwil-db/parse v0.1.2-0.20240325162245-a841d745a8a2 github.com/spf13/cobra v1.7.0 github.com/stretchr/testify v1.8.4 ) @@ -14,7 +14,6 @@ require ( github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/kr/text v0.2.0 // indirect github.com/kwilteam/action-grammar-go v0.1.1 // indirect - github.com/kwilteam/sql-grammar-go v0.1.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rogpeppe/go-internal v1.9.0 // indirect github.com/spf13/pflag v1.0.5 // indirect diff --git a/go.sum b/go.sum index ab053ec..af743fd 100644 --- a/go.sum +++ b/go.sum @@ -12,10 +12,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kwilteam/action-grammar-go v0.1.1 h1:0NeWrIN0B+pQMyiTwW/kWtqLWl7P4ExmfHwaXaJ14zI= github.com/kwilteam/action-grammar-go v0.1.1/go.mod h1:hHGHtnrJpASW9P+F7pdr/EP2M1Hxy1N9Wx/TmjVdV6I= -github.com/kwilteam/kwil-db/parse v0.1.2 h1:RE8vzX+hZlDCz329hpoP7Az/v2BxfWZUQQQMkEVRWng= -github.com/kwilteam/kwil-db/parse v0.1.2/go.mod h1:lGzcCdSvjVrJj71nPuDQRYfF9Jnqj94wFPHEbbDq0+Y= -github.com/kwilteam/sql-grammar-go v0.1.0 h1:rSS7DER9PWVDmFwNyoInG5oXrn+E9UrZkjref84L4Qk= -github.com/kwilteam/sql-grammar-go v0.1.0/go.mod h1:A9AXaH5Vl/uPsY88fWqvU9O7z7P4YfvndaGyc8s//2s= +github.com/kwilteam/kwil-db/parse v0.1.2-0.20240325162245-a841d745a8a2 h1:EOY9P5bnQJoUzk7HqxiPkYxDM12pynq2Asy4PBpkZZQ= +github.com/kwilteam/kwil-db/parse v0.1.2-0.20240325162245-a841d745a8a2/go.mod h1:ZgelAtf4gWAOecXOrESXhFqRei3KvhJ4H7Ent28qAQs= github.com/pganalyze/pg_query_go/v5 v5.1.0 h1:MlxQqHZnvA3cbRQYyIrjxEjzo560P6MyTgtlaf3pmXg= github.com/pganalyze/pg_query_go/v5 v5.1.0/go.mod h1:FsglvxidZsVN+Ltw3Ai6nTgPVcK2BPukH3jCDEqc1Ug= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/schema/validator.go b/schema/validator.go index 0aa9b35..96b3e90 100644 --- a/schema/validator.go +++ b/schema/validator.go @@ -251,7 +251,7 @@ func (c *ContextValidator) visitActions(actions []Action) error { switch s := actStmt.(type) { case *actparser.DMLStmt: - sqlStmt, err := sqlparser.ParseSql(statement, 1, nil, false, false) + sqlStmt, err := sqlparser.Parse(statement) if err != nil { return fmt.Errorf("%w: %s", err, statement) }