Skip to content

Commit 4bb99be

Browse files
georgwiesechriseth
andauthored
JIT for block machines with non-rectangular shapes (#2275)
Depends on #2279 This PR implements JIT code generation for block machines with irregular block shape, such as `std::machines::large_field::binary::Binary`. This is achieved as follows: - Instead of solving rows `0..block_size`, we run the solver for rows `-1..(block_size + 1)`. This way, the solver is able to generate code that writes to the previous row of the last block or the first row of the next block. - At the end, we check whether the generated code is actually consistent: For example, if the code writes to the last row of the previous block, it can't have a unknown value in the same cell of the current block (unless it's known to be the same). Note that the generated code is still not used in practice, because we don't call the JIT with the right amount of context. I started fixing this in #2281, but it is still WIP. --------- Co-authored-by: chriseth <chris@ethereum.org>
1 parent 91202ee commit 4bb99be

File tree

6 files changed

+286
-46
lines changed

6 files changed

+286
-46
lines changed

executor/src/witgen/jit/block_machine_processor.rs

+204-25
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
1-
use std::collections::HashSet;
1+
use std::collections::{BTreeSet, HashSet};
22

33
use bit_vec::BitVec;
44
use itertools::Itertools;
5-
use powdr_ast::analyzed::{AlgebraicReference, Identity, SelectedExpressions};
5+
use powdr_ast::analyzed::{
6+
AlgebraicReference, Identity, PolyID, PolynomialType, SelectedExpressions,
7+
};
68
use powdr_number::FieldElement;
79

810
use crate::witgen::{jit::effect::format_code, machines::MachineParts, FixedData};
911

1012
use super::{
1113
effect::Effect,
12-
variable::Variable,
13-
witgen_inference::{CanProcessCall, FixedEvaluator, WitgenInference},
14+
variable::{Cell, Variable},
15+
witgen_inference::{CanProcessCall, FixedEvaluator, Value, WitgenInference},
1416
};
1517

1618
/// A processor for generating JIT code for a block machine.
@@ -85,6 +87,11 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> {
8587
}
8688
}
8789

90+
fn row_range(&self) -> std::ops::Range<i32> {
91+
// We iterate over all rows of the block +/- one row, so that we can also solve for non-rectangular blocks.
92+
-1..(self.block_size + 1) as i32
93+
}
94+
8895
/// Repeatedly processes all identities on all rows, until no progress is made.
8996
/// Fails iff there are incomplete machine calls in the latch row.
9097
fn solve_block<CanProcess: CanProcessCall<T> + Clone>(
@@ -97,11 +104,10 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> {
97104
for iteration in 0.. {
98105
let mut progress = false;
99106

100-
// TODO: This algorithm is assuming a rectangular block shape.
101-
for row in 0..self.block_size {
107+
for row in self.row_range() {
102108
for id in &self.machine_parts.identities {
103109
if !complete.contains(&(id.id(), row)) {
104-
let result = witgen.process_identity(can_process.clone(), id, row as i32);
110+
let result = witgen.process_identity(can_process.clone(), id, row);
105111
if result.complete {
106112
complete.insert((id.id(), row));
107113
}
@@ -125,22 +131,121 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> {
125131
}
126132
}
127133

128-
// If any machine call could not be completed, that's bad because machine calls typically have side effects.
129-
// So, the underlying lookup / permutation / bus argument likely does not hold.
130-
// TODO: This assumes a rectangular block shape.
131-
let has_incomplete_machine_calls = (0..self.block_size)
132-
.flat_map(|row| {
133-
self.machine_parts
134-
.identities
135-
.iter()
136-
.filter(|id| is_machine_call(id))
137-
.map(move |id| (id, row))
134+
// TODO: Fail hard (or return a different error), as this should never
135+
// happen for valid block machines. Currently fails in:
136+
// powdr-pipeline::powdr_std arith256_memory_large_test
137+
self.check_block_shape(witgen)?;
138+
self.check_incomplete_machine_calls(&complete)?;
139+
140+
Ok(())
141+
}
142+
143+
/// After solving, the known values should be such that we can stack different blocks.
144+
fn check_block_shape(&self, witgen: &mut WitgenInference<'a, T, &Self>) -> Result<(), String> {
145+
let known_columns = witgen
146+
.known_variables()
147+
.iter()
148+
.filter_map(|var| match var {
149+
Variable::Cell(cell) => Some(cell.id),
150+
_ => None,
138151
})
139-
.any(|(identity, row)| !complete.contains(&(identity.id(), row)));
152+
.collect::<BTreeSet<_>>();
153+
154+
let can_stack = known_columns.iter().all(|column_id| {
155+
// Increase the range by 1, because in row <block_size>,
156+
// we might have processed an identity with next references.
157+
let row_range = self.row_range();
158+
let values = (row_range.start..(row_range.end + 1))
159+
.map(|row| {
160+
witgen.value(&Variable::Cell(Cell {
161+
id: *column_id,
162+
row_offset: row,
163+
// Dummy value, the column name is ignored in the implementation
164+
// of Cell::eq, etc.
165+
column_name: "".to_string(),
166+
}))
167+
})
168+
.collect::<Vec<_>>();
169+
170+
// Two values that refer to the same row (modulo block size) are compatible if:
171+
// - One of them is unknown, or
172+
// - Both are concrete and equal
173+
let is_compatible = |v1: Value<T>, v2: Value<T>| match (v1, v2) {
174+
(Value::Unknown, _) | (_, Value::Unknown) => true,
175+
(Value::Concrete(a), Value::Concrete(b)) => a == b,
176+
_ => false,
177+
};
178+
// A column is stackable if all rows equal to each other modulo
179+
// the block size are compatible.
180+
let stackable = (0..(values.len() - self.block_size))
181+
.all(|i| is_compatible(values[i], values[i + self.block_size]));
140182

141-
match has_incomplete_machine_calls {
142-
true => Err("Incomplete machine calls".to_string()),
143-
false => Ok(()),
183+
if !stackable {
184+
let column_name = self.fixed_data.column_name(&PolyID {
185+
id: *column_id,
186+
ptype: PolynomialType::Committed,
187+
});
188+
let block_list = values.iter().skip(1).take(self.block_size).join(", ");
189+
let column_str = format!(
190+
"... {} | {} | {} ...",
191+
values[0],
192+
block_list,
193+
values[self.block_size + 1]
194+
);
195+
log::debug!("Column {column_name} is not stackable:\n{column_str}");
196+
}
197+
198+
stackable
199+
});
200+
201+
match can_stack {
202+
true => Ok(()),
203+
false => Err("Block machine shape does not allow stacking".to_string()),
204+
}
205+
}
206+
207+
/// If any machine call could not be completed, that's bad because machine calls typically have side effects.
208+
/// So, the underlying lookup / permutation / bus argument likely does not hold.
209+
/// This function checks that all machine calls are complete, at least for a window of <block_size> rows.
210+
fn check_incomplete_machine_calls(&self, complete: &HashSet<(u64, i32)>) -> Result<(), String> {
211+
let machine_calls = self
212+
.machine_parts
213+
.identities
214+
.iter()
215+
.filter(|id| is_machine_call(id));
216+
217+
let incomplete_machine_calls = machine_calls
218+
.flat_map(|call| {
219+
let complete_rows = self
220+
.row_range()
221+
.filter(|row| complete.contains(&(call.id(), *row)))
222+
.collect::<Vec<_>>();
223+
// Because we process rows -1..block_size+1, it is fine to have two incomplete machine calls,
224+
// as long as <block_size> consecutive rows are complete.
225+
if complete_rows.len() >= self.block_size {
226+
let (min, max) = complete_rows.iter().minmax().into_option().unwrap();
227+
let is_consecutive = max - min == complete_rows.len() as i32 - 1;
228+
if is_consecutive {
229+
return vec![];
230+
}
231+
}
232+
self.row_range()
233+
.filter(|row| !complete.contains(&(call.id(), *row)))
234+
.map(|row| (call, row))
235+
.collect::<Vec<_>>()
236+
})
237+
.collect::<Vec<_>>();
238+
239+
if !incomplete_machine_calls.is_empty() {
240+
Err(format!(
241+
"Incomplete machine calls:\n {}",
242+
incomplete_machine_calls
243+
.iter()
244+
.map(|(identity, row)| format!("{identity} (row {row})"))
245+
.join("\n ")
246+
))
247+
} else {
248+
Ok(())
144249
}
145250
}
146251
}
@@ -160,7 +265,22 @@ impl<T: FieldElement> FixedEvaluator<T> for &BlockMachineProcessor<'_, T> {
160265
fn evaluate(&self, var: &AlgebraicReference, row_offset: i32) -> Option<T> {
161266
assert!(var.is_fixed());
162267
let values = self.fixed_data.fixed_cols[&var.poly_id].values_max_size();
163-
let row = (row_offset + var.next as i32 + values.len() as i32) as usize % values.len();
268+
269+
// By assumption of the block machine, all fixed columns are cyclic with a period of <block_size>.
270+
// An exception might be the first and last row.
271+
assert!(row_offset >= -1);
272+
assert!(self.block_size >= 1);
273+
// The current row is guaranteed to be at least 1.
274+
let current_row = (2 * self.block_size as i32 + row_offset) as usize;
275+
let row = current_row + var.next as usize;
276+
277+
assert!(values.len() >= self.block_size * 4);
278+
279+
// Fixed columns are assumed to be cyclic, except in the first and last row.
280+
// The code above should ensure that we never access the first or last row.
281+
assert!(row > 0);
282+
assert!(row < values.len() - 1);
283+
164284
Some(values[row])
165285
}
166286
}
@@ -265,11 +385,70 @@ params[2] = Add::c[0];"
265385
}
266386

267387
#[test]
268-
// TODO: Currently fails, because the machine has a non-rectangular block shape.
269-
#[should_panic = "Unable to derive algorithm to compute output value \\\"main_binary::C\\\""]
388+
#[should_panic = "Block machine shape does not allow stacking"]
389+
fn not_stackable() {
390+
let input = "
391+
namespace Main(256);
392+
col witness a, b, c;
393+
[a] is NotStackable.sel $ [NotStackable.a];
394+
namespace NotStackable(256);
395+
col witness sel, a;
396+
a = a';
397+
";
398+
generate_for_block_machine(input, "NotStackable", 1, 0).unwrap();
399+
}
400+
401+
#[test]
270402
fn binary() {
271403
let input = read_to_string("../test_data/pil/binary.pil").unwrap();
272-
generate_for_block_machine(&input, "main_binary", 3, 1).unwrap();
404+
let code = generate_for_block_machine(&input, "main_binary", 3, 1).unwrap();
405+
assert_eq!(
406+
format_code(&code),
407+
"main_binary::sel[0][3] = 1;
408+
main_binary::operation_id[3] = params[0];
409+
main_binary::A[3] = params[1];
410+
main_binary::B[3] = params[2];
411+
main_binary::operation_id[2] = main_binary::operation_id[3];
412+
main_binary::A_byte[2] = ((main_binary::A[3] & 4278190080) // 16777216);
413+
main_binary::A[2] = (main_binary::A[3] & 16777215);
414+
assert (main_binary::A[3] & 18446744069414584320) == 0;
415+
main_binary::B_byte[2] = ((main_binary::B[3] & 4278190080) // 16777216);
416+
main_binary::B[2] = (main_binary::B[3] & 16777215);
417+
assert (main_binary::B[3] & 18446744069414584320) == 0;
418+
main_binary::operation_id_next[2] = main_binary::operation_id[3];
419+
machine_call(9, [Known(main_binary::operation_id_next[2]), Known(main_binary::A_byte[2]), Known(main_binary::B_byte[2]), Unknown(ret(9, 2, 3))]);
420+
main_binary::C_byte[2] = ret(9, 2, 3);
421+
main_binary::operation_id[1] = main_binary::operation_id[2];
422+
main_binary::A_byte[1] = ((main_binary::A[2] & 16711680) // 65536);
423+
main_binary::A[1] = (main_binary::A[2] & 65535);
424+
assert (main_binary::A[2] & 18446744073692774400) == 0;
425+
main_binary::B_byte[1] = ((main_binary::B[2] & 16711680) // 65536);
426+
main_binary::B[1] = (main_binary::B[2] & 65535);
427+
assert (main_binary::B[2] & 18446744073692774400) == 0;
428+
main_binary::operation_id_next[1] = main_binary::operation_id[2];
429+
machine_call(9, [Known(main_binary::operation_id_next[1]), Known(main_binary::A_byte[1]), Known(main_binary::B_byte[1]), Unknown(ret(9, 1, 3))]);
430+
main_binary::C_byte[1] = ret(9, 1, 3);
431+
main_binary::operation_id[0] = main_binary::operation_id[1];
432+
main_binary::A_byte[0] = ((main_binary::A[1] & 65280) // 256);
433+
main_binary::A[0] = (main_binary::A[1] & 255);
434+
assert (main_binary::A[1] & 18446744073709486080) == 0;
435+
main_binary::B_byte[0] = ((main_binary::B[1] & 65280) // 256);
436+
main_binary::B[0] = (main_binary::B[1] & 255);
437+
assert (main_binary::B[1] & 18446744073709486080) == 0;
438+
main_binary::operation_id_next[0] = main_binary::operation_id[1];
439+
machine_call(9, [Known(main_binary::operation_id_next[0]), Known(main_binary::A_byte[0]), Known(main_binary::B_byte[0]), Unknown(ret(9, 0, 3))]);
440+
main_binary::C_byte[0] = ret(9, 0, 3);
441+
main_binary::A_byte[-1] = main_binary::A[0];
442+
main_binary::B_byte[-1] = main_binary::B[0];
443+
main_binary::operation_id_next[-1] = main_binary::operation_id[0];
444+
machine_call(9, [Known(main_binary::operation_id_next[-1]), Known(main_binary::A_byte[-1]), Known(main_binary::B_byte[-1]), Unknown(ret(9, -1, 3))]);
445+
main_binary::C_byte[-1] = ret(9, -1, 3);
446+
main_binary::C[0] = main_binary::C_byte[-1];
447+
main_binary::C[1] = (main_binary::C[0] + (main_binary::C_byte[0] * 256));
448+
main_binary::C[2] = (main_binary::C[1] + (main_binary::C_byte[1] * 65536));
449+
main_binary::C[3] = (main_binary::C[2] + (main_binary::C_byte[2] * 16777216));
450+
params[3] = main_binary::C[3];"
451+
)
273452
}
274453

275454
#[test]

executor/src/witgen/jit/function_cache.rs

+23
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use powdr_number::{FieldElement, KnownField};
55

66
use crate::witgen::{
77
data_structures::finalizable_data::{ColumnLayout, CompactDataRef},
8+
jit::effect::Effect,
89
machines::{LookupCell, MachineParts},
910
EvalError, FixedData, MutableState, QueryCallback,
1011
};
@@ -28,6 +29,7 @@ pub struct FunctionCache<'a, T: FieldElement> {
2829
/// but failed.
2930
witgen_functions: HashMap<CacheKey, Option<WitgenFunction<T>>>,
3031
column_layout: ColumnLayout,
32+
block_size: usize,
3133
}
3234

3335
impl<'a, T: FieldElement> FunctionCache<'a, T> {
@@ -45,6 +47,7 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
4547
processor,
4648
column_layout: metadata,
4749
witgen_functions: HashMap::new(),
50+
block_size,
4851
}
4952
}
5053

@@ -89,9 +92,29 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
8992
cache_key: &CacheKey,
9093
) -> Option<WitgenFunction<T>> {
9194
log::trace!("Compiling JIT function for {:?}", cache_key);
95+
9296
self.processor
9397
.generate_code(mutable_state, cache_key.identity_id, &cache_key.known_args)
9498
.ok()
99+
.and_then(|code| {
100+
// TODO: Remove this once BlockMachine passes the right amount of context for machines with
101+
// non-rectangular block shapes.
102+
let is_rectangular = code
103+
.iter()
104+
.filter_map(|effect| match effect {
105+
Effect::Assignment(v, _) => Some(v),
106+
_ => None,
107+
})
108+
.filter_map(|assigned_variable| match assigned_variable {
109+
Variable::Cell(cell) => Some(cell.row_offset),
110+
_ => None,
111+
})
112+
.all(|row_offset| row_offset >= 0 && row_offset < self.block_size as i32);
113+
if !is_rectangular {
114+
log::debug!("Filtering out code for non-rectangular block shape");
115+
}
116+
is_rectangular.then_some(code)
117+
})
95118
.map(|code| {
96119
log::trace!("Generated code ({} steps)", code.len());
97120
let known_inputs = cache_key

0 commit comments

Comments
 (0)