Skip to content

Commit a612d39

Browse files
author
Reuben Harry Cohn-Gordon
committed
fix emaus code
1 parent fb72f34 commit a612d39

File tree

2 files changed

+21
-19
lines changed

2 files changed

+21
-19
lines changed

blackjax/adaptation/ensemble_umclmc.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,14 @@ def nan_reject(nonans, old, new):
4444
def build_kernel(logdensity_fn):
4545
"""MCLMC kernel (with nan rejection)"""
4646

47-
kernel = mclmc.build_kernel(
48-
logdensity_fn=logdensity_fn, integrator=isokinetic_velocity_verlet
49-
)
47+
# kernel = mclmc.build_kernel(
48+
# logdensity_fn=logdensity_fn, integrator=isokinetic_velocity_verlet
49+
# )
5050

5151
def sequential_kernel(key, state, adap):
52-
new_state, info = kernel(key, state, adap.L, adap.step_size)
52+
new_state, info = mclmc.build_kernel(
53+
logdensity_fn=logdensity_fn, integrator=isokinetic_velocity_verlet, inverse_mass_matrix=adap.inverse_mass_matrix
54+
)(key, state, adap.L, adap.step_size)
5355

5456
# reject the new state if there were nans
5557
nonans = no_nans(new_state)

blackjax/util.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -432,29 +432,29 @@ def all_steps(initial_state, keys_sampling, keys_adaptation):
432432
step_size = jnp.zeros((num_steps,))
433433

434434
def step_while(a):
435-
x, i, _ = a
435+
x, i, _, EEVPD, EEVPD_wanted, L, entropy, equi_diag, equi_full, observables, r_avg, r_max, step_size = a
436436

437437
auxilliary_input = (xs[0][i], xs[1][i], xs[2][i])
438438

439439
output, (info, pos) = step(x, auxilliary_input)
440-
EEVPD.at[i].set(info.get("EEVPD"))
441-
EEVPD_wanted.at[i].set(info.get("EEVPD_wanted"))
442-
L.at[i].set(info.get("L"))
443-
entropy.at[i].set(info.get("entropy"))
444-
equi_diag.at[i].set(info.get("equi_diag"))
445-
equi_full.at[i].set(info.get("equi_full"))
446-
observables.at[i].set(info.get("observables"))
447-
r_avg.at[i].set(info.get("r_avg"))
448-
r_max.at[i].set(info.get("r_max"))
449-
step_size.at[i].set(info.get("step_size"))
450-
451-
return (output, i + 1, info.get("while_cond"))
440+
new_EEVPD = EEVPD.at[i].set(info.get("EEVPD"))
441+
new_EEVPD_wanted = EEVPD_wanted.at[i].set(info.get("EEVPD_wanted"))
442+
new_L = L.at[i].set(info.get("L"))
443+
new_entropy = entropy.at[i].set(info.get("entropy"))
444+
new_equi_diag = equi_diag.at[i].set(info.get("equi_diag"))
445+
new_equi_full = equi_full.at[i].set(info.get("equi_full"))
446+
new_observables = observables.at[i].set(info.get("observables"))
447+
new_r_avg = r_avg.at[i].set(info.get("r_avg"))
448+
new_r_max = r_max.at[i].set(info.get("r_max"))
449+
new_step_size = step_size.at[i].set(info.get("step_size"))
450+
451+
return (output, i + 1, info.get("while_cond"), new_EEVPD, new_EEVPD_wanted, new_L, new_entropy, new_equi_diag, new_equi_full, new_observables, new_r_avg, new_r_max, new_step_size)
452452

453453
if early_stop:
454-
final_state_all, i, _ = lax.while_loop(
454+
final_state_all, i, _, EEVPD, EEVPD_wanted, L, entropy, equi_diag, equi_full, observables, r_avg, r_max, step_size = lax.while_loop(
455455
lambda a: ((a[1] < num_steps) & a[2]),
456456
step_while,
457-
(initial_state_all, 0, True),
457+
(initial_state_all, 0, True, EEVPD, EEVPD_wanted, L, entropy, equi_diag, equi_full, observables, r_avg, r_max, step_size),
458458
)
459459
steps_done = i
460460
info_history = {

0 commit comments

Comments
 (0)