Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Single step with branching #2274

Merged
merged 45 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
3710c19
Bisect range constraints.
chriseth Dec 19, 2024
200cfa9
fix mask.
chriseth Dec 20, 2024
00f53da
remove commented-out code.
chriseth Dec 20, 2024
cce4375
Single step processor.
chriseth Dec 20, 2024
c54c0dd
Negative test.
chriseth Dec 20, 2024
71666c0
branching test.
chriseth Dec 20, 2024
96b3660
loop
chriseth Dec 20, 2024
405bdc1
c
chriseth Dec 20, 2024
2bdbe81
compute branch var
chriseth Dec 20, 2024
140e7d4
Merge branch 'bisect' into single_step_with_branching
chriseth Dec 20, 2024
79dd351
almost branch.
chriseth Dec 20, 2024
e5b9846
Merge remote-tracking branch 'origin/main' into single_step_with_bran…
chriseth Dec 20, 2024
e41976a
Add "can process"
chriseth Dec 20, 2024
b346f62
Fix test.
chriseth Dec 20, 2024
e271ef7
Fix test again.
chriseth Dec 20, 2024
0f8ac79
fix test.
chriseth Dec 20, 2024
b58afc1
swap branches
chriseth Dec 20, 2024
a40a133
Merge remote-tracking branch 'origin/main' into single_step_simple
chriseth Jan 2, 2025
eb245ac
Apply suggestions from code review
chriseth Jan 2, 2025
1530cae
Merge branch 'single_step_simple' of ssh://github.com/powdr-labs/powd…
chriseth Jan 2, 2025
1588eb2
adjust tests.
chriseth Jan 2, 2025
cc70c19
Remove test function.
chriseth Jan 2, 2025
b13df2d
Merge branch 'remove_test_function' into single_step_simple
chriseth Jan 2, 2025
96d5506
Fix testing functions.
chriseth Jan 2, 2025
b6a2ad9
Merge remote-tracking branch 'origin/main' into single_step_with_bran…
chriseth Jan 2, 2025
b799c35
Merge branch 'single_step_simple' into single_step_with_branching
chriseth Jan 2, 2025
819d794
Fix merge
chriseth Jan 2, 2025
9604c79
Fix test.
chriseth Jan 2, 2025
250253c
Fix written vars.
chriseth Jan 2, 2025
93373a3
Fix test expectation.
chriseth Jan 2, 2025
ce159f6
Code generation for branches.
chriseth Jan 2, 2025
b45dd7f
clippy
chriseth Jan 2, 2025
4fc4a2e
Merge branch 'single_step_simple'
chriseth Jan 3, 2025
3517df3
Merge commit '50fd4580b' into single_step_with_branching
chriseth Jan 3, 2025
8fab6ae
Merge branch 'main' into single_step_with_branching
chriseth Jan 3, 2025
cf9ea9f
Fix formatting and tests.
chriseth Jan 3, 2025
352d63f
Update executor/src/witgen/jit/single_step_processor.rs
chriseth Jan 6, 2025
232b0d8
Extend comment
chriseth Jan 6, 2025
671546f
Merge branch 'single_step_with_branching' of ssh://github.com/powdr-l…
chriseth Jan 6, 2025
7d69841
Return both copies in on_branch.
chriseth Jan 6, 2025
df85d7f
Improve error message.
chriseth Jan 6, 2025
0242bfe
Apply suggestions from code review
chriseth Jan 6, 2025
0ab316e
Update executor/src/witgen/jit/compiler.rs
chriseth Jan 6, 2025
747ff3b
Test codegen.
chriseth Jan 6, 2025
c4cbc05
Merge branch 'single_step_with_branching' of ssh://github.com/powdr-l…
chriseth Jan 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions executor/src/witgen/jit/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ extern "C" fn witgen(
}

/// Returns an iterator over all variables written to in the effect.
/// The flag indicates if the variable is used in a lookup and thus needs
/// The flag indicates if the variable is the return value of a machine call and thus needs
/// to be declared mutable.
fn written_vars_in_effect<T: FieldElement>(
effect: &Effect<T, Variable>,
Expand Down Expand Up @@ -337,7 +337,7 @@ fn format_effect<T: FieldElement>(effect: &Effect<T, Variable>, is_top_level: bo
"".to_string()
};
format!(
"{var_decls}if {} {{\n{}}} else {{\n{}}}",
"{var_decls}if {} {{\n{}\n}} else {{\n{}\n}}",
format_condition(condition),
format_effects_inner(first, false),
format_effects_inner(second, false)
Expand Down Expand Up @@ -893,4 +893,26 @@ extern \"C\" fn witgen(
(f.function)(params);
assert_eq!(y_val, GoldilocksField::from(4));
}

#[test]
fn branches_codegen() {
let x = param(0);
let y = param(1);
let branch_effect = Effect::Branch(
BranchCondition {
variable: x.clone(),
first_branch: RangeConstraint::from_range(7.into(), 20.into()),
second_branch: RangeConstraint::from_range(21.into(), 6.into()),
},
vec![assignment(&y, symbol(&x) + number(1))],
vec![assignment(&y, symbol(&x) + number(2))],
);
let expectation = " let p_1;
if 7 <= IntType::from(p_0) && IntType::from(p_0) <= 20 {
p_1 = (p_0 + FieldElement::from(1));
} else {
p_1 = (p_0 + FieldElement::from(2));
}";
assert_eq!(format_effects(&[branch_effect]), expectation);
}
}
55 changes: 40 additions & 15 deletions executor/src/witgen/jit/single_step_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::witgen::{machines::MachineParts, FixedData};
use super::{
effect::Effect,
variable::{Cell, Variable},
witgen_inference::{CanProcessCall, FixedEvaluator, WitgenInference},
witgen_inference::{BranchResult, CanProcessCall, FixedEvaluator, WitgenInference},
};

/// A processor for generating JIT code that computes the next row from the previous row.
Expand Down Expand Up @@ -45,43 +45,67 @@ impl<'a, T: FieldElement> SingleStepProcessor<'a, T> {
// Check that we could derive all witness values in the next row.
let unknown_witnesses = self
.unknown_witness_cols_on_next_row(&witgen)
// Sort to get deterministic code.
.sorted()
.collect_vec();

let missing_identities = self.machine_parts.identities.len() - complete.len();
let code = if unknown_witnesses.is_empty() && missing_identities == 0 {
witgen.code()
} else {
let Some((most_constrained_var, _)) = witgen
let Some(most_constrained_var) = witgen
.known_variables()
.iter()
.filter_map(|var| witgen.range_constraint(var).map(|rc| (var, rc)))
.filter(|(_, rc)| rc.try_to_single_value().is_none())
.sorted()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this has the effect that when in two range widths are equal we pick the column with the smaller ID. Is this so that the generated code is deterministic? I think this deserves a comment then :)

.min_by_key(|(_, rc)| rc.range_width())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we have some limit on the width? In practice, this will bisect until it's a concrete value, right? Which would take forever for very large ranges?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say we'll tune it later and see where we get with this.

.map(|(var, _)| var.clone())
else {
let incomplete_identities = self
.machine_parts
.identities
.iter()
.filter(|id| !complete.contains(&id.id()));
let column_errors = if unknown_witnesses.is_empty() {
"".to_string()
} else {
format!(
"\nThe following columns are still missing: {}",
unknown_witnesses
.iter()
.map(|wit| self.fixed_data.column_name(wit))
.format(", ")
)
};
let identity_errors = if missing_identities == 0 {
"".to_string()
} else {
format!(
"\nThe following identities have not been fully processed:\n{}",
incomplete_identities
.map(|id| format!(" {id}"))
.join("\n")
)
};
return Err(format!(
"Unable to derive algorithm to compute values for witness columns in the next row and\n\
unable to branch on a variable. The following columns are still missing:\n{}\nThe following identities have not been fully processed:\n{}",
unknown_witnesses.iter().map(|wit| self.fixed_data.column_name(wit)).format(", "),
incomplete_identities.map(|id| format!(" {id}")).join("\n")
unable to branch on a variable.{column_errors}{identity_errors}",
));
};

let (common_code, condition, other_branch) =
witgen.branch_on(&most_constrained_var.clone());
let BranchResult {
common_code,
condition,
branches: [first_branch, second_branch],
} = witgen.branch_on(&most_constrained_var.clone());

// TODO Tuning: If this fails (or also if it does not generate progress right away),
// we could also choose a different variable to branch on.
let left_branch_code =
self.generate_code_for_branch(can_process.clone(), witgen, complete.clone())?;
self.generate_code_for_branch(can_process.clone(), first_branch, complete.clone())?;
let right_branch_code =
self.generate_code_for_branch(can_process, other_branch, complete)?;
self.generate_code_for_branch(can_process, second_branch, complete)?;
if left_branch_code == right_branch_code {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In which cases would this happen?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It happens when the chosen variable is not helpful. This is the case already in the test here.

common_code.into_iter().chain(left_branch_code).collect()
} else {
Expand All @@ -99,6 +123,7 @@ impl<'a, T: FieldElement> SingleStepProcessor<'a, T> {
}

fn initialize_witgen(&self) -> WitgenInference<'a, T, NoEval> {
// All witness columns in row 0 are known.
let known_variables = self.machine_parts.witnesses.iter().map(|id| {
Variable::Cell(Cell {
column_name: self.fixed_data.column_name(id).to_string(),
Expand All @@ -119,7 +144,9 @@ impl<'a, T: FieldElement> SingleStepProcessor<'a, T> {
while progress {
progress = false;

// TODO propagate known.
// TODO At this point, we should call a function on `witgen`
// to propagate known concrete values across the identities
// to other known (but not concrete) variables.

for id in &self.machine_parts.identities {
if complete.contains(&id.id()) {
Expand Down Expand Up @@ -229,9 +256,7 @@ mod test {
assert_eq!(
err.to_string(),
"Unable to derive algorithm to compute values for witness columns in the next row and\n\
unable to branch on a variable. The following columns are still missing:\n\
M::Y\n\
The following identities have not been fully processed:\n"
unable to branch on a variable.\nThe following columns are still missing: M::Y"
);
}

Expand All @@ -246,13 +271,13 @@ mod test {
let pc: col;

col fixed LINE = [0, 1] + [2]*;
col fixed INSTR_ADD = [0, 1] + [0]*;
col fixed INSTR_ADD = [0, 1] + [0]*;
col fixed INSTR_MUL = [1, 0] + [1]*;

pc' = pc + 1;
[ pc, instr_add, instr_mul ] in [ LINE, INSTR_ADD, INSTR_MUL ];

instr_add * (A' - (A + B)) + instr_mul * (A' - A * B) + (1 - instr_add - instr_mul) * (A' - A) = 0;
instr_add * (A' - (A + B)) + instr_mul * (A' - A * B) + (1 - instr_add - instr_mul) * (A' - A) = 0;
B' = B;
";
let code = generate_single_step(input, "Main").unwrap();
Expand Down
33 changes: 22 additions & 11 deletions executor/src/witgen/jit/witgen_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ pub struct ProcessSummary {
/// This component can generate code that solves identities.
/// It needs a driver that tells it which identities to process on which rows.
#[derive(Clone)]
pub struct WitgenInference<'a, T: FieldElement, FixedEval: FixedEvaluator<T>> {
pub struct WitgenInference<'a, T: FieldElement, FixedEval> {
fixed_data: &'a FixedData<'a, T>,
fixed_evaluator: FixedEval,
derived_range_constraints: HashMap<Variable, RangeConstraint<T>>,
Expand All @@ -62,6 +62,16 @@ impl<T: Display> Display for Value<T> {
}
}

/// Return type of the `branch_on` method.
pub struct BranchResult<'a, T: FieldElement, FixedEval> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️

/// The code common to both branches.
pub common_code: Vec<Effect<T, Variable>>,
/// The condition of the branch.
pub condition: BranchCondition<T, Variable>,
/// The two branches.
pub branches: [WitgenInference<'a, T, FixedEval>; 2],
}

impl<'a, T: FieldElement, FixedEval: FixedEvaluator<T>> WitgenInference<'a, T, FixedEval> {
pub fn new(
fixed_data: &'a FixedData<'a, T>,
Expand Down Expand Up @@ -101,10 +111,11 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator<T>> WitgenInference<'a, T, F
}
}

pub fn branch_on(
&mut self,
variable: &Variable,
) -> (Vec<Effect<T, Variable>>, BranchCondition<T, Variable>, Self) {
/// Splits the current inference into two copies - one where the provided variable
/// is in the "second half" of its range constraint and one where it is in the
/// "first half" of its range constraint (determined by calling the `bisect` method).
/// Returns the common code, the branch condition and the two branches.
pub fn branch_on(mut self, variable: &Variable) -> BranchResult<'a, T, FixedEval> {
// The variable needs to be known, we need to have a range constraint but
// it cannot be a single value.
assert!(self.known_variables.contains(variable));
Expand All @@ -118,21 +129,21 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator<T>> WitgenInference<'a, T, F

let (low_condition, high_condition) = rc.bisect();

let code = std::mem::take(&mut self.code);
let common_code = std::mem::take(&mut self.code);
let mut low_branch = self.clone();

self.add_range_constraint(variable.clone(), high_condition.clone());
low_branch.add_range_constraint(variable.clone(), low_condition.clone());

(
code,
BranchCondition {
BranchResult {
common_code,
condition: BranchCondition {
variable: variable.clone(),
first_branch: high_condition,
second_branch: low_condition,
},
low_branch,
)
branches: [self, low_branch],
}
}

/// Process an identity on a certain row.
Expand Down
Loading