Skip to content

Commit a164c63

Browse files
author
Reuben Harry Cohn-Gordon
committed
windows for unadjusted
1 parent fe28c3c commit a164c63

File tree

1 file changed

+18
-11
lines changed

1 file changed

+18
-11
lines changed

blackjax/adaptation/mclmc_adaptation.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def mclmc_find_L_and_step_size(
5353
num_effective_samples=150,
5454
params=None,
5555
diagonal_preconditioning=True,
56+
num_windows=1,
5657
euclidean=False
5758
):
5859
"""
@@ -122,6 +123,8 @@ def mclmc_find_L_and_step_size(
122123
)
123124

124125

126+
127+
125128
part1_key, part2_key = jax.random.split(rng_key, 2)
126129
total_num_tuning_integrator_steps = 0
127130

@@ -131,17 +134,21 @@ def mclmc_find_L_and_step_size(
131134
num_steps2 += diagonal_preconditioning * (num_steps2 // 3)
132135
num_steps3 = round(num_steps * frac_tune3)
133136

134-
state, params = make_L_step_size_adaptation(
135-
kernel=mclmc_kernel,
136-
dim=dim,
137-
frac_tune1=frac_tune1,
138-
frac_tune2=frac_tune2,
139-
desired_energy_var=desired_energy_var,
140-
trust_in_estimate=trust_in_estimate,
141-
num_effective_samples=num_effective_samples,
142-
diagonal_preconditioning=diagonal_preconditioning,
143-
euclidean=euclidean
144-
)(state, params, num_steps, part1_key)
137+
for i in range(num_windows):
138+
window_key = jax.random.fold_in(part1_key, i)
139+
140+
state, params = make_L_step_size_adaptation(
141+
kernel=mclmc_kernel,
142+
dim=dim,
143+
frac_tune1=frac_tune1/num_windows,
144+
frac_tune2=frac_tune2/num_windows,
145+
desired_energy_var=desired_energy_var,
146+
trust_in_estimate=trust_in_estimate,
147+
num_effective_samples=num_effective_samples,
148+
diagonal_preconditioning=diagonal_preconditioning,
149+
euclidean=euclidean
150+
)(state, params, num_steps, window_key)
151+
145152
total_num_tuning_integrator_steps += num_steps1 + num_steps2
146153

147154
if num_steps3 >= 2: # at least 2 samples for ESS estimation

0 commit comments

Comments
 (0)