Skip to content

Commit

Permalink
Backport "Fix tr for Symmetric/Hermitian block matrices #55522" to v1…
Browse files Browse the repository at this point in the history
….11 (#55535)
  • Loading branch information
jishnub authored Aug 21, 2024
1 parent 8467772 commit 13d440d
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
4 changes: 2 additions & 2 deletions stdlib/LinearAlgebra/src/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -422,8 +422,8 @@ Base.copy(A::Adjoint{<:Any,<:Symmetric}) =
Base.copy(A::Transpose{<:Any,<:Hermitian}) =
Hermitian(copy(transpose(A.parent.data)), ifelse(A.parent.uplo == 'U', :L, :U))

tr(A::Symmetric) = tr(A.data) # to avoid AbstractMatrix fallback (incl. allocations)
tr(A::Hermitian) = real(tr(A.data))
tr(A::Symmetric{<:Number}) = tr(A.data) # to avoid AbstractMatrix fallback (incl. allocations)
tr(A::Hermitian{<:Number}) = real(tr(A.data))

Base.conj(A::Symmetric) = Symmetric(parentof_applytri(conj, A), sym_uplo(A.uplo))
Base.conj(A::Hermitian) = Hermitian(parentof_applytri(conj, A), sym_uplo(A.uplo))
Expand Down
11 changes: 11 additions & 0 deletions stdlib/LinearAlgebra/test/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -978,4 +978,15 @@ end
@test conj(H) == conj(Array(H))
end

@testset "tr for block matrices" begin
m = [1 2; 3 4]
for b in (m, m * (1 + im))
M = fill(b, 3, 3)
for ST in (Symmetric, Hermitian)
S = ST(M)
@test tr(S) == sum(diag(S))
end
end
end

end # module TestSymmetric

0 comments on commit 13d440d

Please sign in to comment.