Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add forward enzyme rules for init and solve #416

Merged
merged 11 commits into from
Nov 8, 2023
47 changes: 47 additions & 0 deletions ext/LinearSolveEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,53 @@

using EnzymeCore

function EnzymeCore.EnzymeRules.forward(func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
@assert !(prob isa Const)
res = func.val(prob.val, alg.val; kwargs...)
if RT <: Const
return res

Check warning on line 16 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L12-L16

Added lines #L12 - L16 were not covered by tests
end
dres = func.val(prob.dval, alg.val; kwargs...)
dres.b .= res.b == dres.b ? zero(dres.b) : dres.b
dres.A .= res.A == dres.A ? zero(dres.A) : dres.A
if RT <: DuplicatedNoNeed
return dres
elseif RT <: Duplicated
return Duplicated(res, dres)

Check warning on line 24 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L18-L24

Added lines #L18 - L24 were not covered by tests
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't handle batching atm

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Throw a good error for now? It would be good to get something merged even if it doesn't handle every case, as long as the errors are clear.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tried to address this in 585cbb8. We can make a followup PR for batch support.

end
error("Unsupported return type $RT")

Check warning on line 26 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L26

Added line #L26 was not covered by tests
end

function EnzymeCore.EnzymeRules.forward(func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache}
@assert !(linsolve isa Const)

Check warning on line 30 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L29-L30

Added lines #L29 - L30 were not covered by tests

res = func.val(linsolve.val; kwargs...)

Check warning on line 32 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L32

Added line #L32 was not covered by tests

if RT <: Const
return res

Check warning on line 35 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L34-L35

Added lines #L34 - L35 were not covered by tests
end
if linsolve.val.alg isa LinearSolve.AbstractKrylovSubspaceMethod
error("Algorithm $(_linsolve.alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling")

Check warning on line 38 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L37-L38

Added lines #L37 - L38 were not covered by tests
end
b = deepcopy(linsolve.val.b)

Check warning on line 40 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L40

Added line #L40 was not covered by tests

db = linsolve.dval.b
dA = linsolve.dval.A

Check warning on line 43 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L42-L43

Added lines #L42 - L43 were not covered by tests

linsolve.val.b = db - dA * res.u
dres = func.val(linsolve.val; kwargs...)

Check warning on line 46 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L45-L46

Added lines #L45 - L46 were not covered by tests

linsolve.val.b = b

Check warning on line 48 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L48

Added line #L48 was not covered by tests

if RT <: DuplicatedNoNeed
return dres
elseif RT <: Duplicated
return Duplicated(res, dres)

Check warning on line 53 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L50-L53

Added lines #L50 - L53 were not covered by tests
end

return Duplicated(res, dres)

Check warning on line 56 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L56

Added line #L56 was not covered by tests
end

function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
res = func.val(prob.val, alg.val; kwargs...)
dres = if EnzymeRules.width(config) == 1
Expand Down
54 changes: 53 additions & 1 deletion test/enzyme.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
using Enzyme, ForwardDiff
using LinearSolve, LinearAlgebra, Test
using FiniteDiff
using SafeTestsets

n = 4
A = rand(n, n);
Expand Down Expand Up @@ -161,4 +163,54 @@ Enzyme.autodiff(Reverse, f3, Duplicated(copy(A), dA), Duplicated(copy(b1), db1),
@test dA ≈ dA2 atol=5e-5
@test db1 ≈ db12
@test db2 ≈ db22
=#
=#

A = rand(n, n);
dA = zeros(n, n);
b1 = rand(n);
for alg in (
LUFactorization(),
RFLUFactorization(),
# KrylovJL_GMRES(), fails
)
@show alg
function fb(b)
prob = LinearProblem(A, b)

sol1 = solve(prob, alg)

sum(sol1.u)
end
fb(b1)

fd_jac = FiniteDiff.finite_difference_jacobian(fb, b1) |> vec
@show fd_jac

en_jac = map(onehot(b1)) do db1
eres = Enzyme.autodiff(Forward, fb, Duplicated(copy(b1), db1))
eres[1]
end |> collect
@show en_jac

@test en_jac ≈ fd_jac rtol=1e-4

function fA(A)
prob = LinearProblem(A, b1)

sol1 = solve(prob, alg)

sum(sol1.u)
end
fA(A)

fd_jac = FiniteDiff.finite_difference_jacobian(fA, A) |> vec
@show fd_jac

en_jac = map(onehot(A)) do dA
eres = Enzyme.autodiff(Forward, fA, Duplicated(copy(A), dA))
eres[1]
end |> collect
@show en_jac

@test en_jac ≈ fd_jac rtol=1e-4
end
Loading