@@ -41,14 +41,15 @@ impl<'a> StatementSplitter<'a> {
41
41
match self . parser . nth ( 0 , false ) . kind {
42
42
SyntaxKind :: Ascii40 => {
43
43
// "("
44
- self . sub_trx_depth += 1 ;
44
+ self . sub_stmt_depth += 1 ;
45
45
}
46
46
SyntaxKind :: Ascii41 => {
47
47
// ")"
48
- self . sub_trx_depth -= 1 ;
48
+ self . sub_stmt_depth -= 1 ;
49
49
}
50
50
SyntaxKind :: Atomic => {
51
- if self . parser . lookbehind ( 2 , true ) . map ( |t| t. kind ) == Some ( SyntaxKind :: BeginP ) {
51
+ if self . parser . lookbehind ( 2 , true , None ) . map ( |t| t. kind ) == Some ( SyntaxKind :: BeginP )
52
+ {
52
53
self . is_within_atomic_block = true ;
53
54
}
54
55
}
@@ -88,7 +89,7 @@ impl<'a> StatementSplitter<'a> {
88
89
kind : SyntaxKind :: Any ,
89
90
range : TextRange :: new (
90
91
self . token_range ( started_at. unwrap ( ) ) . start ( ) ,
91
- self . parser . lookbehind ( 2 , true ) . unwrap ( ) . span . end ( ) ,
92
+ self . parser . lookbehind ( 2 , true , None ) . unwrap ( ) . span . end ( ) ,
92
93
) ,
93
94
} ) ;
94
95
}
@@ -115,23 +116,23 @@ impl<'a> StatementSplitter<'a> {
115
116
let new_stmts = STATEMENT_DEFINITIONS . get ( & self . parser . nth ( 0 , false ) . kind ) ;
116
117
117
118
if let Some ( new_stmts) = new_stmts {
118
- self . tracked_statements . append (
119
- & mut new_stmts
120
- . iter ( )
121
- . filter_map ( |stmt| {
122
- if self . active_bridges . iter ( ) . any ( |b| b . def . stmt == stmt . stmt ) {
123
- None
124
- } else if self . tracked_statements . iter ( ) . any ( |s| {
125
- s . could_be_complete ( )
126
- && s . def . prohibited_following_statements . contains ( & stmt. stmt )
127
- } ) {
128
- None
129
- } else {
130
- Some ( Tracker :: new_at ( stmt, self . parser . pos ) )
131
- }
132
- } )
133
- . collect ( ) ,
134
- ) ;
119
+ let to_add = & mut new_stmts
120
+ . iter ( )
121
+ . filter_map ( |stmt| {
122
+ if self . active_bridges . iter ( ) . any ( |b| b . def . stmt == stmt . stmt ) {
123
+ None
124
+ } else if self
125
+ . tracked_statements
126
+ . iter_mut ( )
127
+ . any ( |s| !s . can_start_stmt_after ( & stmt. stmt ) )
128
+ {
129
+ None
130
+ } else {
131
+ Some ( Tracker :: new_at ( stmt, self . parser . pos ) )
132
+ }
133
+ } )
134
+ . collect ( ) ;
135
+ self . tracked_statements . append ( to_add ) ;
135
136
}
136
137
}
137
138
@@ -283,7 +284,11 @@ impl<'a> StatementSplitter<'a> {
283
284
284
285
pub fn run ( mut self ) -> Vec < StatementPosition > {
285
286
while !self . parser . eof ( ) {
286
- println ! ( "{:?}" , self . parser. nth( 0 , false ) . kind) ;
287
+ println ! (
288
+ "#{:?}: {:?}" ,
289
+ self . parser. pos,
290
+ self . parser. nth( 0 , false ) . kind
291
+ ) ;
287
292
println ! (
288
293
"tracked stmts before {:?}" ,
289
294
self . tracked_statements
@@ -342,13 +347,13 @@ impl<'a> StatementSplitter<'a> {
342
347
343
348
// the end position is the end() of the last non-whitespace token before the start
344
349
// of the latest complete statement
345
- let latest_non_whitespace_token = self
346
- . parser
347
- . lookbehind ( self . parser . pos - latest_completed_stmt_started_at + 1 , true ) ;
350
+ let latest_non_whitespace_token = self . parser . lookbehind (
351
+ 2 ,
352
+ true ,
353
+ Some ( self . parser . pos - latest_completed_stmt_started_at) ,
354
+ ) ;
348
355
let end_pos = latest_non_whitespace_token. unwrap ( ) . span . end ( ) ;
349
356
350
- println ! ( "adding stmt: {:?}" , stmt_kind) ;
351
-
352
357
self . ranges . push ( StatementPosition {
353
358
kind : stmt_kind,
354
359
range : TextRange :: new ( start_pos, end_pos) ,
@@ -365,13 +370,6 @@ impl<'a> StatementSplitter<'a> {
365
370
366
371
// we reached eof; add any remaining statements
367
372
368
- println ! (
369
- "tracked stmts after eof {:?}" ,
370
- self . tracked_statements
371
- . iter( )
372
- . map( |s| s. def. stmt)
373
- . collect:: <Vec <_>>( )
374
- ) ;
375
373
// get the earliest statement that is complete
376
374
if let Some ( earliest_complete_stmt_started_at) =
377
375
self . find_earliest_complete_statement_start_pos ( )
@@ -387,7 +385,7 @@ impl<'a> StatementSplitter<'a> {
387
385
388
386
let start_pos = self . token_range ( earliest_complete_stmt_started_at) . start ( ) ;
389
387
390
- let end_token = self . parser . lookbehind ( 1 , true ) . unwrap ( ) ;
388
+ let end_token = self . parser . lookbehind ( 1 , true , None ) . unwrap ( ) ;
391
389
let end_pos = end_token. span . end ( ) ;
392
390
393
391
println ! ( "adding stmt at end: {:?}" , earliest_complete_stmt. def. stmt) ;
@@ -406,7 +404,7 @@ impl<'a> StatementSplitter<'a> {
406
404
let start_pos = self . token_range ( earliest_stmt_started_at) . start ( ) ;
407
405
408
406
// end position is last non-whitespace token before or at the current position
409
- let end_pos = self . parser . lookbehind ( 1 , true ) . unwrap ( ) . span . end ( ) ;
407
+ let end_pos = self . parser . lookbehind ( 1 , true , None ) . unwrap ( ) . span . end ( ) ;
410
408
411
409
println ! ( "adding any stmt at end" ) ;
412
410
self . ranges . push ( StatementPosition {
@@ -425,6 +423,37 @@ mod tests {
425
423
426
424
use crate :: statement_splitter:: StatementSplitter ;
427
425
426
+ #[ test]
427
+ fn test_simple_select ( ) {
428
+ let input = "
429
+ select id, name, test1231234123, unknown from co;
430
+
431
+ select 14433313331333
432
+
433
+ alter table test drop column id;
434
+
435
+ select lower('test');
436
+ " ;
437
+
438
+ let result = StatementSplitter :: new ( input) . run ( ) ;
439
+
440
+ assert_eq ! ( result. len( ) , 4 ) ;
441
+ assert_eq ! (
442
+ "select id, name, test1231234123, unknown from co;" ,
443
+ input[ result[ 0 ] . range] . to_string( )
444
+ ) ;
445
+ assert_eq ! ( SyntaxKind :: SelectStmt , result[ 0 ] . kind) ;
446
+ assert_eq ! ( "select 14433313331333" , input[ result[ 1 ] . range] . to_string( ) ) ;
447
+ assert_eq ! ( SyntaxKind :: SelectStmt , result[ 1 ] . kind) ;
448
+ assert_eq ! ( SyntaxKind :: AlterTableStmt , result[ 2 ] . kind) ;
449
+ assert_eq ! (
450
+ "alter table test drop column id;" ,
451
+ input[ result[ 2 ] . range] . to_string( )
452
+ ) ;
453
+ assert_eq ! ( SyntaxKind :: SelectStmt , result[ 3 ] . kind) ;
454
+ assert_eq ! ( "select lower('test');" , input[ result[ 3 ] . range] . to_string( ) ) ;
455
+ }
456
+
428
457
#[ test]
429
458
fn test_create_or_replace ( ) {
430
459
let input = "CREATE OR REPLACE TRIGGER check_update
@@ -586,19 +615,19 @@ mod tests {
586
615
587
616
#[ test]
588
617
fn test_explain_analyze ( ) {
589
- let input = "explain analyze select 1 from contact\n select 1\n select 4" ;
618
+ let input = "explain analyze select 1 from contact; \n select 1; \n select 4; " ;
590
619
591
620
let result = StatementSplitter :: new ( input) . run ( ) ;
592
621
593
622
assert_eq ! ( result. len( ) , 3 ) ;
594
623
assert_eq ! (
595
- "explain analyze select 1 from contact" ,
624
+ "explain analyze select 1 from contact; " ,
596
625
input[ result[ 0 ] . range] . to_string( )
597
626
) ;
598
627
assert_eq ! ( SyntaxKind :: ExplainStmt , result[ 0 ] . kind) ;
599
- assert_eq ! ( "select 1" , input[ result[ 1 ] . range] . to_string( ) ) ;
628
+ assert_eq ! ( "select 1; " , input[ result[ 1 ] . range] . to_string( ) ) ;
600
629
assert_eq ! ( SyntaxKind :: SelectStmt , result[ 1 ] . kind) ;
601
- assert_eq ! ( "select 4" , input[ result[ 2 ] . range] . to_string( ) ) ;
630
+ assert_eq ! ( "select 4; " , input[ result[ 2 ] . range] . to_string( ) ) ;
602
631
assert_eq ! ( SyntaxKind :: SelectStmt , result[ 2 ] . kind) ;
603
632
}
604
633
@@ -694,10 +723,6 @@ DROP ROLE IF EXISTS regress_alter_generic_user1;";
694
723
695
724
let result = StatementSplitter :: new ( input) . run ( ) ;
696
725
697
- for r in & result {
698
- println ! ( "{:?} {:?}" , r. kind, input[ r. range] . to_string( ) ) ;
699
- }
700
-
701
726
assert_eq ! ( result. len( ) , 2 ) ;
702
727
assert_eq ! ( "create" , input[ result[ 0 ] . range] . to_string( ) ) ;
703
728
assert_eq ! ( SyntaxKind :: Any , result[ 0 ] . kind) ;
@@ -1008,7 +1033,7 @@ ALTER OPERATOR FAMILY alt_nsp6.alt_opf6 USING btree ADD OPERATOR 1 < (int4, int2
1008
1033
fn test_alter_op_family_2 ( ) {
1009
1034
let input = "
1010
1035
CREATE OPERATOR FAMILY alt_opf4 USING btree;
1011
- ALTER OPERATOR FAMILY schema .alt_opf4 USING btree ADD
1036
+ ALTER OPERATOR FAMILY test .alt_opf4 USING btree ADD
1012
1037
-- int4 vs int2
1013
1038
OPERATOR 1 < (int4, int2) ,
1014
1039
OPERATOR 2 <= (int4, int2) ,
0 commit comments