Skip to content

Commit

Permalink
Merge pull request #543 from SciML/ap/enz
Browse files Browse the repository at this point in the history
Adapt to pending Enzyme breaking change
  • Loading branch information
ChrisRackauckas authored Sep 27, 2024
2 parents 8d54518 + b2e24b4 commit 174b51a
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 65 deletions.
1 change: 1 addition & 0 deletions .github/workflows/Downgrade.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ jobs:
version: ['1']
group:
- Core
- Enzyme
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/Tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ jobs:
- "LinearSolveHYPRE"
- "LinearSolvePardiso"
- "LinearSolveBandedMatrices"
- "Enzyme"
uses: "SciML/.github/.github/workflows/tests.yml@v1"
with:
group: "${{ matrix.group }}"
Expand Down
15 changes: 7 additions & 8 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
CUDSS = "45b445bb-4962-46a0-9369-b4df9d0f772e"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
FastAlmostBandedMatrices = "9d29842c-ecb8-4973-b1e9-a27b1157504e"
HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
Expand All @@ -52,8 +51,8 @@ LinearSolveBandedMatricesExt = "BandedMatrices"
LinearSolveBlockDiagonalsExt = "BlockDiagonals"
LinearSolveCUDAExt = "CUDA"
LinearSolveCUDSSExt = "CUDSS"
LinearSolveEnzymeExt = ["Enzyme", "EnzymeCore"]
LinearSolveFastAlmostBandedMatricesExt = ["FastAlmostBandedMatrices"]
LinearSolveEnzymeExt = "EnzymeCore"
LinearSolveFastAlmostBandedMatricesExt = "FastAlmostBandedMatrices"
LinearSolveHYPREExt = "HYPRE"
LinearSolveIterativeSolversExt = "IterativeSolvers"
LinearSolveKernelAbstractionsExt = "KernelAbstractions"
Expand All @@ -74,8 +73,8 @@ ChainRulesCore = "1.22"
ConcreteStructs = "0.2.3"
DocStringExtensions = "0.9.3"
EnumX = "1.0.4"
Enzyme = "0.11.15, 0.12, 0.13"
EnzymeCore = "0.6.5, 0.7, 0.8"
Enzyme = "0.13"
EnzymeCore = "0.8.1"
FastAlmostBandedMatrices = "0.1"
FastLapackInterface = "2"
FiniteDiff = "2.22"
Expand All @@ -84,12 +83,12 @@ GPUArraysCore = "0.1.6"
HYPRE = "1.4.0"
InteractiveUtils = "1.10"
IterativeSolvers = "0.9.3"
JET = "0.8.28"
JET = "0.8.28, 0.9"
KLU = "0.6"
KernelAbstractions = "0.9.16"
KernelAbstractions = "0.9.27"
Krylov = "0.9"
KrylovKit = "0.8"
KrylovPreconditioners = "0.2"
KrylovPreconditioners = "0.3"
LazyArrays = "1.8, 2"
Libdl = "1.10"
LinearAlgebra = "1.10"
Expand Down
76 changes: 46 additions & 30 deletions ext/LinearSolveEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,64 +2,80 @@ module LinearSolveEnzymeExt

using LinearSolve
using LinearSolve.LinearAlgebra
isdefined(Base, :get_extension) ? (import Enzyme) : (import ..Enzyme)

using Enzyme

using EnzymeCore
using EnzymeCore: EnzymeRules

function EnzymeCore.EnzymeRules.forward(
function EnzymeRules.forward(config::EnzymeRules.FwdConfigWidth{1},
func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP},
alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
@assert !(prob isa Const)
res = func.val(prob.val, alg.val; kwargs...)
if RT <: Const
return res
if EnzymeRules.needs_primal(config)
return res
else
return nothing
end
end

dres = func.val(prob.dval, alg.val; kwargs...)
dres.b .= res.b == dres.b ? zero(dres.b) : dres.b
dres.A .= res.A == dres.A ? zero(dres.A) : dres.A
if RT <: DuplicatedNoNeed
return dres
elseif RT <: Duplicated

if dres.b == res.b
dres.b .= false
end
if dres.A == res.A
dres.A .= false
end

if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
return Duplicated(res, dres)
elseif EnzymeRules.needs_shadow(config)
return dres
elseif EnzymeRules.needs_primal(config)
return res
else
return nothing
end
error("Unsupported return type $RT")
end

function EnzymeCore.EnzymeRules.forward(func::Const{typeof(LinearSolve.solve!)},
function EnzymeRules.forward(
config::EnzymeRules.FwdConfigWidth{1}, func::Const{typeof(LinearSolve.solve!)},
::Type{RT}, linsolve::EnzymeCore.Annotation{LP};
kwargs...) where {RT, LP <: LinearSolve.LinearCache}
@assert !(linsolve isa Const)

res = func.val(linsolve.val; kwargs...)

if RT <: Const
return res
if EnzymeRules.needs_primal(config)
return res
else
return nothing
end
end
if linsolve.val.alg isa LinearSolve.AbstractKrylovSubspaceMethod
error("Algorithm $(_linsolve.alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling")
end
b = deepcopy(linsolve.val.b)

db = linsolve.dval.b
dA = linsolve.dval.A
res = deepcopy(res) # Without this copy, the next solve will end up mutating the result

linsolve.val.b = db - dA * res.u
b = linsolve.val.b
linsolve.val.b = linsolve.dval.b - linsolve.dval.A * res.u
dres = func.val(linsolve.val; kwargs...)

linsolve.val.b = b

if RT <: DuplicatedNoNeed
return dres
elseif RT <: Duplicated
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
return Duplicated(res, dres)
elseif EnzymeRules.needs_shadow(config)
return dres
elseif EnzymeRules.needs_primal(config)
return res
else
return nothing
end

return Duplicated(res, dres)
end

function EnzymeCore.EnzymeRules.augmented_primal(
function EnzymeRules.augmented_primal(
config, func::Const{typeof(LinearSolve.init)},
::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const;
kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
Expand Down Expand Up @@ -94,10 +110,10 @@ function EnzymeCore.EnzymeRules.augmented_primal(
(dval.b for dval in prob.dval)
end

return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, (d_A, d_b, prob_d_A, prob_d_b))
return EnzymeRules.AugmentedReturn(res, dres, (d_A, d_b, prob_d_A, prob_d_b))
end

function EnzymeCore.EnzymeRules.reverse(
function EnzymeRules.reverse(
config, func::Const{typeof(LinearSolve.init)}, ::Type{RT},
cache, prob::EnzymeCore.Annotation{LP}, alg::Const;
kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
Expand Down Expand Up @@ -131,7 +147,7 @@ end
# y=inv(A) B
# dA −= z y^T
# dB += z, where z = inv(A^T) dy
function EnzymeCore.EnzymeRules.augmented_primal(
function EnzymeRules.augmented_primal(
config, func::Const{typeof(LinearSolve.solve!)},
::Type{RT}, linsolve::EnzymeCore.Annotation{LP};
kwargs...) where {RT, LP <: LinearSolve.LinearCache}
Expand Down Expand Up @@ -184,10 +200,10 @@ function EnzymeCore.EnzymeRules.augmented_primal(
cachesolve = deepcopy(linsolve.val)

cache = (copy(res.u), resvals, cachesolve, dAs, dbs)
return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, cache)
return EnzymeRules.AugmentedReturn(res, dres, cache)
end

function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.solve!)},
function EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.solve!)},
::Type{RT}, cache, linsolve::EnzymeCore.Annotation{LP};
kwargs...) where {RT, LP <: LinearSolve.LinearCache}
y, dys, _linsolve, dAs, dbs = cache
Expand Down
43 changes: 17 additions & 26 deletions test/enzyme.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
using Enzyme, ForwardDiff
using LinearSolve, LinearAlgebra, Test
using FiniteDiff
using SafeTestsets

n = 4
A = rand(n, n);
Expand Down Expand Up @@ -178,47 +177,39 @@ Enzyme.autodiff(Reverse, f3, Duplicated(copy(A), dA), Duplicated(copy(b1), db1),
A = rand(n, n);
dA = zeros(n, n);
b1 = rand(n);
for alg in (

function fnice(A, b, alg)
prob = LinearProblem(A, b)
sol1 = solve(prob, alg)
return sum(sol1.u)
end

@testset for alg in (
LUFactorization(),
RFLUFactorization() # KrylovJL_GMRES(), fails
)
@show alg
function fb(b)
prob = LinearProblem(A, b)

sol1 = solve(prob, alg)
fb_closure = b -> fnice(A, b, alg)

sum(sol1.u)
end
fb(b1)

fd_jac = FiniteDiff.finite_difference_jacobian(fb, b1) |> vec
fd_jac = FiniteDiff.finite_difference_jacobian(fb_closure, b1) |> vec
@show fd_jac

en_jac = map(onehot(b1)) do db1
eres = Enzyme.autodiff(Forward, fb, Duplicated(copy(b1), db1))
eres[1]
return only(Enzyme.autodiff(set_runtime_activity(Forward), fnice,
Const(A), Duplicated(b1, db1), Const(alg)))
end |> collect
@show en_jac

@test en_jacfd_jac rtol=1e-4

function fA(A)
prob = LinearProblem(A, b1)

sol1 = solve(prob, alg)
fA_closure = A -> fnice(A, b1, alg)

sum(sol1.u)
end
fA(A)

fd_jac = FiniteDiff.finite_difference_jacobian(fA, A) |> vec
fd_jac = FiniteDiff.finite_difference_jacobian(fA_closure, A) |> vec
@show fd_jac

en_jac = map(onehot(A)) do dA
eres = Enzyme.autodiff(Forward, fA, Duplicated(copy(A), dA))
eres[1]
end |> collect
return only(Enzyme.autodiff(set_runtime_activity(Forward), fnice,
Duplicated(A, dA), Const(b1), Const(alg)))
end |> collect
@show en_jac

@test en_jacfd_jac rtol=1e-4
Expand Down
5 changes: 4 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@ if GROUP == "All" || GROUP == "Core"
@time @safetestset "Non-Square Tests" include("nonsquare.jl")
@time @safetestset "SparseVector b Tests" include("sparse_vector.jl")
@time @safetestset "Default Alg Tests" include("default_algs.jl")
@time @safetestset "Enzyme Derivative Rules" include("enzyme.jl")
@time @safetestset "Adjoint Sensitivity" include("adjoint.jl")
@time @safetestset "Traits" include("traits.jl")
@time @safetestset "BandedMatrices" include("banded.jl")
@time @safetestset "Static Arrays" include("static_arrays.jl")
end

if GROUP == "All" || GROUP == "Enzyme"
@time @safetestset "Enzyme Derivative Rules" include("enzyme.jl")
end

if GROUP == "LinearSolveCUDA"
Pkg.activate("gpu")
Pkg.develop(PackageSpec(path = dirname(@__DIR__)))
Expand Down

0 comments on commit 174b51a

Please sign in to comment.