Skip to content

Commit e6a7f86

Browse files
authored
Move block shape check (#2471)
Fixes point 2 of #2327
1 parent ee650ef commit e6a7f86

File tree

2 files changed

+119
-136
lines changed

2 files changed

+119
-136
lines changed

executor/src/witgen/jit/block_machine_processor.rs

+100-17
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::collections::HashSet;
1+
use std::collections::{BTreeMap, BTreeSet, HashSet};
22

33
use bit_vec::BitVec;
44
use itertools::Itertools;
@@ -15,9 +15,10 @@ use crate::witgen::{
1515
};
1616

1717
use super::{
18+
effect::Effect,
1819
processor::ProcessorResult,
1920
prover_function_heuristics::ProverFunction,
20-
variable::{Cell, Variable},
21+
variable::{Cell, MachineCallVariable, Variable},
2122
witgen_inference::{CanProcessCall, FixedEvaluator, WitgenInference},
2223
};
2324

@@ -117,29 +118,20 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> {
117118
.identities
118119
.iter()
119120
.any(|id| id.contains_next_ref(&intermediate_definitions));
120-
let start_row = if !have_next_ref {
121+
let (start_row, end_row) = if !have_next_ref {
121122
// No identity contains a next reference - we do not need to consider row -1,
122123
// and the block has to be rectangular-shaped.
123-
0
124+
(0, self.block_size as i32 - 1)
124125
} else {
125126
// A machine that might have a non-rectangular shape.
126127
// We iterate over all rows of the block +/- one row.
127-
-1
128+
(-1, self.block_size as i32)
128129
};
129-
let identities = (start_row..self.block_size as i32).flat_map(|row| {
130+
let identities = (start_row..=end_row).flat_map(|row| {
130131
self.machine_parts
131132
.identities
132133
.iter()
133-
.filter_map(|id| {
134-
// Filter out identities with next references on the last row.
135-
if row as usize == self.block_size - 1
136-
&& id.contains_next_ref(&intermediate_definitions)
137-
{
138-
None
139-
} else {
140-
Some((*id, row))
141-
}
142-
})
134+
.map(|id| (*id, row))
143135
.collect_vec()
144136
});
145137

@@ -160,7 +152,6 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> {
160152
requested_known,
161153
BLOCK_MACHINE_MAX_BRANCH_DEPTH,
162154
)
163-
.with_block_shape_check()
164155
.with_block_size(self.block_size)
165156
.with_requested_range_constraints((0..known_args.len()).map(Variable::Param))
166157
.generate_code(can_process, witgen)
@@ -183,8 +174,100 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> {
183174
.format("\n ");
184175
format!("Code generation failed: {shortened_error}\nRun with RUST_LOG=trace to see the code generated so far.")
185176
})?;
177+
self.check_block_shape(&result.code)?;
186178
Ok((result, prover_functions))
187179
}
180+
181+
/// Verifies that each column and each bus send is stackable in the block.
182+
/// This means that if we have a cell write or a bus send in row `i`, we cannot
183+
/// have another one in row `i + block_size`.
184+
fn check_block_shape(&self, code: &[Effect<T, Variable>]) -> Result<(), String> {
185+
for (column_id, row_offsets) in written_rows_per_column(code) {
186+
for offset in &row_offsets {
187+
if row_offsets.contains(&(*offset + self.block_size as i32)) {
188+
return Err(format!(
189+
"Column {} is not stackable in a {}-row block, conflict in rows {} and {}.",
190+
self.fixed_data.column_name(&PolyID {
191+
id: column_id,
192+
ptype: PolynomialType::Committed
193+
}),
194+
self.block_size,
195+
offset,
196+
offset + self.block_size as i32
197+
));
198+
}
199+
}
200+
}
201+
for (identity_id, row_offsets) in completed_rows_for_bus_send(code) {
202+
let row_offsets: BTreeSet<_> = row_offsets.into_iter().collect();
203+
for offset in &row_offsets {
204+
if row_offsets.contains(&(*offset + self.block_size as i32)) {
205+
return Err(format!(
206+
"Bus send for identity {} is not stackable in a {}-row block, conflict in rows {} and {}.",
207+
identity_id,
208+
self.block_size,
209+
offset,
210+
offset + self.block_size as i32
211+
));
212+
}
213+
}
214+
}
215+
Ok(())
216+
}
217+
}
218+
219+
/// Returns, for each column ID, the collection of row offsets that have a cell write.
220+
/// Combines writes from branches.
221+
fn written_rows_per_column<T: FieldElement>(
222+
code: &[Effect<T, Variable>],
223+
) -> BTreeMap<u64, BTreeSet<i32>> {
224+
code.iter()
225+
.flat_map(|e| e.written_vars())
226+
.filter_map(|(v, _)| match v {
227+
Variable::WitnessCell(cell) => Some((cell.id, cell.row_offset)),
228+
_ => None,
229+
})
230+
.fold(BTreeMap::new(), |mut map, (id, row)| {
231+
map.entry(id).or_default().insert(row);
232+
map
233+
})
234+
}
235+
236+
/// Returns, for each bus send ID, the collection of row offsets that have a machine call.
237+
/// Combines calls from branches.
238+
fn completed_rows_for_bus_send<T: FieldElement>(
239+
code: &[Effect<T, Variable>],
240+
) -> BTreeMap<u64, BTreeSet<i32>> {
241+
code.iter()
242+
.flat_map(machine_calls)
243+
.fold(BTreeMap::new(), |mut map, (id, row)| {
244+
map.entry(id).or_default().insert(row);
245+
map
246+
})
247+
}
248+
249+
/// Returns all machine calls (bus identity and row offset) found in the effect.
250+
/// Recurses into branches.
251+
fn machine_calls<T: FieldElement>(
252+
e: &Effect<T, Variable>,
253+
) -> Box<dyn Iterator<Item = (u64, i32)> + '_> {
254+
match e {
255+
Effect::MachineCall(id, _, arguments) => match &arguments[0] {
256+
Variable::MachineCallParam(MachineCallVariable {
257+
identity_id,
258+
row_offset,
259+
..
260+
}) => {
261+
assert_eq!(*id, *identity_id);
262+
Box::new(std::iter::once((*identity_id, *row_offset)))
263+
}
264+
_ => panic!("Expected machine call variable."),
265+
},
266+
Effect::Branch(_, first, second) => {
267+
Box::new(first.iter().chain(second.iter()).flat_map(machine_calls))
268+
}
269+
_ => Box::new(std::iter::empty()),
270+
}
188271
}
189272

190273
impl<T: FieldElement> FixedEvaluator<T> for &BlockMachineProcessor<'_, T> {

executor/src/witgen/jit/processor.rs

+19-119
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,8 @@
11
#![allow(dead_code)]
2-
use std::{
3-
collections::BTreeSet,
4-
fmt::{self, Display, Formatter, Write},
5-
};
2+
use std::fmt::{self, Display, Formatter, Write};
63

74
use itertools::Itertools;
8-
use powdr_ast::analyzed::{
9-
AlgebraicExpression as Expression, PolyID, PolynomialIdentity, PolynomialType,
10-
};
5+
use powdr_ast::analyzed::PolynomialIdentity;
116
use powdr_number::FieldElement;
127

138
use crate::witgen::{
@@ -21,8 +16,8 @@ use super::{
2116
affine_symbolic_expression,
2217
effect::{format_code, Effect},
2318
identity_queue::{IdentityQueue, QueueItem},
24-
variable::{Cell, MachineCallVariable, Variable},
25-
witgen_inference::{BranchResult, CanProcessCall, FixedEvaluator, Value, WitgenInference},
19+
variable::{MachineCallVariable, Variable},
20+
witgen_inference::{BranchResult, CanProcessCall, FixedEvaluator, WitgenInference},
2621
};
2722

2823
/// A generic processor for generating JIT code.
@@ -36,8 +31,6 @@ pub struct Processor<'a, T: FieldElement, FixedEval> {
3631
initial_queue: Vec<QueueItem<'a, T>>,
3732
/// The size of a block.
3833
block_size: usize,
39-
/// If the processor should check for correctly stackable block shapes.
40-
check_block_shape: bool,
4134
/// List of variables we want to be known at the end. One of them not being known
4235
/// is a failure.
4336
requested_known_vars: Vec<Variable>,
@@ -71,7 +64,6 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator<T>> Processor<'a, T, FixedEv
7164
identities,
7265
initial_queue,
7366
block_size: 1,
74-
check_block_shape: false,
7567
requested_known_vars: requested_known_vars.into_iter().collect(),
7668
requested_range_constraints: vec![],
7769
max_branch_depth,
@@ -93,13 +85,6 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator<T>> Processor<'a, T, FixedEv
9385
self
9486
}
9587

96-
/// Activates the check to see if the code for two subsequently generated
97-
/// blocks conflicts.
98-
pub fn with_block_shape_check(mut self) -> Self {
99-
self.check_block_shape = true;
100-
self
101-
}
102-
10388
pub fn generate_code(
10489
self,
10590
can_process: impl CanProcessCall<T>,
@@ -143,16 +128,6 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator<T>> Processor<'a, T, FixedEv
143128
));
144129
}
145130

146-
if self.check_block_shape {
147-
// Check that the "spill" into the previous block is compatible
148-
// with the "missing pieces" in the next block.
149-
// If this is not the case, this is a hard error
150-
// (i.e. cannot be fixed by runtime witgen) and thus we panic inside.
151-
// We could do this only at the end of each branch, but it's a bit
152-
// more convenient to do it here.
153-
self.check_block_shape(&witgen);
154-
}
155-
156131
// Check that we could derive all requested variables.
157132
let missing_variables = self
158133
.requested_known_vars
@@ -347,27 +322,23 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator<T>> Processor<'a, T, FixedEv
347322
.unique()
348323
.flat_map(|&call| {
349324
let rows = self.rows_for_identity(call);
350-
let complete_rows = rows
325+
if rows
351326
.iter()
352327
.filter(|&&row| witgen.is_complete_call(call, row))
353-
.collect::<Vec<_>>();
354-
// We might process more rows than `self.block_size`, so we check
355-
// that the complete calls are on consecutive rows.
356-
if complete_rows.len() >= self.block_size {
357-
let (min, max) = complete_rows.iter().minmax().into_option().unwrap();
358-
// TODO instead of checking for consecutive rows, we could also check
359-
// that they "fit" the next block.
360-
// TODO actually I think that we should not allow more than block size
361-
// completed calls.
362-
let is_consecutive = *max - *min == complete_rows.len() as i32 - 1;
363-
if is_consecutive {
364-
return vec![];
365-
}
328+
.count()
329+
>= self.block_size
330+
{
331+
// We might process more rows than `self.block_size`, so we check
332+
// that we have the reqired amount of calls.
333+
// The block shape check done by block_machine_processor will do a more
334+
// thorough check later on.
335+
vec![]
336+
} else {
337+
rows.iter()
338+
.filter(|&row| !witgen.is_complete_call(call, *row))
339+
.map(|row| (call, *row))
340+
.collect_vec()
366341
}
367-
rows.iter()
368-
.filter(|&row| !witgen.is_complete_call(call, *row))
369-
.map(|row| (call, *row))
370-
.collect::<Vec<_>>()
371342
})
372343
.collect::<Vec<_>>()
373344
}
@@ -386,77 +357,6 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator<T>> Processor<'a, T, FixedEv
386357
.collect()
387358
}
388359

389-
/// After solving, the known cells should be such that we can stack different blocks.
390-
/// If this is not the case, this function panics.
391-
/// TODO the same is actually true for machine calls.
392-
fn check_block_shape(&self, witgen: &WitgenInference<'a, T, FixedEval>) {
393-
let known_columns: BTreeSet<_> = witgen
394-
.known_variables()
395-
.iter()
396-
.filter_map(|var| match var {
397-
Variable::WitnessCell(cell) => Some(cell.id),
398-
_ => None,
399-
})
400-
.collect();
401-
for column_id in known_columns {
402-
let known_rows = witgen
403-
.known_variables()
404-
.iter()
405-
.filter_map(|var| match var {
406-
Variable::WitnessCell(cell) if cell.id == column_id => Some(cell.row_offset),
407-
_ => None,
408-
})
409-
.collect::<BTreeSet<_>>();
410-
411-
// Two values that refer to the same row (modulo block size) are compatible if:
412-
// - One of them is unknown, or
413-
// - Both are concrete and equal
414-
let is_compatible = |v1: Value<T>, v2: Value<T>| match (v1, v2) {
415-
(Value::Unknown, _) | (_, Value::Unknown) => true,
416-
(Value::Concrete(a), Value::Concrete(b)) => a == b,
417-
_ => false,
418-
};
419-
let cell_var = |row_offset| {
420-
Variable::WitnessCell(Cell {
421-
// Column name does not matter.
422-
column_name: "".to_string(),
423-
id: column_id,
424-
row_offset,
425-
})
426-
};
427-
428-
// A column is stackable if all rows equal to each other modulo
429-
// the block size are compatible.
430-
for row in &known_rows {
431-
let this_val = witgen.value(&cell_var(*row));
432-
let next_block_val = witgen.value(&cell_var(row + self.block_size as i32));
433-
if !is_compatible(this_val, next_block_val) {
434-
let column_name = self.fixed_data.column_name(&PolyID {
435-
id: column_id,
436-
ptype: PolynomialType::Committed,
437-
});
438-
let row_vals = known_rows
439-
.iter()
440-
.map(|&r| format!(" row {r}: {}\n", witgen.value(&cell_var(r))))
441-
.format("");
442-
log::debug!(
443-
"Code generated so far:\n{}\n\
444-
Column {column_name} is not stackable in a {}-row block, \
445-
conflict in rows {row} and {}.\n{row_vals}",
446-
format_code(witgen.code()),
447-
self.block_size,
448-
row + self.block_size as i32
449-
);
450-
panic!(
451-
"Column {column_name} is not stackable in a {}-row block, conflict in rows {row} and {}.\n{row_vals}",
452-
self.block_size,
453-
row + self.block_size as i32
454-
);
455-
}
456-
}
457-
}
458-
}
459-
460360
/// If the only missing sends all only have a single argument, try to set those arguments
461361
/// to zero.
462362
fn try_fix_simple_sends(
@@ -494,7 +394,7 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator<T>> Processor<'a, T, FixedEv
494394
};
495395
assert!(!witgen.is_known(param));
496396
match modified_witgen.process_equation_on_row(
497-
&Expression::Number(T::from(0)),
397+
&T::from(0).into(),
498398
Some(param.clone()),
499399
0.into(),
500400
row,

0 commit comments

Comments
 (0)