Skip to content

Commit 69bf227

Browse files
committed
fix: minor stuff
1 parent 0b4a9fd commit 69bf227

File tree

4 files changed

+217
-53
lines changed

4 files changed

+217
-53
lines changed

crates/pg_statement_splitter/src/data.rs

Lines changed: 126 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ pub struct StatementDefinition {
7878
pub stmt: SyntaxKind,
7979
pub tokens: Vec<SyntaxDefinition>,
8080
pub prohibited_following_statements: Vec<SyntaxKind>,
81+
pub prohibited_tokens: Vec<SyntaxKind>,
8182
}
8283

8384
impl StatementDefinition {
@@ -86,9 +87,15 @@ impl StatementDefinition {
8687
stmt,
8788
tokens: b.build(),
8889
prohibited_following_statements: Vec::new(),
90+
prohibited_tokens: Vec::new(),
8991
}
9092
}
9193

94+
fn with_prohibited_tokens(mut self, prohibited: Vec<SyntaxKind>) -> Self {
95+
self.prohibited_tokens = prohibited;
96+
self
97+
}
98+
9299
fn with_prohibited_following_statements(mut self, prohibited: Vec<SyntaxKind>) -> Self {
93100
self.prohibited_following_statements = prohibited;
94101
self
@@ -223,7 +230,11 @@ pub static STATEMENT_DEFINITIONS: LazyLock<HashMap<SyntaxKind, Vec<StatementDefi
223230
.optional_if_exists_group()
224231
.optional_token(SyntaxKind::Only)
225232
.optional_schema_name_group()
226-
.required_token(SyntaxKind::Ident)
233+
.one_of(vec![
234+
SyntaxKind::Ident,
235+
SyntaxKind::VersionP,
236+
SyntaxKind::Simple,
237+
])
227238
.any_token(),
228239
));
229240

@@ -273,13 +284,16 @@ pub static STATEMENT_DEFINITIONS: LazyLock<HashMap<SyntaxKind, Vec<StatementDefi
273284
.required_token(SyntaxKind::Ascii41),
274285
));
275286

276-
m.push(StatementDefinition::new(
277-
SyntaxKind::AlterDefaultPrivilegesStmt,
278-
SyntaxBuilder::new()
279-
.required_token(SyntaxKind::Alter)
280-
.required_token(SyntaxKind::Default)
281-
.required_token(SyntaxKind::Privileges),
282-
));
287+
m.push(
288+
StatementDefinition::new(
289+
SyntaxKind::AlterDefaultPrivilegesStmt,
290+
SyntaxBuilder::new()
291+
.required_token(SyntaxKind::Alter)
292+
.required_token(SyntaxKind::Default)
293+
.required_token(SyntaxKind::Privileges),
294+
)
295+
.with_prohibited_following_statements(vec![SyntaxKind::GrantStmt]),
296+
);
283297

284298
m.push(StatementDefinition::new(
285299
SyntaxKind::ClusterStmt,
@@ -387,6 +401,17 @@ pub static STATEMENT_DEFINITIONS: LazyLock<HashMap<SyntaxKind, Vec<StatementDefi
387401
.required_token(SyntaxKind::Ident),
388402
));
389403

404+
m.push(StatementDefinition::new(
405+
SyntaxKind::DropStmt,
406+
SyntaxBuilder::new()
407+
.required_token(SyntaxKind::Drop)
408+
.required_token(SyntaxKind::Materialized)
409+
.required_token(SyntaxKind::View)
410+
.optional_if_exists_group()
411+
.optional_schema_name_group()
412+
.required_token(SyntaxKind::Ident),
413+
));
414+
390415
m.push(StatementDefinition::new(
391416
SyntaxKind::DropStmt,
392417
SyntaxBuilder::new()
@@ -822,6 +847,11 @@ pub static STATEMENT_DEFINITIONS: LazyLock<HashMap<SyntaxKind, Vec<StatementDefi
822847
SyntaxBuilder::new().required_token(SyntaxKind::BeginP),
823848
));
824849

850+
m.push(StatementDefinition::new(
851+
SyntaxKind::TransactionStmt,
852+
SyntaxBuilder::new().required_token(SyntaxKind::EndP),
853+
));
854+
825855
m.push(StatementDefinition::new(
826856
SyntaxKind::TransactionStmt,
827857
SyntaxBuilder::new()
@@ -942,7 +972,11 @@ pub static STATEMENT_DEFINITIONS: LazyLock<HashMap<SyntaxKind, Vec<StatementDefi
942972
.required_token(SyntaxKind::Table)
943973
.optional_if_not_exists_group()
944974
.optional_schema_name_group()
945-
.required_token(SyntaxKind::Ident)
975+
.one_of(vec![
976+
SyntaxKind::Ident,
977+
SyntaxKind::VersionP,
978+
SyntaxKind::Simple,
979+
])
946980
.any_tokens(None)
947981
.required_token(SyntaxKind::As)
948982
.any_token(),
@@ -973,7 +1007,19 @@ pub static STATEMENT_DEFINITIONS: LazyLock<HashMap<SyntaxKind, Vec<StatementDefi
9731007
m.push(
9741008
StatementDefinition::new(
9751009
SyntaxKind::ExplainStmt,
976-
SyntaxBuilder::new().required_token(SyntaxKind::Explain),
1010+
SyntaxBuilder::new()
1011+
.required_token(SyntaxKind::Explain)
1012+
.one_of(vec![
1013+
SyntaxKind::Analyze,
1014+
SyntaxKind::Ascii40,
1015+
SyntaxKind::Select,
1016+
SyntaxKind::Insert,
1017+
SyntaxKind::Update,
1018+
SyntaxKind::DeleteP,
1019+
SyntaxKind::Merge,
1020+
SyntaxKind::Execute,
1021+
SyntaxKind::Create,
1022+
]),
9771023
)
9781024
.with_prohibited_following_statements(vec![
9791025
SyntaxKind::VacuumStmt,
@@ -983,6 +1029,7 @@ pub static STATEMENT_DEFINITIONS: LazyLock<HashMap<SyntaxKind, Vec<StatementDefi
9831029
SyntaxKind::UpdateStmt,
9841030
SyntaxKind::MergeStmt,
9851031
SyntaxKind::ExecuteStmt,
1032+
SyntaxKind::CreateTableAsStmt,
9861033
]),
9871034
);
9881035

@@ -1105,6 +1152,18 @@ pub static STATEMENT_DEFINITIONS: LazyLock<HashMap<SyntaxKind, Vec<StatementDefi
11051152
.required_token(SyntaxKind::Ident),
11061153
));
11071154

1155+
m.push(
1156+
StatementDefinition::new(
1157+
SyntaxKind::AlterRoleSetStmt,
1158+
SyntaxBuilder::new()
1159+
.required_token(SyntaxKind::Alter)
1160+
.required_token(SyntaxKind::Role)
1161+
.required_token(SyntaxKind::Ident)
1162+
.required_token(SyntaxKind::Set),
1163+
)
1164+
.with_prohibited_following_statements(vec![SyntaxKind::VariableSetStmt]),
1165+
);
1166+
11081167
m.push(StatementDefinition::new(
11091168
SyntaxKind::DropRoleStmt,
11101169
SyntaxBuilder::new()
@@ -1160,12 +1219,23 @@ pub static STATEMENT_DEFINITIONS: LazyLock<HashMap<SyntaxKind, Vec<StatementDefi
11601219
SyntaxBuilder::new().required_token(SyntaxKind::Checkpoint),
11611220
));
11621221

1163-
m.push(StatementDefinition::new(
1164-
SyntaxKind::CreateSchemaStmt,
1165-
SyntaxBuilder::new()
1166-
.required_token(SyntaxKind::Create)
1167-
.required_token(SyntaxKind::Schema),
1168-
));
1222+
// CREATE TABLE, CREATE VIEW, CREATE INDEX, CREATE SEQUENCE, CREATE TRIGGER and GRANT
1223+
m.push(
1224+
StatementDefinition::new(
1225+
SyntaxKind::CreateSchemaStmt,
1226+
SyntaxBuilder::new()
1227+
.required_token(SyntaxKind::Create)
1228+
.required_token(SyntaxKind::Schema),
1229+
)
1230+
.with_prohibited_following_statements(vec![
1231+
SyntaxKind::CreateTableAsStmt,
1232+
SyntaxKind::CreateStmt,
1233+
SyntaxKind::IndexStmt,
1234+
SyntaxKind::CreateSeqStmt,
1235+
SyntaxKind::CreateTrigStmt,
1236+
SyntaxKind::GrantStmt,
1237+
]),
1238+
);
11691239

11701240
m.push(StatementDefinition::new(
11711241
SyntaxKind::AlterDatabaseStmt,
@@ -1233,18 +1303,21 @@ pub static STATEMENT_DEFINITIONS: LazyLock<HashMap<SyntaxKind, Vec<StatementDefi
12331303
.required_token(SyntaxKind::Ident),
12341304
));
12351305

1236-
m.push(StatementDefinition::new(
1237-
SyntaxKind::AlterOpFamilyStmt,
1238-
SyntaxBuilder::new()
1239-
.required_token(SyntaxKind::Alter)
1240-
.required_token(SyntaxKind::Operator)
1241-
.required_token(SyntaxKind::Family)
1242-
.optional_schema_name_group()
1243-
.required_token(SyntaxKind::Ident)
1244-
.required_token(SyntaxKind::Using)
1245-
.required_token(SyntaxKind::Ident)
1246-
.one_of(vec![SyntaxKind::Drop, SyntaxKind::AddP, SyntaxKind::Rename]),
1247-
));
1306+
m.push(
1307+
StatementDefinition::new(
1308+
SyntaxKind::AlterOpFamilyStmt,
1309+
SyntaxBuilder::new()
1310+
.required_token(SyntaxKind::Alter)
1311+
.required_token(SyntaxKind::Operator)
1312+
.required_token(SyntaxKind::Family)
1313+
.optional_schema_name_group()
1314+
.required_token(SyntaxKind::Ident)
1315+
.required_token(SyntaxKind::Using)
1316+
.required_token(SyntaxKind::Ident)
1317+
.one_of(vec![SyntaxKind::Drop, SyntaxKind::AddP, SyntaxKind::Rename]),
1318+
)
1319+
.with_prohibited_tokens(vec![SyntaxKind::Rename]),
1320+
);
12481321

12491322
m.push(
12501323
StatementDefinition::new(
@@ -1256,9 +1329,21 @@ pub static STATEMENT_DEFINITIONS: LazyLock<HashMap<SyntaxKind, Vec<StatementDefi
12561329
.required_token(SyntaxKind::As)
12571330
.any_token(),
12581331
)
1259-
.with_prohibited_following_statements(vec![SyntaxKind::SelectStmt]),
1332+
.with_prohibited_following_statements(vec![
1333+
SyntaxKind::SelectStmt,
1334+
SyntaxKind::InsertStmt,
1335+
SyntaxKind::UpdateStmt,
1336+
SyntaxKind::DeleteStmt,
1337+
]),
12601338
);
12611339

1340+
m.push(StatementDefinition::new(
1341+
SyntaxKind::ClosePortalStmt,
1342+
SyntaxBuilder::new()
1343+
.required_token(SyntaxKind::Close)
1344+
.one_of(vec![SyntaxKind::Ident, SyntaxKind::All]),
1345+
));
1346+
12621347
m.push(StatementDefinition::new(
12631348
SyntaxKind::DeallocateStmt,
12641349
SyntaxBuilder::new()
@@ -1331,15 +1416,18 @@ pub static STATEMENT_DEFINITIONS: LazyLock<HashMap<SyntaxKind, Vec<StatementDefi
13311416
.required_token(SyntaxKind::Ident),
13321417
));
13331418

1334-
m.push(StatementDefinition::new(
1335-
SyntaxKind::AlterFdwStmt,
1336-
SyntaxBuilder::new()
1337-
.required_token(SyntaxKind::Alter)
1338-
.required_token(SyntaxKind::Foreign)
1339-
.required_token(SyntaxKind::DataP)
1340-
.required_token(SyntaxKind::Wrapper)
1341-
.required_token(SyntaxKind::Ident),
1342-
));
1419+
m.push(
1420+
StatementDefinition::new(
1421+
SyntaxKind::AlterFdwStmt,
1422+
SyntaxBuilder::new()
1423+
.required_token(SyntaxKind::Alter)
1424+
.required_token(SyntaxKind::Foreign)
1425+
.required_token(SyntaxKind::DataP)
1426+
.required_token(SyntaxKind::Wrapper)
1427+
.required_token(SyntaxKind::Ident),
1428+
)
1429+
.with_prohibited_tokens(vec![SyntaxKind::Rename]),
1430+
);
13431431

13441432
m.push(StatementDefinition::new(
13451433
SyntaxKind::CreateForeignServerStmt,

crates/pg_statement_splitter/src/statement_splitter.rs

Lines changed: 63 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -37,25 +37,31 @@ impl<'a> StatementSplitter<'a> {
3737
}
3838
}
3939

40-
fn track_nesting(&mut self) {
40+
fn end_nesting(&mut self) {
4141
match self.parser.nth(0, false).kind {
42-
SyntaxKind::Ascii40 => {
43-
// "("
44-
self.sub_stmt_depth += 1;
45-
}
4642
SyntaxKind::Ascii41 => {
4743
// ")"
4844
self.sub_stmt_depth -= 1;
4945
}
46+
SyntaxKind::EndP => {
47+
self.is_within_atomic_block = false;
48+
}
49+
_ => {}
50+
};
51+
}
52+
53+
fn start_nesting(&mut self) {
54+
match self.parser.nth(0, false).kind {
55+
SyntaxKind::Ascii40 => {
56+
// "("
57+
self.sub_stmt_depth += 1;
58+
}
5059
SyntaxKind::Atomic => {
5160
if self.parser.lookbehind(2, true, None).map(|t| t.kind) == Some(SyntaxKind::BeginP)
5261
{
5362
self.is_within_atomic_block = true;
5463
}
5564
}
56-
SyntaxKind::EndP => {
57-
self.is_within_atomic_block = false;
58-
}
5965
_ => {}
6066
};
6167
}
@@ -177,19 +183,19 @@ impl<'a> StatementSplitter<'a> {
177183
.min_by_key(|stmt| stmt.started_at)
178184
.map(|stmt| stmt.started_at)
179185
{
180-
println!(
181-
"earliest complete stmt started at: {:?}",
182-
earliest_complete_stmt_started_at
183-
);
184186
let earliest_complete_stmt = self
185187
.tracked_statements
186188
.iter()
187189
.filter(|s| {
188190
s.started_at == earliest_complete_stmt_started_at && s.could_be_complete()
189191
})
190-
.max_by_key(|stmt| stmt.max_pos())
192+
.max_by_key(|stmt| {
193+
println!("stmt: {:?} max pos: {:?}", stmt.def.stmt, stmt.max_pos());
194+
stmt.max_pos()
195+
})
191196
.unwrap();
192197

198+
println!("earliest complete stmt: {:?}", earliest_complete_stmt);
193199
assert_eq!(
194200
1,
195201
self.tracked_statements
@@ -304,7 +310,7 @@ impl<'a> StatementSplitter<'a> {
304310
.collect::<Vec<_>>()
305311
);
306312

307-
self.track_nesting();
313+
self.start_nesting();
308314

309315
let removed_items_min_started_at = self.advance_tracker();
310316

@@ -328,6 +334,8 @@ impl<'a> StatementSplitter<'a> {
328334
self.close_stmt_with_semicolon();
329335
}
330336

337+
self.end_nesting();
338+
331339
// # This is where the actual parsing happens
332340

333341
// 1. Find the latest complete statement
@@ -1360,6 +1368,47 @@ DROP LANGUAGE IF EXISTS test_language_exists;
13601368
assert_eq!(SyntaxKind::DropStmt, result[2].kind);
13611369
}
13621370

1371+
#[test]
1372+
fn alter_mat_view() {
1373+
let input = "
1374+
ALTER MATERIALIZED VIEW mvtest_tvm SET SCHEMA mvtest_mvschema;
1375+
";
1376+
let result = StatementSplitter::new(input).run();
1377+
1378+
assert_eq!(result.len(), 1);
1379+
assert_eq!(SyntaxKind::AlterObjectSchemaStmt, result[0].kind);
1380+
}
1381+
1382+
#[test]
1383+
fn create_tbl_as_2() {
1384+
let input = "
1385+
create table simple as
1386+
select generate_series(1, 20000) AS id, 'aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa';
1387+
";
1388+
let result = StatementSplitter::new(input).run();
1389+
1390+
assert_eq!(result.len(), 1);
1391+
assert_eq!(SyntaxKind::CreateTableAsStmt, result[0].kind);
1392+
}
1393+
1394+
#[test]
1395+
fn create_tbl_as() {
1396+
let input = "
1397+
CREATE TABLE tab_settings_flags AS SELECT name, category,
1398+
'EXPLAIN' = ANY(flags) AS explain,
1399+
'NO_RESET_ALL' = ANY(flags) AS no_reset_all,
1400+
'NO_SHOW_ALL' = ANY(flags) AS no_show_all,
1401+
'NOT_IN_SAMPLE' = ANY(flags) AS not_in_sample,
1402+
'RUNTIME_COMPUTED' = ANY(flags) AS runtime_computed
1403+
FROM pg_show_all_settings() AS psas,
1404+
pg_settings_get_flags(psas.name) AS flags;
1405+
";
1406+
let result = StatementSplitter::new(input).run();
1407+
1408+
assert_eq!(result.len(), 1);
1409+
assert_eq!(SyntaxKind::CreateTableAsStmt, result[0].kind);
1410+
}
1411+
13631412
#[allow(clippy::must_use)]
13641413
fn debug(input: &str) {
13651414
for s in input.split(';').filter_map(|s| {

0 commit comments

Comments
 (0)