diff --git a/.github/workflows/pull_request.yml b/.github/workflows/pull_request.yml index 4600ac92..f79392b7 100644 --- a/.github/workflows/pull_request.yml +++ b/.github/workflows/pull_request.yml @@ -184,6 +184,8 @@ jobs: uses: ./.github/actions/free-disk-space - name: Install toolchain uses: moonrepo/setup-rust@v1 + with: + cache-base: main - name: Build main binary run: cargo build -p pgt_cli --release - name: Setup Bun @@ -222,6 +224,10 @@ jobs: cache-base: main env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Ensure RustFMT on nightly toolchain + run: rustup component add rustfmt --toolchain nightly + - name: echo toolchain + run: rustup show - name: Run the analyser codegen run: cargo run -p xtask_codegen -- analyser - name: Run the configuration codegen diff --git a/crates/pgt_completions/src/complete.rs b/crates/pgt_completions/src/complete.rs index 442ee546..5bc5d41c 100644 --- a/crates/pgt_completions/src/complete.rs +++ b/crates/pgt_completions/src/complete.rs @@ -4,7 +4,9 @@ use crate::{ builder::CompletionBuilder, context::CompletionContext, item::CompletionItem, - providers::{complete_columns, complete_functions, complete_schemas, complete_tables}, + providers::{ + complete_columns, complete_functions, complete_policies, complete_schemas, complete_tables, + }, sanitization::SanitizedCompletionParams, }; @@ -33,6 +35,7 @@ pub fn complete(params: CompletionParams) -> Vec { complete_functions(&ctx, &mut builder); complete_columns(&ctx, &mut builder); complete_schemas(&ctx, &mut builder); + complete_policies(&ctx, &mut builder); builder.finish() } diff --git a/crates/pgt_completions/src/context.rs b/crates/pgt_completions/src/context/mod.rs similarity index 77% rename from crates/pgt_completions/src/context.rs rename to crates/pgt_completions/src/context/mod.rs index a17cafa2..23a6fcae 100644 --- a/crates/pgt_completions/src/context.rs +++ b/crates/pgt_completions/src/context/mod.rs @@ -1,12 +1,19 @@ +mod policy_parser; + use std::collections::{HashMap, HashSet}; use pgt_schema_cache::SchemaCache; +use pgt_text_size::TextRange; use pgt_treesitter_queries::{ TreeSitterQueriesExecutor, queries::{self, QueryResult}, }; -use crate::sanitization::SanitizedCompletionParams; +use crate::{ + NodeText, + context::policy_parser::{PolicyParser, PolicyStmtKind}, + sanitization::SanitizedCompletionParams, +}; #[derive(Debug, PartialEq, Eq, Hash)] pub enum WrappingClause<'a> { @@ -18,12 +25,8 @@ pub enum WrappingClause<'a> { }, Update, Delete, -} - -#[derive(PartialEq, Eq, Debug)] -pub(crate) enum NodeText<'a> { - Replaced, - Original(&'a str), + PolicyName, + ToRoleAssignment, } #[derive(PartialEq, Eq, Hash, Debug)] @@ -47,6 +50,45 @@ pub enum WrappingNode { Assignment, } +#[derive(Debug)] +pub(crate) enum NodeUnderCursor<'a> { + TsNode(tree_sitter::Node<'a>), + CustomNode { + text: NodeText, + range: TextRange, + kind: String, + }, +} + +impl NodeUnderCursor<'_> { + pub fn start_byte(&self) -> usize { + match self { + NodeUnderCursor::TsNode(node) => node.start_byte(), + NodeUnderCursor::CustomNode { range, .. } => range.start().into(), + } + } + + pub fn end_byte(&self) -> usize { + match self { + NodeUnderCursor::TsNode(node) => node.end_byte(), + NodeUnderCursor::CustomNode { range, .. } => range.end().into(), + } + } + + pub fn kind(&self) -> &str { + match self { + NodeUnderCursor::TsNode(node) => node.kind(), + NodeUnderCursor::CustomNode { kind, .. } => kind.as_str(), + } + } +} + +impl<'a> From> for NodeUnderCursor<'a> { + fn from(node: tree_sitter::Node<'a>) -> Self { + NodeUnderCursor::TsNode(node) + } +} + impl TryFrom<&str> for WrappingNode { type Error = String; @@ -77,7 +119,7 @@ impl TryFrom for WrappingNode { } pub(crate) struct CompletionContext<'a> { - pub node_under_cursor: Option>, + pub node_under_cursor: Option>, pub tree: &'a tree_sitter::Tree, pub text: &'a str, @@ -137,12 +179,49 @@ impl<'a> CompletionContext<'a> { is_in_error_node: false, }; - ctx.gather_tree_context(); - ctx.gather_info_from_ts_queries(); + // policy handling is important to Supabase, but they are a PostgreSQL specific extension, + // so the tree_sitter_sql language does not support it. + // We infer the context manually. + if PolicyParser::looks_like_policy_stmt(¶ms.text) { + ctx.gather_policy_context(); + } else { + ctx.gather_tree_context(); + ctx.gather_info_from_ts_queries(); + } ctx } + fn gather_policy_context(&mut self) { + let policy_context = PolicyParser::get_context(self.text, self.position); + + self.node_under_cursor = Some(NodeUnderCursor::CustomNode { + text: policy_context.node_text.into(), + range: policy_context.node_range, + kind: policy_context.node_kind.clone(), + }); + + if policy_context.node_kind == "policy_table" { + self.schema_or_alias_name = policy_context.schema_name.clone(); + } + + if policy_context.table_name.is_some() { + let mut new = HashSet::new(); + new.insert(policy_context.table_name.unwrap()); + self.mentioned_relations + .insert(policy_context.schema_name, new); + } + + self.wrapping_clause_type = match policy_context.node_kind.as_str() { + "policy_name" if policy_context.statement_kind != PolicyStmtKind::Create => { + Some(WrappingClause::PolicyName) + } + "policy_role" => Some(WrappingClause::ToRoleAssignment), + "policy_table" => Some(WrappingClause::From), + _ => None, + }; + } + fn gather_info_from_ts_queries(&mut self) { let stmt_range = self.wrapping_statement_range.as_ref(); let sql = self.text; @@ -195,24 +274,30 @@ impl<'a> CompletionContext<'a> { } } - pub fn get_ts_node_content(&self, ts_node: tree_sitter::Node<'a>) -> Option> { + fn get_ts_node_content(&self, ts_node: &tree_sitter::Node<'a>) -> Option { let source = self.text; ts_node.utf8_text(source.as_bytes()).ok().map(|txt| { if SanitizedCompletionParams::is_sanitized_token(txt) { NodeText::Replaced } else { - NodeText::Original(txt) + NodeText::Original(txt.into()) } }) } pub fn get_node_under_cursor_content(&self) -> Option { - self.node_under_cursor - .and_then(|n| self.get_ts_node_content(n)) - .and_then(|txt| match txt { + match self.node_under_cursor.as_ref()? { + NodeUnderCursor::TsNode(node) => { + self.get_ts_node_content(node).and_then(|nt| match nt { + NodeText::Replaced => None, + NodeText::Original(c) => Some(c.to_string()), + }) + } + NodeUnderCursor::CustomNode { text, .. } => match text { NodeText::Replaced => None, NodeText::Original(c) => Some(c.to_string()), - }) + }, + } } fn gather_tree_context(&mut self) { @@ -250,7 +335,7 @@ impl<'a> CompletionContext<'a> { // prevent infinite recursion – this can happen if we only have a PROGRAM node if current_node_kind == parent_node_kind { - self.node_under_cursor = Some(current_node); + self.node_under_cursor = Some(NodeUnderCursor::from(current_node)); return; } @@ -289,7 +374,7 @@ impl<'a> CompletionContext<'a> { match current_node_kind { "object_reference" | "field" => { - let content = self.get_ts_node_content(current_node); + let content = self.get_ts_node_content(¤t_node); if let Some(node_txt) = content { match node_txt { NodeText::Original(txt) => { @@ -321,7 +406,7 @@ impl<'a> CompletionContext<'a> { // We have arrived at the leaf node if current_node.child_count() == 0 { - self.node_under_cursor = Some(current_node); + self.node_under_cursor = Some(NodeUnderCursor::from(current_node)); return; } @@ -334,11 +419,11 @@ impl<'a> CompletionContext<'a> { node: tree_sitter::Node<'a>, ) -> Option> { if node.kind().starts_with("keyword_") { - if let Some(txt) = self.get_ts_node_content(node).and_then(|txt| match txt { + if let Some(txt) = self.get_ts_node_content(&node).and_then(|txt| match txt { NodeText::Original(txt) => Some(txt), NodeText::Replaced => None, }) { - match txt { + match txt.as_str() { "where" => return Some(WrappingClause::Where), "update" => return Some(WrappingClause::Update), "select" => return Some(WrappingClause::Select), @@ -388,11 +473,14 @@ impl<'a> CompletionContext<'a> { #[cfg(test)] mod tests { use crate::{ - context::{CompletionContext, NodeText, WrappingClause}, + NodeText, + context::{CompletionContext, WrappingClause}, sanitization::SanitizedCompletionParams, test_helper::{CURSOR_POS, get_text_and_position}, }; + use super::NodeUnderCursor; + fn get_tree(input: &str) -> tree_sitter::Tree { let mut parser = tree_sitter::Parser::new(); parser @@ -551,17 +639,22 @@ mod tests { let ctx = CompletionContext::new(¶ms); - let node = ctx.node_under_cursor.unwrap(); + let node = ctx.node_under_cursor.as_ref().unwrap(); - assert_eq!( - ctx.get_ts_node_content(node), - Some(NodeText::Original("select")) - ); + match node { + NodeUnderCursor::TsNode(node) => { + assert_eq!( + ctx.get_ts_node_content(node), + Some(NodeText::Original("select".into())) + ); - assert_eq!( - ctx.wrapping_clause_type, - Some(crate::context::WrappingClause::Select) - ); + assert_eq!( + ctx.wrapping_clause_type, + Some(crate::context::WrappingClause::Select) + ); + } + _ => unreachable!(), + } } } @@ -582,12 +675,17 @@ mod tests { let ctx = CompletionContext::new(¶ms); - let node = ctx.node_under_cursor.unwrap(); + let node = ctx.node_under_cursor.as_ref().unwrap(); - assert_eq!( - ctx.get_ts_node_content(node), - Some(NodeText::Original("from")) - ); + match node { + NodeUnderCursor::TsNode(node) => { + assert_eq!( + ctx.get_ts_node_content(node), + Some(NodeText::Original("from".into())) + ); + } + _ => unreachable!(), + } } #[test] @@ -607,10 +705,18 @@ mod tests { let ctx = CompletionContext::new(¶ms); - let node = ctx.node_under_cursor.unwrap(); + let node = ctx.node_under_cursor.as_ref().unwrap(); - assert_eq!(ctx.get_ts_node_content(node), Some(NodeText::Original(""))); - assert_eq!(ctx.wrapping_clause_type, None); + match node { + NodeUnderCursor::TsNode(node) => { + assert_eq!( + ctx.get_ts_node_content(node), + Some(NodeText::Original("".into())) + ); + assert_eq!(ctx.wrapping_clause_type, None); + } + _ => unreachable!(), + } } #[test] @@ -632,12 +738,17 @@ mod tests { let ctx = CompletionContext::new(¶ms); - let node = ctx.node_under_cursor.unwrap(); + let node = ctx.node_under_cursor.as_ref().unwrap(); - assert_eq!( - ctx.get_ts_node_content(node), - Some(NodeText::Original("fro")) - ); - assert_eq!(ctx.wrapping_clause_type, Some(WrappingClause::Select)); + match node { + NodeUnderCursor::TsNode(node) => { + assert_eq!( + ctx.get_ts_node_content(node), + Some(NodeText::Original("fro".into())) + ); + assert_eq!(ctx.wrapping_clause_type, Some(WrappingClause::Select)); + } + _ => unreachable!(), + } } } diff --git a/crates/pgt_completions/src/context/policy_parser.rs b/crates/pgt_completions/src/context/policy_parser.rs new file mode 100644 index 00000000..db37a13f --- /dev/null +++ b/crates/pgt_completions/src/context/policy_parser.rs @@ -0,0 +1,617 @@ +use std::iter::Peekable; + +use pgt_text_size::{TextRange, TextSize}; + +#[derive(Default, Debug, PartialEq, Eq)] +pub(crate) enum PolicyStmtKind { + #[default] + Create, + + Alter, + Drop, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +struct WordWithIndex { + word: String, + start: usize, + end: usize, +} + +impl WordWithIndex { + fn is_under_cursor(&self, cursor_pos: usize) -> bool { + self.start <= cursor_pos && self.end > cursor_pos + } + + fn get_range(&self) -> TextRange { + let start: u32 = self.start.try_into().expect("Text too long"); + let end: u32 = self.end.try_into().expect("Text too long"); + TextRange::new(TextSize::from(start), TextSize::from(end)) + } +} + +/// Note: A policy name within quotation marks will be considered a single word. +fn sql_to_words(sql: &str) -> Result, String> { + let mut words = vec![]; + + let mut start_of_word: Option = None; + let mut current_word = String::new(); + let mut in_quotation_marks = false; + + for (current_position, current_char) in sql.char_indices() { + if (current_char.is_ascii_whitespace() || current_char == ';') + && !current_word.is_empty() + && start_of_word.is_some() + && !in_quotation_marks + { + words.push(WordWithIndex { + word: current_word, + start: start_of_word.unwrap(), + end: current_position, + }); + + current_word = String::new(); + start_of_word = None; + } else if (current_char.is_ascii_whitespace() || current_char == ';') + && current_word.is_empty() + { + // do nothing + } else if current_char == '"' && start_of_word.is_none() { + in_quotation_marks = true; + current_word.push(current_char); + start_of_word = Some(current_position); + } else if current_char == '"' && start_of_word.is_some() { + current_word.push(current_char); + words.push(WordWithIndex { + word: current_word, + start: start_of_word.unwrap(), + end: current_position + 1, + }); + in_quotation_marks = false; + start_of_word = None; + current_word = String::new() + } else if start_of_word.is_some() { + current_word.push(current_char) + } else { + start_of_word = Some(current_position); + current_word.push(current_char); + } + } + + if let Some(start_of_word) = start_of_word { + if !current_word.is_empty() { + words.push(WordWithIndex { + word: current_word, + start: start_of_word, + end: sql.len(), + }); + } + } + + if in_quotation_marks { + Err("String was not closed properly.".into()) + } else { + Ok(words) + } +} + +#[derive(Default, Debug, PartialEq, Eq)] +pub(crate) struct PolicyContext { + pub policy_name: Option, + pub table_name: Option, + pub schema_name: Option, + pub statement_kind: PolicyStmtKind, + pub node_text: String, + pub node_range: TextRange, + pub node_kind: String, +} + +/// Simple parser that'll turn a policy-related statement into a context object required for +/// completions. +/// The parser will only work if the (trimmed) sql starts with `create policy`, `drop policy`, or `alter policy`. +/// It can only parse policy statements. +pub(crate) struct PolicyParser { + tokens: Peekable>, + previous_token: Option, + current_token: Option, + context: PolicyContext, + cursor_position: usize, +} + +impl PolicyParser { + pub(crate) fn looks_like_policy_stmt(sql: &str) -> bool { + let lowercased = sql.to_ascii_lowercase(); + let trimmed = lowercased.trim(); + trimmed.starts_with("create policy") + || trimmed.starts_with("drop policy") + || trimmed.starts_with("alter policy") + } + + pub(crate) fn get_context(sql: &str, cursor_position: usize) -> PolicyContext { + assert!( + Self::looks_like_policy_stmt(sql), + "PolicyParser should only be used for policy statements. Developer error!" + ); + + match sql_to_words(sql) { + Ok(tokens) => { + let parser = PolicyParser { + tokens: tokens.into_iter().peekable(), + context: PolicyContext::default(), + previous_token: None, + current_token: None, + cursor_position, + }; + + parser.parse() + } + Err(_) => PolicyContext::default(), + } + } + + fn parse(mut self) -> PolicyContext { + while let Some(token) = self.advance() { + if token.is_under_cursor(self.cursor_position) { + self.handle_token_under_cursor(token); + } else { + self.handle_token(token); + } + } + + self.context + } + + fn handle_token_under_cursor(&mut self, token: WordWithIndex) { + if self.previous_token.is_none() { + return; + } + + let previous = self.previous_token.take().unwrap(); + + match previous.word.to_ascii_lowercase().as_str() { + "policy" => { + self.context.node_range = token.get_range(); + self.context.node_kind = "policy_name".into(); + self.context.node_text = token.word; + } + "on" => { + if token.word.contains('.') { + let (schema_name, table_name) = self.schema_and_table_name(&token); + + let schema_name_len = schema_name.len(); + self.context.schema_name = Some(schema_name); + + let offset: u32 = schema_name_len.try_into().expect("Text too long"); + let range_without_schema = token + .get_range() + .checked_expand_start( + TextSize::new(offset + 1), // kill the dot as well + ) + .expect("Text too long"); + + self.context.node_range = range_without_schema; + self.context.node_kind = "policy_table".into(); + + // In practice, we should always have a table name. + // The completion sanitization will add a word after a `.` if nothing follows it; + // the token_text will then look like `schema.REPLACED_TOKEN`. + self.context.node_text = table_name.unwrap_or_default(); + } else { + self.context.node_range = token.get_range(); + self.context.node_text = token.word; + self.context.node_kind = "policy_table".into(); + } + } + "to" => { + self.context.node_range = token.get_range(); + self.context.node_kind = "policy_role".into(); + self.context.node_text = token.word; + } + _ => { + self.context.node_range = token.get_range(); + self.context.node_text = token.word; + } + } + } + + fn handle_token(&mut self, token: WordWithIndex) { + match token.word.to_ascii_lowercase().as_str() { + "create" if self.next_matches("policy") => { + self.context.statement_kind = PolicyStmtKind::Create; + } + "alter" if self.next_matches("policy") => { + self.context.statement_kind = PolicyStmtKind::Alter; + } + "drop" if self.next_matches("policy") => { + self.context.statement_kind = PolicyStmtKind::Drop; + } + "on" => self.table_with_schema(), + + // skip the "to" so we don't parse it as the TO rolename when it's under the cursor + "rename" if self.next_matches("to") => { + self.advance(); + } + + _ => { + if self.prev_matches("policy") { + self.context.policy_name = Some(token.word); + } + } + } + } + + fn next_matches(&mut self, it: &str) -> bool { + self.tokens.peek().is_some_and(|c| c.word.as_str() == it) + } + + fn prev_matches(&self, it: &str) -> bool { + self.previous_token.as_ref().is_some_and(|t| t.word == it) + } + + fn advance(&mut self) -> Option { + // we can't peek back n an iterator, so we'll have to keep track manually. + self.previous_token = self.current_token.take(); + self.current_token = self.tokens.next(); + self.current_token.clone() + } + + fn table_with_schema(&mut self) { + if let Some(token) = self.advance() { + if token.is_under_cursor(self.cursor_position) { + self.handle_token_under_cursor(token); + } else if token.word.contains('.') { + let (schema, maybe_table) = self.schema_and_table_name(&token); + self.context.schema_name = Some(schema); + self.context.table_name = maybe_table; + } else { + self.context.table_name = Some(token.word); + } + }; + } + + fn schema_and_table_name(&self, token: &WordWithIndex) -> (String, Option) { + let mut parts = token.word.split('.'); + + ( + parts.next().unwrap().into(), + parts.next().map(|tb| tb.into()), + ) + } +} + +#[cfg(test)] +mod tests { + use pgt_text_size::{TextRange, TextSize}; + + use crate::{ + context::policy_parser::{PolicyContext, PolicyStmtKind, WordWithIndex}, + test_helper::CURSOR_POS, + }; + + use super::{PolicyParser, sql_to_words}; + + fn with_pos(query: String) -> (usize, String) { + let mut pos: Option = None; + + for (p, c) in query.char_indices() { + if c == CURSOR_POS { + pos = Some(p); + break; + } + } + + ( + pos.expect("Please add cursor position!"), + query.replace(CURSOR_POS, "REPLACED_TOKEN").to_string(), + ) + } + + #[test] + fn infers_progressively() { + let (pos, query) = with_pos(format!( + r#" + create policy {} + "#, + CURSOR_POS + )); + + let context = PolicyParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + PolicyContext { + policy_name: None, + table_name: None, + schema_name: None, + statement_kind: PolicyStmtKind::Create, + node_text: "REPLACED_TOKEN".into(), + node_range: TextRange::new(TextSize::new(25), TextSize::new(39)), + node_kind: "policy_name".into() + } + ); + + let (pos, query) = with_pos(format!( + r#" + create policy "my cool policy" {} + "#, + CURSOR_POS + )); + + let context = PolicyParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + PolicyContext { + policy_name: Some("\"my cool policy\"".into()), + table_name: None, + schema_name: None, + statement_kind: PolicyStmtKind::Create, + node_text: "REPLACED_TOKEN".into(), + node_kind: "".into(), + node_range: TextRange::new(TextSize::new(42), TextSize::new(56)), + } + ); + + let (pos, query) = with_pos(format!( + r#" + create policy "my cool policy" on {} + "#, + CURSOR_POS + )); + + let context = PolicyParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + PolicyContext { + policy_name: Some("\"my cool policy\"".into()), + table_name: None, + schema_name: None, + statement_kind: PolicyStmtKind::Create, + node_text: "REPLACED_TOKEN".into(), + node_kind: "policy_table".into(), + node_range: TextRange::new(TextSize::new(45), TextSize::new(59)), + } + ); + + let (pos, query) = with_pos(format!( + r#" + create policy "my cool policy" on auth.{} + "#, + CURSOR_POS + )); + + let context = PolicyParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + PolicyContext { + policy_name: Some("\"my cool policy\"".into()), + table_name: None, + schema_name: Some("auth".into()), + statement_kind: PolicyStmtKind::Create, + node_text: "REPLACED_TOKEN".into(), + node_kind: "policy_table".into(), + node_range: TextRange::new(TextSize::new(50), TextSize::new(64)), + } + ); + + let (pos, query) = with_pos(format!( + r#" + create policy "my cool policy" on auth.users + as {} + "#, + CURSOR_POS + )); + + let context = PolicyParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + PolicyContext { + policy_name: Some("\"my cool policy\"".into()), + table_name: Some("users".into()), + schema_name: Some("auth".into()), + statement_kind: PolicyStmtKind::Create, + node_text: "REPLACED_TOKEN".into(), + node_kind: "".into(), + node_range: TextRange::new(TextSize::new(72), TextSize::new(86)), + } + ); + + let (pos, query) = with_pos(format!( + r#" + create policy "my cool policy" on auth.users + as permissive + {} + "#, + CURSOR_POS + )); + + let context = PolicyParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + PolicyContext { + policy_name: Some("\"my cool policy\"".into()), + table_name: Some("users".into()), + schema_name: Some("auth".into()), + statement_kind: PolicyStmtKind::Create, + node_text: "REPLACED_TOKEN".into(), + node_kind: "".into(), + node_range: TextRange::new(TextSize::new(95), TextSize::new(109)), + } + ); + + let (pos, query) = with_pos(format!( + r#" + create policy "my cool policy" on auth.users + as permissive + to {} + "#, + CURSOR_POS + )); + + let context = PolicyParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + PolicyContext { + policy_name: Some("\"my cool policy\"".into()), + table_name: Some("users".into()), + schema_name: Some("auth".into()), + statement_kind: PolicyStmtKind::Create, + node_text: "REPLACED_TOKEN".into(), + node_kind: "policy_role".into(), + node_range: TextRange::new(TextSize::new(98), TextSize::new(112)), + } + ); + } + + #[test] + fn determines_on_table_node() { + let (pos, query) = with_pos(format!( + r#" + create policy "my cool policy" + on {} + to all + using (true); + "#, + CURSOR_POS + )); + + let context = PolicyParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + PolicyContext { + policy_name: Some(r#""my cool policy""#.into()), + table_name: None, + schema_name: None, + statement_kind: PolicyStmtKind::Create, + node_text: "REPLACED_TOKEN".into(), + node_range: TextRange::new(TextSize::new(57), TextSize::new(71)), + node_kind: "policy_table".into() + } + ) + } + + #[test] + fn determines_on_table_node_after_schema() { + let (pos, query) = with_pos(format!( + r#" + create policy "my cool policy" + on auth.{} + to all + using (true); + "#, + CURSOR_POS + )); + + let context = PolicyParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + PolicyContext { + policy_name: Some(r#""my cool policy""#.into()), + table_name: None, + schema_name: Some("auth".into()), + statement_kind: PolicyStmtKind::Create, + node_text: "REPLACED_TOKEN".into(), + node_range: TextRange::new(TextSize::new(62), TextSize::new(76)), + node_kind: "policy_table".into() + } + ) + } + + #[test] + fn determines_we_are_on_column_name() { + let (pos, query) = with_pos(format!( + r#" + drop policy {} on auth.users; + "#, + CURSOR_POS + )); + + let context = PolicyParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + PolicyContext { + policy_name: None, + table_name: Some("users".into()), + schema_name: Some("auth".into()), + statement_kind: PolicyStmtKind::Drop, + node_text: "REPLACED_TOKEN".into(), + node_range: TextRange::new(TextSize::new(23), TextSize::new(37)), + node_kind: "policy_name".into() + } + ); + + // cursor within quotation marks. + let (pos, query) = with_pos(format!( + r#" + drop policy "{}" on auth.users; + "#, + CURSOR_POS + )); + + let context = PolicyParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + PolicyContext { + policy_name: None, + table_name: Some("users".into()), + schema_name: Some("auth".into()), + statement_kind: PolicyStmtKind::Drop, + node_text: "\"REPLACED_TOKEN\"".into(), + node_range: TextRange::new(TextSize::new(23), TextSize::new(39)), + node_kind: "policy_name".into() + } + ); + } + + #[test] + fn single_quotation_mark_does_not_fail() { + let (pos, query) = with_pos(format!( + r#" + drop policy "{} on auth.users; + "#, + CURSOR_POS + )); + + let context = PolicyParser::get_context(query.as_str(), pos); + + assert_eq!(context, PolicyContext::default()); + } + + fn to_word(word: &str, start: usize, end: usize) -> WordWithIndex { + WordWithIndex { + word: word.into(), + start, + end, + } + } + + #[test] + fn determines_positions_correctly() { + let query = "\ncreate policy \"my cool pol\"\n\ton auth.users\n\tas permissive\n\tfor select\n\t\tto public\n\t\tusing (true);".to_string(); + + let words = sql_to_words(query.as_str()).unwrap(); + + assert_eq!(words[0], to_word("create", 1, 7)); + assert_eq!(words[1], to_word("policy", 8, 14)); + assert_eq!(words[2], to_word("\"my cool pol\"", 15, 28)); + assert_eq!(words[3], to_word("on", 30, 32)); + assert_eq!(words[4], to_word("auth.users", 33, 43)); + assert_eq!(words[5], to_word("as", 45, 47)); + assert_eq!(words[6], to_word("permissive", 48, 58)); + assert_eq!(words[7], to_word("for", 60, 63)); + assert_eq!(words[8], to_word("select", 64, 70)); + assert_eq!(words[9], to_word("to", 73, 75)); + assert_eq!(words[10], to_word("public", 78, 84)); + assert_eq!(words[11], to_word("using", 87, 92)); + assert_eq!(words[12], to_word("(true)", 93, 99)); + } +} diff --git a/crates/pgt_completions/src/item.rs b/crates/pgt_completions/src/item.rs index f37d0efb..702fc766 100644 --- a/crates/pgt_completions/src/item.rs +++ b/crates/pgt_completions/src/item.rs @@ -11,6 +11,7 @@ pub enum CompletionItemKind { Function, Column, Schema, + Policy, } impl Display for CompletionItemKind { @@ -20,6 +21,7 @@ impl Display for CompletionItemKind { CompletionItemKind::Function => "Function", CompletionItemKind::Column => "Column", CompletionItemKind::Schema => "Schema", + CompletionItemKind::Policy => "Policy", }; write!(f, "{txt}") diff --git a/crates/pgt_completions/src/providers/helper.rs b/crates/pgt_completions/src/providers/helper.rs index 999d6b37..eacb8314 100644 --- a/crates/pgt_completions/src/providers/helper.rs +++ b/crates/pgt_completions/src/providers/helper.rs @@ -1,6 +1,6 @@ use pgt_text_size::{TextRange, TextSize}; -use crate::{CompletionText, context::CompletionContext}; +use crate::{CompletionText, context::CompletionContext, remove_sanitized_token}; pub(crate) fn find_matching_alias_for_table( ctx: &CompletionContext, @@ -14,6 +14,21 @@ pub(crate) fn find_matching_alias_for_table( None } +pub(crate) fn get_range_to_replace(ctx: &CompletionContext) -> TextRange { + match ctx.node_under_cursor.as_ref() { + Some(node) => { + let content = ctx.get_node_under_cursor_content().unwrap_or("".into()); + let length = remove_sanitized_token(content.as_str()).len(); + + let start = node.start_byte(); + let end = start + length; + + TextRange::new(start.try_into().unwrap(), end.try_into().unwrap()) + } + None => TextRange::empty(TextSize::new(0)), + } +} + pub(crate) fn get_completion_text_with_schema_or_alias( ctx: &CompletionContext, item_name: &str, @@ -22,12 +37,7 @@ pub(crate) fn get_completion_text_with_schema_or_alias( if schema_or_alias_name == "public" || ctx.schema_or_alias_name.is_some() { None } else { - let node = ctx.node_under_cursor.unwrap(); - - let range = TextRange::new( - TextSize::try_from(node.start_byte()).unwrap(), - TextSize::try_from(node.end_byte()).unwrap(), - ); + let range = get_range_to_replace(ctx); Some(CompletionText { text: format!("{}.{}", schema_or_alias_name, item_name), diff --git a/crates/pgt_completions/src/providers/mod.rs b/crates/pgt_completions/src/providers/mod.rs index 82e32cdf..7b07cee8 100644 --- a/crates/pgt_completions/src/providers/mod.rs +++ b/crates/pgt_completions/src/providers/mod.rs @@ -1,10 +1,12 @@ mod columns; mod functions; mod helper; +mod policies; mod schemas; mod tables; pub use columns::*; pub use functions::*; +pub use policies::*; pub use schemas::*; pub use tables::*; diff --git a/crates/pgt_completions/src/providers/policies.rs b/crates/pgt_completions/src/providers/policies.rs new file mode 100644 index 00000000..2421f1f1 --- /dev/null +++ b/crates/pgt_completions/src/providers/policies.rs @@ -0,0 +1,103 @@ +use pgt_text_size::{TextRange, TextSize}; + +use crate::{ + CompletionItemKind, CompletionText, + builder::{CompletionBuilder, PossibleCompletionItem}, + context::CompletionContext, + relevance::{CompletionRelevanceData, filtering::CompletionFilter, scoring::CompletionScore}, +}; + +use super::helper::get_range_to_replace; + +pub fn complete_policies<'a>(ctx: &CompletionContext<'a>, builder: &mut CompletionBuilder<'a>) { + let available_policies = &ctx.schema_cache.policies; + + let surrounded_by_quotes = ctx + .get_node_under_cursor_content() + .is_some_and(|c| c.starts_with('"') && c.ends_with('"') && c != "\"\""); + + for pol in available_policies { + let completion_text = if surrounded_by_quotes { + // If we're within quotes, we want to change the content + // *within* the quotes. + // If we attempt to replace outside the quotes, the VSCode + // client won't show the suggestions. + let range = get_range_to_replace(ctx); + Some(CompletionText { + text: pol.name.clone(), + range: TextRange::new( + range.start() + TextSize::new(1), + range.end() - TextSize::new(1), + ), + }) + } else { + // If we aren't within quotes, we want to complete the + // full policy including quotation marks. + Some(CompletionText { + text: format!("\"{}\"", pol.name), + range: get_range_to_replace(ctx), + }) + }; + + let relevance = CompletionRelevanceData::Policy(pol); + + let item = PossibleCompletionItem { + label: pol.name.chars().take(35).collect::(), + score: CompletionScore::from(relevance.clone()), + filter: CompletionFilter::from(relevance), + description: pol.table_name.to_string(), + kind: CompletionItemKind::Policy, + completion_text, + }; + + builder.add_item(item); + } +} + +#[cfg(test)] +mod tests { + use crate::test_helper::{CURSOR_POS, CompletionAssertion, assert_complete_results}; + + #[tokio::test] + async fn completes_within_quotation_marks() { + let setup = r#" + create schema private; + + create table private.users ( + id serial primary key, + email text + ); + + create policy "read for public users disallowed" on private.users + as restrictive + for select + to public + using (false); + + create policy "write for public users allowed" on private.users + as restrictive + for insert + to public + with check (true); + "#; + + assert_complete_results( + format!("alter policy \"{}\" on private.users;", CURSOR_POS).as_str(), + vec![ + CompletionAssertion::Label("read for public users disallowed".into()), + CompletionAssertion::Label("write for public users allowed".into()), + ], + setup, + ) + .await; + + assert_complete_results( + format!("alter policy \"w{}\" on private.users;", CURSOR_POS).as_str(), + vec![CompletionAssertion::Label( + "write for public users allowed".into(), + )], + setup, + ) + .await; + } +} diff --git a/crates/pgt_completions/src/relevance.rs b/crates/pgt_completions/src/relevance.rs index 911a6433..f51c3c52 100644 --- a/crates/pgt_completions/src/relevance.rs +++ b/crates/pgt_completions/src/relevance.rs @@ -7,4 +7,5 @@ pub(crate) enum CompletionRelevanceData<'a> { Function(&'a pgt_schema_cache::Function), Column(&'a pgt_schema_cache::Column), Schema(&'a pgt_schema_cache::Schema), + Policy(&'a pgt_schema_cache::Policy), } diff --git a/crates/pgt_completions/src/relevance/filtering.rs b/crates/pgt_completions/src/relevance/filtering.rs index ec12201c..3b148336 100644 --- a/crates/pgt_completions/src/relevance/filtering.rs +++ b/crates/pgt_completions/src/relevance/filtering.rs @@ -1,4 +1,4 @@ -use crate::context::{CompletionContext, WrappingClause}; +use crate::context::{CompletionContext, NodeUnderCursor, WrappingClause}; use super::CompletionRelevanceData; @@ -24,7 +24,11 @@ impl CompletionFilter<'_> { } fn completable_context(&self, ctx: &CompletionContext) -> Option<()> { - let current_node_kind = ctx.node_under_cursor.map(|n| n.kind()).unwrap_or(""); + let current_node_kind = ctx + .node_under_cursor + .as_ref() + .map(|n| n.kind()) + .unwrap_or(""); if current_node_kind.starts_with("keyword_") || current_node_kind == "=" @@ -36,20 +40,23 @@ impl CompletionFilter<'_> { } // No autocompletions if there are two identifiers without a separator. - if ctx.node_under_cursor.is_some_and(|n| { - n.prev_sibling().is_some_and(|p| { + if ctx.node_under_cursor.as_ref().is_some_and(|n| match n { + NodeUnderCursor::TsNode(node) => node.prev_sibling().is_some_and(|p| { (p.kind() == "identifier" || p.kind() == "object_reference") && n.kind() == "identifier" - }) + }), + NodeUnderCursor::CustomNode { .. } => false, }) { return None; } // no completions if we're right after an asterisk: // `select * {}` - if ctx.node_under_cursor.is_some_and(|n| { - n.prev_sibling() - .is_some_and(|p| (p.kind() == "all_fields") && n.kind() == "identifier") + if ctx.node_under_cursor.as_ref().is_some_and(|n| match n { + NodeUnderCursor::TsNode(node) => node + .prev_sibling() + .is_some_and(|p| (p.kind() == "all_fields") && n.kind() == "identifier"), + NodeUnderCursor::CustomNode { .. } => false, }) { return None; } @@ -60,18 +67,19 @@ impl CompletionFilter<'_> { fn check_clause(&self, ctx: &CompletionContext) -> Option<()> { let clause = ctx.wrapping_clause_type.as_ref(); + let in_clause = |compare: WrappingClause| clause.is_some_and(|c| c == &compare); + match self.data { CompletionRelevanceData::Table(_) => { - let in_select_clause = clause.is_some_and(|c| c == &WrappingClause::Select); - let in_where_clause = clause.is_some_and(|c| c == &WrappingClause::Where); - - if in_select_clause || in_where_clause { + if in_clause(WrappingClause::Select) + || in_clause(WrappingClause::Where) + || in_clause(WrappingClause::PolicyName) + { return None; }; } CompletionRelevanceData::Column(_) => { - let in_from_clause = clause.is_some_and(|c| c == &WrappingClause::From); - if in_from_clause { + if in_clause(WrappingClause::From) || in_clause(WrappingClause::PolicyName) { return None; } @@ -83,6 +91,7 @@ impl CompletionFilter<'_> { WrappingClause::Join { on_node: Some(on) } => ctx .node_under_cursor + .as_ref() .is_some_and(|n| n.end_byte() < on.start_byte()), _ => false, @@ -92,7 +101,16 @@ impl CompletionFilter<'_> { return None; } } - _ => {} + CompletionRelevanceData::Policy(_) => { + if clause.is_none_or(|c| c != &WrappingClause::PolicyName) { + return None; + } + } + _ => { + if in_clause(WrappingClause::PolicyName) { + return None; + } + } } Some(()) @@ -126,10 +144,10 @@ impl CompletionFilter<'_> { .get(schema_or_alias) .is_some_and(|t| t == &col.table_name), - CompletionRelevanceData::Schema(_) => { - // we should never allow schema suggestions if there already was one. - false - } + // we should never allow schema suggestions if there already was one. + CompletionRelevanceData::Schema(_) => false, + // no policy comletion if user typed a schema node first. + CompletionRelevanceData::Policy(_) => false, }; if !matches { diff --git a/crates/pgt_completions/src/relevance/scoring.rs b/crates/pgt_completions/src/relevance/scoring.rs index b0b0bf63..2fe12511 100644 --- a/crates/pgt_completions/src/relevance/scoring.rs +++ b/crates/pgt_completions/src/relevance/scoring.rs @@ -37,20 +37,23 @@ impl CompletionScore<'_> { fn check_matches_query_input(&mut self, ctx: &CompletionContext) { let content = match ctx.get_node_under_cursor_content() { - Some(c) => c, + Some(c) => c.replace('"', ""), None => return, }; let name = match self.data { - CompletionRelevanceData::Function(f) => f.name.as_str(), - CompletionRelevanceData::Table(t) => t.name.as_str(), - CompletionRelevanceData::Column(c) => c.name.as_str(), - CompletionRelevanceData::Schema(s) => s.name.as_str(), + CompletionRelevanceData::Function(f) => f.name.as_str().to_ascii_lowercase(), + CompletionRelevanceData::Table(t) => t.name.as_str().to_ascii_lowercase(), + CompletionRelevanceData::Column(c) => c.name.as_str().to_ascii_lowercase(), + CompletionRelevanceData::Schema(s) => s.name.as_str().to_ascii_lowercase(), + CompletionRelevanceData::Policy(p) => p.name.as_str().to_ascii_lowercase(), }; let fz_matcher = SkimMatcherV2::default(); - if let Some(score) = fz_matcher.fuzzy_match(name, content.as_str()) { + if let Some(score) = + fz_matcher.fuzzy_match(name.as_str(), content.to_ascii_lowercase().as_str()) + { let scorei32: i32 = score .try_into() .expect("The length of the input exceeds i32 capacity"); @@ -82,6 +85,7 @@ impl CompletionScore<'_> { WrappingClause::Join { on_node } if on_node.is_none_or(|on| { ctx.node_under_cursor + .as_ref() .is_none_or(|n| n.end_byte() < on.start_byte()) }) => { @@ -102,6 +106,7 @@ impl CompletionScore<'_> { WrappingClause::Join { on_node } if on_node.is_some_and(|on| { ctx.node_under_cursor + .as_ref() .is_some_and(|n| n.start_byte() > on.end_byte()) }) => { @@ -117,6 +122,10 @@ impl CompletionScore<'_> { WrappingClause::Delete if !has_mentioned_schema => 15, _ => -50, }, + CompletionRelevanceData::Policy(_) => match clause_type { + WrappingClause::PolicyName => 25, + _ => -50, + }, } } @@ -150,6 +159,7 @@ impl CompletionScore<'_> { WrappingNode::Relation if !has_mentioned_schema && has_node_text => 0, _ => -50, }, + CompletionRelevanceData::Policy(_) => 0, } } @@ -183,6 +193,7 @@ impl CompletionScore<'_> { CompletionRelevanceData::Table(t) => t.schema.as_str(), CompletionRelevanceData::Column(c) => c.schema_name.as_str(), CompletionRelevanceData::Schema(s) => s.name.as_str(), + CompletionRelevanceData::Policy(p) => p.schema_name.as_str(), } } @@ -190,6 +201,7 @@ impl CompletionScore<'_> { match self.data { CompletionRelevanceData::Column(c) => Some(c.table_name.as_str()), CompletionRelevanceData::Table(t) => Some(t.name.as_str()), + CompletionRelevanceData::Policy(p) => Some(p.table_name.as_str()), _ => None, } } diff --git a/crates/pgt_completions/src/sanitization.rs b/crates/pgt_completions/src/sanitization.rs index 248a0ffa..6aa75a16 100644 --- a/crates/pgt_completions/src/sanitization.rs +++ b/crates/pgt_completions/src/sanitization.rs @@ -4,6 +4,8 @@ use pgt_text_size::TextSize; use crate::CompletionParams; +static SANITIZED_TOKEN: &str = "REPLACED_TOKEN"; + pub(crate) struct SanitizedCompletionParams<'a> { pub position: TextSize, pub text: String, @@ -16,13 +18,39 @@ pub fn benchmark_sanitization(params: CompletionParams) -> String { params.text } +pub(crate) fn remove_sanitized_token(it: &str) -> String { + it.replace(SANITIZED_TOKEN, "") +} + +#[derive(PartialEq, Eq, Debug)] +pub(crate) enum NodeText { + Replaced, + Original(String), +} + +impl From<&str> for NodeText { + fn from(value: &str) -> Self { + if value == SANITIZED_TOKEN { + NodeText::Replaced + } else { + NodeText::Original(value.into()) + } + } +} + +impl From for NodeText { + fn from(value: String) -> Self { + NodeText::from(value.as_str()) + } +} + impl<'larger, 'smaller> From> for SanitizedCompletionParams<'smaller> where 'larger: 'smaller, { fn from(params: CompletionParams<'larger>) -> Self { - if cursor_inbetween_nodes(params.tree, params.position) - || cursor_prepared_to_write_token_after_last_node(params.tree, params.position) + if cursor_inbetween_nodes(¶ms.text, params.position) + || cursor_prepared_to_write_token_after_last_node(¶ms.text, params.position) || cursor_before_semicolon(params.tree, params.position) || cursor_on_a_dot(¶ms.text, params.position) { @@ -33,8 +61,6 @@ where } } -static SANITIZED_TOKEN: &str = "REPLACED_TOKEN"; - impl<'larger, 'smaller> SanitizedCompletionParams<'smaller> where 'larger: 'smaller, @@ -102,37 +128,17 @@ where /// select |from users; -- cursor "touches" from node. returns false. /// select | from users; -- cursor is between select and from nodes. returns true. /// ``` -fn cursor_inbetween_nodes(tree: &tree_sitter::Tree, position: TextSize) -> bool { - let mut cursor = tree.walk(); - let mut leaf_node = tree.root_node(); - - let byte = position.into(); - - // if the cursor escapes the root node, it can't be between nodes. - if byte < leaf_node.start_byte() || byte >= leaf_node.end_byte() { - return false; - } +fn cursor_inbetween_nodes(sql: &str, position: TextSize) -> bool { + let position: usize = position.into(); + let mut chars = sql.chars(); - /* - * Get closer and closer to the leaf node, until - * a) there is no more child *for the node* or - * b) there is no more child *under the cursor*. - */ - loop { - let child_idx = cursor.goto_first_child_for_byte(position.into()); - if child_idx.is_none() { - break; - } - leaf_node = cursor.node(); - } + let previous_whitespace = chars + .nth(position - 1) + .is_some_and(|c| c.is_ascii_whitespace()); - let cursor_on_leafnode = byte >= leaf_node.start_byte() && leaf_node.end_byte() >= byte; + let current_whitespace = chars.next().is_some_and(|c| c.is_ascii_whitespace()); - /* - * The cursor is inbetween nodes if it is not within the range - * of a leaf node. - */ - !cursor_on_leafnode + previous_whitespace && current_whitespace } /// Checks if the cursor is positioned after the last node, @@ -143,12 +149,9 @@ fn cursor_inbetween_nodes(tree: &tree_sitter::Tree, position: TextSize) -> bool /// select * from| -- user still needs to type a space /// select * from | -- too far off. /// ``` -fn cursor_prepared_to_write_token_after_last_node( - tree: &tree_sitter::Tree, - position: TextSize, -) -> bool { +fn cursor_prepared_to_write_token_after_last_node(sql: &str, position: TextSize) -> bool { let cursor_pos: usize = position.into(); - cursor_pos == tree.root_node().end_byte() + 1 + cursor_pos == sql.len() + 1 } fn cursor_on_a_dot(sql: &str, position: TextSize) -> bool { @@ -214,58 +217,44 @@ mod tests { // note: two spaces between select and from. let input = "select from users;"; - let mut parser = tree_sitter::Parser::new(); - parser - .set_language(tree_sitter_sql::language()) - .expect("Error loading sql language"); - - let tree = parser.parse(input, None).unwrap(); - // select | from users; <-- just right, one space after select token, one space before from - assert!(cursor_inbetween_nodes(&tree, TextSize::new(7))); + assert!(cursor_inbetween_nodes(input, TextSize::new(7))); // select| from users; <-- still on select token - assert!(!cursor_inbetween_nodes(&tree, TextSize::new(6))); + assert!(!cursor_inbetween_nodes(input, TextSize::new(6))); // select |from users; <-- already on from token - assert!(!cursor_inbetween_nodes(&tree, TextSize::new(8))); + assert!(!cursor_inbetween_nodes(input, TextSize::new(8))); // select from users;| - assert!(!cursor_inbetween_nodes(&tree, TextSize::new(19))); + assert!(!cursor_inbetween_nodes(input, TextSize::new(19))); } #[test] fn test_cursor_after_nodes() { let input = "select * from"; - let mut parser = tree_sitter::Parser::new(); - parser - .set_language(tree_sitter_sql::language()) - .expect("Error loading sql language"); - - let tree = parser.parse(input, None).unwrap(); - // select * from| <-- still on previous token assert!(!cursor_prepared_to_write_token_after_last_node( - &tree, + input, TextSize::new(13) )); // select * from | <-- too far off, two spaces afterward assert!(!cursor_prepared_to_write_token_after_last_node( - &tree, + input, TextSize::new(15) )); // select * |from <-- it's within assert!(!cursor_prepared_to_write_token_after_last_node( - &tree, + input, TextSize::new(9) )); // select * from | <-- just right assert!(cursor_prepared_to_write_token_after_last_node( - &tree, + input, TextSize::new(14) )); } diff --git a/crates/pgt_lsp/src/adapters/mod.rs b/crates/pgt_lsp/src/adapters/mod.rs index 972dd576..a5375180 100644 --- a/crates/pgt_lsp/src/adapters/mod.rs +++ b/crates/pgt_lsp/src/adapters/mod.rs @@ -158,6 +158,27 @@ mod tests { assert!(offset.is_none()); } + #[test] + fn with_tabs() { + let line_index = LineIndex::new( + r#" +select + email, + id +from auth.users u +join public.client_identities c on u.id = c.user_id; +"# + .trim(), + ); + + // on `i` of `id` in the select + // 22 because of: + // selectemail,i = 13 + // 8 spaces, 2 newlines = 23 characters + // it's zero indexed => index 22 + check_conversion!(line_index: Position { line: 2, character: 4 } => TextSize::from(22)); + } + #[test] fn unicode() { let line_index = LineIndex::new("'Jan 1, 2018 – Jan 1, 2019'"); diff --git a/crates/pgt_lsp/src/handlers/completions.rs b/crates/pgt_lsp/src/handlers/completions.rs index e1a7508c..ee13b26e 100644 --- a/crates/pgt_lsp/src/handlers/completions.rs +++ b/crates/pgt_lsp/src/handlers/completions.rs @@ -65,5 +65,6 @@ fn to_lsp_types_completion_item_kind( pgt_completions::CompletionItemKind::Table => lsp_types::CompletionItemKind::CLASS, pgt_completions::CompletionItemKind::Column => lsp_types::CompletionItemKind::FIELD, pgt_completions::CompletionItemKind::Schema => lsp_types::CompletionItemKind::CLASS, + pgt_completions::CompletionItemKind::Policy => lsp_types::CompletionItemKind::CONSTANT, } } diff --git a/crates/pgt_schema_cache/src/lib.rs b/crates/pgt_schema_cache/src/lib.rs index d978a94b..e73901d0 100644 --- a/crates/pgt_schema_cache/src/lib.rs +++ b/crates/pgt_schema_cache/src/lib.rs @@ -14,6 +14,7 @@ mod versions; pub use columns::*; pub use functions::{Behavior, Function, FunctionArg, FunctionArgs}; +pub use policies::{Policy, PolicyCommand}; pub use schema_cache::SchemaCache; pub use schemas::Schema; pub use tables::{ReplicaIdentity, Table}; diff --git a/crates/pgt_schema_cache/src/policies.rs b/crates/pgt_schema_cache/src/policies.rs index 641dad12..85cd7821 100644 --- a/crates/pgt_schema_cache/src/policies.rs +++ b/crates/pgt_schema_cache/src/policies.rs @@ -56,14 +56,14 @@ impl From for Policy { #[derive(Debug, PartialEq, Eq)] pub struct Policy { - name: String, - table_name: String, - schema_name: String, - is_permissive: bool, - command: PolicyCommand, - role_names: Vec, - security_qualification: Option, - with_check: Option, + pub name: String, + pub table_name: String, + pub schema_name: String, + pub is_permissive: bool, + pub command: PolicyCommand, + pub role_names: Vec, + pub security_qualification: Option, + pub with_check: Option, } impl SchemaCacheItem for Policy { diff --git a/crates/pgt_text_size/src/range.rs b/crates/pgt_text_size/src/range.rs index 3cfc3c96..baab91e9 100644 --- a/crates/pgt_text_size/src/range.rs +++ b/crates/pgt_text_size/src/range.rs @@ -299,6 +299,39 @@ impl TextRange { end: self.end.checked_add(offset)?, }) } + + /// Expand the range's start by the given offset. + /// The start will never exceed the range's end. + /// + /// # Examples + /// + /// ```rust + /// # use pgt_text_size::*; + /// assert_eq!( + /// TextRange::new(2.into(), 12.into()).checked_expand_start(4.into()).unwrap(), + /// TextRange::new(6.into(), 12.into()), + /// ); + /// + /// assert_eq!( + /// TextRange::new(2.into(), 12.into()).checked_expand_start(12.into()).unwrap(), + /// TextRange::new(12.into(), 12.into()), + /// ); + /// ``` + #[inline] + pub fn checked_expand_start(self, offset: TextSize) -> Option { + let new_start = self.start.checked_add(offset)?; + let end = self.end; + + if new_start > end { + Some(TextRange { start: end, end }) + } else { + Some(TextRange { + start: new_start, + end, + }) + } + } + /// Subtract an offset from this range. /// /// Note that this is not appropriate for changing where a `TextRange` is diff --git a/crates/pgt_workspace/src/workspace/server.rs b/crates/pgt_workspace/src/workspace/server.rs index 5a7bfc44..2c0f2b75 100644 --- a/crates/pgt_workspace/src/workspace/server.rs +++ b/crates/pgt_workspace/src/workspace/server.rs @@ -488,8 +488,11 @@ impl Workspace for WorkspaceServer { let schema_cache = self.schema_cache.load(pool)?; match get_statement_for_completions(&parsed_doc, params.position) { - None => Ok(CompletionsResult::default()), - Some((_id, range, content, cst)) => { + None => { + tracing::debug!("No statement found."); + Ok(CompletionsResult::default()) + } + Some((id, range, content, cst)) => { let position = params.position - range.start(); let items = pgt_completions::complete(pgt_completions::CompletionParams { @@ -499,6 +502,12 @@ impl Workspace for WorkspaceServer { text: content, }); + tracing::debug!( + "Found {} completion items for statement with id {}", + items.len(), + id.raw() + ); + Ok(CompletionsResult { items }) } }