Skip to content

Commit

Permalink
Make sparsearrays an ext
Browse files Browse the repository at this point in the history
Fixes #447
  • Loading branch information
ChrisRackauckas committed Jul 27, 2024
1 parent ac0b31a commit cb614b5
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 35 deletions.
5 changes: 2 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ version = "7.12.0"
[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"

[weakdeps]
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
Expand All @@ -16,6 +14,7 @@ CUDSS = "45b445bb-4962-46a0-9369-b4df9d0f772e"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"

Expand All @@ -27,6 +26,7 @@ ArrayInterfaceCUDSSExt = "CUDSS"
ArrayInterfaceChainRulesExt = "ChainRules"
ArrayInterfaceGPUArraysCoreExt = "GPUArraysCore"
ArrayInterfaceReverseDiffExt = "ReverseDiff"
ArrayInterfaceSparseArraysExt = "SparseArrays"
ArrayInterfaceStaticArraysCoreExt = "StaticArraysCore"
ArrayInterfaceTrackerExt = "Tracker"

Expand All @@ -42,7 +42,6 @@ LinearAlgebra = "1.10"
ReverseDiff = "1"
SparseArrays = "1.10"
StaticArraysCore = "1"
SuiteSparse = "1.10"
Tracker = "0.2"
julia = "1.10"

Expand Down
38 changes: 38 additions & 0 deletions ext/ArrayInterfaceSparseArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
module ArrayInterfaceSparseArraysExt

import ArrayInterface: buffer, has_sparsestruct, issingular, findstructralnz, bunchkaufman_instance, DEFAULT_CHOLESKY_PIVOT, cholesky_instance, ldlt_instance, lu_instance, qr_instance
using ArrayInterface.LinearAlgebra
using SparseArrays

buffer(x::SparseMatrixCSC) = getfield(x, :nzval)
buffer(x::SparseVector) = getfield(x, :nzval)
has_sparsestruct(::Type{<:SparseMatrixCSC}) = true
issingular(A::AbstractSparseMatrix) = !issuccess(lu(A, check = false))

function findstructralnz(x::SparseMatrixCSC)
rowind, colind, _ = findnz(x)
(rowind, colind)
end

function bunchkaufman_instance(A::SparseMatrixCSC)
bunchkaufman(sparse(similar(A, 1, 1)), check = false)
end

function cholesky_instance(A::Union{SparseMatrixCSC,Symmetric{<:Number,<:SparseMatrixCSC}}, pivot = DEFAULT_CHOLESKY_PIVOT)
cholesky(sparse(similar(A, 1, 1)), check = false)
end

function ldlt_instance(A::SparseMatrixCSC)
ldlt(sparse(similar(A, 1, 1)), check=false)
end

# Could be optimized but this should work for any real case.
function lu_instance(jac_prototype::SparseMatrixCSC, pivot = DEFAULT_CHOLESKY_PIVOT)
lu(sparse(rand(1,1)))
end

function qr_instance(jac_prototype::SparseMatrixCSC, pivot = DEFAULT_CHOLESKY_PIVOT)
qr(sparse(rand(1,1)))
end

end
34 changes: 2 additions & 32 deletions src/ArrayInterface.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
module ArrayInterface

using LinearAlgebra
using SparseArrays
using SuiteSparse

@static if isdefined(Base, Symbol("@assume_effects"))
using Base: @assume_effects
Expand Down Expand Up @@ -121,8 +119,6 @@ Return the buffer data that `x` points to. Unlike `parent(x::AbstractArray)`, `b
may not return another array type.
"""
buffer(x) = parent(x)
buffer(x::SparseMatrixCSC) = getfield(x, :nzval)
buffer(x::SparseVector) = getfield(x, :nzval)
buffer(@nospecialize x::Union{Base.Slice, Base.IdentityUnitRange}) = getfield(x, :indices)

"""
Expand Down Expand Up @@ -308,7 +304,6 @@ Determine whether `findstructralnz` accepts the parameter `x`.
has_sparsestruct(x) = has_sparsestruct(typeof(x))
has_sparsestruct(::Type) = false
has_sparsestruct(::Type{<:AbstractArray}) = false
has_sparsestruct(::Type{<:SparseMatrixCSC}) = true
has_sparsestruct(::Type{<:Diagonal}) = true
has_sparsestruct(::Type{<:Bidiagonal}) = true
has_sparsestruct(::Type{<:Tridiagonal}) = true
Expand All @@ -320,7 +315,6 @@ has_sparsestruct(::Type{<:SymTridiagonal}) = true
Determine whether a given abstract matrix is singular.
"""
issingular(A::AbstractMatrix) = issingular(Matrix(A))
issingular(A::AbstractSparseMatrix) = !issuccess(lu(A, check = false))
issingular(A::Matrix) = !issuccess(lu(A, check = false))
issingular(A::UniformScaling) = A.λ == 0
issingular(A::Diagonal) = any(iszero, A.diag)
Expand Down Expand Up @@ -359,11 +353,6 @@ function findstructralnz(x::Union{Tridiagonal, SymTridiagonal})
(rowind, colind)
end

function findstructralnz(x::SparseMatrixCSC)
rowind, colind, _ = findnz(x)
(rowind, colind)
end

abstract type ColoringAlgorithm end

"""
Expand Down Expand Up @@ -403,9 +392,6 @@ cheaply.
function bunchkaufman_instance(A::Matrix{T}) where T
return bunchkaufman(similar(A, 0, 0), check = false)
end
function bunchkaufman_instance(A::SparseMatrixCSC)
bunchkaufman(sparse(similar(A, 1, 1)), check = false)
end

"""
bunchkaufman_instance(a::Number) -> a
Expand All @@ -429,14 +415,10 @@ cholesky_instance(A, pivot = LinearAlgebra.RowMaximum()) -> cholesky_factorizati
Returns an instance of the Cholesky factorization object with the correct type
cheaply.
"""
function cholesky_instance(A::Matrix{T}, pivot = DEFAULT_CHOLESKY_PIVOT) where {T}
function cholesky_instance(A::Matrix{T}, pivot = DEFAULT_CHOLESKY_PIVOT) where {T}
return cholesky(similar(A, 0, 0), pivot, check = false)
end

function cholesky_instance(A::Union{SparseMatrixCSC,Symmetric{<:Number,<:SparseMatrixCSC}}, pivot = DEFAULT_CHOLESKY_PIVOT)
cholesky(sparse(similar(A, 1, 1)), check = false)
end

"""
cholesky_instance(a::Number, pivot = LinearAlgebra.RowMaximum()) -> a
Expand All @@ -458,14 +440,10 @@ ldlt_instance(A) -> ldlt_factorization_instance
Returns an instance of the LDLT factorization object with the correct type
cheaply.
"""
function ldlt_instance(A::Matrix{T}) where {T}
function ldlt_instance(A::Matrix{T}) where {T}
return ldlt_instance(SymTridiagonal(similar(A, 0, 0)))
end

function ldlt_instance(A::SparseMatrixCSC)
ldlt(sparse(similar(A, 1, 1)), check=false)
end

function ldlt_instance(A::SymTridiagonal{T,V}) where {T,V}
return LinearAlgebra.LDLt{T,SymTridiagonal{T,V}}(A)
end
Expand Down Expand Up @@ -498,9 +476,6 @@ function lu_instance(A::Matrix{T}) where {T}
info = zero(LinearAlgebra.BlasInt)
return LU{luT}(similar(A, 0, 0), ipiv, info)
end
function lu_instance(jac_prototype::SparseMatrixCSC)
SuiteSparse.UMFPACK.UmfpackLU(similar(jac_prototype, 1, 1))
end

function lu_instance(A::Symmetric{T}) where {T}
noUnitT = typeof(zero(T))
Expand Down Expand Up @@ -557,11 +532,6 @@ function qr_instance(A::Matrix{BigFloat},pivot = DEFAULT_CHOLESKY_PIVOT)
LinearAlgebra.QR(zeros(BigFloat,0,0),zeros(BigFloat,0))
end

# Could be optimized but this should work for any real case.
function qr_instance(jac_prototype::SparseMatrixCSC, pivot = DEFAULT_CHOLESKY_PIVOT)
qr(sparse(rand(1,1)))
end

"""
qr_instance(a::Number) -> a
Expand Down

0 comments on commit cb614b5

Please sign in to comment.