Skip to content

Commit

Permalink
Refactor code and add dispatch where both A and b are dual
Browse files Browse the repository at this point in the history
  • Loading branch information
sharanry committed Nov 13, 2023
1 parent a67d7aa commit 465e11c
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 24 deletions.
36 changes: 26 additions & 10 deletions ext/LinearSolveForwardDiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,48 +17,64 @@ function _solve!(cache, alg, dAs, dbs, A, b, T; kwargs...)
LinearSolve.LinearAlgebra.LU(ForwardDiff.value.(cache.cacheval.factors), cache.cacheval.ipiv, cache.cacheval.info)

Check warning on line 17 in ext/LinearSolveForwardDiff.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveForwardDiff.jl#L8-L17

Added lines #L8 - L17 were not covered by tests
end : cache.cacheval
cache2 = remake(cache; A, b, u, reltol, abstol, cacheval)
res = LinearSolve.solve!(cache2, alg, kwargs...)
res = LinearSolve.solve!(cache2, alg, kwargs...) |> deepcopy
dresus = reduce(hcat, map(dAs, dbs) do dA, db
cache2.b = db - dA * res.u
dres = LinearSolve.solve!(cache2, alg, kwargs...)
deepcopy(dres.u)

Check warning on line 24 in ext/LinearSolveForwardDiff.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveForwardDiff.jl#L19-L24

Added lines #L19 - L24 were not covered by tests
end)
# display(dresus)
d = Dual{T}.(res.u, Tuple.(eachrow(dresus)))
LinearSolve.SciMLBase.build_linear_solution(alg, d, nothing, cache; retcode=res.retcode, iters=res.iters, stats=res.stats)

Check warning on line 27 in ext/LinearSolveForwardDiff.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveForwardDiff.jl#L26-L27

Added lines #L26 - L27 were not covered by tests
end


function LinearSolve.solve!(

Check warning on line 31 in ext/LinearSolveForwardDiff.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveForwardDiff.jl#L31

Added line #L31 was not covered by tests
cache::LinearSolve.LinearCache{<:AbstractMatrix{<:Dual{T,V,P}}},
cache::LinearSolve.LinearCache{<:AbstractMatrix{<:Dual{T,V,P}},<:AbstractArray{<:AbstractFloat}},
alg::LinearSolve.AbstractFactorization;
kwargs...
) where {T, V, P}
@info "using solve! df/dA"
dAs = begin
dAs_ = ForwardDiff.partials.(cache.A)
dAs_ = collect.(dAs_)
dAs_ = [getindex.(dAs_, i) for i in 1:length(first(dAs_))]
t = collect.(ForwardDiff.partials.(cache.A))
[getindex.(t, i) for i in 1:P]

Check warning on line 39 in ext/LinearSolveForwardDiff.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveForwardDiff.jl#L36-L39

Added lines #L36 - L39 were not covered by tests
end
dbs = [zero(cache.b) for _=1:P]
A = ForwardDiff.value.(cache.A)
b = cache.b
_solve!(cache, alg, dAs, dbs, A, b, T; kwargs...)

Check warning on line 44 in ext/LinearSolveForwardDiff.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveForwardDiff.jl#L41-L44

Added lines #L41 - L44 were not covered by tests
end
function LinearSolve.solve!(

Check warning on line 46 in ext/LinearSolveForwardDiff.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveForwardDiff.jl#L46

Added line #L46 was not covered by tests
cache::LinearSolve.LinearCache{A_,<:AbstractArray{<:Dual{T,V,P}}},
cache::LinearSolve.LinearCache{<:AbstractMatrix{<:AbstractFloat},<:AbstractArray{<:Dual{T,V,P}}},
alg::LinearSolve.AbstractFactorization;
kwargs...
) where {T, V, P, A_}
@info "using solve! df/db"
dAs = [zero(cache.A) for _=1:P]
dbs = begin
dbs_ = ForwardDiff.partials.(cache.b)
dbs_ = collect.(dbs_)
dbs_ = [getindex.(dbs_, i) for i in 1:length(first(dbs_))]
t = collect.(ForwardDiff.partials.(cache.b))
[getindex.(t, i) for i in 1:P]

Check warning on line 55 in ext/LinearSolveForwardDiff.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveForwardDiff.jl#L51-L55

Added lines #L51 - L55 were not covered by tests
end
A = cache.A
b = ForwardDiff.value.(cache.b)
_solve!(cache, alg, dAs, dbs, A, b, T; kwargs...)

Check warning on line 59 in ext/LinearSolveForwardDiff.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveForwardDiff.jl#L57-L59

Added lines #L57 - L59 were not covered by tests
end
function LinearSolve.solve!(

Check warning on line 61 in ext/LinearSolveForwardDiff.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveForwardDiff.jl#L61

Added line #L61 was not covered by tests
cache::LinearSolve.LinearCache{<:AbstractMatrix{<:Dual{T,V,P}},<:AbstractArray{<:Dual{T,V,P}}},
alg::LinearSolve.AbstractFactorization;
kwargs...
) where {T, V, P}
@info "using solve! df/dAb"
dAs = begin
t = collect.(ForwardDiff.partials.(cache.A))
[getindex.(t, i) for i in 1:P]

Check warning on line 69 in ext/LinearSolveForwardDiff.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveForwardDiff.jl#L66-L69

Added lines #L66 - L69 were not covered by tests
end
dbs = begin
t = collect.(ForwardDiff.partials.(cache.b))
[getindex.(t, i) for i in 1:P]

Check warning on line 73 in ext/LinearSolveForwardDiff.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveForwardDiff.jl#L71-L73

Added lines #L71 - L73 were not covered by tests
end
A = ForwardDiff.value.(cache.A)
b = ForwardDiff.value.(cache.b)
_solve!(cache, alg, dAs, dbs, A, b, T; kwargs...)

Check warning on line 77 in ext/LinearSolveForwardDiff.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveForwardDiff.jl#L75-L77

Added lines #L75 - L77 were not covered by tests
end

end # module
54 changes: 40 additions & 14 deletions test/forwarddiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ n = 4
A = rand(n, n);
dA = zeros(n, n);
b1 = rand(n);
alg = LUFactorization()
# for alg in (
# LUFactorization(),
# # RFLUFactorization(),
# # KrylovJL_GMRES(),
# )
# alg = LUFactorization()
for alg in (
LUFactorization(),
RFLUFactorization(),
KrylovJL_GMRES(),
)
alg_str = string(alg)
@show alg_str
function fb(b)
Expand Down Expand Up @@ -51,19 +51,45 @@ alg = LUFactorization()
sum(sol1.u)
end
fA(A)
db = zero(b1)
manual_jac = map(onehot(A)) do dA
y = A \ b1
sum(inv(A) * (db - dA*y))
end |> collect
display(reduce(hcat, manual_jac))
# db = zero(b1)
# manual_jac = map(onehot(A)) do dA
# y = A \ b1
# t = inv(A) * (db - dA*y)
# end |> collect
# display(reduce(hcat, manual_jac))

fid_jac = FiniteDiff.finite_difference_jacobian(fA, A) |> vec
@show fid_jac

# @test_throws MethodError fod_jac = ForwardDiff.gradient(fA, A) |> vec
fod_jac = ForwardDiff.gradient(fA, A) |> vec
@show fod_jac

@test fod_jac fid_jac rtol=1e-6
# end


# function fAb(Ab)
# A = Ab[:, 1:n]
# b1 = Ab[:, n+1]
# prob = LinearProblem(A, b1)

# sol1 = solve(prob, alg)

# sum(sol1.u)
# end
# fAb(hcat(A, b1))
# # db = zero(b1)
# # manual_jac = map(onehot(A)) do dA
# # y = A \ b1
# # t = inv(A) * (db - dA*y)
# # end |> collect
# # display(reduce(hcat, manual_jac))

# fid_jac = FiniteDiff.finite_difference_jacobian(fAb, hcat(A, b1)) |> vec
# @show fid_jac

# fod_jac = ForwardDiff.gradient(fAb, hcat(A, b1)) |> vec
# @show fod_jac

# @test fod_jac ≈ fid_jac rtol=1e-6

end

0 comments on commit 465e11c

Please sign in to comment.