15
15
from typing import Callable , NamedTuple
16
16
17
17
import jax
18
+ import jax .numpy as jnp
18
19
19
20
from blackjax .base import SamplingAlgorithm
20
21
from blackjax .mcmc .integrators import (
@@ -60,7 +61,13 @@ def init(position: ArrayLike, logdensity_fn, rng_key):
60
61
)
61
62
62
63
63
- def build_kernel (logdensity_fn , inverse_mass_matrix , integrator ):
64
+ def build_kernel (
65
+ logdensity_fn ,
66
+ inverse_mass_matrix ,
67
+ integrator ,
68
+ desired_energy_var_max_ratio = jnp .inf ,
69
+ desired_energy_var = 5e-4 ,
70
+ ):
64
71
"""Build a HMC kernel.
65
72
66
73
Parameters
@@ -91,14 +98,33 @@ def kernel(
91
98
state , step_size , L , rng_key
92
99
)
93
100
94
- return IntegratorState (
95
- position , momentum , logdensity , logdensitygrad
96
- ), MCLMCInfo (
97
- logdensity = logdensity ,
98
- energy_change = kinetic_change - logdensity + state .logdensity ,
99
- kinetic_change = kinetic_change ,
101
+ energy_error = kinetic_change - logdensity + state .logdensity
102
+
103
+ eev_max_per_dim = desired_energy_var_max_ratio * desired_energy_var
104
+ ndims = pytree_size (position )
105
+
106
+ new_state , new_info = jax .lax .cond (
107
+ jnp .abs (energy_error ) > jnp .sqrt (ndims * eev_max_per_dim ),
108
+ lambda : (
109
+ state ,
110
+ MCLMCInfo (
111
+ logdensity = state .logdensity ,
112
+ energy_change = 0.0 ,
113
+ kinetic_change = 0.0 ,
114
+ ),
115
+ ),
116
+ lambda : (
117
+ IntegratorState (position , momentum , logdensity , logdensitygrad ),
118
+ MCLMCInfo (
119
+ logdensity = logdensity ,
120
+ energy_change = energy_error ,
121
+ kinetic_change = kinetic_change ,
122
+ ),
123
+ ),
100
124
)
101
125
126
+ return new_state , new_info
127
+
102
128
return kernel
103
129
104
130
@@ -108,6 +134,7 @@ def as_top_level_api(
108
134
step_size ,
109
135
integrator = isokinetic_mclachlan ,
110
136
inverse_mass_matrix = 1.0 ,
137
+ desired_energy_var_max_ratio = jnp .inf ,
111
138
) -> SamplingAlgorithm :
112
139
"""The general mclmc kernel builder (:meth:`blackjax.mcmc.mclmc.build_kernel`, alias `blackjax.mclmc.build_kernel`) can be
113
140
cumbersome to manipulate. Since most users only need to specify the kernel
@@ -155,7 +182,12 @@ def as_top_level_api(
155
182
A ``SamplingAlgorithm``.
156
183
"""
157
184
158
- kernel = build_kernel (logdensity_fn , inverse_mass_matrix , integrator )
185
+ kernel = build_kernel (
186
+ logdensity_fn ,
187
+ inverse_mass_matrix ,
188
+ integrator ,
189
+ desired_energy_var_max_ratio = desired_energy_var_max_ratio ,
190
+ )
159
191
160
192
def init_fn (position : ArrayLike , rng_key : PRNGKey ):
161
193
return init (position , logdensity_fn , rng_key )
0 commit comments