Skip to content

Commit 7e4241f

Browse files
ciguaranjunpenglao
andauthored
SMC: Joint tuning and pretuning (#776)
* impl * rename * docs --------- Co-authored-by: Junpeng Lao <junpenglao@gmail.com>
1 parent 3f0cbb7 commit 7e4241f

File tree

4 files changed

+201
-45
lines changed

4 files changed

+201
-45
lines changed

blackjax/smc/inner_kernel_tuning.py

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from typing import Callable, Dict, NamedTuple, Tuple
22

3+
import jax
4+
35
from blackjax.base import SamplingAlgorithm
46
from blackjax.smc.base import SMCInfo, SMCState
57
from blackjax.types import ArrayTree, PRNGKey
@@ -28,8 +30,11 @@ def build_kernel(
2830
mcmc_step_fn: Callable,
2931
mcmc_init_fn: Callable,
3032
resampling_fn: Callable,
31-
mcmc_parameter_update_fn: Callable[[SMCState, SMCInfo], Dict[str, ArrayTree]],
33+
mcmc_parameter_update_fn: Callable[
34+
[PRNGKey, SMCState, SMCInfo], Dict[str, ArrayTree]
35+
],
3236
num_mcmc_steps: int = 10,
37+
smc_returns_state_with_parameter_override=False,
3338
**extra_parameters,
3439
) -> Callable:
3540
"""In the context of an SMC sampler (whose step_fn returning state has a .particles attribute), there's an inner
@@ -40,7 +45,8 @@ def build_kernel(
4045
----------
4146
smc_algorithm
4247
Either blackjax.adaptive_tempered_smc or blackjax.tempered_smc (or any other implementation of
43-
a sampling algorithm that returns an SMCState and SMCInfo pair).
48+
a sampling algorithm that returns an SMCState and SMCInfo pair). It is also possible for this
49+
to return an StateWithParameterOverride, in such case smc_returns_state_with_parameter_override needs to be True
4450
logprior_fn
4551
A function that computes the log density of the prior distribution
4652
loglikelihood_fn
@@ -54,7 +60,30 @@ def build_kernel(
5460
A callable that takes the SMCState and SMCInfo at step i and constructs a parameter to be used by the inner kernel in i+1 iteration.
5561
extra_parameters:
5662
parameters to be used for the creation of the smc_algorithm.
63+
smc_returns_state_with_parameter_override:
64+
a boolean indicating that the underlying smc_algorithm returns a smc_returns_state_with_parameter_override.
65+
this is used in order to compose different adaptation mechanisms, such as pretuning with tuning.
5766
"""
67+
if smc_returns_state_with_parameter_override:
68+
69+
def extract_state_for_delegate(state):
70+
return state
71+
72+
def compose_new_state(new_state, new_parameter_override):
73+
composed_parameter_override = (
74+
new_state.parameter_override | new_parameter_override
75+
)
76+
return StateWithParameterOverride(
77+
new_state.sampler_state, composed_parameter_override
78+
)
79+
80+
else:
81+
82+
def extract_state_for_delegate(state):
83+
return state.sampler_state
84+
85+
def compose_new_state(new_state, new_parameter_override):
86+
return StateWithParameterOverride(new_state, new_parameter_override)
5887

5988
def kernel(
6089
rng_key: PRNGKey, state: StateWithParameterOverride, **extra_step_parameters
@@ -69,9 +98,14 @@ def kernel(
6998
num_mcmc_steps=num_mcmc_steps,
7099
**extra_parameters,
71100
).step
72-
new_state, info = step_fn(rng_key, state.sampler_state, **extra_step_parameters)
73-
new_parameter_override = mcmc_parameter_update_fn(new_state, info)
74-
return StateWithParameterOverride(new_state, new_parameter_override), info
101+
parameter_update_key, step_key = jax.random.split(rng_key, 2)
102+
new_state, info = step_fn(
103+
step_key, extract_state_for_delegate(state), **extra_step_parameters
104+
)
105+
new_parameter_override = mcmc_parameter_update_fn(
106+
parameter_update_key, new_state, info
107+
)
108+
return compose_new_state(new_state, new_parameter_override), info
75109

76110
return kernel
77111

@@ -83,9 +117,12 @@ def as_top_level_api(
83117
mcmc_step_fn: Callable,
84118
mcmc_init_fn: Callable,
85119
resampling_fn: Callable,
86-
mcmc_parameter_update_fn: Callable[[SMCState, SMCInfo], Dict[str, ArrayTree]],
120+
mcmc_parameter_update_fn: Callable[
121+
[PRNGKey, SMCState, SMCInfo], Dict[str, ArrayTree]
122+
],
87123
initial_parameter_value,
88124
num_mcmc_steps: int = 10,
125+
smc_returns_state_with_parameter_override=False,
89126
**extra_parameters,
90127
) -> SamplingAlgorithm:
91128
"""In the context of an SMC sampler (whose step_fn returning state
@@ -130,6 +167,7 @@ def as_top_level_api(
130167
resampling_fn,
131168
mcmc_parameter_update_fn,
132169
num_mcmc_steps,
170+
smc_returns_state_with_parameter_override,
133171
**extra_parameters,
134172
)
135173

blackjax/smc/pretuning.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,15 +99,21 @@ def update_parameter_distribution(
9999
)
100100

101101

102+
def default_measure_factory(state):
103+
inverse_mass_matrix = state.parameter_override["inverse_mass_matrix"]
104+
if not (len(inverse_mass_matrix.shape) == 3 and inverse_mass_matrix.shape[0] == 1):
105+
raise ValueError("ESJD only works if chains share the inverse_mass_matrix.")
106+
107+
return esjd(inverse_mass_matrix[0])
108+
109+
102110
def build_pretune(
103111
mcmc_init_fn: Callable,
104112
mcmc_step_fn: Callable,
105113
alpha: float,
106114
sigma_parameters: ArrayLikeTree,
107115
n_particles: int,
108-
performance_of_chain_measure_factory: Callable = lambda state: esjd(
109-
state.parameter_override["inverse_mass_matrix"]
110-
),
116+
performance_of_chain_measure_factory: Callable = default_measure_factory,
111117
natural_parameters: Optional[List[str]] = None,
112118
positive_parameters: Optional[List[str]] = None,
113119
):

blackjax/smc/tuning/from_particles.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
"particles_means",
1313
"particles_stds",
1414
"particles_covariance_matrix",
15-
"mass_matrix_from_particles",
15+
"inverse_mass_matrix_from_particles",
1616
]
1717

1818

@@ -28,18 +28,16 @@ def particles_covariance_matrix(particles):
2828
return jnp.cov(particles_as_rows(particles), ddof=0, rowvar=False)
2929

3030

31-
def mass_matrix_from_particles(particles) -> Array:
31+
def inverse_mass_matrix_from_particles(particles) -> Array:
3232
"""
3333
Implements tuning from section 3.1 from https://arxiv.org/pdf/1808.07730.pdf
34-
Computing a mass matrix to be used in HMC from particles.
35-
Given the particles covariance matrix, set all non-diagonal elements as zero,
36-
take the inverse, and keep the diagonal.
34+
Computing an inverse mass matrix to be used in HMC from particles.
3735
3836
Returns
3937
-------
40-
A mass Matrix
38+
An inverse mass matrix
4139
"""
42-
return jnp.diag(1.0 / jnp.var(particles_as_rows(particles), axis=0))
40+
return jnp.diag(jnp.var(particles_as_rows(particles), axis=0))
4341

4442

4543
def particles_as_rows(particles):

0 commit comments

Comments
 (0)