Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Line Search with Negative Curvature Detection for MINRES Based on Liu et al. (2022) #969

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions src/krylov_solvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,14 @@ mutable struct MinresSolver{T,FC,S} <: KrylovSolver{T,FC,S}
x :: S
r1 :: S
r2 :: S
rk :: S
w1 :: S
w2 :: S
y :: S
v :: S
err_vec :: Vector{T}
warm_start :: Bool
stats :: SimpleStats{T}
stats :: conStats{T}
end

function MinresSolver(kc::KrylovConstructor; window :: Int=5)
Expand All @@ -131,13 +132,14 @@ function MinresSolver(kc::KrylovConstructor; window :: Int=5)
x = similar(kc.vn)
r1 = similar(kc.vn)
r2 = similar(kc.vn)
rk = similar(kc.vn_empty)
w1 = similar(kc.vn)
w2 = similar(kc.vn)
y = similar(kc.vn)
v = similar(kc.vn_empty)
err_vec = zeros(T, window)
stats = SimpleStats(0, false, false, T[], T[], T[], 0.0, "unknown")
solver = MinresSolver{T,FC,S}(m, n, Δx, x, r1, r2, w1, w2, y, v, err_vec, false, stats)
stats = conStats(0, false, false, false, false, T[], T[], T[], 0.0, "unknown")
solver = MinresSolver{T,FC,S}(m, n, Δx, x, r1, r2, rk, w1, w2, y, v, err_vec, false, stats)
return solver
end

Expand All @@ -148,13 +150,14 @@ function MinresSolver(m, n, S; window :: Int=5)
x = S(undef, n)
r1 = S(undef, n)
r2 = S(undef, n)
rk = S(undef, 0)
w1 = S(undef, n)
w2 = S(undef, n)
y = S(undef, n)
v = S(undef, 0)
err_vec = zeros(T, window)
stats = SimpleStats(0, false, false, T[], T[], T[], 0.0, "unknown")
solver = MinresSolver{T,FC,S}(m, n, Δx, x, r1, r2, w1, w2, y, v, err_vec, false, stats)
stats = conStats(0, false, false, false, false, T[], T[], T[], 0.0, "unknown")
solver = MinresSolver{T,FC,S}(m, n, Δx, x, r1, r2, rk, w1, w2, y, v, err_vec, false, stats)
return solver
end

Expand Down
52 changes: 52 additions & 0 deletions src/krylov_stats.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,58 @@ function copyto!(dest :: SimpleStats, src :: SimpleStats)
return dest
end

"""
Type for storing statistics returned by Conjugate Methods.
Methods icludes:
- CG (TODO)
- CR (TODO)
- MINRES
The fields are as follows:
- niter
- solved
- nonposi_curv: when a non-positive curvature is detected
- linesearch: when a line search is performed
- inconsistent
- residuals
- Aresiduals
- Acond
- timer
- status
"""
mutable struct conStats{T} <: KrylovStats{T}
niter :: Int
solved :: Bool
nonposi_curv :: Bool
linesearch :: Bool
inconsistent :: Bool
residuals :: Vector{T}
Aresiduals :: Vector{T}
Acond :: Vector{T}
timer :: Float64
status :: String
end

function reset!(stats :: conStats)
empty!(stats.residuals)
empty!(stats.Aresiduals)
empty!(stats.Acond)
end

function copyto!(dest :: conStats, src :: conStats)
dest.niter = src.niter
dest.solved = src.solved
dest.nonposi_curv = src.nonposi_curv
dest.linesearch = src.linesearch
dest.inconsistent = src.inconsistent
dest.residuals = copy(src.residuals)
dest.Aresiduals = copy(src.Aresiduals)
dest.Acond = copy(src.Acond)
dest.timer = src.timer
dest.status = src.status
return dest
end


"""
Type for storing statistics returned by LSMR.
The fields are as follows:
Expand Down
55 changes: 51 additions & 4 deletions src/minres.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@
# Dominique Orban, <[email protected]>
# Brussels, Belgium, June 2015.
# Montréal, August 2015.
#
# Liu, Yang, and Fred Roosta. "MINRES: from negative curvature detection to monotonicity properties." SIAM Journal on Optimization 32, no. 4 (2022): 2636-2661.

export minres, minres!

"""
(x, stats) = minres(A, b::AbstractVector{FC};
M=I, ldiv::Bool=false, window::Int=5,
λ::T=zero(T), atol::T=√eps(T),
linesearch::Bool=false, λ::T=zero(T), atol::T=√eps(T),
rtol::T=√eps(T), etol::T=√eps(T),
conlim::T=1/√eps(T), itmax::Int=0,
timemax::Float64=Inf, verbose::Int=0, history::Bool=false,
Expand Down Expand Up @@ -68,6 +70,7 @@ MINRES produces monotonic residuals ‖r‖₂ and optimality residuals ‖Aᴴr
* `M`: linear operator that models a Hermitian positive-definite matrix of size `n` used for centered preconditioning;
* `ldiv`: define whether the preconditioner uses `ldiv!` or `mul!`;
* `window`: number of iterations used to accumulate a lower bound on the error;
* `linesearch`: if `true`, indicate that the solution is to be used in an inexact Newton method with linesearch. If negative curvature is detected at iteration k > 0, the solution of iteration k-1 is returned. If negative curvature is detected at iteration 0, the right-hand side is returned (i.e., the negative gradient);
* `λ`: regularization parameter;
* `atol`: absolute stopping tolerance based on the residual norm;
* `rtol`: relative stopping tolerance based on the residual norm;
Expand All @@ -88,6 +91,8 @@ MINRES produces monotonic residuals ‖r‖₂ and optimality residuals ‖Aᴴr
#### Reference

* C. C. Paige and M. A. Saunders, [*Solution of Sparse Indefinite Systems of Linear Equations*](https://doi.org/10.1137/0712047), SIAM Journal on Numerical Analysis, 12(4), pp. 617--629, 1975.

* Liu, Yang & Roosta, Fred. (2022). A Newton-MR algorithm with complexity guarantees for nonconvex smooth unconstrained optimization. 10.48550/arXiv.2208.07095.
"""
function minres end

Expand All @@ -108,6 +113,7 @@ def_optargs_minres = (:(x0::AbstractVector),)

def_kwargs_minres = (:(; M = I ),
:(; ldiv::Bool = false ),
:(; linesearch::Bool = false ),
:(; λ::T = zero(T) ),
:(; atol::T = √eps(T) ),
:(; rtol::T = √eps(T) ),
Expand All @@ -124,7 +130,7 @@ def_kwargs_minres = extract_parameters.(def_kwargs_minres)

args_minres = (:A, :b)
optargs_minres = (:x0,)
kwargs_minres = (:M, :ldiv, :λ, :atol, :rtol, :etol, :conlim, :itmax, :timemax, :verbose, :history, :callback, :iostream)
kwargs_minres = (:M, :ldiv, :linesearch ,:λ, :atol, :rtol, :etol, :conlim, :itmax, :timemax, :verbose, :history, :callback, :iostream)

@eval begin
function minres!(solver :: MinresSolver{T,FC,S}, $(def_args_minres...); $(def_kwargs_minres...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, S <: AbstractVector{FC}}
Expand All @@ -148,12 +154,16 @@ kwargs_minres = (:M, :ldiv, :λ, :atol, :rtol, :etol, :conlim, :itmax, :timemax,

# Set up workspace.
allocate_if(!MisI, solver, :v, S, solver.x) # The length of v is n
allocate_if(linesearch, solver, :rk, S, solver.x) # The length of rk is n
Δx, x, r1, r2, w1, w2, y = solver.Δx, solver.x, solver.r1, solver.r2, solver.w1, solver.w2, solver.y
err_vec, stats = solver.err_vec, solver.stats
warm_start = solver.warm_start
rNorms, ArNorms, Aconds = stats.residuals, stats.Aresiduals, stats.Acond
reset!(stats)
stats.linesearch = linesearch

v = MisI ? r2 : solver.v
rk = linesearch ? solver.rk : r2

ϵM = eps(T)
ctol = conlim > 0 ? 1 / conlim : zero(T)
Expand All @@ -169,22 +179,30 @@ kwargs_minres = (:M, :ldiv, :λ, :atol, :rtol, :etol, :conlim, :itmax, :timemax,
kcopy!(n, r1, b) # r1 ← b
end

linesearch && kcopy!(n, rk, r1) # rk ← r1

# Initialize Lanczos process.
# β₁ M v₁ = b.
kcopy!(n, r2, r1) # r2 ← r1
MisI || mulorldiv!(v, M, r1, ldiv)

rNorm = knorm_elliptic(n, r2, r1) # = ‖r‖
history && push!(rNorms, rNorm)


β₁ = kdotr(m, r1, v)
β₁ < 0 && error("Preconditioner is not positive definite")
if β₁ == 0
stats.niter = 0
stats.solved, stats.inconsistent = true, false
stats.timer = start_time |> ktimer
stats.status = "x is a zero-residual solution"
stats.status = "b is a zero-curvature directions"
history && push!(rNorms, β₁)
history && push!(ArNorms, zero(T))
history && push!(Aconds, zero(T))
warm_start && kaxpy!(n, one(FC), Δx, x)
solver.warm_start = false
stats.nonposi_curv = true
return solver
end
β₁ = sqrt(β₁)
Expand Down Expand Up @@ -239,7 +257,7 @@ kwargs_minres = (:M, :ldiv, :λ, :atol, :rtol, :etol, :conlim, :itmax, :timemax,
# Generate next Lanczos vector.
mul!(y, A, v)
λ ≠ 0 && kaxpy!(n, λ, v, y) # (y = y + λ * v)
kscal!(n, one(FC) / β, y)
kscal!(n, one(FC) / β, y) # (y = y / β)
iter ≥ 2 && kaxpy!(n, -β / oldβ, r1, y) # (y = y - β / oldβ * r1)

α = kdotr(n, v, y) / β
Expand Down Expand Up @@ -275,6 +293,24 @@ kwargs_minres = (:M, :ldiv, :λ, :atol, :rtol, :etol, :conlim, :itmax, :timemax,
ArNorm = ϕbar * root # = ‖Aᴴrₖ₋₁‖
history && push!(ArNorms, ArNorm)

# Check for nonpositive curvature
if linesearch
cγ = cs * γbar
if cγ ≥ 0
(verbose > 0) && @printf(iostream, "nonpositive curvature detected: cs * γbar = %e\n", cγ)
stats.solved = true
stats.niter = iter
stats.inconsistent = false
stats.timer = start_time |> ktimer
stats.status = "nonpositive curvature"
iter == 1 && kcopy!(n, x, r1)
solver.warm_start = false
# when we use the linesearch and encounter negative curvature, we return the last residual xk but user has access to rk from solver.rk
stats.nonposi_curv = true
return solver
end
end

# Compute the next plane rotation.
γ = sqrt(γbar * γbar + β * β)
γ = max(γ, ϵM)
Expand All @@ -283,6 +319,16 @@ kwargs_minres = (:M, :ldiv, :λ, :atol, :rtol, :etol, :conlim, :itmax, :timemax,
ϕ = cs * ϕbar
ϕbar = sn * ϕbar

if linesearch
# calculating the residual rk = sn*sn * rk - ϕbar * cs * v
sn2 = sn * sn
kscal!(n, sn2, rk ) # rk = sn2 * rk
ϕ_c = -ϕbar * cs
kaxpy!(n, ϕ_c, v, rk) # rk = rk + ϕ_c * v
rk = sn*sn * rk - ϕbar * cs * v
end


# Final update of w.
kscal!(n, one(FC) / γ, w)

Expand Down Expand Up @@ -372,6 +418,7 @@ kwargs_minres = (:M, :ldiv, :λ, :atol, :rtol, :etol, :conlim, :itmax, :timemax,
user_requested_exit && (status = "user-requested exit")
overtimed && (status = "time limit exceeded")

stats.nonposi_curv = zero_resid
# Update x
warm_start && kaxpy!(n, one(FC), Δx, x)
solver.warm_start = false
Expand Down
23 changes: 22 additions & 1 deletion test/test_minres.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
A, b = zero_rhs(FC=FC)
(x, stats) = minres(A, b)
@test norm(x) == 0
@test stats.status == "x is a zero-residual solution"
@test stats.status == "b is a zero-curvature directions"

# Shifted system
A, b = symmetric_indefinite(FC=FC)
Expand Down Expand Up @@ -69,6 +69,27 @@
@test(resid ≤ minres_tol * norm(A) * norm(x))
@test(stats.solved)

# Test linesearch
A, b = symmetric_indefinite(FC=FC)
x, stats = minres(A, b, linesearch=true)
@test stats.status == "nonpositive curvature"

# Test Linesearch which would stop on the first call since A is negative definite
A, b = symmetric_indefinite(FC=FC; shift = 5)
x, stats = minres(A, b, linesearch=true)
@test stats.status == "nonpositive curvature"
@test stats.niter == 1 # in Minres they add 1 to the number of iterations first step
@test all(x .== b)
@test stats.solved == true

# Test when b^TAb=0 and linesearch is true
A, b = system_zero_quad(FC=FC)
x, stats = minres(A, b, linesearch=true)
@test stats.status == "nonpositive curvature"
@test all(x .== b)
@test stats.solved == true


# test callback function
solver = MinresSolver(A, b)
storage_vec = similar(b, size(A, 1))
Expand Down
Loading