-
Notifications
You must be signed in to change notification settings - Fork 25
Option to return auxiliary data from the primal #720
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
The issue I see is that most backends want a single-output function. For instance, to get the gradient of |
Fair point! I guess the only way to get the extra_data out of a single call without purely returning it is to extract it by a side effect indeed (i.e. global state / save-to-file / ...). That does seem tricky to do in a nice way for all AD backends. |
Idea in passing: maybe we could define a new type of |
FWIW, |
I don't think @niklasschmitz meant returning the primal value? If so, we do of course return it efficiently, making use of Lines 111 to 117 in 13f1859
I think the question was more about side-products of the primal computation, like solver timing statistics, that sort of thing, which is separate from the value of the function itself |
That's right? |
Oh right, my bad, I thought your mention of withgradient only referred to its straightforward use. But yes, I was aware that Zygote supports this, I didn't know about Enzyme but I suspect they are pretty much the only backends where this is possible (probably with Mooncake). If we write the auxiliary data to a cache as suggested above, more possibilities might open up though |
A very common use case is that one wants to not only differentiate an objective, but also get some auxiliary output (intermediate results, the predictions of an ML model, data structures of a PDE solver, etc.)
For example, in JAX there is the
has_aux
keyword option in jax.value_and_grad, which is actually the most common usage pattern of AD in JAX I have seen. The pattern looks like this (See e.g. the flax docs for a full example in context)I typically use some hacky workarounds to achieve similar behavior in Julia, but maybe it is common enough to solve it at the interface level?
The text was updated successfully, but these errors were encountered: