diff --git a/internal/engine/generate/foreign_procedure.go b/internal/engine/generate/foreign_procedure.go index 24fc0fe20..1d3e1464c 100644 --- a/internal/engine/generate/foreign_procedure.go +++ b/internal/engine/generate/foreign_procedure.go @@ -7,90 +7,6 @@ import ( "github.com/kwilteam/kwil-db/core/types" ) -/* -Generated foreign procedure calls are responsible for making calls to procedures in other schemas. -The dbid and procedure can be passed in dynamically. It should follow the following general structure, with variations -for inputs/outputs: - -create or replace function _fp_template(_dbid TEXT, _procedure TEXT, _arg1 TEXT, _arg2 INT8, OUT _out_1 TEXT) as $$ -DECLARE - _schema_owner BYTEA; - _is_view BOOLEAN; - _is_owner_only BOOLEAN; - _is_public BOOLEAN; - _returns_table BOOLEAN; - _expected_input_types TEXT[]; - _expected_return_types TEXT[]; -BEGIN - - SELECT p.param_types, p.return_types, p.is_view, p.owner_only, p.public, s.owner, p.returns_table - INTO _expected_input_types, _expected_return_types, _is_view, _is_owner_only, _is_public, _schema_owner, _returns_table - FROM kwild_internal.procedures as p INNER JOIN kwild_internal.kwil_schemas as s - ON s.schema_id = p.id - WHERE p.name = _procedure AND s.dbid = _dbid; - - -- Schema owner cannot be nil, and will only be nil - -- if the procedure is not found - IF _schema_owner IS NULL THEN - RAISE EXCEPTION 'Procedure "%" not found in schema "%"', _procedure, _dbid; - END IF; - - -- we now ensure that: - -- 1. if we are in a read-only connection, that the procedure is view - -- 2. if it is owner only, the signer is the owner - -- 3. it is public - -- 4. our input types match the expected input types - -- 5. our return types match the expected return types - - -- 1. if we are in a read-only connection, that the procedure is view - IF _is_view = FALSE AND current_setting('is_read_only') = 'on' THEN - RAISE EXCEPTION 'Non-view procedure "%" called in view-only connection', _procedure; - END IF; - - -- 2. if it is owner only, the signer is the owner - IF _is_owner_only = TRUE AND _schema_owner != current_setting('ctx.signer')::BYTEA THEN - RAISE EXCEPTION 'Procedure "%" is owner-only and cannot be called by user "%" in schema "%"', _procedure, current_setting('ctx.signer'), _dbid; - END IF; - - -- 3. it is public - IF _is_public = FALSE THEN - RAISE EXCEPTION 'Procedure "%" is not public and cannot be foreign called', _procedure; - END IF; - - -- 4. our input types match the expected input types - IF array_length(_expected_input_types, 1) != 1 THEN - RAISE EXCEPTION 'Procedure "%" expects exactly one input type, but got %', _procedure, array_length(_expected_input_types, 1); - END IF; - - -- since _arg1 is text, we want to ensure that the input type is text - IF _expected_input_types[1] != 'TEXT' THEN - RAISE EXCEPTION 'Procedure "%" expects input type "TEXT", but got "%"', _procedure, _expected_input_types[1]; - END IF; - - IF _expected_input_types[2] != 'INT' THEN - RAISE EXCEPTION 'Procedure "%" expects input type "INT", but got "%"', _procedure, _expected_input_types[2]; - END IF; - - -- 5. our return types match the expected return types - IF array_length(_expected_return_types, 1) != 1 THEN - RAISE EXCEPTION 'Procedure "%" expects exactly one return type, but got %', _procedure, array_length(_expected_return_types, 1); - END IF; - - -- since _out_1 is text, we want to ensure that the return type is text - IF _expected_return_types[1] != 'TEXT' THEN - RAISE EXCEPTION 'Procedure "%" expects return type "TEXT", but got "%"', _procedure, _expected_return_types[1]; - END IF; - - -- we now call the procedure. we prefix ds_ to the dbid, as per the Kwil rules. - EXECUTE format('SELECT * FROM ds_%I.%I($1, $2)', _dbid, _procedure) INTO _out_1 USING _arg1, _arg2; - - -- or, to return a table, we can do: - -- RETURN QUERY EXECUTE format('SELECT * FROM ds_%I.%I($1, $2)', _dbid, _procedure) USING _arg1, _arg2; - -END; -$$ LANGUAGE plpgsql; -*/ - // This is implicitly // coupled to the schema defined in internal/engine.execution/queries.go, and therefore is implicitly // a circular dependency. I am unsure how to resolve this, but am punting on it for now since the structure @@ -175,6 +91,7 @@ func GenerateForeignProcedure(proc *types.ForeignProcedure, pgSchema string) (st _is_public BOOLEAN; _returns_table BOOLEAN; _expected_input_types TEXT[]; + _expected_return_names TEXT[]; _expected_return_types TEXT[];`) // begin block @@ -182,8 +99,8 @@ func GenerateForeignProcedure(proc *types.ForeignProcedure, pgSchema string) (st // select the procedure info, and perform checks 1-3 str.WriteString(` - SELECT p.param_types, p.return_types, p.is_view, p.owner_only, p.public, s.owner, p.returns_table - INTO _expected_input_types, _expected_return_types, _is_view, _is_owner_only, _is_public, _schema_owner, _returns_table + SELECT p.param_types, p.return_types, p.return_names, p.is_view, p.owner_only, p.public, s.owner, p.returns_table + INTO _expected_input_types, _expected_return_types, _expected_return_names, _is_view, _is_owner_only, _is_public, _schema_owner, _returns_table FROM kwild_internal.procedures as p INNER JOIN kwild_internal.kwil_schemas as s ON p.schema_id = s.id WHERE p.name = _procedure AND s.dbid = _dbid; @@ -192,7 +109,7 @@ func GenerateForeignProcedure(proc *types.ForeignProcedure, pgSchema string) (st RAISE EXCEPTION 'Procedure "%" not found in schema "%"', _procedure, _dbid; END IF; - IF _is_view = FALSE AND current_setting('is_read_only') = 'on' THEN + IF _is_view = FALSE AND current_setting('transaction_read_only')::boolean = true THEN RAISE EXCEPTION 'Non-view procedure "%" called in view-only connection', _procedure; END IF; @@ -213,17 +130,17 @@ func GenerateForeignProcedure(proc *types.ForeignProcedure, pgSchema string) (st // first check the length of the array str.WriteString(fmt.Sprintf(` IF array_length(_expected_input_types, 1) IS NOT NULL THEN - RAISE EXCEPTION 'Foreign procedure definition "%s" expects no inputs, but got procedure "%%" requires %% inputs', _procedure, array_length(_expected_input_types, 1); + RAISE EXCEPTION 'Foreign procedure definition "%s" expects no args, but procedure "%%" located at DBID "%%" requires %% arg(s)', _procedure, _dbid, array_length(_expected_input_types, 1); END IF; `, proc.Name)) } else { str.WriteString(fmt.Sprintf(` IF array_length(_expected_input_types, 1) IS NULL THEN - RAISE EXCEPTION 'Foreign procedure definition "%s" expects %d inputs, but procedure "%%" requires no inputs', _procedure; + RAISE EXCEPTION 'Foreign procedure definition "%s" expects %d args, but procedure "%%" located at DBID "%%" requires no args', _procedure, _dbid; END IF; IF array_length(_expected_input_types, 1) != %d THEN - RAISE EXCEPTION 'Foreign procedure definition "%s" expects %d inputs, but procedure "%%" requires %% inputs', _procedure, array_length(_expected_input_types, 1); + RAISE EXCEPTION 'Foreign procedure definition "%s" expects %d args, but procedure "%%" located at DBID "%%" requires %% arg(s)', _procedure, _dbid, array_length(_expected_input_types, 1); END IF;`, proc.Name, len(proc.Parameters), len(proc.Parameters), proc.Name, len(proc.Parameters))) } @@ -231,48 +148,58 @@ func GenerateForeignProcedure(proc *types.ForeignProcedure, pgSchema string) (st for i, in := range proc.Parameters { str.WriteString(fmt.Sprintf(` IF _expected_input_types[%d] != '%s' THEN - RAISE EXCEPTION 'Foreign procedure definition "%s" expects input type "%s", but procedure "%%" requires %%', _procedure, _expected_input_types[%d]; + RAISE EXCEPTION 'Foreign procedure definition "%s" expects arg type "%s", but procedure "%%" located at DBID "%%" requires %%', _procedure, _dbid, _expected_input_types[%d]; END IF;`, i+1, in.String(), proc.Name, in.String(), i+1)) } - // if the proc returns a table, ensure that the called procedure returns a table - if proc.Returns != nil && proc.Returns.IsTable { + // if there is an expected return, check that the return fields are the same count and type. + // If it returns a table, also check to make sure that the return names are the same. + if proc.Returns != nil { + // if foreign proc returns a table, check that the called procedure returns a table + // if foreign proc does not return a table, check that the called procedure does not return a table + if proc.Returns.IsTable { + str.WriteString(fmt.Sprintf(` + IF _returns_table = FALSE THEN + RAISE EXCEPTION 'Foreign procedure definition "%s" expects a table return, but procedure "%%" located at DBID "%%" does not return a table', _procedure, _dbid; + END IF;`, proc.Name)) + } else { + str.WriteString(fmt.Sprintf(` + IF _returns_table = TRUE THEN + RAISE EXCEPTION 'Foreign procedure definition "%s" expects a non-table return, but procedure "%%" located at DBID "%%" returns a table', _procedure, _dbid; + END IF;`, proc.Name)) + } + str.WriteString(fmt.Sprintf(` - IF _returns_table = FALSE THEN - RAISE EXCEPTION 'Foreign procedure definition "%s" expects a table return, but procedure "%%" does not return a table', _procedure; - END IF;`, proc.Name)) + IF array_length(_expected_return_types, 1) IS NULL THEN + RAISE EXCEPTION 'Foreign procedure definition "%s" expects %d returns, but procedure "%%" located at DBID "%%" returns nothing', _procedure, _dbid; + END IF; + + IF array_length(_expected_return_types, 1) != %d THEN + RAISE EXCEPTION 'Foreign procedure definition "%s" expects %d returns, but procedure "%%" located at DBID "%%" returns %% fields', _procedure, _dbid, array_length(_expected_return_types, 1); + END IF;`, proc.Name, len(proc.Returns.Fields), len(proc.Returns.Fields), proc.Name, len(proc.Returns.Fields))) - // If a table is returned, we also need to ensure that it returns the exact correct column names and types. + // check that the return types match for i, out := range proc.Returns.Fields { - // check the return type str.WriteString(fmt.Sprintf(` - IF _expected_return_types[%d] != '%s' THEN - RAISE EXCEPTION 'Foreign procedure definition "%s" expects return type "%s" at column position %d, but procedure "%%" requires %%', _procedure, _expected_return_types[%d]; - END IF;`, i+1, out.Type.String(), proc.Name, out.Type.String(), i+1, i+1)) + IF _expected_return_types[%d] != '%s' THEN + RAISE EXCEPTION 'Foreign procedure definition "%s" expects return type "%s" at return position %d, but procedure "%%" located at DBID "%%" returns %%', _procedure, _dbid, _expected_return_types[%d]; + END IF;`, i+1, out.Type.String(), proc.Name, out.Type.String(), i+1, i+1)) - // check the return name - str.WriteString(fmt.Sprintf(` - IF _expected_return_types[%d] != '%s' THEN - RAISE EXCEPTION 'Foreign procedure definition "%s" expects return name "%s" at column position %d, but procedure "%%" requires %%', _procedure, _expected_return_types[%d]; - END IF;`, i+1, out.Name, proc.Name, out.Name, i+1, i+1)) + // if it returns a table, check that the return names match + if proc.Returns.IsTable { + str.WriteString(fmt.Sprintf(` + IF _expected_return_names[%d] != '%s' THEN + RAISE EXCEPTION 'Foreign procedure definition "%s" expects return name "%s" at return column position %d, but procedure "%%" located at DBID "%%" returns %%', _procedure, _dbid, _expected_return_names[%d]; + END IF;`, i+1, out.Name, proc.Name, out.Name, i+1, i+1)) + } } + } else { - // else check that the called procedure does not return a table + // if not expecting returns, ensure that the expected return types are nil str.WriteString(fmt.Sprintf(` - IF _returns_table = TRUE THEN - RAISE EXCEPTION 'Foreign procedure definition "%s" expects a non-table return, but procedure "%%" returns a table', _procedure; + IF _expected_return_types IS NOT NULL THEN + RAISE EXCEPTION 'Foreign procedure definition "%s" expects no returns, but procedure "%%" located at DBID "%%" returns non-nil value(s)', _procedure, _dbid; END IF;`, proc.Name)) - - // since we are not returning a table, if we are returning anything, we need to check that the return types match. - // we do not care about the return names, since they are not tables. - if proc.Returns != nil { - for i, out := range proc.Returns.Fields { - str.WriteString(fmt.Sprintf(` - IF _expected_return_types[%d] != '%s' THEN - RAISE EXCEPTION 'Foreign procedure definition "%s" expects return type "%s" at return position %d, but procedure "%%" requires %%', _procedure, _expected_return_types[%d]; - END IF;`, i+1, out.Type.String(), proc.Name, out.Type.String(), i+1, i+1)) - } - } } // now we call the procedure. diff --git a/internal/engine/generate/foreign_procedure_test.go b/internal/engine/generate/foreign_procedure_test.go deleted file mode 100644 index 524b40be6..000000000 --- a/internal/engine/generate/foreign_procedure_test.go +++ /dev/null @@ -1,101 +0,0 @@ -package generate - -import ( - "strings" - "testing" - "unicode" - - "github.com/kwilteam/kwil-db/core/types" - "github.com/stretchr/testify/require" -) - -var schemaName = "schema" - -func Test_ForeignProcedureGen(t *testing.T) { - type testcase struct { - name string - procedure *types.ForeignProcedure - want string - } - - tests := []testcase{ - { - name: "procedure has no inputs and no outputs", - procedure: &types.ForeignProcedure{ - Name: "test", - }, - want: ` -CREATE OR REPLACE FUNCTION schema._fp_test(_dbid TEXT, _procedure TEXT) RETURNS VOID AS $$ -DECLARE - _schema_owner BYTEA; - _is_view BOOLEAN; - _is_owner_only BOOLEAN; - _is_public BOOLEAN; - _returns_table BOOLEAN; - _expected_input_types TEXT[]; - _expected_return_types TEXT[]; -BEGIN - - SELECT p.param_types, p.return_types, p.is_view, p.owner_only, p.public, s.owner, p.returns_table - INTO _expected_input_types, _expected_return_types, _is_view, _is_owner_only, _is_public, _schema_owner, _returns_table - FROM kwild_internal.procedures as p INNER JOIN kwild_internal.kwil_schemas as s - ON p.schema_id = s.id - WHERE p.name = _procedure AND s.dbid = _dbid; - - IF _schema_owner IS NULL THEN - RAISE EXCEPTION 'Procedure "%" not found in schema "%"', _procedure, _dbid; - END IF; - - IF _is_view = FALSE AND current_setting('is_read_only') = 'on' THEN - RAISE EXCEPTION 'Non-view procedure "%" called in view-only connection', _procedure; - END IF; - - IF _is_owner_only = TRUE AND _schema_owner != current_setting('ctx.signer')::BYTEA THEN - RAISE EXCEPTION 'Procedure "%" is owner-only and cannot be called by signer "%" in schema "%"', _procedure, current_setting('ctx.signer'), _dbid; - END IF; - - IF _is_public = FALSE THEN - RAISE EXCEPTION 'Procedure "%" is not public and cannot be foreign called', _procedure; - END IF; - - IF array_length(_expected_input_types, 1) IS NOT NULL THEN - RAISE EXCEPTION 'Foreign procedure definition "test" expects no inputs, but got procedure "%" requires % inputs', _procedure, array_length(_expected_input_types, 1); - END IF; - - IF _returns_table = TRUE THEN - RAISE EXCEPTION 'Foreign procedure definition "test" expects a non-table return, but procedure "%" returns a table', _procedure; - END IF; - -EXECUTE format('SELECT * FROM ds_%I.%I()', _dbid, _procedure); - -END; -$$ LANGUAGE plpgsql;`, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - got, err := GenerateForeignProcedure(test.procedure, schemaName) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - // fmt.Println("GOT:") - // fmt.Println(got) - // fmt.Printf("\n\n\nWANT:") - // fmt.Println(test.want) - // panic("") - - require.Equal(t, removeWhitespace(test.want), removeWhitespace(got)) - }) - } -} - -func removeWhitespace(s string) string { - return strings.Map(func(r rune) rune { - if unicode.IsSpace(r) { - return -1 - } - return r - }, s) -} diff --git a/internal/engine/integration/procedure_test.go b/internal/engine/integration/procedure_test.go index e3ec8f096..c31a85aef 100644 --- a/internal/engine/integration/procedure_test.go +++ b/internal/engine/integration/procedure_test.go @@ -11,7 +11,11 @@ import ( "testing" "github.com/kwilteam/kwil-db/common" + "github.com/kwilteam/kwil-db/common/sql" "github.com/kwilteam/kwil-db/core/crypto" + "github.com/kwilteam/kwil-db/core/types" + "github.com/kwilteam/kwil-db/core/utils/order" + "github.com/kwilteam/kwil-db/internal/engine/execution" "github.com/kwilteam/kwil-db/parse" "github.com/stretchr/testify/require" ) @@ -21,67 +25,6 @@ import ( // as mock data. The test is then able to define its own procedure, the inputs, // outputs, and expected error (if any). func Test_Procedures(t *testing.T) { - schema := ` - database ecclesia; - - table users { - id uuid primary key, - name text not null maxlen(100) minlen(4) unique, - wallet_address text not null - } - - table posts { - id uuid primary key, - user_id uuid not null, - content text not null maxlen(300), - foreign key (user_id) references users(id) - } - - procedure create_user($name text) public { - INSERT INTO users (id, name, wallet_address) - VALUES (uuid_generate_v5('985b93a4-2045-44d6-bde4-442a4e498bc6'::uuid, @txid), - $name, - @caller - ); - } - - procedure owns_user($wallet text, $name text) public view returns (owns bool) { - $exists bool := false; - for $row in SELECT * FROM users WHERE wallet_address = $wallet - AND name = $name { - $exists := true; - } - - return $exists; - } - - procedure id_from_name($name text) public view returns (id uuid) { - for $row in SELECT id FROM users WHERE name = $name { - return $row.id; - } - error('user not found'); - } - - procedure create_post($username text, $content text) public { - if owns_user(@caller, $username) == false { - error('caller does not own user'); - } - - INSERT INTO posts (id, user_id, content) - VALUES (uuid_generate_v5('985b93a4-2045-44d6-bde4-442a4e498bc6'::uuid, @txid), - id_from_name($username), - $content - ); - } - ` - - // maps usernames to post content. - initialData := map[string][]string{ - "satoshi": {"hello world", "goodbye world", "buy $btc to grow laser eyes"}, - "zeus": {"i am zeus", "i am the god of thunder", "i am the god of lightning"}, - "wendys_drive_through_lady": {"hi how can I help you", "no I don't know what the federal reserve is", "sir this is a wendys"}, - } - type testcase struct { name string procedure string @@ -203,6 +146,20 @@ func Test_Procedures(t *testing.T) { inputs: []any{hex.EncodeToString([]byte("hello"))}, outputs: [][]any{{base64.StdEncoding.EncodeToString([]byte("hello")), []byte("hello"), crypto.Sha256([]byte("hello"))}}, }, + { + name: "join on subquery", + procedure: `procedure join_on_subquery() public view returns table(name text, content text) { + return SELECT u.name, p.content FROM users u + INNER JOIN (select content, user_id from posts) p ON u.id = p.user_id + WHERE u.name = 'satoshi'; + }`, + // should come out LIFO, due to default ordering + outputs: [][]any{ + {"satoshi", "buy $btc to grow laser eyes"}, + {"satoshi", "goodbye world"}, + {"satoshi", "hello world"}, + }, + }, } for _, test := range tests { @@ -220,56 +177,198 @@ func Test_Procedures(t *testing.T) { defer tx.Rollback(ctx) // deploy schema - parsed, err := parse.ParseSchema([]byte(schema + test.procedure)) - require.NoError(t, err) - require.NoError(t, parsed.Err()) + dbid := deployAndSeed(t, global, tx, test.procedure) - err = global.CreateDataset(ctx, tx, parsed.Schema, &common.TransactionData{ - Signer: []byte("deployer"), - Caller: "deployer", - TxID: "deploydb", + // parse out procedure name + procedureName := parseProcedureName(test.procedure) + + // execute test procedure + res, err := global.Procedure(ctx, tx, &common.ExecutionData{ + TransactionData: common.TransactionData{ + Signer: []byte("test_signer"), + Caller: "test_caller", + TxID: "test", + }, + Dataset: dbid, + Procedure: procedureName, + Args: test.inputs, }) + if test.err != nil { + require.Error(t, err) + require.ErrorIs(t, err, test.err) + return + } require.NoError(t, err) - // get dbid - dbs, err := global.ListDatasets([]byte("deployer")) - require.NoError(t, err) - require.Len(t, dbs, 1) - dbid := dbs[0].DBID - - // create initial data - for username, posts := range initialData { - _, err = global.Procedure(ctx, tx, &common.ExecutionData{ - TransactionData: common.TransactionData{ - Signer: []byte("username_signer"), - Caller: "username_caller", - TxID: "create_user_" + username, - }, - Dataset: dbid, - Procedure: "create_user", - Args: []any{username}, - }) - require.NoError(t, err) - - for i, post := range posts { - _, err = global.Procedure(ctx, tx, &common.ExecutionData{ - TransactionData: common.TransactionData{ - Signer: []byte("username_signer"), - Caller: "username_caller", - TxID: "create_post_" + username + "_" + fmt.Sprint(i), - }, - Dataset: dbid, - Procedure: "create_post", - Args: []any{username, post}, - }) - require.NoError(t, err) + require.Len(t, res.Rows, len(test.outputs)) + + for i, output := range test.outputs { + require.Len(t, res.Rows[i], len(output)) + for j, val := range output { + require.Equal(t, val, res.Rows[i][j]) } } + }) + } +} - // parse out procedure name - procs := strings.Split(test.procedure, " ") - procedureName := strings.Split(procs[1], "(")[0] - procedureName = strings.TrimSpace(procedureName) +func Test_ForeignProcedures(t *testing.T) { + type testcase struct { + name string + // foreign is the foreign procedure definition. + // It will be deployed in a separate schema. + foreign string + // otherProc is the procedure that calls the foreign procedure. + // It will be included with the foreign procedure. + // It should be formattable to allow the caller to format with + // the target dbid, and the target procedure should be hardcoded. + otherProc string + // inputs are the inputs to the test procedure. + inputs []any + // outputs are the expected outputs from the test procedure. + outputs [][]any + // if wantErr is not empty, the test will expect an error containing this string. + // We use a string, instead go Go's error type, because we are reading errors raised + // from Postgres, which are strings. + wantErr string + } + + tests := []testcase{ + { + name: "foreign procedure takes nothing, returns nothing", + foreign: `foreign procedure do_something()`, + otherProc: `procedure call_foreign() public { + do_something['%s', 'delete_users'](); + }`, + }, + { + name: "foreign procedure takes nothing, returns table", + foreign: `foreign procedure get_users() returns table(id uuid, name text, wallet_address text)`, + otherProc: `procedure call_foreign() public returns table(username text) { + return select name as username from get_users['%s', 'get_users'](); + }`, + outputs: [][]any{ + {"satoshi"}, + {"wendys_drive_through_lady"}, + {"zeus"}, + }, + }, + { + name: "foreign procedure takes values, returns values", + foreign: `foreign procedure id_from_name($name text) returns (id uuid)`, + otherProc: `procedure call_foreign($name text) public returns (id uuid) { + return id_from_name['%s', 'id_from_name']($name); + }`, + inputs: []any{"satoshi"}, + outputs: [][]any{{satoshisUUID}}, + }, + { + name: "foreign procedure expects no args, implementation expects some", + foreign: `foreign procedure id_from_name() returns (id uuid)`, + otherProc: `procedure call_foreign() public returns (id uuid) { + return id_from_name['%s', 'id_from_name'](); + }`, + wantErr: `requires 1 arg(s)`, + }, + { + name: "foreign procedure expects args, implementation expects none", + foreign: `foreign procedure get_users($name text) returns table(id uuid, name text, wallet_address text)`, + otherProc: `procedure call_foreign() public returns table(username text) { + return select name as username from get_users['%s', 'get_users']('satoshi'); + }`, + wantErr: "requires no args", + }, + { + name: "foreign procedure expects 2 args, implementation expects 2", + foreign: `foreign procedure id_from_name($name text, $name2 text) returns (id uuid)`, + otherProc: `procedure call_foreign() public returns (id uuid) { + return id_from_name['%s', 'id_from_name']('satoshi', 'zeus'); + }`, + wantErr: "requires 1 arg(s)", + }, + { + name: "foreign procedure returns 1 arg, implementation returns none", + foreign: `foreign procedure delete_users() returns (text)`, + otherProc: `procedure call_foreign() public returns (text) { + return delete_users['%s', 'delete_users'](); + }`, + wantErr: "returns nothing", + }, + { + name: "foreign procedure returns 0 args, implementation returns 1", + foreign: `foreign procedure id_from_name($name text)`, + otherProc: `procedure call_foreign() public { + id_from_name['%s', 'id_from_name']('satoshi'); + }`, + wantErr: "returns non-nil value(s)", + }, + { + name: "foreign procedure returns table, implementation returns non-table", + foreign: `foreign procedure id_from_name($name text) returns table(id uuid)`, + otherProc: `procedure call_foreign() public { + select id from id_from_name['%s', 'id_from_name']('satoshi'); + }`, + wantErr: "does not return a table", + }, + { + name: "foreign procedure does not return table, implementation returns table", + foreign: `foreign procedure get_users() returns (id uuid, name text, wallet_address text)`, + otherProc: `procedure call_foreign() public returns table(username text) { + $id, $name, $wallet := get_users['%s', 'get_users'](); + }`, + wantErr: "returns a table", + }, + { + name: "foreign procedure returns table, implementation returns nothing", + foreign: `foreign procedure create_user($name text) returns table(id uuid)`, + otherProc: `procedure call_foreign() public { + create_user['%s', 'create_user']('satoshi'); + }`, + wantErr: "does not return a table", + }, + { + name: "procedures returning scalar return different named values (ok)", + // returns value "uid" instead of impl's "id" + foreign: `foreign procedure id_from_name($name text) returns (uid uuid)`, + otherProc: `procedure call_foreign() public returns (id uuid) { + return id_from_name['%s', 'id_from_name']('satoshi'); + }`, + outputs: [][]any{{satoshisUUID}}, + }, + { + name: "procedure returning table return different column names (failure)", + foreign: `foreign procedure get_users() returns table(uid uuid, name text, wallet_address text)`, + otherProc: `procedure call_foreign() public returns table(name text) { + return select name from get_users['%s', 'get_users'](); + }`, + wantErr: "returns id", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + global, db, err := setup(t) + if err != nil { + t.Fatal(err) + } + defer cleanup(t, db) + + ctx := context.Background() + + tx, err := db.BeginOuterTx(ctx) + require.NoError(t, err) + defer tx.Rollback(ctx) + + // deploy the main test schema + foreignDBID := deployAndSeed(t, global, tx) + + // deploy the new schema that will call the main one + // first, format the procedure with the foreign DBID + otherProc := fmt.Sprintf(test.otherProc, foreignDBID) + // deploy the new schema + mainDBID := deploy(t, global, tx, fmt.Sprintf("database db2;\n%s\n%s", test.foreign, otherProc)) + + procedureName := parseProcedureName(otherProc) // execute test procedure res, err := global.Procedure(ctx, tx, &common.ExecutionData{ @@ -278,19 +377,18 @@ func Test_Procedures(t *testing.T) { Caller: "test_caller", TxID: "test", }, - Dataset: dbid, + Dataset: mainDBID, Procedure: procedureName, Args: test.inputs, }) - if test.err != nil { + if test.wantErr != "" { require.Error(t, err) - require.ErrorIs(t, err, test.err) + require.Contains(t, err.Error(), test.wantErr) return } require.NoError(t, err) require.Len(t, res.Rows, len(test.outputs)) - for i, output := range test.outputs { require.Len(t, res.Rows[i], len(output)) for j, val := range output { @@ -300,3 +398,147 @@ func Test_Procedures(t *testing.T) { }) } } + +// testSchema is a schema that can be deployed with deployAndSeed +var testSchema = ` +database ecclesia; + +table users { + id uuid primary key, + name text not null maxlen(100) minlen(4) unique, + wallet_address text not null +} + +table posts { + id uuid primary key, + user_id uuid not null, + content text not null maxlen(300), + foreign key (user_id) references users(id) on delete cascade +} + +procedure create_user($name text) public { + INSERT INTO users (id, name, wallet_address) + VALUES (uuid_generate_v5('985b93a4-2045-44d6-bde4-442a4e498bc6'::uuid, @txid), + $name, + @caller + ); +} + +procedure owns_user($wallet text, $name text) public view returns (owns bool) { + $exists bool := false; + for $row in SELECT * FROM users WHERE wallet_address = $wallet + AND name = $name { + $exists := true; + } + + return $exists; +} + +procedure id_from_name($name text) public view returns (id uuid) { + for $row in SELECT id FROM users WHERE name = $name { + return $row.id; + } + error('user not found'); +} + +procedure create_post($username text, $content text) public { + if owns_user(@caller, $username) == false { + error('caller does not own user'); + } + + INSERT INTO posts (id, user_id, content) + VALUES (uuid_generate_v5('985b93a4-2045-44d6-bde4-442a4e498bc6'::uuid, @txid), + id_from_name($username), + $content + ); +} + +// the following procedures serve no utility, and are made only to test foreign calls +// to different signatures. +procedure delete_users() public { + DELETE FROM users; +} + +procedure get_users() public returns table(id uuid, name text, wallet_address text) { + return SELECT * FROM users; +} +` + +// maps usernames to post content. +var initialData = map[string][]string{ + "satoshi": {"hello world", "goodbye world", "buy $btc to grow laser eyes"}, + "zeus": {"i am zeus", "i am the god of thunder", "i am the god of lightning"}, + "wendys_drive_through_lady": {"hi how can I help you", "no I don't know what the federal reserve is", "sir this is a wendys"}, +} + +var satoshisUUID = &types.UUID{0x38, 0xeb, 0x77, 0xcb, 0x1e, 0x5a, 0x56, 0xc0, 0x85, 0x63, 0x2e, 0x25, 0x34, 0xd6, 0x7b, 0x96} + +// deploy deploys a schema +func deploy(t *testing.T, global *execution.GlobalContext, db sql.DB, schema string) (dbid string) { + ctx := context.Background() + + parsed, err := parse.ParseAndValidate([]byte(schema)) + require.NoError(t, err) + require.NoError(t, parsed.Err()) + + d := txData() + err = global.CreateDataset(ctx, db, parsed.Schema, &d) + require.NoError(t, err) + + // get dbid + dbs, err := global.ListDatasets(owner) + require.NoError(t, err) + + for _, db := range dbs { + if db.Name == parsed.Schema.Name { + dbid = db.DBID + break + } + } + + return dbid +} + +// deployAndSeed deploys the test schema and seeds it with data +func deployAndSeed(t *testing.T, global *execution.GlobalContext, db sql.DB, extraProcedures ...string) (dbid string) { + ctx := context.Background() + + schema := testSchema + for _, proc := range extraProcedures { + schema += proc + "\n" + } + + // deploy schema + dbid = deploy(t, global, db, schema) + + // create initial data + for _, kv := range order.OrderMap(initialData) { + _, err := global.Procedure(ctx, db, &common.ExecutionData{ + TransactionData: txData(), + Dataset: dbid, + Procedure: "create_user", + Args: []any{kv.Key}, + }) + require.NoError(t, err) + + for _, post := range kv.Value { + _, err = global.Procedure(ctx, db, &common.ExecutionData{ + TransactionData: txData(), + Dataset: dbid, + Procedure: "create_post", + Args: []any{kv.Key, post}, + }) + require.NoError(t, err) + } + } + + return dbid +} + +// parseProcedureName parses the procedure name from a procedure definition +func parseProcedureName(proc string) string { + procs := strings.Split(proc, " ") + procedureName := strings.Split(procs[1], "(")[0] + procedureName = strings.TrimSpace(procedureName) + return procedureName +} diff --git a/internal/engine/integration/schema_test.go b/internal/engine/integration/schema_test.go index a62438c4e..88597d651 100644 --- a/internal/engine/integration/schema_test.go +++ b/internal/engine/integration/schema_test.go @@ -167,6 +167,40 @@ func Test_Schemas(t *testing.T) { require.Equal(t, int64(100), res.Rows[0][0]) }, }, + { + name: "write data to foreign procedure", + fn: func(t *testing.T, global *execution.GlobalContext, db sql.DB) { + usersDBID, social_media, _ := deployAllSchemas(t, global, db) + + ctx := context.Background() + + // create user. we do this in the social_media db to ensure the + // procedure can write to a foreign dataset + _, err := global.Procedure(ctx, db, &common.ExecutionData{ + Dataset: social_media, + Procedure: "create_user", + Args: []any{"satoshi"}, + TransactionData: txData(), + }) + require.NoError(t, err) + + // get the user by name + res, err := global.Procedure(ctx, db, &common.ExecutionData{ + Dataset: usersDBID, + Procedure: "get_user_by_name", + Args: []any{"satoshi"}, + TransactionData: txData(), + }) + require.NoError(t, err) + + // check the columns. should be owner, name + require.Len(t, res.Columns, 2) + + // check the values + require.Len(t, res.Rows, 1) + require.Equal(t, "test_owner", res.Rows[0][1]) + }, + }, } for _, tc := range testCases { diff --git a/internal/engine/integration/schemas/social_media.kf b/internal/engine/integration/schemas/social_media.kf index ca05e440a..447caacde 100644 --- a/internal/engine/integration/schemas/social_media.kf +++ b/internal/engine/integration/schemas/social_media.kf @@ -82,11 +82,11 @@ procedure get_recent_posts_by_size($username text, $size int, $limit int) public } } - // the following foreign procedures define procedures // that the users db has. It returns redundant data since it needs // to match the procedure signature defines in users. foreign procedure get_user($address text) returns (uuid, text) +foreign procedure foreign_create_user(text) // table keyvalue is a kv table to track metadata // for foreign calls. @@ -119,4 +119,12 @@ procedure get_user_id($address text) public view returns (id uuid) { $user_id, _ := get_user[$dbid, $procedure](@caller); return $user_id; -} \ No newline at end of file +} + +// this simply tests that we can write data to foreign procedures. +procedure create_user($name text) public { + $dbid text := admin_get('dbid'); + $procedure text := admin_get('userbyowner'); + + foreign_create_user[$dbid, 'create_user']($name); +} diff --git a/internal/engine/integration/setup_test.go b/internal/engine/integration/setup_test.go index 761500052..9c7627ffc 100644 --- a/internal/engine/integration/setup_test.go +++ b/internal/engine/integration/setup_test.go @@ -25,6 +25,9 @@ func TestMain(m *testing.M) { // cleanup deletes all schemas and closes the database func cleanup(t *testing.T, db *pg.DB) { + txCounter = 0 // reset the global tx counter, which is necessary to properly + // encapsulate each test and make their results independent of each other + db.AutoCommit(true) defer db.AutoCommit(false) defer db.Close() diff --git a/parse/analyze.go b/parse/analyze.go index eee4460fb..796747650 100644 --- a/parse/analyze.go +++ b/parse/analyze.go @@ -4,7 +4,6 @@ import ( "fmt" "github.com/kwilteam/kwil-db/core/types" - "github.com/kwilteam/kwil-db/core/utils/order" ) /* @@ -420,29 +419,30 @@ func (s *sqlAnalyzer) expectedNumeric(node Node, t *types.DataType) { // but it returns something else. It will attempt to read the actual type and create an error // message that is helpful for the end user. func (s *sqlAnalyzer) expressionTypeErr(e Expression) *types.DataType { - // if expression is a receiver from a loop, it will be a map - _, ok := e.Accept(s).(map[string]*types.DataType) - if ok { + switch v := e.Accept(s).(type) { + case *types.DataType: + // if it is a basic expression returning a scalar (e.g. "'hello'" or "abs(-1)"), + // or a procedure that returns exactly one scalar value. + // This should never happen, since expressionTypeErr is called when the expression + // does not return a *types.DataType. + panic("api misuse: expressionTypeErr should only be called when the expression does not return a *types.DataType") + case map[string]*types.DataType: + // if it is a loop receiver on a select statement (e.g. "for $row in select * from table") s.errs.AddErr(e, ErrType, "invalid usage of compound type. you must reference a field using $compound.field notation") - return cast(e, types.UnknownType) - } - - // if expression is a procedure call that returns a table, it will be a slice of attributes - _, ok = e.Accept(s).([]*Attribute) - if ok { - s.errs.AddErr(e, ErrType, "procedure returns table, not a scalar value") - return cast(e, types.UnknownType) - } - - // if it is a procedure call that returns many values, it will be a slice of data types - vals, ok := e.Accept(s).([]*types.DataType) - if ok { - s.errs.AddErr(e, ErrType, "expected procedure to return a single value, returns %d", len(vals)) - return cast(e, types.UnknownType) - + case []*types.DataType: + // if it is a procedure than returns several scalar values + s.errs.AddErr(e, ErrType, "expected procedure to return a single value, returns %d values", len(v)) + case *returnsTable: + // if it is a procedure that returns a table + s.errs.AddErr(e, ErrType, "procedure returns table, not scalar values") + case nil: + // if it is a procedure that returns nothing + s.errs.AddErr(e, ErrType, "procedure does not return any value") + default: + // unknown + s.errs.AddErr(e, ErrType, "internal bug: could not infer expected type") } - s.errs.AddErr(e, ErrType, "could not infer expected type") return cast(e, types.UnknownType) } @@ -567,6 +567,12 @@ func (s *sqlAnalyzer) VisitExpressionFunctionCall(p0 *ExpressionFunctionCall) an } } + // callers of this visitor know that a nil return means a function does not + // return anything. We explicitly return nil instead of a nil *types.DataType + if returnType == nil { + return nil + } + return cast(p0, returnType) } @@ -625,7 +631,7 @@ func (s *sqlAnalyzer) returnProcedureReturnExpr(p0 ExpressionCall, procedureName if p0.GetTypeCast() != nil { s.errs.AddErr(p0, ErrType, "cannot typecast procedure %s because does not return a value", procedureName) } - return types.NullType + return nil } // if it returns a table, we need to return it as a set of attributes. @@ -638,7 +644,9 @@ func (s *sqlAnalyzer) returnProcedureReturnExpr(p0 ExpressionCall, procedureName } } - return attrs + return &returnsTable{ + attrs: attrs, + } } switch len(ret.Fields) { @@ -661,6 +669,13 @@ func (s *sqlAnalyzer) returnProcedureReturnExpr(p0 ExpressionCall, procedureName } } +// returnsTable is a special struct returned by returnProcedureReturnExpr when a procedure returns a table. +// It is used internally to detect when a procedure returns a table, so that we can properly throw type errors +// with helpful messages when a procedure returning a table is used in a position where a scalar value is expected. +type returnsTable struct { + attrs []*Attribute +} + func (s *sqlAnalyzer) VisitExpressionVariable(p0 *ExpressionVariable) any { dt, ok := s.blockContext.variables[p0.String()] if !ok { @@ -1316,7 +1331,13 @@ func (s *sqlAnalyzer) VisitSelectStatement(p0 *SelectStatement) any { // if it is not a compound, then we apply the following default ordering rules (after the user defined): // 1. Each primary key for each schema table joined is ordered in ascending order. // The tables and columns for all joined tables will be sorted alphabetically. - // If table aliases are used, they will be used instead of the name. + // If table aliases are used, they will be used instead of the name. This must include + // subqueries and function joins; even though those are ordered, they still need to + // be ordered in the outermost select. + // see: https://www.reddit.com/r/PostgreSQL/comments/u6icv9/is_original_sort_order_preserve_after_joining/ + // TODO: we can likely make some significant optimizations here by only applying ordering + // on the outermost query UNLESS aggregates are used in the subquery, but that is a future + // optimization. // 2. If the select core contains DISTINCT, then the above does not apply, and // we order by all columns returned, in the order they are returned. // 3. If there is a group by clause, none of the above apply, and instead we order by @@ -1362,17 +1383,34 @@ func (s *sqlAnalyzer) VisitSelectStatement(p0 *SelectStatement) any { } } else { // if not distinct, order by primary keys in all joined tables - for _, tbl := range order.OrderMap(rel1Scope.joinedTables) { - pks, err := tbl.Value.GetPrimaryKey() - if err != nil { - s.errs.AddErr(p0, err, "could not get primary key for table %s", tbl.Key) + for _, rel := range rel1Scope.joinedRelations { + // if it is a table, we only order by primary key. + // otherwise, order by all columns. + tbl, ok := rel1Scope.joinedTables[rel.Name] + if ok { + pks, err := tbl.GetPrimaryKey() + if err != nil { + s.errs.AddErr(p0, err, "could not get primary key for table %s", rel.Name) + } + + for _, pk := range pks { + p0.Ordering = append(p0.Ordering, &OrderingTerm{ + Expression: &ExpressionColumn{ + Table: rel.Name, + Column: pk, + }, + }) + } + + continue } - for _, pk := range pks { + // if not a table, order by all columns + for _, attr := range rel.Attributes { p0.Ordering = append(p0.Ordering, &OrderingTerm{ Expression: &ExpressionColumn{ - Table: tbl.Key, - Column: pk, + Table: rel.Name, + Column: attr.Name, }, }) } @@ -1666,7 +1704,7 @@ func (s *sqlAnalyzer) VisitRelationFunctionCall(p0 *RelationFunctionCall) any { // the function call here must return []*Attribute // this logic is handled in returnProcedureReturnExpr. - ret, ok := p0.FunctionCall.Accept(s).([]*Attribute) + ret, ok := p0.FunctionCall.Accept(s).(*returnsTable) if !ok { s.errs.AddErr(p0, ErrType, "cannot join procedure that does not return type table") } @@ -1690,7 +1728,7 @@ func (s *sqlAnalyzer) VisitRelationFunctionCall(p0 *RelationFunctionCall) any { err := s.sqlCtx.joinRelation(&Relation{ Name: p0.Alias, - Attributes: ret, + Attributes: ret.attrs, }) if err != nil { s.errs.AddErr(p0, err, p0.Alias) @@ -1754,7 +1792,6 @@ func (s *sqlAnalyzer) VisitUpdateStatement(p0 *UpdateStatement) any { if !ok { s.expressionTypeErr(p0.Where) return []*Attribute{} - } s.expect(p0.Where, whereType, types.BoolType) @@ -2185,23 +2222,30 @@ func (p *procedureAnalyzer) VisitProcedureStmtCall(p0 *ProcedureStmtCall) any { alreadyMutative := p.sqlResult.Mutative var callReturns []*types.DataType - // it might return a single value - returns1, ok := p0.Call.Accept(p).(*types.DataType) - if ok { - // if it returns null, then we do not need to assign it to a variable. - if !returns1.EqualsStrict(types.NullType) { - callReturns = append(callReturns, returns1) - } - } else { - // or it might return multiple values - returns2, ok := p0.Call.Accept(p).([]*types.DataType) - if !ok { - p.errs.AddErr(p0.Call, ErrType, "expected function/procedure to return one or more variables") + + // procedure calls can return many different types of values. + switch v := p0.Call.Accept(p).(type) { + case *types.DataType: + callReturns = []*types.DataType{v} + case []*types.DataType: + callReturns = v + case *returnsTable: + // if a procedure that returns a table is being called in a + // procedure, we need to ensure there are no receivers, since + // it is impossible to assign a table to a variable. + // we will also not add these to the callReturns, since they are + // table columns, and not assignable variables + if len(p0.Receivers) != 0 { + p.errs.AddErr(p0, ErrResultShape, "procedure returns table, cannot assign to variable(s)") return zeroProcedureReturn() } - - callReturns = returns2 + case nil: + // do nothing + default: + p.expressionTypeErr(p0.Call) + return zeroProcedureReturn() } + // if calling the `error` function, then this branch will return exits := false if p0.Call.FunctionName() == "error" { @@ -2214,6 +2258,14 @@ func (p *procedureAnalyzer) VisitProcedureStmtCall(p0 *ProcedureStmtCall) any { p.errs.AddErr(p0, ErrViewMutatesState, `view procedure calls non-view procedure "%s"`, p0.Call.FunctionName()) } + // users can discard returns by simply not having receivers. + // if there are no receivers, we can return early. + if len(p0.Receivers) == 0 { + return &procedureStmtResult{ + willReturn: exits, + } + } + // we do not have to capture all return values, but we need to ensure // we do not have more receivers than return values. if len(p0.Receivers) != len(callReturns) { @@ -2280,16 +2332,18 @@ func (p *procedureAnalyzer) VisitProcedureStmtForLoop(p0 *ProcedureStmtForLoop) // we do not mark declared here since these are loop receivers, // and they get tracked in a separate slice than other variables. if ok { - // if here, we are likely looping over an array. + // if here, we are looping over an array or range. // we need to use the returned type, but remove the IsArray rec := scalarVal.Copy() rec.IsArray = false p.variables[p0.Receiver.String()] = rec tracker.dataType = rec } else { + // if we are here, we are looping over a select. compound, ok := res.(map[string]*types.DataType) if !ok { - panic("expected loop term to return scalar or compound type") + p.expressionTypeErr(p0.LoopTerm) + return zeroProcedureReturn() } p.anonymousVariables[p0.Receiver.String()] = compound // we do not set the tracker type here, since it is an anonymous variable. diff --git a/parse/functions.go b/parse/functions.go index abdb2222b..598fb3e63 100644 --- a/parse/functions.go +++ b/parse/functions.go @@ -42,7 +42,7 @@ var ( return nil, wrapErrArgumentType(types.TextType, args[0]) } - return types.NullType, nil + return nil, nil }, PGFormat: func(inputs []string, distinct bool, star bool) (string, error) { if star { diff --git a/parse/parse_test.go b/parse/parse_test.go index c9d485655..605a5bf2d 100644 --- a/parse/parse_test.go +++ b/parse/parse_test.go @@ -892,6 +892,29 @@ var ( }, Body: `return select id from users;`, } + + foreignProcGetUser = &types.ForeignProcedure{ + Name: "get_user_id", + Parameters: []*types.DataType{ + types.TextType, + }, + Returns: &types.ProcedureReturn{ + Fields: []*types.NamedType{ + { + Name: "id", + Type: types.IntType, + }, + }, + }, + } + + foreignProcCreateUser = &types.ForeignProcedure{ + Name: "foreign_create_user", + Parameters: []*types.DataType{ + types.IntType, + types.TextType, + }, + } ) func Test_Procedure(t *testing.T) { @@ -929,6 +952,37 @@ func Test_Procedure(t *testing.T) { }, }, }, + { + name: "procedure applies default ordering to selects", + proc: ` + select * from users; + `, + want: &parse.ProcedureParseResult{ + AST: []parse.ProcedureStmt{ + &parse.ProcedureStmtSQL{ + SQL: &parse.SQLStatement{ + SQL: &parse.SelectStatement{ + SelectCores: []*parse.SelectCore{ + { + Columns: []parse.ResultColumn{ + &parse.ResultColumnWildcard{}, + }, + From: &parse.RelationTable{ + Table: "users", + }, + }, + }, + Ordering: []*parse.OrderingTerm{ + { + Expression: exprColumn("users", "id"), + }, + }, + }, + }, + }, + }, + }, + }, { name: "for loop", proc: ` @@ -1406,6 +1460,11 @@ func Test_Procedure(t *testing.T) { }, }, }, + Ordering: []*parse.OrderingTerm{ + { + Expression: exprColumn("", "id"), + }, + }, }, }, }, @@ -1445,6 +1504,68 @@ func Test_Procedure(t *testing.T) { }, }, }, + { + // this tests for regression on a previously known bug + name: "foreign procedure returning nothing to a variable", + returns: &types.ProcedureReturn{ + Fields: []*types.NamedType{ + { + Name: "id", + Type: types.IntType, + }, + }, + }, + proc: ` + return foreign_create_user['xbd', 'create_user'](1, 'user1'); + `, + err: parse.ErrType, + }, + { + // regression test for a previously known bug + name: "calling a procedure that returns nothing works fine", + proc: ` + foreign_create_user['xbd', 'create_user'](1, 'user1'); + `, + want: &parse.ProcedureParseResult{ + AST: []parse.ProcedureStmt{ + &parse.ProcedureStmtCall{ + Call: &parse.ExpressionForeignCall{ + Name: "foreign_create_user", + ContextualArgs: []parse.Expression{ + exprLit("xbd"), + exprLit("create_user"), + }, + Args: []parse.Expression{ + exprLit(1), + exprLit("user1"), + }, + }, + }, + }, + }, + }, + { + name: "assigning a variable with error is invalid", + proc: `$a := error('error message');`, + err: parse.ErrResultShape, + }, + { + // this is a regression test for a previous bug + name: "discarding return values of a function is ok", + proc: `abs(-1);`, + want: &parse.ProcedureParseResult{ + AST: []parse.ProcedureStmt{ + &parse.ProcedureStmtCall{ + Call: &parse.ExpressionFunctionCall{ + Name: "abs", + Args: []parse.Expression{ + exprLit(-1), + }, + }, + }, + }, + }, + }, } for _, tt := range tests { @@ -1475,6 +1596,10 @@ func Test_Procedure(t *testing.T) { proc, procGetAllUserIds, }, + ForeignProcedures: []*types.ForeignProcedure{ + foreignProcGetUser, + foreignProcCreateUser, + }, }) require.NoError(t, err) @@ -1801,7 +1926,12 @@ func Test_SQL(t *testing.T) { }, }, }, - // no ordering since the procedure implementation is ordered + // apply default ordering + Ordering: []*parse.OrderingTerm{ + { + Expression: exprColumn("", "id"), + }, + }, }, }, }, @@ -1985,6 +2115,125 @@ func Test_SQL(t *testing.T) { sql: `SELECT count(*) FROM users order by count(*) DESC;`, err: parse.ErrAggregate, }, + { + name: "ordering for subqueries", + sql: `SELECT u.username, p.id FROM (SELECT * FROM users) as u inner join (SELECT * FROM posts) as p on u.id = p.author_id;`, + want: &parse.SQLStatement{ + SQL: &parse.SelectStatement{ + SelectCores: []*parse.SelectCore{ + { + Columns: []parse.ResultColumn{ + &parse.ResultColumnExpression{ + Expression: exprColumn("u", "username"), + }, + &parse.ResultColumnExpression{ + Expression: exprColumn("p", "id"), + }, + }, + From: &parse.RelationSubquery{ + Subquery: &parse.SelectStatement{ + SelectCores: []*parse.SelectCore{ + { + Columns: []parse.ResultColumn{ + &parse.ResultColumnWildcard{}, + }, + From: &parse.RelationTable{ + Table: "users", + }, + }, + }, + Ordering: []*parse.OrderingTerm{ + { + Expression: exprColumn("users", "id"), + }, + }, + }, + Alias: "u", + }, + Joins: []*parse.Join{ + { + Type: parse.JoinTypeInner, + Relation: &parse.RelationSubquery{ + Subquery: &parse.SelectStatement{ + SelectCores: []*parse.SelectCore{ + { + Columns: []parse.ResultColumn{ + &parse.ResultColumnWildcard{}, + }, + From: &parse.RelationTable{ + Table: "posts", + }, + }, + }, + Ordering: []*parse.OrderingTerm{ + { + Expression: exprColumn("posts", "id"), + }, + }, + }, + Alias: "p", + }, + On: &parse.ExpressionComparison{ + Left: exprColumn("u", "id"), + Operator: parse.ComparisonOperatorEqual, + Right: exprColumn("p", "author_id"), + }, + }, + }, + }, + }, + Ordering: []*parse.OrderingTerm{ + { + Expression: exprColumn("u", "id"), + }, + { + Expression: exprColumn("u", "username"), + }, + { + Expression: exprColumn("p", "id"), + }, + { + Expression: exprColumn("p", "author_id"), + }, + }, + }, + }, + }, + { + name: "select against subquery with table join", + sql: `SELECT u.username, p.id FROM (SELECT * FROM users) inner join posts as p on users.id = p.author_id;`, + err: parse.ErrUnnamedJoin, + }, + { + name: "default ordering on procedure call", + sql: `SELECT * FROM get_all_user_ids();`, + want: &parse.SQLStatement{ + SQL: &parse.SelectStatement{ + SelectCores: []*parse.SelectCore{ + { + Columns: []parse.ResultColumn{ + &parse.ResultColumnWildcard{}, + }, + From: &parse.RelationFunctionCall{ + FunctionCall: &parse.ExpressionFunctionCall{ + Name: "get_all_user_ids", + }, + }, + }, + }, + Ordering: []*parse.OrderingTerm{ + { + Expression: exprColumn("", "id"), + }, + }, + }, + }, + }, + { + name: "join against unnamed function call fails", + sql: `SELECT * FROM users inner join get_all_user_ids() on users.id = u.id;`, + err: parse.ErrUnnamedJoin, + }, } for _, tt := range tests { @@ -2011,8 +2260,8 @@ func Test_SQL(t *testing.T) { return } - if !deepCompare(res.AST, tt.want) { - t.Errorf("unexpected AST:\n%s", diff(res.AST, tt.want)) + if !deepCompare(tt.want, res.AST) { + t.Errorf("unexpected AST:\n%s", diff(tt.want, res.AST)) } }) } diff --git a/test/acceptance/test-data/users.kf b/test/acceptance/test-data/users.kf index 30d9733f3..37824839d 100644 --- a/test/acceptance/test-data/users.kf +++ b/test/acceptance/test-data/users.kf @@ -92,8 +92,8 @@ procedure create_post($content text) public { ); } -procedure get_recent_posts($username text) public view returns table(id uuid, content text) { -return SELECT p.id, p.content from posts as p +procedure get_recent_posts($username text) public view returns table(id uuid, content text, post_num int) { +return SELECT p.id, p.content, p.post_num from posts as p inner join users as u on p.author_id = u.id WHERE u.name = $username ORDER BY p.post_num DESC; @@ -109,7 +109,7 @@ procedure get_recent_posts_by_size($username text, $size int, $limit int) public } $count int := 0; - for $row in select * from get_recent_posts($username) { + for $row in select * from get_recent_posts($username) order by post_num DESC { if $count == $limit { break; } @@ -126,7 +126,9 @@ procedure get_recent_posts_by_size($username text, $size int, $limit int) public procedure reverse_latest_posts($username text, $limit int) public view returns (content text[]) { $content text[]; - for $post in select * from get_recent_posts($username) { + // we need to re-apply ordering here since Postgres doesn't guarantee ordering + // propagates from subqueries and procedures. + for $post in select * from get_recent_posts($username) order by post_num DESC { $content := array_append($content, $post.content); }