1
1
from typing import Callable , Dict , NamedTuple , Tuple
2
2
3
+ import jax
4
+
3
5
from blackjax .base import SamplingAlgorithm
4
6
from blackjax .smc .base import SMCInfo , SMCState
5
7
from blackjax .types import ArrayTree , PRNGKey
@@ -28,8 +30,11 @@ def build_kernel(
28
30
mcmc_step_fn : Callable ,
29
31
mcmc_init_fn : Callable ,
30
32
resampling_fn : Callable ,
31
- mcmc_parameter_update_fn : Callable [[SMCState , SMCInfo ], Dict [str , ArrayTree ]],
33
+ mcmc_parameter_update_fn : Callable [
34
+ [PRNGKey , SMCState , SMCInfo ], Dict [str , ArrayTree ]
35
+ ],
32
36
num_mcmc_steps : int = 10 ,
37
+ smc_returns_state_with_parameter_override = False ,
33
38
** extra_parameters ,
34
39
) -> Callable :
35
40
"""In the context of an SMC sampler (whose step_fn returning state has a .particles attribute), there's an inner
@@ -40,7 +45,8 @@ def build_kernel(
40
45
----------
41
46
smc_algorithm
42
47
Either blackjax.adaptive_tempered_smc or blackjax.tempered_smc (or any other implementation of
43
- a sampling algorithm that returns an SMCState and SMCInfo pair).
48
+ a sampling algorithm that returns an SMCState and SMCInfo pair). It is also possible for this
49
+ to return an StateWithParameterOverride, in such case smc_returns_state_with_parameter_override needs to be True
44
50
logprior_fn
45
51
A function that computes the log density of the prior distribution
46
52
loglikelihood_fn
@@ -54,7 +60,30 @@ def build_kernel(
54
60
A callable that takes the SMCState and SMCInfo at step i and constructs a parameter to be used by the inner kernel in i+1 iteration.
55
61
extra_parameters:
56
62
parameters to be used for the creation of the smc_algorithm.
63
+ smc_returns_state_with_parameter_override:
64
+ a boolean indicating that the underlying smc_algorithm returns a smc_returns_state_with_parameter_override.
65
+ this is used in order to compose different adaptation mechanisms, such as pretuning with tuning.
57
66
"""
67
+ if smc_returns_state_with_parameter_override :
68
+
69
+ def extract_state_for_delegate (state ):
70
+ return state
71
+
72
+ def compose_new_state (new_state , new_parameter_override ):
73
+ composed_parameter_override = (
74
+ new_state .parameter_override | new_parameter_override
75
+ )
76
+ return StateWithParameterOverride (
77
+ new_state .sampler_state , composed_parameter_override
78
+ )
79
+
80
+ else :
81
+
82
+ def extract_state_for_delegate (state ):
83
+ return state .sampler_state
84
+
85
+ def compose_new_state (new_state , new_parameter_override ):
86
+ return StateWithParameterOverride (new_state , new_parameter_override )
58
87
59
88
def kernel (
60
89
rng_key : PRNGKey , state : StateWithParameterOverride , ** extra_step_parameters
@@ -69,9 +98,14 @@ def kernel(
69
98
num_mcmc_steps = num_mcmc_steps ,
70
99
** extra_parameters ,
71
100
).step
72
- new_state , info = step_fn (rng_key , state .sampler_state , ** extra_step_parameters )
73
- new_parameter_override = mcmc_parameter_update_fn (new_state , info )
74
- return StateWithParameterOverride (new_state , new_parameter_override ), info
101
+ parameter_update_key , step_key = jax .random .split (rng_key , 2 )
102
+ new_state , info = step_fn (
103
+ step_key , extract_state_for_delegate (state ), ** extra_step_parameters
104
+ )
105
+ new_parameter_override = mcmc_parameter_update_fn (
106
+ parameter_update_key , new_state , info
107
+ )
108
+ return compose_new_state (new_state , new_parameter_override ), info
75
109
76
110
return kernel
77
111
@@ -83,9 +117,12 @@ def as_top_level_api(
83
117
mcmc_step_fn : Callable ,
84
118
mcmc_init_fn : Callable ,
85
119
resampling_fn : Callable ,
86
- mcmc_parameter_update_fn : Callable [[SMCState , SMCInfo ], Dict [str , ArrayTree ]],
120
+ mcmc_parameter_update_fn : Callable [
121
+ [PRNGKey , SMCState , SMCInfo ], Dict [str , ArrayTree ]
122
+ ],
87
123
initial_parameter_value ,
88
124
num_mcmc_steps : int = 10 ,
125
+ smc_returns_state_with_parameter_override = False ,
89
126
** extra_parameters ,
90
127
) -> SamplingAlgorithm :
91
128
"""In the context of an SMC sampler (whose step_fn returning state
@@ -130,6 +167,7 @@ def as_top_level_api(
130
167
resampling_fn ,
131
168
mcmc_parameter_update_fn ,
132
169
num_mcmc_steps ,
170
+ smc_returns_state_with_parameter_override ,
133
171
** extra_parameters ,
134
172
)
135
173
0 commit comments