Skip to content

Commit

Permalink
Reland "Reroute Symmetric/Hermitian + Diagonal through triangular"
Browse files Browse the repository at this point in the history
This backports the following commits:
commit 9690961c426ce2640d7db6c89952e69f87873a93
Author: Jishnu Bhattacharya <[email protected]>
Date:   Mon Apr 29 21:43:31 2024 +0530

    Add upper/lowertriangular functions and use in applytri (#53573)

    We may use the fact that a `Diagonal` is already triangular to avoid
    adding a wrapper.

    Fixes the specific example in
    #53564, although not the
    broader issue. This is because it changes the operation from a
    `UpperTriangular + UpperTriangular` to a `UpperTriangular + Diagonal`,
    which uses broadcasting. The latter operation may also allow one to
    define more efficient methods.

commit 77821cdddb968eeabf31ccb6b214ccf59a604c68
Author: Jishnu Bhattacharya <[email protected]>
Date:   Wed Aug 28 00:53:31 2024 +0530

    Remove Diagonal-triangular specialization

commit 621fb2e739a04207df63857700aca3562b41b5eb
Author: Jishnu Bhattacharya <[email protected]>
Date:   Wed Aug 28 00:50:49 2024 +0530

    Restrict broadcasting to strided-diag Diagonal

commit 58eb2045ddb5dbbfdb759c06239ca54751e73d71
Author: Jishnu Bhattacharya <[email protected]>
Date:   Wed Aug 28 00:44:47 2024 +0530

    Add tests for partly filled parent

commit 5aa6080a580bfbc9453e94a06f3e379e4517b316
Author: Jishnu Bhattacharya <[email protected]>
Date:   Tue Aug 27 20:42:07 2024 +0530

    Reroute Symmetric/Hermitian + Diagonal through triangular
  • Loading branch information
jishnub committed Sep 12, 2024
1 parent f09de94 commit d40fa57
Show file tree
Hide file tree
Showing 8 changed files with 326 additions and 21 deletions.
18 changes: 3 additions & 15 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -250,21 +250,6 @@ end
(+)(Da::Diagonal, Db::Diagonal) = Diagonal(Da.diag + Db.diag)
(-)(Da::Diagonal, Db::Diagonal) = Diagonal(Da.diag - Db.diag)

for f in (:+, :-)
@eval function $f(D::Diagonal{<:Number}, S::Symmetric)
return Symmetric($f(D, S.data), sym_uplo(S.uplo))
end
@eval function $f(S::Symmetric, D::Diagonal{<:Number})
return Symmetric($f(S.data, D), sym_uplo(S.uplo))
end
@eval function $f(D::Diagonal{<:Real}, H::Hermitian)
return Hermitian($f(D, H.data), sym_uplo(H.uplo))
end
@eval function $f(H::Hermitian, D::Diagonal{<:Real})
return Hermitian($f(H.data, D), sym_uplo(H.uplo))
end
end

(*)(x::Number, D::Diagonal) = Diagonal(x * D.diag)
(*)(D::Diagonal, x::Number) = Diagonal(D.diag * x)
(/)(D::Diagonal, x::Number) = Diagonal(D.diag / x)
Expand Down Expand Up @@ -991,3 +976,6 @@ end
function Base.muladd(A::Diagonal, B::Diagonal, z::Diagonal)
Diagonal(A.diag .* B.diag .+ z.diag)
end

uppertriangular(D::Diagonal) = D
lowertriangular(D::Diagonal) = D
19 changes: 19 additions & 0 deletions stdlib/LinearAlgebra/src/special.jl
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,25 @@ function (-)(A::UniformScaling, B::Diagonal)
Diagonal(Ref(A) .- B.diag)
end

for f in (:+, :-)
@eval function $f(D::Diagonal{<:Number}, S::Symmetric)
uplo = sym_uplo(S.uplo)
return Symmetric(parentof_applytri($f, Symmetric(D, uplo), S), uplo)
end
@eval function $f(S::Symmetric, D::Diagonal{<:Number})
uplo = sym_uplo(S.uplo)
return Symmetric(parentof_applytri($f, S, Symmetric(D, uplo)), uplo)
end
@eval function $f(D::Diagonal{<:Real}, H::Hermitian)
uplo = sym_uplo(H.uplo)
return Hermitian(parentof_applytri($f, Hermitian(D, uplo), H), uplo)
end
@eval function $f(H::Hermitian, D::Diagonal{<:Real})
uplo = sym_uplo(H.uplo)
return Hermitian(parentof_applytri($f, H, Hermitian(D, uplo)), uplo)
end
end

## Diagonal construction from UniformScaling
Diagonal{T}(s::UniformScaling, m::Integer) where {T} = Diagonal{T}(fill(T(s.λ), m))
Diagonal(s::UniformScaling, m::Integer) = Diagonal{eltype(s)}(s, m)
Expand Down
12 changes: 6 additions & 6 deletions stdlib/LinearAlgebra/src/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -277,21 +277,21 @@ diag(A::Hermitian) = hermitian.(diag(parent(A)), sym_uplo(A.uplo))

function applytri(f, A::HermOrSym)
if A.uplo == 'U'
f(UpperTriangular(A.data))
f(uppertriangular(A.data))
else
f(LowerTriangular(A.data))
f(lowertriangular(A.data))
end
end

function applytri(f, A::HermOrSym, B::HermOrSym)
if A.uplo == B.uplo == 'U'
f(UpperTriangular(A.data), UpperTriangular(B.data))
f(uppertriangular(A.data), uppertriangular(B.data))
elseif A.uplo == B.uplo == 'L'
f(LowerTriangular(A.data), LowerTriangular(B.data))
f(lowertriangular(A.data), lowertriangular(B.data))
elseif A.uplo == 'U'
f(UpperTriangular(A.data), UpperTriangular(_conjugation(B)(B.data)))
f(uppertriangular(A.data), uppertriangular(_conjugation(B)(B.data)))
else # A.uplo == 'L'
f(UpperTriangular(_conjugation(A)(A.data)), UpperTriangular(B.data))
f(uppertriangular(_conjugation(A)(A.data)), uppertriangular(B.data))
end
end
parentof_applytri(f, args...) = applytri(parent f, args...)
Expand Down
8 changes: 8 additions & 0 deletions stdlib/LinearAlgebra/src/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,14 @@ const UpperOrUnitUpperTriangular{T,S} = Union{UpperTriangular{T,S}, UnitUpperTri
const LowerOrUnitLowerTriangular{T,S} = Union{LowerTriangular{T,S}, UnitLowerTriangular{T,S}}
const UpperOrLowerTriangular{T,S} = Union{UpperOrUnitUpperTriangular{T,S}, LowerOrUnitLowerTriangular{T,S}}

uppertriangular(M) = UpperTriangular(M)
lowertriangular(M) = LowerTriangular(M)

uppertriangular(U::UpperOrUnitUpperTriangular) = U
lowertriangular(U::LowerOrUnitLowerTriangular) = U

Base.dataids(A::UpperOrLowerTriangular) = Base.dataids(A.data)

imag(A::UpperTriangular) = UpperTriangular(imag(A.data))
imag(A::LowerTriangular) = LowerTriangular(imag(A.data))
imag(A::UpperTriangular{<:Any,<:StridedMaybeAdjOrTransMat}) = imag.(A)
Expand Down
6 changes: 6 additions & 0 deletions stdlib/LinearAlgebra/test/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1277,6 +1277,12 @@ end
@test c == Diagonal([2,2,2,2])
end

@testset "uppertriangular/lowertriangular" begin
D = Diagonal([1,2])
@test LinearAlgebra.uppertriangular(D) === D
@test LinearAlgebra.lowertriangular(D) === D
end

@testset "mul/div with an adjoint vector" begin
A = [1.0;;]
x = [1.0]
Expand Down
251 changes: 251 additions & 0 deletions stdlib/LinearAlgebra/test/special.jl
Original file line number Diff line number Diff line change
Expand Up @@ -536,4 +536,255 @@ end
@test v * S isa Matrix
end

@testset "copyto! between matrix types" begin
dl, d, du = zeros(Int,4), [1:5;], zeros(Int,4)
d_ones = ones(Int,size(du))

@testset "from Diagonal" begin
D = Diagonal(d)
@testset "to Bidiagonal" begin
BU = Bidiagonal(similar(d, BigInt), similar(du, BigInt), :U)
BL = Bidiagonal(similar(d, BigInt), similar(dl, BigInt), :L)
for B in (BL, BU)
copyto!(B, D)
@test B == D
end

@testset "mismatched size" begin
for B in (BU, BL)
B .= 0
copyto!(B, Diagonal(Int[1]))
@test B[1,1] == 1
B[1,1] = 0
@test iszero(B)
end
end
end
@testset "to Tridiagonal" begin
T = Tridiagonal(similar(dl, BigInt), similar(d, BigInt), similar(du, BigInt))
copyto!(T, D)
@test T == D

@testset "mismatched size" begin
T .= 0
copyto!(T, Diagonal([1]))
@test T[1,1] == 1
T[1,1] = 0
@test iszero(T)
end
end
@testset "to SymTridiagonal" begin
for du2 in (similar(du, BigInt), similar(d, BigInt))
S = SymTridiagonal(similar(d), du2)
copyto!(S, D)
@test S == D
end

@testset "mismatched size" begin
S = SymTridiagonal(zero(d), zero(du))
copyto!(S, Diagonal([1]))
@test S[1,1] == 1
S[1,1] = 0
@test iszero(S)
end
end
end

@testset "from Bidiagonal" begin
BU = Bidiagonal(d, du, :U)
BUones = Bidiagonal(d, oneunit.(du), :U)
BL = Bidiagonal(d, dl, :L)
BLones = Bidiagonal(d, oneunit.(dl), :L)
@testset "to Diagonal" begin
D = Diagonal(zero(d))
for B in (BL, BU)
@test copyto!(D, B) == B
D .= 0
end
for B in (BLones, BUones)
errmsg = "cannot copy a Bidiagonal with a non-zero off-diagonal band to a Diagonal"
@test_throws errmsg copyto!(D, B)
@test iszero(D)
end

@testset "mismatched size" begin
for uplo in (:L, :U)
D .= 0
copyto!(D, Bidiagonal(Int[1], Int[], uplo))
@test D[1,1] == 1
D[1,1] = 0
@test iszero(D)
end
end
end
@testset "to Tridiagonal" begin
T = Tridiagonal(similar(dl, BigInt), similar(d, BigInt), similar(du, BigInt))
for B in (BL, BU, BLones, BUones)
copyto!(T, B)
@test T == B
end

@testset "mismatched size" begin
T = Tridiagonal(oneunit.(dl), zero(d), oneunit.(du))
for uplo in (:L, :U)
T .= 0
copyto!(T, Bidiagonal([1], Int[], uplo))
@test T[1,1] == 1
T[1,1] = 0
@test iszero(T)
end
end
end
@testset "to SymTridiagonal" begin
for du2 in (similar(du, BigInt), similar(d, BigInt))
S = SymTridiagonal(similar(d, BigInt), du2)
for B in (BL, BU)
copyto!(S, B)
@test S == B
end
errmsg = "cannot copy a non-symmetric Bidiagonal matrix to a SymTridiagonal"
@test_throws errmsg copyto!(S, BUones)
@test_throws errmsg copyto!(S, BLones)
end

@testset "mismatched size" begin
S = SymTridiagonal(zero(d), zero(du))
for uplo in (:L, :U)
copyto!(S, Bidiagonal([1], Int[], uplo))
@test S[1,1] == 1
S[1,1] = 0
@test iszero(S)
end
end
end
end

@testset "from Tridiagonal" begin
T = Tridiagonal(dl, d, du)
TU = Tridiagonal(dl, d, d_ones)
TL = Tridiagonal(d_ones, d, dl)
@testset "to Diagonal" begin
D = Diagonal(zero(d))
@test copyto!(D, T) == Diagonal(d)
errmsg = "cannot copy a Tridiagonal with a non-zero off-diagonal band to a Diagonal"
D .= 0
@test_throws errmsg copyto!(D, TU)
@test iszero(D)
errmsg = "cannot copy a Tridiagonal with a non-zero off-diagonal band to a Diagonal"
@test_throws errmsg copyto!(D, TL)
@test iszero(D)

@testset "mismatched size" begin
D .= 0
copyto!(D, Tridiagonal(Int[], Int[1], Int[]))
@test D[1,1] == 1
D[1,1] = 0
@test iszero(D)
end
end
@testset "to Bidiagonal" begin
BU = Bidiagonal(zero(d), zero(du), :U)
BL = Bidiagonal(zero(d), zero(du), :L)
@test copyto!(BU, T) == Bidiagonal(d, du, :U)
@test copyto!(BL, T) == Bidiagonal(d, du, :L)

BU .= 0
BL .= 0
errmsg = "cannot copy a Tridiagonal with a non-zero superdiagonal to a Bidiagonal with uplo=:L"
@test_throws errmsg copyto!(BL, TU)
@test iszero(BL)
@test copyto!(BU, TU) == Bidiagonal(d, d_ones, :U)

BU .= 0
BL .= 0
@test copyto!(BL, TL) == Bidiagonal(d, d_ones, :L)
errmsg = "cannot copy a Tridiagonal with a non-zero subdiagonal to a Bidiagonal with uplo=:U"
@test_throws errmsg copyto!(BU, TL)
@test iszero(BU)

@testset "mismatched size" begin
for B in (BU, BL)
B .= 0
copyto!(B, Tridiagonal(Int[], Int[1], Int[]))
@test B[1,1] == 1
B[1,1] = 0
@test iszero(B)
end
end
end
end

@testset "from SymTridiagonal" begin
S2 = SymTridiagonal(d, ones(Int,size(d)))
for S in (SymTridiagonal(d, du), SymTridiagonal(d, zero(d)))
@testset "to Diagonal" begin
D = Diagonal(zero(d))
@test copyto!(D, S) == Diagonal(d)
D .= 0
errmsg = "cannot copy a SymTridiagonal with a non-zero off-diagonal band to a Diagonal"
@test_throws errmsg copyto!(D, S2)
@test iszero(D)

@testset "mismatched size" begin
D .= 0
copyto!(D, SymTridiagonal(Int[1], Int[]))
@test D[1,1] == 1
D[1,1] = 0
@test iszero(D)
end
end
@testset "to Bidiagonal" begin
BU = Bidiagonal(zero(d), zero(du), :U)
BL = Bidiagonal(zero(d), zero(du), :L)
@test copyto!(BU, S) == Bidiagonal(d, du, :U)
@test copyto!(BL, S) == Bidiagonal(d, du, :L)

BU .= 0
BL .= 0
errmsg = "cannot copy a SymTridiagonal with a non-zero off-diagonal band to a Bidiagonal"
@test_throws errmsg copyto!(BU, S2)
@test iszero(BU)
@test_throws errmsg copyto!(BL, S2)
@test iszero(BL)

@testset "mismatched size" begin
for B in (BU, BL)
B .= 0
copyto!(B, SymTridiagonal(Int[1], Int[]))
@test B[1,1] == 1
B[1,1] = 0
@test iszero(B)
end
end
end
end
end
end

@testset "BandIndex indexing" begin
for D in (Diagonal(1:3), Bidiagonal(1:3, 2:3, :U), Bidiagonal(1:3, 2:3, :L),
Tridiagonal(2:3, 1:3, 1:2), SymTridiagonal(1:3, 2:3))
M = Matrix(D)
for band in -size(D,1)+1:size(D,1)-1
for idx in 1:size(D,1)-abs(band)
@test D[BandIndex(band, idx)] == M[BandIndex(band, idx)]
end
end
@test_throws BoundsError D[BandIndex(size(D,1),1)]
end
end

@testset "Partly filled Hermitian and Diagonal algebra" begin
D = Diagonal([1,2])
for S in (Symmetric, Hermitian), uplo in (:U, :L)
M = Matrix{BigInt}(undef, 2, 2)
M[1,1] = M[2,2] = M[1+(uplo == :L), 1 + (uplo == :U)] = 3
H = S(M, uplo)
HM = Matrix(H)
@test H + D == D + H == HM + D
@test H - D == HM - D
@test D - H == D - HM
end
end

end # module TestSpecial
25 changes: 25 additions & 0 deletions stdlib/LinearAlgebra/test/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,31 @@ end
@test Su - Sl == -(Sl - Su) == MSu - MSl
end
end
@testset "non-strided" begin
@testset "diagonal" begin
for ST1 in (Symmetric, Hermitian), uplo1 in (:L, :U)
m = ST1(Matrix{BigFloat}(undef,2,2), uplo1)
m.data[1,1] = 1
m.data[2,2] = 3
m.data[1+(uplo1==:L), 1+(uplo1==:U)] = 2
A = Array(m)
for ST2 in (Symmetric, Hermitian), uplo2 in (:L, :U)
id = ST2(I(2), uplo2)
@test m + id == id + m == A + id
end
end
end
@testset "unit triangular" begin
for ST1 in (Symmetric, Hermitian), uplo1 in (:L, :U)
H1 = ST1(UnitUpperTriangular(big.(rand(Int8,4,4))), uplo1)
M1 = Matrix(H1)
for ST2 in (Symmetric, Hermitian), uplo2 in (:L, :U)
H2 = ST2(UnitUpperTriangular(big.(rand(Int8,4,4))), uplo2)
@test H1 + H2 == M1 + Matrix(H2)
end
end
end
end
end

# bug identified in PR #52318: dot products of quaternionic Hermitian matrices,
Expand Down
Loading

0 comments on commit d40fa57

Please sign in to comment.