From ac5693c674460b584f9a23dd73f98459a8491d98 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Sun, 15 Sep 2024 20:21:27 -0500 Subject: [PATCH 1/6] Adapt to pending Enzyme breaking change --- Project.toml | 2 +- ext/LinearSolveEnzymeExt.jl | 40 +++++++++++++++++++++++++++---------- 2 files changed, 31 insertions(+), 11 deletions(-) diff --git a/Project.toml b/Project.toml index 1d9172b0e..6917c9f95 100644 --- a/Project.toml +++ b/Project.toml @@ -74,7 +74,7 @@ ChainRulesCore = "1.22" ConcreteStructs = "0.2.3" DocStringExtensions = "0.9.3" EnumX = "1.0.4" -Enzyme = "0.11.15, 0.12" +Enzyme = "0.13" EnzymeCore = "0.6.5, 0.7" FastAlmostBandedMatrices = "0.1" FastLapackInterface = "2" diff --git a/ext/LinearSolveEnzymeExt.jl b/ext/LinearSolveEnzymeExt.jl index c8d89e874..84884c040 100644 --- a/ext/LinearSolveEnzymeExt.jl +++ b/ext/LinearSolveEnzymeExt.jl @@ -8,13 +8,17 @@ using Enzyme using EnzymeCore -function EnzymeCore.EnzymeRules.forward( +function EnzymeCore.EnzymeRules.forward(config::ConfigWidth{1}, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem} @assert !(prob isa Const) res = func.val(prob.val, alg.val; kwargs...) if RT <: Const - return res + if EnzymeRules.needs_primal(config) + return res + else + return nothing + end end dres = func.val(prob.dval, alg.val; kwargs...) dres.b .= res.b == dres.b ? zero(dres.b) : dres.b @@ -25,9 +29,19 @@ function EnzymeCore.EnzymeRules.forward( return Duplicated(res, dres) end error("Unsupported return type $RT") + + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + Duplicated(res, dres) + elseif EnzymeRules.needs_shadow(config) + dres + elseif EnzymeRules.needs_primal(config) + res + else + nothing + end end -function EnzymeCore.EnzymeRules.forward(func::Const{typeof(LinearSolve.solve!)}, +function EnzymeCore.EnzymeRules.forward(config::ConfigWidth{1}, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache} @assert !(linsolve isa Const) @@ -35,7 +49,11 @@ function EnzymeCore.EnzymeRules.forward(func::Const{typeof(LinearSolve.solve!)}, res = func.val(linsolve.val; kwargs...) if RT <: Const - return res + if EnzymeRules.needs_primal(config) + return res + else + return nothing + end end if linsolve.val.alg isa LinearSolve.AbstractKrylovSubspaceMethod error("Algorithm $(_linsolve.alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling") @@ -50,13 +68,15 @@ function EnzymeCore.EnzymeRules.forward(func::Const{typeof(LinearSolve.solve!)}, linsolve.val.b = b - if RT <: DuplicatedNoNeed - return dres - elseif RT <: Duplicated - return Duplicated(res, dres) + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + Duplicated(res, dres) + elseif EnzymeRules.needs_shadow(config) + dres + elseif EnzymeRules.needs_primal(config) + res + else + nothing end - - return Duplicated(res, dres) end function EnzymeCore.EnzymeRules.augmented_primal( From 574b0d8e356bdadc7940700c3f4f57c2bf96520b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 26 Sep 2024 05:40:36 -0400 Subject: [PATCH 2/6] fix: no ConfigWidth --- Project.toml | 7 +++---- ext/LinearSolveEnzymeExt.jl | 9 +++------ 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/Project.toml b/Project.toml index 3cf059417..3f7c6d198 100644 --- a/Project.toml +++ b/Project.toml @@ -36,7 +36,6 @@ BandedMatrices = "aae01518-5342-5314-be14-df237901396f" BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" CUDSS = "45b445bb-4962-46a0-9369-b4df9d0f772e" -Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" FastAlmostBandedMatrices = "9d29842c-ecb8-4973-b1e9-a27b1157504e" HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771" @@ -52,8 +51,8 @@ LinearSolveBandedMatricesExt = "BandedMatrices" LinearSolveBlockDiagonalsExt = "BlockDiagonals" LinearSolveCUDAExt = "CUDA" LinearSolveCUDSSExt = "CUDSS" -LinearSolveEnzymeExt = ["Enzyme", "EnzymeCore"] -LinearSolveFastAlmostBandedMatricesExt = ["FastAlmostBandedMatrices"] +LinearSolveEnzymeExt = "EnzymeCore" +LinearSolveFastAlmostBandedMatricesExt = "FastAlmostBandedMatrices" LinearSolveHYPREExt = "HYPRE" LinearSolveIterativeSolversExt = "IterativeSolvers" LinearSolveKernelAbstractionsExt = "KernelAbstractions" @@ -84,7 +83,7 @@ GPUArraysCore = "0.1.6" HYPRE = "1.4.0" InteractiveUtils = "1.10" IterativeSolvers = "0.9.3" -JET = "0.8.28" +JET = "0.8.28, 0.9" KLU = "0.6" KernelAbstractions = "0.9.16" Krylov = "0.9" diff --git a/ext/LinearSolveEnzymeExt.jl b/ext/LinearSolveEnzymeExt.jl index 84884c040..2e3b3adc5 100644 --- a/ext/LinearSolveEnzymeExt.jl +++ b/ext/LinearSolveEnzymeExt.jl @@ -2,13 +2,9 @@ module LinearSolveEnzymeExt using LinearSolve using LinearSolve.LinearAlgebra -isdefined(Base, :get_extension) ? (import Enzyme) : (import ..Enzyme) - -using Enzyme - using EnzymeCore -function EnzymeCore.EnzymeRules.forward(config::ConfigWidth{1}, +function EnzymeCore.EnzymeRules.forward(config::EnzymeCore.EnzymeRules.FwdConfigWidth{1}, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem} @assert !(prob isa Const) @@ -41,7 +37,8 @@ function EnzymeCore.EnzymeRules.forward(config::ConfigWidth{1}, end end -function EnzymeCore.EnzymeRules.forward(config::ConfigWidth{1}, func::Const{typeof(LinearSolve.solve!)}, +function EnzymeCore.EnzymeRules.forward( + config::EnzymeCore.EnzymeRules.FwdConfigWidth{1}, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache} @assert !(linsolve isa Const) From ef14bab2f8b2962f9c009878475034bd8f94644d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 26 Sep 2024 05:59:23 -0400 Subject: [PATCH 3/6] test: separate out the enzyme testing --- .github/workflows/Downgrade.yml | 1 + .github/workflows/Tests.yml | 1 + ext/LinearSolveEnzymeExt.jl | 41 +++++++++++++++------------------ test/enzyme.jl | 1 - test/runtests.jl | 5 +++- 5 files changed, 24 insertions(+), 25 deletions(-) diff --git a/.github/workflows/Downgrade.yml b/.github/workflows/Downgrade.yml index ccfa12ae9..09507cf54 100644 --- a/.github/workflows/Downgrade.yml +++ b/.github/workflows/Downgrade.yml @@ -18,6 +18,7 @@ jobs: version: ['1'] group: - Core + - Enzyme steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 diff --git a/.github/workflows/Tests.yml b/.github/workflows/Tests.yml index b85feb422..33741cd5c 100644 --- a/.github/workflows/Tests.yml +++ b/.github/workflows/Tests.yml @@ -31,6 +31,7 @@ jobs: - "LinearSolveHYPRE" - "LinearSolvePardiso" - "LinearSolveBandedMatrices" + - "Enzyme" uses: "SciML/.github/.github/workflows/tests.yml@v1" with: group: "${{ matrix.group }}" diff --git a/ext/LinearSolveEnzymeExt.jl b/ext/LinearSolveEnzymeExt.jl index 2e3b3adc5..ddc37f630 100644 --- a/ext/LinearSolveEnzymeExt.jl +++ b/ext/LinearSolveEnzymeExt.jl @@ -3,8 +3,9 @@ module LinearSolveEnzymeExt using LinearSolve using LinearSolve.LinearAlgebra using EnzymeCore +using EnzymeCore: EnzymeRules -function EnzymeCore.EnzymeRules.forward(config::EnzymeCore.EnzymeRules.FwdConfigWidth{1}, +function EnzymeRules.forward(config::EnzymeRules.FwdConfigWidth{1}, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem} @assert !(prob isa Const) @@ -19,26 +20,20 @@ function EnzymeCore.EnzymeRules.forward(config::EnzymeCore.EnzymeRules.FwdConfig dres = func.val(prob.dval, alg.val; kwargs...) dres.b .= res.b == dres.b ? zero(dres.b) : dres.b dres.A .= res.A == dres.A ? zero(dres.A) : dres.A - if RT <: DuplicatedNoNeed - return dres - elseif RT <: Duplicated - return Duplicated(res, dres) - end - error("Unsupported return type $RT") if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) - Duplicated(res, dres) + return Duplicated(res, dres) elseif EnzymeRules.needs_shadow(config) - dres + return dres elseif EnzymeRules.needs_primal(config) - res + return res else - nothing + return nothing end end -function EnzymeCore.EnzymeRules.forward( - config::EnzymeCore.EnzymeRules.FwdConfigWidth{1}, func::Const{typeof(LinearSolve.solve!)}, +function EnzymeRules.forward( + config::EnzymeRules.FwdConfigWidth{1}, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache} @assert !(linsolve isa Const) @@ -66,17 +61,17 @@ function EnzymeCore.EnzymeRules.forward( linsolve.val.b = b if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) - Duplicated(res, dres) + return Duplicated(res, dres) elseif EnzymeRules.needs_shadow(config) - dres + return dres elseif EnzymeRules.needs_primal(config) - res + return res else - nothing + return nothing end end -function EnzymeCore.EnzymeRules.augmented_primal( +function EnzymeRules.augmented_primal( config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem} @@ -111,10 +106,10 @@ function EnzymeCore.EnzymeRules.augmented_primal( (dval.b for dval in prob.dval) end - return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, (d_A, d_b, prob_d_A, prob_d_b)) + return EnzymeRules.AugmentedReturn(res, dres, (d_A, d_b, prob_d_A, prob_d_b)) end -function EnzymeCore.EnzymeRules.reverse( +function EnzymeRules.reverse( config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, cache, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem} @@ -148,7 +143,7 @@ end # y=inv(A) B # dA −= z y^T # dB += z, where z = inv(A^T) dy -function EnzymeCore.EnzymeRules.augmented_primal( +function EnzymeRules.augmented_primal( config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache} @@ -201,10 +196,10 @@ function EnzymeCore.EnzymeRules.augmented_primal( cachesolve = deepcopy(linsolve.val) cache = (copy(res.u), resvals, cachesolve, dAs, dbs) - return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, cache) + return EnzymeRules.AugmentedReturn(res, dres, cache) end -function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.solve!)}, +function EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, cache, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache} y, dys, _linsolve, dAs, dbs = cache diff --git a/test/enzyme.jl b/test/enzyme.jl index ac552a45a..323f3e607 100644 --- a/test/enzyme.jl +++ b/test/enzyme.jl @@ -1,7 +1,6 @@ using Enzyme, ForwardDiff using LinearSolve, LinearAlgebra, Test using FiniteDiff -using SafeTestsets n = 4 A = rand(n, n); diff --git a/test/runtests.jl b/test/runtests.jl index 8d7b626fa..44be59058 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -15,13 +15,16 @@ if GROUP == "All" || GROUP == "Core" @time @safetestset "Non-Square Tests" include("nonsquare.jl") @time @safetestset "SparseVector b Tests" include("sparse_vector.jl") @time @safetestset "Default Alg Tests" include("default_algs.jl") - @time @safetestset "Enzyme Derivative Rules" include("enzyme.jl") @time @safetestset "Adjoint Sensitivity" include("adjoint.jl") @time @safetestset "Traits" include("traits.jl") @time @safetestset "BandedMatrices" include("banded.jl") @time @safetestset "Static Arrays" include("static_arrays.jl") end +if GROUP == "All" || GROUP == "Enzyme" + @time @safetestset "Enzyme Derivative Rules" include("enzyme.jl") +end + if GROUP == "LinearSolveCUDA" Pkg.activate("gpu") Pkg.develop(PackageSpec(path = dirname(@__DIR__))) From f5aefbe9811b77766c79866cf127023b9d43e886 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 26 Sep 2024 12:10:08 -0400 Subject: [PATCH 4/6] test: enable runtime activity for now --- test/enzyme.jl | 42 +++++++++++++++++------------------------- 1 file changed, 17 insertions(+), 25 deletions(-) diff --git a/test/enzyme.jl b/test/enzyme.jl index 323f3e607..9192b63a6 100644 --- a/test/enzyme.jl +++ b/test/enzyme.jl @@ -177,47 +177,39 @@ Enzyme.autodiff(Reverse, f3, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), A = rand(n, n); dA = zeros(n, n); b1 = rand(n); -for alg in ( + +function fnice(A, b, alg) + prob = LinearProblem(A, b) + sol1 = solve(prob, alg) + return sum(sol1.u) +end + +@testset for alg in ( LUFactorization(), RFLUFactorization() # KrylovJL_GMRES(), fails ) - @show alg - function fb(b) - prob = LinearProblem(A, b) - - sol1 = solve(prob, alg) + fb_closure = b -> fnice(A, b, alg) - sum(sol1.u) - end - fb(b1) - - fd_jac = FiniteDiff.finite_difference_jacobian(fb, b1) |> vec + fd_jac = FiniteDiff.finite_difference_jacobian(fb_closure, b1) |> vec @show fd_jac en_jac = map(onehot(b1)) do db1 - eres = Enzyme.autodiff(Forward, fb, Duplicated(copy(b1), db1)) - eres[1] + return only(Enzyme.autodiff(set_runtime_activity(Forward), fnice, + Const(A), Duplicated(b1, db1), Const(alg))) end |> collect @show en_jac @test en_jac≈fd_jac rtol=1e-4 - function fA(A) - prob = LinearProblem(A, b1) - - sol1 = solve(prob, alg) + fA_closure = A -> fnice(A, b1, alg) - sum(sol1.u) - end - fA(A) - - fd_jac = FiniteDiff.finite_difference_jacobian(fA, A) |> vec + fd_jac = FiniteDiff.finite_difference_jacobian(fA_closure, A) |> vec @show fd_jac en_jac = map(onehot(A)) do dA - eres = Enzyme.autodiff(Forward, fA, Duplicated(copy(A), dA)) - eres[1] - end |> collect + return only(Enzyme.autodiff(set_runtime_activity(Forward), fnice, + Duplicated(A, dA), Const(b1), Const(alg))) + end |> collect |> (x -> reshape(x, n, n)) @show en_jac @test en_jac≈fd_jac rtol=1e-4 From 0d8a47a4046574944f423d36460b615b5695b48c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 26 Sep 2024 12:28:39 -0400 Subject: [PATCH 5/6] fix: forward rules aliasing issue --- ext/LinearSolveEnzymeExt.jl | 18 +++++++++++------- test/enzyme.jl | 2 +- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/ext/LinearSolveEnzymeExt.jl b/ext/LinearSolveEnzymeExt.jl index ddc37f630..abd2232e1 100644 --- a/ext/LinearSolveEnzymeExt.jl +++ b/ext/LinearSolveEnzymeExt.jl @@ -17,9 +17,15 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfigWidth{1}, return nothing end end + dres = func.val(prob.dval, alg.val; kwargs...) - dres.b .= res.b == dres.b ? zero(dres.b) : dres.b - dres.A .= res.A == dres.A ? zero(dres.A) : dres.A + + if dres.b == res.b + dres.b .= false + end + if dres.A == res.A + dres.A .= false + end if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) return Duplicated(res, dres) @@ -50,14 +56,12 @@ function EnzymeRules.forward( if linsolve.val.alg isa LinearSolve.AbstractKrylovSubspaceMethod error("Algorithm $(_linsolve.alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling") end - b = deepcopy(linsolve.val.b) - db = linsolve.dval.b - dA = linsolve.dval.A + res = deepcopy(res) # Without this copy, the next solve will end up mutating the result - linsolve.val.b = db - dA * res.u + b = linsolve.val.b + linsolve.val.b = linsolve.dval.b - linsolve.dval.A * res.u dres = func.val(linsolve.val; kwargs...) - linsolve.val.b = b if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) diff --git a/test/enzyme.jl b/test/enzyme.jl index 9192b63a6..b09c0de54 100644 --- a/test/enzyme.jl +++ b/test/enzyme.jl @@ -209,7 +209,7 @@ end en_jac = map(onehot(A)) do dA return only(Enzyme.autodiff(set_runtime_activity(Forward), fnice, Duplicated(A, dA), Const(b1), Const(alg))) - end |> collect |> (x -> reshape(x, n, n)) + end |> collect @show en_jac @test en_jac≈fd_jac rtol=1e-4 From b2e24b4c13c567522b694a81a480406fb4186f08 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 26 Sep 2024 12:34:44 -0400 Subject: [PATCH 6/6] fix: bump minimum versions --- Project.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index 3f7c6d198..f62a7915d 100644 --- a/Project.toml +++ b/Project.toml @@ -74,7 +74,7 @@ ConcreteStructs = "0.2.3" DocStringExtensions = "0.9.3" EnumX = "1.0.4" Enzyme = "0.13" -EnzymeCore = "0.8" +EnzymeCore = "0.8.1" FastAlmostBandedMatrices = "0.1" FastLapackInterface = "2" FiniteDiff = "2.22" @@ -85,10 +85,10 @@ InteractiveUtils = "1.10" IterativeSolvers = "0.9.3" JET = "0.8.28, 0.9" KLU = "0.6" -KernelAbstractions = "0.9.16" +KernelAbstractions = "0.9.27" Krylov = "0.9" KrylovKit = "0.8" -KrylovPreconditioners = "0.2" +KrylovPreconditioners = "0.3" LazyArrays = "1.8, 2" Libdl = "1.10" LinearAlgebra = "1.10"