diff --git a/blackjax/__init__.py b/blackjax/__init__.py index ef5eabd79..f34afdcd0 100644 --- a/blackjax/__init__.py +++ b/blackjax/__init__.py @@ -45,7 +45,7 @@ from .vi import pathfinder as _pathfinder from .vi import schrodinger_follmer as _schrodinger_follmer from .vi import svgd as _svgd -from .vi.pathfinder import PathFinderAlgorithm +from .vi.pathfinder import MultiPathfinderAlgorithm, multi_pathfinder """ The above three classes exist as a backwards compatible way of exposing both the high level, differentiable @@ -81,10 +81,12 @@ def __call__(self, *args, **kwargs) -> VIAlgorithm: @dataclasses.dataclass class GeneratePathfinderAPI: differentiable: Callable - approximate: Callable - sample: Callable + init: Callable + pathfinder: Callable + logp: Callable + importance_sampling: Callable - def __call__(self, *args, **kwargs) -> PathFinderAlgorithm: + def __call__(self, *args, **kwargs) -> MultiPathfinderAlgorithm: return self.differentiable(*args, **kwargs) @@ -153,7 +155,11 @@ def generate_top_level_api_from(module): ) pathfinder = GeneratePathfinderAPI( - _pathfinder.as_top_level_api, _pathfinder.approximate, _pathfinder.sample + _pathfinder.as_top_level_api, + _pathfinder.init, + _pathfinder.pathfinder, + _pathfinder.logp, + _pathfinder.importance_sampling, ) @@ -169,4 +175,5 @@ def generate_top_level_api_from(module): "adjusted_mclmc_find_L_and_step_size", # adjusted mclmc adaptation "ess", # diagnostics "rhat", + "multi_pathfinder", # pathfinder ] diff --git a/blackjax/adaptation/pathfinder_adaptation.py b/blackjax/adaptation/pathfinder_adaptation.py index c0b4ebc50..25ceb3420 100644 --- a/blackjax/adaptation/pathfinder_adaptation.py +++ b/blackjax/adaptation/pathfinder_adaptation.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Implementation of the Pathinder warmup for the HMC family of sampling algorithms.""" + from typing import Callable, NamedTuple import jax @@ -197,7 +198,12 @@ def one_step(carry, rng_key): adaptation_info_fn(new_state, info, new_adaptation_state), ) - def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 400): + def run( + rng_key: PRNGKey, + position: ArrayLikeTree, + num_steps: int = 400, + num_samples_per_path: int = 1000, + ): init_key, sample_key, rng_key = jax.random.split(rng_key, 3) pathfinder_state, _ = vi.pathfinder.approximate( @@ -210,7 +216,9 @@ def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 400): initial_step_size, ) - init_position, _ = vi.pathfinder.sample(sample_key, pathfinder_state) + init_position, _ = vi.pathfinder.sample( + sample_key, pathfinder_state, num_samples_per_path + ) init_state = algorithm.init(init_position, logdensity_fn) keys = jax.random.split(rng_key, num_steps) diff --git a/blackjax/optimizers/lbfgs.py b/blackjax/optimizers/lbfgs.py index 0dd59f003..6411ef1a7 100644 --- a/blackjax/optimizers/lbfgs.py +++ b/blackjax/optimizers/lbfgs.py @@ -11,26 +11,33 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import functools +import logging from typing import Callable, NamedTuple import jax import jax.numpy as jnp import jax.random -import jaxopt -from jax import lax -from jax.flatten_util import ravel_pytree -from jaxopt._src.lbfgs import LbfgsState -from jaxopt.base import OptStep +import jax.scipy as jsp +import optax +import optax.tree_utils as otu -from blackjax.types import Array, ArrayLikeTree +from blackjax.types import Array + +jax.config.update("jax_enable_x64", True) + + +logger = logging.getLogger(__name__) __all__ = [ "LBFGSHistory", "minimize_lbfgs", + "lbfgs_diff_history_matrix", "lbfgs_inverse_hessian_factors", "lbfgs_inverse_hessian_formula_1", "lbfgs_inverse_hessian_formula_2", "bfgs_sample", + "lbfgs_recover_alpha", ] INIT_STEP_SIZE = 1.0 @@ -58,213 +65,165 @@ class LBFGSHistory(NamedTuple): x: Array f: Array g: Array - alpha: Array - update_mask: Array + converged: Array # for clipping history for shorter inverse hessian calcs and bfgs sampling + iter: Array # TODO: remove iter. not needed def minimize_lbfgs( - fun: Callable, - x0: ArrayLikeTree, - maxiter: int = 30, - maxcor: float = 10, - gtol: float = 1e-08, - ftol: float = 1e-05, - maxls: int = 1000, - **lbfgs_kwargs, -) -> tuple[OptStep, LBFGSHistory]: - """ - Minimize a function using L-BFGS - - Parameters - ---------- - fun: - function of the form f(x) where x is a pytree and returns a real scalar. - The function should be composed of operations with vjp defined. - x0: - initial guess - maxiter: - maximum number of iterations - maxcor: - maximum number of metric corrections ("history size") - ftol: - terminates the minimization when `(f_k - f_{k+1}) < ftol` - gtol: - terminates the minimization when `|g_k|_norm < gtol` - maxls: - maximum number of line search steps (per iteration) - **lbfgs_kwargs - other keyword arguments passed to `jaxopt.LBFGS`. - - Returns - ------- - Optimization results and optimization path - - """ - # Ravel pytree into flat array. - x0_raveled, unravel_fn = ravel_pytree(x0) - unravel_fn_mapped = jax.vmap(unravel_fn) - - # Run LBFGS optimizer on flat input. - last_step_raveled, history_raveled = _minimize_lbfgs( - lambda x: fun(unravel_fn(x)), - x0_raveled, - maxiter, - maxcor, - gtol, - ftol, - maxls, - **lbfgs_kwargs, - ) - - # Unravel final optimization step. - last_step = OptStep( - params=unravel_fn(last_step_raveled.params), - state=LbfgsState( - iter_num=last_step_raveled.state.iter_num, - value=last_step_raveled.state.value, - grad=unravel_fn(last_step_raveled.state.grad), - stepsize=last_step_raveled.state.stepsize, - error=last_step_raveled.state.error, - s_history=unravel_fn_mapped(last_step_raveled.state.s_history), - y_history=unravel_fn_mapped(last_step_raveled.state.y_history), - rho_history=last_step_raveled.state.rho_history, - gamma=last_step_raveled.state.gamma, - aux=last_step_raveled.state.aux, - ), - ) - - # Unravel optimization path history. - history = LBFGSHistory( - x=unravel_fn_mapped(history_raveled.x), - f=history_raveled.f, - g=unravel_fn_mapped(history_raveled.g), - alpha=unravel_fn_mapped(history_raveled.alpha), - update_mask=jax.tree.map( - lambda x: x.astype(history_raveled.update_mask.dtype), - unravel_fn_mapped(history_raveled.update_mask.astype(x0_raveled.dtype)), - ), - ) - - return last_step, history - - -def _minimize_lbfgs( fun: Callable, x0: Array, - maxiter: int, - maxcor: float, - gtol: float, - ftol: float, - maxls: int, + maxiter: int = 100, + maxcor: int = 6, + gtol: float = 1e-8, + ftol: float = 1e-5, + maxls: int = 100, **lbfgs_kwargs, -) -> tuple[OptStep, LBFGSHistory]: +): def lbfgs_one_step(carry, i): - (params, state), previous_history = carry + # state is a 3-dim tuple + (params, state), _ = carry + lbfgs_state, _, _ = state - # this is to help optimization when using log-likelihoods, especially for float 32 - # it resets stepsize of the line search algorithm back to stating value (INIT_STEP_SIZE) if - # it get stuck in very small values - state = state._replace( - stepsize=jnp.where( - state.stepsize < MIN_STEP_SIZE, INIT_STEP_SIZE, state.stepsize - ) + value, grad = value_grad_fn(params) + updates, next_state = solver.update( + grad, state, params, value=value, grad=grad, value_fn=fun ) - # LBFGS use a rolling history, getting the correct index here. - last = (state.iter_num % maxcor + maxcor) % maxcor - next_params, next_state = solver.update(params, state) - - # Recover alpha and update mask - s_l = next_state.s_history[last] - z_l = next_state.y_history[last] - alpha_lm1 = previous_history.alpha + _, _, next_ls_state = next_state - alpha_l, mask_l = lbfgs_recover_alpha(alpha_lm1, s_l, z_l) + # LBFGS use a rolling history, getting the correct index here. + iter = lbfgs_state.count + next_params = params + updates - current_grad = previous_history.g + z_l + converged = check_convergence(state, next_state, iter) history = LBFGSHistory( x=next_params, - f=next_state.value, - g=current_grad, - alpha=alpha_l, - update_mask=mask_l, + f=next_ls_state.value, + g=next_ls_state.grad, + converged=converged, + iter=jnp.asarray(iter, dtype=jnp.int32), ) - # check convergence + return ((next_params, next_state), history), converged + + def check_convergence(state, next_state, iter): + _, _, ls_state = state + _, _, next_ls_state = next_state f_delta = ( - jnp.abs(state.value - next_state.value) - / jnp.asarray([jnp.abs(state.value), jnp.abs(next_state.value), 1.0]).max() + jnp.abs(ls_state.value - next_ls_state.value) + / jnp.asarray( + [jnp.abs(ls_state.value), jnp.abs(next_ls_state.value), 1.0] + ).max() + ) + error = otu.tree_l2_norm(next_ls_state.grad) + return jnp.array( + (iter > 0) & ((error <= gtol) | (f_delta <= ftol) | (iter >= maxiter)), + dtype=bool, ) - not_converged = (next_state.error > gtol) & (f_delta > ftol) & (i < maxiter) - return (OptStep(params=next_params, state=next_state), history), not_converged - def non_op(carry, it): - return carry, False + def state_type_handler(state): + # ensure num_linesearch_steps is of the same type + lbfgs_state, _, ls_state = state + info = ls_state.info._replace( + num_linesearch_steps=jnp.asarray( + ls_state.info.num_linesearch_steps, dtype=jnp.int32 + ) + ) + return lbfgs_state, _, ls_state._replace(info=info) - def scan_body(tup, it): - carry, not_converged = tup - # When cond is met, we start doing no-ops. - next_tup = lax.cond(not_converged, lbfgs_one_step, non_op, carry, it) + def non_op(carry, i): + (params, state), previous_history = carry + + return ((params, state), previous_history), jnp.array(True, dtype=bool) + + def scan_body(tup, i): + carry, converged = tup + next_tup = jax.lax.cond(converged, non_op, lbfgs_one_step, carry, i) return next_tup, next_tup[0][-1] - solver = jaxopt.LBFGS( - fun=fun, - maxiter=maxiter, - maxls=maxls, - history_size=maxcor, + linesearch = optax.scale_by_zoom_linesearch( + max_linesearch_steps=maxls, + verbose=False, **lbfgs_kwargs, ) - state = solver.init_state(x0) + solver = optax.lbfgs( + memory_size=maxcor, + linesearch=linesearch, + ) + value_grad_fn = jax.value_and_grad(fun) - value0, grad0 = jax.value_and_grad(fun)(x0) - # LBFGS update overwrite value internally, here is to set the value for checking condition - state = state._replace(value=value0) - init_step = OptStep(params=x0, state=state) - initial_history = LBFGSHistory( + init_state = solver.init(x0) + init_state = state_type_handler(init_state) + + value, grad = value_grad_fn(x0) + init_history = LBFGSHistory( x=x0, - f=value0, - g=grad0, - alpha=jnp.ones_like(x0), - update_mask=jnp.zeros_like(x0, dtype=bool), + f=value, + g=grad, + converged=jnp.array(False, dtype=bool), + iter=jnp.asarray(0, dtype=jnp.int32), ) - ((last_step, _), _), history = lax.scan( - scan_body, ((init_step, initial_history), True), jnp.arange(maxiter) + # Use lax.scan to accumulate history + (((last_params, last_state), _), _), history = jax.lax.scan( + scan_body, + (((x0, init_state), init_history), False), + jnp.arange(maxiter), ) - # Append initial state to history. + last_lbfgs_state, _, last_ls_state = last_state + history = jax.tree.map( lambda x, y: jnp.concatenate([x[None, ...], y], axis=0), - initial_history, + init_history, history, ) - return last_step, history + return (last_params, (last_lbfgs_state, last_ls_state)), history -def lbfgs_recover_alpha(alpha_lm1, s_l, z_l, epsilon=1e-12): +def lbfgs_recover_alpha( + position: Array, + grad_position: Array, + not_converged_mask: Array, + epsilon=1e-12, +): """ Compute diagonal elements of the inverse Hessian approximation from optimation path. It implements the inner loop body of Algorithm 3 in :cite:p:`zhang2022pathfinder`. Parameters ---------- - alpha_lm1 - The diagonal element of the inverse Hessian approximation of the previous iteration - s_l - The update of the position (current position - previous position) - z_l - The update of the gradient (current gradient - previous gradient). Note that in :cite:p:`zhang2022pathfinder` - it is defined as the negative of the update of the gradient, but since we are optimizing - the negative log prob function taking the update of the gradient is correct here. + position + shape (L+1, N) + The position at the current iteration + grad_position + shape (L+1, N) + The gradient at the current iteration + not_converged_mask + shape (L, N) + The indicator of whether the update of position and gradient are included in the inverse-Hessian approximation or not. + epsilon + The threshold for filtering updates based on inner product of position + and gradient differences Returns ------- - alpha_l + alpha + shape (L+1, N) The diagonal element of the inverse Hessian approximation of the current iteration - mask_l - The indicator of whether the update of position and gradient are included in - the inverse-Hessian approximation or not. - + s + shape (L, N) + The update of the position (current position - previous position) + z + shape (L, N) + The update of the gradient (current gradient - previous gradient). Note that in :cite:p:`zhang2022pathfinder` it is defined as the negative of the update of the gradient, but since we are optimizing the negative log prob function taking the update of the gradient is correct here. + update_mask + shape (L, N) + The indicator of whether the update of position and gradient are included in the inverse-Hessian approximation or not. + + Notes + ----- + shapes: N=num_params """ - def compute_next_alpha(s_l, z_l, alpha_lm1): + def compute_next_alpha(alpha_lm1, s_l, z_l): a = z_l.T @ jnp.diag(alpha_lm1) @ z_l b = z_l.T @ s_l c = s_l.T @ jnp.diag(1.0 / alpha_lm1) @ s_l @@ -275,40 +234,119 @@ def compute_next_alpha(s_l, z_l, alpha_lm1): ) return 1.0 / inv_alpha_l - pred = s_l.T @ z_l > (epsilon * jnp.linalg.norm(z_l, 2)) - alpha_l = lax.cond( - pred, compute_next_alpha, lambda *_: alpha_lm1, s_l, z_l, alpha_lm1 - ) - mask_l = jnp.where( - pred, - jnp.ones_like(alpha_lm1, dtype=bool), - jnp.zeros_like(alpha_lm1, dtype=bool), + def non_op(alpha_lm1, s_l, z_l): + return alpha_lm1 + + def scan_body(alpha_init, tup): + update_mask_l, s_l, z_l = tup + next_tup = jax.lax.cond( + update_mask_l, + compute_next_alpha, + non_op, + alpha_init, + s_l, + z_l, + ) + return next_tup, next_tup + + nan_pos_mask = jnp.any(~jnp.isfinite(position.at[1:].get()), axis=-1) + nan_grad_mask = jnp.any(~jnp.isfinite(grad_position.at[1:].get()), axis=-1) + nan_mask = jnp.logical_not(nan_pos_mask | nan_grad_mask) + + param_dims = position.shape[-1] + s = jnp.diff(position, axis=0) + z = jnp.diff(grad_position, axis=0) + sz = jnp.sum(s * z, axis=-1) + update_mask = ( + (sz > epsilon * jnp.sqrt(jnp.sum(z**2, axis=-1))) + & not_converged_mask + & nan_mask ) - return alpha_l, mask_l + alpha_init = jnp.ones((param_dims,)) + tup = (update_mask, s, z) + + _, alpha = jax.lax.scan(scan_body, alpha_init, tup) + + return alpha, s, z, update_mask + + +def lbfgs_diff_history_matrix(diff: Array, update_mask: Array, J: int): + """ + Construct an NxJ matrix that stores the previous J differences for position or gradient updates in L-BFGS. Storage is based on the update mask. + + Parameters + ---------- + diff : Array + shape (L, N) + array of differences, where L is the number of iterations + and N is the number of parameters. + update_mask : Array + shape (L, N) + boolean array indicating which differences to include. + J : int + history size, the number of past differences to store. + + Returns + ------- + chi_mat + shape (L, N, J) + history matrix of differences. + """ + L, N = diff.shape + j_last = jnp.array(J - 1) # since indexing starts at 0 + + def chi_update(chi_lm1, diff_l): + chi_l = jnp.roll(chi_lm1, -1, axis=0) + return chi_l.at[j_last].set(diff_l) + + def non_op(chi_lm1, diff_l): + return chi_lm1 + + def scan_body(chi_init, tup): + update_mask_l, diff_l = tup + next_tup = jax.lax.cond(update_mask_l, chi_update, non_op, chi_init, diff_l) + return next_tup, next_tup + + chi_init = jnp.zeros((J, N)) + _, chi_mat = jax.lax.scan(scan_body, chi_init, (update_mask, diff)) + + chi_mat = jnp.matrix_transpose(chi_mat) + + # (L, N, J) + return chi_mat def lbfgs_inverse_hessian_factors(S, Z, alpha): """ + Calculates factors for inverse hessian factored representation. It implements formula II.2 of: Pathfinder: Parallel quasi-newton variational inference, Lu Zhang et al., arXiv:2108.03782 """ - param_dims = S.shape[-1] + + param_dims, J = S.shape + StZ = S.T @ Z - R = jnp.triu(StZ) + jnp.eye(param_dims) * jnp.finfo(S.dtype).eps + Ij = jnp.eye(J) + + # TODO: uncomment this + # R = jnp.triu(StZ) + Ij * jnp.finfo(S.dtype).eps + # TODO: delete this + REGULARISATION_TERM = 1e-8 + R = jnp.triu(StZ) + Ij * REGULARISATION_TERM - eta = jnp.diag(StZ) + eta = jnp.diag(R) beta = jnp.hstack([jnp.diag(alpha) @ Z, S]) - minvR = -jnp.linalg.inv(R) + # jsp.linalg.solve is more stable than jnp.linalg.inv + minvR = -jsp.linalg.solve_triangular(R, Ij) alphaZ = jnp.diag(jnp.sqrt(alpha)) @ Z block_dd = minvR.T @ (alphaZ.T @ alphaZ + jnp.diag(eta)) @ minvR - gamma = jnp.block( - [[jnp.zeros((param_dims, param_dims)), minvR], [minvR.T, block_dd]] - ) + gamma = jnp.block([[jnp.zeros((J, J)), minvR], [minvR.T, block_dd]]) + return beta, gamma @@ -339,35 +377,122 @@ def lbfgs_inverse_hessian_formula_2(alpha, beta, gamma): ) -def bfgs_sample(rng_key, num_samples, position, grad_position, alpha, beta, gamma): +def bfgs_sample( + rng_key, + num_samples, + position, + grad_position, + alpha, + beta, + gamma, + sparse: bool | None = None, +): """ Draws approximate samples of target distribution. It implements Algorithm 4 in: Pathfinder: Parallel quasi-newton variational inference, Lu Zhang et al., arXiv:2108.03782 + parameters + ---------- + rng_key : array + prng key + num_samples : int + number of samples to draw + position : array + current position in parameter space + grad_position : array + gradient at current position + alpha : array + diagonal elements of inverse hessian approximation + beta : array + first factor of inverse hessian approximation + gamma : array + second factor of inverse hessian approximation + sparse : bool | none, optional + whether to use sparse computation, by default none + if none, automatically determined based on problem size + + returns + ------- + phi : array + samples drawn from approximate distribution + logdensity : array + log density of samples """ - if not isinstance(num_samples, tuple): - num_samples = (num_samples,) - - Q, R = jnp.linalg.qr(jnp.diag(jnp.sqrt(1 / alpha)) @ beta) - param_dims = beta.shape[0] - Id = jnp.identity(R.shape[0]) - L = jnp.linalg.cholesky(Id + R @ gamma @ R.T) - - logdet = jnp.log(jnp.prod(alpha)) + 2 * jnp.log(jnp.linalg.det(L)) - mu = ( - position - + jnp.diag(alpha) @ grad_position - + beta @ gamma @ beta.T @ grad_position - ) + param_dims = position.shape[-1] + J = beta.shape[-1] // 2 # beta has shape (param_dims, 2*J) + + def _bfgs_sample_sparse(args): + rng_key, position, grad_position, alpha, beta, gamma = args + sqrt_alpha = jnp.sqrt(alpha) + Q, R = jnp.linalg.qr(jnp.diag(1.0 / sqrt_alpha) @ beta) + Id = jnp.identity(R.shape[0]) + L = jnp.linalg.cholesky(Id + R @ gamma @ R.T) + + u = jax.random.normal(rng_key, (num_samples, param_dims, 1)) + logdet = jnp.sum(jnp.log(alpha), axis=-1) + 2.0 * jnp.sum( + jnp.log(jnp.abs(jnp.diagonal(L))), axis=-1 + ) + mu = position - ( + (jnp.diag(alpha) @ grad_position) + (beta @ gamma @ beta.T @ grad_position) + ) + phi = jnp.squeeze( + mu[..., None] + jnp.diag(sqrt_alpha) @ (Q @ (L - Id) @ (Q.T @ u) + u), + axis=-1, + ) + logdensity = -0.5 * ( + logdet + + jnp.einsum("...ji,...ji->...", u, u) + + param_dims * jnp.log(2.0 * jnp.pi) + ) + return phi, logdensity + + def _bfgs_sample_dense(args): + rng_key, position, grad_position, alpha, beta, gamma = args + sqrt_alpha = jnp.sqrt(alpha) + sqrt_alpha_diag = jnp.diag(sqrt_alpha) + inv_sqrt_alpha_diag = jnp.diag(1.0 / sqrt_alpha) + Id = jnp.identity(param_dims) + + H_inv = ( + sqrt_alpha_diag + @ (Id + inv_sqrt_alpha_diag @ beta @ gamma @ beta.T @ inv_sqrt_alpha_diag) + @ sqrt_alpha_diag + ) + + u = jax.random.normal(rng_key, (num_samples, param_dims, 1)) + Lchol = jnp.linalg.cholesky(H_inv) + logdet = 2.0 * jnp.sum( + jnp.log(jnp.abs(jnp.diagonal(Lchol, axis1=-2, axis2=-1))), axis=-1 + ) + mu = position - (H_inv @ grad_position) + phi = jnp.squeeze(mu[..., None] + (Lchol @ u), axis=-1) + logdensity = -0.5 * ( + logdet + + jnp.einsum("...ji,...ji->...", u, u) + + param_dims * jnp.log(2.0 * jnp.pi) + ) + return phi, logdensity + + # pack args to avoid excessive parameter passing + args = (rng_key, position, grad_position, alpha, beta, gamma) - u = jax.random.normal(rng_key, num_samples + (param_dims, 1)) - phi = mu[..., None] + jnp.diag(jnp.sqrt(alpha)) @ (Q @ (L - Id) @ (Q.T @ u) + u) + sparse = jax.lax.cond( + sparse is None, lambda _: param_dims < 2 * J, lambda _: sparse, None + ) - logdensity = -0.5 * ( - logdet - + jnp.einsum("...ji,...ji->...", u, u) - + param_dims * jnp.log(2.0 * jnp.pi) + phi, logdensity = jax.lax.cond( + sparse, _bfgs_sample_sparse, _bfgs_sample_dense, args ) - return phi[..., 0], logdensity + + nan_phi_mask = jnp.any(~jnp.isfinite(phi), axis=-1) + nan_logdensity_mask = ~jnp.isfinite(logdensity) + nan_mask = nan_phi_mask | nan_logdensity_mask + + logdensity = jnp.where(nan_mask, jnp.inf, logdensity) + return phi, logdensity + + +bfgs_sample_sparse = functools.partial(bfgs_sample, sparse=True) +bfgs_sample_dense = functools.partial(bfgs_sample, sparse=False) diff --git a/blackjax/smc/resampling.py b/blackjax/smc/resampling.py index 8e701e1b2..4e89e2634 100644 --- a/blackjax/smc/resampling.py +++ b/blackjax/smc/resampling.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" All things resampling. """ +"""All things resampling.""" + from functools import partial from typing import Callable @@ -22,9 +23,7 @@ def _resampling_func(func, name, desc="", additional_params="") -> Callable: - # Decorator for resampling function - - doc = f""" + doc = """ {name} resampling. {desc} Parameters @@ -40,8 +39,9 @@ def _resampling_func(func, name, desc="", additional_params="") -> Callable: ------- idx: Array Array of size `num_samples` to use for resampling - """ - + """.format( + name=name, desc=desc + ) func.__doc__ = doc return func @@ -104,7 +104,6 @@ def residual(rng_key: PRNGKey, weights: Array, num_samples: int) -> Array: ) # Permutation is needed due to the concatenation happening at the last step. - # # I am pretty sure we can use lower variance resamplers inside here instead # of multinomial, but I am not sure yet due to the loss of exchangeability, # and as a consequence I am playing it safe. diff --git a/blackjax/vi/pathfinder.py b/blackjax/vi/pathfinder.py index c1b7dc113..2bbf663f3 100644 --- a/blackjax/vi/pathfinder.py +++ b/blackjax/vi/pathfinder.py @@ -11,76 +11,219 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, NamedTuple, Union +import logging +import warnings +from enum import IntEnum +from typing import Callable, Literal, NamedTuple, Optional + +import arviz as az import jax import jax.numpy as jnp import jax.random +import numpy as np from jax.flatten_util import ravel_pytree +from jax.scipy.special import logsumexp from blackjax.optimizers.lbfgs import ( - _minimize_lbfgs, bfgs_sample, + lbfgs_diff_history_matrix, lbfgs_inverse_hessian_factors, + lbfgs_recover_alpha, + minimize_lbfgs, ) -from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey +from blackjax.types import Array, PRNGKey + +logger = logging.getLogger(__name__) + +__all__ = [ + "MultiPathfinderAlgorithm", + "multi_pathfinder", + "as_top_level_api", +] -__all__ = ["PathfinderState", "approximate", "sample", "as_top_level_api"] +# TODO: set jnp.arrays to float64 -class PathfinderState(NamedTuple): +class SinglePathfinderState(NamedTuple): """State of the Pathfinder algorithm. Pathfinder locates normal approximations to the target density along a quasi-Newton optimization path, with local covariance estimated using the inverse Hessian estimates produced by the L-BFGS optimizer. - PathfinderState stores for an interation fo the L-BFGS optimizer the + PathfinderState stores for an iteration of the L-BFGS optimizer the resulting ELBO and all factors needed to sample from the approximated target density. - position: - position - grad_position: - gradient of target distribution wrt position - alpha, beta, gamma: - factored rappresentation of the inverse hessian - elbo: - ELBO of approximation wrt target distribution - + Parameters + ---------- + initial_position : Array + initial position for the optimization + position : Array, optional + current position in parameter space + grad_position : Array, optional + gradient of target distribution at current position + alpha : Array, optional + first factor of the inverse hessian representation + beta : Array, optional + second factor of the inverse hessian representation + gamma : Array, optional + third factor of the inverse hessian representation + elbo : Array, optional + evidence lower bound of approximation to target distribution + sparse : bool + whether to use sparse representation of the inverse hessian """ - elbo: Array - position: ArrayTree - grad_position: ArrayTree - alpha: Array - beta: Array - gamma: Array + initial_position: Array + position: Optional[Array] = None + grad_position: Optional[Array] = None + alpha: Optional[Array] = None + beta: Optional[Array] = None + gamma: Optional[Array] = None + elbo: Optional[Array] = None + sparse: bool = False -class PathfinderInfo(NamedTuple): +class SinglePathfinderInfo(NamedTuple): """Extra information returned by the Pathfinder algorithm.""" - path: PathfinderState + path: SinglePathfinderState + update_mask: Array + + +class ImpSamplingMethod(IntEnum): + PSIS = 0 + PSIR = 1 + IDENTITY = 2 + NONE = 3 + + +class ImportanceSamplingState(NamedTuple): + """Container for importance sampling results. + + This class stores the results of importance sampling from multiple Pathfinder + approximations, including diagnostic information about the quality of the + importance sampling. + + Parameters + ---------- + num_paths : int + number of paths used in multi-pathfinder + samples : Array + importance sampled draws from the approximate distribution + pareto_k : float, optional + Pareto k diagnostic value from importance sampling + method : str, optional + importance sampling method used + """ + + num_paths: int + samples: Array + pareto_k: Optional[float] = None + method: int = ImpSamplingMethod.PSIS + + +class MultiPathfinderAlgorithm(NamedTuple): + init: Callable + pathfinder: Callable + logp: Callable + importance_sampling: Callable -class PathFinderAlgorithm(NamedTuple): - approximate: Callable - sample: Callable +def init( + rng_key: Optional[PRNGKey] = None, + base_position: Optional[Array] = None, + num_paths: Optional[int] = None, + jitter_amount: Optional[float] = None, + initial_position: Optional[Array] = None, +) -> Array: + """Initialize positions for multi-pathfinder. + + This function either returns the provided initial positions or generates them + by adding jitter to a base position. + + Parameters + ---------- + rng_key : PRNGKey, optional + JAX PRNG key for generating jittered positions + base_position : Array, optional + Unflattened base position to jitter around + num_paths : int, optional + Number of paths to generate + jitter_amount : float, optional + Scale of the uniform jitter to add to the base position + initial_position : Array, optional + Pre-specified initial positions to use. + + Returns + ------- + Array + Initial positions for multi-pathfinder + """ + + if num_paths is None: + raise ValueError("num_paths must be provided") + + if initial_position is not None: + # Check if initial_position is a single position or multiple positions + batch_size = jax.tree.leaves(initial_position)[0].shape[0] + if initial_position.ndim == 1: + logger.warning( + "Initial position is a single position, repeating for all paths. " + "This is likely to lead to poor performance. " + "Consider providing a batch of initial positions, or use base_position and jitter_amount." + ) + return jax.tree.map( + lambda x: jnp.repeat(x[jnp.newaxis, ...], num_paths, axis=0), + initial_position, + ) + else: + if num_paths == batch_size: + return initial_position + else: + raise ValueError( + f"num_paths ({num_paths}) must match batch_size ({batch_size})." + ) + + if base_position is None: + raise ValueError( + "base_position must be provided if initial_position is not provided" + ) + if jitter_amount is None: + raise ValueError( + "jitter_amount must be provided if initial_position is not provided" + ) + if rng_key is None: + raise ValueError("rng_key must be provided if initial_position is not provided") + + # Generate jittered positions from base position + base_position_flatten, unravel_fn = ravel_pytree(base_position) + + jitter_value = jax.random.uniform( + rng_key, + shape=(num_paths, base_position_flatten.shape[0]), + minval=-jitter_amount, + maxval=jitter_amount, + ) + jittered_positions = base_position_flatten + jitter_value + + return jax.vmap(unravel_fn)(jittered_positions) def approximate( rng_key: PRNGKey, logdensity_fn: Callable, - initial_position: ArrayLikeTree, - num_samples: int = 200, - *, # lgbfs parameters - maxiter=30, - maxcor=10, - maxls=1000, - gtol=1e-08, - ftol=1e-05, + initial_position: Array, + num_elbo_draws: int = 15, + maxcor: int = 6, + maxiter: int = 100, + maxls: int = 100, + gtol: float = 1e-08, + ftol: float = 1e-05, + epsilon: float = 1e-8, **lbfgs_kwargs, -) -> tuple[PathfinderState, PathfinderInfo]: +) -> tuple[SinglePathfinderState, SinglePathfinderInfo]: """Pathfinder variational inference algorithm. Pathfinder locates normal approximations to the target density along a @@ -91,26 +234,26 @@ def approximate( Parameters ---------- - rng_key - PRPNG key - logdensity_fn - (un-normalized) log densify function of target distribution to take + rng_key : PRNGKey + JAX PRNG key + logdensity_fn : Callable + (un-normalized) log density function of target distribution to take approximate samples from - initial_position + initial_position : Array starting point of the L-BFGS optimization routine - num_samples - number of samples to draw to estimate ELBO - maxiter - Maximum number of iterations of the LGBFS algorithm. - maxcor + maxcor : int Maximum number of metric corrections of the LGBFS algorithm ("history size") - ftol + num_elbo_draws : int + number of samples to draw to estimate ELBO + maxiter : int + Maximum number of iterations of the LGBFS algorithm. + ftol : float The LGBFS algorithm terminates the minimization when `(f_k - f_{k+1}) < ftol` - gtol + gtol : float The LGBFS algorithm terminates the minimization when `|g_k|_norm < gtol` - maxls + maxls : int The maximum number of line search steps (per iteration) for the LGBFS algorithm **lbfgs_kwargs @@ -124,10 +267,11 @@ def approximate( contains all the states traversed. """ + initial_position_flatten, unravel_fn = ravel_pytree(initial_position) objective_fn = lambda x: -logdensity_fn(unravel_fn(x)) - (_, status), history = _minimize_lbfgs( + (_, _), history = minimize_lbfgs( objective_fn, initial_position_flatten, maxiter, @@ -138,111 +282,544 @@ def approximate( **lbfgs_kwargs, ) - # Get postions and gradients of the optimization path (including the starting point). + not_converged_mask = jnp.logical_not(history.converged.at[1:].get()) + + # jax.jit would not work with truncated history, so we keep the full history position = history.x grad_position = history.g - alpha = history.alpha - # Get the update of position and gradient. - update_mask = history.update_mask[1:] + + alpha, s, z, update_mask = lbfgs_recover_alpha( + position, grad_position, not_converged_mask, epsilon + ) + s = jnp.diff(position, axis=0) z = jnp.diff(grad_position, axis=0) - # Account for the mask - s_masked = jnp.where(update_mask, s, jnp.zeros_like(s)) - z_masked = jnp.where(update_mask, z, jnp.zeros_like(z)) - # Pad 0 to leading dimension so we have constant shape output - s_padded = jnp.pad(s_masked, ((maxcor, 0), (0, 0)), mode="constant") - z_padded = jnp.pad(z_masked, ((maxcor, 0), (0, 0)), mode="constant") - - def path_finder_body_fn(rng_key, S, Z, alpha_l, theta, theta_grad): + S = lbfgs_diff_history_matrix(s, update_mask, maxcor) + Z = lbfgs_diff_history_matrix(z, update_mask, maxcor) + + position = position.at[1:].get() + grad_position = grad_position.at[1:].get() + + param_dims = position.shape[-1] + sparse = param_dims < 2 * maxcor + + def pathfinder_body_fn( + rng_key, update_mask_l, S_l, Z_l, alpha_l, theta, theta_grad + ): """The for loop body in Algorithm 1 of the Pathfinder paper.""" - beta, gamma = lbfgs_inverse_hessian_factors(S.T, Z.T, alpha_l) - phi, logq = bfgs_sample( - rng_key=rng_key, - num_samples=num_samples, - position=theta, - grad_position=theta_grad, - alpha=alpha_l, - beta=beta, - gamma=gamma, - ) - logp = -jax.vmap(objective_fn)(phi) - elbo = (logp - logq).mean() # Algorithm 7 of the paper - return elbo, beta, gamma - - # Index and reshape S and Z to be sliding window view shape=(maxiter, - # maxcor, param_dim), so we can vmap over all the iterations. - # This is in effect numpy.lib.stride_tricks.sliding_window_view - path_size = maxiter + 1 - index = jnp.arange(path_size)[:, None] + jnp.arange(maxcor)[None, :] - s_j = s_padded[index.reshape(path_size, maxcor)].reshape(path_size, maxcor, -1) - z_j = z_padded[index.reshape(path_size, maxcor)].reshape(path_size, maxcor, -1) - rng_keys = jax.random.split(rng_key, path_size) - elbo, beta, gamma = jax.vmap(path_finder_body_fn)( - rng_keys, s_j, z_j, alpha, position, grad_position - ) - elbo = jnp.where( - (jnp.arange(path_size) < (status.iter_num)) & jnp.isfinite(elbo), - elbo, - -jnp.inf, + + def _pathfinder_body_fn(args): + beta, gamma = lbfgs_inverse_hessian_factors(S_l, Z_l, alpha_l) + phi, logq = bfgs_sample( + rng_key=rng_key, + num_samples=num_elbo_draws, + position=theta, + grad_position=theta_grad, + alpha=alpha_l, + beta=beta, + gamma=gamma, + sparse=sparse, + ) + logp = jax.vmap(logdensity_fn)(phi) + logp = jnp.where(~jnp.isfinite(logp), jnp.inf, logp) + elbo = jnp.mean(logp - logq) + return elbo, beta, gamma + + def _nan_op(args): + elbo = jnp.asarray(jnp.nan, dtype=jnp.float64) + beta = jnp.ones((param_dims, 2 * maxcor), dtype=jnp.float64) * jnp.nan + gamma = jnp.ones((2 * maxcor, 2 * maxcor), dtype=jnp.float64) * jnp.nan + return elbo, beta, gamma + + args = (rng_key, S_l, Z_l, alpha_l, theta, theta_grad) + return jax.lax.cond(update_mask_l, _pathfinder_body_fn, _nan_op, args) + + rng_keys = jax.random.split(rng_key, maxiter) + elbo, beta, gamma = jax.vmap(pathfinder_body_fn)( + rng_keys, update_mask, S, Z, alpha, position, grad_position ) unravel_fn_mapped = jax.vmap(unravel_fn) - pathfinder_result = PathfinderState( - elbo, + res_argmax = ( unravel_fn_mapped(position), unravel_fn_mapped(grad_position), alpha, beta, gamma, + elbo, ) - max_elbo_idx = jnp.argmax(elbo) - return jax.tree.map(lambda x: x[max_elbo_idx], pathfinder_result), PathfinderInfo( - pathfinder_result + max_elbo_idx = jnp.nanargmax(elbo) + + # keep all of PathfinderInfo, including masked info to make approximate jittable. + return SinglePathfinderState( + *( + initial_position, + *jax.tree.map(lambda x: x.at[max_elbo_idx].get(), res_argmax), + sparse, + ), + ), SinglePathfinderInfo( + SinglePathfinderState(initial_position, *res_argmax, sparse), update_mask ) def sample( rng_key: PRNGKey, - state: PathfinderState, - num_samples: Union[int, tuple[()], tuple[int]] = (), -) -> ArrayTree: + state: SinglePathfinderState, + num_samples_per_path: int = 1000, +) -> tuple[Array, Array]: """Draw from the Pathfinder approximation of the target distribution. Parameters ---------- - rng_key - PRNG key - state + rng_key : PRNGKey + JAX PRNG key + state : PathfinderState PathfinderState containing information for sampling - num_samples + num_samples_per_path : int Number of samples to draw Returns ------- - Samples drawn from the approximate Pathfinder distribution + tuple + Samples drawn from the approximate Pathfinder distribution and their log probabilities + Raises + ------ + ValueError + If the state contains invalid values (NaN or Inf) or is not properly initialized """ + position_flatten, unravel_fn = ravel_pytree(state.position) grad_position_flatten, _ = ravel_pytree(state.grad_position) - phi, logq = bfgs_sample( + psi, logq = bfgs_sample( rng_key, - num_samples, + num_samples_per_path, position_flatten, grad_position_flatten, state.alpha, state.beta, state.gamma, + state.sparse, ) - if num_samples == (): - return unravel_fn(phi), logq + return jax.vmap(unravel_fn)(psi), logq + + +def logp(logdensity_fn: Callable, samples: Array) -> Array: + return logdensity_fn(samples) + + +def importance_sampling( + rng_key: PRNGKey, + samples: Array, + logP: Array, + logQ: Array, + num_paths: int, + num_samples: int = 1000, + method: int = ImpSamplingMethod.PSIS, +) -> ImportanceSamplingState: + """Pareto Smoothed Importance Resampling (PSIR) + + This implements the Pareto Smooth Importance Resampling (PSIR) method, as described in + Algorithm 5 of Zhang et al. (2022). The PSIR follows a similar approach to Algorithm 1 + PSIS diagnostic from Yao et al., (2018). However, before computing the importance ratio r_s, + the logP and logQ are adjusted to account for the number of multiple estimators (or paths). + The process involves resampling from the original sample with replacement, with probabilities + proportional to the computed importance weights from PSIS. + + Parameters + ---------- + rng_key : PRNGKey + JAX random key for sampling + logdensity_fn : Callable + log density function of the target distribution + samples : Array + samples from proposal distribution, shape (L, M, N) + logQ : Array + log probability values of proposal distribution, shape (L, M) + num_paths : int + number of paths used in multi-pathfinder + num_samples : int + number of draws to return where num_samples <= samples.shape[0] + method : str, optional + importance sampling method to use. Options are "psis" (default), "psir", "identity", None. + Pareto Smoothed Importance Sampling (psis) is recommended in many cases for more stable + results than Pareto Smoothed Importance Resampling (psir). identity applies the log + importance weights directly without resampling. None applies no importance sampling + weights and returns the samples as is. + + Returns + ------- + ImportanceSamplingState + importance sampled draws and other info based on the specified method + """ + + # TODO: make this function jax.jit compatible + + if num_paths != samples.shape[0]: + raise ValueError( + f"num_paths ({num_paths}) must be equal to the number of rows in samples ({samples.shape[0]})" + ) + + if samples.ndim != 3: + raise ValueError(f"Samples must be a 3D array, got shape {samples.shape}") + + batch_shape, param_shape = samples.shape[:2], samples.shape[2:] + batch_size = np.prod(batch_shape).item() + + if method == ImpSamplingMethod.NONE: + logger.warning("No importance sampling method specified, returning raw samples") + return ImportanceSamplingState( + num_paths=num_paths, samples=samples, method=method + ) else: - return jax.vmap(unravel_fn)(phi), logq + log_I = jnp.log(num_paths) + logP = logP.ravel() - log_I + logP = jnp.where(~jnp.isfinite(logP), -jnp.inf, logP) + samples = samples.reshape(batch_size, *param_shape) + + logQ = logQ.ravel() - log_I + + logiw = logP - logQ + + def psislw_wrapper(logiw_array): + def psislw(logiw_array): + result_logiw, result_k = az.psislw(np.array(logiw_array)) + return np.array(result_logiw), np.array(result_k) + + return jax.pure_callback( + psislw, + (jnp.zeros_like(logiw_array), jnp.zeros((), dtype=jnp.float64)), + logiw_array, + ) + + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", category=RuntimeWarning, message="overflow encountered in exp" + ) + match method: + case ImpSamplingMethod.PSIS: + replace = False + logiw, pareto_k = psislw_wrapper(logiw) + case ImpSamplingMethod.PSIR: + replace = True + logiw, pareto_k = psislw_wrapper(logiw) + case ImpSamplingMethod.IDENTITY: + replace = False + pareto_k = None + logger.info("Identity importance sampling (no smoothing)") + case _: + raise ValueError(f"Invalid importance sampling method: {method}. ") + + # NOTE: Pareto k is normally bad for Pathfinder even when the posterior is close to the NUTS posterior or closer to NUTS than ADVI. + # Pareto k may not be a good diagnostic for Pathfinder. + p = jnp.exp(logiw - logsumexp(logiw)) + + non_infinite = ~jnp.isfinite(p) + + def handle_non_infinite(p, non_infinite): + logger.warning( + "Detected NaN or Inf values in importance weights. " + "This may indicate numerical instability in the target or proposal distributions." + ) + p = jnp.where(non_infinite, 0.0, p) + return p / jnp.sum(p) + + p = jax.lax.cond( + jnp.any(non_infinite), handle_non_infinite, lambda p, _: p, p, non_infinite + ) + + try: + resampled = jax.random.choice( + rng_key, + samples, + shape=(num_samples,), + replace=replace, + p=p, + axis=0, + ) + return ImportanceSamplingState( + num_paths=num_paths, + samples=resampled, + pareto_k=pareto_k, + method=method, + ) + except ValueError as e1: + if "Cannot take a larger sample" in str(e1): + num_nonzero = jnp.sum(p > 0) + logger.warning( + f"Not enough valid samples: {num_nonzero} available out of {num_samples} requested. " + f"Switching to psir importance sampling with replacement." + ) + try: + resampled = jax.random.choice( + rng_key, + samples, + shape=(num_samples,), + replace=True, + p=p, + axis=0, + ) + return ImportanceSamplingState( + num_paths=num_paths, + samples=resampled, + pareto_k=pareto_k, + method=ImpSamplingMethod.PSIR, + ) + except ValueError as e2: + logger.error(f"Importance sampling failed: {str(e2)}") + raise ValueError( + "Importance sampling failed for both with and without replacement. " + ) + else: + raise e1 + + +def pathfinder( + rng_key: PRNGKey, + logdensity_fn: Callable, + initial_position: Array, + num_samples_per_path: int = 1000, + num_elbo_draws: int = 15, + maxcor: int = 6, + maxiter: int = 100, + maxls: int = 100, + gtol: float = 1e-08, + ftol: float = 1e-05, + epsilon: float = 1e-8, + **lbfgs_kwargs, +) -> tuple[Array, Array]: + approx_key, sample_key = jax.random.split(rng_key, 2) + + state, _ = approximate( + rng_key=approx_key, + logdensity_fn=logdensity_fn, + initial_position=initial_position, + num_elbo_draws=num_elbo_draws, + maxcor=maxcor, + maxiter=maxiter, + maxls=maxls, + gtol=gtol, + ftol=ftol, + epsilon=epsilon, + **lbfgs_kwargs, + ) + + samples, logq = sample( + rng_key=sample_key, state=state, num_samples_per_path=num_samples_per_path + ) + + return samples, logq + + +def _shape_handler_for_parallel( + batch_shape: tuple, + parallel_method: Literal["parallel", "vectorize"] = "parallel", +): + batch_size = np.prod(batch_shape).item() + if parallel_method == "vectorize": + return batch_shape + elif parallel_method == "parallel": + num_devices = len(jax.devices()) + if num_devices >= batch_size: + return batch_shape + elif batch_size % num_devices == 0: + return (num_devices, batch_size // num_devices) + else: + raise ValueError( + f"The batch size must be divisible by the number of devices ({num_devices}). Received batch size ({batch_size}) from the batch shape ({batch_shape})." + ) + else: + raise ValueError(f"Unsupported parallel method: {parallel_method}") + + +def multi_pathfinder( + rng_key: PRNGKey, + logdensity_fn: Callable, + base_position: Optional[Array] = None, + jitter_amount: Optional[float] = None, + num_paths: int = 4, + num_samples: int = 1000, + num_samples_per_path: int = 1000, + num_elbo_draws: int = 15, + maxcor: int = 6, + maxiter: int = 100, + maxls: int = 100, + gtol: float = 1e-08, + ftol: float = 1e-05, + epsilon: float = 1e-8, + importance_sampling_method: Literal["psis", "psir", "identity"] = "psis", + initial_position: Optional[Array] = None, + parallel_method: Literal["parallel", "vectorize"] = "parallel", + **lbfgs_kwargs, +) -> ImportanceSamplingState: + """Run the multi-pathfinder algorithm. + + This function runs multiple instances of Pathfinder in parallel and combines + the results using importance sampling. + + Parameters + ---------- + rng_key : PRNGKey + JAX PRNG key + logdensity_fn : Callable + log density function of the target distribution + base_position : Array, optional + Unflattened base position to jitter around + jitter_amount : float, optional + scale of the uniform jitter to add to the base position + num_paths : int + number of parallel Pathfinder instances to run + num_samples : int + number of final draws to return after importance sampling + num_samples_per_path : int + number of samples to draw from each Pathfinder instance + num_elbo_draws : int + number of samples to draw for ELBO estimation + maxcor : int + maximum number of metric corrections for L-BFGS + maxiter : int + maximum number of iterations for L-BFGS + maxls : int + maximum number of line search steps for L-BFGS + gtol : float + gradient tolerance for L-BFGS + ftol : float + function value tolerance for L-BFGS + importance_sampling_method : str + importance sampling method to use + initial_position : Array, optional + pre-specified initial positions to use + + Returns + ------- + ImportanceSamplingState + Result of importance sampling + + Raises + ------ + ValueError + If the inputs are inconsistent or insufficient + """ + + path_batch_shape = _shape_handler_for_parallel( + (num_paths,), parallel_method=parallel_method + ) + logp_batch_shape = _shape_handler_for_parallel( + (num_paths, num_samples_per_path), parallel_method=parallel_method + ) + + # Split the random key for initialization and sampling + init_key, rng_path_key, choice_key = jax.random.split(rng_key, 3) + + # Create the multi-pathfinder algorithm + multi_pathfinder = as_top_level_api( + logdensity_fn=logdensity_fn, + num_paths=num_paths, + num_samples=num_samples, + num_samples_per_path=num_samples_per_path, + num_elbo_draws=num_elbo_draws, + maxcor=maxcor, + maxiter=maxiter, + maxls=maxls, + gtol=gtol, + ftol=ftol, + epsilon=epsilon, + importance_sampling_method=importance_sampling_method, + **lbfgs_kwargs, + ) + + # Initialize the positions + initial_positions = multi_pathfinder.init( + rng_key=init_key, + base_position=base_position, + jitter_amount=jitter_amount, + initial_position=initial_position, + ) + + param_shape = initial_positions.shape[1:] + + path_keys = jax.random.split(rng_path_key, path_batch_shape) + + path_batch_size = np.prod(path_batch_shape).item() + + if parallel_method == "vectorize": + pathfinder_pmap = jax.jit(jax.vmap(multi_pathfinder.pathfinder)) + else: # parallel_method == "parallel" + num_devices = len(jax.devices()) + if num_devices >= path_batch_size: + pathfinder_pmap = jax.pmap(multi_pathfinder.pathfinder) + elif path_batch_size % num_devices == 0: + initial_positions = initial_positions.reshape( + (*path_batch_shape, *param_shape) + ) + initial_positions = jax.pmap(jax.vmap(lambda x: x))(initial_positions) + + path_keys = jax.pmap(jax.vmap(lambda r: r))(path_keys) + + pathfinder_pmap = jax.pmap(jax.vmap(multi_pathfinder.pathfinder)) + else: + raise ValueError( + f"The batch size must be divisible by the number of devices ({num_devices}). Received batch size ({path_batch_size}) from the batch shape ({path_batch_shape})." + ) + + # Run Pathfinder on each path + samples, logq = pathfinder_pmap( + rng_key=path_keys, initial_position=initial_positions + ) + + logp_batch_size = np.prod(logp_batch_shape).item() + + if parallel_method == "vectorize": + logp_pmap = jax.jit(jax.vmap(multi_pathfinder.logp)) + samples = samples.reshape((-1, *param_shape)) + else: # parallel_method == "parallel" + if num_devices >= logp_batch_size: + logp_pmap = jax.pmap(multi_pathfinder.logp) + elif logp_batch_size % num_devices == 0: + new_batch_shape = _shape_handler_for_parallel( + logp_batch_shape, + ) + + samples = samples.reshape((*new_batch_shape, *param_shape)) + samples = jax.pmap(jax.vmap(lambda x: x))(samples) + logp_pmap = jax.pmap(jax.vmap(multi_pathfinder.logp)) + else: + raise ValueError( + f"The batch size must be divisible by the number of devices ({num_devices}). Received batch size ({logp_batch_size}) from the batch shape ({logp_batch_shape})." + ) + + logp = logp_pmap(samples) + samples = samples.reshape((num_paths, num_samples_per_path, *param_shape)) + + # Perform importance sampling + result = jax.jit(multi_pathfinder.importance_sampling)( + rng_key=choice_key, + samples=samples, + logP=logp, + logQ=logq, + ) + + return result -def as_top_level_api(logdensity_fn: Callable) -> PathFinderAlgorithm: +def as_top_level_api( + logdensity_fn: Callable, + num_paths: int = 4, + num_samples: int = 1000, + num_samples_per_path: int = 1000, + num_elbo_draws: int = 15, + maxcor: int = 6, + maxiter: int = 100, + maxls: int = 100, + gtol: float = 1e-08, + ftol: float = 1e-05, + epsilon: float = 1e-8, + importance_sampling_method: Literal["psis", "psir", "identity"] | None = "psis", + **lbfgs_kwargs, +) -> MultiPathfinderAlgorithm: """Implements the (basic) user interface for the pathfinder kernel. Pathfinder locates normal approximations to the target density along a @@ -266,17 +843,74 @@ def as_top_level_api(logdensity_fn: Callable) -> PathFinderAlgorithm: """ - def approximate_fn( - rng_key: PRNGKey, - position: ArrayLikeTree, - num_samples: int = 200, - **lbfgs_parameters, - ): - return approximate( - rng_key, logdensity_fn, position, num_samples, **lbfgs_parameters + valid_isamp = {"psis", "psir", "identity", None} + if importance_sampling_method not in valid_isamp: + raise ValueError( + f"Invalid importance sampling method: {importance_sampling_method}. " ) - def sample_fn(rng_key: PRNGKey, state: PathfinderState, num_samples: int): - return sample(rng_key, state, num_samples) + # jittable + def init_fn( + rng_key: Optional[PRNGKey] = None, + base_position: Optional[Array] = None, + jitter_amount: Optional[float] = None, + initial_position: Optional[Array] = None, + ) -> Array: + return init( + rng_key=rng_key, + base_position=base_position, + num_paths=num_paths, + jitter_amount=jitter_amount, + initial_position=initial_position, + ) + + # jittable + def pathfinder_fn(rng_key: PRNGKey, initial_position: Array): + return pathfinder( + rng_key=rng_key, + logdensity_fn=logdensity_fn, + initial_position=initial_position, + num_samples_per_path=num_samples_per_path, + num_elbo_draws=num_elbo_draws, + maxcor=maxcor, + maxiter=maxiter, + maxls=maxls, + gtol=gtol, + ftol=ftol, + epsilon=epsilon, + **lbfgs_kwargs, + ) - return PathFinderAlgorithm(approximate_fn, sample_fn) + # jittable + def logp_fn( + samples: Array, + ): + return logp(logdensity_fn=logdensity_fn, samples=samples) + + # jittable + def importance_sampling_fn( + rng_key: PRNGKey, + samples: Array, + logP: Array, + logQ: Array, + ) -> ImportanceSamplingState: + # Inside the function, convert the enum back to the string if needed + method = { + "psis": ImpSamplingMethod.PSIS, + "psir": ImpSamplingMethod.PSIR, + "identity": ImpSamplingMethod.IDENTITY, + None: ImpSamplingMethod.NONE, + }.get(importance_sampling_method, ImpSamplingMethod.PSIS) + return importance_sampling( + rng_key=rng_key, + samples=samples, + logP=logP, + logQ=logQ, + num_paths=num_paths, + num_samples=num_samples, + method=method, + ) + + return MultiPathfinderAlgorithm( + init_fn, pathfinder_fn, logp_fn, importance_sampling_fn + ) diff --git a/requirements.txt b/requirements.txt index 4cdf22942..2acb3912a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ -e ./ +arviz chex>=0.1.83 pre-commit pytest diff --git a/tests/optimizers/test_optimizers.py b/tests/optimizers/test_optimizers.py index a7549842f..789793ee1 100644 --- a/tests/optimizers/test_optimizers.py +++ b/tests/optimizers/test_optimizers.py @@ -1,17 +1,16 @@ """Test optimizers.""" + import functools import chex import jax import jax.numpy as jnp -import jax.scipy.stats as stats import numpy as np from absl.testing import absltest, parameterized -from jax.flatten_util import ravel_pytree -from jaxopt._src.lbfgs import compute_gamma, inv_hessian_product from blackjax.optimizers.dual_averaging import dual_averaging from blackjax.optimizers.lbfgs import ( + lbfgs_diff_history_matrix, lbfgs_inverse_hessian_factors, lbfgs_inverse_hessian_formula_1, lbfgs_inverse_hessian_formula_2, @@ -20,6 +19,33 @@ ) +def compute_inverse_hessian_1_and_2(history, maxcor): + not_converged_mask = jnp.logical_not(history.converged.at[1:].get()) + + # jax.jit would not work with truncated history, so we keep the full history + position = history.x + grad_position = history.g + + alpha, s, z, update_mask = lbfgs_recover_alpha( + position, grad_position, not_converged_mask + ) + + s = jnp.diff(position, axis=0) + z = jnp.diff(grad_position, axis=0) + S = lbfgs_diff_history_matrix(s, update_mask, maxcor) + Z = lbfgs_diff_history_matrix(z, update_mask, maxcor) + + position = position.at[1:].get() + grad_position = grad_position.at[1:].get() + + beta, gamma = jax.vmap(lbfgs_inverse_hessian_factors)(S, Z, alpha) + + inv_hess_1 = jax.vmap(lbfgs_inverse_hessian_formula_1)(alpha, beta, gamma) + inv_hess_2 = jax.vmap(lbfgs_inverse_hessian_formula_2)(alpha, beta, gamma) + + return inv_hess_1, inv_hess_2 + + class OptimizerTest(chex.TestCase): def setUp(self): super().setUp() @@ -51,107 +77,90 @@ def test_dual_averaging(self): @chex.all_variants(with_pmap=False) @parameterized.parameters( - [(5, 10), (10, 2), (10, 20)], + [(4, 12), (12, 12), (4, 20), (20, 20)], ) - def test_minimize_lbfgs(self, maxiter, maxcor): + def test_minimize_lbfgs(self, maxcor, n): """Test if dot product between approximate inverse hessian and gradient is the same between two loop recursion algorthm of LBFGS and formulas of the pathfinder paper""" - def regression_logprob(log_scale, coefs, preds, x): - """Linear regression""" - scale = jnp.exp(log_scale) - scale_prior = stats.expon.logpdf(scale, 0, 1) + log_scale - coefs_prior = stats.norm.logpdf(coefs, 0, 5) - y = jnp.dot(x, coefs) - logpdf = stats.norm.logpdf(preds, y, scale) - return sum(x.sum() for x in [scale_prior, coefs_prior, logpdf]) - - def regression_model(key): - init_key0, init_key1 = jax.random.split(key, 2) - x_data = jax.random.normal(init_key0, shape=(10_000, 1)) - y_data = 3 * x_data + jax.random.normal(init_key1, shape=x_data.shape) - - logposterior_fn_ = functools.partial( - regression_logprob, x=x_data, preds=y_data - ) - logposterior_fn = lambda x: logposterior_fn_(**x) - - return logposterior_fn - - fn = regression_model(self.key) - b0 = {"log_scale": 0.0, "coefs": 2.0} - b0_flatten, unravel_fn = ravel_pytree(b0) - objective_fn = lambda x: -fn(unravel_fn(x)) - (_, status), history = self.variant( - functools.partial( - minimize_lbfgs, objective_fn, maxiter=maxiter, maxcor=maxcor - ) - )(b0_flatten) - history = jax.tree.map(lambda x: x[: status.iter_num + 1], history) - - # Test recover alpha - S = jnp.diff(history.x, axis=0) - Z = jnp.diff(history.g, axis=0) - alpha0 = history.alpha[0] - - def scan_fn(alpha, val): - alpha_l, mask_l = lbfgs_recover_alpha(alpha, *val) - return alpha_l, (alpha_l, mask_l) - - _, (alpha, mask) = jax.lax.scan(scan_fn, alpha0, (S, Z)) - np.testing.assert_array_almost_equal(alpha, history.alpha[1:]) - np.testing.assert_array_equal(mask, history.update_mask[1:]) - - # Test inverse hessian product - S_partial = S[-maxcor:].T - Z_partial = Z[-maxcor:].T - alpha = history.alpha[-1] - - beta, gamma = lbfgs_inverse_hessian_factors(S_partial, Z_partial, alpha) - inv_hess_1 = lbfgs_inverse_hessian_formula_1(alpha, beta, gamma) - inv_hess_2 = lbfgs_inverse_hessian_formula_2(alpha, beta, gamma) - - gamma = compute_gamma(S_partial, Z_partial, -1) - pk = inv_hessian_product( - -history.g[-1], - status.s_history, - status.y_history, - status.rho_history, - gamma, - status.iter_num % maxcor, + def quadratic( + x: np.ndarray, A: np.ndarray, b: np.ndarray, c: float = 0.0 + ) -> float: + """ + Quadratic function: f(x) = 0.5 * x^T * A * x - b^T * x + c + + Parameters: + x: Input vector of shape (n,) + A: Symmetric positive definite matrix of shape (n, n) + b: Vector of shape (n,) + c: Scalar constant + + Returns: + Function value at x + """ + return 0.5 * x.dot(A).dot(x) - b.dot(x) + c + + def create_spd_matrix(rng_key, n): + """Create a symmetric positive definite matrix of shape (n, n).""" + rand = jax.random.normal(rng_key, (n, n)) + A = jnp.dot(rand, rand.T) + n * jnp.eye(n) + assert np.all(jnp.linalg.eigh(A)[0] > 0) + return A + + spd_key, b_key, init_key = jax.random.split(self.key, 3) + + A = create_spd_matrix(spd_key, n) + b = jax.random.normal(b_key, (n,)) + + # initial guess + x0 = jax.random.normal(init_key, shape=(n,)) + + # run the optimizer + quadratic_fn = functools.partial(quadratic, A=A, b=b) + (result, (last_lbfgs_state, last_ls_state)), history = self.variant( + functools.partial(minimize_lbfgs, quadratic_fn, maxcor=maxcor) + )(x0) + + # check if the result is close to the expected minimum + gt_minimum = np.linalg.solve(A, b) + np.testing.assert_allclose( + result, + gt_minimum, + atol=1e-2, + err_msg=f"Expected {gt_minimum}, got {result}", ) - np.testing.assert_allclose(pk, -inv_hess_1 @ history.g[-1], atol=1e-3) - np.testing.assert_allclose(pk, -inv_hess_2 @ history.g[-1], atol=1e-3) + gt_inverse_hessian = jnp.linalg.inv(A) + inv_hess_1, inv_hess_2 = compute_inverse_hessian_1_and_2(history, maxcor=maxcor) + + np.testing.assert_allclose(gt_inverse_hessian, inv_hess_1[-1], atol=1e-1) + + np.testing.assert_allclose(gt_inverse_hessian, inv_hess_2[-1], atol=1e-1) @chex.all_variants(with_pmap=False) def test_recover_diag_inv_hess(self): "Compare inverse Hessian estimation from LBFGS with known groundtruth." nd = 5 + maxcor = 6 + mean = np.linspace(3.0, 50.0, nd) cov = np.diag(np.linspace(1.0, 10.0, nd)) def loss_fn(x): - return -stats.multivariate_normal.logpdf(x, mean, cov) - - (result, status), history = self.variant( - functools.partial(minimize_lbfgs, loss_fn, maxiter=50) - )(np.zeros(nd)) - history = jax.tree.map(lambda x: x[: status.iter_num + 1], history) + return -jax.scipy.stats.multivariate_normal.logpdf(x, mean, cov) - np.testing.assert_allclose(result, mean, rtol=0.01) + x0 = jnp.zeros(nd) + (result, (last_lbfgs_state, last_ls_state)), history = self.variant( + functools.partial(minimize_lbfgs, loss_fn, maxcor=maxcor) + )(x0) - S_partial = jnp.diff(history.x, axis=0)[-10:].T - Z_partial = jnp.diff(history.g, axis=0)[-10:].T - alpha = history.alpha[-1] + np.testing.assert_allclose(result, mean, rtol=0.05) - beta, gamma = lbfgs_inverse_hessian_factors(S_partial, Z_partial, alpha) - inv_hess_1 = lbfgs_inverse_hessian_formula_1(alpha, beta, gamma) - inv_hess_2 = lbfgs_inverse_hessian_formula_2(alpha, beta, gamma) + inv_hess_1, inv_hess_2 = compute_inverse_hessian_1_and_2(history, maxcor=maxcor) - np.testing.assert_allclose(np.diag(inv_hess_1), np.diag(cov), rtol=0.01) - np.testing.assert_allclose(inv_hess_1, inv_hess_2, rtol=0.01) + np.testing.assert_allclose(np.diag(inv_hess_1[-1]), np.diag(cov), rtol=0.05) + np.testing.assert_allclose(inv_hess_1[-1], inv_hess_2[-1], rtol=0.05) if __name__ == "__main__": diff --git a/tests/optimizers/test_pathfinder.py b/tests/optimizers/test_pathfinder.py index b9b9c69be..173748244 100644 --- a/tests/optimizers/test_pathfinder.py +++ b/tests/optimizers/test_pathfinder.py @@ -1,14 +1,14 @@ """Test the pathfinder algorithm.""" + import functools import chex import jax import jax.numpy as jnp -import jax.scipy.stats as stats +import numpy as np from absl.testing import absltest, parameterized import blackjax -from blackjax.optimizers.lbfgs import bfgs_sample class PathfinderTest(chex.TestCase): @@ -16,7 +16,7 @@ def setUp(self): super().setUp() self.key = jax.random.key(1) - @chex.all_variants(without_device=False, with_pmap=False) + @chex.all_variants(with_pmap=False) @parameterized.parameters( [(1,), (2,)], ) @@ -36,16 +36,20 @@ def logp_posterior_conjugate_normal_model( + n * true_prec @ observed.mean(0)[:, None] ) )[:, 0] - return stats.multivariate_normal.logpdf(x, posterior_mu, posterior_cov) + return jax.scipy.stats.multivariate_normal.logpdf( + x, posterior_mu, posterior_cov + ) def logp_unnormalized_posterior(x, observed, prior_mu, prior_prec, true_cov): logp = 0.0 - logp += stats.multivariate_normal.logpdf(x, prior_mu, prior_prec) - logp += stats.multivariate_normal.logpdf(observed, x, true_cov).sum() + logp += jax.scipy.stats.multivariate_normal.logpdf(x, prior_mu, prior_prec) + logp += jax.scipy.stats.multivariate_normal.logpdf( + observed, x, true_cov + ).sum() return logp - rng_key_chol, rng_key_observed, rng_key_pathfinder = jax.random.split( - self.key, 3 + rng_key_chol, rng_key_observed, rng_key_path, rng_key_choice = jax.random.split( + self.key, 4 ) L = jnp.tril(jax.random.normal(rng_key_chol, (ndim, ndim))) @@ -69,27 +73,94 @@ def logp_unnormalized_posterior(x, observed, prior_mu, prior_prec, true_cov): ) x0 = jnp.ones(ndim) - pathfinder = blackjax.pathfinder(logp_model) - out, _ = self.variant(pathfinder.approximate)(rng_key_pathfinder, x0) - - sim_p, log_p = bfgs_sample( - rng_key_pathfinder, - 10_000, - out.position, - out.grad_position, - out.alpha, - out.beta, - out.gamma, + num_paths = 4 + + pathfinder = blackjax.pathfinder(logp_model, num_paths=num_paths) + initial_positions = self.variant(pathfinder.init)(initial_position=x0) + + path_keys = jax.random.split(rng_key_path, num_paths) + + samples, logq = self.variant(jax.vmap(pathfinder.pathfinder))( + path_keys, initial_positions ) - log_q = logp_posterior_conjugate_normal_model( - sim_p, observed, prior_mu, prior_prec, true_prec + samples = samples.reshape((-1, ndim)) + logq = logq.ravel() + + logp = logp_posterior_conjugate_normal_model( + samples, observed, prior_mu, prior_prec, true_prec ) - kl = (log_p - log_q).mean() + kl = (logp - logq).mean() # TODO(junpenglao): Make this test more robust. self.assertAlmostEqual(kl, 0.0, delta=2.5) + result = blackjax.multi_pathfinder( + rng_key=self.key, + logdensity_fn=logp_model, + initial_position=x0, + # jitter_amount=12.0, + num_paths=num_paths, + parallel_method="vectorize", + ) + + self.assertAlmostEqual(result.samples.mean(), 0.0, delta=2.5) + + @chex.all_variants(with_pmap=False) + @parameterized.parameters( + [(2,), (6,)], + ) + def test_recover_posterior_eight_schools(self, maxcor): + J = 8 + y = jnp.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]) + sigma = jnp.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]) + + def eight_schools_log_density(y, sigma, mu, tau, theta): + logp = 0.0 + # Prior for mu + logp += jax.scipy.stats.norm.logpdf(mu, loc=0.0, scale=10.0) + # Prior for tau + logp += jax.scipy.stats.gamma.logpdf(tau, 5, 1) + # Prior for theta + logp += jax.scipy.stats.norm.logpdf(theta, loc=0.0, scale=1.0).sum() + # Likelihood + logp += jax.scipy.stats.norm.logpdf( + y, loc=mu + tau * theta, scale=sigma + ).sum() + return logp + + def logdensity_fn(param): + def inner(param): + mu, tau, *theta = param + mu = jnp.atleast_1d(mu) + tau = jnp.atleast_1d(tau) + theta = jnp.array(theta) + return eight_schools_log_density(y, sigma, mu, tau, theta) + + return inner(param).squeeze() + + mu_prior = jnp.array([0.0]) + tau_prior = jnp.array([5.0]) + theta_prior = jnp.array([0.0] * J) + base_position = jnp.concatenate([mu_prior, tau_prior, theta_prior]) + + mp = functools.partial( + blackjax.multi_pathfinder, + logdensity_fn=logdensity_fn, + num_paths=20, + maxcor=maxcor, + parallel_method="vectorize", + ) + + result = self.variant(mp)( + rng_key=self.key, + base_position=base_position, + jitter_amount=12.0, + ) + + np.testing.assert_allclose(result.samples[:, 0].mean(), 5.0, atol=1.6) + np.testing.assert_allclose(result.samples[:, 1].mean(), 4.15, atol=1.5) + if __name__ == "__main__": absltest.main()