-
Notifications
You must be signed in to change notification settings - Fork 118
Multi-Path Pathfinder #783
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This update adds the multi pathfinder capabilities, parallel Pathfinder runs, LBFGS optimisation using Optax, and importance sampling. - Multi-Pathfinder: Supports parallel and vectorized sampling strategies, improving scalability. - LBFGS Optimizer: Refactored using Optax for better numerical stability and performance, with enhanced inverse Hessian estimation. - Importance Sampling: Added support for various methods, improving sampling accuracy. - Testing: Expanded tests for both Pathfinder and LBFGS to ensure correctness and stability.
def psislw_wrapper(logiw_array): | ||
def psislw(logiw_array): | ||
result_logiw, result_k = az.psislw(np.array(logiw_array)) | ||
return np.array(result_logiw), np.array(result_k) | ||
|
||
return jax.pure_callback( | ||
psislw, | ||
(jnp.zeros_like(logiw_array), jnp.zeros((), dtype=jnp.float64)), | ||
logiw_array, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should add a pure JAX implementation of psislw
(in a seperate PR), as pure_callback will be quite inefficient.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can. I have found a JAX implementation here https://gist.github.com/adamhaber/0556671340e0daa9e2c6e3fd535cd992 by @adamhaber as a good starting point
Hi @aphc14, could you isolate the change replacing jaxopt with optax in a seperate PR? |
Should this PR include the optax code changes while referencing the separate optax PR? Or should this PR revert the changes from optax to jaxopt? |
Implement Multi-Pathfinder and Enhance Pathfinder and LBFGS Optimizers
High-Level Description
Multi-Pathfinder: Introduces the
multi_pathfinder
function, enabling parallel and vectorized sampling strategies. Initial tests indicate thatmulti_pathfinder
is functioning as expected. Feel free to test it out and provide feedback.LBFGS Optimizer: The optimizer has been refactored to use Optax.
single pathfinder fix: Fixed inaccurate calculations of S Z matrices, phi, log densities.
alpha_recover: Decoupling of alpha recover out of LBFGS optimisation.
Importance Sampling: Added support for various importance sampling methods.
Testing: Expanded and updated tests for both Pathfinder and LBFGS.
Current Status
This is a draft PR that requires some tidying up. There are existing linting errors from mypy that need to be addressed. Your feedback and testing are welcome to help refine these changes.
Checklist
main
commit.pre-commit
to check for any issues.resolves #763, #213, #461, #749, #704
related #465, #387