diff --git a/ext/LinearSolveHYPREExt.jl b/ext/LinearSolveHYPREExt.jl index 0a3dcb1cc..279aba75d 100644 --- a/ext/LinearSolveHYPREExt.jl +++ b/ext/LinearSolveHYPREExt.jl @@ -90,7 +90,7 @@ function SciMLBase.init(prob::LinearProblem, alg::HYPREAlgorithm, cache = LinearCache{ typeof(A), typeof(b), typeof(u0), typeof(p), typeof(alg), Tc, typeof(Pl), typeof(Pr), typeof(reltol), - typeof(__issquare(assumptions), typeof(sensealg)) + typeof(__issquare(assumptions)), typeof(sensealg) }(A, b, u0, p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol, maxiters, verbose, assumptions, sensealg) return cache diff --git a/src/adjoint.jl b/src/adjoint.jl index 3d46d8048..550bb2bd6 100644 --- a/src/adjoint.jl +++ b/src/adjoint.jl @@ -47,11 +47,7 @@ function CRC.rrule(::typeof(SciMLBase.solve), prob::LinearProblem, A_ = alias_A ? deepcopy(A) : A end else - if alg isa DefaultLinearSolver - A_ = deepcopy(A) - else - A_ = alias_A ? deepcopy(A) : A - end + A_ = deepcopy(A) end sol = solve!(cache) diff --git a/test/adjoint.jl b/test/adjoint.jl index ecc9714eb..26a72016f 100644 --- a/test/adjoint.jl +++ b/test/adjoint.jl @@ -51,15 +51,34 @@ end dA, db1, db2 = Zygote.gradient(f3, A, b1, b1) -#= Needs ForwardDiff rules -dA2 = ForwardDiff.gradient(x -> f3(x, eltype(x).(b1), eltype(x).(b1)), copy(A)) -db12 = ForwardDiff.gradient(x -> f3(eltype(x).(A), x, eltype(x).(b1)), copy(b1)) -db22 = ForwardDiff.gradient(x -> f3(eltype(x).(A), eltype(x).(b1), x), copy(b1)) +dA2 = FiniteDiff.finite_difference_gradient( + x -> f4(x, eltype(x).(b1), eltype(x).(b1)), copy(A)) +db12 = FiniteDiff.finite_difference_gradient( + x -> f4(eltype(x).(A), x, eltype(x).(b1)), copy(b1)) +db22 = FiniteDiff.finite_difference_gradient( + x -> f4(eltype(x).(A), eltype(x).(b1), x), copy(b1)) + +@test dA≈dA2 atol=5e-5 +@test db1 ≈ db12 +@test db2 ≈ db22 + +function f4(A, b1, b2; alg = LUFactorization()) + prob = LinearProblem(A, b1) + sol1 = solve(prob, alg; sensealg = LinearSolveAdjoint(; linsolve = KrylovJL_LSMR())) + prob = LinearProblem(A, b2) + sol2 = solve(prob, alg; sensealg = LinearSolveAdjoint(; linsolve = KrylovJL_GMRES())) + norm(sol1.u .+ sol2.u) +end + +dA, db1, db2 = Zygote.gradient(f4, A, b1, b1) + +dA2 = ForwardDiff.gradient(x -> f4(x, eltype(x).(b1), eltype(x).(b1)), copy(A)) +db12 = ForwardDiff.gradient(x -> f4(eltype(x).(A), x, eltype(x).(b1)), copy(b1)) +db22 = ForwardDiff.gradient(x -> f4(eltype(x).(A), eltype(x).(b1), x), copy(b1)) -@test dA ≈ dA2 atol=5e-5 +@test dA≈dA2 atol=5e-5 @test db1 ≈ db12 @test db2 ≈ db22 -=# A = rand(n, n); b1 = rand(n);