Skip to content

Commit

Permalink
fix bug when writing to foreign proc (#771)
Browse files Browse the repository at this point in the history
updated comment

fix bugs on foreign procedures

fix build tag

fixed act test

addressed gavins feedback

fix failing unit test
  • Loading branch information
brennanjl authored May 28, 2024
1 parent ae11433 commit 933dca6
Show file tree
Hide file tree
Showing 10 changed files with 807 additions and 389 deletions.
167 changes: 47 additions & 120 deletions internal/engine/generate/foreign_procedure.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,90 +7,6 @@ import (
"github.com/kwilteam/kwil-db/core/types"
)

/*
Generated foreign procedure calls are responsible for making calls to procedures in other schemas.
The dbid and procedure can be passed in dynamically. It should follow the following general structure, with variations
for inputs/outputs:
create or replace function _fp_template(_dbid TEXT, _procedure TEXT, _arg1 TEXT, _arg2 INT8, OUT _out_1 TEXT) as $$
DECLARE
_schema_owner BYTEA;
_is_view BOOLEAN;
_is_owner_only BOOLEAN;
_is_public BOOLEAN;
_returns_table BOOLEAN;
_expected_input_types TEXT[];
_expected_return_types TEXT[];
BEGIN
SELECT p.param_types, p.return_types, p.is_view, p.owner_only, p.public, s.owner, p.returns_table
INTO _expected_input_types, _expected_return_types, _is_view, _is_owner_only, _is_public, _schema_owner, _returns_table
FROM kwild_internal.procedures as p INNER JOIN kwild_internal.kwil_schemas as s
ON s.schema_id = p.id
WHERE p.name = _procedure AND s.dbid = _dbid;
-- Schema owner cannot be nil, and will only be nil
-- if the procedure is not found
IF _schema_owner IS NULL THEN
RAISE EXCEPTION 'Procedure "%" not found in schema "%"', _procedure, _dbid;
END IF;
-- we now ensure that:
-- 1. if we are in a read-only connection, that the procedure is view
-- 2. if it is owner only, the signer is the owner
-- 3. it is public
-- 4. our input types match the expected input types
-- 5. our return types match the expected return types
-- 1. if we are in a read-only connection, that the procedure is view
IF _is_view = FALSE AND current_setting('is_read_only') = 'on' THEN
RAISE EXCEPTION 'Non-view procedure "%" called in view-only connection', _procedure;
END IF;
-- 2. if it is owner only, the signer is the owner
IF _is_owner_only = TRUE AND _schema_owner != current_setting('ctx.signer')::BYTEA THEN
RAISE EXCEPTION 'Procedure "%" is owner-only and cannot be called by user "%" in schema "%"', _procedure, current_setting('ctx.signer'), _dbid;
END IF;
-- 3. it is public
IF _is_public = FALSE THEN
RAISE EXCEPTION 'Procedure "%" is not public and cannot be foreign called', _procedure;
END IF;
-- 4. our input types match the expected input types
IF array_length(_expected_input_types, 1) != 1 THEN
RAISE EXCEPTION 'Procedure "%" expects exactly one input type, but got %', _procedure, array_length(_expected_input_types, 1);
END IF;
-- since _arg1 is text, we want to ensure that the input type is text
IF _expected_input_types[1] != 'TEXT' THEN
RAISE EXCEPTION 'Procedure "%" expects input type "TEXT", but got "%"', _procedure, _expected_input_types[1];
END IF;
IF _expected_input_types[2] != 'INT' THEN
RAISE EXCEPTION 'Procedure "%" expects input type "INT", but got "%"', _procedure, _expected_input_types[2];
END IF;
-- 5. our return types match the expected return types
IF array_length(_expected_return_types, 1) != 1 THEN
RAISE EXCEPTION 'Procedure "%" expects exactly one return type, but got %', _procedure, array_length(_expected_return_types, 1);
END IF;
-- since _out_1 is text, we want to ensure that the return type is text
IF _expected_return_types[1] != 'TEXT' THEN
RAISE EXCEPTION 'Procedure "%" expects return type "TEXT", but got "%"', _procedure, _expected_return_types[1];
END IF;
-- we now call the procedure. we prefix ds_ to the dbid, as per the Kwil rules.
EXECUTE format('SELECT * FROM ds_%I.%I($1, $2)', _dbid, _procedure) INTO _out_1 USING _arg1, _arg2;
-- or, to return a table, we can do:
-- RETURN QUERY EXECUTE format('SELECT * FROM ds_%I.%I($1, $2)', _dbid, _procedure) USING _arg1, _arg2;
END;
$$ LANGUAGE plpgsql;
*/

// This is implicitly
// coupled to the schema defined in internal/engine.execution/queries.go, and therefore is implicitly
// a circular dependency. I am unsure how to resolve this, but am punting on it for now since the structure
Expand Down Expand Up @@ -175,15 +91,16 @@ func GenerateForeignProcedure(proc *types.ForeignProcedure, pgSchema string) (st
_is_public BOOLEAN;
_returns_table BOOLEAN;
_expected_input_types TEXT[];
_expected_return_names TEXT[];
_expected_return_types TEXT[];`)

// begin block
str.WriteString("\nBEGIN")

// select the procedure info, and perform checks 1-3
str.WriteString(`
SELECT p.param_types, p.return_types, p.is_view, p.owner_only, p.public, s.owner, p.returns_table
INTO _expected_input_types, _expected_return_types, _is_view, _is_owner_only, _is_public, _schema_owner, _returns_table
SELECT p.param_types, p.return_types, p.return_names, p.is_view, p.owner_only, p.public, s.owner, p.returns_table
INTO _expected_input_types, _expected_return_types, _expected_return_names, _is_view, _is_owner_only, _is_public, _schema_owner, _returns_table
FROM kwild_internal.procedures as p INNER JOIN kwild_internal.kwil_schemas as s
ON p.schema_id = s.id
WHERE p.name = _procedure AND s.dbid = _dbid;
Expand All @@ -192,7 +109,7 @@ func GenerateForeignProcedure(proc *types.ForeignProcedure, pgSchema string) (st
RAISE EXCEPTION 'Procedure "%" not found in schema "%"', _procedure, _dbid;
END IF;
IF _is_view = FALSE AND current_setting('is_read_only') = 'on' THEN
IF _is_view = FALSE AND current_setting('transaction_read_only')::boolean = true THEN
RAISE EXCEPTION 'Non-view procedure "%" called in view-only connection', _procedure;
END IF;
Expand All @@ -213,66 +130,76 @@ func GenerateForeignProcedure(proc *types.ForeignProcedure, pgSchema string) (st
// first check the length of the array
str.WriteString(fmt.Sprintf(`
IF array_length(_expected_input_types, 1) IS NOT NULL THEN
RAISE EXCEPTION 'Foreign procedure definition "%s" expects no inputs, but got procedure "%%" requires %% inputs', _procedure, array_length(_expected_input_types, 1);
RAISE EXCEPTION 'Foreign procedure definition "%s" expects no args, but procedure "%%" located at DBID "%%" requires %% arg(s)', _procedure, _dbid, array_length(_expected_input_types, 1);
END IF;
`, proc.Name))
} else {
str.WriteString(fmt.Sprintf(`
IF array_length(_expected_input_types, 1) IS NULL THEN
RAISE EXCEPTION 'Foreign procedure definition "%s" expects %d inputs, but procedure "%%" requires no inputs', _procedure;
RAISE EXCEPTION 'Foreign procedure definition "%s" expects %d args, but procedure "%%" located at DBID "%%" requires no args', _procedure, _dbid;
END IF;
IF array_length(_expected_input_types, 1) != %d THEN
RAISE EXCEPTION 'Foreign procedure definition "%s" expects %d inputs, but procedure "%%" requires %% inputs', _procedure, array_length(_expected_input_types, 1);
RAISE EXCEPTION 'Foreign procedure definition "%s" expects %d args, but procedure "%%" located at DBID "%%" requires %% arg(s)', _procedure, _dbid, array_length(_expected_input_types, 1);
END IF;`, proc.Name, len(proc.Parameters), len(proc.Parameters), proc.Name, len(proc.Parameters)))
}

// now we check that the types match
for i, in := range proc.Parameters {
str.WriteString(fmt.Sprintf(`
IF _expected_input_types[%d] != '%s' THEN
RAISE EXCEPTION 'Foreign procedure definition "%s" expects input type "%s", but procedure "%%" requires %%', _procedure, _expected_input_types[%d];
RAISE EXCEPTION 'Foreign procedure definition "%s" expects arg type "%s", but procedure "%%" located at DBID "%%" requires %%', _procedure, _dbid, _expected_input_types[%d];
END IF;`, i+1, in.String(), proc.Name, in.String(), i+1))
}

// if the proc returns a table, ensure that the called procedure returns a table
if proc.Returns != nil && proc.Returns.IsTable {
// if there is an expected return, check that the return fields are the same count and type.
// If it returns a table, also check to make sure that the return names are the same.
if proc.Returns != nil {
// if foreign proc returns a table, check that the called procedure returns a table
// if foreign proc does not return a table, check that the called procedure does not return a table
if proc.Returns.IsTable {
str.WriteString(fmt.Sprintf(`
IF _returns_table = FALSE THEN
RAISE EXCEPTION 'Foreign procedure definition "%s" expects a table return, but procedure "%%" located at DBID "%%" does not return a table', _procedure, _dbid;
END IF;`, proc.Name))
} else {
str.WriteString(fmt.Sprintf(`
IF _returns_table = TRUE THEN
RAISE EXCEPTION 'Foreign procedure definition "%s" expects a non-table return, but procedure "%%" located at DBID "%%" returns a table', _procedure, _dbid;
END IF;`, proc.Name))
}

str.WriteString(fmt.Sprintf(`
IF _returns_table = FALSE THEN
RAISE EXCEPTION 'Foreign procedure definition "%s" expects a table return, but procedure "%%" does not return a table', _procedure;
END IF;`, proc.Name))
IF array_length(_expected_return_types, 1) IS NULL THEN
RAISE EXCEPTION 'Foreign procedure definition "%s" expects %d returns, but procedure "%%" located at DBID "%%" returns nothing', _procedure, _dbid;
END IF;
IF array_length(_expected_return_types, 1) != %d THEN
RAISE EXCEPTION 'Foreign procedure definition "%s" expects %d returns, but procedure "%%" located at DBID "%%" returns %% fields', _procedure, _dbid, array_length(_expected_return_types, 1);
END IF;`, proc.Name, len(proc.Returns.Fields), len(proc.Returns.Fields), proc.Name, len(proc.Returns.Fields)))

// If a table is returned, we also need to ensure that it returns the exact correct column names and types.
// check that the return types match
for i, out := range proc.Returns.Fields {
// check the return type
str.WriteString(fmt.Sprintf(`
IF _expected_return_types[%d] != '%s' THEN
RAISE EXCEPTION 'Foreign procedure definition "%s" expects return type "%s" at column position %d, but procedure "%%" requires %%', _procedure, _expected_return_types[%d];
END IF;`, i+1, out.Type.String(), proc.Name, out.Type.String(), i+1, i+1))
IF _expected_return_types[%d] != '%s' THEN
RAISE EXCEPTION 'Foreign procedure definition "%s" expects return type "%s" at return position %d, but procedure "%%" located at DBID "%%" returns %%', _procedure, _dbid, _expected_return_types[%d];
END IF;`, i+1, out.Type.String(), proc.Name, out.Type.String(), i+1, i+1))

// check the return name
str.WriteString(fmt.Sprintf(`
IF _expected_return_types[%d] != '%s' THEN
RAISE EXCEPTION 'Foreign procedure definition "%s" expects return name "%s" at column position %d, but procedure "%%" requires %%', _procedure, _expected_return_types[%d];
END IF;`, i+1, out.Name, proc.Name, out.Name, i+1, i+1))
// if it returns a table, check that the return names match
if proc.Returns.IsTable {
str.WriteString(fmt.Sprintf(`
IF _expected_return_names[%d] != '%s' THEN
RAISE EXCEPTION 'Foreign procedure definition "%s" expects return name "%s" at return column position %d, but procedure "%%" located at DBID "%%" returns %%', _procedure, _dbid, _expected_return_names[%d];
END IF;`, i+1, out.Name, proc.Name, out.Name, i+1, i+1))
}
}

} else {
// else check that the called procedure does not return a table
// if not expecting returns, ensure that the expected return types are nil
str.WriteString(fmt.Sprintf(`
IF _returns_table = TRUE THEN
RAISE EXCEPTION 'Foreign procedure definition "%s" expects a non-table return, but procedure "%%" returns a table', _procedure;
IF _expected_return_types IS NOT NULL THEN
RAISE EXCEPTION 'Foreign procedure definition "%s" expects no returns, but procedure "%%" located at DBID "%%" returns non-nil value(s)', _procedure, _dbid;
END IF;`, proc.Name))

// since we are not returning a table, if we are returning anything, we need to check that the return types match.
// we do not care about the return names, since they are not tables.
if proc.Returns != nil {
for i, out := range proc.Returns.Fields {
str.WriteString(fmt.Sprintf(`
IF _expected_return_types[%d] != '%s' THEN
RAISE EXCEPTION 'Foreign procedure definition "%s" expects return type "%s" at return position %d, but procedure "%%" requires %%', _procedure, _expected_return_types[%d];
END IF;`, i+1, out.Type.String(), proc.Name, out.Type.String(), i+1, i+1))
}
}
}

// now we call the procedure.
Expand Down
101 changes: 0 additions & 101 deletions internal/engine/generate/foreign_procedure_test.go

This file was deleted.

Loading

0 comments on commit 933dca6

Please sign in to comment.