diff --git a/Project.toml b/Project.toml index f083c4a79..d9edebbd8 100644 --- a/Project.toml +++ b/Project.toml @@ -90,6 +90,7 @@ JET = "0.8.28" KLU = "0.6" KernelAbstractions = "0.9.16" Krylov = "0.9" +KrylovPreconditioners = "0.2" KrylovKit = "0.8" LazyArrays = "1.8, 2" Libdl = "1.10" @@ -135,6 +136,7 @@ IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" +KrylovPreconditioners = "45d422c2-293f-44ce-8315-2cb988662dec" MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" Metal = "dde4c033-4e86-420c-a63e-0dd931031962" MultiFloats = "bdf0d083-296b-4888-a5b6-7498122e68a5" @@ -148,4 +150,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays", "AllocCheck", "StableRNGs", "Zygote"] +test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "KrylovPreconditioners", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays", "AllocCheck", "StableRNGs", "Zygote"] diff --git a/docs/src/basics/FAQ.md b/docs/src/basics/FAQ.md index 108f748b6..8e07a9957 100644 --- a/docs/src/basics/FAQ.md +++ b/docs/src/basics/FAQ.md @@ -45,8 +45,8 @@ a few ways: ## How do I use IterativeSolvers solvers with a weighted tolerance vector? -IterativeSolvers.jl computes the norm after the application of the left preconditioner -`Pl`. Thus, in order to use a vector tolerance `weights`, one can mathematically +IterativeSolvers.jl computes the norm after the application of the left preconditioner. +Thus, in order to use a vector tolerance `weights`, one can mathematically hack the system via the following formulation: ```@example FAQPrec @@ -57,11 +57,10 @@ A = rand(n, n) b = rand(n) weights = [1e-1, 1] -Pl = LinearSolve.InvPreconditioner(Diagonal(weights)) -Pr = Diagonal(weights) +precs = Returns((LinearSolve.InvPreconditioner(Diagonal(weights)), Diagonal(weights))) prob = LinearProblem(A, b) -sol = solve(prob, KrylovJL_GMRES(), Pl = Pl, Pr = Pr) +sol = solve(prob, KrylovJL_GMRES(precs)) sol.u ``` @@ -84,5 +83,5 @@ Pl = LinearSolve.ComposePreconditioner(LinearSolve.InvPreconditioner(Diagonal(we Pr = Diagonal(weights) prob = LinearProblem(A, b) -sol = solve(prob, KrylovJL_GMRES(), Pl = Pl, Pr = Pr) +sol = solve(prob, KrylovJL_GMRES(precs=Returns((Pl,Pr)))) ``` diff --git a/ext/LinearSolveIterativeSolversExt.jl b/ext/LinearSolveIterativeSolversExt.jl index 507e75ad2..cb4589642 100644 --- a/ext/LinearSolveIterativeSolversExt.jl +++ b/ext/LinearSolveIterativeSolversExt.jl @@ -1,7 +1,7 @@ module LinearSolveIterativeSolversExt using LinearSolve, LinearAlgebra -using LinearSolve: LinearCache +using LinearSolve: LinearCache, DEFAULT_PRECS import LinearSolve: IterativeSolversJL if isdefined(Base, :get_extension) @@ -12,9 +12,9 @@ end function LinearSolve.IterativeSolversJL(args...; generate_iterator = IterativeSolvers.gmres_iterable!, - gmres_restart = 0, kwargs...) + gmres_restart = 0, precs = DEFAULT_PRECS, kwargs...) return IterativeSolversJL(generate_iterator, gmres_restart, - args, kwargs) + precs, args, kwargs) end function LinearSolve.IterativeSolversJL_CG(args...; kwargs...) diff --git a/ext/LinearSolveKrylovKitExt.jl b/ext/LinearSolveKrylovKitExt.jl index a26e26688..1aa1e5d52 100644 --- a/ext/LinearSolveKrylovKitExt.jl +++ b/ext/LinearSolveKrylovKitExt.jl @@ -1,12 +1,13 @@ module LinearSolveKrylovKitExt using LinearSolve, KrylovKit, LinearAlgebra -using LinearSolve: LinearCache +using LinearSolve: LinearCache, DEFAULT_PRECS function LinearSolve.KrylovKitJL(args...; KrylovAlg = KrylovKit.GMRES, gmres_restart = 0, + precs = DEFAULT_PRECS, kwargs...) - return KrylovKitJL(KrylovAlg, gmres_restart, args, kwargs) + return KrylovKitJL(KrylovAlg, gmres_restart, precs, args, kwargs) end function LinearSolve.KrylovKitJL_CG(args...; kwargs...) diff --git a/src/LinearSolve.jl b/src/LinearSolve.jl index d8490e86a..c2fb6067c 100644 --- a/src/LinearSolve.jl +++ b/src/LinearSolve.jl @@ -41,6 +41,7 @@ import PrecompileTools import Krylov using SciMLBase import Preferences + const CRC = ChainRulesCore @static if Sys.ARCH === :x86_64 || Sys.ARCH === :i686 diff --git a/src/common.jl b/src/common.jl index a53741411..3c222eca5 100644 --- a/src/common.jl +++ b/src/common.jl @@ -85,7 +85,18 @@ end function Base.setproperty!(cache::LinearCache, name::Symbol, x) if name === :A + if hasproperty(cache.alg, :precs) && !isnothing(cache.alg.precs) + Pl, Pr = cache.alg.precs(x, cache.p) + setfield!(cache, :Pl, Pl) + setfield!(cache, :Pr, Pr) + end setfield!(cache, :isfresh, true) + elseif name === :p + if hasproperty(cache.alg, :precs) && !isnothing(cache.alg.precs) + Pl, Pr = cache.alg.precs(cache.A, x) + setfield!(cache, :Pl, Pl) + setfield!(cache, :Pr, Pr) + end elseif name === :b # In case there is something that needs to be done when b is updated update_cacheval!(cache, :b, x) @@ -121,6 +132,8 @@ default_alias_b(::Any, ::Any, ::Any) = false default_alias_A(::AbstractKrylovSubspaceMethod, ::Any, ::Any) = true default_alias_b(::AbstractKrylovSubspaceMethod, ::Any, ::Any) = true +DEFAULT_PRECS(A, p) = IdentityOperator(size(A)[1]), IdentityOperator(size(A)[2]) + function __init_u0_from_Ab(A, b) u0 = similar(b, size(A, 2)) fill!(u0, false) @@ -136,12 +149,12 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm, reltol = default_tol(real(eltype(prob.b))), maxiters::Int = length(prob.b), verbose::Bool = false, - Pl = IdentityOperator(size(prob.A)[1]), - Pr = IdentityOperator(size(prob.A)[2]), + Pl = nothing, + Pr = nothing, assumptions = OperatorAssumptions(issquare(prob.A)), sensealg = LinearSolveAdjoint(), kwargs...) - @unpack A, b, u0, p = prob + (;A, b, u0, p) = prob A = if alias_A || A isa SMatrix A @@ -167,6 +180,24 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm, reltol = real(eltype(prob.b))(reltol) abstol = real(eltype(prob.b))(abstol) + precs = if hasproperty(alg, :precs) + isnothing(alg.precs) ? DEFAULT_PRECS : alg.precs + else + DEFAULT_PRECS + end + _Pl, _Pr = precs(A, p) + if isnothing(Pl) + Pl = _Pl + else + # TODO: deprecate once all docs are updated to the new form + #@warn "passing Preconditioners at `init`/`solve` time is deprecated. Instead add a `precs` function to your algorithm." + end + if isnothing(Pr) + Pr = _Pr + else + # TODO: deprecate once all docs are updated to the new form + #@warn "passing Preconditioners at `init`/`solve` time is deprecated. Instead add a `precs` function to your algorithm." + end cacheval = init_cacheval(alg, A, b, u0_, Pl, Pr, maxiters, abstol, reltol, verbose, assumptions) isfresh = true @@ -179,6 +210,45 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm, return cache end + +function SciMLBase.reinit!(cache::LinearCache; + A = nothing, + b = cache.b, + u = cache.u, + p = nothing, + reinit_cache = false,) + (; alg, cacheval, abstol, reltol, maxiters, verbose, assumptions, sensealg) = cache + + precs = (hasproperty(alg, :precs) && !isnothing(alg.precs)) ? alg.precs : DEFAULT_PRECS + Pl, Pr = if isnothing(A) || isnothing(p) + if isnothing(A) + A = cache.A + end + if isnothing(p) + p = cache.p + end + precs(A, p) + else + (cache.Pl, cache.Pr) + end + isfresh = true + + if reinit_cache + return LinearCache{typeof(A), typeof(b), typeof(u), typeof(p), typeof(alg), typeof(cacheval), + typeof(Pl), typeof(Pr), typeof(reltol), typeof(assumptions.issq), + typeof(sensealg)}(A, b, u, p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol, + maxiters, verbose, assumptions, sensealg) + else + cache.A = A + cache.b = b + cache.u = u + cache.p = p + cache.Pl = Pl + cache.Pr = Pr + cache.isfresh = true + end +end + function SciMLBase.solve(prob::LinearProblem, args...; kwargs...) return solve(prob, nothing, args...; kwargs...) end diff --git a/src/extension_algs.jl b/src/extension_algs.jl index c5e2db955..7534d2fa1 100644 --- a/src/extension_algs.jl +++ b/src/extension_algs.jl @@ -259,9 +259,10 @@ solvers. Using this solver requires adding the package KrylovKit.jl, i.e. `using KrylovKit` """ -struct KrylovKitJL{F, A, I, K} <: LinearSolve.AbstractKrylovSubspaceMethod +struct KrylovKitJL{F, I, P, A, K} <: LinearSolve.AbstractKrylovSubspaceMethod KrylovAlg::F gmres_restart::I + precs::P args::A kwargs::K end @@ -306,9 +307,10 @@ A generic wrapper over the IterativeSolvers.jl solvers. Using this solver requires adding the package IterativeSolvers.jl, i.e. `using IterativeSolvers` """ -struct IterativeSolversJL{F, I, A, K} <: LinearSolve.AbstractKrylovSubspaceMethod +struct IterativeSolversJL{F, I, P, A, K} <: LinearSolve.AbstractKrylovSubspaceMethod generate_iterator::F gmres_restart::I + precs::P args::A kwargs::K end diff --git a/src/iterative_wrappers.jl b/src/iterative_wrappers.jl index bb93ba632..16a50a27c 100644 --- a/src/iterative_wrappers.jl +++ b/src/iterative_wrappers.jl @@ -10,19 +10,21 @@ KrylovJL(args...; KrylovAlg = Krylov.gmres!, A generic wrapper over the Krylov.jl krylov-subspace iterative solvers. """ -struct KrylovJL{F, I, A, K} <: AbstractKrylovSubspaceMethod +struct KrylovJL{F, I, P, A, K} <: AbstractKrylovSubspaceMethod KrylovAlg::F gmres_restart::I window::I + precs::P args::A kwargs::K end function KrylovJL(args...; KrylovAlg = Krylov.gmres!, gmres_restart = 0, window = 0, + precs = nothing, kwargs...) return KrylovJL(KrylovAlg, gmres_restart, window, - args, kwargs) + precs, args, kwargs) end default_alias_A(::KrylovJL, ::Any, ::Any) = true @@ -231,8 +233,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::KrylovJL; kwargs...) cache.isfresh = false end - M = cache.Pl - N = cache.Pr + M, N = cache.Pl, cache.Pr # use no-op preconditioner for Krylov.jl (LinearAlgebra.I) when M/N is identity M = _isidentity_struct(M) ? I : M @@ -258,25 +259,21 @@ function SciMLBase.solve!(cache::LinearCache, alg::KrylovJL; kwargs...) end args = (cacheval, cache.A, cache.b) - kwargs = (atol = atol, rtol = rtol, itmax = itmax, verbose = verbose, + kwargs = (atol = atol, rtol, itmax, verbose, ldiv = true, history = true, alg.kwargs...) if cache.cacheval isa Krylov.CgSolver N !== I && @warn "$(alg.KrylovAlg) doesn't support right preconditioning." - Krylov.solve!(args...; M = M, - kwargs...) + Krylov.solve!(args...; M, kwargs...) elseif cache.cacheval isa Krylov.GmresSolver - Krylov.solve!(args...; M = M, N = N, restart = alg.gmres_restart > 0, - kwargs...) + Krylov.solve!(args...; M, N, restart = alg.gmres_restart > 0, kwargs...) elseif cache.cacheval isa Krylov.BicgstabSolver - Krylov.solve!(args...; M = M, N = N, - kwargs...) + Krylov.solve!(args...; M, N, kwargs...) elseif cache.cacheval isa Krylov.MinresSolver N !== I && @warn "$(alg.KrylovAlg) doesn't support right preconditioning." - Krylov.solve!(args...; M = M, - kwargs...) + Krylov.solve!(args...; M, kwargs...) else Krylov.solve!(args...; kwargs...) end diff --git a/test/basictests.jl b/test/basictests.jl index cb64a1246..d27c0a8dc 100644 --- a/test/basictests.jl +++ b/test/basictests.jl @@ -1,6 +1,6 @@ using LinearSolve, LinearAlgebra, SparseArrays, MultiFloats, ForwardDiff using SciMLOperators -using IterativeSolvers, KrylovKit, MKL_jll +using IterativeSolvers, KrylovKit, MKL_jll, KrylovPreconditioners using Test import Random @@ -267,10 +267,12 @@ end @testset "KrylovJL" begin kwargs = (; gmres_restart = 5) + precs = (A,p=nothing) -> (BlockJacobiPreconditioner(A, 2), I) algorithms = ( ("Default", KrylovJL(kwargs...)), ("CG", KrylovJL_CG(kwargs...)), ("GMRES", KrylovJL_GMRES(kwargs...)), + ("GMRES_prec", KrylovJL_GMRES(;precs, ldiv=false, kwargs...)), # ("BICGSTAB",KrylovJL_BICGSTAB(kwargs...)), ("MINRES", KrylovJL_MINRES(kwargs...)) )