From 8fc4ae3b598aefd733efa551bcba9821a9466129 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Sun, 24 Sep 2023 21:36:43 -0400 Subject: [PATCH] WIP: fix KrylovJL_GMRES with Enzyme --- ext/LinearSolveEnzymeExt.jl | 4 +++- test/enzyme.jl | 8 +++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/ext/LinearSolveEnzymeExt.jl b/ext/LinearSolveEnzymeExt.jl index fea332dcd..f1bc4dbd1 100644 --- a/ext/LinearSolveEnzymeExt.jl +++ b/ext/LinearSolveEnzymeExt.jl @@ -4,11 +4,13 @@ using LinearSolve using LinearSolve.LinearAlgebra isdefined(Base, :get_extension) ? (import Enzyme) : (import ..Enzyme) - using Enzyme using EnzymeCore +@inline EnzymeCore.EnzymeRules.inactive_type(v::Type{LinearSolve.KrylovJL}) = true +@inline EnzymeCore.EnzymeRules.inactive_type(v::Type{LinearSolve.Krylov.GmresSolver}) = true + function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem} res = func.val(prob.val, alg.val; kwargs...) dres = if EnzymeRules.width(config) == 1 diff --git a/test/enzyme.jl b/test/enzyme.jl index 62904c055..7196f38f2 100644 --- a/test/enzyme.jl +++ b/test/enzyme.jl @@ -107,7 +107,7 @@ Enzyme.autodiff(Reverse, f2, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), @test db1 ≈ db12 @test db2 ≈ db22 -#= + function f3(A, b1, b2; alg = KrylovJL_GMRES()) prob = LinearProblem(A, b1) cache = init(prob, alg) @@ -117,9 +117,11 @@ function f3(A, b1, b2; alg = KrylovJL_GMRES()) norm(s1 + s2) end +dA = zeros(n, n); +db1 = zeros(n); +db2 = zeros(n); Enzyme.autodiff(Reverse, f3, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2)) @test dA ≈ dA2 atol=5e-5 @test db1 ≈ db12 -@test db2 ≈ db22 -=# \ No newline at end of file +@test db2 ≈ db22 \ No newline at end of file