-
Notifications
You must be signed in to change notification settings - Fork 101
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Single step with branching #2274
Changes from 9 commits
3710c19
200cfa9
00f53da
cce4375
c54c0dd
71666c0
96b3660
405bdc1
2bdbe81
140e7d4
79dd351
e5b9846
e41976a
b346f62
e271ef7
0f8ac79
b58afc1
a40a133
eb245ac
1530cae
1588eb2
cc70c19
b13df2d
96d5506
b6a2ad9
b799c35
819d794
9604c79
250253c
93373a3
ce159f6
b45dd7f
4fc4a2e
3517df3
8fab6ae
cf9ea9f
352d63f
232b0d8
671546f
7d69841
df85d7f
0242bfe
0ab316e
747ff3b
c4cbc05
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,7 +10,7 @@ use crate::witgen::{machines::MachineParts, FixedData}; | |
use super::{ | ||
effect::Effect, | ||
variable::{Cell, Variable}, | ||
witgen_inference::{CanProcessCall, FixedEvaluator, WitgenInference}, | ||
witgen_inference::{BranchResult, CanProcessCall, FixedEvaluator, WitgenInference}, | ||
}; | ||
|
||
/// A processor for generating JIT code that computes the next row from the previous row. | ||
|
@@ -45,43 +45,67 @@ impl<'a, T: FieldElement> SingleStepProcessor<'a, T> { | |
// Check that we could derive all witness values in the next row. | ||
let unknown_witnesses = self | ||
.unknown_witness_cols_on_next_row(&witgen) | ||
// Sort to get deterministic code. | ||
.sorted() | ||
.collect_vec(); | ||
|
||
let missing_identities = self.machine_parts.identities.len() - complete.len(); | ||
let code = if unknown_witnesses.is_empty() && missing_identities == 0 { | ||
witgen.code() | ||
} else { | ||
let Some((most_constrained_var, _)) = witgen | ||
let Some(most_constrained_var) = witgen | ||
.known_variables() | ||
.iter() | ||
.filter_map(|var| witgen.range_constraint(var).map(|rc| (var, rc))) | ||
.filter(|(_, rc)| rc.try_to_single_value().is_none()) | ||
.sorted() | ||
.min_by_key(|(_, rc)| rc.range_width()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we have some limit on the width? In practice, this will bisect until it's a concrete value, right? Which would take forever for very large ranges? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would say we'll tune it later and see where we get with this. |
||
.map(|(var, _)| var.clone()) | ||
else { | ||
let incomplete_identities = self | ||
.machine_parts | ||
.identities | ||
.iter() | ||
.filter(|id| !complete.contains(&id.id())); | ||
let column_errors = if unknown_witnesses.is_empty() { | ||
"".to_string() | ||
} else { | ||
format!( | ||
"\nThe following columns are still missing: {}", | ||
unknown_witnesses | ||
.iter() | ||
.map(|wit| self.fixed_data.column_name(wit)) | ||
.format(", ") | ||
) | ||
}; | ||
let identity_errors = if missing_identities == 0 { | ||
"".to_string() | ||
} else { | ||
format!( | ||
"\nThe following identities have not been fully processed:\n{}", | ||
incomplete_identities | ||
.map(|id| format!(" {id}")) | ||
.join("\n") | ||
) | ||
}; | ||
return Err(format!( | ||
"Unable to derive algorithm to compute values for witness columns in the next row and\n\ | ||
unable to branch on a variable. The following columns are still missing:\n{}\nThe following identities have not been fully processed:\n{}", | ||
unknown_witnesses.iter().map(|wit| self.fixed_data.column_name(wit)).format(", "), | ||
incomplete_identities.map(|id| format!(" {id}")).join("\n") | ||
unable to branch on a variable.{column_errors}{identity_errors}", | ||
)); | ||
}; | ||
|
||
let (common_code, condition, other_branch) = | ||
witgen.branch_on(&most_constrained_var.clone()); | ||
let BranchResult { | ||
common_code, | ||
condition, | ||
branches: [first_branch, second_branch], | ||
} = witgen.branch_on(&most_constrained_var.clone()); | ||
|
||
// TODO Tuning: If this fails (or also if it does not generate progress right away), | ||
// we could also choose a different variable to branch on. | ||
let left_branch_code = | ||
self.generate_code_for_branch(can_process.clone(), witgen, complete.clone())?; | ||
self.generate_code_for_branch(can_process.clone(), first_branch, complete.clone())?; | ||
let right_branch_code = | ||
self.generate_code_for_branch(can_process, other_branch, complete)?; | ||
self.generate_code_for_branch(can_process, second_branch, complete)?; | ||
if left_branch_code == right_branch_code { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In which cases would this happen? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It happens when the chosen variable is not helpful. This is the case already in the test here. |
||
common_code.into_iter().chain(left_branch_code).collect() | ||
} else { | ||
|
@@ -99,6 +123,7 @@ impl<'a, T: FieldElement> SingleStepProcessor<'a, T> { | |
} | ||
|
||
fn initialize_witgen(&self) -> WitgenInference<'a, T, NoEval> { | ||
chriseth marked this conversation as resolved.
Show resolved
Hide resolved
|
||
// All witness columns in row 0 are known. | ||
let known_variables = self.machine_parts.witnesses.iter().map(|id| { | ||
Variable::Cell(Cell { | ||
column_name: self.fixed_data.column_name(id).to_string(), | ||
|
@@ -119,7 +144,9 @@ impl<'a, T: FieldElement> SingleStepProcessor<'a, T> { | |
while progress { | ||
progress = false; | ||
|
||
// TODO propagate known. | ||
// TODO At this point, we should call a function on `witgen` | ||
// to propagate known concrete values across the identities | ||
// to other known (but not concrete) variables. | ||
|
||
for id in &self.machine_parts.identities { | ||
if complete.contains(&id.id()) { | ||
|
@@ -229,9 +256,7 @@ mod test { | |
assert_eq!( | ||
err.to_string(), | ||
"Unable to derive algorithm to compute values for witness columns in the next row and\n\ | ||
unable to branch on a variable. The following columns are still missing:\n\ | ||
M::Y\n\ | ||
The following identities have not been fully processed:\n" | ||
unable to branch on a variable.\nThe following columns are still missing: M::Y" | ||
); | ||
} | ||
|
||
|
@@ -246,13 +271,13 @@ mod test { | |
let pc: col; | ||
|
||
col fixed LINE = [0, 1] + [2]*; | ||
col fixed INSTR_ADD = [0, 1] + [0]*; | ||
col fixed INSTR_ADD = [0, 1] + [0]*; | ||
col fixed INSTR_MUL = [1, 0] + [1]*; | ||
|
||
pc' = pc + 1; | ||
[ pc, instr_add, instr_mul ] in [ LINE, INSTR_ADD, INSTR_MUL ]; | ||
|
||
instr_add * (A' - (A + B)) + instr_mul * (A' - A * B) + (1 - instr_add - instr_mul) * (A' - A) = 0; | ||
instr_add * (A' - (A + B)) + instr_mul * (A' - A * B) + (1 - instr_add - instr_mul) * (A' - A) = 0; | ||
B' = B; | ||
"; | ||
let code = generate_single_step(input, "Main").unwrap(); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,7 +35,7 @@ pub struct ProcessSummary { | |
/// This component can generate code that solves identities. | ||
/// It needs a driver that tells it which identities to process on which rows. | ||
#[derive(Clone)] | ||
pub struct WitgenInference<'a, T: FieldElement, FixedEval: FixedEvaluator<T>> { | ||
pub struct WitgenInference<'a, T: FieldElement, FixedEval> { | ||
fixed_data: &'a FixedData<'a, T>, | ||
fixed_evaluator: FixedEval, | ||
derived_range_constraints: HashMap<Variable, RangeConstraint<T>>, | ||
|
@@ -62,6 +62,16 @@ impl<T: Display> Display for Value<T> { | |
} | ||
} | ||
|
||
/// Return type of the `branch_on` method. | ||
pub struct BranchResult<'a, T: FieldElement, FixedEval> { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ❤️ |
||
/// The code common to both branches. | ||
pub common_code: Vec<Effect<T, Variable>>, | ||
/// The condition of the branch. | ||
pub condition: BranchCondition<T, Variable>, | ||
/// The two branches. | ||
pub branches: [WitgenInference<'a, T, FixedEval>; 2], | ||
} | ||
|
||
impl<'a, T: FieldElement, FixedEval: FixedEvaluator<T>> WitgenInference<'a, T, FixedEval> { | ||
pub fn new( | ||
fixed_data: &'a FixedData<'a, T>, | ||
|
@@ -101,10 +111,11 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator<T>> WitgenInference<'a, T, F | |
} | ||
} | ||
|
||
pub fn branch_on( | ||
&mut self, | ||
variable: &Variable, | ||
) -> (Vec<Effect<T, Variable>>, BranchCondition<T, Variable>, Self) { | ||
/// Splits the current inference into two copies - one where the provided variable | ||
/// is in the "second half" of its range constraint and one where it is in the | ||
/// "first half" of its range constraint (determined by calling the `bisect` method). | ||
/// Returns the common code, the branch condition and the two branches. | ||
pub fn branch_on(mut self, variable: &Variable) -> BranchResult<'a, T, FixedEval> { | ||
// The variable needs to be known, we need to have a range constraint but | ||
// it cannot be a single value. | ||
assert!(self.known_variables.contains(variable)); | ||
|
@@ -118,21 +129,21 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator<T>> WitgenInference<'a, T, F | |
|
||
let (low_condition, high_condition) = rc.bisect(); | ||
|
||
let code = std::mem::take(&mut self.code); | ||
let common_code = std::mem::take(&mut self.code); | ||
let mut low_branch = self.clone(); | ||
|
||
self.add_range_constraint(variable.clone(), high_condition.clone()); | ||
low_branch.add_range_constraint(variable.clone(), low_condition.clone()); | ||
|
||
( | ||
code, | ||
BranchCondition { | ||
BranchResult { | ||
common_code, | ||
condition: BranchCondition { | ||
variable: variable.clone(), | ||
first_branch: high_condition, | ||
second_branch: low_condition, | ||
}, | ||
low_branch, | ||
) | ||
branches: [self, low_branch], | ||
} | ||
} | ||
|
||
/// Process an identity on a certain row. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this has the effect that when in two range widths are equal we pick the column with the smaller ID. Is this so that the generated code is deterministic? I think this deserves a comment then :)