Skip to content

Commit

Permalink
Test conj and contract with complex numbers on Reactant integration
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed Jan 21, 2025
1 parent 27079f9 commit 7b1be3c
Showing 1 changed file with 27 additions and 29 deletions.
56 changes: 27 additions & 29 deletions test/integration/Reactant_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,20 @@ using Reactant

# TODO test unary einsum
# TODO test scalar × tensor
@testset "conj" begin
A = Tensor(rand(ComplexF64, 2, 3), (:i, :j))
Are = adapt(ConcreteRArray, A)

C = conj(A)
Cre = @jit conj(Are)

@test Cre C
end

@testset "contract" begin
@testset "matrix multiplication" begin
A = Tensor(rand(2, 3), (:i, :j))
B = Tensor(rand(3, 4), (:j, :k))
@testset "matrix multiplication - eltype=$T" for T in [Float64, ComplexF64]
A = Tensor(rand(T, 2, 3), (:i, :j))
B = Tensor(rand(T, 3, 4), (:j, :k))
Are = adapt(ConcreteRArray, A)
Bre = adapt(ConcreteRArray, B)

Expand Down Expand Up @@ -39,9 +49,9 @@ using Reactant
end
end

@testset "inner product" begin
A = Tensor(rand(3, 4), (:i, :j))
B = Tensor(rand(4, 3), (:j, :i))
@testset "inner product - eltype=$T" for T in [Float64, ComplexF64]
A = Tensor(rand(T, 3, 4), (:i, :j))
B = Tensor(rand(T, 4, 3), (:j, :i))
C = contract(A, B)

Are = adapt(ConcreteRArray, A)
Expand All @@ -51,9 +61,9 @@ using Reactant
@test Cre C
end

@testset "outer product" begin
A = Tensor(rand(2, 2), (:i, :j))
B = Tensor(rand(2, 2), (:k, :l))
@testset "outer product - eltype=$T" for T in [Float64, ComplexF64]
A = Tensor(rand(T, 2, 2), (:i, :j))
B = Tensor(rand(T, 2, 2), (:k, :l))
C = contract(A, B)

Are = adapt(ConcreteRArray, A)
Expand All @@ -63,9 +73,9 @@ using Reactant
@test Cre C
end

@testset "manual" begin
A = Tensor(rand(2, 3, 4), (:i, :j, :k))
B = Tensor(rand(4, 5, 3), (:k, :l, :j))
@testset "manual - eltype=$T" for T in [Float64, ComplexF64]
A = Tensor(rand(T, 2, 3, 4), (:i, :j, :k))
B = Tensor(rand(T, 4, 5, 3), (:k, :l, :j))
Are = adapt(ConcreteRArray, A)
Bre = adapt(ConcreteRArray, B)

Expand All @@ -82,25 +92,13 @@ using Reactant
Cre = @jit f2(Are, Bre)

@test Cre C

@testset "Complex numbers" begin
A = Tensor(rand(Complex{Float64}, 2, 3, 4), (:i, :j, :k))
B = Tensor(rand(Complex{Float64}, 4, 5, 3), (:k, :l, :j))
Are = adapt(ConcreteRArray, A)
Bre = adapt(ConcreteRArray, B)

C = f1(A, B)
Cre = @jit f1(Are, Bre)

@test Cre C
end
end

@testset "multiple tensors" begin
A = Tensor(rand(2, 3, 4), (:i, :j, :k))
B = Tensor(rand(4, 5, 3), (:k, :l, :j))
C = Tensor(rand(5, 6, 2), (:l, :m, :i))
D = Tensor(rand(6, 7, 2), (:m, :n, :i))
@testset "multiple tensors - eltype=$T" for T in [Float64, ComplexF64]
A = Tensor(rand(T, 2, 3, 4), (:i, :j, :k))
B = Tensor(rand(T, 4, 5, 3), (:k, :l, :j))
C = Tensor(rand(T, 5, 6, 2), (:l, :m, :i))
D = Tensor(rand(T, 6, 7, 2), (:m, :n, :i))

Are = adapt(ConcreteRArray, A)
Bre = adapt(ConcreteRArray, B)
Expand Down

0 comments on commit 7b1be3c

Please sign in to comment.