1
- use std:: ops:: Range ;
1
+ use std:: {
2
+ collections:: { HashMap , HashSet } ,
3
+ hash:: Hash ,
4
+ ops:: Range ,
5
+ } ;
2
6
3
7
use pg_schema_cache:: SchemaCache ;
4
- use pg_treesitter_queries:: { queries, TreeSitterQueriesExecutor } ;
8
+ use pg_treesitter_queries:: {
9
+ queries:: { self , QueryResult } ,
10
+ TreeSitterQueriesExecutor ,
11
+ } ;
5
12
6
13
use crate :: CompletionParams ;
7
14
@@ -57,11 +64,11 @@ pub(crate) struct CompletionContext<'a> {
57
64
pub is_invocation : bool ,
58
65
pub wrapping_statement_range : Option < Range < usize > > ,
59
66
60
- pub ts_query_executor : Option < TreeSitterQueriesExecutor < ' a > > ,
67
+ pub mentioned_relations : HashMap < Option < String > , HashSet < String > > ,
61
68
}
62
69
63
70
impl < ' a > CompletionContext < ' a > {
64
- pub async fn new ( params : & ' a CompletionParams < ' a > ) -> Self {
71
+ pub fn new ( params : & ' a CompletionParams < ' a > ) -> Self {
65
72
let mut ctx = Self {
66
73
tree : params. tree ,
67
74
text : & params. text ,
@@ -73,26 +80,49 @@ impl<'a> CompletionContext<'a> {
73
80
wrapping_clause_type : None ,
74
81
wrapping_statement_range : None ,
75
82
is_invocation : false ,
76
- ts_query_executor : None ,
83
+ mentioned_relations : HashMap :: new ( ) ,
77
84
} ;
78
85
79
86
ctx. gather_tree_context ( ) ;
80
- ctx. dispatch_ts_queries ( ) . await ;
87
+ ctx. gather_info_from_ts_queries ( ) ;
81
88
82
89
ctx
83
90
}
84
91
85
- async fn dispatch_ts_queries ( & mut self ) {
92
+ fn gather_info_from_ts_queries ( & mut self ) {
86
93
let tree = match self . tree . as_ref ( ) {
87
94
None => return ,
88
95
Some ( t) => t,
89
96
} ;
90
97
91
- let mut executor = TreeSitterQueriesExecutor :: new ( tree. root_node ( ) , self . text ) ;
98
+ let stmt_range = self . wrapping_statement_range . as_ref ( ) ;
99
+ let sql = self . text ;
92
100
93
- executor. add_query_results :: < queries :: RelationMatch > ( ) . await ;
101
+ let mut executor = TreeSitterQueriesExecutor :: new ( tree . root_node ( ) , self . text ) ;
94
102
95
- self . ts_query_executor = Some ( executor) ;
103
+ executor. add_query_results :: < queries:: RelationMatch > ( ) ;
104
+
105
+ for relation_match in executor. get_iter ( stmt_range) {
106
+ match relation_match {
107
+ QueryResult :: Relation ( r) => {
108
+ let schema_name = r. get_schema ( sql) ;
109
+ let table_name = r. get_table ( sql) ;
110
+
111
+ let current = self . mentioned_relations . get_mut ( & schema_name) ;
112
+
113
+ match current {
114
+ Some ( c) => {
115
+ c. insert ( table_name) ;
116
+ }
117
+ None => {
118
+ let mut new = HashSet :: new ( ) ;
119
+ new. insert ( table_name) ;
120
+ self . mentioned_relations . insert ( schema_name, new) ;
121
+ }
122
+ } ;
123
+ }
124
+ } ;
125
+ }
96
126
}
97
127
98
128
pub fn get_ts_node_content ( & self , ts_node : tree_sitter:: Node < ' a > ) -> Option < & ' a str > {
@@ -203,8 +233,8 @@ mod tests {
203
233
parser. parse ( input, None ) . expect ( "Unable to parse tree" )
204
234
}
205
235
206
- #[ tokio :: test]
207
- async fn identifies_clauses ( ) {
236
+ #[ test]
237
+ fn identifies_clauses ( ) {
208
238
let test_cases = vec ! [
209
239
( format!( "Select {}* from users;" , CURSOR_POS ) , "select" ) ,
210
240
( format!( "Select * from u{};" , CURSOR_POS ) , "from" ) ,
@@ -244,14 +274,14 @@ mod tests {
244
274
schema : & pg_schema_cache:: SchemaCache :: new ( ) ,
245
275
} ;
246
276
247
- let ctx = CompletionContext :: new ( & params) . await ;
277
+ let ctx = CompletionContext :: new ( & params) ;
248
278
249
279
assert_eq ! ( ctx. wrapping_clause_type, expected_clause. try_into( ) . ok( ) ) ;
250
280
}
251
281
}
252
282
253
- #[ tokio :: test]
254
- async fn identifies_schema ( ) {
283
+ #[ test]
284
+ fn identifies_schema ( ) {
255
285
let test_cases = vec ! [
256
286
(
257
287
format!( "Select * from private.u{}" , CURSOR_POS ) ,
@@ -276,14 +306,14 @@ mod tests {
276
306
schema : & pg_schema_cache:: SchemaCache :: new ( ) ,
277
307
} ;
278
308
279
- let ctx = CompletionContext :: new ( & params) . await ;
309
+ let ctx = CompletionContext :: new ( & params) ;
280
310
281
311
assert_eq ! ( ctx. schema_name, expected_schema. map( |f| f. to_string( ) ) ) ;
282
312
}
283
313
}
284
314
285
- #[ tokio :: test]
286
- async fn identifies_invocation ( ) {
315
+ #[ test]
316
+ fn identifies_invocation ( ) {
287
317
let test_cases = vec ! [
288
318
( format!( "Select * from u{}sers" , CURSOR_POS ) , false ) ,
289
319
( format!( "Select * from u{}sers()" , CURSOR_POS ) , true ) ,
@@ -310,14 +340,14 @@ mod tests {
310
340
schema : & pg_schema_cache:: SchemaCache :: new ( ) ,
311
341
} ;
312
342
313
- let ctx = CompletionContext :: new ( & params) . await ;
343
+ let ctx = CompletionContext :: new ( & params) ;
314
344
315
345
assert_eq ! ( ctx. is_invocation, is_invocation) ;
316
346
}
317
347
}
318
348
319
- #[ tokio :: test]
320
- async fn does_not_fail_on_leading_whitespace ( ) {
349
+ #[ test]
350
+ fn does_not_fail_on_leading_whitespace ( ) {
321
351
let cases = vec ! [
322
352
format!( "{} select * from" , CURSOR_POS ) ,
323
353
format!( " {} select * from" , CURSOR_POS ) ,
@@ -335,7 +365,7 @@ mod tests {
335
365
schema : & pg_schema_cache:: SchemaCache :: new ( ) ,
336
366
} ;
337
367
338
- let ctx = CompletionContext :: new ( & params) . await ;
368
+ let ctx = CompletionContext :: new ( & params) ;
339
369
340
370
let node = ctx. ts_node . unwrap ( ) ;
341
371
@@ -348,8 +378,8 @@ mod tests {
348
378
}
349
379
}
350
380
351
- #[ tokio :: test]
352
- async fn does_not_fail_on_trailing_whitespace ( ) {
381
+ #[ test]
382
+ fn does_not_fail_on_trailing_whitespace ( ) {
353
383
let query = format ! ( "select * from {}" , CURSOR_POS ) ;
354
384
355
385
let ( position, text) = get_text_and_position ( query. as_str ( ) ) ;
@@ -363,7 +393,7 @@ mod tests {
363
393
schema : & pg_schema_cache:: SchemaCache :: new ( ) ,
364
394
} ;
365
395
366
- let ctx = CompletionContext :: new ( & params) . await ;
396
+ let ctx = CompletionContext :: new ( & params) ;
367
397
368
398
let node = ctx. ts_node . unwrap ( ) ;
369
399
@@ -374,8 +404,8 @@ mod tests {
374
404
) ;
375
405
}
376
406
377
- #[ tokio :: test]
378
- async fn does_not_fail_with_empty_statements ( ) {
407
+ #[ test]
408
+ fn does_not_fail_with_empty_statements ( ) {
379
409
let query = format ! ( "{}" , CURSOR_POS ) ;
380
410
381
411
let ( position, text) = get_text_and_position ( query. as_str ( ) ) ;
@@ -389,16 +419,16 @@ mod tests {
389
419
schema : & pg_schema_cache:: SchemaCache :: new ( ) ,
390
420
} ;
391
421
392
- let ctx = CompletionContext :: new ( & params) . await ;
422
+ let ctx = CompletionContext :: new ( & params) ;
393
423
394
424
let node = ctx. ts_node . unwrap ( ) ;
395
425
396
426
assert_eq ! ( ctx. get_ts_node_content( node) , Some ( "" ) ) ;
397
427
assert_eq ! ( ctx. wrapping_clause_type, None ) ;
398
428
}
399
429
400
- #[ tokio :: test]
401
- async fn does_not_fail_on_incomplete_keywords ( ) {
430
+ #[ test]
431
+ fn does_not_fail_on_incomplete_keywords ( ) {
402
432
// Instead of autocompleting "FROM", we'll assume that the user
403
433
// is selecting a certain column name, such as `frozen_account`.
404
434
let query = format ! ( "select * fro{}" , CURSOR_POS ) ;
@@ -414,7 +444,7 @@ mod tests {
414
444
schema : & pg_schema_cache:: SchemaCache :: new ( ) ,
415
445
} ;
416
446
417
- let ctx = CompletionContext :: new ( & params) . await ;
447
+ let ctx = CompletionContext :: new ( & params) ;
418
448
419
449
let node = ctx. ts_node . unwrap ( ) ;
420
450
0 commit comments