Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradients and Hessian-vector products #152

Open
omalled opened this issue Jul 5, 2022 · 1 comment
Open

Gradients and Hessian-vector products #152

omalled opened this issue Jul 5, 2022 · 1 comment

Comments

@omalled
Copy link

omalled commented Jul 5, 2022

I saw in the docs that "the current algorithms should support automatic differentiation," and played around with it a bit. My ultimate goal is to get a Hessian-vector product working (similar to this issue in AlgebraicMultigrid, which was never resolved despite some effort being made). However, I wasn't able to get a Hessian-vector product or even a gradient working in a relatively simple example:

using Test
import ForwardDiff
import LinearAlgebra
import LinearSolve
import SparseArrays
import Zygote

hessian_vector_product(f, x, v) = ForwardDiff.jacobian(s->Zygote.gradient(f, x + s[1] * v)[1], [0.0])[:]

n = 4
A = randn(n, n)
hessian = A + A'
f(x) = LinearAlgebra.dot(x, A * x) 
x = randn(n)
v = randn(n)
hvp1 = hessian_vector_product(f, x, v)
hvp2 = hessian * v
@test hvp1  hvp2#the hessian_vector_product plausibly works!

function g(x)
	k = x[1:n + 1]
	B = SparseArrays.spdiagm(0=>k[1:end - 1] + k[2:end], -1=>-k[2:end - 1], 1=>-k[2:end - 1])
	prob = LinearSolve.LinearProblem(B, x[n + 2:end])
	sol = LinearSolve.solve(prob)
	return sum(sol.u)
end
x = randn(2 * n + 1)
v = randn(2 * n + 1)
Zygote.gradient(g, x)#Can't differentiate foreigncall expression
hessian_vector_product(g, x, v)#LoadError: MethodError: no method matching SuiteSparse.UMFPACK.UmfpackLU...

Is there any chance to get these derivatives working with LinearSolve? It would be really great, especially the Hessian-vector products. Thanks for your help and your great work on this package! Please let me know if there's something I can do to help get this working 😄

@ChrisRackauckas
Copy link
Member

Yeah... that's why it said "should" in the "Roadmap" section of the docs 😅. A lot of cases end up working out since it just differentiates the algorithm, and things like lu/qr/svd have ChainRules defined on them so a lot of cases "accidentally" work. But what we need to do is lower to solve_ab(A,b,sensealg,alg) etc. and then define the chain rule on that, which is the same as https://github.com/JuliaDiff/ChainRules.jl/blob/main/src/rulesets/Base/arraymath.jl#L336-L359 .

Then it just needs a solve on the adjoint, i.e. #92

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants