Skip to content

Commit dc4c5cd

Browse files
chrisethgeorgwiese
andauthored
Single step with branching (#2274)
Infers single-step update code including branching. TODO: - [ ] perform constant propagation after branching (using the evaluator in `only_concrete_known()`-setting) - [ ] extend the "can process fully"-interface to allow answers like "yes, I am authorized to process, but it will never succeed given these range constraints". This way we can remove conflicting combinations of e.g. instruction flags. Another way would be to directly return range constraints on variables with the "can process" call, that way we could save branching. Maybe it's also fine to just implement this for the fixed machine. In any case, we should double-check that we do not create another machine call for an identity we already solved. --------- Co-authored-by: Georg Wiese <georgwiese@gmail.com>
1 parent 5ceeda4 commit dc4c5cd

7 files changed

+451
-85
lines changed

executor/src/witgen/jit/compiler.rs

+154-19
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
use std::{ffi::c_void, iter, mem, sync::Arc};
1+
use std::{cmp::Ordering, ffi::c_void, iter, mem, sync::Arc};
22

3-
use auto_enums::auto_enum;
43
use itertools::Itertools;
54
use libloading::Library;
5+
use powdr_ast::indent;
66
use powdr_number::{FieldElement, KnownField};
77

88
use crate::witgen::{
@@ -16,7 +16,7 @@ use crate::witgen::{
1616
};
1717

1818
use super::{
19-
effect::{Assertion, Effect},
19+
effect::{Assertion, BranchCondition, Effect},
2020
symbolic_expression::{BinaryOperator, BitOperator, SymbolicExpression, UnaryOperator},
2121
variable::Variable,
2222
};
@@ -157,10 +157,11 @@ fn witgen_code<T: FieldElement>(
157157
format!(" let {var_name} = {value};")
158158
})
159159
.format("\n");
160-
let main_code = effects.iter().map(format_effect).format("\n");
160+
let main_code = format_effects(effects);
161161
let vars_known = effects
162162
.iter()
163163
.flat_map(written_vars_in_effect)
164+
.map(|(var, _)| var)
164165
.collect_vec();
165166
let store_values = vars_known
166167
.iter()
@@ -224,26 +225,51 @@ extern "C" fn witgen(
224225
}
225226

226227
/// Returns an iterator over all variables written to in the effect.
227-
#[auto_enum(Iterator)]
228+
/// The flag indicates if the variable is the return value of a machine call and thus needs
229+
/// to be declared mutable.
228230
fn written_vars_in_effect<T: FieldElement>(
229231
effect: &Effect<T, Variable>,
230-
) -> impl Iterator<Item = &Variable> + '_ {
232+
) -> Box<dyn Iterator<Item = (&Variable, bool)> + '_> {
231233
match effect {
232-
Effect::Assignment(var, _) => iter::once(var),
234+
Effect::Assignment(var, _) => Box::new(iter::once((var, false))),
233235
Effect::RangeConstraint(..) => unreachable!(),
234-
Effect::Assertion(..) => iter::empty(),
235-
Effect::MachineCall(_, arguments) => arguments.iter().flat_map(|e| match e {
236-
MachineCallArgument::Unknown(v) => Some(v),
236+
Effect::Assertion(..) => Box::new(iter::empty()),
237+
Effect::MachineCall(_, arguments) => Box::new(arguments.iter().flat_map(|e| match e {
238+
MachineCallArgument::Unknown(v) => Some((v, true)),
237239
MachineCallArgument::Known(_) => None,
238-
}),
240+
})),
241+
Effect::Branch(_, first, second) => Box::new(
242+
first
243+
.iter()
244+
.chain(second)
245+
.flat_map(|e| written_vars_in_effect(e)),
246+
),
239247
}
240248
}
241249

242-
fn format_effect<T: FieldElement>(effect: &Effect<T, Variable>) -> String {
250+
pub fn format_effects<T: FieldElement>(effects: &[Effect<T, Variable>]) -> String {
251+
format_effects_inner(effects, true)
252+
}
253+
254+
fn format_effects_inner<T: FieldElement>(
255+
effects: &[Effect<T, Variable>],
256+
is_top_level: bool,
257+
) -> String {
258+
indent(
259+
effects
260+
.iter()
261+
.map(|effect| format_effect(effect, is_top_level))
262+
.join("\n"),
263+
1,
264+
)
265+
}
266+
267+
fn format_effect<T: FieldElement>(effect: &Effect<T, Variable>, is_top_level: bool) -> String {
243268
match effect {
244269
Effect::Assignment(var, e) => {
245270
format!(
246-
" let {} = {};",
271+
"{}{} = {};",
272+
if is_top_level { "let " } else { "" },
247273
variable_to_string(var),
248274
format_expression(e)
249275
)
@@ -256,7 +282,7 @@ fn format_effect<T: FieldElement>(effect: &Effect<T, Variable>) -> String {
256282
rhs,
257283
expected_equal,
258284
}) => format!(
259-
" assert!({} {} {});",
285+
"assert!({} {} {});",
260286
format_expression(lhs),
261287
if *expected_equal { "==" } else { "!=" },
262288
format_expression(rhs)
@@ -268,7 +294,9 @@ fn format_effect<T: FieldElement>(effect: &Effect<T, Variable>) -> String {
268294
.map(|a| match a {
269295
MachineCallArgument::Unknown(v) => {
270296
let var_name = variable_to_string(v);
271-
result_vars.push(var_name.clone());
297+
if is_top_level {
298+
result_vars.push(var_name.clone());
299+
}
272300
format!("LookupCell::Output(&mut {var_name})")
273301
}
274302
MachineCallArgument::Known(v) => {
@@ -279,11 +307,40 @@ fn format_effect<T: FieldElement>(effect: &Effect<T, Variable>) -> String {
279307
.to_string();
280308
let var_decls = result_vars
281309
.iter()
282-
.map(|var_name| format!(" let mut {var_name} = FieldElement::default();"))
283-
.format("\n");
310+
.map(|var_name| format!("let mut {var_name} = FieldElement::default();\n"))
311+
.format("");
312+
format!(
313+
"{var_decls}assert!(call_machine(mutable_state, {id}, MutSlice::from((&mut [{args}]).as_mut_slice())));"
314+
)
315+
}
316+
Effect::Branch(condition, first, second) => {
317+
let var_decls = if is_top_level {
318+
// We need to declare all assigned variables at top level,
319+
// so that they are available after the branches.
320+
first
321+
.iter()
322+
.chain(second)
323+
.flat_map(|e| written_vars_in_effect(e))
324+
.sorted()
325+
.dedup()
326+
.map(|(v, needs_mut)| {
327+
let v = variable_to_string(v);
328+
if needs_mut {
329+
format!("let mut {v} = FieldElement::default();\n")
330+
} else {
331+
format!("let {v};\n")
332+
}
333+
})
334+
.format("")
335+
.to_string()
336+
} else {
337+
"".to_string()
338+
};
284339
format!(
285-
"{var_decls}
286-
assert!(call_machine(mutable_state, {id}, MutSlice::from((&mut [{args}]).as_mut_slice())));"
340+
"{var_decls}if {} {{\n{}\n}} else {{\n{}\n}}",
341+
format_condition(condition),
342+
format_effects_inner(first, false),
343+
format_effects_inner(second, false)
287344
)
288345
}
289346
}
@@ -319,6 +376,16 @@ fn format_expression<T: FieldElement>(e: &SymbolicExpression<T, Variable>) -> St
319376
}
320377
}
321378

379+
fn format_condition<T: FieldElement>(condition: &BranchCondition<T, Variable>) -> String {
380+
let var = format!("IntType::from({})", variable_to_string(&condition.variable));
381+
let (min, max) = condition.first_branch.range();
382+
match min.cmp(&max) {
383+
Ordering::Equal => format!("{var} == {min}",),
384+
Ordering::Less => format!("{min} <= {var} && {var} <= {max}"),
385+
Ordering::Greater => format!("{var} <= {min} || {var} >= {max}"),
386+
}
387+
}
388+
322389
/// Returns the name of a local (stack) variable for the given expression variable.
323390
fn variable_to_string(v: &Variable) -> String {
324391
match v {
@@ -423,6 +490,7 @@ mod tests {
423490

424491
use crate::witgen::jit::variable::Cell;
425492
use crate::witgen::jit::variable::MachineCallReturnVariable;
493+
use crate::witgen::range_constraints::RangeConstraint;
426494

427495
use super::*;
428496

@@ -780,4 +848,71 @@ extern \"C\" fn witgen(
780848
assert_eq!(data[1], GoldilocksField::from(18));
781849
assert_eq!(data[2], GoldilocksField::from(0));
782850
}
851+
852+
#[test]
853+
fn branches() {
854+
let x = param(0);
855+
let y = param(1);
856+
let mut x_val: GoldilocksField = 7.into();
857+
let mut y_val: GoldilocksField = 9.into();
858+
let effects = vec![Effect::Branch(
859+
BranchCondition {
860+
variable: x.clone(),
861+
first_branch: RangeConstraint::from_range(7.into(), 20.into()),
862+
second_branch: RangeConstraint::from_range(21.into(), 6.into()),
863+
},
864+
vec![assignment(&y, symbol(&x) + number(1))],
865+
vec![assignment(&y, symbol(&x) + number(2))],
866+
)];
867+
let f = compile_effects(0, 1, &[x], &effects).unwrap();
868+
let mut data = vec![];
869+
let mut known = vec![];
870+
871+
let mut params = vec![LookupCell::Input(&x_val), LookupCell::Output(&mut y_val)];
872+
let params = WitgenFunctionParams {
873+
data: data.as_mut_slice().into(),
874+
known: known.as_mut_ptr(),
875+
row_offset: 0,
876+
params: params.as_mut_slice().into(),
877+
mutable_state: std::ptr::null(),
878+
call_machine: no_call_machine,
879+
};
880+
(f.function)(params);
881+
assert_eq!(y_val, GoldilocksField::from(8));
882+
883+
x_val = 2.into();
884+
let mut params = vec![LookupCell::Input(&x_val), LookupCell::Output(&mut y_val)];
885+
let params = WitgenFunctionParams {
886+
data: data.as_mut_slice().into(),
887+
known: known.as_mut_ptr(),
888+
row_offset: 0,
889+
params: params.as_mut_slice().into(),
890+
mutable_state: std::ptr::null(),
891+
call_machine: no_call_machine,
892+
};
893+
(f.function)(params);
894+
assert_eq!(y_val, GoldilocksField::from(4));
895+
}
896+
897+
#[test]
898+
fn branches_codegen() {
899+
let x = param(0);
900+
let y = param(1);
901+
let branch_effect = Effect::Branch(
902+
BranchCondition {
903+
variable: x.clone(),
904+
first_branch: RangeConstraint::from_range(7.into(), 20.into()),
905+
second_branch: RangeConstraint::from_range(21.into(), 6.into()),
906+
},
907+
vec![assignment(&y, symbol(&x) + number(1))],
908+
vec![assignment(&y, symbol(&x) + number(2))],
909+
);
910+
let expectation = " let p_1;
911+
if 7 <= IntType::from(p_0) && IntType::from(p_0) <= 20 {
912+
p_1 = (p_0 + FieldElement::from(1));
913+
} else {
914+
p_1 = (p_0 + FieldElement::from(2));
915+
}";
916+
assert_eq!(format_effects(&[branch_effect]), expectation);
917+
}
783918
}

executor/src/witgen/jit/effect.rs

+33-1
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,30 @@
1+
use std::cmp::Ordering;
2+
13
use itertools::Itertools;
4+
use powdr_ast::indent;
25
use powdr_number::FieldElement;
36

47
use crate::witgen::range_constraints::RangeConstraint;
58

69
use super::{symbolic_expression::SymbolicExpression, variable::Variable};
710

811
/// The effect of solving a symbolic equation.
12+
#[derive(Clone, PartialEq, Eq)]
913
pub enum Effect<T: FieldElement, V> {
1014
/// Variable can be assigned a value.
1115
Assignment(V, SymbolicExpression<T, V>),
1216
/// We learnt a new range constraint on variable.
1317
RangeConstraint(V, RangeConstraint<T>),
1418
/// A run-time assertion. If this fails, we have conflicting constraints.
1519
Assertion(Assertion<T, V>),
16-
/// a call to a different machine.
20+
/// A call to a different machine.
1721
MachineCall(u64, Vec<MachineCallArgument<T, V>>),
22+
/// A branch on a variable.
23+
Branch(BranchCondition<T, V>, Vec<Effect<T, V>>, Vec<Effect<T, V>>),
1824
}
1925

2026
/// A run-time assertion. If this fails, we have conflicting constraints.
27+
#[derive(Clone, PartialEq, Eq)]
2128
pub struct Assertion<T: FieldElement, V> {
2229
pub lhs: SymbolicExpression<T, V>,
2330
pub rhs: SymbolicExpression<T, V>,
@@ -52,11 +59,19 @@ impl<T: FieldElement, V> Assertion<T, V> {
5259
}
5360
}
5461

62+
#[derive(Clone, PartialEq, Eq)]
5563
pub enum MachineCallArgument<T: FieldElement, V> {
5664
Known(SymbolicExpression<T, V>),
5765
Unknown(V),
5866
}
5967

68+
#[derive(Clone, PartialEq, Eq)]
69+
pub struct BranchCondition<T: FieldElement, V> {
70+
pub variable: V,
71+
pub first_branch: RangeConstraint<T>,
72+
pub second_branch: RangeConstraint<T>,
73+
}
74+
6075
/// Helper function to render a list of effects. Used for informational purposes only.
6176
pub fn format_code<T: FieldElement>(effects: &[Effect<T, Variable>]) -> String {
6277
effects
@@ -87,6 +102,23 @@ pub fn format_code<T: FieldElement>(effects: &[Effect<T, Variable>]) -> String {
87102
Effect::RangeConstraint(..) => {
88103
panic!("Range constraints should not be part of the code.")
89104
}
105+
Effect::Branch(condition, first, second) => {
106+
let first = indent(format_code(first), 1);
107+
let second = indent(format_code(second), 1);
108+
let condition = format_condition(condition);
109+
110+
format!("if ({condition}) {{\n{first}\n}} else {{\n{second}\n}}")
111+
}
90112
})
91113
.join("\n")
92114
}
115+
116+
fn format_condition<T: FieldElement>(condition: &BranchCondition<T, Variable>) -> String {
117+
let var = &condition.variable;
118+
let (min, max) = condition.first_branch.range();
119+
match min.cmp(&max) {
120+
Ordering::Equal => format!("{var} == {min}"),
121+
Ordering::Less => format!("{min} <= {var} && {var} <= {max}"),
122+
Ordering::Greater => format!("{var} <= {min} || {var} >= {max}"),
123+
}
124+
}

executor/src/witgen/jit/includes/field_generic_up_to_64.rs

+6
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,12 @@ impl From<IntType> for FieldElement {
1919
Self(i)
2020
}
2121
}
22+
impl From<FieldElement> for IntType {
23+
#[inline]
24+
fn from(f: FieldElement) -> Self {
25+
f.0
26+
}
27+
}
2228
impl std::ops::Add for FieldElement {
2329
type Output = Self;
2430
#[inline]

executor/src/witgen/jit/includes/field_goldilocks.rs

+8
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
struct GoldilocksField(u64);
44

55
type FieldElement = GoldilocksField;
6+
type IntType = u64;
67

78
const EPSILON: u64 = (1 << 32) - 1;
89

@@ -292,6 +293,13 @@ impl From<u64> for GoldilocksField {
292293
}
293294
}
294295

296+
impl From<FieldElement> for IntType {
297+
#[inline]
298+
fn from(f: FieldElement) -> Self {
299+
f.0
300+
}
301+
}
302+
295303
impl std::fmt::Display for GoldilocksField {
296304
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
297305
write!(f, "{}", self.0)

0 commit comments

Comments
 (0)