Skip to content

Commit

Permalink
Merge pull request #494 from SciML/ap/fact_sym
Browse files Browse the repository at this point in the history
Handle LU failure and Symmetric QR
  • Loading branch information
ChrisRackauckas authored Apr 25, 2024
2 parents 4afec5a + de29257 commit f5282e5
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LinearSolve"
uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
authors = ["SciML"]
version = "2.29.0"
version = "2.29.1"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
24 changes: 23 additions & 1 deletion src/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,12 @@ function SciMLBase.solve!(cache::LinearCache, alg::LUFactorization; kwargs...)
fact = lu(A, check = false)
end
cache.cacheval = fact

if !LinearAlgebra.issuccess(fact)
return SciMLBase.build_linear_solution(
alg, cache.u, nothing, cache; retcode = ReturnCode.Failure)
end

cache.isfresh = false
end

Expand Down Expand Up @@ -187,7 +193,11 @@ function do_factorization(alg::QRFactorization, A, b, u)
A = convert(AbstractMatrix, A)
if ArrayInterface.can_setindex(typeof(A))
if alg.inplace && !(A isa SparseMatrixCSC) && !(A isa GPUArraysCore.AnyGPUArray)
fact = qr!(A, alg.pivot)
if A isa Symmetric
fact = qr(A, alg.pivot)
else
fact = qr!(A, alg.pivot)
end
else
fact = qr(A) # CUDA.jl does not allow other args!
end
Expand All @@ -203,6 +213,12 @@ function init_cacheval(alg::QRFactorization, A, b, u, Pl, Pr,
ArrayInterface.qr_instance(convert(AbstractMatrix, A), alg.pivot)
end

function init_cacheval(alg::QRFactorization, A::Symmetric, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
return qr(convert(AbstractMatrix, A), alg.pivot)
end

const PREALLOCATED_QR_ColumnNorm = ArrayInterface.qr_instance(rand(1, 1), ColumnNorm())

function init_cacheval(alg::QRFactorization{ColumnNorm}, A::Matrix{Float64}, b, u, Pl, Pr,
Expand Down Expand Up @@ -1023,6 +1039,12 @@ function SciMLBase.solve!(cache::LinearCache, alg::RFLUFactorization{P, T};
end
fact = RecursiveFactorization.lu!(A, ipiv, Val(P), Val(T), check = false)
cache.cacheval = (fact, ipiv)

if !LinearAlgebra.issuccess(fact)
return SciMLBase.build_linear_solution(
alg, cache.u, nothing, cache; retcode = ReturnCode.Failure)
end

cache.isfresh = false
end
y = ldiv!(cache.u, @get_cacheval(cache, :RFLUFactorization)[1], cache.b)
Expand Down

0 comments on commit f5282e5

Please sign in to comment.