Skip to content

Commit 9b00e28

Browse files
committed
bug fix
1 parent e64b7f4 commit 9b00e28

File tree

4 files changed

+15
-25
lines changed

4 files changed

+15
-25
lines changed

blackjax/adaptation/ensemble_mclmc.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -101,17 +101,15 @@ def __init__(
101101
def summary_statistics_fn(self, state, info, rng_key):
102102
return {
103103
"acceptance_probability": info.acceptance_rate,
104-
"equipartition_diagonal": equipartition_diagonal(
105-
state
106-
), # metric for bias: equipartition theorem gives todo...
104+
"equipartition_diagonal": equipartition_diagonal(state),
107105
"observables": self.observables(state.position),
108106
"observables_for_bias": self.observables_for_bias(state.position),
109107
}
110108

111109
def update(self, adaptation_state, Etheta):
112110
acc_prob = Etheta["acceptance_probability"]
113111
equi_diag = equipartition_diagonal_loss(Etheta["equipartition_diagonal"])
114-
true_bias = self.contract(Etheta["observables_for_bias"])
112+
true_bias = self.contract(Etheta["observables_for_bias"])
115113

116114
info_to_be_stored = {
117115
"L": adaptation_state.step_size * adaptation_state.steps_per_sample,
@@ -179,7 +177,7 @@ def emaus(
179177
integrator_coefficients=None,
180178
steps_per_sample=15,
181179
acc_prob=None,
182-
observables_for_bias=lambda x: 0.0,
180+
observables_for_bias=lambda x: x,
183181
ensemble_observables=None,
184182
diagnostics=True,
185183
contract=lambda x: 0.0,
@@ -205,7 +203,6 @@ def emaus(
205203
diagnostics: whether to return diagnostics
206204
"""
207205

208-
# observables_for_bias, contract = bias(model)
209206
key_init, key_umclmc, key_mclmc = jax.random.split(rng_key, 3)
210207

211208
# initialize the chains
@@ -297,7 +294,6 @@ def emaus(
297294
observables_for_bias=observables_for_bias,
298295
)
299296

300-
301297
final_state, final_adaptation_state, info2 = run_eca(
302298
key_mclmc,
303299
initial_state,

blackjax/adaptation/ensemble_umclmc.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def initialize(rng_key, logdensity_fn, sample_init, num_chains, mesh):
7272
def sequential_init(key, x, args):
7373
"""initialize the position using sample_init and the velocity along the gradient"""
7474
position = sample_init(key)
75-
75+
7676
logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position)
7777
flat_g, unravel_fn = ravel_pytree(logdensity_grad)
7878
velocity = unravel_fn(
@@ -83,8 +83,12 @@ def sequential_init(key, x, args):
8383

8484
def summary_statistics_fn(state):
8585
"""compute the diagonal elements of the equipartition matrix"""
86-
return -state.position * state.logdensity_grad
86+
flat_pos, unflatten = jax.flatten_util.ravel_pytree(state.position)
87+
flat_g, unravel_fn = ravel_pytree(state.logdensity_grad)
88+
return unravel_fn(-flat_pos * flat_g)
89+
# return 0
8790

91+
# -state.position # * state.logdensity_grad
8892

8993
def ensemble_init(key, state, signs):
9094
"""flip the velocity, depending on the equipartition condition"""
@@ -113,7 +117,9 @@ def ensemble_init(key, state, signs):
113117
summary_statistics_fn=summary_statistics_fn,
114118
)
115119

116-
signs = -2.0 * (equipartition < 1.0) + 1.0
120+
flat_equi, _ = ravel_pytree(equipartition)
121+
122+
signs = -2.0 * (flat_equi < 1.0) + 1.0
117123
initial_state, _ = ensemble_execute_fn(
118124
ensemble_init, key2, num_chains, mesh, x=initial_state, args=signs
119125
)
@@ -122,7 +128,6 @@ def ensemble_init(key, state, signs):
122128

123129

124130
def update_history(new_vals, history):
125-
new_vals, _ = jax.flatten_util.ravel_pytree(new_vals)
126131
return jnp.concatenate((new_vals[None, :], history[:-1, :]))
127132

128133

@@ -258,7 +263,6 @@ def update(self, adaptation_state, Etheta):
258263
history_observables = update_history(
259264
Etheta["observables_for_bias"], adaptation_state.history.observables
260265
)
261-
# history_observables = adaptation_state.history.observables
262266

263267
history_weights = update_history_scalar(1.0, adaptation_state.history.weights)
264268
fluctuations = contract_history(history_observables, history_weights)

blackjax/util.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from jax.random import normal, split
1111
from jax.sharding import NamedSharding, PartitionSpec
1212
from jax.tree_util import tree_leaves, tree_map
13-
import jax
13+
1414
from blackjax.base import SamplingAlgorithm, VIAlgorithm
1515
from blackjax.progress_bar import gen_scan_fn
1616
from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey
@@ -352,14 +352,11 @@ def _step(state_all, xs):
352352
adaptation_state, info_to_be_stored = adaptation_update(
353353
adaptation_state, Etheta
354354
)
355-
356355

357356
return (state, adaptation_state), info_to_be_stored
358-
359357

360358
if ensemble_info is not None:
361359

362-
363360
def step(state_all, xs):
364361
(state, adaptation_state), info_to_be_stored = _step(state_all, xs)
365362
return (state, adaptation_state), (
@@ -384,7 +381,6 @@ def run_eca(
384381
ensemble_info=None,
385382
early_stop=False,
386383
):
387-
388384
"""
389385
Run ensemble chain adaptation (eca) in parallel on multiple devices.
390386
-----------------------------------------------------
@@ -417,7 +413,6 @@ def all_steps(initial_state, keys_sampling, keys_adaptation):
417413

418414
initial_state_all = (initial_state, adaptation.initial_state)
419415

420-
421416
# run sampling
422417
xs = (
423418
jnp.arange(num_steps),
@@ -446,16 +441,13 @@ def step_while(a):
446441
else:
447442
final_state_all, info_history = lax.scan(step, initial_state_all, xs)
448443

449-
450-
451444
final_state, final_adaptation_state = final_state_all
452445
return (
453446
final_state,
454447
final_adaptation_state,
455448
info_history,
456449
) # info history is composed of averages over all chains, so it is a couple of scalars
457450

458-
459451
p, pscalar = PartitionSpec("chains"), PartitionSpec()
460452
parallel_execute = shard_map(
461453
all_steps,

tests/mcmc/test_sampling.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import chex
77
import jax
8+
89
# jax.config.update("jax_traceback_filtering", "off")
910
import jax.numpy as jnp
1011
import jax.scipy.stats as stats
@@ -296,11 +297,10 @@ def run_emaus(
296297
sample_init,
297298
logdensity_fn,
298299
ndims,
299-
transform,
300300
key,
301301
diagonal_preconditioning,
302302
):
303-
mesh = jax.sharding.Mesh(devices=jax.devices(),axis_names= "chains")
303+
mesh = jax.sharding.Mesh(devices=jax.devices(), axis_names="chains")
304304

305305
from blackjax.mcmc.integrators import mclachlan_coefficients
306306

@@ -309,7 +309,6 @@ def run_emaus(
309309
info, grads_per_step, _acc_prob, final_state = emaus(
310310
logdensity_fn=logdensity_fn,
311311
sample_init=sample_init,
312-
transform=transform,
313312
ndims=ndims,
314313
num_steps1=100,
315314
num_steps2=300,
@@ -602,7 +601,6 @@ def sample_init(key):
602601
samples = self.run_emaus(
603602
sample_init=sample_init,
604603
logdensity_fn=logdensity_fn,
605-
transform=lambda x: x,
606604
ndims=2,
607605
key=inference_key,
608606
diagonal_preconditioning=True,

0 commit comments

Comments
 (0)