diff --git a/stdlib/LinearAlgebra/src/transpose.jl b/stdlib/LinearAlgebra/src/transpose.jl index 9d70ac3add34b..dd3167c1ecddf 100644 --- a/stdlib/LinearAlgebra/src/transpose.jl +++ b/stdlib/LinearAlgebra/src/transpose.jl @@ -74,27 +74,32 @@ julia> A ``` """ adjoint!(B::AbstractMatrix, A::AbstractMatrix) = transpose_f!(adjoint, B, A) + +@noinline function check_transpose_axes(axesA, axesB) + axesB == reverse(axesA) || throw(DimensionMismatch("axes of the destination are incompatible with that of the source")) +end + function transpose!(B::AbstractVector, A::AbstractMatrix) - axes(B,1) == axes(A,2) && axes(A,1) == 1:1 || throw(DimensionMismatch("transpose")) + check_transpose_axes((axes(B,1), axes(B,2)), axes(A)) copyto!(B, A) end function transpose!(B::AbstractMatrix, A::AbstractVector) - axes(B,2) == axes(A,1) && axes(B,1) == 1:1 || throw(DimensionMismatch("transpose")) + check_transpose_axes(axes(B), (axes(A,1), axes(A,2))) copyto!(B, A) end function adjoint!(B::AbstractVector, A::AbstractMatrix) - axes(B,1) == axes(A,2) && axes(A,1) == 1:1 || throw(DimensionMismatch("transpose")) + check_transpose_axes((axes(B,1), axes(B,2)), axes(A)) ccopy!(B, A) end function adjoint!(B::AbstractMatrix, A::AbstractVector) - axes(B,2) == axes(A,1) && axes(B,1) == 1:1 || throw(DimensionMismatch("transpose")) + check_transpose_axes(axes(B), (axes(A,1), axes(A,2))) ccopy!(B, A) end const transposebaselength=64 function transpose_f!(f, B::AbstractMatrix, A::AbstractMatrix) inds = axes(A) - axes(B,1) == inds[2] && axes(B,2) == inds[1] || throw(DimensionMismatch(string(f))) + check_transpose_axes(axes(B), inds) m, n = length(inds[1]), length(inds[2]) if m*n<=4*transposebaselength diff --git a/stdlib/LinearAlgebra/test/adjtrans.jl b/stdlib/LinearAlgebra/test/adjtrans.jl index 93d2f264b05fe..313427be22095 100644 --- a/stdlib/LinearAlgebra/test/adjtrans.jl +++ b/stdlib/LinearAlgebra/test/adjtrans.jl @@ -671,4 +671,14 @@ end @test sprint(Base.print_matrix, Adjoint(o)) == sprint(Base.print_matrix, OneHotVecOrMat((1,2), (1,4))) end +@testset "error message in transpose" begin + v = zeros(2) + A = zeros(1,1) + B = zeros(2,3) + for (t1, t2) in Any[(A, v), (v, A), (A, B)] + @test_throws "axes of the destination are incompatible with that of the source" transpose!(t1, t2) + @test_throws "axes of the destination are incompatible with that of the source" adjoint!(t1, t2) + end +end + end # module TestAdjointTranspose