Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Single step with branching #2274

Merged
merged 45 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
3710c19
Bisect range constraints.
chriseth Dec 19, 2024
200cfa9
fix mask.
chriseth Dec 20, 2024
00f53da
remove commented-out code.
chriseth Dec 20, 2024
cce4375
Single step processor.
chriseth Dec 20, 2024
c54c0dd
Negative test.
chriseth Dec 20, 2024
71666c0
branching test.
chriseth Dec 20, 2024
96b3660
loop
chriseth Dec 20, 2024
405bdc1
c
chriseth Dec 20, 2024
2bdbe81
compute branch var
chriseth Dec 20, 2024
140e7d4
Merge branch 'bisect' into single_step_with_branching
chriseth Dec 20, 2024
79dd351
almost branch.
chriseth Dec 20, 2024
e5b9846
Merge remote-tracking branch 'origin/main' into single_step_with_bran…
chriseth Dec 20, 2024
e41976a
Add "can process"
chriseth Dec 20, 2024
b346f62
Fix test.
chriseth Dec 20, 2024
e271ef7
Fix test again.
chriseth Dec 20, 2024
0f8ac79
fix test.
chriseth Dec 20, 2024
b58afc1
swap branches
chriseth Dec 20, 2024
a40a133
Merge remote-tracking branch 'origin/main' into single_step_simple
chriseth Jan 2, 2025
eb245ac
Apply suggestions from code review
chriseth Jan 2, 2025
1530cae
Merge branch 'single_step_simple' of ssh://github.com/powdr-labs/powd…
chriseth Jan 2, 2025
1588eb2
adjust tests.
chriseth Jan 2, 2025
cc70c19
Remove test function.
chriseth Jan 2, 2025
b13df2d
Merge branch 'remove_test_function' into single_step_simple
chriseth Jan 2, 2025
96d5506
Fix testing functions.
chriseth Jan 2, 2025
b6a2ad9
Merge remote-tracking branch 'origin/main' into single_step_with_bran…
chriseth Jan 2, 2025
b799c35
Merge branch 'single_step_simple' into single_step_with_branching
chriseth Jan 2, 2025
819d794
Fix merge
chriseth Jan 2, 2025
9604c79
Fix test.
chriseth Jan 2, 2025
250253c
Fix written vars.
chriseth Jan 2, 2025
93373a3
Fix test expectation.
chriseth Jan 2, 2025
ce159f6
Code generation for branches.
chriseth Jan 2, 2025
b45dd7f
clippy
chriseth Jan 2, 2025
4fc4a2e
Merge branch 'single_step_simple'
chriseth Jan 3, 2025
3517df3
Merge commit '50fd4580b' into single_step_with_branching
chriseth Jan 3, 2025
8fab6ae
Merge branch 'main' into single_step_with_branching
chriseth Jan 3, 2025
cf9ea9f
Fix formatting and tests.
chriseth Jan 3, 2025
352d63f
Update executor/src/witgen/jit/single_step_processor.rs
chriseth Jan 6, 2025
232b0d8
Extend comment
chriseth Jan 6, 2025
671546f
Merge branch 'single_step_with_branching' of ssh://github.com/powdr-l…
chriseth Jan 6, 2025
7d69841
Return both copies in on_branch.
chriseth Jan 6, 2025
df85d7f
Improve error message.
chriseth Jan 6, 2025
0242bfe
Apply suggestions from code review
chriseth Jan 6, 2025
0ab316e
Update executor/src/witgen/jit/compiler.rs
chriseth Jan 6, 2025
747ff3b
Test codegen.
chriseth Jan 6, 2025
c4cbc05
Merge branch 'single_step_with_branching' of ssh://github.com/powdr-l…
chriseth Jan 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 154 additions & 19 deletions executor/src/witgen/jit/compiler.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use std::{ffi::c_void, iter, mem, sync::Arc};
use std::{cmp::Ordering, ffi::c_void, iter, mem, sync::Arc};

use auto_enums::auto_enum;
use itertools::Itertools;
use libloading::Library;
use powdr_ast::indent;
use powdr_number::{FieldElement, KnownField};

use crate::witgen::{
Expand All @@ -16,7 +16,7 @@ use crate::witgen::{
};

use super::{
effect::{Assertion, Effect},
effect::{Assertion, BranchCondition, Effect},
symbolic_expression::{BinaryOperator, BitOperator, SymbolicExpression, UnaryOperator},
variable::Variable,
};
Expand Down Expand Up @@ -157,10 +157,11 @@ fn witgen_code<T: FieldElement>(
format!(" let {var_name} = {value};")
})
.format("\n");
let main_code = effects.iter().map(format_effect).format("\n");
let main_code = format_effects(effects);
let vars_known = effects
.iter()
.flat_map(written_vars_in_effect)
.map(|(var, _)| var)
.collect_vec();
let store_values = vars_known
.iter()
Expand Down Expand Up @@ -224,26 +225,51 @@ extern "C" fn witgen(
}

/// Returns an iterator over all variables written to in the effect.
#[auto_enum(Iterator)]
/// The flag indicates if the variable is the return value of a machine call and thus needs
/// to be declared mutable.
fn written_vars_in_effect<T: FieldElement>(
effect: &Effect<T, Variable>,
) -> impl Iterator<Item = &Variable> + '_ {
) -> Box<dyn Iterator<Item = (&Variable, bool)> + '_> {
match effect {
Effect::Assignment(var, _) => iter::once(var),
Effect::Assignment(var, _) => Box::new(iter::once((var, false))),
Effect::RangeConstraint(..) => unreachable!(),
Effect::Assertion(..) => iter::empty(),
Effect::MachineCall(_, arguments) => arguments.iter().flat_map(|e| match e {
MachineCallArgument::Unknown(v) => Some(v),
Effect::Assertion(..) => Box::new(iter::empty()),
Effect::MachineCall(_, arguments) => Box::new(arguments.iter().flat_map(|e| match e {
MachineCallArgument::Unknown(v) => Some((v, true)),
MachineCallArgument::Known(_) => None,
}),
})),
Effect::Branch(_, first, second) => Box::new(
first
.iter()
.chain(second)
.flat_map(|e| written_vars_in_effect(e)),
),
}
}

fn format_effect<T: FieldElement>(effect: &Effect<T, Variable>) -> String {
pub fn format_effects<T: FieldElement>(effects: &[Effect<T, Variable>]) -> String {
format_effects_inner(effects, true)
}

fn format_effects_inner<T: FieldElement>(
effects: &[Effect<T, Variable>],
is_top_level: bool,
) -> String {
indent(
effects
.iter()
.map(|effect| format_effect(effect, is_top_level))
.join("\n"),
1,
)
}

fn format_effect<T: FieldElement>(effect: &Effect<T, Variable>, is_top_level: bool) -> String {
match effect {
Effect::Assignment(var, e) => {
format!(
" let {} = {};",
"{}{} = {};",
if is_top_level { "let " } else { "" },
variable_to_string(var),
format_expression(e)
)
Expand All @@ -256,7 +282,7 @@ fn format_effect<T: FieldElement>(effect: &Effect<T, Variable>) -> String {
rhs,
expected_equal,
}) => format!(
" assert!({} {} {});",
"assert!({} {} {});",
format_expression(lhs),
if *expected_equal { "==" } else { "!=" },
format_expression(rhs)
Expand All @@ -268,7 +294,9 @@ fn format_effect<T: FieldElement>(effect: &Effect<T, Variable>) -> String {
.map(|a| match a {
MachineCallArgument::Unknown(v) => {
let var_name = variable_to_string(v);
result_vars.push(var_name.clone());
if is_top_level {
result_vars.push(var_name.clone());
}
format!("LookupCell::Output(&mut {var_name})")
}
MachineCallArgument::Known(v) => {
Expand All @@ -279,11 +307,40 @@ fn format_effect<T: FieldElement>(effect: &Effect<T, Variable>) -> String {
.to_string();
let var_decls = result_vars
.iter()
.map(|var_name| format!(" let mut {var_name} = FieldElement::default();"))
.format("\n");
.map(|var_name| format!("let mut {var_name} = FieldElement::default();\n"))
.format("");
format!(
"{var_decls}assert!(call_machine(mutable_state, {id}, MutSlice::from((&mut [{args}]).as_mut_slice())));"
)
}
Effect::Branch(condition, first, second) => {
let var_decls = if is_top_level {
// We need to declare all assigned variables at top level,
// so that they are available after the branches.
first
.iter()
.chain(second)
.flat_map(|e| written_vars_in_effect(e))
.sorted()
.dedup()
.map(|(v, needs_mut)| {
let v = variable_to_string(v);
if needs_mut {
format!("let mut {v} = FieldElement::default();\n")
} else {
format!("let {v};\n")
}
})
.format("")
.to_string()
} else {
"".to_string()
};
format!(
"{var_decls}
assert!(call_machine(mutable_state, {id}, MutSlice::from((&mut [{args}]).as_mut_slice())));"
"{var_decls}if {} {{\n{}\n}} else {{\n{}\n}}",
format_condition(condition),
format_effects_inner(first, false),
format_effects_inner(second, false)
)
}
}
Expand Down Expand Up @@ -319,6 +376,16 @@ fn format_expression<T: FieldElement>(e: &SymbolicExpression<T, Variable>) -> St
}
}

fn format_condition<T: FieldElement>(condition: &BranchCondition<T, Variable>) -> String {
let var = format!("IntType::from({})", variable_to_string(&condition.variable));
let (min, max) = condition.first_branch.range();
match min.cmp(&max) {
Ordering::Equal => format!("{var} == {min}",),
Ordering::Less => format!("{min} <= {var} && {var} <= {max}"),
Ordering::Greater => format!("{var} <= {min} || {var} >= {max}"),
}
}

/// Returns the name of a local (stack) variable for the given expression variable.
fn variable_to_string(v: &Variable) -> String {
match v {
Expand Down Expand Up @@ -423,6 +490,7 @@ mod tests {

use crate::witgen::jit::variable::Cell;
use crate::witgen::jit::variable::MachineCallReturnVariable;
use crate::witgen::range_constraints::RangeConstraint;

use super::*;

Expand Down Expand Up @@ -780,4 +848,71 @@ extern \"C\" fn witgen(
assert_eq!(data[1], GoldilocksField::from(18));
assert_eq!(data[2], GoldilocksField::from(0));
}

#[test]
fn branches() {
let x = param(0);
let y = param(1);
let mut x_val: GoldilocksField = 7.into();
let mut y_val: GoldilocksField = 9.into();
let effects = vec![Effect::Branch(
BranchCondition {
variable: x.clone(),
first_branch: RangeConstraint::from_range(7.into(), 20.into()),
second_branch: RangeConstraint::from_range(21.into(), 6.into()),
},
vec![assignment(&y, symbol(&x) + number(1))],
vec![assignment(&y, symbol(&x) + number(2))],
)];
let f = compile_effects(0, 1, &[x], &effects).unwrap();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
)];
)];
let witgen_code = witgen_code(&[x.clone()], &effects);
assert_eq!(
witgen_code,
"
#[no_mangle]
extern \"C\" fn witgen(
WitgenFunctionParams{
data,
known,
row_offset,
params,
mutable_state,
call_machine
}: WitgenFunctionParams<FieldElement>,
) {
let known = known_to_slice(known, data.len);
let data = data.to_mut_slice();
let params = params.to_mut_slice();
let p_0 = get_param(params, 0);
let p_1;
if 7 <= IntType::from(p_0) && IntType::from(p_0) <= 20 {
p_1 = (p_0 + FieldElement::from(1));} else {
p_1 = (p_0 + FieldElement::from(2));}
set_param(params, 1, p_1);
set_param(params, 1, p_1);
}
"
);

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At least for me, it helps to see a concrete example in the tests.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll instead add unit tests for the branch code itself. I don't really like these long auto-generated code tests because they need a lot of changes if we decide to do something in a different order (i.e. they don't really test the core that needs to be tested).

let mut data = vec![];
let mut known = vec![];

let mut params = vec![LookupCell::Input(&x_val), LookupCell::Output(&mut y_val)];
let params = WitgenFunctionParams {
data: data.as_mut_slice().into(),
known: known.as_mut_ptr(),
row_offset: 0,
params: params.as_mut_slice().into(),
mutable_state: std::ptr::null(),
call_machine: no_call_machine,
};
(f.function)(params);
assert_eq!(y_val, GoldilocksField::from(8));

x_val = 2.into();
let mut params = vec![LookupCell::Input(&x_val), LookupCell::Output(&mut y_val)];
let params = WitgenFunctionParams {
data: data.as_mut_slice().into(),
known: known.as_mut_ptr(),
row_offset: 0,
params: params.as_mut_slice().into(),
mutable_state: std::ptr::null(),
call_machine: no_call_machine,
};
(f.function)(params);
assert_eq!(y_val, GoldilocksField::from(4));
}

#[test]
fn branches_codegen() {
let x = param(0);
let y = param(1);
let branch_effect = Effect::Branch(
BranchCondition {
variable: x.clone(),
first_branch: RangeConstraint::from_range(7.into(), 20.into()),
second_branch: RangeConstraint::from_range(21.into(), 6.into()),
},
vec![assignment(&y, symbol(&x) + number(1))],
vec![assignment(&y, symbol(&x) + number(2))],
);
let expectation = " let p_1;
if 7 <= IntType::from(p_0) && IntType::from(p_0) <= 20 {
p_1 = (p_0 + FieldElement::from(1));
} else {
p_1 = (p_0 + FieldElement::from(2));
}";
assert_eq!(format_effects(&[branch_effect]), expectation);
}
}
34 changes: 33 additions & 1 deletion executor/src/witgen/jit/effect.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,30 @@
use std::cmp::Ordering;

use itertools::Itertools;
use powdr_ast::indent;
use powdr_number::FieldElement;

use crate::witgen::range_constraints::RangeConstraint;

use super::{symbolic_expression::SymbolicExpression, variable::Variable};

/// The effect of solving a symbolic equation.
#[derive(Clone, PartialEq, Eq)]
pub enum Effect<T: FieldElement, V> {
/// Variable can be assigned a value.
Assignment(V, SymbolicExpression<T, V>),
/// We learnt a new range constraint on variable.
RangeConstraint(V, RangeConstraint<T>),
/// A run-time assertion. If this fails, we have conflicting constraints.
Assertion(Assertion<T, V>),
/// a call to a different machine.
/// A call to a different machine.
MachineCall(u64, Vec<MachineCallArgument<T, V>>),
/// A branch on a variable.
Branch(BranchCondition<T, V>, Vec<Effect<T, V>>, Vec<Effect<T, V>>),
}

/// A run-time assertion. If this fails, we have conflicting constraints.
#[derive(Clone, PartialEq, Eq)]
pub struct Assertion<T: FieldElement, V> {
pub lhs: SymbolicExpression<T, V>,
pub rhs: SymbolicExpression<T, V>,
Expand Down Expand Up @@ -52,11 +59,19 @@ impl<T: FieldElement, V> Assertion<T, V> {
}
}

#[derive(Clone, PartialEq, Eq)]
pub enum MachineCallArgument<T: FieldElement, V> {
Known(SymbolicExpression<T, V>),
Unknown(V),
}

#[derive(Clone, PartialEq, Eq)]
pub struct BranchCondition<T: FieldElement, V> {
pub variable: V,
pub first_branch: RangeConstraint<T>,
pub second_branch: RangeConstraint<T>,
}

/// Helper function to render a list of effects. Used for informational purposes only.
pub fn format_code<T: FieldElement>(effects: &[Effect<T, Variable>]) -> String {
effects
Expand Down Expand Up @@ -87,6 +102,23 @@ pub fn format_code<T: FieldElement>(effects: &[Effect<T, Variable>]) -> String {
Effect::RangeConstraint(..) => {
panic!("Range constraints should not be part of the code.")
}
Effect::Branch(condition, first, second) => {
let first = indent(format_code(first), 1);
let second = indent(format_code(second), 1);
let condition = format_condition(condition);

format!("if ({condition}) {{\n{first}\n}} else {{\n{second}\n}}")
}
})
.join("\n")
}

fn format_condition<T: FieldElement>(condition: &BranchCondition<T, Variable>) -> String {
let var = &condition.variable;
let (min, max) = condition.first_branch.range();
match min.cmp(&max) {
Ordering::Equal => format!("{var} == {min}"),
Ordering::Less => format!("{min} <= {var} && {var} <= {max}"),
Ordering::Greater => format!("{var} <= {min} || {var} >= {max}"),
}
}
6 changes: 6 additions & 0 deletions executor/src/witgen/jit/includes/field_generic_up_to_64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ impl From<IntType> for FieldElement {
Self(i)
}
}
impl From<FieldElement> for IntType {
#[inline]
fn from(f: FieldElement) -> Self {
f.0
}
}
impl std::ops::Add for FieldElement {
type Output = Self;
#[inline]
Expand Down
8 changes: 8 additions & 0 deletions executor/src/witgen/jit/includes/field_goldilocks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
struct GoldilocksField(u64);

type FieldElement = GoldilocksField;
type IntType = u64;

const EPSILON: u64 = (1 << 32) - 1;

Expand Down Expand Up @@ -292,6 +293,13 @@ impl From<u64> for GoldilocksField {
}
}

impl From<FieldElement> for IntType {
#[inline]
fn from(f: FieldElement) -> Self {
f.0
}
}

impl std::fmt::Display for GoldilocksField {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
Expand Down
Loading
Loading