Skip to content

Commit

Permalink
WIP: Make more libraries into an extension?
Browse files Browse the repository at this point in the history
I think we should be considering some kind of split to LinearSolveCore vs LinearSolve, a la OrdinaryDiffEq. The default algorithm requires RecursiveFactorization and MKL, but we could definitely make a version that doesn't have these deps. But at least sparse stuff can go into an extension since those always dispatch on the existance of sparse matrices.

This PR isn't quite complete yet, more needs to be removed with sparse, but it shows the right direction.
  • Loading branch information
ChrisRackauckas committed Sep 29, 2024
1 parent 272808b commit 03edc78
Show file tree
Hide file tree
Showing 8 changed files with 377 additions and 334 deletions.
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
FastLapackInterface = "29a986be-02c6-4525-aec4-84b980013641"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
KLU = "ef3ab10e-7fda-4108-b977-705223b18434"
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
Expand All @@ -26,7 +25,6 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Sparspak = "e56a9233-b9d6-4f03-8d0f-1825330902ac"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
Expand All @@ -40,11 +38,13 @@ EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
FastAlmostBandedMatrices = "9d29842c-ecb8-4973-b1e9-a27b1157504e"
HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
KLU = "ef3ab10e-7fda-4108-b977-705223b18434"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
Pardiso = "46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[extensions]
LinearSolveBandedMatricesExt = "BandedMatrices"
Expand All @@ -55,11 +55,13 @@ LinearSolveEnzymeExt = "EnzymeCore"
LinearSolveFastAlmostBandedMatricesExt = "FastAlmostBandedMatrices"
LinearSolveHYPREExt = "HYPRE"
LinearSolveIterativeSolversExt = "IterativeSolvers"
LinearSolveKLUExt = "KLU"
LinearSolveKernelAbstractionsExt = "KernelAbstractions"
LinearSolveKrylovKitExt = "KrylovKit"
LinearSolveMetalExt = "Metal"
LinearSolvePardisoExt = "Pardiso"
LinearSolveRecursiveArrayToolsExt = "RecursiveArrayTools"
LinearSolveSparseArraysExt = "SparseArrays"

[compat]
AllocCheck = "0.1"
Expand Down
66 changes: 66 additions & 0 deletions ext/LinearSolveKLUExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
module LinearSolveKLUExt

using LinearSolve, LinearSolve.LinearAlgebra
using KLU, KLU.SparseArrays

const PREALLOCATED_KLU = KLU.KLUFactorization(SparseMatrixCSC(0, 0, [1], Int[],
Float64[]))

function init_cacheval(alg::KLUFactorization,
A, b, u, Pl, Pr,
maxiters::Int, abstol, reltol,
verbose::Bool, assumptions::OperatorAssumptions)
nothing
end

function init_cacheval(alg::KLUFactorization, A::SparseMatrixCSC{Float64, Int}, b, u, Pl,
Pr,
maxiters::Int, abstol, reltol,
verbose::Bool, assumptions::OperatorAssumptions)
PREALLOCATED_KLU
end

function init_cacheval(alg::KLUFactorization, A::AbstractSparseArray, b, u, Pl, Pr,
maxiters::Int, abstol,
reltol,
verbose::Bool, assumptions::OperatorAssumptions)
A = convert(AbstractMatrix, A)
return KLU.KLUFactorization(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
nonzeros(A)))
end

# TODO: guard this against errors
function SciMLBase.solve!(cache::LinearCache, alg::KLUFactorization; kwargs...)
A = cache.A
A = convert(AbstractMatrix, A)
if cache.isfresh
cacheval = @get_cacheval(cache, :KLUFactorization)
if alg.reuse_symbolic
if alg.check_pattern && pattern_changed(cacheval, A)
fact = KLU.klu(
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
nonzeros(A)),
check = false)
else
fact = KLU.klu!(cacheval, nonzeros(A), check = false)
end
else
# New fact each time since the sparsity pattern can change
# and thus it needs to reallocate
fact = KLU.klu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
nonzeros(A)))
end
cache.cacheval = fact
cache.isfresh = false
end
F = @get_cacheval(cache, :KLUFactorization)
if F.common.status == KLU.KLU_OK
y = ldiv!(cache.u, F, cache.b)
SciMLBase.build_linear_solution(alg, y, nothing, cache)
else
SciMLBase.build_linear_solution(
alg, cache.u, nothing, cache; retcode = ReturnCode.Infeasible)
end
end

end
178 changes: 178 additions & 0 deletions ext/LinearSolveSparseArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
module LinearSolveSparseArraysExt

using LinearSolve
import LinearSolve: SciMLBase, LinearAlgebra, PrecompileTools, init_cacheval
using LinearSolve: DefaultLinearSolver, DefaultAlgorithmChoice
using SparseArrays
using SparseArrays: AbstractSparseMatrixCSC, nonzeros, rowvals, getcolptr

# Specialize QR for the non-square case
# Missing ldiv! definitions: https://github.com/JuliaSparse/SparseArrays.jl/issues/242
function LinearSolve._ldiv!(x::Vector,
A::Union{SparseArrays.QR, LinearAlgebra.QRCompactWY,
SparseArrays.SPQR.QRSparse,
SparseArrays.CHOLMOD.Factor}, b::Vector)
x .= A \ b
end

function LinearSolve._ldiv!(x::AbstractVector,
A::Union{SparseArrays.QR, LinearAlgebra.QRCompactWY,
SparseArrays.SPQR.QRSparse,
SparseArrays.CHOLMOD.Factor}, b::AbstractVector)
x .= A \ b
end

# Ambiguity removal
function LinearSolve._ldiv!(::SVector,
A::Union{SparseArrays.CHOLMOD.Factor, LinearAlgebra.QR,
LinearAlgebra.QRCompactWY, SparseArrays.SPQR.QRSparse},
b::AbstractVector)
(A \ b)
end
function LinearSolve._ldiv!(::SVector,
A::Union{SparseArrays.CHOLMOD.Factor, LinearAlgebra.QR,
LinearAlgebra.QRCompactWY, SparseArrays.SPQR.QRSparse},
b::SVector)
(A \ b)
end

function LinearSolve.pattern_changed(fact, A::SparseArrays.SparseMatrixCSC)
!(SparseArrays.decrement(SparseArrays.getcolptr(A)) ==
fact.colptr && SparseArrays.decrement(SparseArrays.getrowval(A)) ==
fact.rowval)
end

const PREALLOCATED_UMFPACK = SparseArrays.UMFPACK.UmfpackLU(SparseMatrixCSC(0, 0, [1],
Int[], Float64[]))

function init_cacheval(alg::UMFPACKFactorization,
A, b, u, Pl, Pr,
maxiters::Int, abstol, reltol,
verbose::Bool, assumptions::OperatorAssumptions)
nothing
end

function init_cacheval(alg::UMFPACKFactorization, A::SparseMatrixCSC{Float64, Int}, b, u,
Pl, Pr,
maxiters::Int, abstol, reltol,
verbose::Bool, assumptions::OperatorAssumptions)
PREALLOCATED_UMFPACK
end

function init_cacheval(alg::UMFPACKFactorization, A::AbstractSparseArray, b, u, Pl, Pr,
maxiters::Int, abstol,
reltol,
verbose::Bool, assumptions::OperatorAssumptions)
A = convert(AbstractMatrix, A)
zerobased = SparseArrays.getcolptr(A)[1] == 0
return SparseArrays.UMFPACK.UmfpackLU(SparseMatrixCSC(size(A)..., getcolptr(A),
rowvals(A), nonzeros(A)))
end

function SciMLBase.solve!(cache::LinearCache, alg::UMFPACKFactorization; kwargs...)
A = cache.A
A = convert(AbstractMatrix, A)
if cache.isfresh
cacheval = @get_cacheval(cache, :UMFPACKFactorization)
if alg.reuse_symbolic
# Caches the symbolic factorization: https://github.com/JuliaLang/julia/pull/33738
if alg.check_pattern && pattern_changed(cacheval, A)
fact = lu(
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
nonzeros(A)),
check = false)
else
fact = lu!(cacheval,
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
nonzeros(A)), check = false)
end
else
fact = lu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)),
check = false)
end
cache.cacheval = fact
cache.isfresh = false
end

F = @get_cacheval(cache, :UMFPACKFactorization)
if F.status == SparseArrays.UMFPACK.UMFPACK_OK
y = ldiv!(cache.u, F, cache.b)
SciMLBase.build_linear_solution(alg, y, nothing, cache)
else
SciMLBase.build_linear_solution(
alg, cache.u, nothing, cache; retcode = ReturnCode.Infeasible)
end
end

const PREALLOCATED_CHOLMOD = cholesky(SparseMatrixCSC(0, 0, [1], Int[], Float64[]))

function init_cacheval(alg::CHOLMODFactorization,
A, b, u, Pl, Pr,
maxiters::Int, abstol, reltol,
verbose::Bool, assumptions::OperatorAssumptions)
nothing
end

function init_cacheval(alg::CHOLMODFactorization,
A::Union{SparseMatrixCSC{T, Int}, Symmetric{T, SparseMatrixCSC{T, Int}}}, b, u,
Pl, Pr,
maxiters::Int, abstol, reltol,
verbose::Bool, assumptions::OperatorAssumptions) where {T <:
Union{Float32, Float64}}
PREALLOCATED_CHOLMOD
end

function SciMLBase.solve!(cache::LinearCache, alg::CHOLMODFactorization; kwargs...)
A = cache.A
A = convert(AbstractMatrix, A)

if cache.isfresh
cacheval = @get_cacheval(cache, :CHOLMODFactorization)
fact = cholesky(A; check = false)
if !LinearAlgebra.issuccess(fact)
ldlt!(fact, A; check = false)
end
cache.cacheval = fact
cache.isfresh = false
end

cache.u .= @get_cacheval(cache, :CHOLMODFactorization) \ cache.b
SciMLBase.build_linear_solution(alg, cache.u, nothing, cache)
end

function LinearSolve.defaultalg(
A::Symmetric{<:Number, <:SparseMatrixCSC}, b, ::OperatorAssumptions{Bool})
DefaultLinearSolver(DefaultAlgorithmChoice.CHOLMODFactorization)
end

function LinearSolve.defaultalg(A::AbstractSparseMatrixCSC{Tv, Ti}, b,
assump::OperatorAssumptions{Bool}) where {Tv, Ti}
if assump.issq
DefaultLinearSolver(DefaultAlgorithmChoice.SparspakFactorization)
else
error("Generic number sparse factorization for non-square is not currently handled")
end
end

function LinearSolve.defaultalg(A::AbstractSparseMatrixCSC{<:Union{Float64, ComplexF64}, Ti}, b,
assump::OperatorAssumptions{Bool}) where {Ti}
if assump.issq
if length(b) <= 10_000 && length(nonzeros(A)) / length(A) < 2e-4
DefaultLinearSolver(DefaultAlgorithmChoice.KLUFactorization)
else
DefaultLinearSolver(DefaultAlgorithmChoice.UMFPACKFactorization)
end
else
DefaultLinearSolver(DefaultAlgorithmChoice.QRFactorization)
end
end

PrecompileTools.@compile_workload begin
A = sprand(4, 4, 0.3) + I
b = rand(4)
prob = LinearProblem(A, b)
sol = solve(prob, KLUFactorization())
sol = solve(prob, UMFPACKFactorization())
end

end
Loading

0 comments on commit 03edc78

Please sign in to comment.