From 9e5a85aa373ddff10124f376456bf256539686eb Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 15 Oct 2023 16:09:34 -0400 Subject: [PATCH] Add dispatches for Transpose and Adjoint for Banded Matrices --- Project.toml | 2 +- ext/ArrayInterfaceBandedMatricesExt.jl | 24 ++++++++---- test/bandedmatrices.jl | 53 ++++++++++++++++++++------ 3 files changed, 60 insertions(+), 19 deletions(-) diff --git a/Project.toml b/Project.toml index 5c84a8b24..32c6f9bef 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ArrayInterface" uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" -version = "7.4.11" +version = "7.5.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/ext/ArrayInterfaceBandedMatricesExt.jl b/ext/ArrayInterfaceBandedMatricesExt.jl index 358434f15..a3b8ca671 100644 --- a/ext/ArrayInterfaceBandedMatricesExt.jl +++ b/ext/ArrayInterfaceBandedMatricesExt.jl @@ -1,16 +1,26 @@ module ArrayInterfaceBandedMatricesExt - if isdefined(Base, :get_extension) using ArrayInterface using ArrayInterface: BandedMatrixIndex using BandedMatrices + using LinearAlgebra else using ..ArrayInterface using ..ArrayInterface: BandedMatrixIndex using ..BandedMatrices + using ..LinearAlgebra end +const TransOrAdjBandedMatrix = Union{ + Adjoint{T, <:BandedMatrix{T}}, + Transpose{T, <:BandedMatrix{T}}, +} where {T} + +const AllBandedMatrix = Union{ + BandedMatrix{T}, + TransOrAdjBandedMatrix{T}, +} where {T} Base.firstindex(i::BandedMatrixIndex) = 1 Base.lastindex(i::BandedMatrixIndex) = i.count @@ -45,12 +55,12 @@ end function BandedMatrixIndex(rowsize, colsize, lowerbandwidth, upperbandwidth, isrow) upperbandwidth > -lowerbandwidth || throw(ErrorException("Invalid Bandwidths")) - bandinds = upperbandwidth:-1:-lowerbandwidth + bandinds = upperbandwidth:-1:(-lowerbandwidth) bandsizes = [_bandsize(band, rowsize, colsize) for band in bandinds] BandedMatrixIndex(sum(bandsizes), rowsize, colsize, bandinds, bandsizes, isrow) end -function ArrayInterface.findstructralnz(x::BandedMatrices.BandedMatrix) +function ArrayInterface.findstructralnz(x::AllBandedMatrix) l, u = BandedMatrices.bandwidths(x) rowsize, colsize = Base.size(x) rowind = BandedMatrixIndex(rowsize, colsize, l, u, true) @@ -58,11 +68,11 @@ function ArrayInterface.findstructralnz(x::BandedMatrices.BandedMatrix) return (rowind, colind) end -ArrayInterface.has_sparsestruct(::Type{<:BandedMatrices.BandedMatrix}) = true -ArrayInterface.isstructured(::Type{<:BandedMatrices.BandedMatrix}) = true -ArrayInterface.fast_matrix_colors(::Type{<:BandedMatrices.BandedMatrix}) = true +ArrayInterface.has_sparsestruct(::Type{<:AllBandedMatrix}) = true +ArrayInterface.isstructured(::Type{<:AllBandedMatrix}) = true +ArrayInterface.fast_matrix_colors(::Type{<:AllBandedMatrix}) = true -function ArrayInterface.matrix_colors(A::BandedMatrices.BandedMatrix) +function ArrayInterface.matrix_colors(A::AllBandedMatrix) l, u = BandedMatrices.bandwidths(A) width = u + l + 1 return ArrayInterface._cycle(1:width, Base.size(A, 2)) diff --git a/test/bandedmatrices.jl b/test/bandedmatrices.jl index a6142efd0..94626aac3 100644 --- a/test/bandedmatrices.jl +++ b/test/bandedmatrices.jl @@ -1,19 +1,50 @@ - using ArrayInterface using BandedMatrices using Test -B=BandedMatrix(Ones(5,5), (-1,2)) -B[band(1)].=[1,2,3,4] -B[band(2)].=[5,6,7] +function checkequal(idx1::ArrayInterface.BandedMatrixIndex, + idx2::ArrayInterface.BandedMatrixIndex) + return idx1.rowsize == idx2.rowsize && idx1.colsize == idx2.colsize && + idx1.bandinds == idx2.bandinds && idx1.bandsizes == idx2.bandsizes && + idx1.isrow == idx2.isrow && idx1.count == idx2.count +end + +B = BandedMatrix(Ones(5, 5), (-1, 2)) +B[band(1)] .= [1, 2, 3, 4] +B[band(2)] .= [5, 6, 7] @test ArrayInterface.has_sparsestruct(B) -rowind,colind=ArrayInterface.findstructralnz(B) -@test [B[rowind[i],colind[i]] for i in 1:length(rowind)]==[5,6,7,1,2,3,4] -B=BandedMatrix(Ones(4,6), (-1,2)) -B[band(1)].=[1,2,3,4] -B[band(2)].=[5,6,7,8] -rowind,colind=ArrayInterface.findstructralnz(B) -@test [B[rowind[i],colind[i]] for i in 1:length(rowind)]==[5,6,7,8,1,2,3,4] +rowind, colind = ArrayInterface.findstructralnz(B) +@test [B[rowind[i], colind[i]] for i in 1:length(rowind)] == [5, 6, 7, 1, 2, 3, 4] +B = BandedMatrix(Ones(4, 6), (-1, 2)) +B[band(1)] .= [1, 2, 3, 4] +B[band(2)] .= [5, 6, 7, 8] +rowind, colind = ArrayInterface.findstructralnz(B) +@test [B[rowind[i], colind[i]] for i in 1:length(rowind)] == [5, 6, 7, 8, 1, 2, 3, 4] @test ArrayInterface.isstructured(typeof(B)) @test ArrayInterface.fast_matrix_colors(typeof(B)) +for op in (adjoint, transpose) + B = BandedMatrix(Ones(5, 5), (-1, 2)) + B[band(1)] .= [1, 2, 3, 4] + B[band(2)] .= [5, 6, 7] + B′ = op(B) + @test ArrayInterface.has_sparsestruct(B′) + rowind′, colind′ = ArrayInterface.findstructralnz(B′) + rowind′′, colind′′ = ArrayInterface.findstructralnz(BandedMatrix(B′)) + @test checkequal(rowind′, rowind′′) + @test checkequal(colind′, colind′′) + + B = BandedMatrix(Ones(4, 6), (-1, 2)) + B[band(1)] .= [1, 2, 3, 4] + B[band(2)] .= [5, 6, 7, 8] + B′ = op(B) + rowind′, colind′ = ArrayInterface.findstructralnz(B′) + rowind′′, colind′′ = ArrayInterface.findstructralnz(BandedMatrix(B′)) + @test checkequal(rowind′, rowind′′) + @test checkequal(colind′, colind′′) + + @test ArrayInterface.isstructured(typeof(B′)) + @test ArrayInterface.fast_matrix_colors(typeof(B′)) + + @test ArrayInterface.matrix_colors(B′) == ArrayInterface.matrix_colors(BandedMatrix(B′)) +end