@@ -432,29 +432,29 @@ def all_steps(initial_state, keys_sampling, keys_adaptation):
432
432
step_size = jnp .zeros ((num_steps ,))
433
433
434
434
def step_while (a ):
435
- x , i , _ = a
435
+ x , i , _ , EEVPD , EEVPD_wanted , L , entropy , equi_diag , equi_full , observables , r_avg , r_max , step_size = a
436
436
437
437
auxilliary_input = (xs [0 ][i ], xs [1 ][i ], xs [2 ][i ])
438
438
439
439
output , (info , pos ) = step (x , auxilliary_input )
440
- EEVPD .at [i ].set (info .get ("EEVPD" ))
441
- EEVPD_wanted .at [i ].set (info .get ("EEVPD_wanted" ))
442
- L .at [i ].set (info .get ("L" ))
443
- entropy .at [i ].set (info .get ("entropy" ))
444
- equi_diag .at [i ].set (info .get ("equi_diag" ))
445
- equi_full .at [i ].set (info .get ("equi_full" ))
446
- observables .at [i ].set (info .get ("observables" ))
447
- r_avg .at [i ].set (info .get ("r_avg" ))
448
- r_max .at [i ].set (info .get ("r_max" ))
449
- step_size .at [i ].set (info .get ("step_size" ))
450
-
451
- return (output , i + 1 , info .get ("while_cond" ))
440
+ new_EEVPD = EEVPD .at [i ].set (info .get ("EEVPD" ))
441
+ new_EEVPD_wanted = EEVPD_wanted .at [i ].set (info .get ("EEVPD_wanted" ))
442
+ new_L = L .at [i ].set (info .get ("L" ))
443
+ new_entropy = entropy .at [i ].set (info .get ("entropy" ))
444
+ new_equi_diag = equi_diag .at [i ].set (info .get ("equi_diag" ))
445
+ new_equi_full = equi_full .at [i ].set (info .get ("equi_full" ))
446
+ new_observables = observables .at [i ].set (info .get ("observables" ))
447
+ new_r_avg = r_avg .at [i ].set (info .get ("r_avg" ))
448
+ new_r_max = r_max .at [i ].set (info .get ("r_max" ))
449
+ new_step_size = step_size .at [i ].set (info .get ("step_size" ))
450
+
451
+ return (output , i + 1 , info .get ("while_cond" ), new_EEVPD , new_EEVPD_wanted , new_L , new_entropy , new_equi_diag , new_equi_full , new_observables , new_r_avg , new_r_max , new_step_size )
452
452
453
453
if early_stop :
454
- final_state_all , i , _ = lax .while_loop (
454
+ final_state_all , i , _ , EEVPD , EEVPD_wanted , L , entropy , equi_diag , equi_full , observables , r_avg , r_max , step_size = lax .while_loop (
455
455
lambda a : ((a [1 ] < num_steps ) & a [2 ]),
456
456
step_while ,
457
- (initial_state_all , 0 , True ),
457
+ (initial_state_all , 0 , True , EEVPD , EEVPD_wanted , L , entropy , equi_diag , equi_full , observables , r_avg , r_max , step_size ),
458
458
)
459
459
steps_done = i
460
460
info_history = {
0 commit comments