From d31a1847600d3ed334b1ce7994dcc5bbc60535ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jofre=20Vall=C3=A8s=20Muns?= <61060572+jofrevalles@users.noreply.github.com> Date: Wed, 7 Aug 2024 13:16:58 +0200 Subject: [PATCH] Fix `evolve!` function (#179) * Fix evolve_2site! function * Add tests for evolve! and expect * Fix test * Format tests * Chnage testset names --- src/Ansatz/Chain.jl | 11 ++++++++- test/Chain_test.jl | 60 ++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 69 insertions(+), 2 deletions(-) diff --git a/src/Ansatz/Chain.jl b/src/Ansatz/Chain.jl index 14364e387..3a9cf74e7 100644 --- a/src/Ansatz/Chain.jl +++ b/src/Ansatz/Chain.jl @@ -598,12 +598,21 @@ function evolve_2site!(qtn::Chain, gate::Dense; threshold, maxdim, iscanonical=f end, ) replace!( - TensorNetwork(gate), + Quantum(gate), map(zip(inputs(gate), contracting_inds)) do (site, contracting_index) inds(gate; at=site) => contracting_index end, ) + # replace output indices of the gate for gensym indices + output_inds = [gensym(:out) for _ in outputs(gate)] + replace!( + Quantum(gate), + map(zip(outputs(gate), output_inds)) do (site, out) + inds(gate; at=site) => out + end, + ) + # reindex output of gate to match TN sitemap for site in outputs(gate) if inds(qtn; at=site) != inds(gate; at=site) diff --git a/test/Chain_test.jl b/test/Chain_test.jl index a3b3d2949..ca18d71be 100644 --- a/test/Chain_test.jl +++ b/test/Chain_test.jl @@ -386,5 +386,63 @@ @test isapprox(contract(TensorNetwork(qtn)), contract(TensorNetwork(adjoint_qtn))) end - # TODO test `evolve!` methods + @testset "evolve!" begin + @testset "one site" begin + i = 2 + mat = reshape(LinearAlgebra.I(2), 2, 2) + gate = Dense(Tenet.Operator(), mat; sites=[Site(i), Site(i; dual=true)]) + + qtn = Chain(State(), Open(), [rand(2, 2), rand(2, 2, 2), rand(2, 2, 2), rand(2, 2, 2), rand(2, 2)]) + + @testset "canonical form" begin + canonized = canonize(qtn) + + evolved = evolve!(deepcopy(canonized), gate; threshold=1e-14) + @test isapprox(contract(TensorNetwork(evolved)), contract(TensorNetwork(canonized))) + @test issetequal(size.(tensors(evolved)), [(2, 2), (2,), (2, 2, 2), (2,), (2, 2, 2), (2,), (2, 2)]) + @test isapprox(contract(TensorNetwork(evolved)), contract(TensorNetwork(qtn))) + end + + @testset "arbitrary chain" begin + evolved = evolve!(deepcopy(qtn), gate; threshold=1e-14, iscanonical=false) + @test length(tensors(evolved)) == 5 + @test issetequal(size.(tensors(evolved)), [(2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2)]) + @test isapprox(contract(TensorNetwork(evolved)), contract(TensorNetwork(qtn))) + end + end + + @testset "two sites" begin + i, j = 2, 3 + mat = reshape(kron(LinearAlgebra.I(2), LinearAlgebra.I(2)), 2, 2, 2, 2) + gate = Dense(Tenet.Operator(), mat; sites=[Site(i), Site(j), Site(i; dual=true), Site(j; dual=true)]) + + qtn = Chain(State(), Open(), [rand(2, 2), rand(2, 2, 2), rand(2, 2, 2), rand(2, 2)]) + + @testset "canonical form" begin + canonized = canonize(qtn) + + evolved = evolve!(deepcopy(canonized), gate; threshold=1e-14) + @test isapprox(contract(TensorNetwork(evolved)), contract(TensorNetwork(canonized))) + @test issetequal(size.(tensors(evolved)), [(2, 2), (2,), (2, 2, 2), (2,), (2, 2, 2), (2,), (2, 2)]) + @test isapprox(contract(TensorNetwork(evolved)), contract(TensorNetwork(qtn))) + end + + @testset "arbitrary chain" begin + evolved = evolve!(deepcopy(qtn), gate; threshold=1e-14, iscanonical=false) + @test length(tensors(evolved)) == 5 + @test issetequal(size.(tensors(evolved)), [(2, 2), (2, 2, 2), (2,), (2, 2, 2), (2, 2, 2), (2, 2)]) + @test isapprox(contract(TensorNetwork(evolved)), contract(TensorNetwork(qtn))) + end + end + end + + @testset "expect" begin + i, j = 2, 3 + mat = reshape(kron(LinearAlgebra.I(2), LinearAlgebra.I(2)), 2, 2, 2, 2) + gate = Dense(Tenet.Operator(), mat; sites=[Site(i), Site(j), Site(i; dual=true), Site(j; dual=true)]) + + qtn = Chain(State(), Open(), [rand(2, 2), rand(2, 2, 2), rand(2, 2, 2), rand(2, 2, 2), rand(2, 2)]) + + @test isapprox(expect(qtn, [gate]), norm(qtn)^2) + end end