Skip to content

Commit e64b7f4

Browse files
committed
Merge branch 'emaus' of github.com:reubenharry/blackjax into emaus
2 parents f404898 + a5eb4f4 commit e64b7f4

File tree

4 files changed

+22
-12
lines changed

4 files changed

+22
-12
lines changed

blackjax/adaptation/ensemble_mclmc.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def __init__(
7171
acc_prob_target=0.8,
7272
observables=lambda x: 0.0, # just for diagnostics: some function of a given chain at given timestep
7373
observables_for_bias=lambda x: 0.0, # just for diagnostics: the above, but averaged over all chains
74-
contract=lambda x: 0.0, # just for diagnostics: observabiels for bias, contracted over dimensions
74+
contract=lambda x: 0.0, # just for diagnostics: observables for bias, contracted over dimensions
7575
):
7676
self.num_adaptation_samples = num_adaptation_samples
7777
self.observables = observables
@@ -111,7 +111,7 @@ def summary_statistics_fn(self, state, info, rng_key):
111111
def update(self, adaptation_state, Etheta):
112112
acc_prob = Etheta["acceptance_probability"]
113113
equi_diag = equipartition_diagonal_loss(Etheta["equipartition_diagonal"])
114-
true_bias = self.contract(Etheta["observables_for_bias"]) # remove
114+
true_bias = self.contract(Etheta["observables_for_bias"])
115115

116116
info_to_be_stored = {
117117
"L": adaptation_state.step_size * adaptation_state.steps_per_sample,
@@ -164,7 +164,6 @@ def while_steps_num(cond):
164164
def emaus(
165165
logdensity_fn,
166166
sample_init,
167-
transform,
168167
ndims,
169168
num_steps1,
170169
num_steps2,
@@ -180,9 +179,10 @@ def emaus(
180179
integrator_coefficients=None,
181180
steps_per_sample=15,
182181
acc_prob=None,
183-
observables=lambda x: None,
182+
observables_for_bias=lambda x: 0.0,
184183
ensemble_observables=None,
185184
diagnostics=True,
185+
contract=lambda x: 0.0,
186186
):
187187
"""
188188
model: the target density object
@@ -215,7 +215,7 @@ def emaus(
215215

216216
# burn-in with the unadjusted method #
217217
kernel = umclmc.build_kernel(logdensity_fn)
218-
save_num = (int)(jnp.rint(save_frac * num_steps1))
218+
save_num = (jnp.rint(save_frac * num_steps1)).astype(int)
219219
adap = umclmc.Adaptation(
220220
ndims,
221221
alpha=alpha,
@@ -224,9 +224,8 @@ def emaus(
224224
C=C,
225225
power=3.0 / 8.0,
226226
r_end=r_end,
227-
observables_for_bias=lambda position: jnp.square(
228-
transform(jax.flatten_util.ravel_pytree(position)[0])
229-
),
227+
observables_for_bias=observables_for_bias,
228+
contract=contract,
230229
)
231230

232231
final_state, final_adaptation_state, info1 = run_eca(
@@ -294,8 +293,11 @@ def emaus(
294293
num_adaptation_samples,
295294
steps_per_sample,
296295
_acc_prob,
296+
contract=contract,
297+
observables_for_bias=observables_for_bias,
297298
)
298299

300+
299301
final_state, final_adaptation_state, info2 = run_eca(
300302
key_mclmc,
301303
initial_state,

blackjax/adaptation/ensemble_umclmc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def initialize(rng_key, logdensity_fn, sample_init, num_chains, mesh):
7272
def sequential_init(key, x, args):
7373
"""initialize the position using sample_init and the velocity along the gradient"""
7474
position = sample_init(key)
75+
7576
logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position)
7677
flat_g, unravel_fn = ravel_pytree(logdensity_grad)
7778
velocity = unravel_fn(
@@ -82,9 +83,8 @@ def sequential_init(key, x, args):
8283

8384
def summary_statistics_fn(state):
8485
"""compute the diagonal elements of the equipartition matrix"""
85-
return 0 # -state.position * state.logdensity_grad
86+
return -state.position * state.logdensity_grad
8687

87-
# TODO: restore!
8888

8989
def ensemble_init(key, state, signs):
9090
"""flip the velocity, depending on the equipartition condition"""

blackjax/adaptation/mclmc_adaptation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ def mclmc_find_L_and_step_size(
5050
desired_energy_var=5e-4,
5151
trust_in_estimate=1.5,
5252
num_effective_samples=150,
53-
diagonal_preconditioning=True,
5453
params=None,
54+
diagonal_preconditioning=True,
5555
):
5656
"""
5757
Finds the optimal value of the parameters for the MCLMC algorithm.

blackjax/util.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from jax.random import normal, split
1111
from jax.sharding import NamedSharding, PartitionSpec
1212
from jax.tree_util import tree_leaves, tree_map
13-
13+
import jax
1414
from blackjax.base import SamplingAlgorithm, VIAlgorithm
1515
from blackjax.progress_bar import gen_scan_fn
1616
from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey
@@ -352,11 +352,14 @@ def _step(state_all, xs):
352352
adaptation_state, info_to_be_stored = adaptation_update(
353353
adaptation_state, Etheta
354354
)
355+
355356

356357
return (state, adaptation_state), info_to_be_stored
358+
357359

358360
if ensemble_info is not None:
359361

362+
360363
def step(state_all, xs):
361364
(state, adaptation_state), info_to_be_stored = _step(state_all, xs)
362365
return (state, adaptation_state), (
@@ -381,6 +384,7 @@ def run_eca(
381384
ensemble_info=None,
382385
early_stop=False,
383386
):
387+
384388
"""
385389
Run ensemble chain adaptation (eca) in parallel on multiple devices.
386390
-----------------------------------------------------
@@ -413,6 +417,7 @@ def all_steps(initial_state, keys_sampling, keys_adaptation):
413417

414418
initial_state_all = (initial_state, adaptation.initial_state)
415419

420+
416421
# run sampling
417422
xs = (
418423
jnp.arange(num_steps),
@@ -441,13 +446,16 @@ def step_while(a):
441446
else:
442447
final_state_all, info_history = lax.scan(step, initial_state_all, xs)
443448

449+
450+
444451
final_state, final_adaptation_state = final_state_all
445452
return (
446453
final_state,
447454
final_adaptation_state,
448455
info_history,
449456
) # info history is composed of averages over all chains, so it is a couple of scalars
450457

458+
451459
p, pscalar = PartitionSpec("chains"), PartitionSpec()
452460
parallel_execute = shard_map(
453461
all_steps,

0 commit comments

Comments
 (0)