1
- use std:: { ffi:: c_void, iter, mem, sync:: Arc } ;
1
+ use std:: { cmp :: Ordering , ffi:: c_void, iter, mem, sync:: Arc } ;
2
2
3
- use auto_enums:: auto_enum;
4
3
use itertools:: Itertools ;
5
4
use libloading:: Library ;
5
+ use powdr_ast:: indent;
6
6
use powdr_number:: { FieldElement , KnownField } ;
7
7
8
8
use crate :: witgen:: {
@@ -16,7 +16,7 @@ use crate::witgen::{
16
16
} ;
17
17
18
18
use super :: {
19
- effect:: { Assertion , Effect } ,
19
+ effect:: { Assertion , BranchCondition , Effect } ,
20
20
symbolic_expression:: { BinaryOperator , BitOperator , SymbolicExpression , UnaryOperator } ,
21
21
variable:: Variable ,
22
22
} ;
@@ -157,10 +157,11 @@ fn witgen_code<T: FieldElement>(
157
157
format ! ( " let {var_name} = {value};" )
158
158
} )
159
159
. format ( "\n " ) ;
160
- let main_code = effects . iter ( ) . map ( format_effect ) . format ( " \n " ) ;
160
+ let main_code = format_effects ( effects ) ;
161
161
let vars_known = effects
162
162
. iter ( )
163
163
. flat_map ( written_vars_in_effect)
164
+ . map ( |( var, _) | var)
164
165
. collect_vec ( ) ;
165
166
let store_values = vars_known
166
167
. iter ( )
@@ -224,26 +225,51 @@ extern "C" fn witgen(
224
225
}
225
226
226
227
/// Returns an iterator over all variables written to in the effect.
227
- #[ auto_enum( Iterator ) ]
228
+ /// The flag indicates if the variable is the return value of a machine call and thus needs
229
+ /// to be declared mutable.
228
230
fn written_vars_in_effect < T : FieldElement > (
229
231
effect : & Effect < T , Variable > ,
230
- ) -> impl Iterator < Item = & Variable > + ' _ {
232
+ ) -> Box < dyn Iterator < Item = ( & Variable , bool ) > + ' _ > {
231
233
match effect {
232
- Effect :: Assignment ( var, _) => iter:: once ( var) ,
234
+ Effect :: Assignment ( var, _) => Box :: new ( iter:: once ( ( var, false ) ) ) ,
233
235
Effect :: RangeConstraint ( ..) => unreachable ! ( ) ,
234
- Effect :: Assertion ( ..) => iter:: empty ( ) ,
235
- Effect :: MachineCall ( _, arguments) => arguments. iter ( ) . flat_map ( |e| match e {
236
- MachineCallArgument :: Unknown ( v) => Some ( v ) ,
236
+ Effect :: Assertion ( ..) => Box :: new ( iter:: empty ( ) ) ,
237
+ Effect :: MachineCall ( _, arguments) => Box :: new ( arguments. iter ( ) . flat_map ( |e| match e {
238
+ MachineCallArgument :: Unknown ( v) => Some ( ( v , true ) ) ,
237
239
MachineCallArgument :: Known ( _) => None ,
238
- } ) ,
240
+ } ) ) ,
241
+ Effect :: Branch ( _, first, second) => Box :: new (
242
+ first
243
+ . iter ( )
244
+ . chain ( second)
245
+ . flat_map ( |e| written_vars_in_effect ( e) ) ,
246
+ ) ,
239
247
}
240
248
}
241
249
242
- fn format_effect < T : FieldElement > ( effect : & Effect < T , Variable > ) -> String {
250
+ pub fn format_effects < T : FieldElement > ( effects : & [ Effect < T , Variable > ] ) -> String {
251
+ format_effects_inner ( effects, true )
252
+ }
253
+
254
+ fn format_effects_inner < T : FieldElement > (
255
+ effects : & [ Effect < T , Variable > ] ,
256
+ is_top_level : bool ,
257
+ ) -> String {
258
+ indent (
259
+ effects
260
+ . iter ( )
261
+ . map ( |effect| format_effect ( effect, is_top_level) )
262
+ . join ( "\n " ) ,
263
+ 1 ,
264
+ )
265
+ }
266
+
267
+ fn format_effect < T : FieldElement > ( effect : & Effect < T , Variable > , is_top_level : bool ) -> String {
243
268
match effect {
244
269
Effect :: Assignment ( var, e) => {
245
270
format ! (
246
- " let {} = {};" ,
271
+ "{}{} = {};" ,
272
+ if is_top_level { "let " } else { "" } ,
247
273
variable_to_string( var) ,
248
274
format_expression( e)
249
275
)
@@ -256,7 +282,7 @@ fn format_effect<T: FieldElement>(effect: &Effect<T, Variable>) -> String {
256
282
rhs,
257
283
expected_equal,
258
284
} ) => format ! (
259
- " assert!({} {} {});" ,
285
+ "assert!({} {} {});" ,
260
286
format_expression( lhs) ,
261
287
if * expected_equal { "==" } else { "!=" } ,
262
288
format_expression( rhs)
@@ -268,7 +294,9 @@ fn format_effect<T: FieldElement>(effect: &Effect<T, Variable>) -> String {
268
294
. map ( |a| match a {
269
295
MachineCallArgument :: Unknown ( v) => {
270
296
let var_name = variable_to_string ( v) ;
271
- result_vars. push ( var_name. clone ( ) ) ;
297
+ if is_top_level {
298
+ result_vars. push ( var_name. clone ( ) ) ;
299
+ }
272
300
format ! ( "LookupCell::Output(&mut {var_name})" )
273
301
}
274
302
MachineCallArgument :: Known ( v) => {
@@ -279,11 +307,40 @@ fn format_effect<T: FieldElement>(effect: &Effect<T, Variable>) -> String {
279
307
. to_string ( ) ;
280
308
let var_decls = result_vars
281
309
. iter ( )
282
- . map ( |var_name| format ! ( " let mut {var_name} = FieldElement::default();" ) )
283
- . format ( "\n " ) ;
310
+ . map ( |var_name| format ! ( "let mut {var_name} = FieldElement::default();\n " ) )
311
+ . format ( "" ) ;
312
+ format ! (
313
+ "{var_decls}assert!(call_machine(mutable_state, {id}, MutSlice::from((&mut [{args}]).as_mut_slice())));"
314
+ )
315
+ }
316
+ Effect :: Branch ( condition, first, second) => {
317
+ let var_decls = if is_top_level {
318
+ // We need to declare all assigned variables at top level,
319
+ // so that they are available after the branches.
320
+ first
321
+ . iter ( )
322
+ . chain ( second)
323
+ . flat_map ( |e| written_vars_in_effect ( e) )
324
+ . sorted ( )
325
+ . dedup ( )
326
+ . map ( |( v, needs_mut) | {
327
+ let v = variable_to_string ( v) ;
328
+ if needs_mut {
329
+ format ! ( "let mut {v} = FieldElement::default();\n " )
330
+ } else {
331
+ format ! ( "let {v};\n " )
332
+ }
333
+ } )
334
+ . format ( "" )
335
+ . to_string ( )
336
+ } else {
337
+ "" . to_string ( )
338
+ } ;
284
339
format ! (
285
- "{var_decls}
286
- assert!(call_machine(mutable_state, {id}, MutSlice::from((&mut [{args}]).as_mut_slice())));"
340
+ "{var_decls}if {} {{\n {}\n }} else {{\n {}\n }}" ,
341
+ format_condition( condition) ,
342
+ format_effects_inner( first, false ) ,
343
+ format_effects_inner( second, false )
287
344
)
288
345
}
289
346
}
@@ -319,6 +376,16 @@ fn format_expression<T: FieldElement>(e: &SymbolicExpression<T, Variable>) -> St
319
376
}
320
377
}
321
378
379
+ fn format_condition < T : FieldElement > ( condition : & BranchCondition < T , Variable > ) -> String {
380
+ let var = format ! ( "IntType::from({})" , variable_to_string( & condition. variable) ) ;
381
+ let ( min, max) = condition. first_branch . range ( ) ;
382
+ match min. cmp ( & max) {
383
+ Ordering :: Equal => format ! ( "{var} == {min}" , ) ,
384
+ Ordering :: Less => format ! ( "{min} <= {var} && {var} <= {max}" ) ,
385
+ Ordering :: Greater => format ! ( "{var} <= {min} || {var} >= {max}" ) ,
386
+ }
387
+ }
388
+
322
389
/// Returns the name of a local (stack) variable for the given expression variable.
323
390
fn variable_to_string ( v : & Variable ) -> String {
324
391
match v {
@@ -423,6 +490,7 @@ mod tests {
423
490
424
491
use crate :: witgen:: jit:: variable:: Cell ;
425
492
use crate :: witgen:: jit:: variable:: MachineCallReturnVariable ;
493
+ use crate :: witgen:: range_constraints:: RangeConstraint ;
426
494
427
495
use super :: * ;
428
496
@@ -780,4 +848,71 @@ extern \"C\" fn witgen(
780
848
assert_eq ! ( data[ 1 ] , GoldilocksField :: from( 18 ) ) ;
781
849
assert_eq ! ( data[ 2 ] , GoldilocksField :: from( 0 ) ) ;
782
850
}
851
+
852
+ #[ test]
853
+ fn branches ( ) {
854
+ let x = param ( 0 ) ;
855
+ let y = param ( 1 ) ;
856
+ let mut x_val: GoldilocksField = 7 . into ( ) ;
857
+ let mut y_val: GoldilocksField = 9 . into ( ) ;
858
+ let effects = vec ! [ Effect :: Branch (
859
+ BranchCondition {
860
+ variable: x. clone( ) ,
861
+ first_branch: RangeConstraint :: from_range( 7 . into( ) , 20 . into( ) ) ,
862
+ second_branch: RangeConstraint :: from_range( 21 . into( ) , 6 . into( ) ) ,
863
+ } ,
864
+ vec![ assignment( & y, symbol( & x) + number( 1 ) ) ] ,
865
+ vec![ assignment( & y, symbol( & x) + number( 2 ) ) ] ,
866
+ ) ] ;
867
+ let f = compile_effects ( 0 , 1 , & [ x] , & effects) . unwrap ( ) ;
868
+ let mut data = vec ! [ ] ;
869
+ let mut known = vec ! [ ] ;
870
+
871
+ let mut params = vec ! [ LookupCell :: Input ( & x_val) , LookupCell :: Output ( & mut y_val) ] ;
872
+ let params = WitgenFunctionParams {
873
+ data : data. as_mut_slice ( ) . into ( ) ,
874
+ known : known. as_mut_ptr ( ) ,
875
+ row_offset : 0 ,
876
+ params : params. as_mut_slice ( ) . into ( ) ,
877
+ mutable_state : std:: ptr:: null ( ) ,
878
+ call_machine : no_call_machine,
879
+ } ;
880
+ ( f. function ) ( params) ;
881
+ assert_eq ! ( y_val, GoldilocksField :: from( 8 ) ) ;
882
+
883
+ x_val = 2 . into ( ) ;
884
+ let mut params = vec ! [ LookupCell :: Input ( & x_val) , LookupCell :: Output ( & mut y_val) ] ;
885
+ let params = WitgenFunctionParams {
886
+ data : data. as_mut_slice ( ) . into ( ) ,
887
+ known : known. as_mut_ptr ( ) ,
888
+ row_offset : 0 ,
889
+ params : params. as_mut_slice ( ) . into ( ) ,
890
+ mutable_state : std:: ptr:: null ( ) ,
891
+ call_machine : no_call_machine,
892
+ } ;
893
+ ( f. function ) ( params) ;
894
+ assert_eq ! ( y_val, GoldilocksField :: from( 4 ) ) ;
895
+ }
896
+
897
+ #[ test]
898
+ fn branches_codegen ( ) {
899
+ let x = param ( 0 ) ;
900
+ let y = param ( 1 ) ;
901
+ let branch_effect = Effect :: Branch (
902
+ BranchCondition {
903
+ variable : x. clone ( ) ,
904
+ first_branch : RangeConstraint :: from_range ( 7 . into ( ) , 20 . into ( ) ) ,
905
+ second_branch : RangeConstraint :: from_range ( 21 . into ( ) , 6 . into ( ) ) ,
906
+ } ,
907
+ vec ! [ assignment( & y, symbol( & x) + number( 1 ) ) ] ,
908
+ vec ! [ assignment( & y, symbol( & x) + number( 2 ) ) ] ,
909
+ ) ;
910
+ let expectation = " let p_1;
911
+ if 7 <= IntType::from(p_0) && IntType::from(p_0) <= 20 {
912
+ p_1 = (p_0 + FieldElement::from(1));
913
+ } else {
914
+ p_1 = (p_0 + FieldElement::from(2));
915
+ }" ;
916
+ assert_eq ! ( format_effects( & [ branch_effect] ) , expectation) ;
917
+ }
783
918
}
0 commit comments