Skip to content

Commit

Permalink
More tests and some safety
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Feb 24, 2024
1 parent 7c1f1b2 commit 2493dca
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 12 deletions.
2 changes: 1 addition & 1 deletion ext/LinearSolveHYPREExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ function SciMLBase.init(prob::LinearProblem, alg::HYPREAlgorithm,
cache = LinearCache{
typeof(A), typeof(b), typeof(u0), typeof(p), typeof(alg), Tc,
typeof(Pl), typeof(Pr), typeof(reltol),
typeof(__issquare(assumptions), typeof(sensealg))
typeof(__issquare(assumptions)), typeof(sensealg)
}(A, b, u0, p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol,
maxiters, verbose, assumptions, sensealg)
return cache
Expand Down
6 changes: 1 addition & 5 deletions src/adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,7 @@ function CRC.rrule(::typeof(SciMLBase.solve), prob::LinearProblem,
A_ = alias_A ? deepcopy(A) : A
end
else
if alg isa DefaultLinearSolver
A_ = deepcopy(A)
else
A_ = alias_A ? deepcopy(A) : A
end
A_ = deepcopy(A)
end

sol = solve!(cache)
Expand Down
31 changes: 25 additions & 6 deletions test/adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,34 @@ end

dA, db1, db2 = Zygote.gradient(f3, A, b1, b1)

#= Needs ForwardDiff rules
dA2 = ForwardDiff.gradient(x -> f3(x, eltype(x).(b1), eltype(x).(b1)), copy(A))
db12 = ForwardDiff.gradient(x -> f3(eltype(x).(A), x, eltype(x).(b1)), copy(b1))
db22 = ForwardDiff.gradient(x -> f3(eltype(x).(A), eltype(x).(b1), x), copy(b1))
dA2 = FiniteDiff.finite_difference_gradient(
x -> f4(x, eltype(x).(b1), eltype(x).(b1)), copy(A))
db12 = FiniteDiff.finite_difference_gradient(
x -> f4(eltype(x).(A), x, eltype(x).(b1)), copy(b1))
db22 = FiniteDiff.finite_difference_gradient(
x -> f4(eltype(x).(A), eltype(x).(b1), x), copy(b1))

@test dAdA2 atol=5e-5
@test db1 db12
@test db2 db22

function f4(A, b1, b2; alg = LUFactorization())
prob = LinearProblem(A, b1)
sol1 = solve(prob, alg; sensealg = LinearSolveAdjoint(; linsolve = KrylovJL_LSMR()))
prob = LinearProblem(A, b2)
sol2 = solve(prob, alg; sensealg = LinearSolveAdjoint(; linsolve = KrylovJL_GMRES()))
norm(sol1.u .+ sol2.u)
end

dA, db1, db2 = Zygote.gradient(f4, A, b1, b1)

dA2 = ForwardDiff.gradient(x -> f4(x, eltype(x).(b1), eltype(x).(b1)), copy(A))
db12 = ForwardDiff.gradient(x -> f4(eltype(x).(A), x, eltype(x).(b1)), copy(b1))
db22 = ForwardDiff.gradient(x -> f4(eltype(x).(A), eltype(x).(b1), x), copy(b1))

@test dAdA2 atol=5e-5
@test dAdA2 atol=5e-5
@test db1 db12
@test db2 db22
=#

A = rand(n, n);
b1 = rand(n);
Expand Down

0 comments on commit 2493dca

Please sign in to comment.