@@ -420,15 +420,35 @@ def all_steps(initial_state, keys_sampling, keys_adaptation):
420
420
keys_adaptation ,
421
421
) # keys for all steps that will be performed. keys_sampling.shape = (num_steps, chains_per_device), keys_adaptation.shape = (num_steps, )
422
422
423
- # ((a, Int) -> (a, Int))
423
+ EEVPD = jnp .zeros ((num_steps ,))
424
+ EEVPD_wanted = jnp .zeros ((num_steps ,))
425
+ L = jnp .zeros ((num_steps ,))
426
+ entropy = jnp .zeros ((num_steps ,))
427
+ equi_diag = jnp .zeros ((num_steps ,))
428
+ equi_full = jnp .zeros ((num_steps ,))
429
+ observables = jnp .zeros ((num_steps ,))
430
+ r_avg = jnp .zeros ((num_steps ,))
431
+ r_max = jnp .zeros ((num_steps ,))
432
+ step_size = jnp .zeros ((num_steps ,))
433
+
424
434
def step_while (a ):
425
435
x , i , _ = a
426
436
427
437
auxilliary_input = (xs [0 ][i ], xs [1 ][i ], xs [2 ][i ])
428
438
429
- output , info = step (x , auxilliary_input )
439
+ output , (info , pos ) = step (x , auxilliary_input )
440
+ EEVPD .at [i ].set (info .get ("EEVPD" ))
441
+ EEVPD_wanted .at [i ].set (info .get ("EEVPD_wanted" ))
442
+ L .at [i ].set (info .get ("L" ))
443
+ entropy .at [i ].set (info .get ("entropy" ))
444
+ equi_diag .at [i ].set (info .get ("equi_diag" ))
445
+ equi_full .at [i ].set (info .get ("equi_full" ))
446
+ observables .at [i ].set (info .get ("observables" ))
447
+ r_avg .at [i ].set (info .get ("r_avg" ))
448
+ r_max .at [i ].set (info .get ("r_max" ))
449
+ step_size .at [i ].set (info .get ("step_size" ))
430
450
431
- return (output , i + 1 , info [ 0 ] .get ("while_cond" ))
451
+ return (output , i + 1 , info .get ("while_cond" ))
432
452
433
453
if early_stop :
434
454
final_state_all , i , _ = lax .while_loop (
@@ -437,7 +457,19 @@ def step_while(a):
437
457
(initial_state_all , 0 , True ),
438
458
)
439
459
steps_done = i
440
- info_history = None
460
+ info_history = {
461
+ "EEVPD" : EEVPD ,
462
+ "EEVPD_wanted" : EEVPD_wanted ,
463
+ "L" : L ,
464
+ "entropy" : entropy ,
465
+ "equi_diag" : equi_diag ,
466
+ "equi_full" : equi_full ,
467
+ "observables" : observables ,
468
+ "r_avg" : r_avg ,
469
+ "r_max" : r_max ,
470
+ "step_size" : step_size ,
471
+ "steps_done" : steps_done ,
472
+ }
441
473
442
474
else :
443
475
final_state_all , info_history = lax .scan (step , initial_state_all , xs )
0 commit comments