Skip to content

Commit 3eea69f

Browse files
committed
fix: rewrite statement splitter wip
1 parent eeb64f5 commit 3eea69f

File tree

6 files changed

+776
-230
lines changed

6 files changed

+776
-230
lines changed
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
use pg_lexer::SyntaxKind;
2+
use std::{collections::HashMap, sync::LazyLock};
3+
4+
#[derive(Debug)]
5+
pub enum SyntaxDefinition {
6+
RequiredToken(SyntaxKind),
7+
OptionalToken(SyntaxKind),
8+
AnyTokens,
9+
AnyToken,
10+
OneOf(Vec<SyntaxKind>),
11+
}
12+
13+
#[derive(Debug)]
14+
pub struct StatementDefinition {
15+
pub stmt: SyntaxKind,
16+
pub tokens: Vec<SyntaxDefinition>,
17+
}
18+
19+
pub static STATEMENT_BRIDGE_DEFINITIONS: LazyLock<HashMap<SyntaxKind, Vec<StatementDefinition>>> =
20+
LazyLock::new(|| {
21+
let mut m: Vec<StatementDefinition> = Vec::new();
22+
23+
m.push(StatementDefinition {
24+
stmt: SyntaxKind::SelectStmt,
25+
tokens: vec![
26+
SyntaxDefinition::RequiredToken(SyntaxKind::Union),
27+
SyntaxDefinition::OptionalToken(SyntaxKind::All),
28+
],
29+
});
30+
31+
m.push(StatementDefinition {
32+
stmt: SyntaxKind::SelectStmt,
33+
tokens: vec![
34+
SyntaxDefinition::RequiredToken(SyntaxKind::Intersect),
35+
SyntaxDefinition::OptionalToken(SyntaxKind::All),
36+
],
37+
});
38+
39+
m.push(StatementDefinition {
40+
stmt: SyntaxKind::SelectStmt,
41+
tokens: vec![
42+
SyntaxDefinition::RequiredToken(SyntaxKind::Except),
43+
SyntaxDefinition::OptionalToken(SyntaxKind::All),
44+
],
45+
});
46+
47+
let mut stmt_starts: HashMap<SyntaxKind, Vec<StatementDefinition>> = HashMap::new();
48+
49+
for stmt in m {
50+
let first_token = stmt.tokens.get(0).unwrap();
51+
if let SyntaxDefinition::RequiredToken(kind) = first_token {
52+
stmt_starts.entry(*kind).or_insert(Vec::new()).push(stmt);
53+
} else {
54+
panic!("Expected RequiredToken as first token in bridge definition");
55+
}
56+
}
57+
58+
stmt_starts
59+
});
60+
61+
pub static STATEMENT_DEFINITIONS: LazyLock<HashMap<SyntaxKind, Vec<StatementDefinition>>> =
62+
LazyLock::new(|| {
63+
let mut m: Vec<StatementDefinition> = Vec::new();
64+
65+
m.push(StatementDefinition {
66+
stmt: SyntaxKind::CreateTrigStmt,
67+
tokens: vec![
68+
SyntaxDefinition::RequiredToken(SyntaxKind::Create),
69+
SyntaxDefinition::OptionalToken(SyntaxKind::Or),
70+
SyntaxDefinition::OptionalToken(SyntaxKind::Replace),
71+
SyntaxDefinition::OptionalToken(SyntaxKind::Constraint),
72+
SyntaxDefinition::RequiredToken(SyntaxKind::Trigger),
73+
SyntaxDefinition::RequiredToken(SyntaxKind::Ident),
74+
SyntaxDefinition::AnyTokens,
75+
SyntaxDefinition::RequiredToken(SyntaxKind::On),
76+
SyntaxDefinition::RequiredToken(SyntaxKind::Ident),
77+
SyntaxDefinition::AnyTokens,
78+
SyntaxDefinition::RequiredToken(SyntaxKind::Execute),
79+
SyntaxDefinition::OneOf(vec![SyntaxKind::Function, SyntaxKind::Procedure]),
80+
SyntaxDefinition::RequiredToken(SyntaxKind::Ident),
81+
],
82+
});
83+
84+
m.push(StatementDefinition {
85+
stmt: SyntaxKind::SelectStmt,
86+
tokens: vec![SyntaxDefinition::RequiredToken(SyntaxKind::Select)],
87+
});
88+
89+
m.push(StatementDefinition {
90+
stmt: SyntaxKind::ExecuteStmt,
91+
tokens: vec![
92+
SyntaxDefinition::RequiredToken(SyntaxKind::Execute),
93+
SyntaxDefinition::RequiredToken(SyntaxKind::Ident),
94+
],
95+
});
96+
97+
let mut stmt_starts: HashMap<SyntaxKind, Vec<StatementDefinition>> = HashMap::new();
98+
99+
for stmt in m {
100+
let first_token = stmt.tokens.get(0).unwrap();
101+
if let SyntaxDefinition::RequiredToken(kind) = first_token {
102+
stmt_starts.entry(*kind).or_insert(Vec::new()).push(stmt);
103+
} else {
104+
panic!("Expected RequiredToken as first token in statement definition");
105+
}
106+
}
107+
108+
stmt_starts
109+
});

crates/pg_statement_splitter/src/lib.rs

Lines changed: 126 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,11 @@
99
/// We should expand the definition map to include an `Any*`, which must be followed by at least
1010
/// one required token and allows the parser to search for the end tokens of the statement. This
1111
/// will hopefully be enough to reduce collisions to zero.
12+
mod data;
1213
mod is_at_stmt_start;
1314
mod parser;
15+
mod statement_splitter;
16+
mod statement_tracker;
1417
mod syntax_error;
1518

1619
use is_at_stmt_start::{is_at_stmt_start, TokenStatement, STATEMENT_START_TOKEN_MAPS};
@@ -19,119 +22,132 @@ use parser::{Parse, Parser};
1922

2023
use pg_lexer::{lex, SyntaxKind};
2124

22-
pub fn split(sql: &str) -> Parse {
23-
let mut parser = Parser::new(lex(sql));
24-
25-
while !parser.eof() {
26-
match is_at_stmt_start(&mut parser) {
27-
Some(stmt) => {
28-
parser.start_stmt();
29-
30-
// advance over all start tokens of the statement
31-
for i in 0..STATEMENT_START_TOKEN_MAPS.len() {
32-
parser.eat_whitespace();
33-
let token = parser.nth(0, false);
34-
if let Some(result) = STATEMENT_START_TOKEN_MAPS[i].get(&token.kind) {
35-
let is_in_results = result
36-
.iter()
37-
.find(|x| match x {
38-
TokenStatement::EoS(y) | TokenStatement::Any(y) => y == &stmt,
39-
})
40-
.is_some();
41-
if i == 0 && !is_in_results {
42-
panic!("Expected statement start");
43-
} else if is_in_results {
44-
parser.expect(token.kind);
45-
} else {
46-
break;
47-
}
48-
}
49-
}
50-
51-
// move until the end of the statement, or until the next statement start
52-
let mut is_sub_stmt = 0;
53-
let mut is_sub_trx = 0;
54-
let mut ignore_next_non_whitespace = false;
55-
while !parser.at(SyntaxKind::Ascii59) && !parser.eof() {
56-
match parser.nth(0, false).kind {
57-
SyntaxKind::All => {
58-
// ALL is never a statement start, but needs to be skipped when combining queries
59-
// (e.g. UNION ALL)
60-
parser.advance();
61-
}
62-
SyntaxKind::BeginP => {
63-
// BEGIN, consume until END
64-
is_sub_trx += 1;
65-
parser.advance();
66-
}
67-
SyntaxKind::EndP => {
68-
is_sub_trx -= 1;
69-
parser.advance();
70-
}
71-
// opening brackets "(", consume until closing bracket ")"
72-
SyntaxKind::Ascii40 => {
73-
is_sub_stmt += 1;
74-
parser.advance();
75-
}
76-
SyntaxKind::Ascii41 => {
77-
is_sub_stmt -= 1;
78-
parser.advance();
79-
}
80-
SyntaxKind::As
81-
| SyntaxKind::Union
82-
| SyntaxKind::Intersect
83-
| SyntaxKind::Except => {
84-
// ignore the next non-whitespace token
85-
ignore_next_non_whitespace = true;
86-
parser.advance();
87-
}
88-
_ => {
89-
// if another stmt FIRST is encountered, break
90-
// ignore if parsing sub stmt
91-
if ignore_next_non_whitespace == false
92-
&& is_sub_stmt == 0
93-
&& is_sub_trx == 0
94-
&& is_at_stmt_start(&mut parser).is_some()
95-
{
96-
break;
97-
} else {
98-
if ignore_next_non_whitespace == true && !parser.at_whitespace() {
99-
ignore_next_non_whitespace = false;
100-
}
101-
parser.advance();
102-
}
103-
}
104-
}
105-
}
106-
107-
parser.expect(SyntaxKind::Ascii59);
108-
109-
parser.close_stmt();
110-
}
111-
None => {
112-
parser.advance();
113-
}
114-
}
115-
}
116-
117-
parser.finish()
118-
}
25+
// pub fn split(sql: &str) -> Parse {
26+
// let mut parser = Parser::new(lex(sql));
27+
//
28+
// while !parser.eof() {
29+
// if parser.at_whitespace() {
30+
// parser.advance();
31+
// continue;
32+
// }
33+
// // check all current active statements if the token matches
34+
// // check if there is a new statement starting at the current token
35+
// }
36+
// }
37+
//
38+
// pub fn split(sql: &str) -> Parse {
39+
// let mut parser = Parser::new(lex(sql));
40+
//
41+
// while !parser.eof() {
42+
// match is_at_stmt_start(&mut parser) {
43+
// Some(stmt) => {
44+
// parser.start_stmt();
45+
//
46+
// // advance over all start tokens of the statement
47+
// for i in 0..STATEMENT_START_TOKEN_MAPS.len() {
48+
// parser.eat_whitespace();
49+
// let token = parser.nth(0, false);
50+
// if let Some(result) = STATEMENT_START_TOKEN_MAPS[i].get(&token.kind) {
51+
// let is_in_results = result
52+
// .iter()
53+
// .find(|x| match x {
54+
// TokenStatement::EoS(y) | TokenStatement::Any(y) => y == &stmt,
55+
// })
56+
// .is_some();
57+
// if i == 0 && !is_in_results {
58+
// panic!("Expected statement start");
59+
// } else if is_in_results {
60+
// parser.expect(token.kind);
61+
// } else {
62+
// break;
63+
// }
64+
// }
65+
// }
66+
//
67+
// // move until the end of the statement, or until the next statement start
68+
// let mut is_sub_stmt = 0;
69+
// let mut is_sub_trx = 0;
70+
// let mut ignore_next_non_whitespace = false;
71+
// while !parser.at(SyntaxKind::Ascii59) && !parser.eof() {
72+
// match parser.nth(0, false).kind {
73+
// SyntaxKind::All => {
74+
// // ALL is never a statement start, but needs to be skipped when combining queries
75+
// // (e.g. UNION ALL)
76+
// parser.advance();
77+
// }
78+
// SyntaxKind::BeginP => {
79+
// // BEGIN, consume until END
80+
// is_sub_trx += 1;
81+
// parser.advance();
82+
// }
83+
// SyntaxKind::EndP => {
84+
// is_sub_trx -= 1;
85+
// parser.advance();
86+
// }
87+
// // opening brackets "(", consume until closing bracket ")"
88+
// SyntaxKind::Ascii40 => {
89+
// is_sub_stmt += 1;
90+
// parser.advance();
91+
// }
92+
// SyntaxKind::Ascii41 => {
93+
// is_sub_stmt -= 1;
94+
// parser.advance();
95+
// }
96+
// SyntaxKind::As
97+
// | SyntaxKind::Union
98+
// | SyntaxKind::Intersect
99+
// | SyntaxKind::Except => {
100+
// // ignore the next non-whitespace token
101+
// ignore_next_non_whitespace = true;
102+
// parser.advance();
103+
// }
104+
// _ => {
105+
// // if another stmt FIRST is encountered, break
106+
// // ignore if parsing sub stmt
107+
// if ignore_next_non_whitespace == false
108+
// && is_sub_stmt == 0
109+
// && is_sub_trx == 0
110+
// && is_at_stmt_start(&mut parser).is_some()
111+
// {
112+
// break;
113+
// } else {
114+
// if ignore_next_non_whitespace == true && !parser.at_whitespace() {
115+
// ignore_next_non_whitespace = false;
116+
// }
117+
// parser.advance();
118+
// }
119+
// }
120+
// }
121+
// }
122+
//
123+
// parser.expect(SyntaxKind::Ascii59);
124+
//
125+
// parser.close_stmt();
126+
// }
127+
// None => {
128+
// parser.advance();
129+
// }
130+
// }
131+
// }
132+
//
133+
// parser.finish()
134+
// }
119135

120136
#[cfg(test)]
121137
mod tests {
122138
use super::*;
123139

124-
#[test]
125-
fn test_splitter() {
126-
let input = "select 1 from contact;\nselect 1;\nalter table test drop column id;";
127-
128-
let res = split(input);
129-
assert_eq!(res.ranges.len(), 3);
130-
assert_eq!("select 1 from contact;", input[res.ranges[0]].to_string());
131-
assert_eq!("select 1;", input[res.ranges[1]].to_string());
132-
assert_eq!(
133-
"alter table test drop column id;",
134-
input[res.ranges[2]].to_string()
135-
);
136-
}
140+
// #[test]
141+
// fn test_splitter() {
142+
// let input = "select 1 from contact;\nselect 1;\nalter table test drop column id;";
143+
//
144+
// let res = split(input);
145+
// assert_eq!(res.ranges.len(), 3);
146+
// assert_eq!("select 1 from contact;", input[res.ranges[0]].to_string());
147+
// assert_eq!("select 1;", input[res.ranges[1]].to_string());
148+
// assert_eq!(
149+
// "alter table test drop column id;",
150+
// input[res.ranges[2]].to_string()
151+
// );
152+
// }
137153
}

crates/pg_statement_splitter/src/parser.rs

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -67,17 +67,6 @@ impl Parser {
6767
}
6868
}
6969

70-
pub fn start_stmt(&mut self) {
71-
assert!(self.current_stmt_start.is_none());
72-
self.current_stmt_start = Some(self.pos);
73-
}
74-
75-
pub fn close_stmt(&mut self) {
76-
assert!(self.current_stmt_start.is_some());
77-
self.ranges
78-
.push((self.current_stmt_start.take().unwrap(), self.pos));
79-
}
80-
8170
/// collects an SyntaxError with an `error` message at `pos`
8271
pub fn error_at_pos(&mut self, error: String, pos: usize) {
8372
self.errors.push(SyntaxError::new_at_offset(
@@ -93,7 +82,8 @@ impl Parser {
9382
/// applies token and advances
9483
pub fn advance(&mut self) {
9584
assert!(!self.eof());
96-
if self.nth(0, false).kind == SyntaxKind::Whitespace {
85+
let token = self.nth(0, false);
86+
if token.kind == SyntaxKind::Whitespace {
9787
if self.whitespace_token_buffer.is_none() {
9888
self.whitespace_token_buffer = Some(self.pos);
9989
}

0 commit comments

Comments
 (0)