Skip to content

Commit

Permalink
Merge pull request #449 from SciML/ap/adjoint
Browse files Browse the repository at this point in the history
Adjoints for Linear Solve
  • Loading branch information
ChrisRackauckas authored Feb 25, 2024
2 parents a206054 + e937e67 commit 7b090b4
Show file tree
Hide file tree
Showing 9 changed files with 261 additions and 21 deletions.
12 changes: 9 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
name = "LinearSolve"
uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
authors = ["SciML"]
version = "2.24.0"
version = "2.25.0"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
Expand All @@ -16,6 +17,7 @@ Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
RecursiveFactorization = "f2c3362d-daeb-58d1-803e-2bc74f2840b4"
Expand Down Expand Up @@ -64,6 +66,7 @@ ArrayInterface = "7.7"
BandedMatrices = "1.5"
BlockDiagonals = "0.1.42"
CUDA = "5"
ChainRulesCore = "1.22"
ConcreteStructs = "0.2.3"
DocStringExtensions = "0.9.3"
EnumX = "1.0.4"
Expand All @@ -85,6 +88,7 @@ KrylovKit = "0.6"
Libdl = "1.10"
LinearAlgebra = "1.10"
MPI = "0.20"
Markdown = "1.10"
Metal = "0.5"
MultiFloats = "1"
Pardiso = "0.5"
Expand All @@ -96,7 +100,7 @@ RecursiveArrayTools = "3.8"
RecursiveFactorization = "0.2.14"
Reexport = "1"
SafeTestsets = "0.1"
SciMLBase = "2.23.0"
SciMLBase = "2.26.3"
SciMLOperators = "0.3.7"
Setfield = "1"
SparseArrays = "1.10"
Expand All @@ -106,6 +110,7 @@ StaticArrays = "1.5"
StaticArraysCore = "1.4.2"
Test = "1"
UnPack = "1"
Zygote = "0.6.69"
julia = "1.10"

[extras]
Expand Down Expand Up @@ -133,6 +138,7 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays", "AllocCheck", "StableRNGs"]
test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays", "AllocCheck", "StableRNGs", "Zygote"]
8 changes: 4 additions & 4 deletions ext/LinearSolveHYPREExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using HYPRE.LibHYPRE: HYPRE_Complex
using HYPRE: HYPRE, HYPREMatrix, HYPRESolver, HYPREVector
using LinearSolve: HYPREAlgorithm, LinearCache, LinearProblem, LinearSolve,
OperatorAssumptions, default_tol, init_cacheval, __issquare,
__conditioning
__conditioning, LinearSolveAdjoint
using SciMLBase: LinearProblem, SciMLBase
using UnPack: @unpack
using Setfield: @set!
Expand Down Expand Up @@ -68,6 +68,7 @@ function SciMLBase.init(prob::LinearProblem, alg::HYPREAlgorithm,
Pl = LinearAlgebra.I,
Pr = LinearAlgebra.I,
assumptions = OperatorAssumptions(),
sensealg = LinearSolveAdjoint(),
kwargs...)
@unpack A, b, u0, p = prob

Expand All @@ -89,10 +90,9 @@ function SciMLBase.init(prob::LinearProblem, alg::HYPREAlgorithm,
cache = LinearCache{
typeof(A), typeof(b), typeof(u0), typeof(p), typeof(alg), Tc,
typeof(Pl), typeof(Pr), typeof(reltol),
typeof(__issquare(assumptions))
typeof(__issquare(assumptions)), typeof(sensealg)
}(A, b, u0, p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol,
maxiters,
verbose, assumptions)
maxiters, verbose, assumptions, sensealg)
return cache
end

Expand Down
7 changes: 7 additions & 0 deletions src/LinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ PrecompileTools.@recompile_invalidations begin
using FastLapackInterface
using DocStringExtensions
using EnumX
using Markdown
using ChainRulesCore
import InteractiveUtils

import StaticArraysCore: StaticArray, SVector, MVector, SMatrix, MMatrix
Expand All @@ -42,6 +44,8 @@ PrecompileTools.@recompile_invalidations begin
import Preferences
end

const CRC = ChainRulesCore

if Preferences.@load_preference("LoadMKL_JLL", true)
using MKL_jll
const usemkl = MKL_jll.is_available()
Expand Down Expand Up @@ -125,6 +129,7 @@ include("solve_function.jl")
include("default.jl")
include("init.jl")
include("extension_algs.jl")
include("adjoint.jl")
include("deprecated.jl")

@generated function SciMLBase.solve!(cache::LinearCache, alg::AbstractFactorization;
Expand Down Expand Up @@ -240,4 +245,6 @@ export MetalLUFactorization

export OperatorAssumptions, OperatorCondition

export LinearSolveAdjoint

end
93 changes: 93 additions & 0 deletions src/adjoint.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# TODO: Preconditioners? Should Pl be transposed and made Pr and similar for Pr.

@doc doc"""
LinearSolveAdjoint(; linsolve = missing)
Given a Linear Problem ``A x = b`` computes the sensitivities for ``A`` and ``b`` as:
```math
\begin{align}
A^T \lambda &= \partial x \\
\partial A &= -\lambda x^T \\
\partial b &= \lambda
\end{align}
```
For more details, check [these notes](https://math.mit.edu/~stevenj/18.336/adjoint.pdf).
## Choice of Linear Solver
Note that in most cases, it makes sense to use the same linear solver for the adjoint as the
forward solve (this is done by keeping the linsolve as `missing`). For example, if the
forward solve was performed via a Factorization, then we can reuse the factorization for the
adjoint solve. However, for specific structured matrices if ``A^T`` is known to have a
specific structure distinct from ``A`` then passing in a `linsolve` will be more efficient.
"""
@kwdef struct LinearSolveAdjoint{L} <:
SciMLBase.AbstractSensitivityAlgorithm{0, false, :central}
linsolve::L = missing
end

function CRC.rrule(::typeof(SciMLBase.solve), prob::LinearProblem,
alg::SciMLLinearSolveAlgorithm, args...; alias_A = default_alias_A(
alg, prob.A, prob.b), kwargs...)
# sol = solve(prob, alg, args...; kwargs...)
cache = init(prob, alg, args...; kwargs...)
(; A, sensealg) = cache

@assert sensealg isa LinearSolveAdjoint "Currently only `LinearSolveAdjoint` is supported for adjoint sensitivity analysis."

# Decide if we need to cache `A` and `b` for the reverse pass
if sensealg.linsolve === missing
# We can reuse the factorization so no copy is needed
# Krylov Methods don't modify `A`, so it's safe to just reuse it
# No Copy is needed even for the default case
if !(alg isa AbstractFactorization || alg isa AbstractKrylovSubspaceMethod ||
alg isa DefaultLinearSolver)
A_ = alias_A ? deepcopy(A) : A
end
else
A_ = deepcopy(A)
end

sol = solve!(cache)

function ∇linear_solve(∂sol)
∂∅ = NoTangent()

∂u = ∂sol.u
if sensealg.linsolve === missing
λ = if cache.cacheval isa Factorization
cache.cacheval' \ ∂u
elseif cache.cacheval isa Tuple && cache.cacheval[1] isa Factorization
first(cache.cacheval)' \ ∂u
elseif alg isa AbstractKrylovSubspaceMethod
invprob = LinearProblem(transpose(cache.A), ∂u)
solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u
elseif alg isa DefaultLinearSolver
LinearSolve.defaultalg_adjoint_eval(cache, ∂u)
else
invprob = LinearProblem(transpose(A_), ∂u) # We cached `A`
solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u
end
else
invprob = LinearProblem(transpose(A_), ∂u) # We cached `A`
λ = solve(
invprob, sensealg.linsolve; cache.abstol, cache.reltol, cache.verbose).u
end

∂A = -λ * transpose(sol.u)
∂b = λ
∂prob = LinearProblem(∂A, ∂b, ∂∅)

return (∂∅, ∂prob, ∂∅, ntuple(_ -> ∂∅, length(args))...)
end

return sol, ∇linear_solve
end

function CRC.rrule(::Type{<:LinearProblem}, A, b, p; kwargs...)
prob = LinearProblem(A, b, p)
∇prob(∂prob) = (NoTangent(), ∂prob.A, ∂prob.b, ∂prob.p)
return prob, ∇prob
end
19 changes: 13 additions & 6 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ end
__issquare(assump::OperatorAssumptions) = assump.issq
__conditioning(assump::OperatorAssumptions) = assump.condition

mutable struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issq}
mutable struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issq, S}
A::TA
b::Tb
u::Tu
Expand All @@ -80,6 +80,7 @@ mutable struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issq}
maxiters::Int
verbose::Bool
assumptions::OperatorAssumptions{issq}
sensealg::S
end

function Base.setproperty!(cache::LinearCache, name::Symbol, x)
Expand Down Expand Up @@ -138,6 +139,7 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
Pl = IdentityOperator(size(prob.A)[1]),
Pr = IdentityOperator(size(prob.A)[2]),
assumptions = OperatorAssumptions(issquare(prob.A)),
sensealg = LinearSolveAdjoint(),
kwargs...)
@unpack A, b, u0, p = prob

Expand Down Expand Up @@ -171,17 +173,22 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
Tc = typeof(cacheval)

cache = LinearCache{typeof(A), typeof(b), typeof(u0_), typeof(p), typeof(alg), Tc,
typeof(Pl), typeof(Pr), typeof(reltol), typeof(assumptions.issq)}(A, b, u0_,
p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol, maxiters, verbose, assumptions)
typeof(Pl), typeof(Pr), typeof(reltol), typeof(assumptions.issq),
typeof(sensealg)}(A, b, u0_, p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol,
maxiters, verbose, assumptions, sensealg)
return cache
end

function SciMLBase.solve(prob::LinearProblem, args...; kwargs...)
solve!(init(prob, nothing, args...; kwargs...))
return solve(prob, nothing, args...; kwargs...)
end

function SciMLBase.solve(prob::LinearProblem,
alg::Union{SciMLLinearSolveAlgorithm, Nothing},
function SciMLBase.solve(prob::LinearProblem, ::Nothing, args...;
assump = OperatorAssumptions(issquare(prob.A)), kwargs...)
return solve(prob, defaultalg(prob.A, prob.b, assump), args...; kwargs...)
end

function SciMLBase.solve(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
args...; kwargs...)
solve!(init(prob, alg, args...; kwargs...))
end
Expand Down
16 changes: 10 additions & 6 deletions src/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -779,26 +779,30 @@ function SciMLBase.solve!(cache::LinearCache, alg::UMFPACKFactorization; kwargs.
cacheval.colptr &&
SparseArrays.decrement(SparseArrays.getrowval(A)) ==
cacheval.rowval)
fact = lu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
nonzeros(A)), check=false)
fact = lu(
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
nonzeros(A)),
check = false)
else
fact = lu!(cacheval,
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
nonzeros(A)), check=false)
nonzeros(A)), check = false)
end
else
fact = lu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)), check=false)
fact = lu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)),
check = false)
end
cache.cacheval = fact
cache.isfresh = false
end

F = @get_cacheval(cache, :UMFPACKFactorization)
F = @get_cacheval(cache, :UMFPACKFactorization)
if F.status == SparseArrays.UMFPACK.UMFPACK_OK
y = ldiv!(cache.u, F, cache.b)
SciMLBase.build_linear_solution(alg, y, nothing, cache)
else
SciMLBase.build_linear_solution(alg, cache.u, nothing, cache; retcode=ReturnCode.Infeasible)
SciMLBase.build_linear_solution(
alg, cache.u, nothing, cache; retcode = ReturnCode.Infeasible)
end
end

Expand Down
Loading

0 comments on commit 7b090b4

Please sign in to comment.