Skip to content

Commit

Permalink
Merge pull request #422 from avik-pal/ap/banded
Browse files Browse the repository at this point in the history
Add dispatches for Transpose and Adjoint for Banded Matrices
  • Loading branch information
ChrisRackauckas authored Oct 16, 2023
2 parents e63b793 + 9e5a85a commit ae71cf2
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 19 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ArrayInterface"
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
version = "7.4.11"
version = "7.5.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
24 changes: 17 additions & 7 deletions ext/ArrayInterfaceBandedMatricesExt.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,26 @@
module ArrayInterfaceBandedMatricesExt


if isdefined(Base, :get_extension)
using ArrayInterface
using ArrayInterface: BandedMatrixIndex
using BandedMatrices
using LinearAlgebra
else
using ..ArrayInterface
using ..ArrayInterface: BandedMatrixIndex
using ..BandedMatrices
using ..LinearAlgebra
end

const TransOrAdjBandedMatrix = Union{
Adjoint{T, <:BandedMatrix{T}},
Transpose{T, <:BandedMatrix{T}},
} where {T}

const AllBandedMatrix = Union{
BandedMatrix{T},
TransOrAdjBandedMatrix{T},
} where {T}

Base.firstindex(i::BandedMatrixIndex) = 1
Base.lastindex(i::BandedMatrixIndex) = i.count
Expand Down Expand Up @@ -45,24 +55,24 @@ end

function BandedMatrixIndex(rowsize, colsize, lowerbandwidth, upperbandwidth, isrow)
upperbandwidth > -lowerbandwidth || throw(ErrorException("Invalid Bandwidths"))
bandinds = upperbandwidth:-1:-lowerbandwidth
bandinds = upperbandwidth:-1:(-lowerbandwidth)
bandsizes = [_bandsize(band, rowsize, colsize) for band in bandinds]
BandedMatrixIndex(sum(bandsizes), rowsize, colsize, bandinds, bandsizes, isrow)
end

function ArrayInterface.findstructralnz(x::BandedMatrices.BandedMatrix)
function ArrayInterface.findstructralnz(x::AllBandedMatrix)
l, u = BandedMatrices.bandwidths(x)
rowsize, colsize = Base.size(x)
rowind = BandedMatrixIndex(rowsize, colsize, l, u, true)
colind = BandedMatrixIndex(rowsize, colsize, l, u, false)
return (rowind, colind)
end

ArrayInterface.has_sparsestruct(::Type{<:BandedMatrices.BandedMatrix}) = true
ArrayInterface.isstructured(::Type{<:BandedMatrices.BandedMatrix}) = true
ArrayInterface.fast_matrix_colors(::Type{<:BandedMatrices.BandedMatrix}) = true
ArrayInterface.has_sparsestruct(::Type{<:AllBandedMatrix}) = true
ArrayInterface.isstructured(::Type{<:AllBandedMatrix}) = true
ArrayInterface.fast_matrix_colors(::Type{<:AllBandedMatrix}) = true

function ArrayInterface.matrix_colors(A::BandedMatrices.BandedMatrix)
function ArrayInterface.matrix_colors(A::AllBandedMatrix)
l, u = BandedMatrices.bandwidths(A)
width = u + l + 1
return ArrayInterface._cycle(1:width, Base.size(A, 2))
Expand Down
53 changes: 42 additions & 11 deletions test/bandedmatrices.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,50 @@

using ArrayInterface
using BandedMatrices
using Test

B=BandedMatrix(Ones(5,5), (-1,2))
B[band(1)].=[1,2,3,4]
B[band(2)].=[5,6,7]
function checkequal(idx1::ArrayInterface.BandedMatrixIndex,
idx2::ArrayInterface.BandedMatrixIndex)
return idx1.rowsize == idx2.rowsize && idx1.colsize == idx2.colsize &&
idx1.bandinds == idx2.bandinds && idx1.bandsizes == idx2.bandsizes &&
idx1.isrow == idx2.isrow && idx1.count == idx2.count
end

B = BandedMatrix(Ones(5, 5), (-1, 2))
B[band(1)] .= [1, 2, 3, 4]
B[band(2)] .= [5, 6, 7]
@test ArrayInterface.has_sparsestruct(B)
rowind,colind=ArrayInterface.findstructralnz(B)
@test [B[rowind[i],colind[i]] for i in 1:length(rowind)]==[5,6,7,1,2,3,4]
B=BandedMatrix(Ones(4,6), (-1,2))
B[band(1)].=[1,2,3,4]
B[band(2)].=[5,6,7,8]
rowind,colind=ArrayInterface.findstructralnz(B)
@test [B[rowind[i],colind[i]] for i in 1:length(rowind)]==[5,6,7,8,1,2,3,4]
rowind, colind = ArrayInterface.findstructralnz(B)
@test [B[rowind[i], colind[i]] for i in 1:length(rowind)] == [5, 6, 7, 1, 2, 3, 4]
B = BandedMatrix(Ones(4, 6), (-1, 2))
B[band(1)] .= [1, 2, 3, 4]
B[band(2)] .= [5, 6, 7, 8]
rowind, colind = ArrayInterface.findstructralnz(B)
@test [B[rowind[i], colind[i]] for i in 1:length(rowind)] == [5, 6, 7, 8, 1, 2, 3, 4]
@test ArrayInterface.isstructured(typeof(B))
@test ArrayInterface.fast_matrix_colors(typeof(B))

for op in (adjoint, transpose)
B = BandedMatrix(Ones(5, 5), (-1, 2))
B[band(1)] .= [1, 2, 3, 4]
B[band(2)] .= [5, 6, 7]
B′ = op(B)
@test ArrayInterface.has_sparsestruct(B′)
rowind′, colind′ = ArrayInterface.findstructralnz(B′)
rowind′′, colind′′ = ArrayInterface.findstructralnz(BandedMatrix(B′))
@test checkequal(rowind′, rowind′′)
@test checkequal(colind′, colind′′)

B = BandedMatrix(Ones(4, 6), (-1, 2))
B[band(1)] .= [1, 2, 3, 4]
B[band(2)] .= [5, 6, 7, 8]
B′ = op(B)
rowind′, colind′ = ArrayInterface.findstructralnz(B′)
rowind′′, colind′′ = ArrayInterface.findstructralnz(BandedMatrix(B′))
@test checkequal(rowind′, rowind′′)
@test checkequal(colind′, colind′′)

@test ArrayInterface.isstructured(typeof(B′))
@test ArrayInterface.fast_matrix_colors(typeof(B′))

@test ArrayInterface.matrix_colors(B′) == ArrayInterface.matrix_colors(BandedMatrix(B′))
end

0 comments on commit ae71cf2

Please sign in to comment.