Skip to content

Thin kernel and sampling algorithm #791

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

hsimonfroy
Copy link
Contributor

PR discussed in #738

  • Add SamplingAlgorithm and kernel transformations making them thinned. They take a thinning integer and a SamplingAlgorithm/kernel and return the same SamplingAlgorithm/kernel but iterated thinning times.

  • This is useful to reduce computation and memory cost of high throughput samplers, especially in high dimension. While the thin_algorithm function operates on top_level_api SamplingAlgorithm, the thin_kernel version is relevant for adaptation algorithms. For instance, the estimation of autocorrelation length, for tuning momentum decoherence length in mclmc_adaptation, using the states from every step is computationally prohibitive in high dimension, see Subsampling for MCLMC tuning #738.

  • Both transformations have an additional info_transform Callable parameter that defines how to aggregate the sampler informations across the thinning steps. For instance, we might want to average the logdensities, and to rootmeansquare the energy_changes, which can be easily performed with tree.map or tree.map_with_path.

  • We should be able to understand what the PR does from its title only;
  • There is a high-level description of the changes;
  • There are links to all the relevant issues, discussions and PRs;
  • The branch is rebased on the latest main commit;
  • Commit messages follow these guidelines;
  • The code respects the current naming conventions;
  • Docstrings follow the numpy style guide
  • pre-commit is installed and configured on your machine, and you ran it before opening the PR;
  • There are tests covering the changes;
  • The doc is up-to-date;

@hsimonfroy
Copy link
Contributor Author

hsimonfroy commented May 18, 2025

  • Here is an example of how thin_algorithm and thin_kernel could be used for MCLMC:
logdf = lambda x: -(x**2).sum()
init_pos = jnp.ones(2)
init_key, tune_key, run_key = jr.split(jr.key(42), 3)

state = blackjax.mcmc.mclmc.init(
            position=init_pos,
            logdensity_fn=logdf,
            rng_key=init_key
            )

kernel = lambda inverse_mass_matrix: thin_kernel(
            blackjax.mcmc.mclmc.build_kernel(
                                logdensity_fn=logdf,
                                integrator=isokinetic_mclachlan,
                                inverse_mass_matrix=inverse_mass_matrix,
                                ),
            thinning = 16
            # Adequately aggregate info.energy_change
            info_transform=lambda info: tree.map(lambda x: (x**2).mean()**.5, info)
            )

state, params, n_steps = blackjax.mclmc_find_L_and_step_size(
            mclmc_kernel=kernel,
            num_steps=100,
            state=state,
            rng_key=tune_key,
            )
    
sampler = blackjax.mclmc(
            logdensity_fn=logdf,
            L=params.L,
            step_size=params.step_size,
            inverse_mass_matrix=params.inverse_mass_matrix,
            )

sampler = thin_algorithm(
            sampler,
            thinning=16,
            info_transform=lambda info: tree.map(jnp.mean, info),
            )

state, history = run_inference_algorithm(
            rng_key=run_key,
            initial_state=state,
            inference_algorithm=sampler,
            num_steps=100,
            )
    
  • NB: I exposed the Lfactor=0.4 parameter in mclmc_find_L_and_step_size, because if thinning is too high, the computed ESS on the thinned samples would be bigger than on the non-thinned samples, leading to underestimating autocorrelation length and therefore L. This can simply be compensated by increasing Lfactor. In practice, during my tests, I only found minor changes (on L estimation, with vs. without thinning) for reasonable thinning values, and so I am not sure this option is necessary. Actually, one shouldn't perform thinning to the point that it deteriorates the ESS, since one could just make less sampling steps then.

@junpenglao
Copy link
Member

Overall LGTM, could you add some test?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants