Skip to content

feat(completions): complete policies #397

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
May 22, 2025
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion crates/pgt_completions/src/complete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ use crate::{
builder::CompletionBuilder,
context::CompletionContext,
item::CompletionItem,
providers::{complete_columns, complete_functions, complete_schemas, complete_tables},
providers::{
complete_columns, complete_functions, complete_policies, complete_schemas, complete_tables,
},
sanitization::SanitizedCompletionParams,
};

Expand Down Expand Up @@ -33,6 +35,7 @@ pub fn complete(params: CompletionParams) -> Vec<CompletionItem> {
complete_functions(&ctx, &mut builder);
complete_columns(&ctx, &mut builder);
complete_schemas(&ctx, &mut builder);
complete_policies(&ctx, &mut builder);

builder.finish()
}
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
mod policy_parser;

use std::collections::{HashMap, HashSet};

use pgt_schema_cache::SchemaCache;
use pgt_text_size::TextRange;
use pgt_treesitter_queries::{
TreeSitterQueriesExecutor,
queries::{self, QueryResult},
};

use crate::sanitization::SanitizedCompletionParams;
use crate::{
NodeText,
context::policy_parser::{PolicyParser, PolicyStmtKind},
sanitization::SanitizedCompletionParams,
};

#[derive(Debug, PartialEq, Eq)]
pub enum WrappingClause<'a> {
Expand All @@ -18,12 +25,8 @@ pub enum WrappingClause<'a> {
},
Update,
Delete,
}

#[derive(PartialEq, Eq, Debug)]
pub(crate) enum NodeText<'a> {
Replaced,
Original(&'a str),
PolicyName,
ToRoleAssignment,
}

/// We can map a few nodes, such as the "update" node, to actual SQL clauses.
Expand All @@ -41,6 +44,45 @@ pub enum WrappingNode {
Assignment,
}

#[derive(Debug)]
pub(crate) enum NodeUnderCursor<'a> {
TsNode(tree_sitter::Node<'a>),
CustomNode {
text: NodeText,
range: TextRange,
kind: String,
},
}

impl NodeUnderCursor<'_> {
pub fn start_byte(&self) -> usize {
match self {
NodeUnderCursor::TsNode(node) => node.start_byte(),
NodeUnderCursor::CustomNode { range, .. } => range.start().into(),
}
}

pub fn end_byte(&self) -> usize {
match self {
NodeUnderCursor::TsNode(node) => node.end_byte(),
NodeUnderCursor::CustomNode { range, .. } => range.end().into(),
}
}

pub fn kind(&self) -> &str {
match self {
NodeUnderCursor::TsNode(node) => node.kind(),
NodeUnderCursor::CustomNode { kind, .. } => kind.as_str(),
}
}
}

impl<'a> From<tree_sitter::Node<'a>> for NodeUnderCursor<'a> {
fn from(node: tree_sitter::Node<'a>) -> Self {
NodeUnderCursor::TsNode(node)
}
}

impl TryFrom<&str> for WrappingNode {
type Error = String;

Expand Down Expand Up @@ -71,7 +113,7 @@ impl TryFrom<String> for WrappingNode {
}

pub(crate) struct CompletionContext<'a> {
pub node_under_cursor: Option<tree_sitter::Node<'a>>,
pub node_under_cursor: Option<NodeUnderCursor<'a>>,

pub tree: &'a tree_sitter::Tree,
pub text: &'a str,
Expand Down Expand Up @@ -130,12 +172,49 @@ impl<'a> CompletionContext<'a> {
is_in_error_node: false,
};

ctx.gather_tree_context();
ctx.gather_info_from_ts_queries();
// policy handling is important to Supabase, but they are a PostgreSQL specific extension,
// so the tree_sitter_sql language does not support it.
// We infer the context manually.
if PolicyParser::looks_like_policy_stmt(&params.text) {
ctx.gather_policy_context();
} else {
ctx.gather_tree_context();
ctx.gather_info_from_ts_queries();
}

ctx
}

fn gather_policy_context(&mut self) {
let policy_context = PolicyParser::get_context(self.text, self.position);

self.node_under_cursor = Some(NodeUnderCursor::CustomNode {
text: policy_context.node_text.into(),
range: policy_context.node_range,
kind: policy_context.node_kind.clone(),
});

if policy_context.node_kind == "policy_table" {
self.schema_or_alias_name = policy_context.schema_name.clone();
}

if policy_context.table_name.is_some() {
let mut new = HashSet::new();
new.insert(policy_context.table_name.unwrap());
self.mentioned_relations
.insert(policy_context.schema_name, new);
}

self.wrapping_clause_type = match policy_context.node_kind.as_str() {
"policy_name" if policy_context.statement_kind != PolicyStmtKind::Create => {
Some(WrappingClause::PolicyName)
}
"policy_role" => Some(WrappingClause::ToRoleAssignment),
"policy_table" => Some(WrappingClause::From),
_ => None,
};
}

fn gather_info_from_ts_queries(&mut self) {
let stmt_range = self.wrapping_statement_range.as_ref();
let sql = self.text;
Expand Down Expand Up @@ -175,24 +254,30 @@ impl<'a> CompletionContext<'a> {
}
}

pub fn get_ts_node_content(&self, ts_node: tree_sitter::Node<'a>) -> Option<NodeText<'a>> {
fn get_ts_node_content(&self, ts_node: &tree_sitter::Node<'a>) -> Option<NodeText> {
let source = self.text;
ts_node.utf8_text(source.as_bytes()).ok().map(|txt| {
if SanitizedCompletionParams::is_sanitized_token(txt) {
NodeText::Replaced
} else {
NodeText::Original(txt)
NodeText::Original(txt.into())
}
})
}

pub fn get_node_under_cursor_content(&self) -> Option<String> {
self.node_under_cursor
.and_then(|n| self.get_ts_node_content(n))
.and_then(|txt| match txt {
match self.node_under_cursor.as_ref()? {
NodeUnderCursor::TsNode(node) => {
self.get_ts_node_content(node).and_then(|nt| match nt {
NodeText::Replaced => None,
NodeText::Original(c) => Some(c.to_string()),
})
}
NodeUnderCursor::CustomNode { text, .. } => match text {
NodeText::Replaced => None,
NodeText::Original(c) => Some(c.to_string()),
})
},
}
}

fn gather_tree_context(&mut self) {
Expand Down Expand Up @@ -230,7 +315,7 @@ impl<'a> CompletionContext<'a> {

// prevent infinite recursion – this can happen if we only have a PROGRAM node
if current_node_kind == parent_node_kind {
self.node_under_cursor = Some(current_node);
self.node_under_cursor = Some(NodeUnderCursor::from(current_node));
return;
}

Expand Down Expand Up @@ -269,7 +354,7 @@ impl<'a> CompletionContext<'a> {

match current_node_kind {
"object_reference" | "field" => {
let content = self.get_ts_node_content(current_node);
let content = self.get_ts_node_content(&current_node);
if let Some(node_txt) = content {
match node_txt {
NodeText::Original(txt) => {
Expand Down Expand Up @@ -301,7 +386,7 @@ impl<'a> CompletionContext<'a> {

// We have arrived at the leaf node
if current_node.child_count() == 0 {
self.node_under_cursor = Some(current_node);
self.node_under_cursor = Some(NodeUnderCursor::from(current_node));
return;
}

Expand All @@ -314,11 +399,11 @@ impl<'a> CompletionContext<'a> {
node: tree_sitter::Node<'a>,
) -> Option<WrappingClause<'a>> {
if node.kind().starts_with("keyword_") {
if let Some(txt) = self.get_ts_node_content(node).and_then(|txt| match txt {
if let Some(txt) = self.get_ts_node_content(&node).and_then(|txt| match txt {
NodeText::Original(txt) => Some(txt),
NodeText::Replaced => None,
}) {
match txt {
match txt.as_str() {
"where" => return Some(WrappingClause::Where),
"update" => return Some(WrappingClause::Update),
"select" => return Some(WrappingClause::Select),
Expand Down Expand Up @@ -368,11 +453,14 @@ impl<'a> CompletionContext<'a> {
#[cfg(test)]
mod tests {
use crate::{
context::{CompletionContext, NodeText, WrappingClause},
NodeText,
context::{CompletionContext, WrappingClause},
sanitization::SanitizedCompletionParams,
test_helper::{CURSOR_POS, get_text_and_position},
};

use super::NodeUnderCursor;

fn get_tree(input: &str) -> tree_sitter::Tree {
let mut parser = tree_sitter::Parser::new();
parser
Expand Down Expand Up @@ -531,17 +619,22 @@ mod tests {

let ctx = CompletionContext::new(&params);

let node = ctx.node_under_cursor.unwrap();
let node = ctx.node_under_cursor.as_ref().unwrap();

assert_eq!(
ctx.get_ts_node_content(node),
Some(NodeText::Original("select"))
);
match node {
NodeUnderCursor::TsNode(node) => {
assert_eq!(
ctx.get_ts_node_content(node),
Some(NodeText::Original("select".into()))
);

assert_eq!(
ctx.wrapping_clause_type,
Some(crate::context::WrappingClause::Select)
);
assert_eq!(
ctx.wrapping_clause_type,
Some(crate::context::WrappingClause::Select)
);
}
_ => unreachable!(),
}
}
}

Expand All @@ -562,12 +655,17 @@ mod tests {

let ctx = CompletionContext::new(&params);

let node = ctx.node_under_cursor.unwrap();
let node = ctx.node_under_cursor.as_ref().unwrap();

assert_eq!(
ctx.get_ts_node_content(node),
Some(NodeText::Original("from"))
);
match node {
NodeUnderCursor::TsNode(node) => {
assert_eq!(
ctx.get_ts_node_content(node),
Some(NodeText::Original("from".into()))
);
}
_ => unreachable!(),
}
}

#[test]
Expand All @@ -587,10 +685,18 @@ mod tests {

let ctx = CompletionContext::new(&params);

let node = ctx.node_under_cursor.unwrap();
let node = ctx.node_under_cursor.as_ref().unwrap();

assert_eq!(ctx.get_ts_node_content(node), Some(NodeText::Original("")));
assert_eq!(ctx.wrapping_clause_type, None);
match node {
NodeUnderCursor::TsNode(node) => {
assert_eq!(
ctx.get_ts_node_content(node),
Some(NodeText::Original("".into()))
);
assert_eq!(ctx.wrapping_clause_type, None);
}
_ => unreachable!(),
}
}

#[test]
Expand All @@ -612,12 +718,17 @@ mod tests {

let ctx = CompletionContext::new(&params);

let node = ctx.node_under_cursor.unwrap();
let node = ctx.node_under_cursor.as_ref().unwrap();

assert_eq!(
ctx.get_ts_node_content(node),
Some(NodeText::Original("fro"))
);
assert_eq!(ctx.wrapping_clause_type, Some(WrappingClause::Select));
match node {
NodeUnderCursor::TsNode(node) => {
assert_eq!(
ctx.get_ts_node_content(node),
Some(NodeText::Original("fro".into()))
);
assert_eq!(ctx.wrapping_clause_type, Some(WrappingClause::Select));
}
_ => unreachable!(),
}
}
}
Loading
Loading