Skip to content

Commit

Permalink
Merge pull request #514 from oscardssmith/os/solver-precs
Browse files Browse the repository at this point in the history
make preconditioners part of the solver rather than a random extra
  • Loading branch information
ChrisRackauckas authored Aug 8, 2024
2 parents f68086b + a456104 commit 4d67571
Show file tree
Hide file tree
Showing 9 changed files with 105 additions and 31 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ JET = "0.8.28"
KLU = "0.6"
KernelAbstractions = "0.9.16"
Krylov = "0.9"
KrylovPreconditioners = "0.2"
KrylovKit = "0.8"
LazyArrays = "1.8, 2"
Libdl = "1.10"
Expand Down Expand Up @@ -135,6 +136,7 @@ IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
KrylovPreconditioners = "45d422c2-293f-44ce-8315-2cb988662dec"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
MultiFloats = "bdf0d083-296b-4888-a5b6-7498122e68a5"
Expand All @@ -148,4 +150,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays", "AllocCheck", "StableRNGs", "Zygote"]
test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "KrylovPreconditioners", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays", "AllocCheck", "StableRNGs", "Zygote"]
11 changes: 5 additions & 6 deletions docs/src/basics/FAQ.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ a few ways:

## How do I use IterativeSolvers solvers with a weighted tolerance vector?

IterativeSolvers.jl computes the norm after the application of the left preconditioner
`Pl`. Thus, in order to use a vector tolerance `weights`, one can mathematically
IterativeSolvers.jl computes the norm after the application of the left preconditioner.
Thus, in order to use a vector tolerance `weights`, one can mathematically
hack the system via the following formulation:

```@example FAQPrec
Expand All @@ -57,11 +57,10 @@ A = rand(n, n)
b = rand(n)
weights = [1e-1, 1]
Pl = LinearSolve.InvPreconditioner(Diagonal(weights))
Pr = Diagonal(weights)
precs = Returns((LinearSolve.InvPreconditioner(Diagonal(weights)), Diagonal(weights)))
prob = LinearProblem(A, b)
sol = solve(prob, KrylovJL_GMRES(), Pl = Pl, Pr = Pr)
sol = solve(prob, KrylovJL_GMRES(precs))
sol.u
```
Expand All @@ -84,5 +83,5 @@ Pl = LinearSolve.ComposePreconditioner(LinearSolve.InvPreconditioner(Diagonal(we
Pr = Diagonal(weights)
prob = LinearProblem(A, b)
sol = solve(prob, KrylovJL_GMRES(), Pl = Pl, Pr = Pr)
sol = solve(prob, KrylovJL_GMRES(precs=Returns((Pl,Pr))))
```
6 changes: 3 additions & 3 deletions ext/LinearSolveIterativeSolversExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module LinearSolveIterativeSolversExt

using LinearSolve, LinearAlgebra
using LinearSolve: LinearCache
using LinearSolve: LinearCache, DEFAULT_PRECS
import LinearSolve: IterativeSolversJL

if isdefined(Base, :get_extension)
Expand All @@ -12,9 +12,9 @@ end

function LinearSolve.IterativeSolversJL(args...;
generate_iterator = IterativeSolvers.gmres_iterable!,
gmres_restart = 0, kwargs...)
gmres_restart = 0, precs = DEFAULT_PRECS, kwargs...)
return IterativeSolversJL(generate_iterator, gmres_restart,
args, kwargs)
precs, args, kwargs)
end

function LinearSolve.IterativeSolversJL_CG(args...; kwargs...)
Expand Down
5 changes: 3 additions & 2 deletions ext/LinearSolveKrylovKitExt.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
module LinearSolveKrylovKitExt

using LinearSolve, KrylovKit, LinearAlgebra
using LinearSolve: LinearCache
using LinearSolve: LinearCache, DEFAULT_PRECS

function LinearSolve.KrylovKitJL(args...;
KrylovAlg = KrylovKit.GMRES, gmres_restart = 0,
precs = DEFAULT_PRECS,
kwargs...)
return KrylovKitJL(KrylovAlg, gmres_restart, args, kwargs)
return KrylovKitJL(KrylovAlg, gmres_restart, precs, args, kwargs)
end

function LinearSolve.KrylovKitJL_CG(args...; kwargs...)
Expand Down
1 change: 1 addition & 0 deletions src/LinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import PrecompileTools
import Krylov
using SciMLBase
import Preferences

const CRC = ChainRulesCore

@static if Sys.ARCH === :x86_64 || Sys.ARCH === :i686
Expand Down
76 changes: 73 additions & 3 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,18 @@ end

function Base.setproperty!(cache::LinearCache, name::Symbol, x)
if name === :A
if hasproperty(cache.alg, :precs) && !isnothing(cache.alg.precs)
Pl, Pr = cache.alg.precs(x, cache.p)
setfield!(cache, :Pl, Pl)
setfield!(cache, :Pr, Pr)
end
setfield!(cache, :isfresh, true)
elseif name === :p
if hasproperty(cache.alg, :precs) && !isnothing(cache.alg.precs)
Pl, Pr = cache.alg.precs(cache.A, x)
setfield!(cache, :Pl, Pl)
setfield!(cache, :Pr, Pr)
end
elseif name === :b
# In case there is something that needs to be done when b is updated
update_cacheval!(cache, :b, x)
Expand Down Expand Up @@ -121,6 +132,8 @@ default_alias_b(::Any, ::Any, ::Any) = false
default_alias_A(::AbstractKrylovSubspaceMethod, ::Any, ::Any) = true
default_alias_b(::AbstractKrylovSubspaceMethod, ::Any, ::Any) = true

DEFAULT_PRECS(A, p) = IdentityOperator(size(A)[1]), IdentityOperator(size(A)[2])

function __init_u0_from_Ab(A, b)
u0 = similar(b, size(A, 2))
fill!(u0, false)
Expand All @@ -136,12 +149,12 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
reltol = default_tol(real(eltype(prob.b))),
maxiters::Int = length(prob.b),
verbose::Bool = false,
Pl = IdentityOperator(size(prob.A)[1]),
Pr = IdentityOperator(size(prob.A)[2]),
Pl = nothing,
Pr = nothing,
assumptions = OperatorAssumptions(issquare(prob.A)),
sensealg = LinearSolveAdjoint(),
kwargs...)
@unpack A, b, u0, p = prob
(;A, b, u0, p) = prob

A = if alias_A || A isa SMatrix
A
Expand All @@ -167,6 +180,24 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
reltol = real(eltype(prob.b))(reltol)
abstol = real(eltype(prob.b))(abstol)

precs = if hasproperty(alg, :precs)
isnothing(alg.precs) ? DEFAULT_PRECS : alg.precs
else
DEFAULT_PRECS
end
_Pl, _Pr = precs(A, p)
if isnothing(Pl)
Pl = _Pl
else
# TODO: deprecate once all docs are updated to the new form
#@warn "passing Preconditioners at `init`/`solve` time is deprecated. Instead add a `precs` function to your algorithm."
end
if isnothing(Pr)
Pr = _Pr
else
# TODO: deprecate once all docs are updated to the new form
#@warn "passing Preconditioners at `init`/`solve` time is deprecated. Instead add a `precs` function to your algorithm."
end
cacheval = init_cacheval(alg, A, b, u0_, Pl, Pr, maxiters, abstol, reltol, verbose,
assumptions)
isfresh = true
Expand All @@ -179,6 +210,45 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
return cache
end


function SciMLBase.reinit!(cache::LinearCache;
A = nothing,
b = cache.b,
u = cache.u,
p = nothing,
reinit_cache = false,)
(; alg, cacheval, abstol, reltol, maxiters, verbose, assumptions, sensealg) = cache

precs = (hasproperty(alg, :precs) && !isnothing(alg.precs)) ? alg.precs : DEFAULT_PRECS
Pl, Pr = if isnothing(A) || isnothing(p)
if isnothing(A)
A = cache.A
end
if isnothing(p)
p = cache.p
end
precs(A, p)
else
(cache.Pl, cache.Pr)
end
isfresh = true

if reinit_cache
return LinearCache{typeof(A), typeof(b), typeof(u), typeof(p), typeof(alg), typeof(cacheval),
typeof(Pl), typeof(Pr), typeof(reltol), typeof(assumptions.issq),
typeof(sensealg)}(A, b, u, p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol,
maxiters, verbose, assumptions, sensealg)
else
cache.A = A
cache.b = b
cache.u = u
cache.p = p
cache.Pl = Pl
cache.Pr = Pr
cache.isfresh = true
end
end

function SciMLBase.solve(prob::LinearProblem, args...; kwargs...)
return solve(prob, nothing, args...; kwargs...)
end
Expand Down
6 changes: 4 additions & 2 deletions src/extension_algs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -259,9 +259,10 @@ solvers.
Using this solver requires adding the package KrylovKit.jl, i.e. `using KrylovKit`
"""
struct KrylovKitJL{F, A, I, K} <: LinearSolve.AbstractKrylovSubspaceMethod
struct KrylovKitJL{F, I, P, A, K} <: LinearSolve.AbstractKrylovSubspaceMethod
KrylovAlg::F
gmres_restart::I
precs::P
args::A
kwargs::K
end
Expand Down Expand Up @@ -306,9 +307,10 @@ A generic wrapper over the IterativeSolvers.jl solvers.
Using this solver requires adding the package IterativeSolvers.jl, i.e. `using IterativeSolvers`
"""
struct IterativeSolversJL{F, I, A, K} <: LinearSolve.AbstractKrylovSubspaceMethod
struct IterativeSolversJL{F, I, P, A, K} <: LinearSolve.AbstractKrylovSubspaceMethod
generate_iterator::F
gmres_restart::I
precs::P
args::A
kwargs::K
end
Expand Down
23 changes: 10 additions & 13 deletions src/iterative_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,21 @@ KrylovJL(args...; KrylovAlg = Krylov.gmres!,
A generic wrapper over the Krylov.jl krylov-subspace iterative solvers.
"""
struct KrylovJL{F, I, A, K} <: AbstractKrylovSubspaceMethod
struct KrylovJL{F, I, P, A, K} <: AbstractKrylovSubspaceMethod
KrylovAlg::F
gmres_restart::I
window::I
precs::P
args::A
kwargs::K
end

function KrylovJL(args...; KrylovAlg = Krylov.gmres!,
gmres_restart = 0, window = 0,
precs = nothing,
kwargs...)
return KrylovJL(KrylovAlg, gmres_restart, window,
args, kwargs)
precs, args, kwargs)
end

default_alias_A(::KrylovJL, ::Any, ::Any) = true
Expand Down Expand Up @@ -231,8 +233,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::KrylovJL; kwargs...)
cache.isfresh = false
end

M = cache.Pl
N = cache.Pr
M, N = cache.Pl, cache.Pr

# use no-op preconditioner for Krylov.jl (LinearAlgebra.I) when M/N is identity
M = _isidentity_struct(M) ? I : M
Expand All @@ -258,25 +259,21 @@ function SciMLBase.solve!(cache::LinearCache, alg::KrylovJL; kwargs...)
end

args = (cacheval, cache.A, cache.b)
kwargs = (atol = atol, rtol = rtol, itmax = itmax, verbose = verbose,
kwargs = (atol = atol, rtol, itmax, verbose,
ldiv = true, history = true, alg.kwargs...)

if cache.cacheval isa Krylov.CgSolver
N !== I &&
@warn "$(alg.KrylovAlg) doesn't support right preconditioning."
Krylov.solve!(args...; M = M,
kwargs...)
Krylov.solve!(args...; M, kwargs...)
elseif cache.cacheval isa Krylov.GmresSolver
Krylov.solve!(args...; M = M, N = N, restart = alg.gmres_restart > 0,
kwargs...)
Krylov.solve!(args...; M, N, restart = alg.gmres_restart > 0, kwargs...)
elseif cache.cacheval isa Krylov.BicgstabSolver
Krylov.solve!(args...; M = M, N = N,
kwargs...)
Krylov.solve!(args...; M, N, kwargs...)
elseif cache.cacheval isa Krylov.MinresSolver
N !== I &&
@warn "$(alg.KrylovAlg) doesn't support right preconditioning."
Krylov.solve!(args...; M = M,
kwargs...)
Krylov.solve!(args...; M, kwargs...)
else
Krylov.solve!(args...; kwargs...)
end
Expand Down
4 changes: 3 additions & 1 deletion test/basictests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using LinearSolve, LinearAlgebra, SparseArrays, MultiFloats, ForwardDiff
using SciMLOperators
using IterativeSolvers, KrylovKit, MKL_jll
using IterativeSolvers, KrylovKit, MKL_jll, KrylovPreconditioners
using Test
import Random

Expand Down Expand Up @@ -267,10 +267,12 @@ end

@testset "KrylovJL" begin
kwargs = (; gmres_restart = 5)
precs = (A,p=nothing) -> (BlockJacobiPreconditioner(A, 2), I)
algorithms = (
("Default", KrylovJL(kwargs...)),
("CG", KrylovJL_CG(kwargs...)),
("GMRES", KrylovJL_GMRES(kwargs...)),
("GMRES_prec", KrylovJL_GMRES(;precs, ldiv=false, kwargs...)),
# ("BICGSTAB",KrylovJL_BICGSTAB(kwargs...)),
("MINRES", KrylovJL_MINRES(kwargs...))
)
Expand Down

0 comments on commit 4d67571

Please sign in to comment.