diff --git a/.github/workflows/Downgrade.yml b/.github/workflows/Downgrade.yml index ccfa12ae..09507cf5 100644 --- a/.github/workflows/Downgrade.yml +++ b/.github/workflows/Downgrade.yml @@ -18,6 +18,7 @@ jobs: version: ['1'] group: - Core + - Enzyme steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 diff --git a/.github/workflows/Tests.yml b/.github/workflows/Tests.yml index b85feb42..33741cd5 100644 --- a/.github/workflows/Tests.yml +++ b/.github/workflows/Tests.yml @@ -31,6 +31,7 @@ jobs: - "LinearSolveHYPRE" - "LinearSolvePardiso" - "LinearSolveBandedMatrices" + - "Enzyme" uses: "SciML/.github/.github/workflows/tests.yml@v1" with: group: "${{ matrix.group }}" diff --git a/Project.toml b/Project.toml index 58db07d3..f62a7915 100644 --- a/Project.toml +++ b/Project.toml @@ -36,7 +36,6 @@ BandedMatrices = "aae01518-5342-5314-be14-df237901396f" BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" CUDSS = "45b445bb-4962-46a0-9369-b4df9d0f772e" -Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" FastAlmostBandedMatrices = "9d29842c-ecb8-4973-b1e9-a27b1157504e" HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771" @@ -52,8 +51,8 @@ LinearSolveBandedMatricesExt = "BandedMatrices" LinearSolveBlockDiagonalsExt = "BlockDiagonals" LinearSolveCUDAExt = "CUDA" LinearSolveCUDSSExt = "CUDSS" -LinearSolveEnzymeExt = ["Enzyme", "EnzymeCore"] -LinearSolveFastAlmostBandedMatricesExt = ["FastAlmostBandedMatrices"] +LinearSolveEnzymeExt = "EnzymeCore" +LinearSolveFastAlmostBandedMatricesExt = "FastAlmostBandedMatrices" LinearSolveHYPREExt = "HYPRE" LinearSolveIterativeSolversExt = "IterativeSolvers" LinearSolveKernelAbstractionsExt = "KernelAbstractions" @@ -74,8 +73,8 @@ ChainRulesCore = "1.22" ConcreteStructs = "0.2.3" DocStringExtensions = "0.9.3" EnumX = "1.0.4" -Enzyme = "0.11.15, 0.12, 0.13" -EnzymeCore = "0.6.5, 0.7, 0.8" +Enzyme = "0.13" +EnzymeCore = "0.8.1" FastAlmostBandedMatrices = "0.1" FastLapackInterface = "2" FiniteDiff = "2.22" @@ -84,12 +83,12 @@ GPUArraysCore = "0.1.6" HYPRE = "1.4.0" InteractiveUtils = "1.10" IterativeSolvers = "0.9.3" -JET = "0.8.28" +JET = "0.8.28, 0.9" KLU = "0.6" -KernelAbstractions = "0.9.16" +KernelAbstractions = "0.9.27" Krylov = "0.9" KrylovKit = "0.8" -KrylovPreconditioners = "0.2" +KrylovPreconditioners = "0.3" LazyArrays = "1.8, 2" Libdl = "1.10" LinearAlgebra = "1.10" diff --git a/ext/LinearSolveEnzymeExt.jl b/ext/LinearSolveEnzymeExt.jl index c8d89e87..abd2232e 100644 --- a/ext/LinearSolveEnzymeExt.jl +++ b/ext/LinearSolveEnzymeExt.jl @@ -2,32 +2,44 @@ module LinearSolveEnzymeExt using LinearSolve using LinearSolve.LinearAlgebra -isdefined(Base, :get_extension) ? (import Enzyme) : (import ..Enzyme) - -using Enzyme - using EnzymeCore +using EnzymeCore: EnzymeRules -function EnzymeCore.EnzymeRules.forward( +function EnzymeRules.forward(config::EnzymeRules.FwdConfigWidth{1}, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem} @assert !(prob isa Const) res = func.val(prob.val, alg.val; kwargs...) if RT <: Const - return res + if EnzymeRules.needs_primal(config) + return res + else + return nothing + end end + dres = func.val(prob.dval, alg.val; kwargs...) - dres.b .= res.b == dres.b ? zero(dres.b) : dres.b - dres.A .= res.A == dres.A ? zero(dres.A) : dres.A - if RT <: DuplicatedNoNeed - return dres - elseif RT <: Duplicated + + if dres.b == res.b + dres.b .= false + end + if dres.A == res.A + dres.A .= false + end + + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) return Duplicated(res, dres) + elseif EnzymeRules.needs_shadow(config) + return dres + elseif EnzymeRules.needs_primal(config) + return res + else + return nothing end - error("Unsupported return type $RT") end -function EnzymeCore.EnzymeRules.forward(func::Const{typeof(LinearSolve.solve!)}, +function EnzymeRules.forward( + config::EnzymeRules.FwdConfigWidth{1}, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache} @assert !(linsolve isa Const) @@ -35,31 +47,35 @@ function EnzymeCore.EnzymeRules.forward(func::Const{typeof(LinearSolve.solve!)}, res = func.val(linsolve.val; kwargs...) if RT <: Const - return res + if EnzymeRules.needs_primal(config) + return res + else + return nothing + end end if linsolve.val.alg isa LinearSolve.AbstractKrylovSubspaceMethod error("Algorithm $(_linsolve.alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling") end - b = deepcopy(linsolve.val.b) - db = linsolve.dval.b - dA = linsolve.dval.A + res = deepcopy(res) # Without this copy, the next solve will end up mutating the result - linsolve.val.b = db - dA * res.u + b = linsolve.val.b + linsolve.val.b = linsolve.dval.b - linsolve.dval.A * res.u dres = func.val(linsolve.val; kwargs...) - linsolve.val.b = b - if RT <: DuplicatedNoNeed - return dres - elseif RT <: Duplicated + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) return Duplicated(res, dres) + elseif EnzymeRules.needs_shadow(config) + return dres + elseif EnzymeRules.needs_primal(config) + return res + else + return nothing end - - return Duplicated(res, dres) end -function EnzymeCore.EnzymeRules.augmented_primal( +function EnzymeRules.augmented_primal( config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem} @@ -94,10 +110,10 @@ function EnzymeCore.EnzymeRules.augmented_primal( (dval.b for dval in prob.dval) end - return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, (d_A, d_b, prob_d_A, prob_d_b)) + return EnzymeRules.AugmentedReturn(res, dres, (d_A, d_b, prob_d_A, prob_d_b)) end -function EnzymeCore.EnzymeRules.reverse( +function EnzymeRules.reverse( config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, cache, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem} @@ -131,7 +147,7 @@ end # y=inv(A) B # dA −= z y^T # dB += z, where z = inv(A^T) dy -function EnzymeCore.EnzymeRules.augmented_primal( +function EnzymeRules.augmented_primal( config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache} @@ -184,10 +200,10 @@ function EnzymeCore.EnzymeRules.augmented_primal( cachesolve = deepcopy(linsolve.val) cache = (copy(res.u), resvals, cachesolve, dAs, dbs) - return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, cache) + return EnzymeRules.AugmentedReturn(res, dres, cache) end -function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.solve!)}, +function EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, cache, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache} y, dys, _linsolve, dAs, dbs = cache diff --git a/test/enzyme.jl b/test/enzyme.jl index ac552a45..b09c0de5 100644 --- a/test/enzyme.jl +++ b/test/enzyme.jl @@ -1,7 +1,6 @@ using Enzyme, ForwardDiff using LinearSolve, LinearAlgebra, Test using FiniteDiff -using SafeTestsets n = 4 A = rand(n, n); @@ -178,47 +177,39 @@ Enzyme.autodiff(Reverse, f3, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), A = rand(n, n); dA = zeros(n, n); b1 = rand(n); -for alg in ( + +function fnice(A, b, alg) + prob = LinearProblem(A, b) + sol1 = solve(prob, alg) + return sum(sol1.u) +end + +@testset for alg in ( LUFactorization(), RFLUFactorization() # KrylovJL_GMRES(), fails ) - @show alg - function fb(b) - prob = LinearProblem(A, b) - - sol1 = solve(prob, alg) + fb_closure = b -> fnice(A, b, alg) - sum(sol1.u) - end - fb(b1) - - fd_jac = FiniteDiff.finite_difference_jacobian(fb, b1) |> vec + fd_jac = FiniteDiff.finite_difference_jacobian(fb_closure, b1) |> vec @show fd_jac en_jac = map(onehot(b1)) do db1 - eres = Enzyme.autodiff(Forward, fb, Duplicated(copy(b1), db1)) - eres[1] + return only(Enzyme.autodiff(set_runtime_activity(Forward), fnice, + Const(A), Duplicated(b1, db1), Const(alg))) end |> collect @show en_jac @test en_jac≈fd_jac rtol=1e-4 - function fA(A) - prob = LinearProblem(A, b1) - - sol1 = solve(prob, alg) + fA_closure = A -> fnice(A, b1, alg) - sum(sol1.u) - end - fA(A) - - fd_jac = FiniteDiff.finite_difference_jacobian(fA, A) |> vec + fd_jac = FiniteDiff.finite_difference_jacobian(fA_closure, A) |> vec @show fd_jac en_jac = map(onehot(A)) do dA - eres = Enzyme.autodiff(Forward, fA, Duplicated(copy(A), dA)) - eres[1] - end |> collect + return only(Enzyme.autodiff(set_runtime_activity(Forward), fnice, + Duplicated(A, dA), Const(b1), Const(alg))) + end |> collect @show en_jac @test en_jac≈fd_jac rtol=1e-4 diff --git a/test/runtests.jl b/test/runtests.jl index 8d7b626f..44be5905 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -15,13 +15,16 @@ if GROUP == "All" || GROUP == "Core" @time @safetestset "Non-Square Tests" include("nonsquare.jl") @time @safetestset "SparseVector b Tests" include("sparse_vector.jl") @time @safetestset "Default Alg Tests" include("default_algs.jl") - @time @safetestset "Enzyme Derivative Rules" include("enzyme.jl") @time @safetestset "Adjoint Sensitivity" include("adjoint.jl") @time @safetestset "Traits" include("traits.jl") @time @safetestset "BandedMatrices" include("banded.jl") @time @safetestset "Static Arrays" include("static_arrays.jl") end +if GROUP == "All" || GROUP == "Enzyme" + @time @safetestset "Enzyme Derivative Rules" include("enzyme.jl") +end + if GROUP == "LinearSolveCUDA" Pkg.activate("gpu") Pkg.develop(PackageSpec(path = dirname(@__DIR__)))