Skip to content

Commit 9c2ba7c

Browse files
committed
fix: minor fixes
1 parent 399c0d2 commit 9c2ba7c

File tree

4 files changed

+162
-121
lines changed

4 files changed

+162
-121
lines changed

crates/pg_statement_splitter/src/data.rs

Lines changed: 46 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,6 @@ pub static STATEMENT_DEFINITIONS: LazyLock<HashMap<SyntaxKind, Vec<StatementDefi
422422
SyntaxBuilder::new()
423423
.required_token(SyntaxKind::Drop)
424424
.one_of(vec![SyntaxKind::Rule, SyntaxKind::Trigger])
425-
.required_token(SyntaxKind::Trigger)
426425
.optional_if_exists_group()
427426
.required_token(SyntaxKind::Ident)
428427
.required_token(SyntaxKind::On)
@@ -501,7 +500,7 @@ pub static STATEMENT_DEFINITIONS: LazyLock<HashMap<SyntaxKind, Vec<StatementDefi
501500
.required_token(SyntaxKind::Operator)
502501
.optional_if_exists_group()
503502
.optional_schema_name_group()
504-
.one_of(vec![SyntaxKind::Ident, SyntaxKind::Operator])
503+
.one_of(vec![SyntaxKind::Ident, SyntaxKind::Op])
505504
.required_token(SyntaxKind::Ascii40)
506505
.any_tokens(None)
507506
.required_token(SyntaxKind::Ascii41),
@@ -658,18 +657,21 @@ pub static STATEMENT_DEFINITIONS: LazyLock<HashMap<SyntaxKind, Vec<StatementDefi
658657
.required_token(SyntaxKind::Ident),
659658
));
660659

661-
m.push(StatementDefinition::new(
662-
SyntaxKind::CreateFunctionStmt,
663-
SyntaxBuilder::new()
664-
.required_token(SyntaxKind::Create)
665-
.optional_token(SyntaxKind::Or)
666-
.optional_token(SyntaxKind::Replace)
667-
.one_of(vec![SyntaxKind::Function, SyntaxKind::Procedure])
668-
.any_tokens(None)
669-
.required_token(SyntaxKind::Ascii40)
670-
.any_tokens(None)
671-
.required_token(SyntaxKind::Ascii41),
672-
));
660+
m.push(
661+
StatementDefinition::new(
662+
SyntaxKind::CreateFunctionStmt,
663+
SyntaxBuilder::new()
664+
.required_token(SyntaxKind::Create)
665+
.optional_token(SyntaxKind::Or)
666+
.optional_token(SyntaxKind::Replace)
667+
.one_of(vec![SyntaxKind::Function, SyntaxKind::Procedure])
668+
.any_tokens(None)
669+
.required_token(SyntaxKind::Ascii40)
670+
.any_tokens(None)
671+
.required_token(SyntaxKind::Ascii41),
672+
)
673+
.with_prohibited_following_statements(vec![SyntaxKind::TransactionStmt]),
674+
);
673675

674676
m.push(StatementDefinition::new(
675677
SyntaxKind::AlterFunctionStmt,
@@ -768,12 +770,15 @@ pub static STATEMENT_DEFINITIONS: LazyLock<HashMap<SyntaxKind, Vec<StatementDefi
768770
.any_token(),
769771
));
770772

771-
m.push(StatementDefinition::new(
772-
SyntaxKind::TransactionStmt,
773-
SyntaxBuilder::new()
774-
.required_token(SyntaxKind::Savepoint)
775-
.required_token(SyntaxKind::Ident),
776-
));
773+
m.push(
774+
StatementDefinition::new(
775+
SyntaxKind::TransactionStmt,
776+
SyntaxBuilder::new()
777+
.required_token(SyntaxKind::Savepoint)
778+
.required_token(SyntaxKind::Ident),
779+
)
780+
.with_prohibited_following_statements(vec![SyntaxKind::TransactionStmt]),
781+
);
777782

778783
m.push(StatementDefinition::new(
779784
SyntaxKind::TransactionStmt,
@@ -792,21 +797,26 @@ pub static STATEMENT_DEFINITIONS: LazyLock<HashMap<SyntaxKind, Vec<StatementDefi
792797
SyntaxBuilder::new().required_token(SyntaxKind::Commit),
793798
));
794799

795-
m.push(StatementDefinition::new(
796-
SyntaxKind::TransactionStmt,
797-
SyntaxBuilder::new()
798-
.required_token(SyntaxKind::Rollback)
799-
.any_tokens(None)
800-
.required_token(SyntaxKind::To)
801-
.optional_token(SyntaxKind::Savepoint)
802-
.required_token(SyntaxKind::Ident),
803-
));
800+
m.push(
801+
StatementDefinition::new(
802+
SyntaxKind::TransactionStmt,
803+
SyntaxBuilder::new()
804+
.required_token(SyntaxKind::Rollback)
805+
.any_tokens(None)
806+
.required_token(SyntaxKind::To)
807+
.optional_token(SyntaxKind::Savepoint)
808+
.required_token(SyntaxKind::Ident),
809+
)
810+
.with_prohibited_following_statements(vec![SyntaxKind::TransactionStmt]),
811+
);
804812

805-
m.push(StatementDefinition::new(
806-
SyntaxKind::TransactionStmt,
807-
// FIXME: conflicts with ROLLBACK TO SAVEPOINT?
808-
SyntaxBuilder::new().required_token(SyntaxKind::Rollback),
809-
));
813+
m.push(
814+
StatementDefinition::new(
815+
SyntaxKind::TransactionStmt,
816+
SyntaxBuilder::new().required_token(SyntaxKind::Rollback),
817+
)
818+
.with_prohibited_following_statements(vec![SyntaxKind::TransactionStmt]),
819+
);
810820

811821
m.push(
812822
StatementDefinition::new(
@@ -926,6 +936,7 @@ pub static STATEMENT_DEFINITIONS: LazyLock<HashMap<SyntaxKind, Vec<StatementDefi
926936
SyntaxBuilder::new().required_token(SyntaxKind::Explain),
927937
)
928938
.with_prohibited_following_statements(vec![
939+
SyntaxKind::VacuumStmt,
929940
SyntaxKind::SelectStmt,
930941
SyntaxKind::InsertStmt,
931942
SyntaxKind::DeleteStmt,
@@ -1175,13 +1186,7 @@ pub static STATEMENT_DEFINITIONS: LazyLock<HashMap<SyntaxKind, Vec<StatementDefi
11751186
.required_token(SyntaxKind::Ident)
11761187
.required_token(SyntaxKind::Using)
11771188
.required_token(SyntaxKind::Ident)
1178-
.one_of(vec![
1179-
SyntaxKind::Drop,
1180-
SyntaxKind::AddP,
1181-
SyntaxKind::Rename,
1182-
SyntaxKind::Owner,
1183-
SyntaxKind::Set,
1184-
]),
1189+
.one_of(vec![SyntaxKind::Drop, SyntaxKind::AddP, SyntaxKind::Rename]),
11851190
));
11861191

11871192
m.push(

crates/pg_statement_splitter/src/parser.rs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,15 +130,20 @@ impl Parser {
130130
/// lookbehind method.
131131
///
132132
/// if `ignore_whitespace` is true, it will skip all whitespace tokens
133-
pub fn lookbehind(&self, lookbehind: usize, ignore_whitespace: bool) -> Option<&Token> {
133+
pub fn lookbehind(
134+
&self,
135+
lookbehind: usize,
136+
ignore_whitespace: bool,
137+
start_before: Option<usize>,
138+
) -> Option<&Token> {
134139
if ignore_whitespace {
135140
let mut idx = 0;
136141
let mut non_whitespace_token_ctr = 0;
137142
loop {
138143
if idx > self.pos {
139144
return None;
140145
}
141-
match self.tokens.get(self.pos - idx) {
146+
match self.tokens.get(self.pos - start_before.unwrap_or(0) - idx) {
142147
Some(token) => {
143148
if !WHITESPACE_TOKENS.contains(&token.kind) {
144149
non_whitespace_token_ctr += 1;
@@ -149,7 +154,7 @@ impl Parser {
149154
idx += 1;
150155
}
151156
None => {
152-
if (self.pos - idx) > 0 {
157+
if (self.pos - idx - start_before.unwrap_or(0)) > 0 {
153158
idx += 1;
154159
} else {
155160
return None;
@@ -158,7 +163,8 @@ impl Parser {
158163
}
159164
}
160165
} else {
161-
self.tokens.get(self.pos - lookbehind)
166+
self.tokens
167+
.get(self.pos - lookbehind - start_before.unwrap_or(0))
162168
}
163169
}
164170

crates/pg_statement_splitter/src/statement_splitter.rs

Lines changed: 70 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,15 @@ impl<'a> StatementSplitter<'a> {
4141
match self.parser.nth(0, false).kind {
4242
SyntaxKind::Ascii40 => {
4343
// "("
44-
self.sub_trx_depth += 1;
44+
self.sub_stmt_depth += 1;
4545
}
4646
SyntaxKind::Ascii41 => {
4747
// ")"
48-
self.sub_trx_depth -= 1;
48+
self.sub_stmt_depth -= 1;
4949
}
5050
SyntaxKind::Atomic => {
51-
if self.parser.lookbehind(2, true).map(|t| t.kind) == Some(SyntaxKind::BeginP) {
51+
if self.parser.lookbehind(2, true, None).map(|t| t.kind) == Some(SyntaxKind::BeginP)
52+
{
5253
self.is_within_atomic_block = true;
5354
}
5455
}
@@ -88,7 +89,7 @@ impl<'a> StatementSplitter<'a> {
8889
kind: SyntaxKind::Any,
8990
range: TextRange::new(
9091
self.token_range(started_at.unwrap()).start(),
91-
self.parser.lookbehind(2, true).unwrap().span.end(),
92+
self.parser.lookbehind(2, true, None).unwrap().span.end(),
9293
),
9394
});
9495
}
@@ -115,23 +116,23 @@ impl<'a> StatementSplitter<'a> {
115116
let new_stmts = STATEMENT_DEFINITIONS.get(&self.parser.nth(0, false).kind);
116117

117118
if let Some(new_stmts) = new_stmts {
118-
self.tracked_statements.append(
119-
&mut new_stmts
120-
.iter()
121-
.filter_map(|stmt| {
122-
if self.active_bridges.iter().any(|b| b.def.stmt == stmt.stmt) {
123-
None
124-
} else if self.tracked_statements.iter().any(|s| {
125-
s.could_be_complete()
126-
&& s.def.prohibited_following_statements.contains(&stmt.stmt)
127-
}) {
128-
None
129-
} else {
130-
Some(Tracker::new_at(stmt, self.parser.pos))
131-
}
132-
})
133-
.collect(),
134-
);
119+
let to_add = &mut new_stmts
120+
.iter()
121+
.filter_map(|stmt| {
122+
if self.active_bridges.iter().any(|b| b.def.stmt == stmt.stmt) {
123+
None
124+
} else if self
125+
.tracked_statements
126+
.iter_mut()
127+
.any(|s| !s.can_start_stmt_after(&stmt.stmt))
128+
{
129+
None
130+
} else {
131+
Some(Tracker::new_at(stmt, self.parser.pos))
132+
}
133+
})
134+
.collect();
135+
self.tracked_statements.append(to_add);
135136
}
136137
}
137138

@@ -283,7 +284,11 @@ impl<'a> StatementSplitter<'a> {
283284

284285
pub fn run(mut self) -> Vec<StatementPosition> {
285286
while !self.parser.eof() {
286-
println!("{:?}", self.parser.nth(0, false).kind);
287+
println!(
288+
"#{:?}: {:?}",
289+
self.parser.pos,
290+
self.parser.nth(0, false).kind
291+
);
287292
println!(
288293
"tracked stmts before {:?}",
289294
self.tracked_statements
@@ -342,13 +347,13 @@ impl<'a> StatementSplitter<'a> {
342347

343348
// the end position is the end() of the last non-whitespace token before the start
344349
// of the latest complete statement
345-
let latest_non_whitespace_token = self
346-
.parser
347-
.lookbehind(self.parser.pos - latest_completed_stmt_started_at + 1, true);
350+
let latest_non_whitespace_token = self.parser.lookbehind(
351+
2,
352+
true,
353+
Some(self.parser.pos - latest_completed_stmt_started_at),
354+
);
348355
let end_pos = latest_non_whitespace_token.unwrap().span.end();
349356

350-
println!("adding stmt: {:?}", stmt_kind);
351-
352357
self.ranges.push(StatementPosition {
353358
kind: stmt_kind,
354359
range: TextRange::new(start_pos, end_pos),
@@ -365,13 +370,6 @@ impl<'a> StatementSplitter<'a> {
365370

366371
// we reached eof; add any remaining statements
367372

368-
println!(
369-
"tracked stmts after eof {:?}",
370-
self.tracked_statements
371-
.iter()
372-
.map(|s| s.def.stmt)
373-
.collect::<Vec<_>>()
374-
);
375373
// get the earliest statement that is complete
376374
if let Some(earliest_complete_stmt_started_at) =
377375
self.find_earliest_complete_statement_start_pos()
@@ -387,7 +385,7 @@ impl<'a> StatementSplitter<'a> {
387385

388386
let start_pos = self.token_range(earliest_complete_stmt_started_at).start();
389387

390-
let end_token = self.parser.lookbehind(1, true).unwrap();
388+
let end_token = self.parser.lookbehind(1, true, None).unwrap();
391389
let end_pos = end_token.span.end();
392390

393391
println!("adding stmt at end: {:?}", earliest_complete_stmt.def.stmt);
@@ -406,7 +404,7 @@ impl<'a> StatementSplitter<'a> {
406404
let start_pos = self.token_range(earliest_stmt_started_at).start();
407405

408406
// end position is last non-whitespace token before or at the current position
409-
let end_pos = self.parser.lookbehind(1, true).unwrap().span.end();
407+
let end_pos = self.parser.lookbehind(1, true, None).unwrap().span.end();
410408

411409
println!("adding any stmt at end");
412410
self.ranges.push(StatementPosition {
@@ -425,6 +423,37 @@ mod tests {
425423

426424
use crate::statement_splitter::StatementSplitter;
427425

426+
#[test]
427+
fn test_simple_select() {
428+
let input = "
429+
select id, name, test1231234123, unknown from co;
430+
431+
select 14433313331333
432+
433+
alter table test drop column id;
434+
435+
select lower('test');
436+
";
437+
438+
let result = StatementSplitter::new(input).run();
439+
440+
assert_eq!(result.len(), 4);
441+
assert_eq!(
442+
"select id, name, test1231234123, unknown from co;",
443+
input[result[0].range].to_string()
444+
);
445+
assert_eq!(SyntaxKind::SelectStmt, result[0].kind);
446+
assert_eq!("select 14433313331333", input[result[1].range].to_string());
447+
assert_eq!(SyntaxKind::SelectStmt, result[1].kind);
448+
assert_eq!(SyntaxKind::AlterTableStmt, result[2].kind);
449+
assert_eq!(
450+
"alter table test drop column id;",
451+
input[result[2].range].to_string()
452+
);
453+
assert_eq!(SyntaxKind::SelectStmt, result[3].kind);
454+
assert_eq!("select lower('test');", input[result[3].range].to_string());
455+
}
456+
428457
#[test]
429458
fn test_create_or_replace() {
430459
let input = "CREATE OR REPLACE TRIGGER check_update
@@ -586,19 +615,19 @@ mod tests {
586615

587616
#[test]
588617
fn test_explain_analyze() {
589-
let input = "explain analyze select 1 from contact\nselect 1\nselect 4";
618+
let input = "explain analyze select 1 from contact;\nselect 1;\nselect 4;";
590619

591620
let result = StatementSplitter::new(input).run();
592621

593622
assert_eq!(result.len(), 3);
594623
assert_eq!(
595-
"explain analyze select 1 from contact",
624+
"explain analyze select 1 from contact;",
596625
input[result[0].range].to_string()
597626
);
598627
assert_eq!(SyntaxKind::ExplainStmt, result[0].kind);
599-
assert_eq!("select 1", input[result[1].range].to_string());
628+
assert_eq!("select 1;", input[result[1].range].to_string());
600629
assert_eq!(SyntaxKind::SelectStmt, result[1].kind);
601-
assert_eq!("select 4", input[result[2].range].to_string());
630+
assert_eq!("select 4;", input[result[2].range].to_string());
602631
assert_eq!(SyntaxKind::SelectStmt, result[2].kind);
603632
}
604633

@@ -694,10 +723,6 @@ DROP ROLE IF EXISTS regress_alter_generic_user1;";
694723

695724
let result = StatementSplitter::new(input).run();
696725

697-
for r in &result {
698-
println!("{:?} {:?}", r.kind, input[r.range].to_string());
699-
}
700-
701726
assert_eq!(result.len(), 2);
702727
assert_eq!("create", input[result[0].range].to_string());
703728
assert_eq!(SyntaxKind::Any, result[0].kind);
@@ -1008,7 +1033,7 @@ ALTER OPERATOR FAMILY alt_nsp6.alt_opf6 USING btree ADD OPERATOR 1 < (int4, int2
10081033
fn test_alter_op_family_2() {
10091034
let input = "
10101035
CREATE OPERATOR FAMILY alt_opf4 USING btree;
1011-
ALTER OPERATOR FAMILY schema.alt_opf4 USING btree ADD
1036+
ALTER OPERATOR FAMILY test.alt_opf4 USING btree ADD
10121037
-- int4 vs int2
10131038
OPERATOR 1 < (int4, int2) ,
10141039
OPERATOR 2 <= (int4, int2) ,

0 commit comments

Comments
 (0)