Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ForwardDiff rules #434

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FastAlmostBandedMatrices = "9d29842c-ecb8-4973-b1e9-a27b1157504e"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
Expand All @@ -48,6 +49,7 @@ LinearSolveBandedMatricesExt = "BandedMatrices"
LinearSolveBlockDiagonalsExt = "BlockDiagonals"
LinearSolveCUDAExt = "CUDA"
LinearSolveEnzymeExt = "Enzyme"
LinearSolveForwardDiff = "ForwardDiff"
LinearSolveHYPREExt = "HYPRE"
LinearSolveIterativeSolversExt = "IterativeSolvers"
LinearSolveKernelAbstractionsExt = "KernelAbstractions"
Expand All @@ -66,6 +68,7 @@ DocStringExtensions = "0.9"
EnumX = "1"
EnzymeCore = "0.6"
FastLapackInterface = "2"
ForwardDiff = "0.10"
GPUArraysCore = "0.1"
HYPRE = "1.4.0"
InteractiveUtils = "1.6"
Expand Down
88 changes: 88 additions & 0 deletions ext/LinearSolveForwardDiff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
module LinearSolveForwardDiff

using LinearSolve
using InteractiveUtils
isdefined(Base, :get_extension) ?
(import ForwardDiff; using ForwardDiff: Dual) :
(import ..ForwardDiff; using ..ForwardDiff: Dual)
Comment on lines +5 to +7
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only 1.9+ is supported now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure I understand. What do you mean?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

basically you dont need to do this anymore, just the first import line works


function _solve!(cache, alg, dAs, dbs, A, b, T; kwargs...)
@assert !(eltype(first(dAs)) isa Dual)
@assert !(eltype(first(dbs)) isa Dual)
@assert !(eltype(A) isa Dual)
@assert !(eltype(b) isa Dual)
reltol = cache.reltol isa Dual ? ForwardDiff.value(cache.reltol) : cache.reltol
abstol = cache.abstol isa Dual ? ForwardDiff.value(cache.abstol) : cache.abstol
u = eltype(cache.u) <: Dual ? ForwardDiff.value.(cache.u) : cache.u
cacheval = cache.cacheval isa Tuple ? cache.cacheval[1] : cache.cacheval
cacheval = eltype(cacheval.factors) <: Dual ? begin
LinearSolve.LinearAlgebra.LU(ForwardDiff.value.(cacheval.factors), cacheval.ipiv, cacheval.info)

Check warning on line 19 in ext/LinearSolveForwardDiff.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveForwardDiff.jl#L9-L19

Added lines #L9 - L19 were not covered by tests
end : cacheval
cacheval = cache.cacheval isa Tuple ? (cacheval, cache.cacheval[2]) : cacheval

Check warning on line 21 in ext/LinearSolveForwardDiff.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveForwardDiff.jl#L21

Added line #L21 was not covered by tests

cache2 = remake(cache; A, b, u, reltol, abstol, cacheval)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Being forced to remake cache in order to solve the non-dual version. Is there some other way we can replace Dual Array with a regular array?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you want to hook into init. In theory in init what you can do is un-dual the user inputs that are dual, but tag the cache in such a way that in solve! you end up doing two (or number of chunk size + 1) solves and reconstruct the resulting dual numbers in the output.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or rather, it's just one solve! call but in a batched form.

res = LinearSolve.solve!(cache2, alg, kwargs...) |> deepcopy
dresus = reduce(hcat, map(dAs, dbs) do dA, db
cache2.b = db - dA * res.u
dres = LinearSolve.solve!(cache2, alg, kwargs...)
deepcopy(dres.u)

Check warning on line 28 in ext/LinearSolveForwardDiff.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveForwardDiff.jl#L23-L28

Added lines #L23 - L28 were not covered by tests
end)
Comment on lines +24 to +29
Copy link
Contributor Author

@sharanry sharanry Nov 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needing to deepcopy the results of the solves as they are being overwritten by subsequent solves when reusing the cache.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if you hook into init and do a single batched solve then this is handled.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any documentation on how to do batched solves? I am unable to find how to do this anywhere. The possi bly closest thing I could find was https://discourse.julialang.org/t/batched-lu-solves-or-factorizations-with-sparse-matrices/106019/2 -- however, couldn't find the right function call.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's just A\B matrix instead of A\b vector

Copy link
Contributor Author

@sharanry sharanry Dec 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not entirely sure what you mean in the context of LinearSolve.jl.

n = 4
A = rand(n, n)
B = rand(n, n)

A \ B  # works

mapreduce(hcat, eachcol(B)) do b
    A \ b
end # works

mapreduce(hcat, eachcol(B)) do b
    prob = LinearProblem(A, b)
    sol = solve(prob)
    sol.u
end # works

begin
    prob = LinearProblem(A, B)
    sol = solve(prob)  # errors
    sol.u
end

Error:

ERROR: MethodError: no method matching ldiv!(::Vector{Float64}, ::LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, ::Matrix{Float64})

Closest candidates are:
  ldiv!(::Any, ::Sparspak.SpkSparseSolver.SparseSolver{IT, FT}, ::Any) where {IT, FT}
   @ Sparspak ~/.julia/packages/Sparspak/oqBYl/src/SparseCSCInterface/SparseCSCInterface.jl:263
  ldiv!(::Any, ::LinearSolve.InvPreconditioner, ::Any)
   @ LinearSolve ~/code/enzyme_playground/LS_FD/src/preconditioners.jl:30
  ldiv!(::Any, ::LinearSolve.ComposePreconditioner, ::Any)
   @ LinearSolve ~/code/enzyme_playground/LS_FD/src/preconditioners.jl:17
  ...

Stacktrace:
 [1] _ldiv!(x::Vector{Float64}, A::LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, b::Matrix{Float64})
   @ LinearSolve ~/code/enzyme_playground/LS_FD/src/factorization.jl:11
 [2] macro expansion
   @ ~/code/enzyme_playground/LS_FD/src/LinearSolve.jl:135 [inlined]
 [3] solve!(cache::LinearSolve.LinearCache{Matrix{Float64}, Matrix{Float64}, Vector{Float64}, SciMLBase.NullParameters, LUFactorization{LinearAlgebra.RowMaximum}, LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, IdentityOperator, IdentityOperator, Float64, Bool}, alg::LUFactorization{LinearAlgebra.RowMaximum}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
   @ LinearSolve ~/code/enzyme_playground/LS_FD/src/LinearSolve.jl:127
 [4] solve!(cache::LinearSolve.LinearCache{Matrix{Float64}, Matrix{Float64}, Vector{Float64}, SciMLBase.NullParameters, LUFactorization{LinearAlgebra.RowMaximum}, LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, IdentityOperator, IdentityOperator, Float64, Bool}, alg::LUFactorization{LinearAlgebra.RowMaximum})
   @ LinearSolve ~/code/enzyme_playground/LS_FD/src/LinearSolve.jl:127
 [5] solve!(::LinearSolve.LinearCache{Matrix{Float64}, Matrix{Float64}, Vector{Float64}, SciMLBase.NullParameters, LUFactorization{LinearAlgebra.RowMaximum}, LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, IdentityOperator, IdentityOperator, Float64, Bool}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
   @ LinearSolve ~/code/enzyme_playground/LS_FD/src/common.jl:218
 [6] solve!(::LinearSolve.LinearCache{Matrix{Float64}, Matrix{Float64}, Vector{Float64}, SciMLBase.NullParameters, LUFactorization{LinearAlgebra.RowMaximum}, LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, IdentityOperator, IdentityOperator, Float64, Bool})
   @ LinearSolve ~/code/enzyme_playground/LS_FD/src/common.jl:217
 [7] solve(::LinearProblem{Nothing, true, Matrix{Float64}, Matrix{Float64}, SciMLBase.NullParameters, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, ::LUFactorization{LinearAlgebra.RowMaximum}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
   @ LinearSolve ~/code/enzyme_playground/LS_FD/src/common.jl:214
 [8] solve(::LinearProblem{Nothing, true, Matrix{Float64}, Matrix{Float64}, SciMLBase.NullParameters, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, ::LUFactorization{LinearAlgebra.RowMaximum})
   @ LinearSolve ~/code/enzyme_playground/LS_FD/src/common.jl:211
 [9] top-level scope
   @ REPL[24]:3

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@avik-pal I thought you handled something with this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@avik-pal A ping on this. Is there another way to do this if we do not yet have batch dispatch?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not for this case, but a case where A and b are both batched. Here you will have to see how Base handles it, there are special LAPACK routines for these

d = Dual{T}.(res.u, Tuple.(eachrow(dresus)))
LinearSolve.SciMLBase.build_linear_solution(alg, d, nothing, cache; retcode=res.retcode, iters=res.iters, stats=res.stats)

Check warning on line 31 in ext/LinearSolveForwardDiff.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveForwardDiff.jl#L30-L31

Added lines #L30 - L31 were not covered by tests
end


for ALG in subtypes(LinearSolve, LinearSolve.AbstractFactorization)
@eval begin
function LinearSolve.solve!(

Check warning on line 37 in ext/LinearSolveForwardDiff.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveForwardDiff.jl#L37

Added line #L37 was not covered by tests
cache::LinearSolve.LinearCache{<:AbstractMatrix{<:Dual{T,V,P}}, B},
alg::$ALG,
kwargs...
) where {T, V, P, B}
# @info "using solve! df/dA"
dAs = begin
t = collect.(ForwardDiff.partials.(cache.A))
[getindex.(t, i) for i in 1:P]

Check warning on line 45 in ext/LinearSolveForwardDiff.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveForwardDiff.jl#L43-L45

Added lines #L43 - L45 were not covered by tests
end
dbs = [zero(cache.b) for _=1:P]
A = ForwardDiff.value.(cache.A)
b = cache.b
_solve!(cache, alg, dAs, dbs, A, b, T; kwargs...)

Check warning on line 50 in ext/LinearSolveForwardDiff.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveForwardDiff.jl#L47-L50

Added lines #L47 - L50 were not covered by tests
end
function LinearSolve.solve!(

Check warning on line 52 in ext/LinearSolveForwardDiff.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveForwardDiff.jl#L52

Added line #L52 was not covered by tests
cache::LinearSolve.LinearCache{A_,<:AbstractArray{<:Dual{T,V,P}}},
alg::$ALG;
kwargs...
) where {T, V, P, A_}
# @info "using solve! df/db"
dAs = [zero(cache.A) for _=1:P]
dbs = begin
t = collect.(ForwardDiff.partials.(cache.b))
[getindex.(t, i) for i in 1:P]

Check warning on line 61 in ext/LinearSolveForwardDiff.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveForwardDiff.jl#L58-L61

Added lines #L58 - L61 were not covered by tests
end
A = cache.A
b = ForwardDiff.value.(cache.b)
_solve!(cache, alg, dAs, dbs, A, b, T; kwargs...)

Check warning on line 65 in ext/LinearSolveForwardDiff.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveForwardDiff.jl#L63-L65

Added lines #L63 - L65 were not covered by tests
end
function LinearSolve.solve!(

Check warning on line 67 in ext/LinearSolveForwardDiff.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveForwardDiff.jl#L67

Added line #L67 was not covered by tests
cache::LinearSolve.LinearCache{<:AbstractMatrix{<:Dual{T,V,P}},<:AbstractArray{<:Dual{T,V,P}}},
alg::$ALG;
kwargs...
) where {T, V, P}
# @info "using solve! df/dAb"
dAs = begin
t = collect.(ForwardDiff.partials.(cache.A))
[getindex.(t, i) for i in 1:P]

Check warning on line 75 in ext/LinearSolveForwardDiff.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveForwardDiff.jl#L73-L75

Added lines #L73 - L75 were not covered by tests
end
dbs = begin
t = collect.(ForwardDiff.partials.(cache.b))
[getindex.(t, i) for i in 1:P]

Check warning on line 79 in ext/LinearSolveForwardDiff.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveForwardDiff.jl#L77-L79

Added lines #L77 - L79 were not covered by tests
end
A = ForwardDiff.value.(cache.A)
b = ForwardDiff.value.(cache.b)
_solve!(cache, alg, dAs, dbs, A, b, T; kwargs...)

Check warning on line 83 in ext/LinearSolveForwardDiff.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveForwardDiff.jl#L81-L83

Added lines #L81 - L83 were not covered by tests
end
end
end

end # module
9 changes: 9 additions & 0 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,15 @@
assumptions::OperatorAssumptions{issq}
end

function SciMLBase.remake(cache::LinearCache;

Check warning on line 85 in src/common.jl

View check run for this annotation

Codecov / codecov/patch

src/common.jl#L85

Added line #L85 was not covered by tests
A::TA=cache.A, b::TB=cache.b, u::TU=cache.u, p::TP=cache.p, alg::Talg=cache.alg,
cacheval::Tc=cache.cacheval, isfresh::Bool=cache.isfresh, Pl::Tl=cache.Pl, Pr::Tr=cache.Pr,
abstol::Ttol=cache.abstol, reltol::Ttol=cache.reltol, maxiters::Int=cache.maxiters,
verbose::Bool=cache.verbose, assumptions::OperatorAssumptions{issq}=cache.assumptions) where {TA, TB, TU, TP, Talg, Tc, Tl, Tr, Ttol, issq}
LinearCache{TA, TB, TU, TP, Talg, Tc, Tl, Tr, Ttol, issq}(A,b,u,p,alg,cacheval,isfresh,Pl,Pr,abstol,reltol,

Check warning on line 90 in src/common.jl

View check run for this annotation

Codecov / codecov/patch

src/common.jl#L90

Added line #L90 was not covered by tests
maxiters,verbose,assumptions)
end

Comment on lines +85 to +93
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to check if there is a way to avoid redefining this by providing a better constructor for LinearCache.

function Base.setproperty!(cache::LinearCache, name::Symbol, x)
if name === :A
setfield!(cache, :isfresh, true)
Expand Down
74 changes: 74 additions & 0 deletions test/forwarddiff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
using Test
using ForwardDiff
using LinearSolve
using FiniteDiff
using Enzyme
using Random
Random.seed!(1234)

n = 4
A = rand(n, n);
dA = zeros(n, n);
b1 = rand(n);
for alg in (
LUFactorization(),
RFLUFactorization(),
# KrylovJL_GMRES(), dispatch fails
)
alg_str = string(alg)
@show alg_str
function fb(b)
prob = LinearProblem(A, b)

sol1 = solve(prob, alg)

sum(sol1.u)
end
fb(b1)

fid_jac = FiniteDiff.finite_difference_jacobian(fb, b1) |> vec
@show fid_jac

fod_jac = ForwardDiff.gradient(fb, b1) |> vec
@show fod_jac

@test fod_jac ≈ fid_jac rtol=1e-6

function fA(A)
prob = LinearProblem(A, b1)

sol1 = solve(prob, alg)

sum(sol1.u)
end
fA(A)

fid_jac = FiniteDiff.finite_difference_jacobian(fA, A) |> vec
@show fid_jac

fod_jac = ForwardDiff.gradient(fA, A) |> vec
@show fod_jac

@test fod_jac ≈ fid_jac rtol=1e-6


function fAb(Ab)
A = Ab[:, 1:n]
b1 = Ab[:, n+1]
prob = LinearProblem(A, b1)

sol1 = solve(prob, alg)

sum(sol1.u)
end
fAb(hcat(A, b1))

fid_jac = FiniteDiff.finite_difference_jacobian(fAb, hcat(A, b1)) |> vec
@show fid_jac

fod_jac = ForwardDiff.gradient(fAb, hcat(A, b1)) |> vec
@show fod_jac

@test fod_jac ≈ fid_jac rtol=1e-6

end
Loading