Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move to SciPy optimizers, made differentiable through jax.lax.custom_root #4

Open
phinate opened this issue Mar 1, 2020 · 4 comments
Labels
dependencies Pull requests that update a dependency file enhancement New feature or request help wanted Extra attention is needed

Comments

@phinate
Copy link
Member

phinate commented Mar 1, 2020

The current implementation of the maximum likelihood fits uses gradient descent to converge to the optimal parameter values. In principle, for comparison with the optimization implementation in pyhf, and for more robust minimization, switching to SciPy optimizers is preferred.

To do this, one needs to differentiate through the optimizer using implicit differentiation. It's probably possible to do this using fax like we do now, but this issue on the jax repo discusses the possibility of wrapping SciPy optimizers using jax.lax.custom_root, which would remove a dependency, and make for (probably) more simplistic code.

@phinate phinate added enhancement New feature or request help wanted Extra attention is needed dependencies Pull requests that update a dependency file labels Mar 1, 2020
@gehring
Copy link
Contributor

gehring commented Mar 11, 2020

You can easily achieve this in fax.implicit.twophase.two_phase_solver by specifying a foward_solver. The documentation for that argument seems to have disappeared (probably during our last refactor) but you can look at the signature for the default solver for how to use it.

We originally implemented this before jax.lax.custom_root was added but I'm looking adding a version based on it. There seems to be a few advantages over the current jax.defvjp approach but this might change once the jax devs reimplement the defvjp API. I believe custom_root supports both forward and backward differentiation which might be the most significant advantage over a defvjp approach which only supports backwards differentiation.

Otherwise, the two phase approach is essential a more specific instance of jax.lax.custom_root; they are both a direct application of the implicit function theorem. The two phase approach uses an additional assumption that the constraints (i.e., the root f(solution)=0 in custom_root) defining the implicit function corresponds to an attractive fixed point (i.e., f(x) := h(x) - x where two_phase_solver expects to be given h).

The two phase approach leverages the attractive assumption to solve for the derivatives by simply iterating (which is guaranteed to converge in this case) while custom_root needs to be given a solver. If you were to define a solver using a simple fixed point iteration approach on the residual, you'd be doing exactly the same thing as two_phase_solver. As a side note, two_phase_solver used to support custom solvers for the backwards pass but we never really used it so we dropped it from the API. Adding that back in would be trivial.

One thing to note is that, last we checked, it's not obvious to "extract" intermediate information from the backwards pass using jax.lax.custom_root, e.g., number of iterations to solve for the gradient, residual error. I haven't had the time to check if Deepmind's newly released haiku can handle this with their hk.set_state method but this issue is the reason we've split up the core behavior of two_phase_solver. You'll notice that two_phase_solver is little more than a wrapper over those two calls which, in our work, we often ended up calling directly in order to have access to that extra info.

I hope this helps clarify the differences and similarities between custom_root and what fax does with two_phase_solver. If you'd like us to look more seriously into a custom_root implementation, it would be useful if you opened an issue in fax for tracking purposes.

@gehring
Copy link
Contributor

gehring commented Aug 2, 2020

Fyi, we're discussing re-implementing fax's implicit differentiation backend as a root finding problem instead of a fixed-point problem (gehring/fax#19). We plan on implementing our own version of jax.lax.custom_root (most likely with a similar api to theirs) which would support calling non-jax code in the solvers. If you have an opinion on the topic, I'd encourage you to let us know!

@phinate
Copy link
Member Author

phinate commented Aug 3, 2020

Thanks for letting us know @gehring! From what you’re saying, it sounds like one could outright use SciPy optimisers/root-finders without writing the corresponding jax code, which would cover the use case in this issue (and another one that I haven’t written on GH yet) — my thoughts are only positive!

@lukasheinrich, do you think this could end up baked straight into pyhf’s optimise module?

(Also thanks for your previous comment @gehring, it clarified some things for me!)

@gehring
Copy link
Contributor

gehring commented Aug 3, 2020

From what you’re saying, it sounds like one could outright use SciPy optimisers/root-finders without writing the corresponding jax code [...]

I think you are referring only to the solver and, in that case, yes. Note that you can already do that with fax.implicit.two_phase_solve and we've recently added an example in our readme. The only limitation is that you won't be able to wrap the whole thing in jax.jit until jax implements support for XLA's CustomCall (or until someone implements a jax primitive for it that we feel comfortable depending on). This applies to the current two_phase_solve and will also apply to our version of custom_root.

However, just to make sure thinks are extra clear, I'll mention that you will still need to provide jax code when specifying the function for which the root is being solved. This is because we'll still need to make use of automatic differentiation through jax.vjp in order to solve for the derivatives using implicit differentiation. Though you have a lot of freedom as to what this jax function is. The only requirement is that the output of the solver is the root of this jax function, but, otherwise, it doesn't need to be related to anything used by the solver.

For an example of what I mean in the fixed-point case (e.g., using two_phase_solve), you could be looking for the fixed-point of some simple gradient step (e.g., x + grad(f)(x)) but solve for that optimum using ADAM, Newton's method, or even a sequential quadratic solver of some kind. The choice of solver won't matter as long as they all agree on the optimum.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
dependencies Pull requests that update a dependency file enhancement New feature or request help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

2 participants