Skip to content

Commit

Permalink
Improve error message in inplace transpose (#54669)
Browse files Browse the repository at this point in the history
(cherry picked from commit 9eb7a0c)
  • Loading branch information
jishnub authored and KristofferC committed Oct 9, 2024
1 parent 2d6e88f commit 36ff239
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 5 deletions.
15 changes: 10 additions & 5 deletions stdlib/LinearAlgebra/src/transpose.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,27 +74,32 @@ julia> A
```
"""
adjoint!(B::AbstractMatrix, A::AbstractMatrix) = transpose_f!(adjoint, B, A)

@noinline function check_transpose_axes(axesA, axesB)
axesB == reverse(axesA) || throw(DimensionMismatch("axes of the destination are incompatible with that of the source"))
end

function transpose!(B::AbstractVector, A::AbstractMatrix)
axes(B,1) == axes(A,2) && axes(A,1) == 1:1 || throw(DimensionMismatch("transpose"))
check_transpose_axes((axes(B,1), axes(B,2)), axes(A))
copyto!(B, A)
end
function transpose!(B::AbstractMatrix, A::AbstractVector)
axes(B,2) == axes(A,1) && axes(B,1) == 1:1 || throw(DimensionMismatch("transpose"))
check_transpose_axes(axes(B), (axes(A,1), axes(A,2)))
copyto!(B, A)
end
function adjoint!(B::AbstractVector, A::AbstractMatrix)
axes(B,1) == axes(A,2) && axes(A,1) == 1:1 || throw(DimensionMismatch("transpose"))
check_transpose_axes((axes(B,1), axes(B,2)), axes(A))
ccopy!(B, A)
end
function adjoint!(B::AbstractMatrix, A::AbstractVector)
axes(B,2) == axes(A,1) && axes(B,1) == 1:1 || throw(DimensionMismatch("transpose"))
check_transpose_axes(axes(B), (axes(A,1), axes(A,2)))
ccopy!(B, A)
end

const transposebaselength=64
function transpose_f!(f, B::AbstractMatrix, A::AbstractMatrix)
inds = axes(A)
axes(B,1) == inds[2] && axes(B,2) == inds[1] || throw(DimensionMismatch(string(f)))
check_transpose_axes(axes(B), inds)

m, n = length(inds[1]), length(inds[2])
if m*n<=4*transposebaselength
Expand Down
10 changes: 10 additions & 0 deletions stdlib/LinearAlgebra/test/adjtrans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -671,4 +671,14 @@ end
@test sprint(Base.print_matrix, Adjoint(o)) == sprint(Base.print_matrix, OneHotVecOrMat((1,2), (1,4)))
end

@testset "error message in transpose" begin
v = zeros(2)
A = zeros(1,1)
B = zeros(2,3)
for (t1, t2) in Any[(A, v), (v, A), (A, B)]
@test_throws "axes of the destination are incompatible with that of the source" transpose!(t1, t2)
@test_throws "axes of the destination are incompatible with that of the source" adjoint!(t1, t2)
end
end

end # module TestAdjointTranspose

0 comments on commit 36ff239

Please sign in to comment.