Skip to content

feat(completions): ts_query package, column autocompletion #168

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ pg_schema_cache = { path = "./crates/pg_schema_cache", version = "0.0.
pg_statement_splitter = { path = "./crates/pg_statement_splitter", version = "0.0.0" }
pg_syntax = { path = "./crates/pg_syntax", version = "0.0.0" }
pg_text_edit = { path = "./crates/pg_text_edit", version = "0.0.0" }
pg_treesitter_queries = { path = "./crates/pg_treesitter_queries", version = "0.0.0" }
pg_type_resolver = { path = "./crates/pg_type_resolver", version = "0.0.0" }
pg_typecheck = { path = "./crates/pg_typecheck", version = "0.0.0" }
pg_workspace = { path = "./crates/pg_workspace", version = "0.0.0" }
Expand Down
1 change: 1 addition & 0 deletions crates/pg_completions/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ serde_json = { workspace = true }
pg_schema_cache.workspace = true
tree-sitter.workspace = true
tree_sitter_sql.workspace = true
pg_treesitter_queries.workspace = true

sqlx.workspace = true

Expand Down
5 changes: 3 additions & 2 deletions crates/pg_completions/src/complete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::{
builder::CompletionBuilder,
context::CompletionContext,
item::CompletionItem,
providers::{complete_functions, complete_tables},
providers::{complete_columns, complete_functions, complete_tables},
};

pub const LIMIT: usize = 50;
Expand All @@ -31,13 +31,14 @@ impl IntoIterator for CompletionResult {
}
}

pub fn complete(params: CompletionParams) -> CompletionResult {
pub fn complete<'a>(params: CompletionParams<'a>) -> CompletionResult {
let ctx = CompletionContext::new(&params);

let mut builder = CompletionBuilder::new();

complete_tables(&ctx, &mut builder);
complete_functions(&ctx, &mut builder);
complete_columns(&ctx, &mut builder);

builder.finish()
}
74 changes: 64 additions & 10 deletions crates/pg_completions/src/context.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
use std::{
collections::{HashMap, HashSet},
hash::Hash,
ops::Range,
};

use pg_schema_cache::SchemaCache;
use pg_treesitter_queries::{
queries::{self, QueryResult},
TreeSitterQueriesExecutor,
};

use crate::CompletionParams;

Expand Down Expand Up @@ -52,10 +62,13 @@ pub(crate) struct CompletionContext<'a> {
pub schema_name: Option<String>,
pub wrapping_clause_type: Option<ClauseType>,
pub is_invocation: bool,
pub wrapping_statement_range: Option<Range<usize>>,

pub mentioned_relations: HashMap<Option<String>, HashSet<String>>,
}

impl<'a> CompletionContext<'a> {
pub fn new(params: &'a CompletionParams) -> Self {
pub fn new(params: &'a CompletionParams<'a>) -> Self {
let mut ctx = Self {
tree: params.tree,
text: &params.text,
Expand All @@ -65,14 +78,53 @@ impl<'a> CompletionContext<'a> {
ts_node: None,
schema_name: None,
wrapping_clause_type: None,
wrapping_statement_range: None,
is_invocation: false,
mentioned_relations: HashMap::new(),
};

ctx.gather_tree_context();
ctx.gather_info_from_ts_queries();

ctx
}

fn gather_info_from_ts_queries(&mut self) {
let tree = match self.tree.as_ref() {
None => return,
Some(t) => t,
};

let stmt_range = self.wrapping_statement_range.as_ref();
let sql = self.text;

let mut executor = TreeSitterQueriesExecutor::new(tree.root_node(), self.text);

executor.add_query_results::<queries::RelationMatch>();

for relation_match in executor.get_iter(stmt_range) {
match relation_match {
QueryResult::Relation(r) => {
let schema_name = r.get_schema(sql);
let table_name = r.get_table(sql);

let current = self.mentioned_relations.get_mut(&schema_name);

match current {
Some(c) => {
c.insert(table_name);
}
None => {
let mut new = HashSet::new();
new.insert(table_name);
self.mentioned_relations.insert(schema_name, new);
}
};
}
};
}
}

pub fn get_ts_node_content(&self, ts_node: tree_sitter::Node<'a>) -> Option<&'a str> {
let source = self.text;
match ts_node.utf8_text(source.as_bytes()) {
Expand Down Expand Up @@ -100,36 +152,38 @@ impl<'a> CompletionContext<'a> {
* We'll therefore adjust the cursor position such that it meets the last node of the AST.
* `select * from use {}` becomes `select * from use{}`.
*/
let current_node_kind = cursor.node().kind();
let current_node = cursor.node();
while cursor.goto_first_child_for_byte(self.position).is_none() && self.position > 0 {
self.position -= 1;
}

self.gather_context_from_node(cursor, current_node_kind);
self.gather_context_from_node(cursor, current_node);
}

fn gather_context_from_node(
&mut self,
mut cursor: tree_sitter::TreeCursor<'a>,
previous_node_kind: &str,
previous_node: tree_sitter::Node<'a>,
) {
let current_node = cursor.node();
let current_node_kind = current_node.kind();

// prevent infinite recursion – this can happen if we only have a PROGRAM node
if current_node_kind == previous_node_kind {
if current_node.kind() == previous_node.kind() {
self.ts_node = Some(current_node);
return;
}

match previous_node_kind {
"statement" => self.wrapping_clause_type = current_node_kind.try_into().ok(),
match previous_node.kind() {
"statement" => {
self.wrapping_clause_type = current_node.kind().try_into().ok();
self.wrapping_statement_range = Some(previous_node.byte_range());
}
"invocation" => self.is_invocation = true,

_ => {}
}

match current_node_kind {
match current_node.kind() {
"object_reference" => {
let txt = self.get_ts_node_content(current_node);
if let Some(txt) = txt {
Expand Down Expand Up @@ -159,7 +213,7 @@ impl<'a> CompletionContext<'a> {
}

cursor.goto_first_child_for_byte(self.position);
self.gather_context_from_node(cursor, current_node_kind);
self.gather_context_from_node(cursor, current_node);
}
}

Expand Down
20 changes: 20 additions & 0 deletions crates/pg_completions/src/providers/columns.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
use crate::{
builder::CompletionBuilder, context::CompletionContext, relevance::CompletionRelevanceData,
CompletionItem, CompletionItemKind,
};

pub fn complete_columns(ctx: &CompletionContext, builder: &mut CompletionBuilder) {
let available_columns = &ctx.schema_cache.columns;

for col in available_columns {
let item = CompletionItem {
label: col.name.clone(),
score: CompletionRelevanceData::Column(col).get_score(ctx),
description: format!("Table: {}.{}", col.schema_name, col.table_name),
preselected: false,
kind: CompletionItemKind::Function,
};

builder.add_item(item);
}
}
2 changes: 2 additions & 0 deletions crates/pg_completions/src/providers/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
mod columns;
mod functions;
mod tables;

pub use columns::*;
pub use functions::*;
pub use tables::*;
57 changes: 53 additions & 4 deletions crates/pg_completions/src/relevance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::context::{ClauseType, CompletionContext};
pub(crate) enum CompletionRelevanceData<'a> {
Table(&'a pg_schema_cache::Table),
Function(&'a pg_schema_cache::Function),
Column(&'a pg_schema_cache::Column),
}

impl<'a> CompletionRelevanceData<'a> {
Expand Down Expand Up @@ -34,6 +35,7 @@ impl<'a> CompletionRelevance<'a> {
self.check_if_catalog(ctx);
self.check_is_invocation(ctx);
self.check_matching_clause_type(ctx);
self.check_relations_in_stmt(ctx);

self.score
}
Expand All @@ -49,6 +51,7 @@ impl<'a> CompletionRelevance<'a> {
let name = match self.data {
CompletionRelevanceData::Function(f) => f.name.as_str(),
CompletionRelevanceData::Table(t) => t.name.as_str(),
CompletionRelevanceData::Column(c) => c.name.as_str(),
};

if name.starts_with(content) {
Expand Down Expand Up @@ -79,6 +82,11 @@ impl<'a> CompletionRelevance<'a> {
ClauseType::From => 0,
_ => -50,
},
CompletionRelevanceData::Column(_) => match clause_type {
ClauseType::Select => 15,
ClauseType::Where => 15,
_ => -15,
},
}
}

Expand Down Expand Up @@ -107,10 +115,7 @@ impl<'a> CompletionRelevance<'a> {
Some(n) => n,
};

let data_schema = match self.data {
CompletionRelevanceData::Function(f) => f.schema.as_str(),
CompletionRelevanceData::Table(t) => t.schema.as_str(),
};
let data_schema = self.get_schema_name();

if schema_name == data_schema {
self.score += 25;
Expand All @@ -119,11 +124,55 @@ impl<'a> CompletionRelevance<'a> {
}
}

fn get_schema_name(&self) -> &str {
match self.data {
CompletionRelevanceData::Function(f) => f.schema.as_str(),
CompletionRelevanceData::Table(t) => t.schema.as_str(),
CompletionRelevanceData::Column(c) => c.schema_name.as_str(),
}
}

fn get_table_name(&self) -> Option<&str> {
match self.data {
CompletionRelevanceData::Column(c) => Some(c.table_name.as_str()),
CompletionRelevanceData::Table(t) => Some(t.name.as_str()),
_ => None,
}
}

fn check_if_catalog(&mut self, ctx: &CompletionContext) {
if ctx.schema_name.as_ref().is_some_and(|n| n == "pg_catalog") {
return;
}

self.score -= 5; // unlikely that the user wants schema data
}

fn check_relations_in_stmt(&mut self, ctx: &CompletionContext) {
match self.data {
CompletionRelevanceData::Table(_) => return,
CompletionRelevanceData::Function(_) => return,
_ => {}
}

let schema = self.get_schema_name().to_string();
let table_name = match self.get_table_name() {
Some(t) => t,
None => return,
};

if ctx
.mentioned_relations
.get(&Some(schema.to_string()))
.is_some_and(|tables| tables.contains(table_name))
{
self.score += 45;
} else if ctx
.mentioned_relations
.get(&None)
.is_some_and(|tables| tables.contains(table_name))
{
self.score += 30;
}
}
}
1 change: 1 addition & 0 deletions crates/pg_schema_cache/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ mod tables;
mod types;
mod versions;

pub use columns::*;
pub use functions::{Behavior, Function, FunctionArg, FunctionArgs};
pub use schema_cache::SchemaCache;
pub use tables::{ReplicaIdentity, Table};
4 changes: 4 additions & 0 deletions crates/pg_test_utils/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ version = "0.0.0"
name = "tree_print"
path = "src/bin/tree_print.rs"

[[bin]]
name = "query_debug"
path = "src/bin/tree_query_debug.rs"

[dependencies]
anyhow = "1.0.81"
clap = { version = "4.5.23", features = ["derive"] }
Expand Down
Loading
Loading