Skip to content

Commit

Permalink
Fix evolve! function (#179)
Browse files Browse the repository at this point in the history
* Fix evolve_2site! function

* Add tests for evolve! and expect

* Fix test

* Format tests

* Chnage testset names
  • Loading branch information
jofrevalles authored and mofeing committed Sep 4, 2024
1 parent eb36c18 commit d31a184
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 2 deletions.
11 changes: 10 additions & 1 deletion src/Ansatz/Chain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -598,12 +598,21 @@ function evolve_2site!(qtn::Chain, gate::Dense; threshold, maxdim, iscanonical=f
end,
)
replace!(
TensorNetwork(gate),
Quantum(gate),
map(zip(inputs(gate), contracting_inds)) do (site, contracting_index)
inds(gate; at=site) => contracting_index
end,
)

# replace output indices of the gate for gensym indices
output_inds = [gensym(:out) for _ in outputs(gate)]
replace!(
Quantum(gate),
map(zip(outputs(gate), output_inds)) do (site, out)
inds(gate; at=site) => out
end,
)

# reindex output of gate to match TN sitemap
for site in outputs(gate)
if inds(qtn; at=site) != inds(gate; at=site)
Expand Down
60 changes: 59 additions & 1 deletion test/Chain_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -386,5 +386,63 @@
@test isapprox(contract(TensorNetwork(qtn)), contract(TensorNetwork(adjoint_qtn)))
end

# TODO test `evolve!` methods
@testset "evolve!" begin
@testset "one site" begin
i = 2
mat = reshape(LinearAlgebra.I(2), 2, 2)
gate = Dense(Tenet.Operator(), mat; sites=[Site(i), Site(i; dual=true)])

qtn = Chain(State(), Open(), [rand(2, 2), rand(2, 2, 2), rand(2, 2, 2), rand(2, 2, 2), rand(2, 2)])

@testset "canonical form" begin
canonized = canonize(qtn)

evolved = evolve!(deepcopy(canonized), gate; threshold=1e-14)
@test isapprox(contract(TensorNetwork(evolved)), contract(TensorNetwork(canonized)))
@test issetequal(size.(tensors(evolved)), [(2, 2), (2,), (2, 2, 2), (2,), (2, 2, 2), (2,), (2, 2)])
@test isapprox(contract(TensorNetwork(evolved)), contract(TensorNetwork(qtn)))
end

@testset "arbitrary chain" begin
evolved = evolve!(deepcopy(qtn), gate; threshold=1e-14, iscanonical=false)
@test length(tensors(evolved)) == 5
@test issetequal(size.(tensors(evolved)), [(2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2)])
@test isapprox(contract(TensorNetwork(evolved)), contract(TensorNetwork(qtn)))
end
end

@testset "two sites" begin
i, j = 2, 3
mat = reshape(kron(LinearAlgebra.I(2), LinearAlgebra.I(2)), 2, 2, 2, 2)
gate = Dense(Tenet.Operator(), mat; sites=[Site(i), Site(j), Site(i; dual=true), Site(j; dual=true)])

qtn = Chain(State(), Open(), [rand(2, 2), rand(2, 2, 2), rand(2, 2, 2), rand(2, 2)])

@testset "canonical form" begin
canonized = canonize(qtn)

evolved = evolve!(deepcopy(canonized), gate; threshold=1e-14)
@test isapprox(contract(TensorNetwork(evolved)), contract(TensorNetwork(canonized)))
@test issetequal(size.(tensors(evolved)), [(2, 2), (2,), (2, 2, 2), (2,), (2, 2, 2), (2,), (2, 2)])
@test isapprox(contract(TensorNetwork(evolved)), contract(TensorNetwork(qtn)))
end

@testset "arbitrary chain" begin
evolved = evolve!(deepcopy(qtn), gate; threshold=1e-14, iscanonical=false)
@test length(tensors(evolved)) == 5
@test issetequal(size.(tensors(evolved)), [(2, 2), (2, 2, 2), (2,), (2, 2, 2), (2, 2, 2), (2, 2)])
@test isapprox(contract(TensorNetwork(evolved)), contract(TensorNetwork(qtn)))
end
end
end

@testset "expect" begin
i, j = 2, 3
mat = reshape(kron(LinearAlgebra.I(2), LinearAlgebra.I(2)), 2, 2, 2, 2)
gate = Dense(Tenet.Operator(), mat; sites=[Site(i), Site(j), Site(i; dual=true), Site(j; dual=true)])

qtn = Chain(State(), Open(), [rand(2, 2), rand(2, 2, 2), rand(2, 2, 2), rand(2, 2, 2), rand(2, 2)])

@test isapprox(expect(qtn, [gate]), norm(qtn)^2)
end
end

0 comments on commit d31a184

Please sign in to comment.