Skip to content

Commit fb72f34

Browse files
author
Reuben Harry Cohn-Gordon
committed
Merge branch 'emaus' into working_branch
2 parents 6448735 + 3906c0e commit fb72f34

File tree

2 files changed

+37
-5
lines changed

2 files changed

+37
-5
lines changed

blackjax/adaptation/ensemble_mclmc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ def emaus(
306306
)
307307

308308
if diagnostics:
309-
info = {"phase_1": {"steps_done": steps_done_phase_1}, "phase_2": info2}
309+
info = {"phase_1": info1, "phase_2": info2}
310310
else:
311311
info = None
312312

blackjax/util.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -420,15 +420,35 @@ def all_steps(initial_state, keys_sampling, keys_adaptation):
420420
keys_adaptation,
421421
) # keys for all steps that will be performed. keys_sampling.shape = (num_steps, chains_per_device), keys_adaptation.shape = (num_steps, )
422422

423-
# ((a, Int) -> (a, Int))
423+
EEVPD = jnp.zeros((num_steps,))
424+
EEVPD_wanted = jnp.zeros((num_steps,))
425+
L = jnp.zeros((num_steps,))
426+
entropy = jnp.zeros((num_steps,))
427+
equi_diag = jnp.zeros((num_steps,))
428+
equi_full = jnp.zeros((num_steps,))
429+
observables = jnp.zeros((num_steps,))
430+
r_avg = jnp.zeros((num_steps,))
431+
r_max = jnp.zeros((num_steps,))
432+
step_size = jnp.zeros((num_steps,))
433+
424434
def step_while(a):
425435
x, i, _ = a
426436

427437
auxilliary_input = (xs[0][i], xs[1][i], xs[2][i])
428438

429-
output, info = step(x, auxilliary_input)
439+
output, (info, pos) = step(x, auxilliary_input)
440+
EEVPD.at[i].set(info.get("EEVPD"))
441+
EEVPD_wanted.at[i].set(info.get("EEVPD_wanted"))
442+
L.at[i].set(info.get("L"))
443+
entropy.at[i].set(info.get("entropy"))
444+
equi_diag.at[i].set(info.get("equi_diag"))
445+
equi_full.at[i].set(info.get("equi_full"))
446+
observables.at[i].set(info.get("observables"))
447+
r_avg.at[i].set(info.get("r_avg"))
448+
r_max.at[i].set(info.get("r_max"))
449+
step_size.at[i].set(info.get("step_size"))
430450

431-
return (output, i + 1, info[0].get("while_cond"))
451+
return (output, i + 1, info.get("while_cond"))
432452

433453
if early_stop:
434454
final_state_all, i, _ = lax.while_loop(
@@ -437,7 +457,19 @@ def step_while(a):
437457
(initial_state_all, 0, True),
438458
)
439459
steps_done = i
440-
info_history = None
460+
info_history = {
461+
"EEVPD": EEVPD,
462+
"EEVPD_wanted": EEVPD_wanted,
463+
"L": L,
464+
"entropy": entropy,
465+
"equi_diag": equi_diag,
466+
"equi_full": equi_full,
467+
"observables": observables,
468+
"r_avg": r_avg,
469+
"r_max": r_max,
470+
"step_size": step_size,
471+
"steps_done": steps_done,
472+
}
441473

442474
else:
443475
final_state_all, info_history = lax.scan(step, initial_state_all, xs)

0 commit comments

Comments
 (0)