From 00b09f4c1e148e03325583cb40a9df112be5342a Mon Sep 17 00:00:00 2001 From: Julian Date: Thu, 8 May 2025 09:25:49 +0200 Subject: [PATCH 01/18] =?UTF-8?q?laying=20the=20infra=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/pgt_completions/src/item.rs | 2 ++ crates/pgt_completions/src/providers/mod.rs | 2 ++ .../pgt_completions/src/providers/policies.rs | 35 +++++++++++++++++++ crates/pgt_completions/src/relevance.rs | 1 + .../src/relevance/filtering.rs | 3 ++ .../pgt_completions/src/relevance/scoring.rs | 5 +++ crates/pgt_lsp/src/handlers/completions.rs | 1 + crates/pgt_schema_cache/src/lib.rs | 1 + crates/pgt_schema_cache/src/policies.rs | 16 ++++----- 9 files changed, 58 insertions(+), 8 deletions(-) create mode 100644 crates/pgt_completions/src/providers/policies.rs diff --git a/crates/pgt_completions/src/item.rs b/crates/pgt_completions/src/item.rs index f37d0efb..702fc766 100644 --- a/crates/pgt_completions/src/item.rs +++ b/crates/pgt_completions/src/item.rs @@ -11,6 +11,7 @@ pub enum CompletionItemKind { Function, Column, Schema, + Policy, } impl Display for CompletionItemKind { @@ -20,6 +21,7 @@ impl Display for CompletionItemKind { CompletionItemKind::Function => "Function", CompletionItemKind::Column => "Column", CompletionItemKind::Schema => "Schema", + CompletionItemKind::Policy => "Policy", }; write!(f, "{txt}") diff --git a/crates/pgt_completions/src/providers/mod.rs b/crates/pgt_completions/src/providers/mod.rs index 82e32cdf..7b07cee8 100644 --- a/crates/pgt_completions/src/providers/mod.rs +++ b/crates/pgt_completions/src/providers/mod.rs @@ -1,10 +1,12 @@ mod columns; mod functions; mod helper; +mod policies; mod schemas; mod tables; pub use columns::*; pub use functions::*; +pub use policies::*; pub use schemas::*; pub use tables::*; diff --git a/crates/pgt_completions/src/providers/policies.rs b/crates/pgt_completions/src/providers/policies.rs new file mode 100644 index 00000000..8e2533a3 --- /dev/null +++ b/crates/pgt_completions/src/providers/policies.rs @@ -0,0 +1,35 @@ +use crate::{ + CompletionItemKind, + builder::{CompletionBuilder, PossibleCompletionItem}, + context::{CompletionContext, WrappingClause}, + relevance::{CompletionRelevanceData, filtering::CompletionFilter, scoring::CompletionScore}, +}; + +use super::helper::{find_matching_alias_for_table, get_completion_text_with_schema_or_alias}; + +pub fn complete_policies<'a>(ctx: &CompletionContext<'a>, builder: &mut CompletionBuilder<'a>) { + let available_policies = &ctx.schema_cache.policies; + + for pol in available_policies { + let relevance = CompletionRelevanceData::Policy(pol); + + let mut item = PossibleCompletionItem { + label: pol.name.clone(), + score: CompletionScore::from(relevance.clone()), + filter: CompletionFilter::from(relevance), + description: format!("Table: {}.{}", pol.schema_name, pol.table_name), + kind: CompletionItemKind::Column, + completion_text: None, + }; + + // autocomplete with the alias in a join clause if we find one + if matches!(ctx.wrapping_clause_type, Some(WrappingClause::Join { .. })) { + item.completion_text = find_matching_alias_for_table(ctx, pol.table_name.as_str()) + .and_then(|alias| { + get_completion_text_with_schema_or_alias(ctx, pol.name.as_str(), alias.as_str()) + }); + } + + builder.add_item(item); + } +} diff --git a/crates/pgt_completions/src/relevance.rs b/crates/pgt_completions/src/relevance.rs index 911a6433..f51c3c52 100644 --- a/crates/pgt_completions/src/relevance.rs +++ b/crates/pgt_completions/src/relevance.rs @@ -7,4 +7,5 @@ pub(crate) enum CompletionRelevanceData<'a> { Function(&'a pgt_schema_cache::Function), Column(&'a pgt_schema_cache::Column), Schema(&'a pgt_schema_cache::Schema), + Policy(&'a pgt_schema_cache::Policy), } diff --git a/crates/pgt_completions/src/relevance/filtering.rs b/crates/pgt_completions/src/relevance/filtering.rs index ec12201c..b3299c68 100644 --- a/crates/pgt_completions/src/relevance/filtering.rs +++ b/crates/pgt_completions/src/relevance/filtering.rs @@ -130,6 +130,9 @@ impl CompletionFilter<'_> { // we should never allow schema suggestions if there already was one. false } + + // no aliases and schemas for policies + CompletionRelevanceData::Policy(_) => false, }; if !matches { diff --git a/crates/pgt_completions/src/relevance/scoring.rs b/crates/pgt_completions/src/relevance/scoring.rs index 71c01023..d9c4d937 100644 --- a/crates/pgt_completions/src/relevance/scoring.rs +++ b/crates/pgt_completions/src/relevance/scoring.rs @@ -45,6 +45,7 @@ impl CompletionScore<'_> { CompletionRelevanceData::Table(t) => t.name.as_str(), CompletionRelevanceData::Column(c) => c.name.as_str(), CompletionRelevanceData::Schema(s) => s.name.as_str(), + CompletionRelevanceData::Policy(p) => p.name.as_str(), }; let fz_matcher = SkimMatcherV2::default(); @@ -116,6 +117,7 @@ impl CompletionScore<'_> { WrappingClause::Delete if !has_mentioned_schema => 15, _ => -50, }, + CompletionRelevanceData::Policy(_) => 0, } } @@ -149,6 +151,7 @@ impl CompletionScore<'_> { WrappingNode::Relation if !has_mentioned_schema && has_node_text => 0, _ => -50, }, + CompletionRelevanceData::Policy(_) => 0, } } @@ -182,6 +185,7 @@ impl CompletionScore<'_> { CompletionRelevanceData::Table(t) => t.schema.as_str(), CompletionRelevanceData::Column(c) => c.schema_name.as_str(), CompletionRelevanceData::Schema(s) => s.name.as_str(), + CompletionRelevanceData::Policy(p) => p.name.as_str(), } } @@ -189,6 +193,7 @@ impl CompletionScore<'_> { match self.data { CompletionRelevanceData::Column(c) => Some(c.table_name.as_str()), CompletionRelevanceData::Table(t) => Some(t.name.as_str()), + CompletionRelevanceData::Policy(p) => Some(p.table_name.as_str()), _ => None, } } diff --git a/crates/pgt_lsp/src/handlers/completions.rs b/crates/pgt_lsp/src/handlers/completions.rs index e1a7508c..7da4fdf2 100644 --- a/crates/pgt_lsp/src/handlers/completions.rs +++ b/crates/pgt_lsp/src/handlers/completions.rs @@ -65,5 +65,6 @@ fn to_lsp_types_completion_item_kind( pgt_completions::CompletionItemKind::Table => lsp_types::CompletionItemKind::CLASS, pgt_completions::CompletionItemKind::Column => lsp_types::CompletionItemKind::FIELD, pgt_completions::CompletionItemKind::Schema => lsp_types::CompletionItemKind::CLASS, + pgt_completions::CompletionItemKind::Policy => lsp_types::CompletionItemKind::VALUE, } } diff --git a/crates/pgt_schema_cache/src/lib.rs b/crates/pgt_schema_cache/src/lib.rs index fc717fbe..65c6b750 100644 --- a/crates/pgt_schema_cache/src/lib.rs +++ b/crates/pgt_schema_cache/src/lib.rs @@ -13,6 +13,7 @@ mod versions; pub use columns::*; pub use functions::{Behavior, Function, FunctionArg, FunctionArgs}; +pub use policies::{Policy, PolicyCommand}; pub use schema_cache::SchemaCache; pub use schemas::Schema; pub use tables::{ReplicaIdentity, Table}; diff --git a/crates/pgt_schema_cache/src/policies.rs b/crates/pgt_schema_cache/src/policies.rs index 46a3ab18..7396522c 100644 --- a/crates/pgt_schema_cache/src/policies.rs +++ b/crates/pgt_schema_cache/src/policies.rs @@ -56,14 +56,14 @@ impl From for Policy { #[derive(Debug, Clone, PartialEq, Eq)] pub struct Policy { - name: String, - table_name: String, - schema_name: String, - is_permissive: bool, - command: PolicyCommand, - role_names: Vec, - security_qualification: Option, - with_check: Option, + pub name: String, + pub table_name: String, + pub schema_name: String, + pub is_permissive: bool, + pub command: PolicyCommand, + pub role_names: Vec, + pub security_qualification: Option, + pub with_check: Option, } impl SchemaCacheItem for Policy { From 482735c03718f07caa378ec8d70aca20bf40ff0a Mon Sep 17 00:00:00 2001 From: Julian Date: Thu, 8 May 2025 11:00:40 +0200 Subject: [PATCH 02/18] steps steps steps --- crates/pgt_completions/src/complete.rs | 5 +- .../src/{ => context}/context.rs | 572 ++++++++++-------- crates/pgt_completions/src/context/mod.rs | 4 + .../src/context/policy_parser.rs | 71 +++ .../pgt_completions/src/providers/helper.rs | 2 +- .../pgt_completions/src/providers/policies.rs | 18 +- .../src/relevance/filtering.rs | 24 +- .../pgt_completions/src/relevance/scoring.rs | 2 + 8 files changed, 415 insertions(+), 283 deletions(-) rename crates/pgt_completions/src/{ => context}/context.rs (51%) create mode 100644 crates/pgt_completions/src/context/mod.rs create mode 100644 crates/pgt_completions/src/context/policy_parser.rs diff --git a/crates/pgt_completions/src/complete.rs b/crates/pgt_completions/src/complete.rs index 442ee546..5bc5d41c 100644 --- a/crates/pgt_completions/src/complete.rs +++ b/crates/pgt_completions/src/complete.rs @@ -4,7 +4,9 @@ use crate::{ builder::CompletionBuilder, context::CompletionContext, item::CompletionItem, - providers::{complete_columns, complete_functions, complete_schemas, complete_tables}, + providers::{ + complete_columns, complete_functions, complete_policies, complete_schemas, complete_tables, + }, sanitization::SanitizedCompletionParams, }; @@ -33,6 +35,7 @@ pub fn complete(params: CompletionParams) -> Vec { complete_functions(&ctx, &mut builder); complete_columns(&ctx, &mut builder); complete_schemas(&ctx, &mut builder); + complete_policies(&ctx, &mut builder); builder.finish() } diff --git a/crates/pgt_completions/src/context.rs b/crates/pgt_completions/src/context/context.rs similarity index 51% rename from crates/pgt_completions/src/context.rs rename to crates/pgt_completions/src/context/context.rs index d96d0d53..166652a2 100644 --- a/crates/pgt_completions/src/context.rs +++ b/crates/pgt_completions/src/context/context.rs @@ -1,6 +1,7 @@ use std::collections::{HashMap, HashSet}; use pgt_schema_cache::SchemaCache; +use pgt_text_size::TextRange; use pgt_treesitter_queries::{ TreeSitterQueriesExecutor, queries::{self, QueryResult}, @@ -41,6 +42,44 @@ pub enum WrappingNode { Assignment, } +pub(crate) enum NodeUnderCursor<'a> { + TsNode(tree_sitter::Node<'a>), + CustomNode { + text: NodeText<'a>, + range: TextRange, + kind: String, + }, +} + +impl<'a> NodeUnderCursor<'a> { + pub fn start_byte(&self) -> usize { + match self { + NodeUnderCursor::TsNode(node) => node.start_byte(), + NodeUnderCursor::CustomNode { range, .. } => range.start().into(), + } + } + + pub fn end_byte(&self) -> usize { + match self { + NodeUnderCursor::TsNode(node) => node.end_byte(), + NodeUnderCursor::CustomNode { range, .. } => range.end().into(), + } + } + + pub fn kind(&self) -> &str { + match self { + NodeUnderCursor::TsNode(node) => node.kind(), + NodeUnderCursor::CustomNode { kind, .. } => kind.as_str(), + } + } +} + +impl<'a> From> for NodeUnderCursor<'a> { + fn from(node: tree_sitter::Node<'a>) -> Self { + NodeUnderCursor::TsNode(node) + } +} + impl TryFrom<&str> for WrappingNode { type Error = String; @@ -71,7 +110,7 @@ impl TryFrom for WrappingNode { } pub(crate) struct CompletionContext<'a> { - pub node_under_cursor: Option>, + pub node_under_cursor: Option>, pub tree: &'a tree_sitter::Tree, pub text: &'a str, @@ -130,8 +169,17 @@ impl<'a> CompletionContext<'a> { is_in_error_node: false, }; - ctx.gather_tree_context(); - ctx.gather_info_from_ts_queries(); + // policy handling is important to Supabase, but they are a PostgreSQL specific extension, + // so the tree_sitter_sql language does not support it. + // We infer the context manually. + if params.text.to_lowercase().starts_with("create policy") + || params.text.to_lowercase().starts_with("alter policy") + || params.text.to_lowercase().starts_with("drop policy") + { + } else { + ctx.gather_tree_context(); + ctx.gather_info_from_ts_queries(); + } ctx } @@ -175,7 +223,7 @@ impl<'a> CompletionContext<'a> { } } - pub fn get_ts_node_content(&self, ts_node: tree_sitter::Node<'a>) -> Option> { + fn get_ts_node_content(&self, ts_node: &tree_sitter::Node<'a>) -> Option> { let source = self.text; ts_node.utf8_text(source.as_bytes()).ok().map(|txt| { if SanitizedCompletionParams::is_sanitized_token(txt) { @@ -187,12 +235,18 @@ impl<'a> CompletionContext<'a> { } pub fn get_node_under_cursor_content(&self) -> Option { - self.node_under_cursor - .and_then(|n| self.get_ts_node_content(n)) - .and_then(|txt| match txt { + match self.node_under_cursor.as_ref()? { + NodeUnderCursor::TsNode(node) => { + self.get_ts_node_content(node).and_then(|nt| match nt { + NodeText::Replaced => None, + NodeText::Original(c) => Some(c.to_string()), + }) + } + NodeUnderCursor::CustomNode { text, .. } => match text { NodeText::Replaced => None, NodeText::Original(c) => Some(c.to_string()), - }) + }, + } } fn gather_tree_context(&mut self) { @@ -230,7 +284,7 @@ impl<'a> CompletionContext<'a> { // prevent infinite recursion – this can happen if we only have a PROGRAM node if current_node_kind == parent_node_kind { - self.node_under_cursor = Some(current_node); + self.node_under_cursor = Some(NodeUnderCursor::from(current_node)); return; } @@ -269,7 +323,7 @@ impl<'a> CompletionContext<'a> { match current_node_kind { "object_reference" | "field" => { - let content = self.get_ts_node_content(current_node); + let content = self.get_ts_node_content(¤t_node); if let Some(node_txt) = content { match node_txt { NodeText::Original(txt) => { @@ -301,7 +355,7 @@ impl<'a> CompletionContext<'a> { // We have arrived at the leaf node if current_node.child_count() == 0 { - self.node_under_cursor = Some(current_node); + self.node_under_cursor = Some(NodeUnderCursor::from(current_node)); return; } @@ -314,7 +368,7 @@ impl<'a> CompletionContext<'a> { node: tree_sitter::Node<'a>, ) -> Option> { if node.kind().starts_with("keyword_") { - if let Some(txt) = self.get_ts_node_content(node).and_then(|txt| match txt { + if let Some(txt) = self.get_ts_node_content(&node).and_then(|txt| match txt { NodeText::Original(txt) => Some(txt), NodeText::Replaced => None, }) { @@ -365,259 +419,259 @@ impl<'a> CompletionContext<'a> { } } -#[cfg(test)] -mod tests { - use crate::{ - context::{CompletionContext, NodeText, WrappingClause}, - sanitization::SanitizedCompletionParams, - test_helper::{CURSOR_POS, get_text_and_position}, - }; - - fn get_tree(input: &str) -> tree_sitter::Tree { - let mut parser = tree_sitter::Parser::new(); - parser - .set_language(tree_sitter_sql::language()) - .expect("Couldn't set language"); - - parser.parse(input, None).expect("Unable to parse tree") - } - - #[test] - fn identifies_clauses() { - let test_cases = vec![ - ( - format!("Select {}* from users;", CURSOR_POS), - WrappingClause::Select, - ), - ( - format!("Select * from u{};", CURSOR_POS), - WrappingClause::From, - ), - ( - format!("Select {}* from users where n = 1;", CURSOR_POS), - WrappingClause::Select, - ), - ( - format!("Select * from users where {}n = 1;", CURSOR_POS), - WrappingClause::Where, - ), - ( - format!("update users set u{} = 1 where n = 2;", CURSOR_POS), - WrappingClause::Update, - ), - ( - format!("update users set u = 1 where n{} = 2;", CURSOR_POS), - WrappingClause::Where, - ), - ( - format!("delete{} from users;", CURSOR_POS), - WrappingClause::Delete, - ), - ( - format!("delete from {}users;", CURSOR_POS), - WrappingClause::From, - ), - ( - format!("select name, age, location from public.u{}sers", CURSOR_POS), - WrappingClause::From, - ), - ]; - - for (query, expected_clause) in test_cases { - let (position, text) = get_text_and_position(query.as_str().into()); - - let tree = get_tree(text.as_str()); - - let params = SanitizedCompletionParams { - position: (position as u32).into(), - text, - tree: std::borrow::Cow::Owned(tree), - schema: &pgt_schema_cache::SchemaCache::default(), - }; - - let ctx = CompletionContext::new(¶ms); - - assert_eq!(ctx.wrapping_clause_type, Some(expected_clause)); - } - } - - #[test] - fn identifies_schema() { - let test_cases = vec![ - ( - format!("Select * from private.u{}", CURSOR_POS), - Some("private"), - ), - ( - format!("Select * from private.u{}sers()", CURSOR_POS), - Some("private"), - ), - (format!("Select * from u{}sers", CURSOR_POS), None), - (format!("Select * from u{}sers()", CURSOR_POS), None), - ]; - - for (query, expected_schema) in test_cases { - let (position, text) = get_text_and_position(query.as_str().into()); - - let tree = get_tree(text.as_str()); - let params = SanitizedCompletionParams { - position: (position as u32).into(), - text, - tree: std::borrow::Cow::Owned(tree), - schema: &pgt_schema_cache::SchemaCache::default(), - }; - - let ctx = CompletionContext::new(¶ms); - - assert_eq!( - ctx.schema_or_alias_name, - expected_schema.map(|f| f.to_string()) - ); - } - } - - #[test] - fn identifies_invocation() { - let test_cases = vec![ - (format!("Select * from u{}sers", CURSOR_POS), false), - (format!("Select * from u{}sers()", CURSOR_POS), true), - (format!("Select cool{};", CURSOR_POS), false), - (format!("Select cool{}();", CURSOR_POS), true), - ( - format!("Select upp{}ercase as title from users;", CURSOR_POS), - false, - ), - ( - format!("Select upp{}ercase(name) as title from users;", CURSOR_POS), - true, - ), - ]; - - for (query, is_invocation) in test_cases { - let (position, text) = get_text_and_position(query.as_str().into()); - - let tree = get_tree(text.as_str()); - let params = SanitizedCompletionParams { - position: (position as u32).into(), - text, - tree: std::borrow::Cow::Owned(tree), - schema: &pgt_schema_cache::SchemaCache::default(), - }; - - let ctx = CompletionContext::new(¶ms); - - assert_eq!(ctx.is_invocation, is_invocation); - } - } - - #[test] - fn does_not_fail_on_leading_whitespace() { - let cases = vec![ - format!("{} select * from", CURSOR_POS), - format!(" {} select * from", CURSOR_POS), - ]; - - for query in cases { - let (position, text) = get_text_and_position(query.as_str().into()); - - let tree = get_tree(text.as_str()); - - let params = SanitizedCompletionParams { - position: (position as u32).into(), - text, - tree: std::borrow::Cow::Owned(tree), - schema: &pgt_schema_cache::SchemaCache::default(), - }; - - let ctx = CompletionContext::new(¶ms); - - let node = ctx.node_under_cursor.unwrap(); - - assert_eq!( - ctx.get_ts_node_content(node), - Some(NodeText::Original("select")) - ); - - assert_eq!( - ctx.wrapping_clause_type, - Some(crate::context::WrappingClause::Select) - ); - } - } - - #[test] - fn does_not_fail_on_trailing_whitespace() { - let query = format!("select * from {}", CURSOR_POS); - - let (position, text) = get_text_and_position(query.as_str().into()); +// #[cfg(test)] +// mod tests { +// use crate::{ +// context::{CompletionContext, NodeText, WrappingClause}, +// sanitization::SanitizedCompletionParams, +// test_helper::{CURSOR_POS, get_text_and_position}, +// }; + +// fn get_tree(input: &str) -> tree_sitter::Tree { +// let mut parser = tree_sitter::Parser::new(); +// parser +// .set_language(tree_sitter_sql::language()) +// .expect("Couldn't set language"); + +// parser.parse(input, None).expect("Unable to parse tree") +// } + +// #[test] +// fn identifies_clauses() { +// let test_cases = vec![ +// ( +// format!("Select {}* from users;", CURSOR_POS), +// WrappingClause::Select, +// ), +// ( +// format!("Select * from u{};", CURSOR_POS), +// WrappingClause::From, +// ), +// ( +// format!("Select {}* from users where n = 1;", CURSOR_POS), +// WrappingClause::Select, +// ), +// ( +// format!("Select * from users where {}n = 1;", CURSOR_POS), +// WrappingClause::Where, +// ), +// ( +// format!("update users set u{} = 1 where n = 2;", CURSOR_POS), +// WrappingClause::Update, +// ), +// ( +// format!("update users set u = 1 where n{} = 2;", CURSOR_POS), +// WrappingClause::Where, +// ), +// ( +// format!("delete{} from users;", CURSOR_POS), +// WrappingClause::Delete, +// ), +// ( +// format!("delete from {}users;", CURSOR_POS), +// WrappingClause::From, +// ), +// ( +// format!("select name, age, location from public.u{}sers", CURSOR_POS), +// WrappingClause::From, +// ), +// ]; + +// for (query, expected_clause) in test_cases { +// let (position, text) = get_text_and_position(query.as_str().into()); + +// let tree = get_tree(text.as_str()); + +// let params = SanitizedCompletionParams { +// position: (position as u32).into(), +// text, +// tree: std::borrow::Cow::Owned(tree), +// schema: &pgt_schema_cache::SchemaCache::default(), +// }; + +// let ctx = CompletionContext::new(¶ms); + +// assert_eq!(ctx.wrapping_clause_type, Some(expected_clause)); +// } +// } + +// #[test] +// fn identifies_schema() { +// let test_cases = vec![ +// ( +// format!("Select * from private.u{}", CURSOR_POS), +// Some("private"), +// ), +// ( +// format!("Select * from private.u{}sers()", CURSOR_POS), +// Some("private"), +// ), +// (format!("Select * from u{}sers", CURSOR_POS), None), +// (format!("Select * from u{}sers()", CURSOR_POS), None), +// ]; + +// for (query, expected_schema) in test_cases { +// let (position, text) = get_text_and_position(query.as_str().into()); + +// let tree = get_tree(text.as_str()); +// let params = SanitizedCompletionParams { +// position: (position as u32).into(), +// text, +// tree: std::borrow::Cow::Owned(tree), +// schema: &pgt_schema_cache::SchemaCache::default(), +// }; + +// let ctx = CompletionContext::new(¶ms); + +// assert_eq!( +// ctx.schema_or_alias_name, +// expected_schema.map(|f| f.to_string()) +// ); +// } +// } + +// #[test] +// fn identifies_invocation() { +// let test_cases = vec![ +// (format!("Select * from u{}sers", CURSOR_POS), false), +// (format!("Select * from u{}sers()", CURSOR_POS), true), +// (format!("Select cool{};", CURSOR_POS), false), +// (format!("Select cool{}();", CURSOR_POS), true), +// ( +// format!("Select upp{}ercase as title from users;", CURSOR_POS), +// false, +// ), +// ( +// format!("Select upp{}ercase(name) as title from users;", CURSOR_POS), +// true, +// ), +// ]; + +// for (query, is_invocation) in test_cases { +// let (position, text) = get_text_and_position(query.as_str().into()); + +// let tree = get_tree(text.as_str()); +// let params = SanitizedCompletionParams { +// position: (position as u32).into(), +// text, +// tree: std::borrow::Cow::Owned(tree), +// schema: &pgt_schema_cache::SchemaCache::default(), +// }; + +// let ctx = CompletionContext::new(¶ms); + +// assert_eq!(ctx.is_invocation, is_invocation); +// } +// } + +// #[test] +// fn does_not_fail_on_leading_whitespace() { +// let cases = vec![ +// format!("{} select * from", CURSOR_POS), +// format!(" {} select * from", CURSOR_POS), +// ]; + +// for query in cases { +// let (position, text) = get_text_and_position(query.as_str().into()); + +// let tree = get_tree(text.as_str()); + +// let params = SanitizedCompletionParams { +// position: (position as u32).into(), +// text, +// tree: std::borrow::Cow::Owned(tree), +// schema: &pgt_schema_cache::SchemaCache::default(), +// }; + +// let ctx = CompletionContext::new(¶ms); + +// let node = ctx.node_under_cursor.unwrap(); + +// assert_eq!( +// ctx.get_ts_node_content(node), +// Some(NodeText::Original("select")) +// ); + +// assert_eq!( +// ctx.wrapping_clause_type, +// Some(crate::context::WrappingClause::Select) +// ); +// } +// } + +// #[test] +// fn does_not_fail_on_trailing_whitespace() { +// let query = format!("select * from {}", CURSOR_POS); + +// let (position, text) = get_text_and_position(query.as_str().into()); + +// let tree = get_tree(text.as_str()); + +// let params = SanitizedCompletionParams { +// position: (position as u32).into(), +// text, +// tree: std::borrow::Cow::Owned(tree), +// schema: &pgt_schema_cache::SchemaCache::default(), +// }; + +// let ctx = CompletionContext::new(¶ms); + +// let node = ctx.node_under_cursor.unwrap(); + +// assert_eq!( +// ctx.get_ts_node_content(node), +// Some(NodeText::Original("from")) +// ); +// } + +// #[test] +// fn does_not_fail_with_empty_statements() { +// let query = format!("{}", CURSOR_POS); + +// let (position, text) = get_text_and_position(query.as_str().into()); + +// let tree = get_tree(text.as_str()); + +// let params = SanitizedCompletionParams { +// position: (position as u32).into(), +// text, +// tree: std::borrow::Cow::Owned(tree), +// schema: &pgt_schema_cache::SchemaCache::default(), +// }; + +// let ctx = CompletionContext::new(¶ms); - let tree = get_tree(text.as_str()); +// let node = ctx.node_under_cursor.unwrap(); - let params = SanitizedCompletionParams { - position: (position as u32).into(), - text, - tree: std::borrow::Cow::Owned(tree), - schema: &pgt_schema_cache::SchemaCache::default(), - }; - - let ctx = CompletionContext::new(¶ms); +// assert_eq!(ctx.get_ts_node_content(&node), Some(NodeText::Original(""))); +// assert_eq!(ctx.wrapping_clause_type, None); +// } - let node = ctx.node_under_cursor.unwrap(); +// #[test] +// fn does_not_fail_on_incomplete_keywords() { +// // Instead of autocompleting "FROM", we'll assume that the user +// // is selecting a certain column name, such as `frozen_account`. +// let query = format!("select * fro{}", CURSOR_POS); - assert_eq!( - ctx.get_ts_node_content(node), - Some(NodeText::Original("from")) - ); - } +// let (position, text) = get_text_and_position(query.as_str().into()); - #[test] - fn does_not_fail_with_empty_statements() { - let query = format!("{}", CURSOR_POS); +// let tree = get_tree(text.as_str()); - let (position, text) = get_text_and_position(query.as_str().into()); +// let params = SanitizedCompletionParams { +// position: (position as u32).into(), +// text, +// tree: std::borrow::Cow::Owned(tree), +// schema: &pgt_schema_cache::SchemaCache::default(), +// }; - let tree = get_tree(text.as_str()); - - let params = SanitizedCompletionParams { - position: (position as u32).into(), - text, - tree: std::borrow::Cow::Owned(tree), - schema: &pgt_schema_cache::SchemaCache::default(), - }; - - let ctx = CompletionContext::new(¶ms); - - let node = ctx.node_under_cursor.unwrap(); - - assert_eq!(ctx.get_ts_node_content(node), Some(NodeText::Original(""))); - assert_eq!(ctx.wrapping_clause_type, None); - } +// let ctx = CompletionContext::new(¶ms); - #[test] - fn does_not_fail_on_incomplete_keywords() { - // Instead of autocompleting "FROM", we'll assume that the user - // is selecting a certain column name, such as `frozen_account`. - let query = format!("select * fro{}", CURSOR_POS); +// let node = ctx.node_under_cursor.unwrap(); - let (position, text) = get_text_and_position(query.as_str().into()); - - let tree = get_tree(text.as_str()); - - let params = SanitizedCompletionParams { - position: (position as u32).into(), - text, - tree: std::borrow::Cow::Owned(tree), - schema: &pgt_schema_cache::SchemaCache::default(), - }; - - let ctx = CompletionContext::new(¶ms); - - let node = ctx.node_under_cursor.unwrap(); - - assert_eq!( - ctx.get_ts_node_content(node), - Some(NodeText::Original("fro")) - ); - assert_eq!(ctx.wrapping_clause_type, Some(WrappingClause::Select)); - } -} +// assert_eq!( +// ctx.get_ts_node_content(node), +// Some(NodeText::Original("fro")) +// ); +// assert_eq!(ctx.wrapping_clause_type, Some(WrappingClause::Select)); +// } +// } diff --git a/crates/pgt_completions/src/context/mod.rs b/crates/pgt_completions/src/context/mod.rs new file mode 100644 index 00000000..828b6477 --- /dev/null +++ b/crates/pgt_completions/src/context/mod.rs @@ -0,0 +1,4 @@ +mod context; +mod policy_parser; + +pub use context::*; diff --git a/crates/pgt_completions/src/context/policy_parser.rs b/crates/pgt_completions/src/context/policy_parser.rs new file mode 100644 index 00000000..8d700ddd --- /dev/null +++ b/crates/pgt_completions/src/context/policy_parser.rs @@ -0,0 +1,71 @@ +use std::{iter::Peekable, str::SplitAsciiWhitespace}; + +#[derive(Default)] +pub enum PolicyStmtKind { + #[default] + Create, + + Alter, + Drop, +} + +#[derive(Default)] +pub struct PolicyContext { + table_name: String, + schema_name: Option, + statement_kind: PolicyStmtKind, +} + +pub struct PolicyParser<'a> { + tokens: Peekable>, + sql: &'a str, + context: PolicyContext, +} + +impl<'a> PolicyParser<'a> { + pub(crate) fn get_context(sql: &'a str, cursor_position: usize) -> PolicyContext { + let lower_cased = sql.to_ascii_lowercase(); + + let parser = PolicyParser { + tokens: lower_cased.split_ascii_whitespace().peekable(), + sql, + context: PolicyContext::default(), + }; + + parser.parse() + } + + fn parse(mut self) -> PolicyContext { + while let Some(token) = self.tokens.next() { + self.handle_token(token); + } + + self.context + } + + fn handle_token(&mut self, token: &'a str) { + match token { + "create" if self.next_matches("policy") => { + self.context.statement_kind = PolicyStmtKind::Create; + } + "alter" if self.next_matches("policy") => { + self.context.statement_kind = PolicyStmtKind::Alter; + } + "drop" if self.next_matches("policy") => { + self.context.statement_kind = PolicyStmtKind::Drop; + } + + "on" => self.table_with_schema(), + + _ => {} + } + } + + fn next_matches(&mut self, it: &str) -> bool { + self.tokens.peek().is_some_and(|c| *c == it) + } + + fn table_with_schema(&mut self) { + let token = self.tokens.next(); + } +} diff --git a/crates/pgt_completions/src/providers/helper.rs b/crates/pgt_completions/src/providers/helper.rs index 999d6b37..a6bee236 100644 --- a/crates/pgt_completions/src/providers/helper.rs +++ b/crates/pgt_completions/src/providers/helper.rs @@ -22,7 +22,7 @@ pub(crate) fn get_completion_text_with_schema_or_alias( if schema_or_alias_name == "public" || ctx.schema_or_alias_name.is_some() { None } else { - let node = ctx.node_under_cursor.unwrap(); + let node = ctx.node_under_cursor.as_ref().unwrap(); let range = TextRange::new( TextSize::try_from(node.start_byte()).unwrap(), diff --git a/crates/pgt_completions/src/providers/policies.rs b/crates/pgt_completions/src/providers/policies.rs index 8e2533a3..380746c7 100644 --- a/crates/pgt_completions/src/providers/policies.rs +++ b/crates/pgt_completions/src/providers/policies.rs @@ -1,35 +1,25 @@ use crate::{ CompletionItemKind, builder::{CompletionBuilder, PossibleCompletionItem}, - context::{CompletionContext, WrappingClause}, + context::CompletionContext, relevance::{CompletionRelevanceData, filtering::CompletionFilter, scoring::CompletionScore}, }; -use super::helper::{find_matching_alias_for_table, get_completion_text_with_schema_or_alias}; - pub fn complete_policies<'a>(ctx: &CompletionContext<'a>, builder: &mut CompletionBuilder<'a>) { let available_policies = &ctx.schema_cache.policies; for pol in available_policies { let relevance = CompletionRelevanceData::Policy(pol); - let mut item = PossibleCompletionItem { + let item = PossibleCompletionItem { label: pol.name.clone(), score: CompletionScore::from(relevance.clone()), filter: CompletionFilter::from(relevance), - description: format!("Table: {}.{}", pol.schema_name, pol.table_name), - kind: CompletionItemKind::Column, + description: format!("Table: {}", pol.table_name), + kind: CompletionItemKind::Policy, completion_text: None, }; - // autocomplete with the alias in a join clause if we find one - if matches!(ctx.wrapping_clause_type, Some(WrappingClause::Join { .. })) { - item.completion_text = find_matching_alias_for_table(ctx, pol.table_name.as_str()) - .and_then(|alias| { - get_completion_text_with_schema_or_alias(ctx, pol.name.as_str(), alias.as_str()) - }); - } - builder.add_item(item); } } diff --git a/crates/pgt_completions/src/relevance/filtering.rs b/crates/pgt_completions/src/relevance/filtering.rs index b3299c68..4c9fa139 100644 --- a/crates/pgt_completions/src/relevance/filtering.rs +++ b/crates/pgt_completions/src/relevance/filtering.rs @@ -1,4 +1,4 @@ -use crate::context::{CompletionContext, WrappingClause}; +use crate::context::{CompletionContext, NodeUnderCursor, WrappingClause}; use super::CompletionRelevanceData; @@ -24,7 +24,11 @@ impl CompletionFilter<'_> { } fn completable_context(&self, ctx: &CompletionContext) -> Option<()> { - let current_node_kind = ctx.node_under_cursor.map(|n| n.kind()).unwrap_or(""); + let current_node_kind = ctx + .node_under_cursor + .as_ref() + .map(|n| n.kind()) + .unwrap_or(""); if current_node_kind.starts_with("keyword_") || current_node_kind == "=" @@ -36,20 +40,23 @@ impl CompletionFilter<'_> { } // No autocompletions if there are two identifiers without a separator. - if ctx.node_under_cursor.is_some_and(|n| { - n.prev_sibling().is_some_and(|p| { + if ctx.node_under_cursor.as_ref().is_some_and(|n| match n { + NodeUnderCursor::TsNode(node) => node.prev_sibling().is_some_and(|p| { (p.kind() == "identifier" || p.kind() == "object_reference") && n.kind() == "identifier" - }) + }), + NodeUnderCursor::CustomNode { .. } => false, }) { return None; } // no completions if we're right after an asterisk: // `select * {}` - if ctx.node_under_cursor.is_some_and(|n| { - n.prev_sibling() - .is_some_and(|p| (p.kind() == "all_fields") && n.kind() == "identifier") + if ctx.node_under_cursor.as_ref().is_some_and(|n| match n { + NodeUnderCursor::TsNode(node) => node + .prev_sibling() + .is_some_and(|p| (p.kind() == "all_fields") && n.kind() == "identifier"), + NodeUnderCursor::CustomNode { .. } => false, }) { return None; } @@ -83,6 +90,7 @@ impl CompletionFilter<'_> { WrappingClause::Join { on_node: Some(on) } => ctx .node_under_cursor + .as_ref() .is_some_and(|n| n.end_byte() < on.start_byte()), _ => false, diff --git a/crates/pgt_completions/src/relevance/scoring.rs b/crates/pgt_completions/src/relevance/scoring.rs index d9c4d937..0b0933e5 100644 --- a/crates/pgt_completions/src/relevance/scoring.rs +++ b/crates/pgt_completions/src/relevance/scoring.rs @@ -82,6 +82,7 @@ impl CompletionScore<'_> { WrappingClause::Join { on_node } if on_node.is_none_or(|on| { ctx.node_under_cursor + .as_ref() .is_none_or(|n| n.end_byte() < on.start_byte()) }) => { @@ -102,6 +103,7 @@ impl CompletionScore<'_> { WrappingClause::Join { on_node } if on_node.is_some_and(|on| { ctx.node_under_cursor + .as_ref() .is_some_and(|n| n.start_byte() > on.end_byte()) }) => { From 76224096c19aaebd6057fc4a519544d2509eaeda Mon Sep 17 00:00:00 2001 From: Julian Date: Fri, 9 May 2025 10:15:15 +0200 Subject: [PATCH 03/18] =?UTF-8?q?sofar=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/pgt_completions/src/context/context.rs | 546 +++++++++-------- .../src/context/policy_parser.rs | 569 +++++++++++++++++- crates/pgt_lsp/src/adapters/mod.rs | 21 + crates/pgt_text_size/src/range.rs | 33 + 4 files changed, 889 insertions(+), 280 deletions(-) diff --git a/crates/pgt_completions/src/context/context.rs b/crates/pgt_completions/src/context/context.rs index 166652a2..f92f57da 100644 --- a/crates/pgt_completions/src/context/context.rs +++ b/crates/pgt_completions/src/context/context.rs @@ -172,14 +172,28 @@ impl<'a> CompletionContext<'a> { // policy handling is important to Supabase, but they are a PostgreSQL specific extension, // so the tree_sitter_sql language does not support it. // We infer the context manually. - if params.text.to_lowercase().starts_with("create policy") - || params.text.to_lowercase().starts_with("alter policy") - || params.text.to_lowercase().starts_with("drop policy") - { - } else { - ctx.gather_tree_context(); - ctx.gather_info_from_ts_queries(); - } + // if params.text.to_lowercase().starts_with("create policy") + // || params.text.to_lowercase().starts_with("alter policy") + // || params.text.to_lowercase().starts_with("drop policy") + // { + // } else { + ctx.gather_tree_context(); + ctx.gather_info_from_ts_queries(); + // } + + tracing::warn!("sql: {}", ctx.text); + tracing::warn!("position: {}", ctx.position); + tracing::warn!( + "node range: {} - {}", + ctx.node_under_cursor + .as_ref() + .map(|n| n.start_byte()) + .unwrap_or(0), + ctx.node_under_cursor + .as_ref() + .map(|n| n.end_byte()) + .unwrap_or(0) + ); ctx } @@ -419,259 +433,281 @@ impl<'a> CompletionContext<'a> { } } -// #[cfg(test)] -// mod tests { -// use crate::{ -// context::{CompletionContext, NodeText, WrappingClause}, -// sanitization::SanitizedCompletionParams, -// test_helper::{CURSOR_POS, get_text_and_position}, -// }; - -// fn get_tree(input: &str) -> tree_sitter::Tree { -// let mut parser = tree_sitter::Parser::new(); -// parser -// .set_language(tree_sitter_sql::language()) -// .expect("Couldn't set language"); - -// parser.parse(input, None).expect("Unable to parse tree") -// } - -// #[test] -// fn identifies_clauses() { -// let test_cases = vec![ -// ( -// format!("Select {}* from users;", CURSOR_POS), -// WrappingClause::Select, -// ), -// ( -// format!("Select * from u{};", CURSOR_POS), -// WrappingClause::From, -// ), -// ( -// format!("Select {}* from users where n = 1;", CURSOR_POS), -// WrappingClause::Select, -// ), -// ( -// format!("Select * from users where {}n = 1;", CURSOR_POS), -// WrappingClause::Where, -// ), -// ( -// format!("update users set u{} = 1 where n = 2;", CURSOR_POS), -// WrappingClause::Update, -// ), -// ( -// format!("update users set u = 1 where n{} = 2;", CURSOR_POS), -// WrappingClause::Where, -// ), -// ( -// format!("delete{} from users;", CURSOR_POS), -// WrappingClause::Delete, -// ), -// ( -// format!("delete from {}users;", CURSOR_POS), -// WrappingClause::From, -// ), -// ( -// format!("select name, age, location from public.u{}sers", CURSOR_POS), -// WrappingClause::From, -// ), -// ]; - -// for (query, expected_clause) in test_cases { -// let (position, text) = get_text_and_position(query.as_str().into()); - -// let tree = get_tree(text.as_str()); - -// let params = SanitizedCompletionParams { -// position: (position as u32).into(), -// text, -// tree: std::borrow::Cow::Owned(tree), -// schema: &pgt_schema_cache::SchemaCache::default(), -// }; - -// let ctx = CompletionContext::new(¶ms); - -// assert_eq!(ctx.wrapping_clause_type, Some(expected_clause)); -// } -// } - -// #[test] -// fn identifies_schema() { -// let test_cases = vec![ -// ( -// format!("Select * from private.u{}", CURSOR_POS), -// Some("private"), -// ), -// ( -// format!("Select * from private.u{}sers()", CURSOR_POS), -// Some("private"), -// ), -// (format!("Select * from u{}sers", CURSOR_POS), None), -// (format!("Select * from u{}sers()", CURSOR_POS), None), -// ]; - -// for (query, expected_schema) in test_cases { -// let (position, text) = get_text_and_position(query.as_str().into()); - -// let tree = get_tree(text.as_str()); -// let params = SanitizedCompletionParams { -// position: (position as u32).into(), -// text, -// tree: std::borrow::Cow::Owned(tree), -// schema: &pgt_schema_cache::SchemaCache::default(), -// }; - -// let ctx = CompletionContext::new(¶ms); - -// assert_eq!( -// ctx.schema_or_alias_name, -// expected_schema.map(|f| f.to_string()) -// ); -// } -// } - -// #[test] -// fn identifies_invocation() { -// let test_cases = vec![ -// (format!("Select * from u{}sers", CURSOR_POS), false), -// (format!("Select * from u{}sers()", CURSOR_POS), true), -// (format!("Select cool{};", CURSOR_POS), false), -// (format!("Select cool{}();", CURSOR_POS), true), -// ( -// format!("Select upp{}ercase as title from users;", CURSOR_POS), -// false, -// ), -// ( -// format!("Select upp{}ercase(name) as title from users;", CURSOR_POS), -// true, -// ), -// ]; - -// for (query, is_invocation) in test_cases { -// let (position, text) = get_text_and_position(query.as_str().into()); - -// let tree = get_tree(text.as_str()); -// let params = SanitizedCompletionParams { -// position: (position as u32).into(), -// text, -// tree: std::borrow::Cow::Owned(tree), -// schema: &pgt_schema_cache::SchemaCache::default(), -// }; - -// let ctx = CompletionContext::new(¶ms); - -// assert_eq!(ctx.is_invocation, is_invocation); -// } -// } - -// #[test] -// fn does_not_fail_on_leading_whitespace() { -// let cases = vec![ -// format!("{} select * from", CURSOR_POS), -// format!(" {} select * from", CURSOR_POS), -// ]; - -// for query in cases { -// let (position, text) = get_text_and_position(query.as_str().into()); - -// let tree = get_tree(text.as_str()); - -// let params = SanitizedCompletionParams { -// position: (position as u32).into(), -// text, -// tree: std::borrow::Cow::Owned(tree), -// schema: &pgt_schema_cache::SchemaCache::default(), -// }; - -// let ctx = CompletionContext::new(¶ms); - -// let node = ctx.node_under_cursor.unwrap(); - -// assert_eq!( -// ctx.get_ts_node_content(node), -// Some(NodeText::Original("select")) -// ); - -// assert_eq!( -// ctx.wrapping_clause_type, -// Some(crate::context::WrappingClause::Select) -// ); -// } -// } - -// #[test] -// fn does_not_fail_on_trailing_whitespace() { -// let query = format!("select * from {}", CURSOR_POS); - -// let (position, text) = get_text_and_position(query.as_str().into()); - -// let tree = get_tree(text.as_str()); - -// let params = SanitizedCompletionParams { -// position: (position as u32).into(), -// text, -// tree: std::borrow::Cow::Owned(tree), -// schema: &pgt_schema_cache::SchemaCache::default(), -// }; - -// let ctx = CompletionContext::new(¶ms); - -// let node = ctx.node_under_cursor.unwrap(); - -// assert_eq!( -// ctx.get_ts_node_content(node), -// Some(NodeText::Original("from")) -// ); -// } - -// #[test] -// fn does_not_fail_with_empty_statements() { -// let query = format!("{}", CURSOR_POS); - -// let (position, text) = get_text_and_position(query.as_str().into()); - -// let tree = get_tree(text.as_str()); - -// let params = SanitizedCompletionParams { -// position: (position as u32).into(), -// text, -// tree: std::borrow::Cow::Owned(tree), -// schema: &pgt_schema_cache::SchemaCache::default(), -// }; - -// let ctx = CompletionContext::new(¶ms); +#[cfg(test)] +mod tests { + use crate::{ + context::{CompletionContext, NodeText, WrappingClause}, + sanitization::SanitizedCompletionParams, + test_helper::{CURSOR_POS, get_text_and_position}, + }; + + use super::NodeUnderCursor; + + fn get_tree(input: &str) -> tree_sitter::Tree { + let mut parser = tree_sitter::Parser::new(); + parser + .set_language(tree_sitter_sql::language()) + .expect("Couldn't set language"); + + parser.parse(input, None).expect("Unable to parse tree") + } + + #[test] + fn identifies_clauses() { + let test_cases = vec![ + ( + format!("Select {}* from users;", CURSOR_POS), + WrappingClause::Select, + ), + ( + format!("Select * from u{};", CURSOR_POS), + WrappingClause::From, + ), + ( + format!("Select {}* from users where n = 1;", CURSOR_POS), + WrappingClause::Select, + ), + ( + format!("Select * from users where {}n = 1;", CURSOR_POS), + WrappingClause::Where, + ), + ( + format!("update users set u{} = 1 where n = 2;", CURSOR_POS), + WrappingClause::Update, + ), + ( + format!("update users set u = 1 where n{} = 2;", CURSOR_POS), + WrappingClause::Where, + ), + ( + format!("delete{} from users;", CURSOR_POS), + WrappingClause::Delete, + ), + ( + format!("delete from {}users;", CURSOR_POS), + WrappingClause::From, + ), + ( + format!("select name, age, location from public.u{}sers", CURSOR_POS), + WrappingClause::From, + ), + ]; + + for (query, expected_clause) in test_cases { + let (position, text) = get_text_and_position(query.as_str().into()); + + let tree = get_tree(text.as_str()); + + let params = SanitizedCompletionParams { + position: (position as u32).into(), + text, + tree: std::borrow::Cow::Owned(tree), + schema: &pgt_schema_cache::SchemaCache::default(), + }; + + let ctx = CompletionContext::new(¶ms); + + assert_eq!(ctx.wrapping_clause_type, Some(expected_clause)); + } + } + + #[test] + fn identifies_schema() { + let test_cases = vec![ + ( + format!("Select * from private.u{}", CURSOR_POS), + Some("private"), + ), + ( + format!("Select * from private.u{}sers()", CURSOR_POS), + Some("private"), + ), + (format!("Select * from u{}sers", CURSOR_POS), None), + (format!("Select * from u{}sers()", CURSOR_POS), None), + ]; + + for (query, expected_schema) in test_cases { + let (position, text) = get_text_and_position(query.as_str().into()); + + let tree = get_tree(text.as_str()); + let params = SanitizedCompletionParams { + position: (position as u32).into(), + text, + tree: std::borrow::Cow::Owned(tree), + schema: &pgt_schema_cache::SchemaCache::default(), + }; + + let ctx = CompletionContext::new(¶ms); + + assert_eq!( + ctx.schema_or_alias_name, + expected_schema.map(|f| f.to_string()) + ); + } + } + + #[test] + fn identifies_invocation() { + let test_cases = vec![ + (format!("Select * from u{}sers", CURSOR_POS), false), + (format!("Select * from u{}sers()", CURSOR_POS), true), + (format!("Select cool{};", CURSOR_POS), false), + (format!("Select cool{}();", CURSOR_POS), true), + ( + format!("Select upp{}ercase as title from users;", CURSOR_POS), + false, + ), + ( + format!("Select upp{}ercase(name) as title from users;", CURSOR_POS), + true, + ), + ]; + + for (query, is_invocation) in test_cases { + let (position, text) = get_text_and_position(query.as_str().into()); + + let tree = get_tree(text.as_str()); + let params = SanitizedCompletionParams { + position: (position as u32).into(), + text, + tree: std::borrow::Cow::Owned(tree), + schema: &pgt_schema_cache::SchemaCache::default(), + }; + + let ctx = CompletionContext::new(¶ms); + + assert_eq!(ctx.is_invocation, is_invocation); + } + } + + #[test] + fn does_not_fail_on_leading_whitespace() { + let cases = vec![ + format!("{} select * from", CURSOR_POS), + format!(" {} select * from", CURSOR_POS), + ]; + + for query in cases { + let (position, text) = get_text_and_position(query.as_str().into()); + + let tree = get_tree(text.as_str()); + + let params = SanitizedCompletionParams { + position: (position as u32).into(), + text, + tree: std::borrow::Cow::Owned(tree), + schema: &pgt_schema_cache::SchemaCache::default(), + }; + + let ctx = CompletionContext::new(¶ms); + + let node = ctx.node_under_cursor.as_ref().unwrap(); + + match node { + NodeUnderCursor::TsNode(node) => { + assert_eq!( + ctx.get_ts_node_content(node), + Some(NodeText::Original("select")) + ); + + assert_eq!( + ctx.wrapping_clause_type, + Some(crate::context::WrappingClause::Select) + ); + } + _ => unreachable!(), + } + } + } -// let node = ctx.node_under_cursor.unwrap(); + #[test] + fn does_not_fail_on_trailing_whitespace() { + let query = format!("select * from {}", CURSOR_POS); -// assert_eq!(ctx.get_ts_node_content(&node), Some(NodeText::Original(""))); -// assert_eq!(ctx.wrapping_clause_type, None); -// } + let (position, text) = get_text_and_position(query.as_str().into()); -// #[test] -// fn does_not_fail_on_incomplete_keywords() { -// // Instead of autocompleting "FROM", we'll assume that the user -// // is selecting a certain column name, such as `frozen_account`. -// let query = format!("select * fro{}", CURSOR_POS); + let tree = get_tree(text.as_str()); -// let (position, text) = get_text_and_position(query.as_str().into()); + let params = SanitizedCompletionParams { + position: (position as u32).into(), + text, + tree: std::borrow::Cow::Owned(tree), + schema: &pgt_schema_cache::SchemaCache::default(), + }; -// let tree = get_tree(text.as_str()); + let ctx = CompletionContext::new(¶ms); -// let params = SanitizedCompletionParams { -// position: (position as u32).into(), -// text, -// tree: std::borrow::Cow::Owned(tree), -// schema: &pgt_schema_cache::SchemaCache::default(), -// }; + let node = ctx.node_under_cursor.as_ref().unwrap(); -// let ctx = CompletionContext::new(¶ms); + match node { + NodeUnderCursor::TsNode(node) => { + assert_eq!( + ctx.get_ts_node_content(&node), + Some(NodeText::Original("from")) + ); + } + _ => unreachable!(), + } + } + + #[test] + fn does_not_fail_with_empty_statements() { + let query = format!("{}", CURSOR_POS); + + let (position, text) = get_text_and_position(query.as_str().into()); + + let tree = get_tree(text.as_str()); -// let node = ctx.node_under_cursor.unwrap(); + let params = SanitizedCompletionParams { + position: (position as u32).into(), + text, + tree: std::borrow::Cow::Owned(tree), + schema: &pgt_schema_cache::SchemaCache::default(), + }; + + let ctx = CompletionContext::new(¶ms); -// assert_eq!( -// ctx.get_ts_node_content(node), -// Some(NodeText::Original("fro")) -// ); -// assert_eq!(ctx.wrapping_clause_type, Some(WrappingClause::Select)); -// } -// } + let node = ctx.node_under_cursor.as_ref().unwrap(); + + match node { + NodeUnderCursor::TsNode(node) => { + assert_eq!(ctx.get_ts_node_content(&node), Some(NodeText::Original(""))); + assert_eq!(ctx.wrapping_clause_type, None); + } + _ => unreachable!(), + } + } + + #[test] + fn does_not_fail_on_incomplete_keywords() { + // Instead of autocompleting "FROM", we'll assume that the user + // is selecting a certain column name, such as `frozen_account`. + let query = format!("select * fro{}", CURSOR_POS); + + let (position, text) = get_text_and_position(query.as_str().into()); + + let tree = get_tree(text.as_str()); + + let params = SanitizedCompletionParams { + position: (position as u32).into(), + text, + tree: std::borrow::Cow::Owned(tree), + schema: &pgt_schema_cache::SchemaCache::default(), + }; + + let ctx = CompletionContext::new(¶ms); + + let node = ctx.node_under_cursor.as_ref().unwrap(); + + match node { + NodeUnderCursor::TsNode(node) => { + assert_eq!( + ctx.get_ts_node_content(&node), + Some(NodeText::Original("fro")) + ); + assert_eq!(ctx.wrapping_clause_type, Some(WrappingClause::Select)); + } + _ => unreachable!(), + } + } +} diff --git a/crates/pgt_completions/src/context/policy_parser.rs b/crates/pgt_completions/src/context/policy_parser.rs index 8d700ddd..0a117180 100644 --- a/crates/pgt_completions/src/context/policy_parser.rs +++ b/crates/pgt_completions/src/context/policy_parser.rs @@ -1,6 +1,8 @@ -use std::{iter::Peekable, str::SplitAsciiWhitespace}; +use std::{env::current_exe, iter::Peekable}; -#[derive(Default)] +use pgt_text_size::{TextRange, TextSize}; + +#[derive(Default, Debug, PartialEq, Eq)] pub enum PolicyStmtKind { #[default] Create, @@ -9,42 +11,185 @@ pub enum PolicyStmtKind { Drop, } -#[derive(Default)] +#[derive(Clone, Debug, PartialEq, Eq)] +struct WordWithIndex { + word: String, + start: usize, + end: usize, +} + +impl WordWithIndex { + fn under_cursor(&self, cursor_pos: usize) -> bool { + self.start <= cursor_pos && self.end > cursor_pos + } + + fn get_range(&self) -> TextRange { + let start: u32 = self.start.try_into().expect("Text too long"); + let end: u32 = self.end.try_into().expect("Text too long"); + TextRange::new(TextSize::from(start), TextSize::from(end)) + } +} + +fn sql_to_words(sql: &str) -> Result, String> { + let mut words = vec![]; + + let mut start: Option = None; + let mut current_word = String::new(); + let mut in_quotation_marks = false; + + for (pos, c) in sql.char_indices() { + if (c.is_ascii_whitespace() || c == ';') + && !current_word.is_empty() + && start.is_some() + && !in_quotation_marks + { + words.push(WordWithIndex { + word: current_word, + start: start.unwrap(), + end: pos, + }); + current_word = String::new(); + start = None; + } else if (c.is_ascii_whitespace() || c == ';') && current_word.is_empty() { + // do nothing + } else if c == '"' && start.is_none() { + in_quotation_marks = true; + start = Some(pos); + current_word.push(c); + } else if c == '"' && start.is_some() { + current_word.push(c); + words.push(WordWithIndex { + word: current_word, + start: start.unwrap(), + end: pos + 1, + }); + in_quotation_marks = false; + start = None; + current_word = String::new() + } else if start.is_some() { + current_word.push(c) + } else { + start = Some(pos); + current_word.push(c); + } + } + + if !current_word.is_empty() && start.is_some() { + words.push(WordWithIndex { + word: current_word, + start: start.unwrap(), + end: sql.len(), + }); + } + + if in_quotation_marks { + Err("String was not closed properly.".into()) + } else { + Ok(words) + } +} + +#[derive(Default, Debug, PartialEq, Eq)] pub struct PolicyContext { - table_name: String, + policy_name: Option, + table_name: Option, schema_name: Option, statement_kind: PolicyStmtKind, + node_text: String, + node_range: TextRange, + node_kind: String, } -pub struct PolicyParser<'a> { - tokens: Peekable>, - sql: &'a str, +pub struct PolicyParser { + tokens: Peekable>, + previous_token: Option, + current_token: Option, context: PolicyContext, + cursor_position: usize, } -impl<'a> PolicyParser<'a> { - pub(crate) fn get_context(sql: &'a str, cursor_position: usize) -> PolicyContext { - let lower_cased = sql.to_ascii_lowercase(); - - let parser = PolicyParser { - tokens: lower_cased.split_ascii_whitespace().peekable(), - sql, - context: PolicyContext::default(), - }; +impl PolicyParser { + pub(crate) fn get_context(sql: &str, cursor_position: usize) -> PolicyContext { + match sql_to_words(sql) { + Ok(tokens) => { + let parser = PolicyParser { + tokens: tokens.into_iter().peekable(), + context: PolicyContext::default(), + previous_token: None, + current_token: None, + cursor_position, + }; - parser.parse() + parser.parse() + } + Err(_) => PolicyContext::default(), + } } fn parse(mut self) -> PolicyContext { - while let Some(token) = self.tokens.next() { - self.handle_token(token); + while let Some(token) = self.advance() { + if token.under_cursor(self.cursor_position) { + self.handle_token_under_cursor(token); + } else { + self.handle_token(token); + } } self.context } - fn handle_token(&mut self, token: &'a str) { - match token { + fn handle_token_under_cursor(&mut self, token: WordWithIndex) { + if self.previous_token.is_none() { + return; + } + + let previous = self.previous_token.take().unwrap(); + + match previous.word.to_ascii_lowercase().as_str() { + "policy" => { + self.context.node_range = token.get_range(); + self.context.node_kind = "policy_name".into(); + self.context.node_text = token.word; + } + "on" => { + if token.word.contains('.') { + let mut parts = token.word.split('.'); + + let schema_name: String = parts.next().unwrap().into(); + let schema_name_len = schema_name.len(); + self.context.schema_name = Some(schema_name); + + let offset: u32 = schema_name_len.try_into().expect("Text too long"); + let range_without_schema = token + .get_range() + .checked_expand_start( + TextSize::new(offset + 1), // kill the dot as well + ) + .expect("Text too long"); + + self.context.node_range = range_without_schema; + self.context.node_text = parts.next().unwrap().into(); + self.context.node_kind = "policy_table".into(); + } else { + self.context.node_range = token.get_range(); + self.context.node_text = token.word; + self.context.node_kind = "policy_table".into(); + } + } + "to" => { + self.context.node_range = token.get_range(); + self.context.node_kind = "policy_role".into(); + self.context.node_text = token.word; + } + _ => { + self.context.node_range = token.get_range(); + self.context.node_text = token.word; + } + } + } + + fn handle_token(&mut self, token: WordWithIndex) { + match token.word.to_ascii_lowercase().as_str() { "create" if self.next_matches("policy") => { self.context.statement_kind = PolicyStmtKind::Create; } @@ -54,18 +199,392 @@ impl<'a> PolicyParser<'a> { "drop" if self.next_matches("policy") => { self.context.statement_kind = PolicyStmtKind::Drop; } - "on" => self.table_with_schema(), - _ => {} + // skip the "to" so we don't parse it as the TO rolename + "rename" if self.next_matches("to") => { + self.advance(); + } + + _ => { + if self.prev_matches("policy") { + self.context.policy_name = Some(token.word); + } + } } } fn next_matches(&mut self, it: &str) -> bool { - self.tokens.peek().is_some_and(|c| *c == it) + self.tokens.peek().is_some_and(|c| c.word.as_str() == it) + } + + fn prev_matches(&self, it: &str) -> bool { + self.previous_token.as_ref().is_some_and(|t| t.word == it) + } + + fn advance(&mut self) -> Option { + self.previous_token = self.current_token.take(); + self.current_token = self.tokens.next(); + self.current_token.clone() } fn table_with_schema(&mut self) { - let token = self.tokens.next(); + self.advance().map(|token| { + if token.under_cursor(self.cursor_position) { + self.handle_token_under_cursor(token); + } else if token.word.contains('.') { + let (schema, maybe_table) = self.schema_and_table_name(token); + self.context.schema_name = Some(schema); + self.context.table_name = maybe_table; + } else { + self.context.table_name = Some(token.word); + } + }); + } + + fn schema_and_table_name(&self, token: WordWithIndex) -> (String, Option) { + let mut parts = token.word.split('.'); + + ( + parts.next().unwrap().into(), + parts.next().map(|tb| tb.into()), + ) + } +} + +#[cfg(test)] +mod tests { + use pgt_text_size::{TextRange, TextSize}; + + use crate::{ + context::policy_parser::{PolicyContext, PolicyStmtKind, WordWithIndex}, + test_helper::CURSOR_POS, + }; + + use super::{PolicyParser, sql_to_words}; + + fn with_pos(query: String) -> (usize, String) { + let mut pos: Option = None; + + for (p, c) in query.char_indices() { + if c == CURSOR_POS { + pos = Some(p); + break; + } + } + + return ( + pos.expect("Please add cursor position!"), + query.replace(CURSOR_POS, "REPLACED_TOKEN").to_string(), + ); + } + + #[test] + fn infers_progressively() { + let (pos, query) = with_pos(format!( + r#" + create policy {} + "#, + CURSOR_POS + )); + + let context = PolicyParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + PolicyContext { + policy_name: None, + table_name: None, + schema_name: None, + statement_kind: PolicyStmtKind::Create, + node_text: "REPLACED_TOKEN".into(), + node_range: TextRange::new(TextSize::new(25), TextSize::new(39)), + node_kind: "policy_name".into() + } + ); + + let (pos, query) = with_pos(format!( + r#" + create policy "my cool policy" {} + "#, + CURSOR_POS + )); + + let context = PolicyParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + PolicyContext { + policy_name: Some("\"my cool policy\"".into()), + table_name: None, + schema_name: None, + statement_kind: PolicyStmtKind::Create, + node_text: "REPLACED_TOKEN".into(), + node_kind: "".into(), + node_range: TextRange::new(TextSize::new(42), TextSize::new(56)), + } + ); + + let (pos, query) = with_pos(format!( + r#" + create policy "my cool policy" on {} + "#, + CURSOR_POS + )); + + let context = PolicyParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + PolicyContext { + policy_name: Some("\"my cool policy\"".into()), + table_name: None, + schema_name: None, + statement_kind: PolicyStmtKind::Create, + node_text: "REPLACED_TOKEN".into(), + node_kind: "policy_table".into(), + node_range: TextRange::new(TextSize::new(45), TextSize::new(59)), + } + ); + + let (pos, query) = with_pos(format!( + r#" + create policy "my cool policy" on auth.{} + "#, + CURSOR_POS + )); + + let context = PolicyParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + PolicyContext { + policy_name: Some("\"my cool policy\"".into()), + table_name: None, + schema_name: Some("auth".into()), + statement_kind: PolicyStmtKind::Create, + node_text: "REPLACED_TOKEN".into(), + node_kind: "policy_table".into(), + node_range: TextRange::new(TextSize::new(50), TextSize::new(64)), + } + ); + + let (pos, query) = with_pos(format!( + r#" + create policy "my cool policy" on auth.users + as {} + "#, + CURSOR_POS + )); + + let context = PolicyParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + PolicyContext { + policy_name: Some("\"my cool policy\"".into()), + table_name: Some("users".into()), + schema_name: Some("auth".into()), + statement_kind: PolicyStmtKind::Create, + node_text: "REPLACED_TOKEN".into(), + node_kind: "".into(), + node_range: TextRange::new(TextSize::new(72), TextSize::new(86)), + } + ); + + let (pos, query) = with_pos(format!( + r#" + create policy "my cool policy" on auth.users + as permissive + {} + "#, + CURSOR_POS + )); + + let context = PolicyParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + PolicyContext { + policy_name: Some("\"my cool policy\"".into()), + table_name: Some("users".into()), + schema_name: Some("auth".into()), + statement_kind: PolicyStmtKind::Create, + node_text: "REPLACED_TOKEN".into(), + node_kind: "".into(), + node_range: TextRange::new(TextSize::new(95), TextSize::new(109)), + } + ); + + let (pos, query) = with_pos(format!( + r#" + create policy "my cool policy" on auth.users + as permissive + to {} + "#, + CURSOR_POS + )); + + let context = PolicyParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + PolicyContext { + policy_name: Some("\"my cool policy\"".into()), + table_name: Some("users".into()), + schema_name: Some("auth".into()), + statement_kind: PolicyStmtKind::Create, + node_text: "REPLACED_TOKEN".into(), + node_kind: "policy_role".into(), + node_range: TextRange::new(TextSize::new(98), TextSize::new(112)), + } + ); + } + + #[test] + fn determines_on_table_node() { + let (pos, query) = with_pos(format!( + r#" + create policy "my cool policy" + on {} + to all + using (true); + "#, + CURSOR_POS + )); + + let context = PolicyParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + PolicyContext { + policy_name: Some(r#""my cool policy""#.into()), + table_name: None, + schema_name: None, + statement_kind: PolicyStmtKind::Create, + node_text: "REPLACED_TOKEN".into(), + node_range: TextRange::new(TextSize::new(57), TextSize::new(71)), + node_kind: "policy_table".into() + } + ) + } + + #[test] + fn determines_on_table_node_after_schema() { + let (pos, query) = with_pos(format!( + r#" + create policy "my cool policy" + on auth.{} + to all + using (true); + "#, + CURSOR_POS + )); + + let context = PolicyParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + PolicyContext { + policy_name: Some(r#""my cool policy""#.into()), + table_name: None, + schema_name: Some("auth".into()), + statement_kind: PolicyStmtKind::Create, + node_text: "REPLACED_TOKEN".into(), + node_range: TextRange::new(TextSize::new(62), TextSize::new(76)), + node_kind: "policy_table".into() + } + ) + } + + #[test] + fn determines_we_are_on_column_name() { + let (pos, query) = with_pos(format!( + r#" + drop policy {} on auth.users; + "#, + CURSOR_POS + )); + + let context = PolicyParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + PolicyContext { + policy_name: None, + table_name: Some("users".into()), + schema_name: Some("auth".into()), + statement_kind: PolicyStmtKind::Drop, + node_text: "REPLACED_TOKEN".into(), + node_range: TextRange::new(TextSize::new(23), TextSize::new(37)), + node_kind: "policy_name".into() + } + ); + + // cursor within quotation marks. + let (pos, query) = with_pos(format!( + r#" + drop policy "{}" on auth.users; + "#, + CURSOR_POS + )); + + let context = PolicyParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + PolicyContext { + policy_name: None, + table_name: Some("users".into()), + schema_name: Some("auth".into()), + statement_kind: PolicyStmtKind::Drop, + node_text: "\"REPLACED_TOKEN\"".into(), + node_range: TextRange::new(TextSize::new(23), TextSize::new(39)), + node_kind: "policy_name".into() + } + ); + } + + #[test] + fn single_quotation_mark_does_not_fail() { + let (pos, query) = with_pos(format!( + r#" + drop policy "{} on auth.users; + "#, + CURSOR_POS + )); + + let context = PolicyParser::get_context(query.as_str(), pos); + + assert_eq!(context, PolicyContext::default()); + } + + fn to_word(word: &str, start: usize, end: usize) -> WordWithIndex { + WordWithIndex { + word: word.into(), + start, + end, + } + } + + #[test] + fn determines_positions_correctly() { + let query = "\ncreate policy \"my cool pol\"\n\ton auth.users\n\tas permissive\n\tfor select\n\t\tto public\n\t\tusing (true);".to_string(); + + let words = sql_to_words(query.as_str()).unwrap(); + + assert_eq!(words[0], to_word("create", 1, 7)); + assert_eq!(words[1], to_word("policy", 8, 14)); + assert_eq!(words[2], to_word("\"my cool pol\"", 15, 28)); + assert_eq!(words[3], to_word("on", 30, 32)); + assert_eq!(words[4], to_word("auth.users", 33, 43)); + assert_eq!(words[5], to_word("as", 45, 47)); + assert_eq!(words[6], to_word("permissive", 48, 58)); + assert_eq!(words[7], to_word("for", 60, 63)); + assert_eq!(words[8], to_word("select", 64, 70)); + assert_eq!(words[9], to_word("to", 73, 75)); + assert_eq!(words[10], to_word("public", 78, 84)); + assert_eq!(words[11], to_word("using", 87, 92)); + assert_eq!(words[12], to_word("(true)", 93, 99)); } } diff --git a/crates/pgt_lsp/src/adapters/mod.rs b/crates/pgt_lsp/src/adapters/mod.rs index 972dd576..a5375180 100644 --- a/crates/pgt_lsp/src/adapters/mod.rs +++ b/crates/pgt_lsp/src/adapters/mod.rs @@ -158,6 +158,27 @@ mod tests { assert!(offset.is_none()); } + #[test] + fn with_tabs() { + let line_index = LineIndex::new( + r#" +select + email, + id +from auth.users u +join public.client_identities c on u.id = c.user_id; +"# + .trim(), + ); + + // on `i` of `id` in the select + // 22 because of: + // selectemail,i = 13 + // 8 spaces, 2 newlines = 23 characters + // it's zero indexed => index 22 + check_conversion!(line_index: Position { line: 2, character: 4 } => TextSize::from(22)); + } + #[test] fn unicode() { let line_index = LineIndex::new("'Jan 1, 2018 – Jan 1, 2019'"); diff --git a/crates/pgt_text_size/src/range.rs b/crates/pgt_text_size/src/range.rs index 3cfc3c96..baab91e9 100644 --- a/crates/pgt_text_size/src/range.rs +++ b/crates/pgt_text_size/src/range.rs @@ -299,6 +299,39 @@ impl TextRange { end: self.end.checked_add(offset)?, }) } + + /// Expand the range's start by the given offset. + /// The start will never exceed the range's end. + /// + /// # Examples + /// + /// ```rust + /// # use pgt_text_size::*; + /// assert_eq!( + /// TextRange::new(2.into(), 12.into()).checked_expand_start(4.into()).unwrap(), + /// TextRange::new(6.into(), 12.into()), + /// ); + /// + /// assert_eq!( + /// TextRange::new(2.into(), 12.into()).checked_expand_start(12.into()).unwrap(), + /// TextRange::new(12.into(), 12.into()), + /// ); + /// ``` + #[inline] + pub fn checked_expand_start(self, offset: TextSize) -> Option { + let new_start = self.start.checked_add(offset)?; + let end = self.end; + + if new_start > end { + Some(TextRange { start: end, end }) + } else { + Some(TextRange { + start: new_start, + end, + }) + } + } + /// Subtract an offset from this range. /// /// Note that this is not appropriate for changing where a `TextRange` is From 27cca0e417bc25dfe797f02584041c398d0cf889 Mon Sep 17 00:00:00 2001 From: Julian Date: Fri, 9 May 2025 10:34:52 +0200 Subject: [PATCH 04/18] =?UTF-8?q?build=20context=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/pgt_completions/src/context/context.rs | 71 ++++++++++++------- .../src/context/policy_parser.rs | 29 +++++--- crates/pgt_completions/src/sanitization.rs | 26 ++++++- 3 files changed, 89 insertions(+), 37 deletions(-) diff --git a/crates/pgt_completions/src/context/context.rs b/crates/pgt_completions/src/context/context.rs index f92f57da..72be1f83 100644 --- a/crates/pgt_completions/src/context/context.rs +++ b/crates/pgt_completions/src/context/context.rs @@ -7,7 +7,9 @@ use pgt_treesitter_queries::{ queries::{self, QueryResult}, }; -use crate::sanitization::SanitizedCompletionParams; +use crate::{ + NodeText, context::policy_parser::PolicyParser, sanitization::SanitizedCompletionParams, +}; #[derive(Debug, PartialEq, Eq)] pub enum WrappingClause<'a> { @@ -19,12 +21,8 @@ pub enum WrappingClause<'a> { }, Update, Delete, -} - -#[derive(PartialEq, Eq, Debug)] -pub(crate) enum NodeText<'a> { - Replaced, - Original(&'a str), + PolicyName, + ToRole, } /// We can map a few nodes, such as the "update" node, to actual SQL clauses. @@ -45,7 +43,7 @@ pub enum WrappingNode { pub(crate) enum NodeUnderCursor<'a> { TsNode(tree_sitter::Node<'a>), CustomNode { - text: NodeText<'a>, + text: NodeText, range: TextRange, kind: String, }, @@ -172,14 +170,35 @@ impl<'a> CompletionContext<'a> { // policy handling is important to Supabase, but they are a PostgreSQL specific extension, // so the tree_sitter_sql language does not support it. // We infer the context manually. - // if params.text.to_lowercase().starts_with("create policy") - // || params.text.to_lowercase().starts_with("alter policy") - // || params.text.to_lowercase().starts_with("drop policy") - // { - // } else { - ctx.gather_tree_context(); - ctx.gather_info_from_ts_queries(); - // } + if params.text.to_lowercase().starts_with("create policy") + || params.text.to_lowercase().starts_with("alter policy") + || params.text.to_lowercase().starts_with("drop policy") + { + let policy_context = PolicyParser::get_context(&ctx.text, ctx.position); + + ctx.node_under_cursor = Some(NodeUnderCursor::CustomNode { + text: policy_context.node_text.into(), + range: policy_context.node_range, + kind: policy_context.node_kind.clone(), + }); + + if policy_context.table_name.is_some() { + let mut new = HashSet::new(); + new.insert(policy_context.table_name.unwrap()); + ctx.mentioned_relations + .insert(policy_context.schema_name, new); + } + + ctx.wrapping_clause_type = match policy_context.node_kind.as_str() { + "policy_name" => Some(WrappingClause::PolicyName), + "policy_role" => Some(WrappingClause::ToRole), + "policy_table" => Some(WrappingClause::From), + _ => None, + }; + } else { + ctx.gather_tree_context(); + ctx.gather_info_from_ts_queries(); + } tracing::warn!("sql: {}", ctx.text); tracing::warn!("position: {}", ctx.position); @@ -237,13 +256,13 @@ impl<'a> CompletionContext<'a> { } } - fn get_ts_node_content(&self, ts_node: &tree_sitter::Node<'a>) -> Option> { + fn get_ts_node_content(&self, ts_node: &tree_sitter::Node<'a>) -> Option { let source = self.text; ts_node.utf8_text(source.as_bytes()).ok().map(|txt| { if SanitizedCompletionParams::is_sanitized_token(txt) { NodeText::Replaced } else { - NodeText::Original(txt) + NodeText::Original(txt.into()) } }) } @@ -386,7 +405,7 @@ impl<'a> CompletionContext<'a> { NodeText::Original(txt) => Some(txt), NodeText::Replaced => None, }) { - match txt { + match txt.as_str() { "where" => return Some(WrappingClause::Where), "update" => return Some(WrappingClause::Update), "select" => return Some(WrappingClause::Select), @@ -436,7 +455,8 @@ impl<'a> CompletionContext<'a> { #[cfg(test)] mod tests { use crate::{ - context::{CompletionContext, NodeText, WrappingClause}, + NodeText, + context::{CompletionContext, WrappingClause}, sanitization::SanitizedCompletionParams, test_helper::{CURSOR_POS, get_text_and_position}, }; @@ -607,7 +627,7 @@ mod tests { NodeUnderCursor::TsNode(node) => { assert_eq!( ctx.get_ts_node_content(node), - Some(NodeText::Original("select")) + Some(NodeText::Original("select".into())) ); assert_eq!( @@ -643,7 +663,7 @@ mod tests { NodeUnderCursor::TsNode(node) => { assert_eq!( ctx.get_ts_node_content(&node), - Some(NodeText::Original("from")) + Some(NodeText::Original("from".into())) ); } _ => unreachable!(), @@ -671,7 +691,10 @@ mod tests { match node { NodeUnderCursor::TsNode(node) => { - assert_eq!(ctx.get_ts_node_content(&node), Some(NodeText::Original(""))); + assert_eq!( + ctx.get_ts_node_content(&node), + Some(NodeText::Original("".into())) + ); assert_eq!(ctx.wrapping_clause_type, None); } _ => unreachable!(), @@ -703,7 +726,7 @@ mod tests { NodeUnderCursor::TsNode(node) => { assert_eq!( ctx.get_ts_node_content(&node), - Some(NodeText::Original("fro")) + Some(NodeText::Original("fro".into())) ); assert_eq!(ctx.wrapping_clause_type, Some(WrappingClause::Select)); } diff --git a/crates/pgt_completions/src/context/policy_parser.rs b/crates/pgt_completions/src/context/policy_parser.rs index 0a117180..b492fc6d 100644 --- a/crates/pgt_completions/src/context/policy_parser.rs +++ b/crates/pgt_completions/src/context/policy_parser.rs @@ -1,9 +1,9 @@ -use std::{env::current_exe, iter::Peekable}; +use std::iter::Peekable; use pgt_text_size::{TextRange, TextSize}; #[derive(Default, Debug, PartialEq, Eq)] -pub enum PolicyStmtKind { +pub(crate) enum PolicyStmtKind { #[default] Create, @@ -90,17 +90,17 @@ fn sql_to_words(sql: &str) -> Result, String> { } #[derive(Default, Debug, PartialEq, Eq)] -pub struct PolicyContext { - policy_name: Option, - table_name: Option, - schema_name: Option, - statement_kind: PolicyStmtKind, - node_text: String, - node_range: TextRange, - node_kind: String, +pub(crate) struct PolicyContext { + pub policy_name: Option, + pub table_name: Option, + pub schema_name: Option, + pub statement_kind: PolicyStmtKind, + pub node_text: String, + pub node_range: TextRange, + pub node_kind: String, } -pub struct PolicyParser { +pub(crate) struct PolicyParser { tokens: Peekable>, previous_token: Option, current_token: Option, @@ -110,6 +110,13 @@ pub struct PolicyParser { impl PolicyParser { pub(crate) fn get_context(sql: &str, cursor_position: usize) -> PolicyContext { + assert!( + sql.starts_with("create policy") + || sql.starts_with("drop policy") + || sql.starts_with("alter policy"), + "PolicyParser should only be used for policy statements. Developer error!" + ); + match sql_to_words(sql) { Ok(tokens) => { let parser = PolicyParser { diff --git a/crates/pgt_completions/src/sanitization.rs b/crates/pgt_completions/src/sanitization.rs index 248a0ffa..0f5d2b1f 100644 --- a/crates/pgt_completions/src/sanitization.rs +++ b/crates/pgt_completions/src/sanitization.rs @@ -4,6 +4,8 @@ use pgt_text_size::TextSize; use crate::CompletionParams; +static SANITIZED_TOKEN: &str = "REPLACED_TOKEN"; + pub(crate) struct SanitizedCompletionParams<'a> { pub position: TextSize, pub text: String, @@ -16,6 +18,28 @@ pub fn benchmark_sanitization(params: CompletionParams) -> String { params.text } +#[derive(PartialEq, Eq, Debug)] +pub(crate) enum NodeText { + Replaced, + Original(String), +} + +impl From<&str> for NodeText { + fn from(value: &str) -> Self { + if value == SANITIZED_TOKEN { + NodeText::Replaced + } else { + NodeText::Original(value.into()) + } + } +} + +impl From for NodeText { + fn from(value: String) -> Self { + NodeText::from(value.as_str()) + } +} + impl<'larger, 'smaller> From> for SanitizedCompletionParams<'smaller> where 'larger: 'smaller, @@ -33,8 +57,6 @@ where } } -static SANITIZED_TOKEN: &str = "REPLACED_TOKEN"; - impl<'larger, 'smaller> SanitizedCompletionParams<'smaller> where 'larger: 'smaller, From 121431da7815f11d379007992fd36685eca8db11 Mon Sep 17 00:00:00 2001 From: Julian Date: Fri, 9 May 2025 15:36:15 +0200 Subject: [PATCH 05/18] ok --- crates/pgt_completions/src/context/context.rs | 38 +++++++++------- .../pgt_completions/src/providers/helper.rs | 26 ++++++++--- .../pgt_completions/src/providers/policies.rs | 43 +++++++++++++++++-- .../src/relevance/filtering.rs | 26 +++++++---- .../pgt_completions/src/relevance/scoring.rs | 7 ++- crates/pgt_completions/src/sanitization.rs | 19 +++++++- 6 files changed, 122 insertions(+), 37 deletions(-) diff --git a/crates/pgt_completions/src/context/context.rs b/crates/pgt_completions/src/context/context.rs index 72be1f83..b17522e6 100644 --- a/crates/pgt_completions/src/context/context.rs +++ b/crates/pgt_completions/src/context/context.rs @@ -8,7 +8,9 @@ use pgt_treesitter_queries::{ }; use crate::{ - NodeText, context::policy_parser::PolicyParser, sanitization::SanitizedCompletionParams, + NodeText, + context::policy_parser::{PolicyParser, PolicyStmtKind}, + sanitization::SanitizedCompletionParams, }; #[derive(Debug, PartialEq, Eq)] @@ -40,6 +42,7 @@ pub enum WrappingNode { Assignment, } +#[derive(Debug)] pub(crate) enum NodeUnderCursor<'a> { TsNode(tree_sitter::Node<'a>), CustomNode { @@ -64,6 +67,13 @@ impl<'a> NodeUnderCursor<'a> { } } + pub fn range(&self) -> TextRange { + let start: u32 = self.start_byte().try_into().unwrap(); + let end: u32 = self.end_byte().try_into().unwrap(); + + TextRange::new(start.into(), end.into()) + } + pub fn kind(&self) -> &str { match self { NodeUnderCursor::TsNode(node) => node.kind(), @@ -182,6 +192,10 @@ impl<'a> CompletionContext<'a> { kind: policy_context.node_kind.clone(), }); + if policy_context.node_kind == "policy_table" { + ctx.schema_or_alias_name = policy_context.schema_name.clone(); + } + if policy_context.table_name.is_some() { let mut new = HashSet::new(); new.insert(policy_context.table_name.unwrap()); @@ -190,7 +204,9 @@ impl<'a> CompletionContext<'a> { } ctx.wrapping_clause_type = match policy_context.node_kind.as_str() { - "policy_name" => Some(WrappingClause::PolicyName), + "policy_name" if policy_context.statement_kind != PolicyStmtKind::Create => { + Some(WrappingClause::PolicyName) + } "policy_role" => Some(WrappingClause::ToRole), "policy_table" => Some(WrappingClause::From), _ => None, @@ -200,19 +216,11 @@ impl<'a> CompletionContext<'a> { ctx.gather_info_from_ts_queries(); } - tracing::warn!("sql: {}", ctx.text); - tracing::warn!("position: {}", ctx.position); - tracing::warn!( - "node range: {} - {}", - ctx.node_under_cursor - .as_ref() - .map(|n| n.start_byte()) - .unwrap_or(0), - ctx.node_under_cursor - .as_ref() - .map(|n| n.end_byte()) - .unwrap_or(0) - ); + tracing::warn!("SQL: {}", ctx.text); + tracing::warn!("Position: {}", ctx.position); + tracing::warn!("Node: {:#?}", ctx.node_under_cursor); + tracing::warn!("Relations: {:#?}", ctx.mentioned_relations); + tracing::warn!("Clause: {:#?}", ctx.wrapping_clause_type); ctx } diff --git a/crates/pgt_completions/src/providers/helper.rs b/crates/pgt_completions/src/providers/helper.rs index a6bee236..39bf5ba2 100644 --- a/crates/pgt_completions/src/providers/helper.rs +++ b/crates/pgt_completions/src/providers/helper.rs @@ -14,6 +14,25 @@ pub(crate) fn find_matching_alias_for_table( None } +pub(crate) fn get_range_to_replace(ctx: &CompletionContext) -> TextRange { + let start = ctx + .node_under_cursor + .as_ref() + .map(|n| n.start_byte()) + .unwrap_or(0); + + let end = ctx + .get_node_under_cursor_content() + .unwrap_or("".into()) + .len() + + start; + + TextRange::new( + TextSize::new(start.try_into().unwrap()), + end.try_into().unwrap(), + ) +} + pub(crate) fn get_completion_text_with_schema_or_alias( ctx: &CompletionContext, item_name: &str, @@ -22,12 +41,7 @@ pub(crate) fn get_completion_text_with_schema_or_alias( if schema_or_alias_name == "public" || ctx.schema_or_alias_name.is_some() { None } else { - let node = ctx.node_under_cursor.as_ref().unwrap(); - - let range = TextRange::new( - TextSize::try_from(node.start_byte()).unwrap(), - TextSize::try_from(node.end_byte()).unwrap(), - ); + let range = get_range_to_replace(ctx); Some(CompletionText { text: format!("{}.{}", schema_or_alias_name, item_name), diff --git a/crates/pgt_completions/src/providers/policies.rs b/crates/pgt_completions/src/providers/policies.rs index 380746c7..bc42beb6 100644 --- a/crates/pgt_completions/src/providers/policies.rs +++ b/crates/pgt_completions/src/providers/policies.rs @@ -1,10 +1,12 @@ use crate::{ - CompletionItemKind, + CompletionItemKind, CompletionText, builder::{CompletionBuilder, PossibleCompletionItem}, context::CompletionContext, relevance::{CompletionRelevanceData, filtering::CompletionFilter, scoring::CompletionScore}, }; +use super::helper::get_range_to_replace; + pub fn complete_policies<'a>(ctx: &CompletionContext<'a>, builder: &mut CompletionBuilder<'a>) { let available_policies = &ctx.schema_cache.policies; @@ -12,14 +14,47 @@ pub fn complete_policies<'a>(ctx: &CompletionContext<'a>, builder: &mut Completi let relevance = CompletionRelevanceData::Policy(pol); let item = PossibleCompletionItem { - label: pol.name.clone(), + label: pol.name.chars().take(35).collect::(), score: CompletionScore::from(relevance.clone()), filter: CompletionFilter::from(relevance), - description: format!("Table: {}", pol.table_name), + description: format!("{}", pol.table_name), kind: CompletionItemKind::Policy, - completion_text: None, + completion_text: Some(CompletionText { + text: format!("\"{}\"", pol.name), + range: get_range_to_replace(ctx), + }), }; builder.add_item(item); } } + +#[cfg(test)] +mod tests { + use crate::test_helper::{CURSOR_POS, CompletionAssertion, assert_complete_results}; + + #[tokio::test] + async fn completes_within_quotation_marks() { + let setup = r#" + create table users ( + id serial primary key, + email text + ); + + create policy "should never have access" on users + as restrictive + for all + to public + using (false); + "#; + + assert_complete_results( + format!("alter policy \"{}\" on users;", CURSOR_POS).as_str(), + vec![CompletionAssertion::Label( + "should never have access".into(), + )], + setup, + ) + .await; + } +} diff --git a/crates/pgt_completions/src/relevance/filtering.rs b/crates/pgt_completions/src/relevance/filtering.rs index 4c9fa139..071e2994 100644 --- a/crates/pgt_completions/src/relevance/filtering.rs +++ b/crates/pgt_completions/src/relevance/filtering.rs @@ -67,18 +67,19 @@ impl CompletionFilter<'_> { fn check_clause(&self, ctx: &CompletionContext) -> Option<()> { let clause = ctx.wrapping_clause_type.as_ref(); + let in_clause = |compare: WrappingClause| clause.is_some_and(|c| c == &compare); + match self.data { CompletionRelevanceData::Table(_) => { - let in_select_clause = clause.is_some_and(|c| c == &WrappingClause::Select); - let in_where_clause = clause.is_some_and(|c| c == &WrappingClause::Where); - - if in_select_clause || in_where_clause { + if in_clause(WrappingClause::Select) + || in_clause(WrappingClause::Where) + || in_clause(WrappingClause::PolicyName) + { return None; }; } CompletionRelevanceData::Column(_) => { - let in_from_clause = clause.is_some_and(|c| c == &WrappingClause::From); - if in_from_clause { + if in_clause(WrappingClause::From) || in_clause(WrappingClause::PolicyName) { return None; } @@ -100,7 +101,16 @@ impl CompletionFilter<'_> { return None; } } - _ => {} + CompletionRelevanceData::Policy(_) => { + if clause.is_none_or(|c| c != &WrappingClause::PolicyName) { + return None; + } + } + _ => { + if in_clause(WrappingClause::PolicyName) { + return None; + } + } } Some(()) @@ -140,7 +150,7 @@ impl CompletionFilter<'_> { } // no aliases and schemas for policies - CompletionRelevanceData::Policy(_) => false, + CompletionRelevanceData::Policy(p) => p.schema_name == p.schema_name, }; if !matches { diff --git a/crates/pgt_completions/src/relevance/scoring.rs b/crates/pgt_completions/src/relevance/scoring.rs index 0b0933e5..31dfd96c 100644 --- a/crates/pgt_completions/src/relevance/scoring.rs +++ b/crates/pgt_completions/src/relevance/scoring.rs @@ -119,7 +119,10 @@ impl CompletionScore<'_> { WrappingClause::Delete if !has_mentioned_schema => 15, _ => -50, }, - CompletionRelevanceData::Policy(_) => 0, + CompletionRelevanceData::Policy(_) => match clause_type { + WrappingClause::PolicyName => 25, + _ => -50, + }, } } @@ -187,7 +190,7 @@ impl CompletionScore<'_> { CompletionRelevanceData::Table(t) => t.schema.as_str(), CompletionRelevanceData::Column(c) => c.schema_name.as_str(), CompletionRelevanceData::Schema(s) => s.name.as_str(), - CompletionRelevanceData::Policy(p) => p.name.as_str(), + CompletionRelevanceData::Policy(p) => p.schema_name.as_str(), } } diff --git a/crates/pgt_completions/src/sanitization.rs b/crates/pgt_completions/src/sanitization.rs index 0f5d2b1f..3e3fbebf 100644 --- a/crates/pgt_completions/src/sanitization.rs +++ b/crates/pgt_completions/src/sanitization.rs @@ -49,6 +49,7 @@ where || cursor_prepared_to_write_token_after_last_node(params.tree, params.position) || cursor_before_semicolon(params.tree, params.position) || cursor_on_a_dot(¶ms.text, params.position) + || cursor_between_double_quotes(¶ms.text, params.position) { SanitizedCompletionParams::with_adjusted_sql(params) } else { @@ -178,6 +179,12 @@ fn cursor_on_a_dot(sql: &str, position: TextSize) -> bool { sql.chars().nth(position - 1).is_some_and(|c| c == '.') } +fn cursor_between_double_quotes(sql: &str, position: TextSize) -> bool { + let position: usize = position.into(); + let mut chars = sql.chars(); + chars.nth(position - 1).is_some_and(|c| c == '"') && chars.next().is_some_and(|c| c == '"') +} + fn cursor_before_semicolon(tree: &tree_sitter::Tree, position: TextSize) -> bool { let mut cursor = tree.walk(); let mut leaf_node = tree.root_node(); @@ -227,8 +234,8 @@ mod tests { use pgt_text_size::TextSize; use crate::sanitization::{ - cursor_before_semicolon, cursor_inbetween_nodes, cursor_on_a_dot, - cursor_prepared_to_write_token_after_last_node, + cursor_before_semicolon, cursor_between_double_quotes, cursor_inbetween_nodes, + cursor_on_a_dot, cursor_prepared_to_write_token_after_last_node, }; #[test] @@ -339,4 +346,12 @@ mod tests { assert!(cursor_before_semicolon(&tree, TextSize::new(16))); assert!(cursor_before_semicolon(&tree, TextSize::new(17))); } + + #[test] + fn between_quotations() { + let input = "select * from \"\""; + + // select * from "|" <-- between quotations + assert!(cursor_between_double_quotes(input, TextSize::new(15))); + } } From bf3d96cbfea8a8e347ce23b70c3497deba3d1939 Mon Sep 17 00:00:00 2001 From: Julian Date: Mon, 12 May 2025 17:40:18 +0200 Subject: [PATCH 06/18] arairga --- crates/pgt_completions/src/context/context.rs | 13 --- .../src/context/policy_parser.rs | 7 +- .../pgt_completions/src/providers/policies.rs | 67 ++++++++++++--- .../src/relevance/filtering.rs | 11 ++- .../pgt_completions/src/relevance/scoring.rs | 16 ++-- crates/pgt_completions/src/sanitization.rs | 83 ++++++------------- crates/pgt_lsp/src/handlers/completions.rs | 4 +- crates/pgt_workspace/src/workspace/server.rs | 7 +- 8 files changed, 110 insertions(+), 98 deletions(-) diff --git a/crates/pgt_completions/src/context/context.rs b/crates/pgt_completions/src/context/context.rs index b17522e6..abf479f7 100644 --- a/crates/pgt_completions/src/context/context.rs +++ b/crates/pgt_completions/src/context/context.rs @@ -67,13 +67,6 @@ impl<'a> NodeUnderCursor<'a> { } } - pub fn range(&self) -> TextRange { - let start: u32 = self.start_byte().try_into().unwrap(); - let end: u32 = self.end_byte().try_into().unwrap(); - - TextRange::new(start.into(), end.into()) - } - pub fn kind(&self) -> &str { match self { NodeUnderCursor::TsNode(node) => node.kind(), @@ -216,12 +209,6 @@ impl<'a> CompletionContext<'a> { ctx.gather_info_from_ts_queries(); } - tracing::warn!("SQL: {}", ctx.text); - tracing::warn!("Position: {}", ctx.position); - tracing::warn!("Node: {:#?}", ctx.node_under_cursor); - tracing::warn!("Relations: {:#?}", ctx.mentioned_relations); - tracing::warn!("Clause: {:#?}", ctx.wrapping_clause_type); - ctx } diff --git a/crates/pgt_completions/src/context/policy_parser.rs b/crates/pgt_completions/src/context/policy_parser.rs index b492fc6d..04523261 100644 --- a/crates/pgt_completions/src/context/policy_parser.rs +++ b/crates/pgt_completions/src/context/policy_parser.rs @@ -110,10 +110,11 @@ pub(crate) struct PolicyParser { impl PolicyParser { pub(crate) fn get_context(sql: &str, cursor_position: usize) -> PolicyContext { + let trimmed = sql.trim(); assert!( - sql.starts_with("create policy") - || sql.starts_with("drop policy") - || sql.starts_with("alter policy"), + trimmed.starts_with("create policy") + || trimmed.starts_with("drop policy") + || trimmed.starts_with("alter policy"), "PolicyParser should only be used for policy statements. Developer error!" ); diff --git a/crates/pgt_completions/src/providers/policies.rs b/crates/pgt_completions/src/providers/policies.rs index bc42beb6..81130d84 100644 --- a/crates/pgt_completions/src/providers/policies.rs +++ b/crates/pgt_completions/src/providers/policies.rs @@ -10,6 +10,10 @@ use super::helper::get_range_to_replace; pub fn complete_policies<'a>(ctx: &CompletionContext<'a>, builder: &mut CompletionBuilder<'a>) { let available_policies = &ctx.schema_cache.policies; + let has_quotes = ctx + .get_node_under_cursor_content() + .is_some_and(|c| c.starts_with('"') && c.ends_with('"')); + for pol in available_policies { let relevance = CompletionRelevanceData::Policy(pol); @@ -19,10 +23,14 @@ pub fn complete_policies<'a>(ctx: &CompletionContext<'a>, builder: &mut Completi filter: CompletionFilter::from(relevance), description: format!("{}", pol.table_name), kind: CompletionItemKind::Policy, - completion_text: Some(CompletionText { - text: format!("\"{}\"", pol.name), - range: get_range_to_replace(ctx), - }), + completion_text: if !has_quotes { + Some(CompletionText { + text: format!("\"{}\"", pol.name), + range: get_range_to_replace(ctx), + }) + } else { + None + }, }; builder.add_item(item); @@ -31,30 +39,69 @@ pub fn complete_policies<'a>(ctx: &CompletionContext<'a>, builder: &mut Completi #[cfg(test)] mod tests { - use crate::test_helper::{CURSOR_POS, CompletionAssertion, assert_complete_results}; + use crate::{ + complete, + test_helper::{ + CURSOR_POS, CompletionAssertion, assert_complete_results, get_test_params, + test_against_connection_string, + }, + }; #[tokio::test] async fn completes_within_quotation_marks() { let setup = r#" - create table users ( + create schema private; + + create table private.users ( id serial primary key, email text ); - create policy "should never have access" on users + create policy "read for public users disallowed" on private.users as restrictive - for all + for select to public using (false); + + create policy "write for public users allowed" on private.users + as restrictive + for insert + to public + with check (true); "#; assert_complete_results( - format!("alter policy \"{}\" on users;", CURSOR_POS).as_str(), + format!("alter policy \"{}\" on private.users;", CURSOR_POS).as_str(), + vec![ + CompletionAssertion::Label("read for public users disallowed".into()), + CompletionAssertion::Label("write for public users allowed".into()), + ], + setup, + ) + .await; + + assert_complete_results( + format!("alter policy \"w{}\" on private.users;", CURSOR_POS).as_str(), vec![CompletionAssertion::Label( - "should never have access".into(), + "write for public users allowed".into(), )], setup, ) .await; } + + #[tokio::test] + async fn sb_test() { + let input = format!("alter policy \"u{}\" on public.fcm_tokens;", CURSOR_POS); + + let (tree, cache) = test_against_connection_string( + "postgresql://postgres:postgres@127.0.0.1:54322/postgres", + input.as_str().into(), + ) + .await; + + let result = complete(get_test_params(&tree, &cache, input.as_str().into())); + + println!("{:#?}", result); + } } diff --git a/crates/pgt_completions/src/relevance/filtering.rs b/crates/pgt_completions/src/relevance/filtering.rs index 071e2994..c625d200 100644 --- a/crates/pgt_completions/src/relevance/filtering.rs +++ b/crates/pgt_completions/src/relevance/filtering.rs @@ -71,9 +71,8 @@ impl CompletionFilter<'_> { match self.data { CompletionRelevanceData::Table(_) => { - if in_clause(WrappingClause::Select) - || in_clause(WrappingClause::Where) - || in_clause(WrappingClause::PolicyName) + if in_clause(WrappingClause::Select) || in_clause(WrappingClause::Where) + // || in_clause(WrappingClause::PolicyName) { return None; }; @@ -107,9 +106,9 @@ impl CompletionFilter<'_> { } } _ => { - if in_clause(WrappingClause::PolicyName) { - return None; - } + // if in_clause(WrappingClause::PolicyName) { + // return None; + // } } } diff --git a/crates/pgt_completions/src/relevance/scoring.rs b/crates/pgt_completions/src/relevance/scoring.rs index 31dfd96c..4a554be3 100644 --- a/crates/pgt_completions/src/relevance/scoring.rs +++ b/crates/pgt_completions/src/relevance/scoring.rs @@ -36,21 +36,23 @@ impl CompletionScore<'_> { fn check_matches_query_input(&mut self, ctx: &CompletionContext) { let content = match ctx.get_node_under_cursor_content() { - Some(c) => c, + Some(c) => c.replace('"', ""), None => return, }; let name = match self.data { - CompletionRelevanceData::Function(f) => f.name.as_str(), - CompletionRelevanceData::Table(t) => t.name.as_str(), - CompletionRelevanceData::Column(c) => c.name.as_str(), - CompletionRelevanceData::Schema(s) => s.name.as_str(), - CompletionRelevanceData::Policy(p) => p.name.as_str(), + CompletionRelevanceData::Function(f) => f.name.as_str().to_ascii_lowercase(), + CompletionRelevanceData::Table(t) => t.name.as_str().to_ascii_lowercase(), + CompletionRelevanceData::Column(c) => c.name.as_str().to_ascii_lowercase(), + CompletionRelevanceData::Schema(s) => s.name.as_str().to_ascii_lowercase(), + CompletionRelevanceData::Policy(p) => p.name.as_str().to_ascii_lowercase(), }; let fz_matcher = SkimMatcherV2::default(); - if let Some(score) = fz_matcher.fuzzy_match(name, content.as_str()) { + if let Some(score) = + fz_matcher.fuzzy_match(name.as_str(), content.to_ascii_lowercase().as_str()) + { let scorei32: i32 = score .try_into() .expect("The length of the input exceeds i32 capacity"); diff --git a/crates/pgt_completions/src/sanitization.rs b/crates/pgt_completions/src/sanitization.rs index 3e3fbebf..b6696219 100644 --- a/crates/pgt_completions/src/sanitization.rs +++ b/crates/pgt_completions/src/sanitization.rs @@ -45,8 +45,8 @@ where 'larger: 'smaller, { fn from(params: CompletionParams<'larger>) -> Self { - if cursor_inbetween_nodes(params.tree, params.position) - || cursor_prepared_to_write_token_after_last_node(params.tree, params.position) + if cursor_inbetween_nodes(¶ms.text, params.position) + || cursor_prepared_to_write_token_after_last_node(¶ms.text, params.position) || cursor_before_semicolon(params.tree, params.position) || cursor_on_a_dot(¶ms.text, params.position) || cursor_between_double_quotes(¶ms.text, params.position) @@ -125,37 +125,17 @@ where /// select |from users; -- cursor "touches" from node. returns false. /// select | from users; -- cursor is between select and from nodes. returns true. /// ``` -fn cursor_inbetween_nodes(tree: &tree_sitter::Tree, position: TextSize) -> bool { - let mut cursor = tree.walk(); - let mut leaf_node = tree.root_node(); - - let byte = position.into(); - - // if the cursor escapes the root node, it can't be between nodes. - if byte < leaf_node.start_byte() || byte >= leaf_node.end_byte() { - return false; - } +fn cursor_inbetween_nodes(sql: &str, position: TextSize) -> bool { + let position: usize = position.into(); + let mut chars = sql.chars(); - /* - * Get closer and closer to the leaf node, until - * a) there is no more child *for the node* or - * b) there is no more child *under the cursor*. - */ - loop { - let child_idx = cursor.goto_first_child_for_byte(position.into()); - if child_idx.is_none() { - break; - } - leaf_node = cursor.node(); - } + let previous_whitespace = chars + .nth(position - 1) + .is_some_and(|c| c.is_ascii_whitespace()); - let cursor_on_leafnode = byte >= leaf_node.start_byte() && leaf_node.end_byte() >= byte; + let current_whitespace = chars.next().is_some_and(|c| c.is_ascii_whitespace()); - /* - * The cursor is inbetween nodes if it is not within the range - * of a leaf node. - */ - !cursor_on_leafnode + previous_whitespace && current_whitespace } /// Checks if the cursor is positioned after the last node, @@ -166,12 +146,9 @@ fn cursor_inbetween_nodes(tree: &tree_sitter::Tree, position: TextSize) -> bool /// select * from| -- user still needs to type a space /// select * from | -- too far off. /// ``` -fn cursor_prepared_to_write_token_after_last_node( - tree: &tree_sitter::Tree, - position: TextSize, -) -> bool { +fn cursor_prepared_to_write_token_after_last_node(sql: &str, position: TextSize) -> bool { let cursor_pos: usize = position.into(); - cursor_pos == tree.root_node().end_byte() + 1 + cursor_pos == sql.len() + 1 } fn cursor_on_a_dot(sql: &str, position: TextSize) -> bool { @@ -243,58 +220,44 @@ mod tests { // note: two spaces between select and from. let input = "select from users;"; - let mut parser = tree_sitter::Parser::new(); - parser - .set_language(tree_sitter_sql::language()) - .expect("Error loading sql language"); - - let tree = parser.parse(input, None).unwrap(); - // select | from users; <-- just right, one space after select token, one space before from - assert!(cursor_inbetween_nodes(&tree, TextSize::new(7))); + assert!(cursor_inbetween_nodes(input, TextSize::new(7))); // select| from users; <-- still on select token - assert!(!cursor_inbetween_nodes(&tree, TextSize::new(6))); + assert!(!cursor_inbetween_nodes(input, TextSize::new(6))); // select |from users; <-- already on from token - assert!(!cursor_inbetween_nodes(&tree, TextSize::new(8))); + assert!(!cursor_inbetween_nodes(input, TextSize::new(8))); // select from users;| - assert!(!cursor_inbetween_nodes(&tree, TextSize::new(19))); + assert!(!cursor_inbetween_nodes(input, TextSize::new(19))); } #[test] fn test_cursor_after_nodes() { let input = "select * from"; - let mut parser = tree_sitter::Parser::new(); - parser - .set_language(tree_sitter_sql::language()) - .expect("Error loading sql language"); - - let tree = parser.parse(input, None).unwrap(); - // select * from| <-- still on previous token assert!(!cursor_prepared_to_write_token_after_last_node( - &tree, + input, TextSize::new(13) )); // select * from | <-- too far off, two spaces afterward assert!(!cursor_prepared_to_write_token_after_last_node( - &tree, + input, TextSize::new(15) )); // select * |from <-- it's within assert!(!cursor_prepared_to_write_token_after_last_node( - &tree, + input, TextSize::new(9) )); // select * from | <-- just right assert!(cursor_prepared_to_write_token_after_last_node( - &tree, + input, TextSize::new(14) )); } @@ -353,5 +316,11 @@ mod tests { // select * from "|" <-- between quotations assert!(cursor_between_double_quotes(input, TextSize::new(15))); + + // select * from "r|" <-- between quotations, but there's + // a letter inside + let input = "select * from \"r\""; + + assert!(!cursor_between_double_quotes(input, TextSize::new(16))); } } diff --git a/crates/pgt_lsp/src/handlers/completions.rs b/crates/pgt_lsp/src/handlers/completions.rs index 7da4fdf2..33d8ab1d 100644 --- a/crates/pgt_lsp/src/handlers/completions.rs +++ b/crates/pgt_lsp/src/handlers/completions.rs @@ -54,6 +54,8 @@ pub fn get_completions( }) .collect(); + tracing::warn!("{:#?}", items); + Ok(lsp_types::CompletionResponse::Array(items)) } @@ -65,6 +67,6 @@ fn to_lsp_types_completion_item_kind( pgt_completions::CompletionItemKind::Table => lsp_types::CompletionItemKind::CLASS, pgt_completions::CompletionItemKind::Column => lsp_types::CompletionItemKind::FIELD, pgt_completions::CompletionItemKind::Schema => lsp_types::CompletionItemKind::CLASS, - pgt_completions::CompletionItemKind::Policy => lsp_types::CompletionItemKind::VALUE, + pgt_completions::CompletionItemKind::Policy => lsp_types::CompletionItemKind::CONSTANT, } } diff --git a/crates/pgt_workspace/src/workspace/server.rs b/crates/pgt_workspace/src/workspace/server.rs index 5a7bfc44..bd71c4f2 100644 --- a/crates/pgt_workspace/src/workspace/server.rs +++ b/crates/pgt_workspace/src/workspace/server.rs @@ -481,6 +481,7 @@ impl Workspace for WorkspaceServer { Some(pool) => pool, None => { tracing::debug!("No connection to database. Skipping completions."); + tracing::warn!("No connection to database."); return Ok(CompletionsResult::default()); } }; @@ -488,8 +489,12 @@ impl Workspace for WorkspaceServer { let schema_cache = self.schema_cache.load(pool)?; match get_statement_for_completions(&parsed_doc, params.position) { - None => Ok(CompletionsResult::default()), + None => { + tracing::warn!("No statement found."); + Ok(CompletionsResult::default()) + } Some((_id, range, content, cst)) => { + tracing::warn!("found matching statement, content: {}", content); let position = params.position - range.start(); let items = pgt_completions::complete(pgt_completions::CompletionParams { From 5dd2fdefe98028299569f9ab44ac892c4cdc07df Mon Sep 17 00:00:00 2001 From: Julian Date: Mon, 12 May 2025 18:43:05 +0200 Subject: [PATCH 07/18] it works! --- crates/pgt_completions/src/context/context.rs | 2 ++ .../pgt_completions/src/providers/helper.rs | 26 ++++++++----------- .../pgt_completions/src/providers/policies.rs | 19 +++++++++++--- crates/pgt_completions/src/sanitization.rs | 5 ++++ crates/pgt_lsp/src/handlers/completions.rs | 2 -- 5 files changed, 33 insertions(+), 21 deletions(-) diff --git a/crates/pgt_completions/src/context/context.rs b/crates/pgt_completions/src/context/context.rs index abf479f7..9b1b475e 100644 --- a/crates/pgt_completions/src/context/context.rs +++ b/crates/pgt_completions/src/context/context.rs @@ -209,6 +209,8 @@ impl<'a> CompletionContext<'a> { ctx.gather_info_from_ts_queries(); } + tracing::warn!("{:#?}", ctx.get_node_under_cursor_content()); + ctx } diff --git a/crates/pgt_completions/src/providers/helper.rs b/crates/pgt_completions/src/providers/helper.rs index 39bf5ba2..eacb8314 100644 --- a/crates/pgt_completions/src/providers/helper.rs +++ b/crates/pgt_completions/src/providers/helper.rs @@ -1,6 +1,6 @@ use pgt_text_size::{TextRange, TextSize}; -use crate::{CompletionText, context::CompletionContext}; +use crate::{CompletionText, context::CompletionContext, remove_sanitized_token}; pub(crate) fn find_matching_alias_for_table( ctx: &CompletionContext, @@ -15,22 +15,18 @@ pub(crate) fn find_matching_alias_for_table( } pub(crate) fn get_range_to_replace(ctx: &CompletionContext) -> TextRange { - let start = ctx - .node_under_cursor - .as_ref() - .map(|n| n.start_byte()) - .unwrap_or(0); + match ctx.node_under_cursor.as_ref() { + Some(node) => { + let content = ctx.get_node_under_cursor_content().unwrap_or("".into()); + let length = remove_sanitized_token(content.as_str()).len(); - let end = ctx - .get_node_under_cursor_content() - .unwrap_or("".into()) - .len() - + start; + let start = node.start_byte(); + let end = start + length; - TextRange::new( - TextSize::new(start.try_into().unwrap()), - end.try_into().unwrap(), - ) + TextRange::new(start.try_into().unwrap(), end.try_into().unwrap()) + } + None => TextRange::empty(TextSize::new(0)), + } } pub(crate) fn get_completion_text_with_schema_or_alias( diff --git a/crates/pgt_completions/src/providers/policies.rs b/crates/pgt_completions/src/providers/policies.rs index 81130d84..7e6544b2 100644 --- a/crates/pgt_completions/src/providers/policies.rs +++ b/crates/pgt_completions/src/providers/policies.rs @@ -1,3 +1,5 @@ +use pgt_text_size::{TextRange, TextSize}; + use crate::{ CompletionItemKind, CompletionText, builder::{CompletionBuilder, PossibleCompletionItem}, @@ -10,9 +12,9 @@ use super::helper::get_range_to_replace; pub fn complete_policies<'a>(ctx: &CompletionContext<'a>, builder: &mut CompletionBuilder<'a>) { let available_policies = &ctx.schema_cache.policies; - let has_quotes = ctx + let surrounded_by_quotes = ctx .get_node_under_cursor_content() - .is_some_and(|c| c.starts_with('"') && c.ends_with('"')); + .is_some_and(|c| c.starts_with('"') && c.ends_with('"') && c != "\"\""); for pol in available_policies { let relevance = CompletionRelevanceData::Policy(pol); @@ -23,13 +25,22 @@ pub fn complete_policies<'a>(ctx: &CompletionContext<'a>, builder: &mut Completi filter: CompletionFilter::from(relevance), description: format!("{}", pol.table_name), kind: CompletionItemKind::Policy, - completion_text: if !has_quotes { + completion_text: if !surrounded_by_quotes { Some(CompletionText { text: format!("\"{}\"", pol.name), range: get_range_to_replace(ctx), }) } else { - None + let range = get_range_to_replace(ctx); + Some(CompletionText { + text: pol.name.clone(), + + // trim the quotes. + range: TextRange::new( + range.start() + TextSize::new(1), + range.end() - TextSize::new(1), + ), + }) }, }; diff --git a/crates/pgt_completions/src/sanitization.rs b/crates/pgt_completions/src/sanitization.rs index b6696219..1adf9d95 100644 --- a/crates/pgt_completions/src/sanitization.rs +++ b/crates/pgt_completions/src/sanitization.rs @@ -18,6 +18,10 @@ pub fn benchmark_sanitization(params: CompletionParams) -> String { params.text } +pub(crate) fn remove_sanitized_token(it: &str) -> String { + it.replace(SANITIZED_TOKEN, "") +} + #[derive(PartialEq, Eq, Debug)] pub(crate) enum NodeText { Replaced, @@ -157,6 +161,7 @@ fn cursor_on_a_dot(sql: &str, position: TextSize) -> bool { } fn cursor_between_double_quotes(sql: &str, position: TextSize) -> bool { + return false; let position: usize = position.into(); let mut chars = sql.chars(); chars.nth(position - 1).is_some_and(|c| c == '"') && chars.next().is_some_and(|c| c == '"') diff --git a/crates/pgt_lsp/src/handlers/completions.rs b/crates/pgt_lsp/src/handlers/completions.rs index 33d8ab1d..ee13b26e 100644 --- a/crates/pgt_lsp/src/handlers/completions.rs +++ b/crates/pgt_lsp/src/handlers/completions.rs @@ -54,8 +54,6 @@ pub fn get_completions( }) .collect(); - tracing::warn!("{:#?}", items); - Ok(lsp_types::CompletionResponse::Array(items)) } From d29d16eda731e29f8dd1ebf874a1c1927122caf3 Mon Sep 17 00:00:00 2001 From: Julian Date: Mon, 12 May 2025 18:48:12 +0200 Subject: [PATCH 08/18] cool --- crates/pgt_completions/src/context/context.rs | 4 ++-- crates/pgt_workspace/src/workspace/server.rs | 12 ++++++++---- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/crates/pgt_completions/src/context/context.rs b/crates/pgt_completions/src/context/context.rs index 9b1b475e..ea9df3f4 100644 --- a/crates/pgt_completions/src/context/context.rs +++ b/crates/pgt_completions/src/context/context.rs @@ -24,7 +24,7 @@ pub enum WrappingClause<'a> { Update, Delete, PolicyName, - ToRole, + ToRoleAssignment, } /// We can map a few nodes, such as the "update" node, to actual SQL clauses. @@ -200,7 +200,7 @@ impl<'a> CompletionContext<'a> { "policy_name" if policy_context.statement_kind != PolicyStmtKind::Create => { Some(WrappingClause::PolicyName) } - "policy_role" => Some(WrappingClause::ToRole), + "policy_role" => Some(WrappingClause::ToRoleAssignment), "policy_table" => Some(WrappingClause::From), _ => None, }; diff --git a/crates/pgt_workspace/src/workspace/server.rs b/crates/pgt_workspace/src/workspace/server.rs index bd71c4f2..2c0f2b75 100644 --- a/crates/pgt_workspace/src/workspace/server.rs +++ b/crates/pgt_workspace/src/workspace/server.rs @@ -481,7 +481,6 @@ impl Workspace for WorkspaceServer { Some(pool) => pool, None => { tracing::debug!("No connection to database. Skipping completions."); - tracing::warn!("No connection to database."); return Ok(CompletionsResult::default()); } }; @@ -490,11 +489,10 @@ impl Workspace for WorkspaceServer { match get_statement_for_completions(&parsed_doc, params.position) { None => { - tracing::warn!("No statement found."); + tracing::debug!("No statement found."); Ok(CompletionsResult::default()) } - Some((_id, range, content, cst)) => { - tracing::warn!("found matching statement, content: {}", content); + Some((id, range, content, cst)) => { let position = params.position - range.start(); let items = pgt_completions::complete(pgt_completions::CompletionParams { @@ -504,6 +502,12 @@ impl Workspace for WorkspaceServer { text: content, }); + tracing::debug!( + "Found {} completion items for statement with id {}", + items.len(), + id.raw() + ); + Ok(CompletionsResult { items }) } } From 809aa84249b7c0de2be042e42c2598ee0ef2b990 Mon Sep 17 00:00:00 2001 From: Julian Date: Mon, 12 May 2025 18:56:07 +0200 Subject: [PATCH 09/18] unnecessary --- crates/pgt_completions/src/sanitization.rs | 26 ++-------------------- 1 file changed, 2 insertions(+), 24 deletions(-) diff --git a/crates/pgt_completions/src/sanitization.rs b/crates/pgt_completions/src/sanitization.rs index 1adf9d95..6aa75a16 100644 --- a/crates/pgt_completions/src/sanitization.rs +++ b/crates/pgt_completions/src/sanitization.rs @@ -53,7 +53,6 @@ where || cursor_prepared_to_write_token_after_last_node(¶ms.text, params.position) || cursor_before_semicolon(params.tree, params.position) || cursor_on_a_dot(¶ms.text, params.position) - || cursor_between_double_quotes(¶ms.text, params.position) { SanitizedCompletionParams::with_adjusted_sql(params) } else { @@ -160,13 +159,6 @@ fn cursor_on_a_dot(sql: &str, position: TextSize) -> bool { sql.chars().nth(position - 1).is_some_and(|c| c == '.') } -fn cursor_between_double_quotes(sql: &str, position: TextSize) -> bool { - return false; - let position: usize = position.into(); - let mut chars = sql.chars(); - chars.nth(position - 1).is_some_and(|c| c == '"') && chars.next().is_some_and(|c| c == '"') -} - fn cursor_before_semicolon(tree: &tree_sitter::Tree, position: TextSize) -> bool { let mut cursor = tree.walk(); let mut leaf_node = tree.root_node(); @@ -216,8 +208,8 @@ mod tests { use pgt_text_size::TextSize; use crate::sanitization::{ - cursor_before_semicolon, cursor_between_double_quotes, cursor_inbetween_nodes, - cursor_on_a_dot, cursor_prepared_to_write_token_after_last_node, + cursor_before_semicolon, cursor_inbetween_nodes, cursor_on_a_dot, + cursor_prepared_to_write_token_after_last_node, }; #[test] @@ -314,18 +306,4 @@ mod tests { assert!(cursor_before_semicolon(&tree, TextSize::new(16))); assert!(cursor_before_semicolon(&tree, TextSize::new(17))); } - - #[test] - fn between_quotations() { - let input = "select * from \"\""; - - // select * from "|" <-- between quotations - assert!(cursor_between_double_quotes(input, TextSize::new(15))); - - // select * from "r|" <-- between quotations, but there's - // a letter inside - let input = "select * from \"r\""; - - assert!(!cursor_between_double_quotes(input, TextSize::new(16))); - } } From 1950b4cda491f666a81a5b6f0abe8526452e2e20 Mon Sep 17 00:00:00 2001 From: Julian Date: Mon, 12 May 2025 19:01:59 +0200 Subject: [PATCH 10/18] better --- .../pgt_completions/src/providers/policies.rs | 63 +++++++------------ .../src/relevance/filtering.rs | 11 ++-- 2 files changed, 30 insertions(+), 44 deletions(-) diff --git a/crates/pgt_completions/src/providers/policies.rs b/crates/pgt_completions/src/providers/policies.rs index 7e6544b2..488affdc 100644 --- a/crates/pgt_completions/src/providers/policies.rs +++ b/crates/pgt_completions/src/providers/policies.rs @@ -17,6 +17,28 @@ pub fn complete_policies<'a>(ctx: &CompletionContext<'a>, builder: &mut Completi .is_some_and(|c| c.starts_with('"') && c.ends_with('"') && c != "\"\""); for pol in available_policies { + let completion_text = if surrounded_by_quotes { + // If we're within quotes, we want to change the content + // *within* the quotes. + // If we attempt to replace outside the quotes, the VSCode + // client won't show the suggestions. + let range = get_range_to_replace(ctx); + Some(CompletionText { + text: pol.name.clone(), + range: TextRange::new( + range.start() + TextSize::new(1), + range.end() - TextSize::new(1), + ), + }) + } else { + // If we aren't within quotes, we want to complete the + // full policy including quotation marks. + Some(CompletionText { + text: format!("\"{}\"", pol.name), + range: get_range_to_replace(ctx), + }) + }; + let relevance = CompletionRelevanceData::Policy(pol); let item = PossibleCompletionItem { @@ -25,23 +47,7 @@ pub fn complete_policies<'a>(ctx: &CompletionContext<'a>, builder: &mut Completi filter: CompletionFilter::from(relevance), description: format!("{}", pol.table_name), kind: CompletionItemKind::Policy, - completion_text: if !surrounded_by_quotes { - Some(CompletionText { - text: format!("\"{}\"", pol.name), - range: get_range_to_replace(ctx), - }) - } else { - let range = get_range_to_replace(ctx); - Some(CompletionText { - text: pol.name.clone(), - - // trim the quotes. - range: TextRange::new( - range.start() + TextSize::new(1), - range.end() - TextSize::new(1), - ), - }) - }, + completion_text, }; builder.add_item(item); @@ -50,13 +56,7 @@ pub fn complete_policies<'a>(ctx: &CompletionContext<'a>, builder: &mut Completi #[cfg(test)] mod tests { - use crate::{ - complete, - test_helper::{ - CURSOR_POS, CompletionAssertion, assert_complete_results, get_test_params, - test_against_connection_string, - }, - }; + use crate::test_helper::{CURSOR_POS, CompletionAssertion, assert_complete_results}; #[tokio::test] async fn completes_within_quotation_marks() { @@ -100,19 +100,4 @@ mod tests { ) .await; } - - #[tokio::test] - async fn sb_test() { - let input = format!("alter policy \"u{}\" on public.fcm_tokens;", CURSOR_POS); - - let (tree, cache) = test_against_connection_string( - "postgresql://postgres:postgres@127.0.0.1:54322/postgres", - input.as_str().into(), - ) - .await; - - let result = complete(get_test_params(&tree, &cache, input.as_str().into())); - - println!("{:#?}", result); - } } diff --git a/crates/pgt_completions/src/relevance/filtering.rs b/crates/pgt_completions/src/relevance/filtering.rs index c625d200..071e2994 100644 --- a/crates/pgt_completions/src/relevance/filtering.rs +++ b/crates/pgt_completions/src/relevance/filtering.rs @@ -71,8 +71,9 @@ impl CompletionFilter<'_> { match self.data { CompletionRelevanceData::Table(_) => { - if in_clause(WrappingClause::Select) || in_clause(WrappingClause::Where) - // || in_clause(WrappingClause::PolicyName) + if in_clause(WrappingClause::Select) + || in_clause(WrappingClause::Where) + || in_clause(WrappingClause::PolicyName) { return None; }; @@ -106,9 +107,9 @@ impl CompletionFilter<'_> { } } _ => { - // if in_clause(WrappingClause::PolicyName) { - // return None; - // } + if in_clause(WrappingClause::PolicyName) { + return None; + } } } From 5c12f8cae7ce62e36bf893c54ccc916691e47f07 Mon Sep 17 00:00:00 2001 From: Julian Date: Mon, 12 May 2025 19:20:59 +0200 Subject: [PATCH 11/18] add comments to parser --- .../src/context/policy_parser.rs | 78 +++++++++++-------- 1 file changed, 47 insertions(+), 31 deletions(-) diff --git a/crates/pgt_completions/src/context/policy_parser.rs b/crates/pgt_completions/src/context/policy_parser.rs index 04523261..c8e67791 100644 --- a/crates/pgt_completions/src/context/policy_parser.rs +++ b/crates/pgt_completions/src/context/policy_parser.rs @@ -19,7 +19,7 @@ struct WordWithIndex { } impl WordWithIndex { - fn under_cursor(&self, cursor_pos: usize) -> bool { + fn is_under_cursor(&self, cursor_pos: usize) -> bool { self.start <= cursor_pos && self.end > cursor_pos } @@ -30,54 +30,58 @@ impl WordWithIndex { } } +/// Note: A policy name within quotation marks will be considered a single word. fn sql_to_words(sql: &str) -> Result, String> { let mut words = vec![]; - let mut start: Option = None; + let mut start_of_word: Option = None; let mut current_word = String::new(); let mut in_quotation_marks = false; - for (pos, c) in sql.char_indices() { - if (c.is_ascii_whitespace() || c == ';') + for (current_position, current_char) in sql.char_indices() { + if (current_char.is_ascii_whitespace() || current_char == ';') && !current_word.is_empty() - && start.is_some() + && start_of_word.is_some() && !in_quotation_marks { words.push(WordWithIndex { word: current_word, - start: start.unwrap(), - end: pos, + start: start_of_word.unwrap(), + end: current_position, }); + current_word = String::new(); - start = None; - } else if (c.is_ascii_whitespace() || c == ';') && current_word.is_empty() { + start_of_word = None; + } else if (current_char.is_ascii_whitespace() || current_char == ';') + && current_word.is_empty() + { // do nothing - } else if c == '"' && start.is_none() { + } else if current_char == '"' && start_of_word.is_none() { in_quotation_marks = true; - start = Some(pos); - current_word.push(c); - } else if c == '"' && start.is_some() { - current_word.push(c); + current_word.push(current_char); + start_of_word = Some(current_position); + } else if current_char == '"' && start_of_word.is_some() { + current_word.push(current_char); words.push(WordWithIndex { word: current_word, - start: start.unwrap(), - end: pos + 1, + start: start_of_word.unwrap(), + end: current_position + 1, }); in_quotation_marks = false; - start = None; + start_of_word = None; current_word = String::new() - } else if start.is_some() { - current_word.push(c) + } else if start_of_word.is_some() { + current_word.push(current_char) } else { - start = Some(pos); - current_word.push(c); + start_of_word = Some(current_position); + current_word.push(current_char); } } - if !current_word.is_empty() && start.is_some() { + if !current_word.is_empty() && start_of_word.is_some() { words.push(WordWithIndex { word: current_word, - start: start.unwrap(), + start: start_of_word.unwrap(), end: sql.len(), }); } @@ -100,6 +104,10 @@ pub(crate) struct PolicyContext { pub node_kind: String, } +/// Simple parser that'll turn a policy-related statement into a context object required for +/// completions. +/// The parser will only work if the (trimmed) sql starts with `create policy`, `drop policy`, or `alter policy`. +/// It can only parse policy statements. pub(crate) struct PolicyParser { tokens: Peekable>, previous_token: Option, @@ -136,7 +144,7 @@ impl PolicyParser { fn parse(mut self) -> PolicyContext { while let Some(token) = self.advance() { - if token.under_cursor(self.cursor_position) { + if token.is_under_cursor(self.cursor_position) { self.handle_token_under_cursor(token); } else { self.handle_token(token); @@ -161,9 +169,8 @@ impl PolicyParser { } "on" => { if token.word.contains('.') { - let mut parts = token.word.split('.'); + let (schema_name, table_name) = self.schema_and_table_name(&token); - let schema_name: String = parts.next().unwrap().into(); let schema_name_len = schema_name.len(); self.context.schema_name = Some(schema_name); @@ -176,8 +183,16 @@ impl PolicyParser { .expect("Text too long"); self.context.node_range = range_without_schema; - self.context.node_text = parts.next().unwrap().into(); self.context.node_kind = "policy_table".into(); + + self.context.node_text = match table_name { + Some(node_text) => node_text, + + // In practice, this should never happen. + // The completion sanitization will add a word after a `.` if nothing follows it; + // the token_text will then look like `schema.REPLACED_TOKEN`. + None => String::new(), + }; } else { self.context.node_range = token.get_range(); self.context.node_text = token.word; @@ -209,7 +224,7 @@ impl PolicyParser { } "on" => self.table_with_schema(), - // skip the "to" so we don't parse it as the TO rolename + // skip the "to" so we don't parse it as the TO rolename when it's under the cursor "rename" if self.next_matches("to") => { self.advance(); } @@ -231,6 +246,7 @@ impl PolicyParser { } fn advance(&mut self) -> Option { + // we can't peek back n an iterator, so we'll have to keep track manually. self.previous_token = self.current_token.take(); self.current_token = self.tokens.next(); self.current_token.clone() @@ -238,10 +254,10 @@ impl PolicyParser { fn table_with_schema(&mut self) { self.advance().map(|token| { - if token.under_cursor(self.cursor_position) { + if token.is_under_cursor(self.cursor_position) { self.handle_token_under_cursor(token); } else if token.word.contains('.') { - let (schema, maybe_table) = self.schema_and_table_name(token); + let (schema, maybe_table) = self.schema_and_table_name(&token); self.context.schema_name = Some(schema); self.context.table_name = maybe_table; } else { @@ -250,7 +266,7 @@ impl PolicyParser { }); } - fn schema_and_table_name(&self, token: WordWithIndex) -> (String, Option) { + fn schema_and_table_name(&self, token: &WordWithIndex) -> (String, Option) { let mut parts = token.word.split('.'); ( From f3d3705206e5a4160067691323a431e8cd90b418 Mon Sep 17 00:00:00 2001 From: Julian Date: Mon, 12 May 2025 19:29:18 +0200 Subject: [PATCH 12/18] refactorio --- crates/pgt_completions/src/context/context.rs | 65 +++++++++---------- .../src/context/policy_parser.rs | 13 ++-- 2 files changed, 41 insertions(+), 37 deletions(-) diff --git a/crates/pgt_completions/src/context/context.rs b/crates/pgt_completions/src/context/context.rs index ea9df3f4..118616ea 100644 --- a/crates/pgt_completions/src/context/context.rs +++ b/crates/pgt_completions/src/context/context.rs @@ -173,47 +173,46 @@ impl<'a> CompletionContext<'a> { // policy handling is important to Supabase, but they are a PostgreSQL specific extension, // so the tree_sitter_sql language does not support it. // We infer the context manually. - if params.text.to_lowercase().starts_with("create policy") - || params.text.to_lowercase().starts_with("alter policy") - || params.text.to_lowercase().starts_with("drop policy") - { - let policy_context = PolicyParser::get_context(&ctx.text, ctx.position); - - ctx.node_under_cursor = Some(NodeUnderCursor::CustomNode { - text: policy_context.node_text.into(), - range: policy_context.node_range, - kind: policy_context.node_kind.clone(), - }); - - if policy_context.node_kind == "policy_table" { - ctx.schema_or_alias_name = policy_context.schema_name.clone(); - } - - if policy_context.table_name.is_some() { - let mut new = HashSet::new(); - new.insert(policy_context.table_name.unwrap()); - ctx.mentioned_relations - .insert(policy_context.schema_name, new); - } - - ctx.wrapping_clause_type = match policy_context.node_kind.as_str() { - "policy_name" if policy_context.statement_kind != PolicyStmtKind::Create => { - Some(WrappingClause::PolicyName) - } - "policy_role" => Some(WrappingClause::ToRoleAssignment), - "policy_table" => Some(WrappingClause::From), - _ => None, - }; + if PolicyParser::looks_like_policy_stmt(¶ms.text) { + ctx.gather_policy_context(); } else { ctx.gather_tree_context(); ctx.gather_info_from_ts_queries(); } - tracing::warn!("{:#?}", ctx.get_node_under_cursor_content()); - ctx } + fn gather_policy_context(&mut self) { + let policy_context = PolicyParser::get_context(&self.text, self.position); + + self.node_under_cursor = Some(NodeUnderCursor::CustomNode { + text: policy_context.node_text.into(), + range: policy_context.node_range, + kind: policy_context.node_kind.clone(), + }); + + if policy_context.node_kind == "policy_table" { + self.schema_or_alias_name = policy_context.schema_name.clone(); + } + + if policy_context.table_name.is_some() { + let mut new = HashSet::new(); + new.insert(policy_context.table_name.unwrap()); + self.mentioned_relations + .insert(policy_context.schema_name, new); + } + + self.wrapping_clause_type = match policy_context.node_kind.as_str() { + "policy_name" if policy_context.statement_kind != PolicyStmtKind::Create => { + Some(WrappingClause::PolicyName) + } + "policy_role" => Some(WrappingClause::ToRoleAssignment), + "policy_table" => Some(WrappingClause::From), + _ => None, + }; + } + fn gather_info_from_ts_queries(&mut self) { let stmt_range = self.wrapping_statement_range.as_ref(); let sql = self.text; diff --git a/crates/pgt_completions/src/context/policy_parser.rs b/crates/pgt_completions/src/context/policy_parser.rs index c8e67791..6ee77c77 100644 --- a/crates/pgt_completions/src/context/policy_parser.rs +++ b/crates/pgt_completions/src/context/policy_parser.rs @@ -117,12 +117,17 @@ pub(crate) struct PolicyParser { } impl PolicyParser { + pub(crate) fn looks_like_policy_stmt(sql: &str) -> bool { + let lowercased = sql.to_ascii_lowercase(); + let trimmed = lowercased.trim(); + trimmed.starts_with("create policy") + || trimmed.starts_with("drop policy") + || trimmed.starts_with("alter policy") + } + pub(crate) fn get_context(sql: &str, cursor_position: usize) -> PolicyContext { - let trimmed = sql.trim(); assert!( - trimmed.starts_with("create policy") - || trimmed.starts_with("drop policy") - || trimmed.starts_with("alter policy"), + Self::looks_like_policy_stmt(sql), "PolicyParser should only be used for policy statements. Developer error!" ); From 3d04a598006f5e5ee4c25e277f730951fe986d19 Mon Sep 17 00:00:00 2001 From: Julian Date: Tue, 13 May 2025 09:43:20 +0200 Subject: [PATCH 13/18] lint fixes --- crates/pgt_completions/src/context/context.rs | 10 +++++----- crates/pgt_completions/src/context/policy_parser.rs | 4 ++-- crates/pgt_completions/src/providers/policies.rs | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/crates/pgt_completions/src/context/context.rs b/crates/pgt_completions/src/context/context.rs index 118616ea..8e0c71fd 100644 --- a/crates/pgt_completions/src/context/context.rs +++ b/crates/pgt_completions/src/context/context.rs @@ -52,7 +52,7 @@ pub(crate) enum NodeUnderCursor<'a> { }, } -impl<'a> NodeUnderCursor<'a> { +impl NodeUnderCursor<'_> { pub fn start_byte(&self) -> usize { match self { NodeUnderCursor::TsNode(node) => node.start_byte(), @@ -184,7 +184,7 @@ impl<'a> CompletionContext<'a> { } fn gather_policy_context(&mut self) { - let policy_context = PolicyParser::get_context(&self.text, self.position); + let policy_context = PolicyParser::get_context(self.text, self.position); self.node_under_cursor = Some(NodeUnderCursor::CustomNode { text: policy_context.node_text.into(), @@ -658,7 +658,7 @@ mod tests { match node { NodeUnderCursor::TsNode(node) => { assert_eq!( - ctx.get_ts_node_content(&node), + ctx.get_ts_node_content(node), Some(NodeText::Original("from".into())) ); } @@ -688,7 +688,7 @@ mod tests { match node { NodeUnderCursor::TsNode(node) => { assert_eq!( - ctx.get_ts_node_content(&node), + ctx.get_ts_node_content(node), Some(NodeText::Original("".into())) ); assert_eq!(ctx.wrapping_clause_type, None); @@ -721,7 +721,7 @@ mod tests { match node { NodeUnderCursor::TsNode(node) => { assert_eq!( - ctx.get_ts_node_content(&node), + ctx.get_ts_node_content(node), Some(NodeText::Original("fro".into())) ); assert_eq!(ctx.wrapping_clause_type, Some(WrappingClause::Select)); diff --git a/crates/pgt_completions/src/context/policy_parser.rs b/crates/pgt_completions/src/context/policy_parser.rs index 6ee77c77..600a08a6 100644 --- a/crates/pgt_completions/src/context/policy_parser.rs +++ b/crates/pgt_completions/src/context/policy_parser.rs @@ -302,10 +302,10 @@ mod tests { } } - return ( + ( pos.expect("Please add cursor position!"), query.replace(CURSOR_POS, "REPLACED_TOKEN").to_string(), - ); + ) } #[test] diff --git a/crates/pgt_completions/src/providers/policies.rs b/crates/pgt_completions/src/providers/policies.rs index 488affdc..2421f1f1 100644 --- a/crates/pgt_completions/src/providers/policies.rs +++ b/crates/pgt_completions/src/providers/policies.rs @@ -45,7 +45,7 @@ pub fn complete_policies<'a>(ctx: &CompletionContext<'a>, builder: &mut Completi label: pol.name.chars().take(35).collect::(), score: CompletionScore::from(relevance.clone()), filter: CompletionFilter::from(relevance), - description: format!("{}", pol.table_name), + description: pol.table_name.to_string(), kind: CompletionItemKind::Policy, completion_text, }; From e1163326b675b9f83b8f483f0bb51bbe540d36d7 Mon Sep 17 00:00:00 2001 From: Julian Date: Tue, 13 May 2025 09:49:20 +0200 Subject: [PATCH 14/18] lints, even catched a bug --- crates/pgt_completions/src/context/context.rs | 732 ----------------- crates/pgt_completions/src/context/mod.rs | 734 +++++++++++++++++- .../src/context/policy_parser.rs | 30 +- .../src/relevance/filtering.rs | 2 +- 4 files changed, 747 insertions(+), 751 deletions(-) delete mode 100644 crates/pgt_completions/src/context/context.rs diff --git a/crates/pgt_completions/src/context/context.rs b/crates/pgt_completions/src/context/context.rs deleted file mode 100644 index 8e0c71fd..00000000 --- a/crates/pgt_completions/src/context/context.rs +++ /dev/null @@ -1,732 +0,0 @@ -use std::collections::{HashMap, HashSet}; - -use pgt_schema_cache::SchemaCache; -use pgt_text_size::TextRange; -use pgt_treesitter_queries::{ - TreeSitterQueriesExecutor, - queries::{self, QueryResult}, -}; - -use crate::{ - NodeText, - context::policy_parser::{PolicyParser, PolicyStmtKind}, - sanitization::SanitizedCompletionParams, -}; - -#[derive(Debug, PartialEq, Eq)] -pub enum WrappingClause<'a> { - Select, - Where, - From, - Join { - on_node: Option>, - }, - Update, - Delete, - PolicyName, - ToRoleAssignment, -} - -/// We can map a few nodes, such as the "update" node, to actual SQL clauses. -/// That gives us a lot of insight for completions. -/// Other nodes, such as the "relation" node, gives us less but still -/// relevant information. -/// `WrappingNode` maps to such nodes. -/// -/// Note: This is not the direct parent of the `node_under_cursor`, but the closest -/// *relevant* parent. -#[derive(Debug, PartialEq, Eq)] -pub enum WrappingNode { - Relation, - BinaryExpression, - Assignment, -} - -#[derive(Debug)] -pub(crate) enum NodeUnderCursor<'a> { - TsNode(tree_sitter::Node<'a>), - CustomNode { - text: NodeText, - range: TextRange, - kind: String, - }, -} - -impl NodeUnderCursor<'_> { - pub fn start_byte(&self) -> usize { - match self { - NodeUnderCursor::TsNode(node) => node.start_byte(), - NodeUnderCursor::CustomNode { range, .. } => range.start().into(), - } - } - - pub fn end_byte(&self) -> usize { - match self { - NodeUnderCursor::TsNode(node) => node.end_byte(), - NodeUnderCursor::CustomNode { range, .. } => range.end().into(), - } - } - - pub fn kind(&self) -> &str { - match self { - NodeUnderCursor::TsNode(node) => node.kind(), - NodeUnderCursor::CustomNode { kind, .. } => kind.as_str(), - } - } -} - -impl<'a> From> for NodeUnderCursor<'a> { - fn from(node: tree_sitter::Node<'a>) -> Self { - NodeUnderCursor::TsNode(node) - } -} - -impl TryFrom<&str> for WrappingNode { - type Error = String; - - fn try_from(value: &str) -> Result { - match value { - "relation" => Ok(Self::Relation), - "assignment" => Ok(Self::Assignment), - "binary_expression" => Ok(Self::BinaryExpression), - _ => { - let message = format!("Unimplemented Relation: {}", value); - - // Err on tests, so we notice that we're lacking an implementation immediately. - if cfg!(test) { - panic!("{}", message); - } - - Err(message) - } - } - } -} - -impl TryFrom for WrappingNode { - type Error = String; - fn try_from(value: String) -> Result { - Self::try_from(value.as_str()) - } -} - -pub(crate) struct CompletionContext<'a> { - pub node_under_cursor: Option>, - - pub tree: &'a tree_sitter::Tree, - pub text: &'a str, - pub schema_cache: &'a SchemaCache, - pub position: usize, - - /// If the cursor is on a node that uses dot notation - /// to specify an alias or schema, this will hold the schema's or - /// alias's name. - /// - /// Here, `auth` is a schema name: - /// ```sql - /// select * from auth.users; - /// ``` - /// - /// Here, `u` is an alias name: - /// ```sql - /// select - /// * - /// from - /// auth.users u - /// left join identities i - /// on u.id = i.user_id; - /// ``` - pub schema_or_alias_name: Option, - pub wrapping_clause_type: Option>, - - pub wrapping_node_kind: Option, - - pub is_invocation: bool, - pub wrapping_statement_range: Option, - - /// Some incomplete statements can't be correctly parsed by TreeSitter. - pub is_in_error_node: bool, - - pub mentioned_relations: HashMap, HashSet>, - - pub mentioned_table_aliases: HashMap, -} - -impl<'a> CompletionContext<'a> { - pub fn new(params: &'a SanitizedCompletionParams) -> Self { - let mut ctx = Self { - tree: params.tree.as_ref(), - text: ¶ms.text, - schema_cache: params.schema, - position: usize::from(params.position), - node_under_cursor: None, - schema_or_alias_name: None, - wrapping_clause_type: None, - wrapping_node_kind: None, - wrapping_statement_range: None, - is_invocation: false, - mentioned_relations: HashMap::new(), - mentioned_table_aliases: HashMap::new(), - is_in_error_node: false, - }; - - // policy handling is important to Supabase, but they are a PostgreSQL specific extension, - // so the tree_sitter_sql language does not support it. - // We infer the context manually. - if PolicyParser::looks_like_policy_stmt(¶ms.text) { - ctx.gather_policy_context(); - } else { - ctx.gather_tree_context(); - ctx.gather_info_from_ts_queries(); - } - - ctx - } - - fn gather_policy_context(&mut self) { - let policy_context = PolicyParser::get_context(self.text, self.position); - - self.node_under_cursor = Some(NodeUnderCursor::CustomNode { - text: policy_context.node_text.into(), - range: policy_context.node_range, - kind: policy_context.node_kind.clone(), - }); - - if policy_context.node_kind == "policy_table" { - self.schema_or_alias_name = policy_context.schema_name.clone(); - } - - if policy_context.table_name.is_some() { - let mut new = HashSet::new(); - new.insert(policy_context.table_name.unwrap()); - self.mentioned_relations - .insert(policy_context.schema_name, new); - } - - self.wrapping_clause_type = match policy_context.node_kind.as_str() { - "policy_name" if policy_context.statement_kind != PolicyStmtKind::Create => { - Some(WrappingClause::PolicyName) - } - "policy_role" => Some(WrappingClause::ToRoleAssignment), - "policy_table" => Some(WrappingClause::From), - _ => None, - }; - } - - fn gather_info_from_ts_queries(&mut self) { - let stmt_range = self.wrapping_statement_range.as_ref(); - let sql = self.text; - - let mut executor = TreeSitterQueriesExecutor::new(self.tree.root_node(), sql); - - executor.add_query_results::(); - executor.add_query_results::(); - - for relation_match in executor.get_iter(stmt_range) { - match relation_match { - QueryResult::Relation(r) => { - let schema_name = r.get_schema(sql); - let table_name = r.get_table(sql); - - let current = self.mentioned_relations.get_mut(&schema_name); - - match current { - Some(c) => { - c.insert(table_name); - } - None => { - let mut new = HashSet::new(); - new.insert(table_name); - self.mentioned_relations.insert(schema_name, new); - } - }; - } - - QueryResult::TableAliases(table_alias_match) => { - self.mentioned_table_aliases.insert( - table_alias_match.get_alias(sql), - table_alias_match.get_table(sql), - ); - } - }; - } - } - - fn get_ts_node_content(&self, ts_node: &tree_sitter::Node<'a>) -> Option { - let source = self.text; - ts_node.utf8_text(source.as_bytes()).ok().map(|txt| { - if SanitizedCompletionParams::is_sanitized_token(txt) { - NodeText::Replaced - } else { - NodeText::Original(txt.into()) - } - }) - } - - pub fn get_node_under_cursor_content(&self) -> Option { - match self.node_under_cursor.as_ref()? { - NodeUnderCursor::TsNode(node) => { - self.get_ts_node_content(node).and_then(|nt| match nt { - NodeText::Replaced => None, - NodeText::Original(c) => Some(c.to_string()), - }) - } - NodeUnderCursor::CustomNode { text, .. } => match text { - NodeText::Replaced => None, - NodeText::Original(c) => Some(c.to_string()), - }, - } - } - - fn gather_tree_context(&mut self) { - let mut cursor = self.tree.root_node().walk(); - - /* - * The head node of any treesitter tree is always the "PROGRAM" node. - * - * We want to enter the next layer and focus on the child node that matches the user's cursor position. - * If there is no node under the users position, however, the cursor won't enter the next level – it - * will stay on the Program node. - * - * This might lead to an unexpected context or infinite recursion. - * - * We'll therefore adjust the cursor position such that it meets the last node of the AST. - * `select * from use {}` becomes `select * from use{}`. - */ - let current_node = cursor.node(); - while cursor.goto_first_child_for_byte(self.position).is_none() && self.position > 0 { - self.position -= 1; - } - - self.gather_context_from_node(cursor, current_node); - } - - fn gather_context_from_node( - &mut self, - mut cursor: tree_sitter::TreeCursor<'a>, - parent_node: tree_sitter::Node<'a>, - ) { - let current_node = cursor.node(); - - let parent_node_kind = parent_node.kind(); - let current_node_kind = current_node.kind(); - - // prevent infinite recursion – this can happen if we only have a PROGRAM node - if current_node_kind == parent_node_kind { - self.node_under_cursor = Some(NodeUnderCursor::from(current_node)); - return; - } - - match parent_node_kind { - "statement" | "subquery" => { - self.wrapping_clause_type = - self.get_wrapping_clause_from_current_node(current_node, &mut cursor); - - self.wrapping_statement_range = Some(parent_node.range()); - } - "invocation" => self.is_invocation = true, - _ => {} - } - - // try to gather context from the siblings if we're within an error node. - if self.is_in_error_node { - let mut next_sibling = current_node.next_named_sibling(); - while let Some(n) = next_sibling { - if let Some(clause_type) = self.get_wrapping_clause_from_keyword_node(n) { - self.wrapping_clause_type = Some(clause_type); - break; - } else { - next_sibling = n.next_named_sibling(); - } - } - let mut prev_sibling = current_node.prev_named_sibling(); - while let Some(n) = prev_sibling { - if let Some(clause_type) = self.get_wrapping_clause_from_keyword_node(n) { - self.wrapping_clause_type = Some(clause_type); - break; - } else { - prev_sibling = n.prev_named_sibling(); - } - } - } - - match current_node_kind { - "object_reference" | "field" => { - let content = self.get_ts_node_content(¤t_node); - if let Some(node_txt) = content { - match node_txt { - NodeText::Original(txt) => { - let parts: Vec<&str> = txt.split('.').collect(); - if parts.len() == 2 { - self.schema_or_alias_name = Some(parts[0].to_string()); - } - } - NodeText::Replaced => {} - } - } - } - - "where" | "update" | "select" | "delete" | "from" | "join" => { - self.wrapping_clause_type = - self.get_wrapping_clause_from_current_node(current_node, &mut cursor); - } - - "relation" | "binary_expression" | "assignment" => { - self.wrapping_node_kind = current_node_kind.try_into().ok(); - } - - "ERROR" => { - self.is_in_error_node = true; - } - - _ => {} - } - - // We have arrived at the leaf node - if current_node.child_count() == 0 { - self.node_under_cursor = Some(NodeUnderCursor::from(current_node)); - return; - } - - cursor.goto_first_child_for_byte(self.position); - self.gather_context_from_node(cursor, current_node); - } - - fn get_wrapping_clause_from_keyword_node( - &self, - node: tree_sitter::Node<'a>, - ) -> Option> { - if node.kind().starts_with("keyword_") { - if let Some(txt) = self.get_ts_node_content(&node).and_then(|txt| match txt { - NodeText::Original(txt) => Some(txt), - NodeText::Replaced => None, - }) { - match txt.as_str() { - "where" => return Some(WrappingClause::Where), - "update" => return Some(WrappingClause::Update), - "select" => return Some(WrappingClause::Select), - "delete" => return Some(WrappingClause::Delete), - "from" => return Some(WrappingClause::From), - "join" => { - // TODO: not sure if we can infer it here. - return Some(WrappingClause::Join { on_node: None }); - } - _ => {} - } - }; - } - - None - } - - fn get_wrapping_clause_from_current_node( - &self, - node: tree_sitter::Node<'a>, - cursor: &mut tree_sitter::TreeCursor<'a>, - ) -> Option> { - match node.kind() { - "where" => Some(WrappingClause::Where), - "update" => Some(WrappingClause::Update), - "select" => Some(WrappingClause::Select), - "delete" => Some(WrappingClause::Delete), - "from" => Some(WrappingClause::From), - "join" => { - // sadly, we need to manually iterate over the children – - // `node.child_by_field_id(..)` does not work as expected - let mut on_node = None; - for child in node.children(cursor) { - // 28 is the id for "keyword_on" - if child.kind_id() == 28 { - on_node = Some(child); - } - } - cursor.goto_parent(); - Some(WrappingClause::Join { on_node }) - } - _ => None, - } - } -} - -#[cfg(test)] -mod tests { - use crate::{ - NodeText, - context::{CompletionContext, WrappingClause}, - sanitization::SanitizedCompletionParams, - test_helper::{CURSOR_POS, get_text_and_position}, - }; - - use super::NodeUnderCursor; - - fn get_tree(input: &str) -> tree_sitter::Tree { - let mut parser = tree_sitter::Parser::new(); - parser - .set_language(tree_sitter_sql::language()) - .expect("Couldn't set language"); - - parser.parse(input, None).expect("Unable to parse tree") - } - - #[test] - fn identifies_clauses() { - let test_cases = vec![ - ( - format!("Select {}* from users;", CURSOR_POS), - WrappingClause::Select, - ), - ( - format!("Select * from u{};", CURSOR_POS), - WrappingClause::From, - ), - ( - format!("Select {}* from users where n = 1;", CURSOR_POS), - WrappingClause::Select, - ), - ( - format!("Select * from users where {}n = 1;", CURSOR_POS), - WrappingClause::Where, - ), - ( - format!("update users set u{} = 1 where n = 2;", CURSOR_POS), - WrappingClause::Update, - ), - ( - format!("update users set u = 1 where n{} = 2;", CURSOR_POS), - WrappingClause::Where, - ), - ( - format!("delete{} from users;", CURSOR_POS), - WrappingClause::Delete, - ), - ( - format!("delete from {}users;", CURSOR_POS), - WrappingClause::From, - ), - ( - format!("select name, age, location from public.u{}sers", CURSOR_POS), - WrappingClause::From, - ), - ]; - - for (query, expected_clause) in test_cases { - let (position, text) = get_text_and_position(query.as_str().into()); - - let tree = get_tree(text.as_str()); - - let params = SanitizedCompletionParams { - position: (position as u32).into(), - text, - tree: std::borrow::Cow::Owned(tree), - schema: &pgt_schema_cache::SchemaCache::default(), - }; - - let ctx = CompletionContext::new(¶ms); - - assert_eq!(ctx.wrapping_clause_type, Some(expected_clause)); - } - } - - #[test] - fn identifies_schema() { - let test_cases = vec![ - ( - format!("Select * from private.u{}", CURSOR_POS), - Some("private"), - ), - ( - format!("Select * from private.u{}sers()", CURSOR_POS), - Some("private"), - ), - (format!("Select * from u{}sers", CURSOR_POS), None), - (format!("Select * from u{}sers()", CURSOR_POS), None), - ]; - - for (query, expected_schema) in test_cases { - let (position, text) = get_text_and_position(query.as_str().into()); - - let tree = get_tree(text.as_str()); - let params = SanitizedCompletionParams { - position: (position as u32).into(), - text, - tree: std::borrow::Cow::Owned(tree), - schema: &pgt_schema_cache::SchemaCache::default(), - }; - - let ctx = CompletionContext::new(¶ms); - - assert_eq!( - ctx.schema_or_alias_name, - expected_schema.map(|f| f.to_string()) - ); - } - } - - #[test] - fn identifies_invocation() { - let test_cases = vec![ - (format!("Select * from u{}sers", CURSOR_POS), false), - (format!("Select * from u{}sers()", CURSOR_POS), true), - (format!("Select cool{};", CURSOR_POS), false), - (format!("Select cool{}();", CURSOR_POS), true), - ( - format!("Select upp{}ercase as title from users;", CURSOR_POS), - false, - ), - ( - format!("Select upp{}ercase(name) as title from users;", CURSOR_POS), - true, - ), - ]; - - for (query, is_invocation) in test_cases { - let (position, text) = get_text_and_position(query.as_str().into()); - - let tree = get_tree(text.as_str()); - let params = SanitizedCompletionParams { - position: (position as u32).into(), - text, - tree: std::borrow::Cow::Owned(tree), - schema: &pgt_schema_cache::SchemaCache::default(), - }; - - let ctx = CompletionContext::new(¶ms); - - assert_eq!(ctx.is_invocation, is_invocation); - } - } - - #[test] - fn does_not_fail_on_leading_whitespace() { - let cases = vec![ - format!("{} select * from", CURSOR_POS), - format!(" {} select * from", CURSOR_POS), - ]; - - for query in cases { - let (position, text) = get_text_and_position(query.as_str().into()); - - let tree = get_tree(text.as_str()); - - let params = SanitizedCompletionParams { - position: (position as u32).into(), - text, - tree: std::borrow::Cow::Owned(tree), - schema: &pgt_schema_cache::SchemaCache::default(), - }; - - let ctx = CompletionContext::new(¶ms); - - let node = ctx.node_under_cursor.as_ref().unwrap(); - - match node { - NodeUnderCursor::TsNode(node) => { - assert_eq!( - ctx.get_ts_node_content(node), - Some(NodeText::Original("select".into())) - ); - - assert_eq!( - ctx.wrapping_clause_type, - Some(crate::context::WrappingClause::Select) - ); - } - _ => unreachable!(), - } - } - } - - #[test] - fn does_not_fail_on_trailing_whitespace() { - let query = format!("select * from {}", CURSOR_POS); - - let (position, text) = get_text_and_position(query.as_str().into()); - - let tree = get_tree(text.as_str()); - - let params = SanitizedCompletionParams { - position: (position as u32).into(), - text, - tree: std::borrow::Cow::Owned(tree), - schema: &pgt_schema_cache::SchemaCache::default(), - }; - - let ctx = CompletionContext::new(¶ms); - - let node = ctx.node_under_cursor.as_ref().unwrap(); - - match node { - NodeUnderCursor::TsNode(node) => { - assert_eq!( - ctx.get_ts_node_content(node), - Some(NodeText::Original("from".into())) - ); - } - _ => unreachable!(), - } - } - - #[test] - fn does_not_fail_with_empty_statements() { - let query = format!("{}", CURSOR_POS); - - let (position, text) = get_text_and_position(query.as_str().into()); - - let tree = get_tree(text.as_str()); - - let params = SanitizedCompletionParams { - position: (position as u32).into(), - text, - tree: std::borrow::Cow::Owned(tree), - schema: &pgt_schema_cache::SchemaCache::default(), - }; - - let ctx = CompletionContext::new(¶ms); - - let node = ctx.node_under_cursor.as_ref().unwrap(); - - match node { - NodeUnderCursor::TsNode(node) => { - assert_eq!( - ctx.get_ts_node_content(node), - Some(NodeText::Original("".into())) - ); - assert_eq!(ctx.wrapping_clause_type, None); - } - _ => unreachable!(), - } - } - - #[test] - fn does_not_fail_on_incomplete_keywords() { - // Instead of autocompleting "FROM", we'll assume that the user - // is selecting a certain column name, such as `frozen_account`. - let query = format!("select * fro{}", CURSOR_POS); - - let (position, text) = get_text_and_position(query.as_str().into()); - - let tree = get_tree(text.as_str()); - - let params = SanitizedCompletionParams { - position: (position as u32).into(), - text, - tree: std::borrow::Cow::Owned(tree), - schema: &pgt_schema_cache::SchemaCache::default(), - }; - - let ctx = CompletionContext::new(¶ms); - - let node = ctx.node_under_cursor.as_ref().unwrap(); - - match node { - NodeUnderCursor::TsNode(node) => { - assert_eq!( - ctx.get_ts_node_content(node), - Some(NodeText::Original("fro".into())) - ); - assert_eq!(ctx.wrapping_clause_type, Some(WrappingClause::Select)); - } - _ => unreachable!(), - } - } -} diff --git a/crates/pgt_completions/src/context/mod.rs b/crates/pgt_completions/src/context/mod.rs index 828b6477..c77d092c 100644 --- a/crates/pgt_completions/src/context/mod.rs +++ b/crates/pgt_completions/src/context/mod.rs @@ -1,4 +1,734 @@ -mod context; mod policy_parser; -pub use context::*; +use std::collections::{HashMap, HashSet}; + +use pgt_schema_cache::SchemaCache; +use pgt_text_size::TextRange; +use pgt_treesitter_queries::{ + TreeSitterQueriesExecutor, + queries::{self, QueryResult}, +}; + +use crate::{ + NodeText, + context::policy_parser::{PolicyParser, PolicyStmtKind}, + sanitization::SanitizedCompletionParams, +}; + +#[derive(Debug, PartialEq, Eq)] +pub enum WrappingClause<'a> { + Select, + Where, + From, + Join { + on_node: Option>, + }, + Update, + Delete, + PolicyName, + ToRoleAssignment, +} + +/// We can map a few nodes, such as the "update" node, to actual SQL clauses. +/// That gives us a lot of insight for completions. +/// Other nodes, such as the "relation" node, gives us less but still +/// relevant information. +/// `WrappingNode` maps to such nodes. +/// +/// Note: This is not the direct parent of the `node_under_cursor`, but the closest +/// *relevant* parent. +#[derive(Debug, PartialEq, Eq)] +pub enum WrappingNode { + Relation, + BinaryExpression, + Assignment, +} + +#[derive(Debug)] +pub(crate) enum NodeUnderCursor<'a> { + TsNode(tree_sitter::Node<'a>), + CustomNode { + text: NodeText, + range: TextRange, + kind: String, + }, +} + +impl NodeUnderCursor<'_> { + pub fn start_byte(&self) -> usize { + match self { + NodeUnderCursor::TsNode(node) => node.start_byte(), + NodeUnderCursor::CustomNode { range, .. } => range.start().into(), + } + } + + pub fn end_byte(&self) -> usize { + match self { + NodeUnderCursor::TsNode(node) => node.end_byte(), + NodeUnderCursor::CustomNode { range, .. } => range.end().into(), + } + } + + pub fn kind(&self) -> &str { + match self { + NodeUnderCursor::TsNode(node) => node.kind(), + NodeUnderCursor::CustomNode { kind, .. } => kind.as_str(), + } + } +} + +impl<'a> From> for NodeUnderCursor<'a> { + fn from(node: tree_sitter::Node<'a>) -> Self { + NodeUnderCursor::TsNode(node) + } +} + +impl TryFrom<&str> for WrappingNode { + type Error = String; + + fn try_from(value: &str) -> Result { + match value { + "relation" => Ok(Self::Relation), + "assignment" => Ok(Self::Assignment), + "binary_expression" => Ok(Self::BinaryExpression), + _ => { + let message = format!("Unimplemented Relation: {}", value); + + // Err on tests, so we notice that we're lacking an implementation immediately. + if cfg!(test) { + panic!("{}", message); + } + + Err(message) + } + } + } +} + +impl TryFrom for WrappingNode { + type Error = String; + fn try_from(value: String) -> Result { + Self::try_from(value.as_str()) + } +} + +pub(crate) struct CompletionContext<'a> { + pub node_under_cursor: Option>, + + pub tree: &'a tree_sitter::Tree, + pub text: &'a str, + pub schema_cache: &'a SchemaCache, + pub position: usize, + + /// If the cursor is on a node that uses dot notation + /// to specify an alias or schema, this will hold the schema's or + /// alias's name. + /// + /// Here, `auth` is a schema name: + /// ```sql + /// select * from auth.users; + /// ``` + /// + /// Here, `u` is an alias name: + /// ```sql + /// select + /// * + /// from + /// auth.users u + /// left join identities i + /// on u.id = i.user_id; + /// ``` + pub schema_or_alias_name: Option, + pub wrapping_clause_type: Option>, + + pub wrapping_node_kind: Option, + + pub is_invocation: bool, + pub wrapping_statement_range: Option, + + /// Some incomplete statements can't be correctly parsed by TreeSitter. + pub is_in_error_node: bool, + + pub mentioned_relations: HashMap, HashSet>, + + pub mentioned_table_aliases: HashMap, +} + +impl<'a> CompletionContext<'a> { + pub fn new(params: &'a SanitizedCompletionParams) -> Self { + let mut ctx = Self { + tree: params.tree.as_ref(), + text: ¶ms.text, + schema_cache: params.schema, + position: usize::from(params.position), + node_under_cursor: None, + schema_or_alias_name: None, + wrapping_clause_type: None, + wrapping_node_kind: None, + wrapping_statement_range: None, + is_invocation: false, + mentioned_relations: HashMap::new(), + mentioned_table_aliases: HashMap::new(), + is_in_error_node: false, + }; + + // policy handling is important to Supabase, but they are a PostgreSQL specific extension, + // so the tree_sitter_sql language does not support it. + // We infer the context manually. + if PolicyParser::looks_like_policy_stmt(¶ms.text) { + ctx.gather_policy_context(); + } else { + ctx.gather_tree_context(); + ctx.gather_info_from_ts_queries(); + } + + ctx + } + + fn gather_policy_context(&mut self) { + let policy_context = PolicyParser::get_context(self.text, self.position); + + self.node_under_cursor = Some(NodeUnderCursor::CustomNode { + text: policy_context.node_text.into(), + range: policy_context.node_range, + kind: policy_context.node_kind.clone(), + }); + + if policy_context.node_kind == "policy_table" { + self.schema_or_alias_name = policy_context.schema_name.clone(); + } + + if policy_context.table_name.is_some() { + let mut new = HashSet::new(); + new.insert(policy_context.table_name.unwrap()); + self.mentioned_relations + .insert(policy_context.schema_name, new); + } + + self.wrapping_clause_type = match policy_context.node_kind.as_str() { + "policy_name" if policy_context.statement_kind != PolicyStmtKind::Create => { + Some(WrappingClause::PolicyName) + } + "policy_role" => Some(WrappingClause::ToRoleAssignment), + "policy_table" => Some(WrappingClause::From), + _ => None, + }; + } + + fn gather_info_from_ts_queries(&mut self) { + let stmt_range = self.wrapping_statement_range.as_ref(); + let sql = self.text; + + let mut executor = TreeSitterQueriesExecutor::new(self.tree.root_node(), sql); + + executor.add_query_results::(); + executor.add_query_results::(); + + for relation_match in executor.get_iter(stmt_range) { + match relation_match { + QueryResult::Relation(r) => { + let schema_name = r.get_schema(sql); + let table_name = r.get_table(sql); + + let current = self.mentioned_relations.get_mut(&schema_name); + + match current { + Some(c) => { + c.insert(table_name); + } + None => { + let mut new = HashSet::new(); + new.insert(table_name); + self.mentioned_relations.insert(schema_name, new); + } + }; + } + + QueryResult::TableAliases(table_alias_match) => { + self.mentioned_table_aliases.insert( + table_alias_match.get_alias(sql), + table_alias_match.get_table(sql), + ); + } + }; + } + } + + fn get_ts_node_content(&self, ts_node: &tree_sitter::Node<'a>) -> Option { + let source = self.text; + ts_node.utf8_text(source.as_bytes()).ok().map(|txt| { + if SanitizedCompletionParams::is_sanitized_token(txt) { + NodeText::Replaced + } else { + NodeText::Original(txt.into()) + } + }) + } + + pub fn get_node_under_cursor_content(&self) -> Option { + match self.node_under_cursor.as_ref()? { + NodeUnderCursor::TsNode(node) => { + self.get_ts_node_content(node).and_then(|nt| match nt { + NodeText::Replaced => None, + NodeText::Original(c) => Some(c.to_string()), + }) + } + NodeUnderCursor::CustomNode { text, .. } => match text { + NodeText::Replaced => None, + NodeText::Original(c) => Some(c.to_string()), + }, + } + } + + fn gather_tree_context(&mut self) { + let mut cursor = self.tree.root_node().walk(); + + /* + * The head node of any treesitter tree is always the "PROGRAM" node. + * + * We want to enter the next layer and focus on the child node that matches the user's cursor position. + * If there is no node under the users position, however, the cursor won't enter the next level – it + * will stay on the Program node. + * + * This might lead to an unexpected context or infinite recursion. + * + * We'll therefore adjust the cursor position such that it meets the last node of the AST. + * `select * from use {}` becomes `select * from use{}`. + */ + let current_node = cursor.node(); + while cursor.goto_first_child_for_byte(self.position).is_none() && self.position > 0 { + self.position -= 1; + } + + self.gather_context_from_node(cursor, current_node); + } + + fn gather_context_from_node( + &mut self, + mut cursor: tree_sitter::TreeCursor<'a>, + parent_node: tree_sitter::Node<'a>, + ) { + let current_node = cursor.node(); + + let parent_node_kind = parent_node.kind(); + let current_node_kind = current_node.kind(); + + // prevent infinite recursion – this can happen if we only have a PROGRAM node + if current_node_kind == parent_node_kind { + self.node_under_cursor = Some(NodeUnderCursor::from(current_node)); + return; + } + + match parent_node_kind { + "statement" | "subquery" => { + self.wrapping_clause_type = + self.get_wrapping_clause_from_current_node(current_node, &mut cursor); + + self.wrapping_statement_range = Some(parent_node.range()); + } + "invocation" => self.is_invocation = true, + _ => {} + } + + // try to gather context from the siblings if we're within an error node. + if self.is_in_error_node { + let mut next_sibling = current_node.next_named_sibling(); + while let Some(n) = next_sibling { + if let Some(clause_type) = self.get_wrapping_clause_from_keyword_node(n) { + self.wrapping_clause_type = Some(clause_type); + break; + } else { + next_sibling = n.next_named_sibling(); + } + } + let mut prev_sibling = current_node.prev_named_sibling(); + while let Some(n) = prev_sibling { + if let Some(clause_type) = self.get_wrapping_clause_from_keyword_node(n) { + self.wrapping_clause_type = Some(clause_type); + break; + } else { + prev_sibling = n.prev_named_sibling(); + } + } + } + + match current_node_kind { + "object_reference" | "field" => { + let content = self.get_ts_node_content(¤t_node); + if let Some(node_txt) = content { + match node_txt { + NodeText::Original(txt) => { + let parts: Vec<&str> = txt.split('.').collect(); + if parts.len() == 2 { + self.schema_or_alias_name = Some(parts[0].to_string()); + } + } + NodeText::Replaced => {} + } + } + } + + "where" | "update" | "select" | "delete" | "from" | "join" => { + self.wrapping_clause_type = + self.get_wrapping_clause_from_current_node(current_node, &mut cursor); + } + + "relation" | "binary_expression" | "assignment" => { + self.wrapping_node_kind = current_node_kind.try_into().ok(); + } + + "ERROR" => { + self.is_in_error_node = true; + } + + _ => {} + } + + // We have arrived at the leaf node + if current_node.child_count() == 0 { + self.node_under_cursor = Some(NodeUnderCursor::from(current_node)); + return; + } + + cursor.goto_first_child_for_byte(self.position); + self.gather_context_from_node(cursor, current_node); + } + + fn get_wrapping_clause_from_keyword_node( + &self, + node: tree_sitter::Node<'a>, + ) -> Option> { + if node.kind().starts_with("keyword_") { + if let Some(txt) = self.get_ts_node_content(&node).and_then(|txt| match txt { + NodeText::Original(txt) => Some(txt), + NodeText::Replaced => None, + }) { + match txt.as_str() { + "where" => return Some(WrappingClause::Where), + "update" => return Some(WrappingClause::Update), + "select" => return Some(WrappingClause::Select), + "delete" => return Some(WrappingClause::Delete), + "from" => return Some(WrappingClause::From), + "join" => { + // TODO: not sure if we can infer it here. + return Some(WrappingClause::Join { on_node: None }); + } + _ => {} + } + }; + } + + None + } + + fn get_wrapping_clause_from_current_node( + &self, + node: tree_sitter::Node<'a>, + cursor: &mut tree_sitter::TreeCursor<'a>, + ) -> Option> { + match node.kind() { + "where" => Some(WrappingClause::Where), + "update" => Some(WrappingClause::Update), + "select" => Some(WrappingClause::Select), + "delete" => Some(WrappingClause::Delete), + "from" => Some(WrappingClause::From), + "join" => { + // sadly, we need to manually iterate over the children – + // `node.child_by_field_id(..)` does not work as expected + let mut on_node = None; + for child in node.children(cursor) { + // 28 is the id for "keyword_on" + if child.kind_id() == 28 { + on_node = Some(child); + } + } + cursor.goto_parent(); + Some(WrappingClause::Join { on_node }) + } + _ => None, + } + } +} + +#[cfg(test)] +mod tests { + use crate::{ + NodeText, + context::{CompletionContext, WrappingClause}, + sanitization::SanitizedCompletionParams, + test_helper::{CURSOR_POS, get_text_and_position}, + }; + + use super::NodeUnderCursor; + + fn get_tree(input: &str) -> tree_sitter::Tree { + let mut parser = tree_sitter::Parser::new(); + parser + .set_language(tree_sitter_sql::language()) + .expect("Couldn't set language"); + + parser.parse(input, None).expect("Unable to parse tree") + } + + #[test] + fn identifies_clauses() { + let test_cases = vec![ + ( + format!("Select {}* from users;", CURSOR_POS), + WrappingClause::Select, + ), + ( + format!("Select * from u{};", CURSOR_POS), + WrappingClause::From, + ), + ( + format!("Select {}* from users where n = 1;", CURSOR_POS), + WrappingClause::Select, + ), + ( + format!("Select * from users where {}n = 1;", CURSOR_POS), + WrappingClause::Where, + ), + ( + format!("update users set u{} = 1 where n = 2;", CURSOR_POS), + WrappingClause::Update, + ), + ( + format!("update users set u = 1 where n{} = 2;", CURSOR_POS), + WrappingClause::Where, + ), + ( + format!("delete{} from users;", CURSOR_POS), + WrappingClause::Delete, + ), + ( + format!("delete from {}users;", CURSOR_POS), + WrappingClause::From, + ), + ( + format!("select name, age, location from public.u{}sers", CURSOR_POS), + WrappingClause::From, + ), + ]; + + for (query, expected_clause) in test_cases { + let (position, text) = get_text_and_position(query.as_str().into()); + + let tree = get_tree(text.as_str()); + + let params = SanitizedCompletionParams { + position: (position as u32).into(), + text, + tree: std::borrow::Cow::Owned(tree), + schema: &pgt_schema_cache::SchemaCache::default(), + }; + + let ctx = CompletionContext::new(¶ms); + + assert_eq!(ctx.wrapping_clause_type, Some(expected_clause)); + } + } + + #[test] + fn identifies_schema() { + let test_cases = vec![ + ( + format!("Select * from private.u{}", CURSOR_POS), + Some("private"), + ), + ( + format!("Select * from private.u{}sers()", CURSOR_POS), + Some("private"), + ), + (format!("Select * from u{}sers", CURSOR_POS), None), + (format!("Select * from u{}sers()", CURSOR_POS), None), + ]; + + for (query, expected_schema) in test_cases { + let (position, text) = get_text_and_position(query.as_str().into()); + + let tree = get_tree(text.as_str()); + let params = SanitizedCompletionParams { + position: (position as u32).into(), + text, + tree: std::borrow::Cow::Owned(tree), + schema: &pgt_schema_cache::SchemaCache::default(), + }; + + let ctx = CompletionContext::new(¶ms); + + assert_eq!( + ctx.schema_or_alias_name, + expected_schema.map(|f| f.to_string()) + ); + } + } + + #[test] + fn identifies_invocation() { + let test_cases = vec![ + (format!("Select * from u{}sers", CURSOR_POS), false), + (format!("Select * from u{}sers()", CURSOR_POS), true), + (format!("Select cool{};", CURSOR_POS), false), + (format!("Select cool{}();", CURSOR_POS), true), + ( + format!("Select upp{}ercase as title from users;", CURSOR_POS), + false, + ), + ( + format!("Select upp{}ercase(name) as title from users;", CURSOR_POS), + true, + ), + ]; + + for (query, is_invocation) in test_cases { + let (position, text) = get_text_and_position(query.as_str().into()); + + let tree = get_tree(text.as_str()); + let params = SanitizedCompletionParams { + position: (position as u32).into(), + text, + tree: std::borrow::Cow::Owned(tree), + schema: &pgt_schema_cache::SchemaCache::default(), + }; + + let ctx = CompletionContext::new(¶ms); + + assert_eq!(ctx.is_invocation, is_invocation); + } + } + + #[test] + fn does_not_fail_on_leading_whitespace() { + let cases = vec![ + format!("{} select * from", CURSOR_POS), + format!(" {} select * from", CURSOR_POS), + ]; + + for query in cases { + let (position, text) = get_text_and_position(query.as_str().into()); + + let tree = get_tree(text.as_str()); + + let params = SanitizedCompletionParams { + position: (position as u32).into(), + text, + tree: std::borrow::Cow::Owned(tree), + schema: &pgt_schema_cache::SchemaCache::default(), + }; + + let ctx = CompletionContext::new(¶ms); + + let node = ctx.node_under_cursor.as_ref().unwrap(); + + match node { + NodeUnderCursor::TsNode(node) => { + assert_eq!( + ctx.get_ts_node_content(node), + Some(NodeText::Original("select".into())) + ); + + assert_eq!( + ctx.wrapping_clause_type, + Some(crate::context::WrappingClause::Select) + ); + } + _ => unreachable!(), + } + } + } + + #[test] + fn does_not_fail_on_trailing_whitespace() { + let query = format!("select * from {}", CURSOR_POS); + + let (position, text) = get_text_and_position(query.as_str().into()); + + let tree = get_tree(text.as_str()); + + let params = SanitizedCompletionParams { + position: (position as u32).into(), + text, + tree: std::borrow::Cow::Owned(tree), + schema: &pgt_schema_cache::SchemaCache::default(), + }; + + let ctx = CompletionContext::new(¶ms); + + let node = ctx.node_under_cursor.as_ref().unwrap(); + + match node { + NodeUnderCursor::TsNode(node) => { + assert_eq!( + ctx.get_ts_node_content(node), + Some(NodeText::Original("from".into())) + ); + } + _ => unreachable!(), + } + } + + #[test] + fn does_not_fail_with_empty_statements() { + let query = format!("{}", CURSOR_POS); + + let (position, text) = get_text_and_position(query.as_str().into()); + + let tree = get_tree(text.as_str()); + + let params = SanitizedCompletionParams { + position: (position as u32).into(), + text, + tree: std::borrow::Cow::Owned(tree), + schema: &pgt_schema_cache::SchemaCache::default(), + }; + + let ctx = CompletionContext::new(¶ms); + + let node = ctx.node_under_cursor.as_ref().unwrap(); + + match node { + NodeUnderCursor::TsNode(node) => { + assert_eq!( + ctx.get_ts_node_content(node), + Some(NodeText::Original("".into())) + ); + assert_eq!(ctx.wrapping_clause_type, None); + } + _ => unreachable!(), + } + } + + #[test] + fn does_not_fail_on_incomplete_keywords() { + // Instead of autocompleting "FROM", we'll assume that the user + // is selecting a certain column name, such as `frozen_account`. + let query = format!("select * fro{}", CURSOR_POS); + + let (position, text) = get_text_and_position(query.as_str().into()); + + let tree = get_tree(text.as_str()); + + let params = SanitizedCompletionParams { + position: (position as u32).into(), + text, + tree: std::borrow::Cow::Owned(tree), + schema: &pgt_schema_cache::SchemaCache::default(), + }; + + let ctx = CompletionContext::new(¶ms); + + let node = ctx.node_under_cursor.as_ref().unwrap(); + + match node { + NodeUnderCursor::TsNode(node) => { + assert_eq!( + ctx.get_ts_node_content(node), + Some(NodeText::Original("fro".into())) + ); + assert_eq!(ctx.wrapping_clause_type, Some(WrappingClause::Select)); + } + _ => unreachable!(), + } + } +} diff --git a/crates/pgt_completions/src/context/policy_parser.rs b/crates/pgt_completions/src/context/policy_parser.rs index 600a08a6..db37a13f 100644 --- a/crates/pgt_completions/src/context/policy_parser.rs +++ b/crates/pgt_completions/src/context/policy_parser.rs @@ -78,12 +78,14 @@ fn sql_to_words(sql: &str) -> Result, String> { } } - if !current_word.is_empty() && start_of_word.is_some() { - words.push(WordWithIndex { - word: current_word, - start: start_of_word.unwrap(), - end: sql.len(), - }); + if let Some(start_of_word) = start_of_word { + if !current_word.is_empty() { + words.push(WordWithIndex { + word: current_word, + start: start_of_word, + end: sql.len(), + }); + } } if in_quotation_marks { @@ -190,14 +192,10 @@ impl PolicyParser { self.context.node_range = range_without_schema; self.context.node_kind = "policy_table".into(); - self.context.node_text = match table_name { - Some(node_text) => node_text, - - // In practice, this should never happen. - // The completion sanitization will add a word after a `.` if nothing follows it; - // the token_text will then look like `schema.REPLACED_TOKEN`. - None => String::new(), - }; + // In practice, we should always have a table name. + // The completion sanitization will add a word after a `.` if nothing follows it; + // the token_text will then look like `schema.REPLACED_TOKEN`. + self.context.node_text = table_name.unwrap_or_default(); } else { self.context.node_range = token.get_range(); self.context.node_text = token.word; @@ -258,7 +256,7 @@ impl PolicyParser { } fn table_with_schema(&mut self) { - self.advance().map(|token| { + if let Some(token) = self.advance() { if token.is_under_cursor(self.cursor_position) { self.handle_token_under_cursor(token); } else if token.word.contains('.') { @@ -268,7 +266,7 @@ impl PolicyParser { } else { self.context.table_name = Some(token.word); } - }); + }; } fn schema_and_table_name(&self, token: &WordWithIndex) -> (String, Option) { diff --git a/crates/pgt_completions/src/relevance/filtering.rs b/crates/pgt_completions/src/relevance/filtering.rs index 071e2994..da770301 100644 --- a/crates/pgt_completions/src/relevance/filtering.rs +++ b/crates/pgt_completions/src/relevance/filtering.rs @@ -150,7 +150,7 @@ impl CompletionFilter<'_> { } // no aliases and schemas for policies - CompletionRelevanceData::Policy(p) => p.schema_name == p.schema_name, + CompletionRelevanceData::Policy(p) => &p.schema_name == schema_or_alias, }; if !matches { From 1971f6c88d1b8ed64ad3e1ba8b9c35d99ecc3453 Mon Sep 17 00:00:00 2001 From: Julian Date: Tue, 13 May 2025 09:50:59 +0200 Subject: [PATCH 15/18] this is the way --- crates/pgt_completions/src/relevance/filtering.rs | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/crates/pgt_completions/src/relevance/filtering.rs b/crates/pgt_completions/src/relevance/filtering.rs index da770301..066145c4 100644 --- a/crates/pgt_completions/src/relevance/filtering.rs +++ b/crates/pgt_completions/src/relevance/filtering.rs @@ -144,13 +144,10 @@ impl CompletionFilter<'_> { .get(schema_or_alias) .is_some_and(|t| t == &col.table_name), - CompletionRelevanceData::Schema(_) => { - // we should never allow schema suggestions if there already was one. - false - } - + // we should never allow schema suggestions if there already was one. + CompletionRelevanceData::Schema(_) => false, // no aliases and schemas for policies - CompletionRelevanceData::Policy(p) => &p.schema_name == schema_or_alias, + CompletionRelevanceData::Policy(_) => false, }; if !matches { From 74cbacfee3605a58fad0bd31550ddca6de4945e6 Mon Sep 17 00:00:00 2001 From: Julian Date: Tue, 13 May 2025 09:52:01 +0200 Subject: [PATCH 16/18] comment --- crates/pgt_completions/src/relevance/filtering.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/pgt_completions/src/relevance/filtering.rs b/crates/pgt_completions/src/relevance/filtering.rs index 066145c4..3b148336 100644 --- a/crates/pgt_completions/src/relevance/filtering.rs +++ b/crates/pgt_completions/src/relevance/filtering.rs @@ -146,7 +146,7 @@ impl CompletionFilter<'_> { // we should never allow schema suggestions if there already was one. CompletionRelevanceData::Schema(_) => false, - // no aliases and schemas for policies + // no policy comletion if user typed a schema node first. CompletionRelevanceData::Policy(_) => false, }; From 4fbec6e7b14a41835b5a2427925fe25e98c19d87 Mon Sep 17 00:00:00 2001 From: Julian Date: Tue, 20 May 2025 09:25:28 +0200 Subject: [PATCH 17/18] ??? --- .github/workflows/pull_request.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/pull_request.yml b/.github/workflows/pull_request.yml index 4600ac92..1662471e 100644 --- a/.github/workflows/pull_request.yml +++ b/.github/workflows/pull_request.yml @@ -220,8 +220,11 @@ jobs: uses: moonrepo/setup-rust@v1 with: cache-base: main + components: rustfmt env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: echo toolchain + run: rustup show - name: Run the analyser codegen run: cargo run -p xtask_codegen -- analyser - name: Run the configuration codegen From 602c02b7faa8d3876ca3e874dfed220f65d1a937 Mon Sep 17 00:00:00 2001 From: Julian Date: Thu, 22 May 2025 08:57:39 +0200 Subject: [PATCH 18/18] like so? --- .github/workflows/pull_request.yml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/pull_request.yml b/.github/workflows/pull_request.yml index 1662471e..f79392b7 100644 --- a/.github/workflows/pull_request.yml +++ b/.github/workflows/pull_request.yml @@ -184,6 +184,8 @@ jobs: uses: ./.github/actions/free-disk-space - name: Install toolchain uses: moonrepo/setup-rust@v1 + with: + cache-base: main - name: Build main binary run: cargo build -p pgt_cli --release - name: Setup Bun @@ -220,9 +222,10 @@ jobs: uses: moonrepo/setup-rust@v1 with: cache-base: main - components: rustfmt env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Ensure RustFMT on nightly toolchain + run: rustup component add rustfmt --toolchain nightly - name: echo toolchain run: rustup show - name: Run the analyser codegen