diff --git a/ext/LinearSolveEnzymeExt.jl b/ext/LinearSolveEnzymeExt.jl index 26d530d19..304e302d2 100644 --- a/ext/LinearSolveEnzymeExt.jl +++ b/ext/LinearSolveEnzymeExt.jl @@ -9,6 +9,53 @@ using Enzyme using EnzymeCore +function EnzymeCore.EnzymeRules.forward(func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem} + @assert !(prob isa Const) + res = func.val(prob.val, alg.val; kwargs...) + if RT <: Const + return res + end + dres = func.val(prob.dval, alg.val; kwargs...) + dres.b .= res.b == dres.b ? zero(dres.b) : dres.b + dres.A .= res.A == dres.A ? zero(dres.A) : dres.A + if RT <: DuplicatedNoNeed + return dres + elseif RT <: Duplicated + return Duplicated(res, dres) + end + error("Unsupported return type $RT") +end + +function EnzymeCore.EnzymeRules.forward(func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache} + @assert !(linsolve isa Const) + + res = func.val(linsolve.val; kwargs...) + + if RT <: Const + return res + end + if linsolve.val.alg isa LinearSolve.AbstractKrylovSubspaceMethod + error("Algorithm $(_linsolve.alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling") + end + b = deepcopy(linsolve.val.b) + + db = linsolve.dval.b + dA = linsolve.dval.A + + linsolve.val.b = db - dA * res.u + dres = func.val(linsolve.val; kwargs...) + + linsolve.val.b = b + + if RT <: DuplicatedNoNeed + return dres + elseif RT <: Duplicated + return Duplicated(res, dres) + end + + return Duplicated(res, dres) +end + function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem} res = func.val(prob.val, alg.val; kwargs...) dres = if EnzymeRules.width(config) == 1 diff --git a/test/enzyme.jl b/test/enzyme.jl index 02e071d41..89903a858 100644 --- a/test/enzyme.jl +++ b/test/enzyme.jl @@ -1,5 +1,7 @@ using Enzyme, ForwardDiff using LinearSolve, LinearAlgebra, Test +using FiniteDiff +using SafeTestsets n = 4 A = rand(n, n); @@ -161,4 +163,54 @@ Enzyme.autodiff(Reverse, f3, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), @test dA ≈ dA2 atol=5e-5 @test db1 ≈ db12 @test db2 ≈ db22 -=# \ No newline at end of file +=# + +A = rand(n, n); +dA = zeros(n, n); +b1 = rand(n); +for alg in ( + LUFactorization(), + RFLUFactorization(), + # KrylovJL_GMRES(), fails + ) + @show alg + function fb(b) + prob = LinearProblem(A, b) + + sol1 = solve(prob, alg) + + sum(sol1.u) + end + fb(b1) + + fd_jac = FiniteDiff.finite_difference_jacobian(fb, b1) |> vec + @show fd_jac + + en_jac = map(onehot(b1)) do db1 + eres = Enzyme.autodiff(Forward, fb, Duplicated(copy(b1), db1)) + eres[1] + end |> collect + @show en_jac + + @test en_jac ≈ fd_jac rtol=1e-4 + + function fA(A) + prob = LinearProblem(A, b1) + + sol1 = solve(prob, alg) + + sum(sol1.u) + end + fA(A) + + fd_jac = FiniteDiff.finite_difference_jacobian(fA, A) |> vec + @show fd_jac + + en_jac = map(onehot(A)) do dA + eres = Enzyme.autodiff(Forward, fA, Duplicated(copy(A), dA)) + eres[1] + end |> collect + @show en_jac + + @test en_jac ≈ fd_jac rtol=1e-4 +end \ No newline at end of file