@@ -53,6 +53,7 @@ def mclmc_find_L_and_step_size(
53
53
num_effective_samples = 150 ,
54
54
params = None ,
55
55
diagonal_preconditioning = True ,
56
+ num_windows = 1 ,
56
57
euclidean = False
57
58
):
58
59
"""
@@ -122,6 +123,8 @@ def mclmc_find_L_and_step_size(
122
123
)
123
124
124
125
126
+
127
+
125
128
part1_key , part2_key = jax .random .split (rng_key , 2 )
126
129
total_num_tuning_integrator_steps = 0
127
130
@@ -131,17 +134,21 @@ def mclmc_find_L_and_step_size(
131
134
num_steps2 += diagonal_preconditioning * (num_steps2 // 3 )
132
135
num_steps3 = round (num_steps * frac_tune3 )
133
136
134
- state , params = make_L_step_size_adaptation (
135
- kernel = mclmc_kernel ,
136
- dim = dim ,
137
- frac_tune1 = frac_tune1 ,
138
- frac_tune2 = frac_tune2 ,
139
- desired_energy_var = desired_energy_var ,
140
- trust_in_estimate = trust_in_estimate ,
141
- num_effective_samples = num_effective_samples ,
142
- diagonal_preconditioning = diagonal_preconditioning ,
143
- euclidean = euclidean
144
- )(state , params , num_steps , part1_key )
137
+ for i in range (num_windows ):
138
+ window_key = jax .random .fold_in (part1_key , i )
139
+
140
+ state , params = make_L_step_size_adaptation (
141
+ kernel = mclmc_kernel ,
142
+ dim = dim ,
143
+ frac_tune1 = frac_tune1 / num_windows ,
144
+ frac_tune2 = frac_tune2 / num_windows ,
145
+ desired_energy_var = desired_energy_var ,
146
+ trust_in_estimate = trust_in_estimate ,
147
+ num_effective_samples = num_effective_samples ,
148
+ diagonal_preconditioning = diagonal_preconditioning ,
149
+ euclidean = euclidean
150
+ )(state , params , num_steps , window_key )
151
+
145
152
total_num_tuning_integrator_steps += num_steps1 + num_steps2
146
153
147
154
if num_steps3 >= 2 : # at least 2 samples for ESS estimation
0 commit comments