Skip to content

Commit 060c99a

Browse files
authored
Energy error monitoring (#784)
* energy error monitoring * energy error monitoring * jnp abs
1 parent 7e4241f commit 060c99a

File tree

1 file changed

+40
-8
lines changed

1 file changed

+40
-8
lines changed

blackjax/mcmc/mclmc.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from typing import Callable, NamedTuple
1616

1717
import jax
18+
import jax.numpy as jnp
1819

1920
from blackjax.base import SamplingAlgorithm
2021
from blackjax.mcmc.integrators import (
@@ -60,7 +61,13 @@ def init(position: ArrayLike, logdensity_fn, rng_key):
6061
)
6162

6263

63-
def build_kernel(logdensity_fn, inverse_mass_matrix, integrator):
64+
def build_kernel(
65+
logdensity_fn,
66+
inverse_mass_matrix,
67+
integrator,
68+
desired_energy_var_max_ratio=jnp.inf,
69+
desired_energy_var=5e-4,
70+
):
6471
"""Build a HMC kernel.
6572
6673
Parameters
@@ -91,14 +98,33 @@ def kernel(
9198
state, step_size, L, rng_key
9299
)
93100

94-
return IntegratorState(
95-
position, momentum, logdensity, logdensitygrad
96-
), MCLMCInfo(
97-
logdensity=logdensity,
98-
energy_change=kinetic_change - logdensity + state.logdensity,
99-
kinetic_change=kinetic_change,
101+
energy_error = kinetic_change - logdensity + state.logdensity
102+
103+
eev_max_per_dim = desired_energy_var_max_ratio * desired_energy_var
104+
ndims = pytree_size(position)
105+
106+
new_state, new_info = jax.lax.cond(
107+
jnp.abs(energy_error) > jnp.sqrt(ndims * eev_max_per_dim),
108+
lambda: (
109+
state,
110+
MCLMCInfo(
111+
logdensity=state.logdensity,
112+
energy_change=0.0,
113+
kinetic_change=0.0,
114+
),
115+
),
116+
lambda: (
117+
IntegratorState(position, momentum, logdensity, logdensitygrad),
118+
MCLMCInfo(
119+
logdensity=logdensity,
120+
energy_change=energy_error,
121+
kinetic_change=kinetic_change,
122+
),
123+
),
100124
)
101125

126+
return new_state, new_info
127+
102128
return kernel
103129

104130

@@ -108,6 +134,7 @@ def as_top_level_api(
108134
step_size,
109135
integrator=isokinetic_mclachlan,
110136
inverse_mass_matrix=1.0,
137+
desired_energy_var_max_ratio=jnp.inf,
111138
) -> SamplingAlgorithm:
112139
"""The general mclmc kernel builder (:meth:`blackjax.mcmc.mclmc.build_kernel`, alias `blackjax.mclmc.build_kernel`) can be
113140
cumbersome to manipulate. Since most users only need to specify the kernel
@@ -155,7 +182,12 @@ def as_top_level_api(
155182
A ``SamplingAlgorithm``.
156183
"""
157184

158-
kernel = build_kernel(logdensity_fn, inverse_mass_matrix, integrator)
185+
kernel = build_kernel(
186+
logdensity_fn,
187+
inverse_mass_matrix,
188+
integrator,
189+
desired_energy_var_max_ratio=desired_energy_var_max_ratio,
190+
)
159191

160192
def init_fn(position: ArrayLike, rng_key: PRNGKey):
161193
return init(position, logdensity_fn, rng_key)

0 commit comments

Comments
 (0)