diff --git a/Project.toml b/Project.toml index 3cf05941..3f7c6d19 100644 --- a/Project.toml +++ b/Project.toml @@ -36,7 +36,6 @@ BandedMatrices = "aae01518-5342-5314-be14-df237901396f" BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" CUDSS = "45b445bb-4962-46a0-9369-b4df9d0f772e" -Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" FastAlmostBandedMatrices = "9d29842c-ecb8-4973-b1e9-a27b1157504e" HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771" @@ -52,8 +51,8 @@ LinearSolveBandedMatricesExt = "BandedMatrices" LinearSolveBlockDiagonalsExt = "BlockDiagonals" LinearSolveCUDAExt = "CUDA" LinearSolveCUDSSExt = "CUDSS" -LinearSolveEnzymeExt = ["Enzyme", "EnzymeCore"] -LinearSolveFastAlmostBandedMatricesExt = ["FastAlmostBandedMatrices"] +LinearSolveEnzymeExt = "EnzymeCore" +LinearSolveFastAlmostBandedMatricesExt = "FastAlmostBandedMatrices" LinearSolveHYPREExt = "HYPRE" LinearSolveIterativeSolversExt = "IterativeSolvers" LinearSolveKernelAbstractionsExt = "KernelAbstractions" @@ -84,7 +83,7 @@ GPUArraysCore = "0.1.6" HYPRE = "1.4.0" InteractiveUtils = "1.10" IterativeSolvers = "0.9.3" -JET = "0.8.28" +JET = "0.8.28, 0.9" KLU = "0.6" KernelAbstractions = "0.9.16" Krylov = "0.9" diff --git a/ext/LinearSolveEnzymeExt.jl b/ext/LinearSolveEnzymeExt.jl index 84884c04..2e3b3adc 100644 --- a/ext/LinearSolveEnzymeExt.jl +++ b/ext/LinearSolveEnzymeExt.jl @@ -2,13 +2,9 @@ module LinearSolveEnzymeExt using LinearSolve using LinearSolve.LinearAlgebra -isdefined(Base, :get_extension) ? (import Enzyme) : (import ..Enzyme) - -using Enzyme - using EnzymeCore -function EnzymeCore.EnzymeRules.forward(config::ConfigWidth{1}, +function EnzymeCore.EnzymeRules.forward(config::EnzymeCore.EnzymeRules.FwdConfigWidth{1}, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem} @assert !(prob isa Const) @@ -41,7 +37,8 @@ function EnzymeCore.EnzymeRules.forward(config::ConfigWidth{1}, end end -function EnzymeCore.EnzymeRules.forward(config::ConfigWidth{1}, func::Const{typeof(LinearSolve.solve!)}, +function EnzymeCore.EnzymeRules.forward( + config::EnzymeCore.EnzymeRules.FwdConfigWidth{1}, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache} @assert !(linsolve isa Const)