Skip to content

Commit 7ffbca6

Browse files
psteinroeCopilot
andauthored
fix: sql fn params (#366)
the idea is to replace the fn params with a default value based on their type: ```sql create or replace function users.select_no_ref (user_id int4) returns table ( first_name text ) language sql security invoker as $$ select first_name FROM users_hidden.users where id = user_id; $$; ``` will become ```sql create or replace function users.select_no_ref (user_id int4) returns table ( first_name text ) language sql security invoker as $$ select first_name FROM users_hidden.users where id = 0; -- <-- here $$; ``` ## Todo - [x] pass params to typechecker - [x] implement `apply_identifiers` fixes #353 fixes #352 --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 400715f commit 7ffbca6

File tree

18 files changed

+747
-86
lines changed

18 files changed

+747
-86
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/pgt_completions/src/context/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ impl<'a> CompletionContext<'a> {
270270
.insert(Some(WrappingClause::Select), new);
271271
}
272272
}
273+
_ => {}
273274
};
274275
}
275276
}

crates/pgt_schema_cache/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,4 @@ pub use schema_cache::SchemaCache;
1919
pub use schemas::Schema;
2020
pub use tables::{ReplicaIdentity, Table};
2121
pub use triggers::{Trigger, TriggerAffected, TriggerEvent};
22+
pub use types::{PostgresType, PostgresTypeAttribute};

crates/pgt_schema_cache/src/types.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@ use crate::schema_cache::SchemaCacheItem;
66

77
#[derive(Debug, Clone, Default)]
88
pub struct TypeAttributes {
9-
attrs: Vec<PostgresTypeAttribute>,
9+
pub attrs: Vec<PostgresTypeAttribute>,
1010
}
1111

1212
#[derive(Debug, Clone, Default, Deserialize)]
1313
pub struct PostgresTypeAttribute {
14-
name: String,
15-
type_id: i64,
14+
pub name: String,
15+
pub type_id: i64,
1616
}
1717

1818
impl From<Option<JsonValue>> for TypeAttributes {

crates/pgt_treesitter_queries/src/lib.rs

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ mod tests {
7070

7171
use crate::{
7272
TreeSitterQueriesExecutor,
73-
queries::{RelationMatch, TableAliasMatch},
73+
queries::{ParameterMatch, RelationMatch, TableAliasMatch},
7474
};
7575

7676
#[test]
@@ -207,11 +207,11 @@ where
207207
select
208208
*
209209
from (
210-
select *
210+
select *
211211
from (
212212
select *
213213
from private.something
214-
) as sq2
214+
) as sq2
215215
join private.tableau pt1
216216
on sq2.id = pt1.id
217217
) as sq1
@@ -255,4 +255,33 @@ on sq1.id = pt.id;
255255
assert_eq!(results[0].get_schema(sql), Some("private".into()));
256256
assert_eq!(results[0].get_table(sql), "something");
257257
}
258+
259+
#[test]
260+
fn extracts_parameters() {
261+
let sql = r#"select v_test + fn_name.custom_type.v_test2 + $3 + custom_type.v_test3;"#;
262+
263+
let mut parser = tree_sitter::Parser::new();
264+
parser.set_language(tree_sitter_sql::language()).unwrap();
265+
266+
let tree = parser.parse(sql, None).unwrap();
267+
268+
let mut executor = TreeSitterQueriesExecutor::new(tree.root_node(), sql);
269+
270+
executor.add_query_results::<ParameterMatch>();
271+
272+
let results: Vec<&ParameterMatch> = executor
273+
.get_iter(None)
274+
.filter_map(|q| q.try_into().ok())
275+
.collect();
276+
277+
assert_eq!(results.len(), 4);
278+
279+
assert_eq!(results[0].get_path(sql), "v_test");
280+
281+
assert_eq!(results[1].get_path(sql), "fn_name.custom_type.v_test2");
282+
283+
assert_eq!(results[2].get_path(sql), "$3");
284+
285+
assert_eq!(results[3].get_path(sql), "custom_type.v_test3");
286+
}
258287
}

crates/pgt_treesitter_queries/src/queries/mod.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
1+
mod parameters;
12
mod relations;
23
mod select_columns;
34
mod table_aliases;
45

6+
pub use parameters::*;
57
pub use relations::*;
68
pub use select_columns::*;
79
pub use table_aliases::*;
810

911
#[derive(Debug)]
1012
pub enum QueryResult<'a> {
1113
Relation(RelationMatch<'a>),
14+
Parameter(ParameterMatch<'a>),
1215
TableAliases(TableAliasMatch<'a>),
1316
SelectClauseColumns(SelectColumnMatch<'a>),
1417
}
@@ -26,6 +29,12 @@ impl QueryResult<'_> {
2629

2730
start >= range.start_point && end <= range.end_point
2831
}
32+
Self::Parameter(pm) => {
33+
let node_range = pm.node.range();
34+
35+
node_range.start_point >= range.start_point
36+
&& node_range.end_point <= range.end_point
37+
}
2938
QueryResult::TableAliases(m) => {
3039
let start = m.table.start_position();
3140
let end = m.alias.end_position();
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
use std::sync::LazyLock;
2+
3+
use crate::{Query, QueryResult};
4+
5+
use super::QueryTryFrom;
6+
7+
static TS_QUERY: LazyLock<tree_sitter::Query> = LazyLock::new(|| {
8+
static QUERY_STR: &str = r#"
9+
[
10+
(field
11+
(identifier)) @reference
12+
(field
13+
(object_reference)
14+
"." (identifier)) @reference
15+
(parameter) @parameter
16+
]
17+
"#;
18+
tree_sitter::Query::new(tree_sitter_sql::language(), QUERY_STR).expect("Invalid TS Query")
19+
});
20+
21+
#[derive(Debug)]
22+
pub struct ParameterMatch<'a> {
23+
pub(crate) node: tree_sitter::Node<'a>,
24+
}
25+
26+
impl ParameterMatch<'_> {
27+
pub fn get_path(&self, sql: &str) -> String {
28+
self.node
29+
.utf8_text(sql.as_bytes())
30+
.expect("Failed to get path from ParameterMatch")
31+
.to_string()
32+
}
33+
34+
pub fn get_range(&self) -> tree_sitter::Range {
35+
self.node.range()
36+
}
37+
38+
pub fn get_byte_range(&self) -> std::ops::Range<usize> {
39+
let range = self.node.range();
40+
range.start_byte..range.end_byte
41+
}
42+
}
43+
44+
impl<'a> TryFrom<&'a QueryResult<'a>> for &'a ParameterMatch<'a> {
45+
type Error = String;
46+
47+
fn try_from(q: &'a QueryResult<'a>) -> Result<Self, Self::Error> {
48+
match q {
49+
QueryResult::Parameter(r) => Ok(r),
50+
51+
#[allow(unreachable_patterns)]
52+
_ => Err("Invalid QueryResult type".into()),
53+
}
54+
}
55+
}
56+
57+
impl<'a> QueryTryFrom<'a> for ParameterMatch<'a> {
58+
type Ref = &'a ParameterMatch<'a>;
59+
}
60+
61+
impl<'a> Query<'a> for ParameterMatch<'a> {
62+
fn execute(root_node: tree_sitter::Node<'a>, stmt: &'a str) -> Vec<crate::QueryResult<'a>> {
63+
let mut cursor = tree_sitter::QueryCursor::new();
64+
65+
let matches = cursor.matches(&TS_QUERY, root_node, stmt.as_bytes());
66+
67+
matches
68+
.filter_map(|m| {
69+
let captures = m.captures;
70+
71+
// We expect exactly one capture for a parameter
72+
if captures.len() != 1 {
73+
return None;
74+
}
75+
76+
Some(QueryResult::Parameter(ParameterMatch {
77+
node: captures[0].node,
78+
}))
79+
})
80+
.collect()
81+
}
82+
}

crates/pgt_typecheck/Cargo.toml

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,16 @@ version = "0.0.0"
1212

1313

1414
[dependencies]
15-
pgt_console.workspace = true
16-
pgt_diagnostics.workspace = true
17-
pgt_query_ext.workspace = true
18-
pgt_schema_cache.workspace = true
19-
pgt_text_size.workspace = true
20-
sqlx.workspace = true
21-
tokio.workspace = true
22-
tree-sitter.workspace = true
23-
tree_sitter_sql.workspace = true
15+
pgt_console.workspace = true
16+
pgt_diagnostics.workspace = true
17+
pgt_query_ext.workspace = true
18+
pgt_schema_cache.workspace = true
19+
pgt_text_size.workspace = true
20+
pgt_treesitter_queries.workspace = true
21+
sqlx.workspace = true
22+
tokio.workspace = true
23+
tree-sitter.workspace = true
24+
tree_sitter_sql.workspace = true
2425

2526
[dev-dependencies]
2627
insta.workspace = true

crates/pgt_typecheck/src/diagnostics.rs

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -97,21 +97,26 @@ impl Advices for TypecheckAdvices {
9797
pub(crate) fn create_type_error(
9898
pg_err: &PgDatabaseError,
9999
ts: &tree_sitter::Tree,
100+
positions_valid: bool,
100101
) -> TypecheckDiagnostic {
101102
let position = pg_err.position().and_then(|pos| match pos {
102103
sqlx::postgres::PgErrorPosition::Original(pos) => Some(pos - 1),
103104
_ => None,
104105
});
105106

106107
let range = position.and_then(|pos| {
107-
ts.root_node()
108-
.named_descendant_for_byte_range(pos, pos)
109-
.map(|node| {
110-
TextRange::new(
111-
node.start_byte().try_into().unwrap(),
112-
node.end_byte().try_into().unwrap(),
113-
)
114-
})
108+
if positions_valid {
109+
ts.root_node()
110+
.named_descendant_for_byte_range(pos, pos)
111+
.map(|node| {
112+
TextRange::new(
113+
node.start_byte().try_into().unwrap(),
114+
node.end_byte().try_into().unwrap(),
115+
)
116+
})
117+
} else {
118+
None
119+
}
115120
});
116121

117122
let severity = match pg_err.severity() {

crates/pgt_typecheck/src/lib.rs

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,23 @@
11
mod diagnostics;
2+
mod typed_identifier;
23

34
pub use diagnostics::TypecheckDiagnostic;
45
use diagnostics::create_type_error;
56
use pgt_text_size::TextRange;
67
use sqlx::postgres::PgDatabaseError;
78
pub use sqlx::postgres::PgSeverity;
89
use sqlx::{Executor, PgPool};
10+
use typed_identifier::apply_identifiers;
11+
pub use typed_identifier::{IdentifierType, TypedIdentifier};
912

1013
#[derive(Debug)]
1114
pub struct TypecheckParams<'a> {
1215
pub conn: &'a PgPool,
1316
pub sql: &'a str,
1417
pub ast: &'a pgt_query_ext::NodeEnum,
1518
pub tree: &'a tree_sitter::Tree,
19+
pub schema_cache: &'a pgt_schema_cache::SchemaCache,
20+
pub identifiers: Vec<TypedIdentifier>,
1621
}
1722

1823
#[derive(Debug, Clone)]
@@ -51,13 +56,24 @@ pub async fn check_sql(
5156
// each typecheck operation.
5257
conn.close_on_drop();
5358

54-
let res = conn.prepare(params.sql).await;
59+
let (prepared, positions_valid) = apply_identifiers(
60+
params.identifiers,
61+
params.schema_cache,
62+
params.tree,
63+
params.sql,
64+
);
65+
66+
let res = conn.prepare(&prepared).await;
5567

5668
match res {
5769
Ok(_) => Ok(None),
5870
Err(sqlx::Error::Database(err)) => {
5971
let pg_err = err.downcast_ref::<PgDatabaseError>();
60-
Ok(Some(create_type_error(pg_err, params.tree)))
72+
Ok(Some(create_type_error(
73+
pg_err,
74+
params.tree,
75+
positions_valid,
76+
)))
6177
}
6278
Err(err) => Err(err),
6379
}

0 commit comments

Comments
 (0)