From 9d8f4ef6b8b4c1a2a913213d501556ad4460e3db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Fri, 9 Aug 2024 12:05:06 +0200 Subject: [PATCH 01/75] Prototype `MPS`, `MPO` --- src/Ansatz/MPO.jl | 47 +++++++++++++++++++++++++++++++++++++++++++++++ src/Ansatz/MPS.jl | 44 ++++++++++++++++++++++++++++++++++++++++++++ src/Tenet.jl | 9 ++++++--- 3 files changed, 97 insertions(+), 3 deletions(-) create mode 100644 src/Ansatz/MPO.jl create mode 100644 src/Ansatz/MPS.jl diff --git a/src/Ansatz/MPO.jl b/src/Ansatz/MPO.jl new file mode 100644 index 000000000..8055c098d --- /dev/null +++ b/src/Ansatz/MPO.jl @@ -0,0 +1,47 @@ +struct MPO <: Ansatz + super::Quantum +end + +function MPO(arrays::Vector{<:AbstractArray}; order=defaultorder(MPO)) + @assert ndims(arrays[1]) == 3 "First array must have 3 dimensions" + @assert all(==(4) ∘ ndims, arrays[2:(end - 1)]) "All arrays must have 4 dimensions" + @assert ndims(arrays[end]) == 3 "Last array must have 3 dimensions" + issetequal(order, defaultorder(Chain, Operator())) || + throw(ArgumentError("order must be a permutation of $(String.(defaultorder(Chain, Operator())))")) + + n = length(arrays) + gen = IndexCounter() + symbols = [nextindex!(gen) for _ in 1:(3n - 1)] + + _tensors = map(enumerate(arrays)) do (i, array) + _order = if i == 1 + filter(x -> x != :l, order) + elseif i == n + filter(x -> x != :r, order) + else + order + end + + inds = map(_order) do dir + if dir == :o + symbols[i] + elseif dir == :i + symbols[i + n] + elseif dir == :l + symbols[2n + mod1(i - 1, n)] + elseif dir == :r + symbols[2n + mod1(i, n)] + else + throw(ArgumentError("Invalid direction: $dir")) + end + end + Tensor(array, inds) + end + + sitemap = Dict(Site(i) => symbols[i] for i in 1:n) + merge!(sitemap, Dict(Site(i; dual=true) => symbols[i + n] for i in 1:n)) + + return MPO(Quantum(TensorNetwork(_tensors), sitemap)) +end + +boundary(::MPO) = Open() diff --git a/src/Ansatz/MPS.jl b/src/Ansatz/MPS.jl new file mode 100644 index 000000000..ade4f7099 --- /dev/null +++ b/src/Ansatz/MPS.jl @@ -0,0 +1,44 @@ +struct MPS <: Ansatz + super::Quantum +end + +function MPS(arrays::Vector{<:AbstractArray}; order=defaultorder(MPS)) + @assert ndims(arrays[1]) == 2 "First array must have 2 dimensions" + @assert all(==(3) ∘ ndims, arrays[2:(end - 1)]) "All arrays must have 3 dimensions" + @assert ndims(arrays[end]) == 2 "Last array must have 2 dimensions" + issetequal(order, defaultorder(Chain, State())) || + throw(ArgumentError("order must be a permutation of $(String.(defaultorder(Chain, State())))")) + + n = length(arrays) + gen = IndexCounter() + symbols = [nextindex!(gen) for _ in 1:(2n)] + + _tensors = map(enumerate(arrays)) do (i, array) + _order = if i == 1 + filter(x -> x != :l, order) + elseif i == n + filter(x -> x != :r, order) + else + order + end + + inds = map(_order) do dir + if dir == :o + symbols[i] + elseif dir == :r + symbols[n + mod1(i, n)] + elseif dir == :l + symbols[n + mod1(i - 1, n)] + else + throw(ArgumentError("Invalid direction: $dir")) + end + end + Tensor(array, inds) + end + + sitemap = Dict(Site(i) => symbols[i] for i in 1:n) + + return MPS(Quantum(TensorNetwork(_tensors), sitemap)) +end + +boundary(::MPS) = Open() diff --git a/src/Tenet.jl b/src/Tenet.jl index 46e13c920..d2b4af140 100644 --- a/src/Tenet.jl +++ b/src/Tenet.jl @@ -33,9 +33,12 @@ export Product include("Ansatz/Dense.jl") export Dense -include("Ansatz/Chain.jl") -export Chain -export MPS, pMPS, MPO, pMPO +include("Ansatz/MPS.jl") +export MPS + +include("Ansatz/MPO.jl") +export MPO + export leftindex, rightindex, isleftcanonical, isrightcanonical export canonize_site, canonize_site!, truncate! export canonize, canonize!, mixed_canonize, mixed_canonize! From e48e2c6c0d253eb82513305bbaa0485d39c58bd6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 3 Sep 2024 12:38:11 +0200 Subject: [PATCH 02/75] Introduce Canonical `Form` trait --- src/Ansatz/Ansatz.jl | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/Ansatz/Ansatz.jl b/src/Ansatz/Ansatz.jl index 84fe9fffb..e6c0f123a 100644 --- a/src/Ansatz/Ansatz.jl +++ b/src/Ansatz/Ansatz.jl @@ -76,3 +76,12 @@ struct Open <: Boundary end struct Periodic <: Boundary end function boundary end + +abstract type Form end +struct NonCanonical <: Form end +struct MixedCanonical <: Form + orthogonality_center::Union{Site,Vector{Site}} +end +struct Canonical <: Form end + +function form end From 0382f5764c0982c7d8f3b2ea46fdaf0654db68da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 12 Sep 2024 06:00:48 -0400 Subject: [PATCH 03/75] Implement `rand`, `adjoint`, `defaultorder`, `boundary`, `form` for `MPS`, `MPO` --- src/Ansatz/Chain.jl | 177 -------------------------------------------- src/Ansatz/MPO.jl | 90 +++++++++++++++------- src/Ansatz/MPS.jl | 90 +++++++++++++++------- 3 files changed, 129 insertions(+), 228 deletions(-) diff --git a/src/Ansatz/Chain.jl b/src/Ansatz/Chain.jl index bc7332728..a68b7ef53 100644 --- a/src/Ansatz/Chain.jl +++ b/src/Ansatz/Chain.jl @@ -11,26 +11,6 @@ Base.copy(tn::Chain) = Chain(copy(Quantum(tn)), boundary(tn)) Base.similar(tn::Chain) = Chain(similar(Quantum(tn)), boundary(tn)) Base.zero(tn::Chain) = Chain(zero(Quantum(tn)), boundary(tn)) -boundary(tn::Chain) = tn.boundary - -MPS(arrays) = Chain(State(), Open(), arrays) -pMPS(arrays) = Chain(State(), Periodic(), arrays) -MPO(arrays) = Chain(Operator(), Open(), arrays) -pMPO(arrays) = Chain(Operator(), Periodic(), arrays) - -alias(tn::Chain) = alias(socket(tn), boundary(tn), tn) -alias(::State, ::Open, ::Chain) = "MPS" -alias(::State, ::Periodic, ::Chain) = "pMPS" -alias(::Operator, ::Open, ::Chain) = "MPO" -alias(::Operator, ::Periodic, ::Chain) = "pMPO" - -function Chain(tn::TensorNetwork, sites, args...; kwargs...) - return Chain(Quantum(tn, sites), args...; kwargs...) -end - -defaultorder(::Type{Chain}, ::State) = (:o, :l, :r) -defaultorder(::Type{Chain}, ::Operator) = (:o, :i, :l, :r) - function Chain(::State, boundary::Periodic, arrays::Vector{<:AbstractArray}; order=defaultorder(Chain, State())) @assert all(==(3) ∘ ndims, arrays) "All arrays must have 3 dimensions" issetequal(order, defaultorder(Chain, State())) || @@ -60,45 +40,6 @@ function Chain(::State, boundary::Periodic, arrays::Vector{<:AbstractArray}; ord return Chain(Quantum(TensorNetwork(_tensors), sitemap), boundary) end -function Chain(::State, boundary::Open, arrays::Vector{<:AbstractArray}; order=defaultorder(Chain, State())) - @assert ndims(arrays[1]) == 2 "First array must have 2 dimensions" - @assert all(==(3) ∘ ndims, arrays[2:(end - 1)]) "All arrays must have 3 dimensions" - @assert ndims(arrays[end]) == 2 "Last array must have 2 dimensions" - issetequal(order, defaultorder(Chain, State())) || - throw(ArgumentError("order must be a permutation of $(String.(defaultorder(Chain, State())))")) - - n = length(arrays) - gen = IndexCounter() - symbols = [nextindex!(gen) for _ in 1:(2n)] - - _tensors = map(enumerate(arrays)) do (i, array) - _order = if i == 1 - filter(x -> x != :l, order) - elseif i == n - filter(x -> x != :r, order) - else - order - end - - inds = map(_order) do dir - if dir == :o - symbols[i] - elseif dir == :r - symbols[n + mod1(i, n)] - elseif dir == :l - symbols[n + mod1(i - 1, n)] - else - throw(ArgumentError("Invalid direction: $dir")) - end - end - Tensor(array, inds) - end - - sitemap = Dict(Site(i) => symbols[i] for i in 1:n) - - return Chain(Quantum(TensorNetwork(_tensors), sitemap), boundary) -end - function Chain(::Operator, boundary::Periodic, arrays::Vector{<:AbstractArray}; order=defaultorder(Chain, Operator())) @assert all(==(4) ∘ ndims, arrays) "All arrays must have 4 dimensions" issetequal(order, defaultorder(Chain, Operator())) || @@ -131,48 +72,6 @@ function Chain(::Operator, boundary::Periodic, arrays::Vector{<:AbstractArray}; return Chain(Quantum(TensorNetwork(_tensors), sitemap), boundary) end -function Chain(::Operator, boundary::Open, arrays::Vector{<:AbstractArray}; order=defaultorder(Chain, Operator())) - @assert ndims(arrays[1]) == 3 "First array must have 3 dimensions" - @assert all(==(4) ∘ ndims, arrays[2:(end - 1)]) "All arrays must have 4 dimensions" - @assert ndims(arrays[end]) == 3 "Last array must have 3 dimensions" - issetequal(order, defaultorder(Chain, Operator())) || - throw(ArgumentError("order must be a permutation of $(String.(defaultorder(Chain, Operator())))")) - - n = length(arrays) - gen = IndexCounter() - symbols = [nextindex!(gen) for _ in 1:(3n - 1)] - - _tensors = map(enumerate(arrays)) do (i, array) - _order = if i == 1 - filter(x -> x != :l, order) - elseif i == n - filter(x -> x != :r, order) - else - order - end - - inds = map(_order) do dir - if dir == :o - symbols[i] - elseif dir == :i - symbols[i + n] - elseif dir == :l - symbols[2n + mod1(i - 1, n)] - elseif dir == :r - symbols[2n + mod1(i, n)] - else - throw(ArgumentError("Invalid direction: $dir")) - end - end - Tensor(array, inds) - end - - sitemap = Dict(Site(i) => symbols[i] for i in 1:n) - merge!(sitemap, Dict(Site(i; dual=true) => symbols[i + n] for i in 1:n)) - - return Chain(Quantum(TensorNetwork(_tensors), sitemap), boundary) -end - function Base.convert(::Type{Chain}, qtn::Product) arrs::Vector{Array} = arrays(qtn) arrs[1] = reshape(arrs[1], size(arrs[1])..., 1) @@ -206,82 +105,6 @@ function rightindex(::Open, tn::Chain, site::Site) end rightindex(::Periodic, tn::Chain, site::Site) = inds(tn; bond=(site, rightsite(tn, site))) -Base.adjoint(chain::Chain) = Chain(adjoint(Quantum(chain)), boundary(chain)) - -struct ChainSampler{B<:Boundary,S<:Socket,NT<:NamedTuple} <: Random.Sampler{Chain} - parameters::NT - - ChainSampler{B,S}(; kwargs...) where {B,S} = new{B,S,typeof(values(kwargs))}(values(kwargs)) -end - -function Base.rand(A::Type{<:Chain}, B::Type{<:Boundary}, S::Type{<:Socket}; kwargs...) - return rand(Random.default_rng(), A, B, S; kwargs...) -end - -function Base.rand(rng::AbstractRNG, ::Type{A}, ::Type{B}, ::Type{S}; kwargs...) where {A<:Chain,B<:Boundary,S<:Socket} - return rand(rng, ChainSampler{B,S}(; kwargs...), B, S) -end - -# TODO let choose the orthogonality center -function Base.rand(rng::Random.AbstractRNG, sampler::ChainSampler, ::Type{Open}, ::Type{State}) - n = sampler.parameters.n - χ = sampler.parameters.χ - p = get(sampler.parameters, :p, 2) - T = get(sampler.parameters, :eltype, Float64) - - arrays::Vector{AbstractArray{T,N} where {N}} = map(1:n) do i - χl, χr = let after_mid = i > n ÷ 2, i = (n + 1 - abs(2i - n - 1)) ÷ 2 - χl = min(χ, p^(i - 1)) - χr = min(χ, p^i) - - # swap bond dims after mid and handle midpoint for odd-length MPS - (isodd(n) && i == n ÷ 2 + 1) ? (χl, χl) : (after_mid ? (χr, χl) : (χl, χr)) - end - - # orthogonalize by QR factorization - F = lq!(rand(rng, T, χl, p * χr)) - - reshape(Matrix(F.Q), χl, p, χr) - end - - # reshape boundary sites - arrays[1] = reshape(arrays[1], p, p) - arrays[n] = reshape(arrays[n], p, p) - - return Chain(State(), Open(), arrays; order=(:l, :o, :r)) -end - -# TODO different input/output physical dims -function Base.rand(rng::Random.AbstractRNG, sampler::ChainSampler, ::Type{Open}, ::Type{Operator}) - n = sampler.parameters.n - χ = sampler.parameters.χ - p = get(sampler.parameters, :p, 2) - T = get(sampler.parameters, :eltype, Float64) - - ip = op = p - - arrays::Vector{AbstractArray{T,N} where {N}} = map(1:n) do i - χl, χr = let after_mid = i > n ÷ 2, i = (n + 1 - abs(2i - n - 1)) ÷ 2 - χl = min(χ, ip^(i - 1) * op^(i - 1)) - χr = min(χ, ip^i * op^i) - - # swap bond dims after mid and handle midpoint for odd-length MPS - (isodd(n) && i == n ÷ 2 + 1) ? (χl, χl) : (after_mid ? (χr, χl) : (χl, χr)) - end - - # orthogonalize by QR factorization - F = lq!(rand(rng, T, χl, ip * op * χr)) - reshape(Matrix(F.Q), χl, ip, op, χr) - end - - # reshape boundary sites - arrays[1] = reshape(arrays[1], p, p, min(χ, ip * op)) - arrays[n] = reshape(arrays[n], min(χ, ip * op), p, p) - - # TODO order might not be the best for performance - return Chain(Operator(), Open(), arrays; order=(:l, :i, :o, :r)) -end - # """ # Tenet.contract!(tn::Chain; between=(site1, site2), direction::Symbol = :left, delete_Λ = true) diff --git a/src/Ansatz/MPO.jl b/src/Ansatz/MPO.jl index 8055c098d..d943df67d 100644 --- a/src/Ansatz/MPO.jl +++ b/src/Ansatz/MPO.jl @@ -1,47 +1,85 @@ +using Random + struct MPO <: Ansatz super::Quantum + form::Form end +defaultorder(::Type{MPO}) = (:o, :i, :l, :r) +boundary(::MPO) = Open() +form(tn::MPO) = tn.form + function MPO(arrays::Vector{<:AbstractArray}; order=defaultorder(MPO)) @assert ndims(arrays[1]) == 3 "First array must have 3 dimensions" @assert all(==(4) ∘ ndims, arrays[2:(end - 1)]) "All arrays must have 4 dimensions" @assert ndims(arrays[end]) == 3 "Last array must have 3 dimensions" - issetequal(order, defaultorder(Chain, Operator())) || - throw(ArgumentError("order must be a permutation of $(String.(defaultorder(Chain, Operator())))")) + issetequal(order, defaultorder(MPO)) || + throw(ArgumentError("order must be a permutation of $(String.(defaultorder(MPO)))")) n = length(arrays) gen = IndexCounter() symbols = [nextindex!(gen) for _ in 1:(3n - 1)] - _tensors = map(enumerate(arrays)) do (i, array) - _order = if i == 1 - filter(x -> x != :l, order) - elseif i == n - filter(x -> x != :r, order) - else - order - end - - inds = map(_order) do dir - if dir == :o - symbols[i] - elseif dir == :i - symbols[i + n] - elseif dir == :l - symbols[2n + mod1(i - 1, n)] - elseif dir == :r - symbols[2n + mod1(i, n)] + tn = TensorNetwork( + map(enumerate(arrays)) do (i, array) + _order = if i == 1 + filter(x -> x != :l, order) + elseif i == n + filter(x -> x != :r, order) else - throw(ArgumentError("Invalid direction: $dir")) + order end - end - Tensor(array, inds) - end + + inds = map(_order) do dir + if dir == :o + symbols[i] + elseif dir == :i + symbols[i + n] + elseif dir == :l + symbols[2n + mod1(i - 1, n)] + elseif dir == :r + symbols[2n + mod1(i, n)] + else + throw(ArgumentError("Invalid direction: $dir")) + end + end + Tensor(array, inds) + end, + ) sitemap = Dict(Site(i) => symbols[i] for i in 1:n) merge!(sitemap, Dict(Site(i; dual=true) => symbols[i + n] for i in 1:n)) + qtn = Quantum(tn, sitemap) - return MPO(Quantum(TensorNetwork(_tensors), sitemap)) + return MPO(qtn, NonCanonical()) end -boundary(::MPO) = Open() +Base.adjoint(tn::MPO) = MPO(adjoint(Quantum(tn)), form(tn)) + +# TODO different input/output physical dims +# TODO let choose the orthogonality center +function Base.rand(rng::Random.AbstractRNG, ::Type{MPO}, n, χ; eltype=Float64, physical_dim=2) + T = eltype + ip = op = physical_dim + + arrays::Vector{AbstractArray{T,N} where {N}} = map(1:n) do i + χl, χr = let after_mid = i > n ÷ 2, i = (n + 1 - abs(2i - n - 1)) ÷ 2 + χl = min(χ, ip^(i - 1) * op^(i - 1)) + χr = min(χ, ip^i * op^i) + + # swap bond dims after mid and handle midpoint for odd-length MPS + (isodd(n) && i == n ÷ 2 + 1) ? (χl, χl) : (after_mid ? (χr, χl) : (χl, χr)) + end + + # orthogonalize by QR factorization + F = lq!(rand(rng, T, χl, ip * op * χr)) + reshape(Matrix(F.Q), χl, ip, op, χr) + end + + # reshape boundary sites + arrays[1] = reshape(arrays[1], p, p, min(χ, ip * op)) + arrays[n] = reshape(arrays[n], min(χ, ip * op), p, p) + + # TODO order might not be the best for performance + return MPO(arrays; order=(:l, :i, :o, :r)) +end diff --git a/src/Ansatz/MPS.jl b/src/Ansatz/MPS.jl index ade4f7099..66f35f645 100644 --- a/src/Ansatz/MPS.jl +++ b/src/Ansatz/MPS.jl @@ -1,44 +1,84 @@ -struct MPS <: Ansatz +using Random + +abstract type AbstractMPS <: Ansatz end + +struct MPS <: AbstractMPS super::Quantum + form::Form end +defaultorder(::Type{MPS}) = (:o, :l, :r) +boundary(::MPS) = Open() +form(tn::MPS) = tn.form + function MPS(arrays::Vector{<:AbstractArray}; order=defaultorder(MPS)) @assert ndims(arrays[1]) == 2 "First array must have 2 dimensions" @assert all(==(3) ∘ ndims, arrays[2:(end - 1)]) "All arrays must have 3 dimensions" @assert ndims(arrays[end]) == 2 "Last array must have 2 dimensions" - issetequal(order, defaultorder(Chain, State())) || - throw(ArgumentError("order must be a permutation of $(String.(defaultorder(Chain, State())))")) + issetequal(order, defaultorder(MPS)) || + throw(ArgumentError("order must be a permutation of $(String.(defaultorder(MPS)))")) n = length(arrays) gen = IndexCounter() symbols = [nextindex!(gen) for _ in 1:(2n)] - _tensors = map(enumerate(arrays)) do (i, array) - _order = if i == 1 - filter(x -> x != :l, order) - elseif i == n - filter(x -> x != :r, order) - else - order - end - - inds = map(_order) do dir - if dir == :o - symbols[i] - elseif dir == :r - symbols[n + mod1(i, n)] - elseif dir == :l - symbols[n + mod1(i - 1, n)] + tn = TensorNetwork( + map(enumerate(arrays)) do (i, array) + _order = if i == 1 + filter(x -> x != :l, order) + elseif i == n + filter(x -> x != :r, order) else - throw(ArgumentError("Invalid direction: $dir")) + order end - end - Tensor(array, inds) - end + + inds = map(_order) do dir + if dir == :o + symbols[i] + elseif dir == :r + symbols[n + mod1(i, n)] + elseif dir == :l + symbols[n + mod1(i - 1, n)] + else + throw(ArgumentError("Invalid direction: $dir")) + end + end + Tensor(array, inds) + end, + ) sitemap = Dict(Site(i) => symbols[i] for i in 1:n) + qtn = Quantum(tn, sitemap) - return MPS(Quantum(TensorNetwork(_tensors), sitemap)) + return MPS(qtn, NonCanonical()) end -boundary(::MPS) = Open() +Base.adjoint(tn::MPS) = MPS(adjoint(Quantum(tn)), form(tn)) + +# TODO different input/output physical dims +# TODO let choose the orthogonality center +function Base.rand(rng::Random.AbstractRNG, ::Type{MPS}, n, χ; eltype=Float64, physical_dim=2) + p = physical_dim + T = eltype + + arrays::Vector{AbstractArray{T,N} where {N}} = map(1:n) do i + χl, χr = let after_mid = i > n ÷ 2, i = (n + 1 - abs(2i - n - 1)) ÷ 2 + χl = min(χ, p^(i - 1)) + χr = min(χ, p^i) + + # swap bond dims after mid and handle midpoint for odd-length MPS + (isodd(n) && i == n ÷ 2 + 1) ? (χl, χl) : (after_mid ? (χr, χl) : (χl, χr)) + end + + # orthogonalize by QR factorization + F = lq!(rand(rng, T, χl, p * χr)) + + reshape(Matrix(F.Q), χl, p, χr) + end + + # reshape boundary sites + arrays[1] = reshape(arrays[1], p, p) + arrays[n] = reshape(arrays[n], p, p) + + return MPS(arrays; order=(:l, :o, :r)) +end From d29ff7d345e6167df4be1a4dbcde5b543fe2634e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 12 Sep 2024 06:14:37 -0400 Subject: [PATCH 04/75] Implement conversion from `Product` to `MPS`, `MPO` --- src/Ansatz/Chain.jl | 11 ----------- src/Ansatz/MPO.jl | 13 +++++++++++++ src/Ansatz/MPS.jl | 13 +++++++++++++ 3 files changed, 26 insertions(+), 11 deletions(-) diff --git a/src/Ansatz/Chain.jl b/src/Ansatz/Chain.jl index a68b7ef53..05af06dd4 100644 --- a/src/Ansatz/Chain.jl +++ b/src/Ansatz/Chain.jl @@ -72,17 +72,6 @@ function Chain(::Operator, boundary::Periodic, arrays::Vector{<:AbstractArray}; return Chain(Quantum(TensorNetwork(_tensors), sitemap), boundary) end -function Base.convert(::Type{Chain}, qtn::Product) - arrs::Vector{Array} = arrays(qtn) - arrs[1] = reshape(arrs[1], size(arrs[1])..., 1) - arrs[end] = reshape(arrs[end], size(arrs[end])..., 1) - map!(@view(arrs[2:(end - 1)]), @view(arrs[2:(end - 1)])) do arr - reshape(arr, size(arr)..., 1, 1) - end - - return Chain(socket(qtn), Open(), arrs) -end - leftsite(tn::Chain, site::Site) = leftsite(boundary(tn), tn, site) function leftsite(::Open, tn::Chain, site::Site) return id(site) ∈ range(2, nlanes(tn)) ? Site(id(site) - 1; dual=isdual(site)) : nothing diff --git a/src/Ansatz/MPO.jl b/src/Ansatz/MPO.jl index d943df67d..629f59d51 100644 --- a/src/Ansatz/MPO.jl +++ b/src/Ansatz/MPO.jl @@ -54,6 +54,19 @@ function MPO(arrays::Vector{<:AbstractArray}; order=defaultorder(MPO)) return MPO(qtn, NonCanonical()) end +function Base.convert(::Type{MPO}, tn::Product) + @assert socket(tn) == Operator() + + arrs::Vector{Array} = arrays(tn) + arrs[1] = reshape(arrs[1], size(arrs[1])..., 1) + arrs[end] = reshape(arrs[end], size(arrs[end])..., 1) + map!(@view(arrs[2:(end - 1)]), @view(arrs[2:(end - 1)])) do arr + reshape(arr, size(arr)..., 1, 1) + end + + return MPO(arrs) +end + Base.adjoint(tn::MPO) = MPO(adjoint(Quantum(tn)), form(tn)) # TODO different input/output physical dims diff --git a/src/Ansatz/MPS.jl b/src/Ansatz/MPS.jl index 66f35f645..8ceb5860c 100644 --- a/src/Ansatz/MPS.jl +++ b/src/Ansatz/MPS.jl @@ -53,6 +53,19 @@ function MPS(arrays::Vector{<:AbstractArray}; order=defaultorder(MPS)) return MPS(qtn, NonCanonical()) end +function Base.convert(::Type{MPS}, tn::Product) + @assert socket(tn) == State() + + arrs::Vector{Array} = arrays(tn) + arrs[1] = reshape(arrs[1], size(arrs[1])..., 1) + arrs[end] = reshape(arrs[end], size(arrs[end])..., 1) + map!(@view(arrs[2:(end - 1)]), @view(arrs[2:(end - 1)])) do arr + reshape(arr, size(arr)..., 1, 1) + end + + return MPS(arrs) +end + Base.adjoint(tn::MPS) = MPS(adjoint(Quantum(tn)), form(tn)) # TODO different input/output physical dims From 0f0e5dd746ffe075b642f9b294f71e18fac054ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sat, 14 Sep 2024 16:58:23 -0400 Subject: [PATCH 05/75] Implement `hassite` method and alias to `in` --- src/Quantum.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/Quantum.jl b/src/Quantum.jl index c5298ef64..f3f459d65 100644 --- a/src/Quantum.jl +++ b/src/Quantum.jl @@ -291,6 +291,9 @@ function rmsite!(tn::AbstractQuantum, site) return delete!(tn.sites, site) end +hassite(tn::AbstractQuantum, site) = haskey(Quantum(tn).sites, site) +Base.in(site::Site, tn::AbstractQuantum) = hassite(tn, site) + @kwmethod function sites(tn::AbstractQuantum; set) tn = Quantum(tn) if set === :all From 555440aa1eb4ad6a8758f8a710a3a651af0c10dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 15 Sep 2024 17:53:47 -0400 Subject: [PATCH 06/75] Refactor `Ansatz` into a concrete type --- Project.toml | 2 ++ src/Ansatz/Ansatz.jl | 61 +++++++++++++++++++++++++------------------- src/Tenet.jl | 1 + 3 files changed, 38 insertions(+), 26 deletions(-) diff --git a/Project.toml b/Project.toml index c3f02be2d..16910bef2 100644 --- a/Project.toml +++ b/Project.toml @@ -8,8 +8,10 @@ AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" DeltaArrays = "10b0fc19-5ccc-4427-889b-d75dd6306188" EinExprs = "b1794770-133b-4de1-afb4-526377e9f4c5" +Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" KeywordDispatch = "5888135b-5456-5c80-a1b6-c91ef8180460" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MetaGraphsNext = "fa8bd995-216d-47f1-8a91-f3b68fbeb377" OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63" diff --git a/src/Ansatz/Ansatz.jl b/src/Ansatz/Ansatz.jl index e6c0f123a..6d915d2b3 100644 --- a/src/Ansatz/Ansatz.jl +++ b/src/Ansatz/Ansatz.jl @@ -1,33 +1,51 @@ +using KeywordDispatch using LinearAlgebra +using Graphs +using MetaGraphsNext + +abstract type AbstractAnsatz <: AbstractQuantum end """ Ansatz -[`AbstractQuantum`](@ref) Tensor Network with a predefined structure. - -# Notes - - - Any subtype must define `super::Quantum` field or specialize the `Quantum` method. +[`AbstractQuantum`](@ref) Tensor Network with a preserving structure. """ -abstract type Ansatz <: AbstractQuantum end +struct Ansatz <: AbstractAnsatz + tn::Quantum + lattice::MetaGraph + + function Ansatz(tn, lattice) + if !issetequal(site(tn), labels(lattice)) + throw(ArgumentError("Sites of the tensor network and the lattice must be equal")) + end + return new(tn, lattice) + end +end + +Ansatz(tn::Ansatz) = tn +Quantum(tn::AbstractAnsatz) = Ansatz(tn).tn +lattice(tn::AbstractAnsatz) = Ansatz(tn).lattice -# TODO maybe we need to change this? -Quantum(@nospecialize tn::Ansatz) = tn.super +function Base.isapprox(a::AbstractAnsatz, b::AbstractAnsatz; kwargs...) + return ==(latice.((a, b))...) && isapprox(Quantum(a), Quantum(b); kwargs...) +end -Base.:(==)(a::Ansatz, b::Ansatz) = Quantum(a) == Quantum(b) -Base.isapprox(a::Ansatz, b::Ansatz; kwargs...) = isapprox(Quantum(a), Quantum(b); kwargs...) +function neighbors(tn::AbstractAnsatz, site::Site) + # TODO + # return neighbors(lattice(tn), site) +end -alias(::A) where {A} = string(A) -function Base.summary(io::IO, tn::A) where {A<:Ansatz} - return print(io, "$(alias(tn)) (inputs=$(nsites(tn; set=:inputs)), outputs=$(nsites(tn; set=:outputs)))") +function isneighbor(tn::AbstractAnsatz, a::Site, b::Site) + # TODO + # return isneighbor(lattice(tn), a, b) end -Base.show(io::IO, tn::A) where {A<:Ansatz} = summary(io, tn) -@kwmethod function inds(tn::Ansatz; bond) +@kwmethod function inds(tn::AbstractAnsatz; bond) (site1, site2) = bond @assert site1 ∈ sites(tn) "Site $site1 not found" @assert site2 ∈ sites(tn) "Site $site2 not found" @assert site1 != site2 "Sites must be different" + @assert isneighbor(tn, site1, site2) "Sites must be neighbors" tensor1 = tensors(tn; at=site1) tensor2 = tensors(tn; at=site2) @@ -36,11 +54,12 @@ Base.show(io::IO, tn::A) where {A<:Ansatz} = summary(io, tn) return only(inds(tensor1) ∩ inds(tensor2)) end -@kwmethod function Tenet.tensors(tn::Ansatz; between) +@kwmethod function Tenet.tensors(tn::AbstractAnsatz; between) (site1, site2) = between @assert site1 ∈ sites(tn) "Site $site1 not found" @assert site2 ∈ sites(tn) "Site $site2 not found" @assert site1 != site2 "Sites must be different" + @assert isneighbor(tn, site1, site2) "Sites must be neighbors" tensor1 = tensors(tn; at=site1) tensor2 = tensors(tn; at=site2) @@ -60,16 +79,6 @@ function Base.showerror(io::IO, e::MissingSchmidtCoefficientsException) return print(io, "Can't access the spectrum on bond $(e.bond)") end -function LinearAlgebra.norm(ψ::Ansatz, p::Real=2; kwargs...) - p == 2 || throw(ArgumentError("only L2-norm is implemented yet")) - - return LinearAlgebra.norm2(ψ; kwargs...) -end - -function LinearAlgebra.norm2(ψ::Ansatz; kwargs...) - return abs(sqrt(only(contract(merge(TensorNetwork(ψ), TensorNetwork(ψ')); kwargs...)))) -end - # Traits abstract type Boundary end struct Open <: Boundary end diff --git a/src/Tenet.jl b/src/Tenet.jl index d2b4af140..244ed981f 100644 --- a/src/Tenet.jl +++ b/src/Tenet.jl @@ -26,6 +26,7 @@ include("Ansatz/Ansatz.jl") export Ansatz export socket, Scalar, State, Operator export boundary, Open, Periodic +export form include("Ansatz/Product.jl") export Product From 425df9efe2ee5daf1b9d6085ccca72f9797fea7b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Mon, 16 Sep 2024 00:46:35 -0400 Subject: [PATCH 07/75] Fix typos in `Ansatz` --- src/Ansatz/Ansatz.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Ansatz/Ansatz.jl b/src/Ansatz/Ansatz.jl index 6d915d2b3..9a7657dae 100644 --- a/src/Ansatz/Ansatz.jl +++ b/src/Ansatz/Ansatz.jl @@ -15,7 +15,7 @@ struct Ansatz <: AbstractAnsatz lattice::MetaGraph function Ansatz(tn, lattice) - if !issetequal(site(tn), labels(lattice)) + if !issetequal(sites(tn), labels(lattice)) throw(ArgumentError("Sites of the tensor network and the lattice must be equal")) end return new(tn, lattice) @@ -30,7 +30,7 @@ function Base.isapprox(a::AbstractAnsatz, b::AbstractAnsatz; kwargs...) return ==(latice.((a, b))...) && isapprox(Quantum(a), Quantum(b); kwargs...) end -function neighbors(tn::AbstractAnsatz, site::Site) +function Graphs.neighbors(tn::AbstractAnsatz, site::Site) # TODO # return neighbors(lattice(tn), site) end From d2c8ce4a65254c433d23f95bc8a68655e491c7f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Mon, 16 Sep 2024 00:47:38 -0400 Subject: [PATCH 08/75] Use `Graphs.neighbors` --- src/TensorNetwork.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/TensorNetwork.jl b/src/TensorNetwork.jl index fbdddc81e..8778de50b 100644 --- a/src/TensorNetwork.jl +++ b/src/TensorNetwork.jl @@ -6,6 +6,7 @@ using LinearAlgebra using ScopedValues using Serialization using KeywordDispatch +using Graphs mutable struct CachedField{T} isvalid::Bool @@ -475,7 +476,7 @@ Base.merge!(self::TensorNetwork, other::TensorNetwork) = append!(self, tensors(o Base.merge!(self::TensorNetwork, others::TensorNetwork...) = foldl(merge!, others; init=self) Base.merge(self::AbstractTensorNetwork, others::AbstractTensorNetwork...) = merge!(copy(self), others...) -function neighbors(tn::AbstractTensorNetwork, tensor::Tensor; open::Bool=true) +function Graphs.neighbors(tn::AbstractTensorNetwork, tensor::Tensor; open::Bool=true) @assert tensor ∈ tn "Tensor not found in TensorNetwork" tensors = mapreduce(∪, inds(tensor)) do index Tenet.tensors(tn; intersects=index) @@ -484,7 +485,7 @@ function neighbors(tn::AbstractTensorNetwork, tensor::Tensor; open::Bool=true) return tensors end -function neighbors(tn::AbstractTensorNetwork, i::Symbol; open::Bool=true) +function Graphs.neighbors(tn::AbstractTensorNetwork, i::Symbol; open::Bool=true) @assert i ∈ tn "Index $i not found in TensorNetwork" tensors = mapreduce(inds, ∪, Tenet.tensors(tn; intersects=i)) # open && filter!(x -> x !== i, tensors) From e4928462942e58da7a3ba19df4707f3af4571d26 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Mon, 16 Sep 2024 00:47:55 -0400 Subject: [PATCH 09/75] Refactor `Product` on top of new `Ansatz` type --- src/Ansatz/Product.jl | 49 +++++++++++++++++++++---------------------- 1 file changed, 24 insertions(+), 25 deletions(-) diff --git a/src/Ansatz/Product.jl b/src/Ansatz/Product.jl index ab77d3be6..fd6c07d9a 100644 --- a/src/Ansatz/Product.jl +++ b/src/Ansatz/Product.jl @@ -1,55 +1,54 @@ using LinearAlgebra +using Graphs +using MetaGraphsNext -struct Product <: Ansatz - super::Quantum +struct Product <: AbstractAnsatz + tn::Ansatz end -Base.copy(x::Product) = Product(copy(Quantum(x))) +Ansatz(tn::Product) = tn.tn -Base.similar(x::Product) = Product(similar(Quantum(x))) -Base.zero(x::Product) = Product(zero(Quantum(x))) +Base.copy(x::Product) = Product(copy(Ansatz(x))) -function Product(tn::TensorNetwork, sites) - @assert isempty(inds(tn; set=:inner)) "Product ansatz must not have inner indices" - return Product(Quantum(tn, sites)) -end - -Product(arrays::Vector{<:AbstractVector}) = Product(State(), Open(), arrays) -Product(arrays::Vector{<:AbstractMatrix}) = Product(Operator(), Open(), arrays) +Base.similar(x::Product) = Product(similar(Ansatz(x))) +Base.zero(x::Product) = Product(zero(Ansatz(x))) -function Product(::State, ::Open, arrays) +function Product(arrays::Vector{<:AbstractVector}) + n = length(arrays) gen = IndexCounter() - symbols = [nextindex!(gen) for _ in 1:length(arrays)] + symbols = [nextindex!(gen) for _ in 1:n] _tensors = map(enumerate(arrays)) do (i, array) Tensor(array, [symbols[i]]) end - sitemap = Dict(Site(i) => symbols[i] for i in 1:length(arrays)) - - return Product(TensorNetwork(_tensors), sitemap) + sitemap = Dict(Site(i) => symbols[i] for i in 1:n) + qtn = Quantum(TensorNetwork(_tensors), sitemap) + lattice = MetaGraph(Graph(n), Pair{Site,Nothing}[Site(i) => nothing for i in 1:n], Pair{Tuple{Site,Site},Nothing}[]) + ansatz = Ansatz(qtn, lattice) + return Product(ansatz) end -function Product(::Operator, ::Open, arrays) +function Product(arrays::Vector{<:AbstractMatrix}) n = length(arrays) gen = IndexCounter() symbols = [nextindex!(gen) for _ in 1:(2 * length(arrays))] _tensors = map(enumerate(arrays)) do (i, array) - Tensor(array, [symbols[i + n], symbols[i]]) + Tensor(array, [symbols[i + n], symbols[i]], []) end sitemap = merge!(Dict(Site(i; dual=true) => symbols[i] for i in 1:n), Dict(Site(i) => symbols[i + n] for i in 1:n)) - - return Product(TensorNetwork(_tensors), sitemap) + qtn = Quantum(TensorNetwork(_tensors), sitemap) + lattice = MetaGraph(Graph(n), Pair{Site,Nothing}[Site(i) => nothing for i in 1:n], Pair{Tuple{Site,Site},Nothing}[]) + ansatz = Ansatz(qtn, lattice) + return Product(ansatz) end function Base.zeros(::Type{Product}, n::Integer; p::Int=2, eltype=Bool) - return Product(State(), Open(), fill(append!([one(eltype)], collect(Iterators.repeated(zero(eltype), p - 1))), n)) + return Product(fill(append!([one(eltype)], collect(Iterators.repeated(zero(eltype), p - 1))), n)) end function Base.ones(::Type{Product}, n::Integer; p::Int=2, eltype=Bool) - return Product( - State(), Open(), fill(append!([zero(eltype), one(eltype)], collect(Iterators.repeated(zero(eltype), p - 2))), n) - ) + return Product(fill(append!([zero(eltype), one(eltype)], collect(Iterators.repeated(zero(eltype), p - 2))), n)) end LinearAlgebra.norm(tn::Product, p::Real=2) = LinearAlgebra.norm(socket(tn), tn, p) From d334d13c7e946e20cff3b1619366abb21d3cb47b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Mon, 16 Sep 2024 01:02:25 -0400 Subject: [PATCH 10/75] Implement `copy`, `similar`, `zero` for `Ansatz` --- src/Ansatz/Ansatz.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/Ansatz/Ansatz.jl b/src/Ansatz/Ansatz.jl index 9a7657dae..73f7dfd70 100644 --- a/src/Ansatz/Ansatz.jl +++ b/src/Ansatz/Ansatz.jl @@ -24,6 +24,11 @@ end Ansatz(tn::Ansatz) = tn Quantum(tn::AbstractAnsatz) = Ansatz(tn).tn + +Base.copy(tn::Ansatz) = Ansatz(copy(Quantum(tn)), copy(lattice(tn))) +Base.similar(tn::Ansatz) = Ansatz(similar(Quantum(tn)), copy(lattice(tn))) +Base.zero(tn::Ansatz) = Ansatz(zero(Quantum(tn)), copy(lattice(tn))) + lattice(tn::AbstractAnsatz) = Ansatz(tn).lattice function Base.isapprox(a::AbstractAnsatz, b::AbstractAnsatz; kwargs...) From 86fdd7a709d127049d595f28e188a616462a8b8a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Mon, 16 Sep 2024 01:02:44 -0400 Subject: [PATCH 11/75] Format code --- src/Ansatz/Product.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/Ansatz/Product.jl b/src/Ansatz/Product.jl index fd6c07d9a..8c86f570c 100644 --- a/src/Ansatz/Product.jl +++ b/src/Ansatz/Product.jl @@ -9,7 +9,6 @@ end Ansatz(tn::Product) = tn.tn Base.copy(x::Product) = Product(copy(Ansatz(x))) - Base.similar(x::Product) = Product(similar(Ansatz(x))) Base.zero(x::Product) = Product(zero(Ansatz(x))) From 1937f804048f6ad52a7651d3d6a4de205c2f635a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Mon, 16 Sep 2024 01:03:29 -0400 Subject: [PATCH 12/75] Refactor `Dense` on top of new `Ansatz` type --- src/Ansatz/Dense.jl | 45 ++++++++++++++++++++++++++++++++------------- 1 file changed, 32 insertions(+), 13 deletions(-) diff --git a/src/Ansatz/Dense.jl b/src/Ansatz/Dense.jl index a0fa5355f..0d8176d48 100644 --- a/src/Ansatz/Dense.jl +++ b/src/Ansatz/Dense.jl @@ -1,15 +1,24 @@ -struct Dense <: Ansatz - super::Quantum +using Combinatorics + +struct Dense <: AbstractAnsatz + tn::Ansatz end +Ansatz(tn::Dense) = tn.tn + +Base.copy(qtn::Dense) = Dense(copy(Ansatz(qtn))) +Base.similar(qtn::Dense) = Dense(similar(Ansatz(qtn))) +Base.zero(qtn::Dense) = Dense(zero(Ansatz(qtn))) + function Dense(::State, array::AbstractArray; sites=Site.(1:ndims(array))) - @assert ndims(array) > 0 + n = ndims(array) + @assert n > 0 @assert all(>(1), size(array)) gen = IndexCounter() - symbols = [nextindex!(gen) for _ in 1:ndims(array)] + symbols = [nextindex!(gen) for _ in 1:n] sitemap = Dict{Site,Symbol}( - map(sites, 1:ndims(array)) do site, i + map(sites, 1:n) do site, i site => symbols[i] end, ) @@ -18,23 +27,33 @@ function Dense(::State, array::AbstractArray; sites=Site.(1:ndims(array))) tn = TensorNetwork([tensor]) qtn = Quantum(tn, sitemap) - return Dense(qtn) + lattice = MetaGraph( + complete_graph(n), + Pair{Site,Nothing}[Site(i) => nothing for i in 1:n], + Pair{Tuple{Site,Site},Nothing}[(Site(i), Site(j)) => nothing for (i, j) in combinations(1:n, 2)], + ) + ansatz = Ansatz(qtn, lattice) + return Dense(ansatz) end function Dense(::Operator, array::AbstractArray; sites) - @assert ndims(array) > 0 + n = ndims(array) + @assert n > 0 @assert all(>(1), size(array)) - @assert length(sites) == ndims(array) + @assert length(sites) == n gen = IndexCounter() - tensor_inds = [nextindex!(gen) for _ in 1:ndims(array)] + tensor_inds = [nextindex!(gen) for _ in 1:n] tensor = Tensor(array, tensor_inds) tn = TensorNetwork([tensor]) sitemap = Dict{Site,Symbol}(map(splat(Pair), zip(sites, tensor_inds))) qtn = Quantum(tn, sitemap) - - return Dense(qtn) + lattice = MetaGraph( + complete_graph(n), + Pair{Site,Nothing}[Site(i) => nothing for i in 1:n], + Pair{Tuple{Site,Site},Nothing}[(Site(i), Site(j)) => nothing for (i, j) in combinations(1:n, 2)], + ) + ansatz = Ansatz(qtn, lattice) + return Dense(ansatz) end - -Base.copy(qtn::Dense) = Dense(copy(Quantum(qtn))) From ba817a27c6e752448ca7b6f1940da39186d8692c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Mon, 16 Sep 2024 15:28:25 -0400 Subject: [PATCH 13/75] Refactor `MPS`, `MPO` on top of new `Ansatz` type --- src/Ansatz/MPO.jl | 20 +++++++++++++++----- src/Ansatz/MPS.jl | 18 +++++++++++++----- 2 files changed, 28 insertions(+), 10 deletions(-) diff --git a/src/Ansatz/MPO.jl b/src/Ansatz/MPO.jl index 629f59d51..04709d146 100644 --- a/src/Ansatz/MPO.jl +++ b/src/Ansatz/MPO.jl @@ -1,10 +1,18 @@ using Random -struct MPO <: Ansatz - super::Quantum +abstract type AbstractMPO <: AbstractAnsatz end + +struct MPO <: AbstractAnsatz + tn::Ansatz form::Form end +Ansatz(tn::MPO) = tn.tn + +Base.copy(x::MPO) = MPO(copy(Ansatz(x)), form(x)) +Base.similar(x::MPO) = MPO(similar(Ansatz(x)), form(x)) +Base.zero(x::MPO) = MPO(zero(Ansatz(x)), form(x)) + defaultorder(::Type{MPO}) = (:o, :i, :l, :r) boundary(::MPO) = Open() form(tn::MPO) = tn.form @@ -50,8 +58,10 @@ function MPO(arrays::Vector{<:AbstractArray}; order=defaultorder(MPO)) sitemap = Dict(Site(i) => symbols[i] for i in 1:n) merge!(sitemap, Dict(Site(i; dual=true) => symbols[i + n] for i in 1:n)) qtn = Quantum(tn, sitemap) - - return MPO(qtn, NonCanonical()) + graph = path_graph(n) + lattice = MetaGraph(graph, Site.(vertices(graph)) .=> nothing, map(x -> Site.(Tuple(x)) => nothing, edges(graph))) + ansatz = Ansatz(qtn, lattice) + return MPO(ansatz, NonCanonical()) end function Base.convert(::Type{MPO}, tn::Product) @@ -67,7 +77,7 @@ function Base.convert(::Type{MPO}, tn::Product) return MPO(arrs) end -Base.adjoint(tn::MPO) = MPO(adjoint(Quantum(tn)), form(tn)) +Base.adjoint(tn::MPO) = MPO(adjoint(Ansatz(tn)), form(tn)) # TODO different input/output physical dims # TODO let choose the orthogonality center diff --git a/src/Ansatz/MPS.jl b/src/Ansatz/MPS.jl index 8ceb5860c..d370581e7 100644 --- a/src/Ansatz/MPS.jl +++ b/src/Ansatz/MPS.jl @@ -1,12 +1,18 @@ using Random -abstract type AbstractMPS <: Ansatz end +abstract type AbstractMPS <: AbstractAnsatz end struct MPS <: AbstractMPS - super::Quantum + tn::Ansatz form::Form end +Ansatz(tn::MPS) = tn.tn + +Base.copy(x::MPS) = MPS(copy(Ansatz(x)), form(x)) +Base.similar(x::MPS) = MPS(similar(Ansatz(x)), form(x)) +Base.zero(x::MPS) = MPS(zero(Ansatz(x)), form(x)) + defaultorder(::Type{MPS}) = (:o, :l, :r) boundary(::MPS) = Open() form(tn::MPS) = tn.form @@ -49,8 +55,10 @@ function MPS(arrays::Vector{<:AbstractArray}; order=defaultorder(MPS)) sitemap = Dict(Site(i) => symbols[i] for i in 1:n) qtn = Quantum(tn, sitemap) - - return MPS(qtn, NonCanonical()) + graph = path_graph(n) + lattice = MetaGraph(graph, Site.(vertices(graph)) .=> nothing, map(x -> Site.(Tuple(x)) => nothing, edges(graph))) + ansatz = Ansatz(qtn, lattice) + return MPS(ansatz, NonCanonical()) end function Base.convert(::Type{MPS}, tn::Product) @@ -66,7 +74,7 @@ function Base.convert(::Type{MPS}, tn::Product) return MPS(arrs) end -Base.adjoint(tn::MPS) = MPS(adjoint(Quantum(tn)), form(tn)) +Base.adjoint(tn::MPS) = MPS(adjoint(Ansatz(tn)), form(tn)) # TODO different input/output physical dims # TODO let choose the orthogonality center From 3370244b937c3d583b6c115631ee4d22c4e353e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Mon, 16 Sep 2024 15:28:51 -0400 Subject: [PATCH 14/75] Relax `sites` condition on `Ansatz` construction for `lanes` --- src/Ansatz/Ansatz.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Ansatz/Ansatz.jl b/src/Ansatz/Ansatz.jl index 73f7dfd70..6bfbca555 100644 --- a/src/Ansatz/Ansatz.jl +++ b/src/Ansatz/Ansatz.jl @@ -15,7 +15,7 @@ struct Ansatz <: AbstractAnsatz lattice::MetaGraph function Ansatz(tn, lattice) - if !issetequal(sites(tn), labels(lattice)) + if !issetequal(lanes(tn), labels(lattice)) throw(ArgumentError("Sites of the tensor network and the lattice must be equal")) end return new(tn, lattice) From 8c05563ac8aaf2e3a87d917e428d410deab8627a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Mon, 16 Sep 2024 15:48:19 -0400 Subject: [PATCH 15/75] Implement `PEPS`, `PEPO` types --- src/Ansatz/Grid.jl | 180 --------------------------------------------- src/Ansatz/PEPO.jl | 102 +++++++++++++++++++++++++ src/Ansatz/PEPS.jl | 94 +++++++++++++++++++++++ src/Tenet.jl | 10 ++- 4 files changed, 202 insertions(+), 184 deletions(-) delete mode 100644 src/Ansatz/Grid.jl create mode 100644 src/Ansatz/PEPO.jl create mode 100644 src/Ansatz/PEPS.jl diff --git a/src/Ansatz/Grid.jl b/src/Ansatz/Grid.jl deleted file mode 100644 index ef59f3e84..000000000 --- a/src/Ansatz/Grid.jl +++ /dev/null @@ -1,180 +0,0 @@ -struct Grid <: Ansatz - super::Quantum - boundary::Boundary -end - -Base.copy(tn::Grid) = Grid(copy(Quantum(tn)), boundary(tn)) - -boundary(tn::Grid) = tn.boundary - -PEPS(arrays) = Grid(State(), Open(), arrays) -pPEPS(arrays) = Grid(State(), Periodic(), arrays) -PEPO(arrays) = Grid(Operator(), Open(), arrays) -pPEPO(arrays) = Grid(Operator(), Periodic(), arrays) - -alias(tn::Grid) = alias(socket(tn), boundary(tn), tn) -alias(::State, ::Open, ::Grid) = "PEPS" -alias(::State, ::Periodic, ::Grid) = "pPEPS" -alias(::Operator, ::Open, ::Grid) = "PEPO" -alias(::Operator, ::Periodic, ::Grid) = "pPEPO" - -function Grid(::State, ::Periodic, arrays::Matrix{<:AbstractArray}) - @assert all(==(4) ∘ ndims, arrays) "All arrays must have 4 dimensions" - - m, n = size(arrays) - gen = IndexCounter() - pinds = map(_ -> nextindex!(gen), arrays) - hvinds = map(_ -> nextindex!(gen), arrays) - vvinds = map(_ -> nextindex!(gen), arrays) - - _tensors = map(eachindex(IndexCartesian(), arrays)) do I - i, j = Tuple(I) - - array = arrays[i, j] - pind = pinds[i, j] - up, down = hvinds[i, j], hvinds[mod1(i + 1, m), j] - left, right = vvinds[i, j], vvinds[i, mod1(j + 1, n)] - - # TODO customize order - Tensor(array, [pind, up, down, left, right]) - end - - sitemap = Dict(Site(i, j) => pinds[i, j] for i in 1:m, j in 1:n) - - return Grid(Quantum(TensorNetwork(_tensors), sitemap), Periodic()) -end - -function Grid(::State, ::Open, arrays::Matrix{<:AbstractArray}) - m, n = size(arrays) - - predicate = all(eachindex(arrays)) do I - i, j = Tuple(I) - array = arrays[i, j] - - N = ndims(array) - 1 - (i == 1 || i == m) && (N -= 1) - (j == 1 || j == n) && (N -= 1) - - N > 0 - end - - if !predicate - throw(DimensionMismatch()) - end - - gen = IndexCounter() - pinds = map(_ -> nextindex!(gen), arrays) - vvinds = [nextindex!(gen) for _ in 1:(m - 1), _ in 1:n] - hvinds = [nextindex!(gen) for _ in 1:m, _ in 1:(n - 1)] - - _tensors = map(eachindex(IndexCartesian(), arrays)) do I - i, j = Tuple(I) - - array = arrays[i, j] - pind = pinds[i, j] - up = i == 1 ? missing : vvinds[i - 1, j] - down = i == m ? missing : vvinds[i, j] - left = j == 1 ? missing : hvinds[i, j - 1] - right = j == n ? missing : hvinds[i, j] - - # TODO customize order - Tensor(array, collect(skipmissing([pind, up, down, left, right]))) - end - - sitemap = Dict(Site(i, j) => pinds[i, j] for i in 1:m, j in 1:n) - - return Grid(Quantum(TensorNetwork(_tensors), sitemap), Open()) -end - -function Grid(::Operator, ::Periodic, arrays::Matrix{<:AbstractArray}) - @assert all(==(4) ∘ ndims, arrays) "All arrays must have 4 dimensions" - - m, n = size(arrays) - gen = IndexCounter() - ipinds = map(_ -> nextindex!(gen), arrays) - opinds = map(_ -> nextindex!(gen), arrays) - hvinds = map(_ -> nextindex!(gen), arrays) - vvinds = map(_ -> nextindex!(gen), arrays) - - _tensors = map(eachindex(IndexCartesian(), arrays)) do I - i, j = Tuple(I) - - array = arrays[i, j] - ipind, opind = ipinds[i, j], opinds[i, j] - up, down = hvinds[i, j], hvinds[mod1(i + 1, m), j] - left, right = vvinds[i, j], vvinds[i, mod1(j + 1, n)] - - # TODO customize order - Tensor(array, [ipind, opind, up, down, left, right]) - end - - sitemap = Dict( - flatten([ - (Site(i, j; dual=true) => ipinds[i, j] for i in 1:m, j in 1:n), - (Site(i, j) => opinds[i, j] for i in 1:m, j in 1:n), - ]), - ) - - return Grid(Quantum(TensorNetwork(_tensors), sitemap), Periodic()) -end - -function Grid(::Operator, ::Open, arrays::Matrix{<:AbstractArray}) - m, n = size(arrays) - - predicate = all(eachindex(IndexCartesian(), arrays)) do I - i, j = Tuple(I) - array = arrays[i, j] - - N = ndims(array) - 2 - (i == 1 || i == m) && (N -= 1) - (j == 1 || j == n) && (N -= 1) - - N > 0 - end - - if !predicate - throw(DimensionMismatch()) - end - - gen = IndexCounter() - ipinds = map(_ -> nextindex!(gen), arrays) - opinds = map(_ -> nextindex!(gen), arrays) - vvinds = [nextindex!(gen) for _ in 1:(m - 1), _ in 1:n] - hvinds = [nextindex!(gen) for _ in 1:m, _ in 1:(n - 1)] - - _tensors = map(eachindex(IndexCartesian(), arrays)) do I - i, j = Tuple(I) - - array = arrays[i, j] - ipind = ipinds[i, j] - opind = opinds[i, j] - up = i == 1 ? missing : vvinds[i - 1, j] - down = i == m ? missing : vvinds[i, j] - left = j == 1 ? missing : hvinds[i, j - 1] - right = j == n ? missing : hvinds[i, j] - - # TODO customize order - Tensor(array, collect(skipmissing([ipind, opind, up, down, left, right]))) - end - - sitemap = Dict( - flatten([ - (Site(i, j; dual=true) => ipinds[i, j] for i in 1:m, j in 1:n), - (Site(i, j) => opinds[i, j] for i in 1:m, j in 1:n), - ]), - ) - - return Grid(Quantum(TensorNetwork(_tensors), sitemap), Open()) -end - -function LinearAlgebra.transpose!(qtn::Grid) - old = Quantum(qtn).sites - new = Dict(Site(reverse(id(site)); dual=isdual(site)) => ind for (site, ind) in old) - - empty!(old) - merge!(old, new) - - return qtn -end - -Base.transpose(qtn::Grid) = LinearAlgebra.transpose!(copy(qtn)) diff --git a/src/Ansatz/PEPO.jl b/src/Ansatz/PEPO.jl new file mode 100644 index 000000000..313a1851c --- /dev/null +++ b/src/Ansatz/PEPO.jl @@ -0,0 +1,102 @@ +abstract type AbstractPEPO <: AbstractAnsatz end + +struct PEPO <: AbstractPEPO + tn::Ansatz + form::Form +end + +Ansatz(tn::PEPO) = tn.tn + +Base.copy(x::PEPO) = PEPO(copy(Ansatz(x)), form(x)) +Base.similar(x::PEPO) = PEPO(similar(Ansatz(x)), form(x)) +Base.zero(x::PEPO) = PEPO(zero(Ansatz(x)), form(x)) + +defaultorder(::Type{PEPO}) = (:o, :i, :l, :r, :u, :d) +boundary(::PEPO) = Open() +form(tn::PEPO) = tn.form + +# TODO periodic boundary conditions +# TODO non-square lattice +function PEPO(arrays::Matrix{<:AbstractArray}; order=defaultorder(PEPO)) + @assert ndims(arrays[1, 1]) == 4 "Array at (1,1) must have 4 dimensions" + @assert ndims(arrays[1, end]) == 4 "Array at (1,end) must have 4 dimensions" + @assert ndims(arrays[end, 1]) == 4 "Array at (end,1) must have 4 dimensions" + @assert ndims(arrays[end, end]) == 4 "Array at (end,end) must have 4 dimensions" + @assert all( + ==(5) ∘ ndims, + Iterators.flatten([ + arrays[1, 2:(end - 1)], arrays[end, 2:(end - 1)], arrays[2:(end - 1), 1], arrays[2:(end - 1), end] + ]), + ) "Arrays at boundaries must have 5 dimensions" + @assert all(==(6) ∘ ndims, arrays[2:(end - 1), 2:(end - 1)]) "Inner arrays must have 6 dimensions" + issetequal(order, defaultorder(PEPO)) || + throw(ArgumentError("order must be a permutation of $(String.(defaultorder(PEPO)))")) + + m, n = size(arrays) + + predicate = all(eachindex(IndexCartesian(), arrays)) do I + i, j = Tuple(I) + array = arrays[i, j] + + N = ndims(array) - 2 + (i == 1 || i == m) && (N -= 1) + (j == 1 || j == n) && (N -= 1) + + N > 0 + end + + if !predicate + throw(DimensionMismatch()) + end + + gen = IndexCounter() + ipinds = map(_ -> nextindex!(gen), arrays) + opinds = map(_ -> nextindex!(gen), arrays) + vvinds = [nextindex!(gen) for _ in 1:(m - 1), _ in 1:n] + hvinds = [nextindex!(gen) for _ in 1:m, _ in 1:(n - 1)] + + _tensors = map(eachindex(IndexCartesian(), arrays)) do I + i, j = Tuple(I) + + array = arrays[i, j] + ipind = ipinds[i, j] + opind = opinds[i, j] + up = i == 1 ? missing : vvinds[i - 1, j] + down = i == m ? missing : vvinds[i, j] + left = j == 1 ? missing : hvinds[i, j - 1] + right = j == n ? missing : hvinds[i, j] + + # TODO customize order + Tensor(array, collect(skipmissing([ipind, opind, up, down, left, right]))) + end + + sitemap = Dict( + flatten([ + (Site(i, j; dual=true) => ipinds[i, j] for i in 1:m, j in 1:n), + (Site(i, j) => opinds[i, j] for i in 1:m, j in 1:n), + ]), + ) + + qtn = Quatum(tn, sitemap) + graph = grid((m, n)) + # TODO fix this + lattice = MetaGraph(graph, Site.(vertices(graph)) .=> nothing, map(x -> Site.(Tuple(x)) => nothing, edges(graph))) + ansatz = Ansatz(qtn, lattice) + return PEPO(ansatz, NonCanonical()) +end + +function Base.convert(::Type{PEPO}, tn::Product) + @assert socket(tn) == State() + + # TODO fix this + arrs::Matrix{<:AbstractArray} = arrays(tn) + arrs[1] = reshape(arrs[1], size(arrs[1])..., 1) + arrs[end] = reshape(arrs[end], size(arrs[end])..., 1) + map!(@view(arrs[2:(end - 1)]), @view(arrs[2:(end - 1)])) do arr + reshape(arr, size(arr)..., 1, 1) + end + + return PEPO(arrs) +end + +Base.adjoint(tn::PEPO) = PEPO(adjoint(Ansatz(tn)), form(tn)) diff --git a/src/Ansatz/PEPS.jl b/src/Ansatz/PEPS.jl new file mode 100644 index 000000000..1cb061d4d --- /dev/null +++ b/src/Ansatz/PEPS.jl @@ -0,0 +1,94 @@ +abstract type AbstractPEPS <: AbstractAnsatz end + +struct PEPS <: AbstractPEPS + tn::Ansatz + form::Form +end + +Ansatz(tn::PEPS) = tn.tn + +Base.copy(x::PEPS) = PEPS(copy(Ansatz(x)), form(x)) +Base.similar(x::PEPS) = PEPS(similar(Ansatz(x)), form(x)) +Base.zero(x::PEPS) = PEPS(zero(Ansatz(x)), form(x)) + +defaultorder(::Type{PEPS}) = (:o, :l, :r, :u, :d) +boundary(::PEPS) = Open() +form(tn::PEPS) = tn.form + +# TODO periodic boundary conditions +# TODO non-square lattice +function PEPS(arrays::Matrix{<:AbstractArray}; order=defaultorder(PEPS)) + @assert ndims(arrays[1, 1]) == 3 "Array at (1,1) must have 3 dimensions" + @assert ndims(arrays[1, end]) == 3 "Array at (1,end) must have 3 dimensions" + @assert ndims(arrays[end, 1]) == 3 "Array at (end,1) must have 3 dimensions" + @assert ndims(arrays[end, end]) == 3 "Array at (end,end) must have 3 dimensions" + @assert all( + ==(4) ∘ ndims, + Iterators.flatten([ + arrays[1, 2:(end - 1)], arrays[end, 2:(end - 1)], arrays[2:(end - 1), 1], arrays[2:(end - 1), end] + ]), + ) "Arrays at boundaries must have 4 dimensions" + @assert all(==(5) ∘ ndims, arrays[2:(end - 1), 2:(end - 1)]) "Inner arrays must have 5 dimensions" + issetequal(order, defaultorder(PEPS)) || + throw(ArgumentError("order must be a permutation of $(String.(defaultorder(PEPS)))")) + + m, n = size(arrays) + + predicate = all(eachindex(arrays)) do I + i, j = Tuple(I) + array = arrays[i, j] + + N = ndims(array) - 1 + (i == 1 || i == m) && (N -= 1) + (j == 1 || j == n) && (N -= 1) + + N > 0 + end + + if !predicate + throw(DimensionMismatch()) + end + + gen = IndexCounter() + pinds = map(_ -> nextindex!(gen), arrays) + vvinds = [nextindex!(gen) for _ in 1:(m - 1), _ in 1:n] + hvinds = [nextindex!(gen) for _ in 1:m, _ in 1:(n - 1)] + + _tensors = map(eachindex(IndexCartesian(), arrays)) do I + i, j = Tuple(I) + + array = arrays[i, j] + pind = pinds[i, j] + up = i == 1 ? missing : vvinds[i - 1, j] + down = i == m ? missing : vvinds[i, j] + left = j == 1 ? missing : hvinds[i, j - 1] + right = j == n ? missing : hvinds[i, j] + + # TODO customize order + Tensor(array, collect(skipmissing([pind, up, down, left, right]))) + end + + sitemap = Dict(Site(i, j) => pinds[i, j] for i in 1:m, j in 1:n) + qtn = Quatum(tn, sitemap) + graph = grid((m, n)) + # TODO fix this + lattice = MetaGraph(graph, Site.(vertices(graph)) .=> nothing, map(x -> Site.(Tuple(x)) => nothing, edges(graph))) + ansatz = Ansatz(qtn, lattice) + return PEPS(ansatz, NonCanonical()) +end + +function Base.convert(::Type{PEPS}, tn::Product) + @assert socket(tn) == State() + + # TODO fix this + arrs::Matrix{<:AbstractArray} = arrays(tn) + arrs[1] = reshape(arrs[1], size(arrs[1])..., 1) + arrs[end] = reshape(arrs[end], size(arrs[end])..., 1) + map!(@view(arrs[2:(end - 1)]), @view(arrs[2:(end - 1)])) do arr + reshape(arr, size(arr)..., 1, 1) + end + + return PEPS(arrs) +end + +Base.adjoint(tn::PEPS) = PEPS(adjoint(Ansatz(tn)), form(tn)) diff --git a/src/Tenet.jl b/src/Tenet.jl index 244ed981f..9f67a3c7c 100644 --- a/src/Tenet.jl +++ b/src/Tenet.jl @@ -40,14 +40,16 @@ export MPS include("Ansatz/MPO.jl") export MPO +include("Ansatz/PEPS.jl") +export PEPS + +include("Ansatz/PEPO.jl") +export PEPO + export leftindex, rightindex, isleftcanonical, isrightcanonical export canonize_site, canonize_site!, truncate! export canonize, canonize!, mixed_canonize, mixed_canonize! -include("Ansatz/Grid.jl") -export Grid -export PEPS, pPEPS, PEPO, pPEPO - export evolve!, expect, overlap # reexports from EinExprs From c21db5ef84b0caf529b420f148f91819e7d8ab0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Mon, 16 Sep 2024 18:03:02 -0400 Subject: [PATCH 16/75] Remove some exports --- src/Tenet.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/Tenet.jl b/src/Tenet.jl index 9f67a3c7c..79d13304b 100644 --- a/src/Tenet.jl +++ b/src/Tenet.jl @@ -46,10 +46,7 @@ export PEPS include("Ansatz/PEPO.jl") export PEPO -export leftindex, rightindex, isleftcanonical, isrightcanonical -export canonize_site, canonize_site!, truncate! -export canonize, canonize!, mixed_canonize, mixed_canonize! - +export canonize, canonize!, mixed_canonize, mixed_canonize!, truncate! export evolve!, expect, overlap # reexports from EinExprs From c9d351f5fef144244e4587eeba13a5b464092b7a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 17 Sep 2024 14:33:09 -0400 Subject: [PATCH 17/75] Move `Chain` code to `AbstractAnsatz` and `MPS` --- src/Ansatz/Ansatz.jl | 290 ++++++++++++++++++++-- src/Ansatz/Chain.jl | 561 ------------------------------------------- src/Ansatz/MPO.jl | 2 + src/Ansatz/MPS.jl | 205 +++++++++++++++- 4 files changed, 469 insertions(+), 589 deletions(-) delete mode 100644 src/Ansatz/Chain.jl diff --git a/src/Ansatz/Ansatz.jl b/src/Ansatz/Ansatz.jl index 6bfbca555..f18cd2ff4 100644 --- a/src/Ansatz/Ansatz.jl +++ b/src/Ansatz/Ansatz.jl @@ -3,6 +3,32 @@ using LinearAlgebra using Graphs using MetaGraphsNext +# Traits +abstract type Boundary end +struct Open <: Boundary end +struct Periodic <: Boundary end + +function boundary end + +abstract type Form end +struct NonCanonical <: Form end +struct MixedCanonical <: Form + orthogonality_center::Union{Site,Vector{Site}} +end +struct Canonical <: Form end + +function form end + +struct MissingSchmidtCoefficientsException <: Base.Exception + bond::NTuple{2,Site} +end + +MissingSchmidtCoefficientsException(bond::Vector{<:Site}) = MissingSchmidtCoefficientsException(tuple(bond...)) + +function Base.showerror(io::IO, e::MissingSchmidtCoefficientsException) + return print(io, "Can't access the spectrum on bond $(e.bond)") +end + abstract type AbstractAnsatz <: AbstractQuantum end """ @@ -59,43 +85,255 @@ end return only(inds(tensor1) ∩ inds(tensor2)) end -@kwmethod function Tenet.tensors(tn::AbstractAnsatz; between) - (site1, site2) = between - @assert site1 ∈ sites(tn) "Site $site1 not found" - @assert site2 ∈ sites(tn) "Site $site2 not found" - @assert site1 != site2 "Sites must be different" - @assert isneighbor(tn, site1, site2) "Sites must be neighbors" +@kwmethod tensors(tn::AbstractAnsatz; bond) = tn[inds(tn; bond)] +@kwmethod function tensors(tn::AbstractAnsatz; between) + Base.depwarn( + "`tensors(tn; between)` is deprecated, use `tensors(tn; bond)` instead.", + ((Base.Core).Typeof(tensors)).name.mt.name, + ) + return tensors(tn; bond=between) +end - tensor1 = tensors(tn; at=site1) - tensor2 = tensors(tn; at=site2) +@kwmethod contract!(tn::AbstractAnsatz; bond) = contract!(tn, inds(tn; bond)) - isdisjoint(inds(tensor1), inds(tensor2)) && return nothing +canonize(tn::AbstractAnsatz, args...; kwargs...) = canonize!(deepcopy(tn), args...; kwargs...) +canonize_site(tn::AbstractAnsatz, args...; kwargs...) = canonize_site!(deepcopy(tn), args...; kwargs...) + +""" + truncate(tn::AbstractAnsatz, bond; threshold = nothing, maxdim = nothing) + +Like [`truncate!`](@ref), but returns a new tensor network instead of modifying the original one. +""" +truncate(tn::AbstractAnsatz, args...; kwargs...) = truncate!(deepcopy(tn), args...; kwargs...) + +""" + truncate!(tn::AbstractAnsatz, bond; threshold = nothing, maxdim = nothing) + +Truncate the dimension of the virtual `bond`` of an [`Ansatz`](@ref) Tensor Network by keeping only the `maxdim` largest Schmidt coefficients or those larger than`threshold`. + +# Notes + + - Either `threshold` or `maxdim` must be provided. If both are provided, `maxdim` is used. + - The bond must contain the Schmidt coefficients, i.e. a site canonization must be performed before calling `truncate!`. +""" +function truncate!(tn::AbstractAnsatz, bond; threshold=nothing, maxdim=nothing) + @assert isnothing(maxdim) ⊻ isnothing(threshold) "Either `threshold` or `maxdim` must be provided" + + spectrum = parent(tensors(tn; bond)) + vind = inds(tn; bond) + + maxdim = isnothing(maxdim) ? size(tn, vind) : maxdim + threshold = isnothing(threshold) ? 1e-16 : threshold - return tn[only(inds(tensor1) ∩ inds(tensor2))] + extent = findfirst(1:maxdim) do i + abs(spectrum[i]) < threshold + end + + slice!(tn, vind, extent) + + return tn end -struct MissingSchmidtCoefficientsException <: Base.Exception - bond::NTuple{2,Site} +function expect(ψ::AbstractAnsatz, observables; bra=copy(ψ)) + ϕ = bra + + # TODO is this ok? + for observable in observables + evolve!(ϕ, observable) + end + + return overlap(ϕ, ψ) end -MissingSchmidtCoefficientsException(bond::Vector{<:Site}) = MissingSchmidtCoefficientsException(tuple(bond...)) +overlap(a::AbstractAnsatz, b::AbstractAnsatz) = contract(merge(a, copy(b)')) -function Base.showerror(io::IO, e::MissingSchmidtCoefficientsException) - return print(io, "Can't access the spectrum on bond $(e.bond)") +function evolve!(ψ::AbstractAnsatz, gate; threshold=nothing, maxdim=nothing, renormalize=false) + return simple_update!(ψ, gate; threshold, maxdim, renormalize) end -# Traits -abstract type Boundary end -struct Open <: Boundary end -struct Periodic <: Boundary end +# by popular demand (Stefano, I'm looking at you), I aliased `apply!` to `evolve!` +const apply! = evolve! -function boundary end +function simple_update!(ψ::AbstractAnsatz, gate; threshold=nothing, maxdim=nothing, kwargs...) + @assert issetequal(adjoint.(sites(gate; set=:inputs)), sites(gate; set=:outputs)) "Inputs of the gate must match outputs" + @assert isneighbor(ψ, lanes(gate)...) "Gate must act on neighboring sites" -abstract type Form end -struct NonCanonical <: Form end -struct MixedCanonical <: Form - orthogonality_center::Union{Site,Vector{Site}} + if nlanes(gate) == 1 + return simple_update_1site!(ψ, gate) + end + + return simple_update!(form(ψ), ψ, gate; kwargs...) end -struct Canonical <: Form end -function form end +function simple_update_1site!(ψ::AbstractAnsatz, gate) + @assert nlanes(gate) == 1 "Gate must act only on one lane" + @assert ninputs(gate) == 1 "Gate must have only one input" + @assert noutputs(gate) == 1 "Gate must have only one output" + + # shallow copy to avoid problems if errors in mid execution + gate = copy(gate) + + contracting_index = gensym(:tmp) + targetsite = only(sites(gate; set=:inputs))' + + # reindex contracting index + replace!(ψ, inds(ψ; at=targetsite) => contracting_index) + replace!(gate, inds(gate; at=targetsite') => contracting_index) + + # reindex output of gate to match TN sitemap + replace!(gate, inds(gate; at=only(sites(gate; set=:outputs))) => inds(ψ; at=targetsite)) + + # contract gate with TN + merge!(ψ, gate) + return contract!(ψ, contracting_index) +end + +function simple_update!(::NonCanonical, ψ::AbstractAnsatz, gate; threshold=nothing, maxdim=nothing) + @assert nlanes(gate) == 2 "Only 2-site gates are supported currently" + + # shallow copy to avoid problems if errors in mid execution + gate = copy(gate) + + merge!(ψ, gate) + contract!(ψ, inds(gate; set=:inputs)) + + # TODO split + + return ψ +end + +# TODO move non-canonical code to method above +# TODO remove `renormalize` argument +# TODO refactor code +function simple_update!(::Canonical, ψ::AbstractAnsatz, gate; threshold, maxdim, renormalize=false, iscanonical=false) + # shallow copy to avoid problems if errors in mid execution + gate = copy(gate) + + bond = sitel, siter = minmax(sites(gate; set=:outputs)...) + left_inds::Vector{Symbol} = !isnothing(leftindex(ψ, sitel)) ? [leftindex(ψ, sitel)] : Symbol[] + right_inds::Vector{Symbol} = !isnothing(rightindex(ψ, siter)) ? [rightindex(ψ, siter)] : Symbol[] + + virtualind::Symbol = inds(ψ; bond=bond) + + iscanonical ? contract_2sitewf!(ψ, bond) : contract!(TensorNetwork(ψ), virtualind) + + # reindex contracting index + contracting_inds = [gensym(:tmp) for _ in sites(gate; set=:inputs)] + replace!( + ψ, + map(zip(sites(gate; set=:inputs), contracting_inds)) do (site, contracting_index) + inds(ψ; at=site') => contracting_index + end, + ) + replace!( + gate, + map(zip(sites(gate; set=:inputs), contracting_inds)) do (site, contracting_index) + inds(gate; at=site) => contracting_index + end, + ) + + # replace output indices of the gate for gensym indices + output_inds = [gensym(:out) for _ in sites(gate; set=:outputs)] + replace!( + gate, + map(zip(sites(gate; set=:outputs), output_inds)) do (site, out) + inds(gate; at=site) => out + end, + ) + + # reindex output of gate to match TN sitemap + for site in sites(gate; set=:outputs) + if inds(ψ; at=site) != inds(gate; at=site) + replace!(gate, inds(gate; at=site) => inds(ψ; at=site)) + end + end + + # contract physical inds + merge!(ψ, gate) + contract!(ψ, contracting_inds) + + # decompose using SVD + push!(left_inds, inds(ψ; at=sitel)) + push!(right_inds, inds(ψ; at=siter)) + + if iscanonical + unpack_2sitewf!(ψ, bond, left_inds, right_inds, virtualind) + else + svd!(ψ; left_inds, right_inds, virtualind) + end + # truncate virtual index + if any(!isnothing, [threshold, maxdim]) + truncate!(ψ, bond; threshold, maxdim) + + # renormalize the bond + if renormalize && iscanonical + λ = tensors(ψ; between=bond) + replace!(ψ, λ => normalize(λ)) # TODO this can be replaced by `normalize!(λ)` + elseif renormalize && !iscanonical + normalize!(ψ, bond[1]) + end + end + + return ψ +end + +# TODO refactor code +""" + contract_2sitewf!(ψ::AbstractAnsatz, bond) + +For a given [`AbstractAnsatz`](@ref) in the canonical form, creates the two-site wave function θ with Λᵢ₋₁Γᵢ₋₁ΛᵢΓᵢΛᵢ₊₁, +where i is the `bond`, and replaces the Γᵢ₋₁ΛᵢΓᵢ tensors with θ. +""" +function contract_2sitewf!(ψ::AbstractAnsatz, bond) + @assert form(ψ) == Canonical() "The tensor network must be in canonical form" + + sitel, siter = bond # TODO Check if bond is valid + (0 < id(sitel) < nsites(ψ) || 0 < id(siter) < nsites(ψ)) || + throw(ArgumentError("The sites in the bond must be between 1 and $(nsites(ψ))")) + + Λᵢ₋₁ = id(sitel) == 1 ? nothing : tensors(ψ; between=(Site(id(sitel) - 1), sitel)) + Λᵢ₊₁ = id(sitel) == nsites(ψ) - 1 ? nothing : tensors(ψ; between=(siter, Site(id(siter) + 1))) + + !isnothing(Λᵢ₋₁) && contract!(ψ; between=(Site(id(sitel) - 1), sitel), direction=:right, delete_Λ=false) + !isnothing(Λᵢ₊₁) && contract!(ψ; between=(siter, Site(id(siter) + 1)), direction=:left, delete_Λ=false) + + contract!(ψ, inds(ψ; bond=bond)) + + return ψ +end + +# TODO refactor code +""" + unpack_2sitewf!(ψ::AbstractAnsatz, bond) + +For a given [`AbstractAnsatz`](@ref) that contains a two-site wave function θ in a bond, it decomposes θ into the canonical +form: Γᵢ₋₁ΛᵢΓᵢ, where i is the `bond`. +""" +function unpack_2sitewf!(ψ::AbstractAnsatz, bond, left_inds, right_inds, virtualind) + @assert form(ψ) == Canonical() "The tensor network must be in canonical form" + + sitel, siter = bond # TODO Check if bond is valid + (0 < id(sitel) < nsites(ψ) || 0 < id(site_r) < nsites(ψ)) || + throw(ArgumentError("The sites in the bond must be between 1 and $(nsites(ψ))")) + + Λᵢ₋₁ = id(sitel) == 1 ? nothing : tensors(ψ; between=(Site(id(sitel) - 1), sitel)) + Λᵢ₊₁ = id(siter) == nsites(ψ) ? nothing : tensors(ψ; between=(siter, Site(id(siter) + 1))) + + # do svd of the θ tensor + θ = tensors(ψ; at=sitel) + U, s, Vt = svd(θ; left_inds, right_inds, virtualind) + + # contract with the inverse of Λᵢ and Λᵢ₊₂ + Γᵢ₋₁ = + isnothing(Λᵢ₋₁) ? U : contract(U, Tensor(diag(pinv(Diagonal(parent(Λᵢ₋₁)); atol=1e-32)), inds(Λᵢ₋₁)); dims=()) + Γᵢ = + isnothing(Λᵢ₊₁) ? Vt : contract(Tensor(diag(pinv(Diagonal(parent(Λᵢ₊₁)); atol=1e-32)), inds(Λᵢ₊₁)), Vt; dims=()) + + delete!(ψ, θ) + + push!(ψ, Γᵢ₋₁) + push!(ψ, s) + push!(ψ, Γᵢ) + + return ψ +end diff --git a/src/Ansatz/Chain.jl b/src/Ansatz/Chain.jl deleted file mode 100644 index 05af06dd4..000000000 --- a/src/Ansatz/Chain.jl +++ /dev/null @@ -1,561 +0,0 @@ -using LinearAlgebra -using Random - -struct Chain <: Ansatz - super::Quantum - boundary::Boundary -end - -Base.copy(tn::Chain) = Chain(copy(Quantum(tn)), boundary(tn)) - -Base.similar(tn::Chain) = Chain(similar(Quantum(tn)), boundary(tn)) -Base.zero(tn::Chain) = Chain(zero(Quantum(tn)), boundary(tn)) - -function Chain(::State, boundary::Periodic, arrays::Vector{<:AbstractArray}; order=defaultorder(Chain, State())) - @assert all(==(3) ∘ ndims, arrays) "All arrays must have 3 dimensions" - issetequal(order, defaultorder(Chain, State())) || - throw(ArgumentError("order must be a permutation of $(String.(defaultorder(Chain, State())))")) - - n = length(arrays) - gen = IndexCounter() - symbols = [nextindex!(gen) for _ in 1:(2n)] - - _tensors = map(enumerate(arrays)) do (i, array) - inds = map(order) do dir - if dir == :o - symbols[i] - elseif dir == :r - symbols[n + mod1(i, n)] - elseif dir == :l - symbols[n + mod1(i - 1, n)] - else - throw(ArgumentError("Invalid direction: $dir")) - end - end - Tensor(array, inds) - end - - sitemap = Dict(Site(i) => symbols[i] for i in 1:n) - - return Chain(Quantum(TensorNetwork(_tensors), sitemap), boundary) -end - -function Chain(::Operator, boundary::Periodic, arrays::Vector{<:AbstractArray}; order=defaultorder(Chain, Operator())) - @assert all(==(4) ∘ ndims, arrays) "All arrays must have 4 dimensions" - issetequal(order, defaultorder(Chain, Operator())) || - throw(ArgumentError("order must be a permutation of $(String.(defaultorder(Chain, Operator())))")) - - n = length(arrays) - gen = IndexCounter() - symbols = [nextindex!(gen) for _ in 1:(3n)] - - _tensors = map(enumerate(arrays)) do (i, array) - inds = map(order) do dir - if dir == :o - symbols[i] - elseif dir == :i - symbols[i + n] - elseif dir == :l - symbols[2n + mod1(i - 1, n)] - elseif dir == :r - symbols[2n + mod1(i, n)] - else - throw(ArgumentError("Invalid direction: $dir")) - end - end - Tensor(array, inds) - end - - sitemap = Dict(Site(i) => symbols[i] for i in 1:n) - merge!(sitemap, Dict(Site(i; dual=true) => symbols[i + n] for i in 1:n)) - - return Chain(Quantum(TensorNetwork(_tensors), sitemap), boundary) -end - -leftsite(tn::Chain, site::Site) = leftsite(boundary(tn), tn, site) -function leftsite(::Open, tn::Chain, site::Site) - return id(site) ∈ range(2, nlanes(tn)) ? Site(id(site) - 1; dual=isdual(site)) : nothing -end -leftsite(::Periodic, tn::Chain, site::Site) = Site(mod1(id(site) - 1, nlanes(tn)); dual=isdual(site)) - -rightsite(tn::Chain, site::Site) = rightsite(boundary(tn), tn, site) -function rightsite(::Open, tn::Chain, site::Site) - return id(site) ∈ range(1, nlanes(tn) - 1) ? Site(id(site) + 1; dual=isdual(site)) : nothing -end -rightsite(::Periodic, tn::Chain, site::Site) = Site(mod1(id(site) + 1, nlanes(tn)); dual=isdual(site)) - -leftindex(tn::Chain, site::Site) = leftindex(boundary(tn), tn, site) -leftindex(::Open, tn::Chain, site::Site) = site == site"1" ? nothing : leftindex(Periodic(), tn, site) -leftindex(::Periodic, tn::Chain, site::Site) = inds(tn; bond=(site, leftsite(tn, site))) - -rightindex(tn::Chain, site::Site) = rightindex(boundary(tn), tn, site) -function rightindex(::Open, tn::Chain, site::Site) - return site == Site(nlanes(tn); dual=isdual(site)) ? nothing : rightindex(Periodic(), tn, site) -end -rightindex(::Periodic, tn::Chain, site::Site) = inds(tn; bond=(site, rightsite(tn, site))) - -# """ -# Tenet.contract!(tn::Chain; between=(site1, site2), direction::Symbol = :left, delete_Λ = true) - -# For a given [`Chain`](@ref) tensor network, contracts the singular values Λ between two sites `site1` and `site2`. -# The `direction` keyword argument specifies the direction of the contraction, and the `delete_Λ` keyword argument -# specifies whether to delete the singular values tensor after the contraction. -# """ -@kwmethod contract(tn::Chain; between, direction, delete_Λ) = contract!(copy(tn); between, direction, delete_Λ) -@kwmethod function contract!(tn::Chain; between, direction, delete_Λ) - site1, site2 = between - Λᵢ = tensors(tn; between) - Λᵢ === nothing && return tn - - if direction === :right - Γᵢ₊₁ = tensors(tn; at=site2) - replace!(tn, Γᵢ₊₁ => contract(Γᵢ₊₁, Λᵢ; dims=())) - elseif direction === :left - Γᵢ = tensors(tn; at=site1) - replace!(tn, Γᵢ => contract(Λᵢ, Γᵢ; dims=())) - else - throw(ArgumentError("Unknown direction=:$direction")) - end - - delete_Λ && delete!(TensorNetwork(tn), Λᵢ) - - return tn -end -@kwmethod contract(tn::Chain; between) = contract(tn; between, direction=:left, delete_Λ=true) -@kwmethod contract!(tn::Chain; between) = contract!(tn; between, direction=:left, delete_Λ=true) -@kwmethod contract(tn::Chain; between, direction) = contract(tn; between, direction, delete_Λ=true) -@kwmethod contract!(tn::Chain; between, direction) = contract!(tn; between, direction, delete_Λ=true) - -canonize_site(tn::Chain, args...; kwargs...) = canonize_site!(deepcopy(tn), args...; kwargs...) -canonize_site!(tn::Chain, args...; kwargs...) = canonize_site!(boundary(tn), tn, args...; kwargs...) - -# NOTE: in method == :svd the spectral weights are stored in a vector connected to the now virtual hyperindex! -function canonize_site!(::Open, tn::Chain, site::Site; direction::Symbol, method=:qr) - left_inds = Symbol[] - right_inds = Symbol[] - - virtualind = if direction === :left - site == Site(1) && throw(ArgumentError("Cannot right-canonize left-most tensor")) - push!(right_inds, leftindex(tn, site)) - - site == Site(nsites(tn)) || push!(left_inds, rightindex(tn, site)) - push!(left_inds, inds(tn; at=site)) - - only(right_inds) - elseif direction === :right - site == Site(nsites(tn)) && throw(ArgumentError("Cannot left-canonize right-most tensor")) - push!(right_inds, rightindex(tn, site)) - - site == Site(1) || push!(left_inds, leftindex(tn, site)) - push!(left_inds, inds(tn; at=site)) - - only(right_inds) - else - throw(ArgumentError("Unknown direction=:$direction")) - end - - tmpind = gensym(:tmp) - if method === :svd - svd!(TensorNetwork(tn); left_inds, right_inds, virtualind=tmpind) - elseif method === :qr - qr!(TensorNetwork(tn); left_inds, right_inds, virtualind=tmpind) - else - throw(ArgumentError("Unknown factorization method=:$method")) - end - - contract!(tn, virtualind) - replace!(tn, tmpind => virtualind) - - return tn -end - -truncate(tn::Chain, args...; kwargs...) = truncate!(deepcopy(tn), args...; kwargs...) - -""" - truncate!(qtn::Chain, bond; threshold::Union{Nothing,Real} = nothing, maxdim::Union{Nothing,Int} = nothing) - -Truncate the dimension of the virtual `bond`` of the [`Chain`](@ref) Tensor Network by keeping only the `maxdim` largest Schmidt coefficients or those larger than`threshold`. - -# Notes - - - Either `threshold` or `maxdim` must be provided. If both are provided, `maxdim` is used. - - The bond must contain the Schmidt coefficients, i.e. a site canonization must be performed before calling `truncate!`. -""" -function truncate!(qtn::Chain, bond; threshold::Union{Nothing,Real}=nothing, maxdim::Union{Nothing,Int}=nothing) - # TODO replace for tensors(; between) - vind = rightindex(qtn, bond[1]) - if vind != leftindex(qtn, bond[2]) - throw(ArgumentError("Invalid bond $bond")) - end - - if vind ∉ inds(qtn; set=:hyper) - throw(MissingSchmidtCoefficientsException(bond)) - end - - tensor = TensorNetwork(qtn)[vind] - spectrum = parent(tensor) - - extent = collect( - if !isnothing(maxdim) - 1:min(size(qtn, vind), maxdim) - else - 1:size(qtn, vind) - end, - ) - - # remove 0s from spectrum - if isnothing(threshold) - threshold = 1e-16 - end - - filter!(extent) do i - abs(spectrum[i]) > threshold - end - - slice!(qtn, vind, extent) - - return qtn -end - -function isleftcanonical(qtn::Chain, site; atol::Real=1e-12) - right_ind = rightindex(qtn, site) - tensor = tensors(qtn; at=site) - - # we are at right-most site, we need to add an extra dummy dimension to the tensor - if isnothing(right_ind) - right_ind = gensym(:dummy) - tensor = Tensor(reshape(parent(tensor), size(tensor)..., 1), (inds(tensor)..., right_ind)) - end - - # TODO is replace(conj(A)...) copying too much? - contracted = contract(tensor, replace(conj(tensor), right_ind => gensym(:new_ind))) - n = size(tensor, right_ind) - identity_matrix = Matrix(I, n, n) - - return isapprox(contracted, identity_matrix; atol) -end - -function isrightcanonical(qtn::Chain, site; atol::Real=1e-12) - left_ind = leftindex(qtn, site) - tensor = tensors(qtn; at=site) - - # we are at left-most site, we need to add an extra dummy dimension to the tensor - if isnothing(left_ind) - left_ind = gensym(:dummy) - tensor = Tensor(reshape(parent(tensor), 1, size(tensor)...), (left_ind, inds(tensor)...)) - end - - #TODO is replace(conj(A)...) copying too much? - contracted = contract(tensor, replace(conj(tensor), left_ind => gensym(:new_ind))) - n = size(tensor, left_ind) - identity_matrix = Matrix(I, n, n) - - return isapprox(contracted, identity_matrix; atol) -end - -canonize(tn::Chain, args...; kwargs...) = canonize!(copy(tn), args...; kwargs...) -canonize!(tn::Chain, args...; kwargs...) = canonize!(boundary(tn), tn, args...; kwargs...) - -""" -canonize(boundary::Boundary, tn::Chain) - -Transform a `Chain` tensor network into the canonical form (Vidal form), that is, -we have the singular values matrix Λᵢ between each tensor Γᵢ₋₁ and Γᵢ. -""" -function canonize!(::Open, tn::Chain) - Λ = Tensor[] - - # right-to-left QR sweep, get right-canonical tensors - for i in nsites(tn):-1:2 - canonize_site!(tn, Site(i); direction=:left, method=:qr) - end - - # left-to-right SVD sweep, get left-canonical tensors and singular values without reversing - for i in 1:(nsites(tn) - 1) - canonize_site!(tn, Site(i); direction=:right, method=:svd) - - # extract the singular values and contract them with the next tensor - Λᵢ = pop!(TensorNetwork(tn), tensors(tn; between=(Site(i), Site(i + 1)))) - Aᵢ₊₁ = tensors(tn; at=Site(i + 1)) - replace!(tn, Aᵢ₊₁ => contract(Aᵢ₊₁, Λᵢ; dims=())) - push!(Λ, Λᵢ) - end - - for i in 2:nsites(tn) # tensors at i in "A" form, need to contract (Λᵢ)⁻¹ with A to get Γᵢ - Λᵢ = Λ[i - 1] # singular values start between site 1 and 2 - A = tensors(tn; at=Site(i)) - Γᵢ = contract(A, Tensor(diag(pinv(Diagonal(parent(Λᵢ)); atol=1e-64)), inds(Λᵢ)); dims=()) - replace!(tn, A => Γᵢ) - push!(TensorNetwork(tn), Λᵢ) - end - - return tn -end - -mixed_canonize(tn::Chain, args...; kwargs...) = mixed_canonize!(deepcopy(tn), args...; kwargs...) -mixed_canonize!(tn::Chain, args...; kwargs...) = mixed_canonize!(boundary(tn), tn, args...; kwargs...) - -""" - mixed_canonize!(boundary::Boundary, tn::Chain, center::Site) - -Transform a `Chain` tensor network into the mixed-canonical form, that is, -for i < center the tensors are left-canonical and for i >= center the tensors are right-canonical, -and in the center there is a matrix with singular values. -""" -function mixed_canonize!(::Open, tn::Chain, center::Site) # TODO: center could be a range of sites - # left-to-right QR sweep (left-canonical tensors) - for i in 1:(id(center) - 1) - canonize_site!(tn, Site(i); direction=:right, method=:qr) - end - - # right-to-left QR sweep (right-canonical tensors) - for i in nsites(tn):-1:(id(center) + 1) - canonize_site!(tn, Site(i); direction=:left, method=:qr) - end - - # center SVD sweep to get singular values - canonize_site!(tn, center; direction=:left, method=:svd) - - return tn -end - -""" - LinearAlgebra.normalize!(tn::Chain, center::Site) - -Normalizes the input [`Chain`](@ref) tensor network by transforming it -to mixed-canonized form with the given center site. -""" -function LinearAlgebra.normalize!(tn::Chain, root::Site; p::Real=2) - mixed_canonize!(tn, root) - normalize!(tensors(tn; between=(Site(id(root) - 1), root)), p) - return tn -end - -""" - evolve!(qtn::Chain, gate) - -Applies a local operator `gate` to the [`Chain`](@ref) tensor network. -""" -function evolve!(qtn::Chain, gate::Dense; threshold=nothing, maxdim=nothing, iscanonical=false, renormalize=false) - # check gate is a valid operator - if !(socket(gate) isa Operator) - throw(ArgumentError("Gate must be an operator, but got $(socket(gate))")) - end - - # TODO refactor out to `islane`? - if !issetequal(adjoint.(sites(gate; set=:inputs)), sites(gate; set=:outputs)) - throw( - ArgumentError( - "Gate inputs ($(sites(gate; set=:inputs))) and outputs ($(sites(gate; set=:outputs))) must be the same" - ), - ) - end - - # TODO refactor out to `canconnect`? - if adjoint.(sites(gate; set=:inputs)) ⊈ sites(qtn; set=:outputs) - throw( - ArgumentError("Gate inputs ($(sites(gate; set=:inputs))) must be a subset of the TN sites ($(sites(qtn)))") - ) - end - - if nlanes(gate) == 1 - evolve_1site!(qtn, gate) - elseif nlanes(gate) == 2 - # check gate sites are contiguous - # TODO refactor this out? - gate_inputs = sort!(id.(sites(gate; set=:inputs))) - range = UnitRange(extrema(gate_inputs)...) - - range != gate_inputs && throw(ArgumentError("Gate lanes must be contiguous")) - - # TODO check correctly for periodic boundary conditions - evolve_2site!(qtn, gate; threshold, maxdim, iscanonical, renormalize) - else - # TODO generalize for more than 2 lanes - throw(ArgumentError("Invalid number of lanes $(nlanes(gate)), maximum is 2")) - end - - return qtn -end - -function evolve_1site!(qtn::Chain, gate::Dense) - # shallow copy to avoid problems if errors in mid execution - gate = copy(gate) - resetindex!(gate; init=ninds(qtn)) - - contracting_index = gensym(:tmp) - targetsite = only(sites(gate; set=:inputs))' - - # reindex output of gate to match TN sitemap - replace!(gate, inds(gate; at=only(sites(gate; set=:outputs))) => inds(qtn; at=targetsite)) - - # reindex contracting index - replace!(qtn, inds(qtn; at=targetsite) => contracting_index) - replace!(gate, inds(gate; at=targetsite') => contracting_index) - - # contract gate with TN - merge!(qtn, gate; reset=false) - return contract!(qtn, contracting_index) -end - -# TODO: Maybe rename iscanonical kwarg ? -function evolve_2site!(qtn::Chain, gate::Dense; threshold, maxdim, iscanonical=false, renormalize=false) - # shallow copy to avoid problems if errors in mid execution - gate = copy(gate) - - bond = sitel, siter = minmax(sites(gate; set=:outputs)...) - left_inds::Vector{Symbol} = !isnothing(leftindex(qtn, sitel)) ? [leftindex(qtn, sitel)] : Symbol[] - right_inds::Vector{Symbol} = !isnothing(rightindex(qtn, siter)) ? [rightindex(qtn, siter)] : Symbol[] - - virtualind::Symbol = inds(qtn; bond=bond) - - iscanonical ? contract_2sitewf!(qtn, bond) : contract!(TensorNetwork(qtn), virtualind) - - # reindex contracting index - contracting_inds = [gensym(:tmp) for _ in sites(gate; set=:inputs)] - replace!( - TensorNetwork(qtn), - map(zip(sites(gate; set=:inputs), contracting_inds)) do (site, contracting_index) - inds(qtn; at=site') => contracting_index - end, - ) - replace!( - Quantum(gate), - map(zip(sites(gate; set=:inputs), contracting_inds)) do (site, contracting_index) - inds(gate; at=site) => contracting_index - end, - ) - - # replace output indices of the gate for gensym indices - output_inds = [gensym(:out) for _ in sites(gate; set=:outputs)] - replace!( - Quantum(gate), - map(zip(sites(gate; set=:outputs), output_inds)) do (site, out) - inds(gate; at=site) => out - end, - ) - - # reindex output of gate to match TN sitemap - for site in sites(gate; set=:outputs) - if inds(qtn; at=site) != inds(gate; at=site) - replace!(TensorNetwork(gate), inds(gate; at=site) => inds(qtn; at=site)) - end - end - - # contract physical inds - merge!(TensorNetwork(qtn), TensorNetwork(gate)) - contract!(qtn, contracting_inds) - - # decompose using SVD - push!(left_inds, inds(qtn; at=sitel)) - push!(right_inds, inds(qtn; at=siter)) - - if iscanonical - unpack_2sitewf!(qtn, bond, left_inds, right_inds, virtualind) - else - svd!(TensorNetwork(qtn); left_inds, right_inds, virtualind) - end - # truncate virtual index - if any(!isnothing, [threshold, maxdim]) - truncate!(qtn, bond; threshold, maxdim) - - # renormalize the bond - if renormalize && iscanonical - λ = tensors(qtn; between=bond) - replace!(qtn, λ => normalize(λ)) # TODO this can be replaced by `normalize!(λ)` - elseif renormalize && !iscanonical - normalize!(qtn, bond[1]) - end - end - - return qtn -end - -""" - contract_2sitewf!(ψ::Chain, bond) - -For a given [`Chain`](@ref) in the canonical form, creates the two-site wave function θ with Λᵢ₋₁Γᵢ₋₁ΛᵢΓᵢΛᵢ₊₁, -where i is the `bond`, and replaces the Γᵢ₋₁ΛᵢΓᵢ tensors with θ. -""" -function contract_2sitewf!(ψ::Chain, bond) - # TODO Check if ψ is in canonical form - - sitel, siter = bond # TODO Check if bond is valid - (0 < id(sitel) < nsites(ψ) || 0 < id(siter) < nsites(ψ)) || - throw(ArgumentError("The sites in the bond must be between 1 and $(nsites(ψ))")) - - Λᵢ₋₁ = id(sitel) == 1 ? nothing : tensors(ψ; between=(Site(id(sitel) - 1), sitel)) - Λᵢ₊₁ = id(sitel) == nsites(ψ) - 1 ? nothing : tensors(ψ; between=(siter, Site(id(siter) + 1))) - - !isnothing(Λᵢ₋₁) && contract!(ψ; between=(Site(id(sitel) - 1), sitel), direction=:right, delete_Λ=false) - !isnothing(Λᵢ₊₁) && contract!(ψ; between=(siter, Site(id(siter) + 1)), direction=:left, delete_Λ=false) - - contract!(ψ, inds(ψ; bond=bond)) - - return ψ -end - -""" - unpack_2sitewf!(ψ::Chain, bond) - -For a given [`Chain`](@ref) that contains a two-site wave function θ in a bond, it decomposes θ into the canonical -form: Γᵢ₋₁ΛᵢΓᵢ, where i is the `bond`. -""" -function unpack_2sitewf!(ψ::Chain, bond, left_inds, right_inds, virtualind) - # TODO Check if ψ is in canonical form - - sitel, siter = bond # TODO Check if bond is valid - (0 < id(sitel) < nsites(ψ) || 0 < id(site_r) < nsites(ψ)) || - throw(ArgumentError("The sites in the bond must be between 1 and $(nsites(ψ))")) - - Λᵢ₋₁ = id(sitel) == 1 ? nothing : tensors(ψ; between=(Site(id(sitel) - 1), sitel)) - Λᵢ₊₁ = id(siter) == nsites(ψ) ? nothing : tensors(ψ; between=(siter, Site(id(siter) + 1))) - - # do svd of the θ tensor - θ = tensors(ψ; at=sitel) - U, s, Vt = svd(θ; left_inds, right_inds, virtualind) - - # contract with the inverse of Λᵢ and Λᵢ₊₂ - Γᵢ₋₁ = - isnothing(Λᵢ₋₁) ? U : contract(U, Tensor(diag(pinv(Diagonal(parent(Λᵢ₋₁)); atol=1e-32)), inds(Λᵢ₋₁)); dims=()) - Γᵢ = - isnothing(Λᵢ₊₁) ? Vt : contract(Tensor(diag(pinv(Diagonal(parent(Λᵢ₊₁)); atol=1e-32)), inds(Λᵢ₊₁)), Vt; dims=()) - - delete!(TensorNetwork(ψ), θ) - - push!(TensorNetwork(ψ), Γᵢ₋₁) - push!(TensorNetwork(ψ), s) - push!(TensorNetwork(ψ), Γᵢ) - - return ψ -end - -function expect(ψ::Chain, observables) - # contract observable with TN - ϕ = copy(ψ) - for observable in observables - evolve!(ϕ, observable) - end - - # contract evolved TN with adjoint of original TN - tn = merge!(TensorNetwork(ϕ), TensorNetwork(ψ')) - - return contract(tn) -end - -overlap(a::Chain, b::Chain) = overlap(socket(a), a, socket(b), b) - -# TODO fix optimal path -function overlap(::State, a::Chain, ::State, b::Chain) - @assert issetequal(sites(a), sites(b)) "Ansatzes must have the same sites" - - b = copy(b) - b = @reindex! outputs(a) => outputs(b) - - tn = merge(TensorNetwork(a), TensorNetwork(b')) - return contract(tn) -end - -# TODO optimize -overlap(a::Product, b::Chain) = contract(merge(Quantum(a), Quantum(b)')) -overlap(a::Chain, b::Product) = contract(merge(Quantum(a), Quantum(b)')) diff --git a/src/Ansatz/MPO.jl b/src/Ansatz/MPO.jl index 04709d146..b98f2d534 100644 --- a/src/Ansatz/MPO.jl +++ b/src/Ansatz/MPO.jl @@ -106,3 +106,5 @@ function Base.rand(rng::Random.AbstractRNG, ::Type{MPO}, n, χ; eltype=Float64, # TODO order might not be the best for performance return MPO(arrays; order=(:l, :i, :o, :r)) end + +function evolve!(ψ::MPS, op::MPO; threshold=nothing, maxdim=nothing, renormalize=false) end diff --git a/src/Ansatz/MPS.jl b/src/Ansatz/MPS.jl index d370581e7..0b11ca18a 100644 --- a/src/Ansatz/MPS.jl +++ b/src/Ansatz/MPS.jl @@ -1,9 +1,10 @@ using Random +using LinearAlgebra abstract type AbstractMPS <: AbstractAnsatz end -struct MPS <: AbstractMPS - tn::Ansatz +mutable struct MPS <: AbstractMPS + const tn::Ansatz form::Form end @@ -103,3 +104,203 @@ function Base.rand(rng::Random.AbstractRNG, ::Type{MPS}, n, χ; eltype=Float64, return MPS(arrays; order=(:l, :o, :r)) end + +# TODO deprecate contract(; between) and generalize it to AbstractAnsatz +""" + Tenet.contract!(tn::MPS; between=(site1, site2), direction::Symbol = :left, delete_Λ = true) + +For a given [`MPS`](@ref) tensor network, contracts the singular values Λ between two sites `site1` and `site2`. +The `direction` keyword argument specifies the direction of the contraction, and the `delete_Λ` keyword argument +specifies whether to delete the singular values tensor after the contraction. +""" +@kwmethod contract(tn::MPS; between, direction, delete_Λ) = contract!(copy(tn); between, direction, delete_Λ) +@kwmethod function contract!(tn::MPS; between, direction, delete_Λ) + site1, site2 = between + Λᵢ = tensors(tn; between) + Λᵢ === nothing && return tn + + if direction === :right + Γᵢ₊₁ = tensors(tn; at=site2) + replace!(tn, Γᵢ₊₁ => contract(Γᵢ₊₁, Λᵢ; dims=())) + elseif direction === :left + Γᵢ = tensors(tn; at=site1) + replace!(tn, Γᵢ => contract(Λᵢ, Γᵢ; dims=())) + else + throw(ArgumentError("Unknown direction=:$direction")) + end + + delete_Λ && delete!(TensorNetwork(tn), Λᵢ) + + return tn +end +@kwmethod contract(tn::MPS; between) = contract(tn; between, direction=:left, delete_Λ=true) +@kwmethod contract!(tn::MPS; between) = contract!(tn; between, direction=:left, delete_Λ=true) +@kwmethod contract(tn::MPS; between, direction) = contract(tn; between, direction, delete_Λ=true) +@kwmethod contract!(tn::MPS; between, direction) = contract!(tn; between, direction, delete_Λ=true) + +@kwmethod function sites(ψ::MPS, site::Site; dir) + if dir === :left + return site == site"1" ? nothing : Site(id(site) - 1) + elseif dir === :right + return site == Site(nsites(ψ)) ? nothing : Site(id(site) + 1) + else + throw(ArgumentError("Unknown direction for MPS = :$dir")) + end +end + +@kwmethod function inds(ψ::MPS; at, dir) + if dir === :left && site == site"1" + return nothing + elseif dir === :right && site == Site(nlanes(tn); dual=isdual(site)) + return nothing + elseif dir ∈ (:left, :right) + return inds(tn; bond=(site, sites(tn, site; dir))) + else + throw(ArgumentError("Unknown direction for MPS = :$dir")) + end +end + +function isleftcanonical(ψ::MPS, site; atol::Real=1e-12) + right_ind = inds(ψ; at=site, dir=:right) + tensor = tensors(ψ; at=site) + + # we are at right-most site, we need to add an extra dummy dimension to the tensor + if isnothing(right_ind) + right_ind = gensym(:dummy) + tensor = Tensor(reshape(parent(tensor), size(tensor)..., 1), (inds(tensor)..., right_ind)) + end + + # TODO is replace(conj(A)...) copying too much? + contracted = contract(tensor, replace(conj(tensor), right_ind => gensym(:new_ind))) + n = size(tensor, right_ind) + identity_matrix = Matrix(I, n, n) + + return isapprox(contracted, identity_matrix; atol) +end + +function isrightcanonical(ψ::MPS, site; atol::Real=1e-12) + left_ind = inds(ψ; at=site, dir=:left) + tensor = tensors(ψ; at=site) + + # we are at left-most site, we need to add an extra dummy dimension to the tensor + if isnothing(left_ind) + left_ind = gensym(:dummy) + tensor = Tensor(reshape(parent(tensor), 1, size(tensor)...), (left_ind, inds(tensor)...)) + end + + #TODO is replace(conj(A)...) copying too much? + contracted = contract(tensor, replace(conj(tensor), left_ind => gensym(:new_ind))) + n = size(tensor, left_ind) + identity_matrix = Matrix(I, n, n) + + return isapprox(contracted, identity_matrix; atol) +end + +# NOTE: in method == :svd the spectral weights are stored in a vector connected to the now virtual hyperindex! +function canonize_site!(ψ::MPS, site::Site; direction::Symbol, method=:qr) + left_inds = Symbol[] + right_inds = Symbol[] + + virtualind = if direction === :left + site == Site(1) && throw(ArgumentError("Cannot right-canonize left-most tensor")) + push!(right_inds, inds(ψ; at=site, dir=:left)) + + site == Site(nsites(ψ)) || push!(left_inds, inds(ψ; at=site, dir=:right)) + push!(left_inds, inds(ψ; at=site)) + + only(right_inds) + elseif direction === :right + site == Site(nsites(ψ)) && throw(ArgumentError("Cannot left-canonize right-most tensor")) + push!(right_inds, inds(ψ; at=site, dir=:right)) + + site == Site(1) || push!(left_inds, inds(ψ; at=site, dir=:left)) + push!(left_inds, inds(ψ; at=site)) + + only(right_inds) + else + throw(ArgumentError("Unknown direction=:$direction")) + end + + tmpind = gensym(:tmp) + if method === :svd + svd!(ψ; left_inds, right_inds, virtualind=tmpind) + elseif method === :qr + qr!(ψ; left_inds, right_inds, virtualind=tmpind) + else + throw(ArgumentError("Unknown factorization method=:$method")) + end + + contract!(ψ, virtualind) + replace!(ψ, tmpind => virtualind) + + return ψ +end + +""" + canonize!(tn::MPS) + +Transform a [`MPS`](@ref) tensor network into the canonical form (Vidal form); i.e. the singular values matrix Λᵢ between each tensor Γᵢ₋₁ and Γᵢ. +""" +function canonize!(ψ::MPS) + Λ = Tensor[] + + # right-to-left QR sweep, get right-canonical tensors + for i in nsites(ψ):-1:2 + canonize_site!(ψ, Site(i); direction=:left, method=:qr) + end + + # left-to-right SVD sweep, get left-canonical tensors and singular values without reversing + for i in 1:(nsites(ψ) - 1) + canonize_site!(ψ, Site(i); direction=:right, method=:svd) + + # extract the singular values and contract them with the next tensor + Λᵢ = pop!(ψ, tensors(ψ; between=(Site(i), Site(i + 1)))) + Aᵢ₊₁ = tensors(ψ; at=Site(i + 1)) + replace!(ψ, Aᵢ₊₁ => contract(Aᵢ₊₁, Λᵢ; dims=())) + push!(Λ, Λᵢ) + end + + for i in 2:nsites(ψ) # tensors at i in "A" form, need to contract (Λᵢ)⁻¹ with A to get Γᵢ + Λᵢ = Λ[i - 1] # singular values start between site 1 and 2 + A = tensors(ψ; at=Site(i)) + Γᵢ = contract(A, Tensor(diag(pinv(Diagonal(parent(Λᵢ)); atol=1e-64)), inds(Λᵢ)); dims=()) + replace!(ψ, A => Γᵢ) + push!(ψ, Λᵢ) + end + + return ψ +end + +mixed_canonize(tn::MPS, args...; kwargs...) = mixed_canonize!(deepcopy(tn), args...; kwargs...) + +# TODO mixed_canonize! at bond +""" + mixed_canonize!(tn::MPS, orthog_center) + +Transform a [`MPS`](@ref) tensor network into the mixed-canonical form, that is, +for `i < orthog_center` the tensors are left-canonical and for `i >= orthog_center` the tensors are right-canonical, +and in the `orthog_center` there is a matrix with singular values. +""" +function mixed_canonize!(tn::MPS, orthog_center) + # left-to-right QR sweep (left-canonical tensors) + for i in 1:(id(center) - 1) + canonize_site!(tn, Site(i); direction=:right, method=:qr) + end + + # right-to-left QR sweep (right-canonical tensors) + for i in nsites(tn):-1:(id(center) + 1) + canonize_site!(tn, Site(i); direction=:left, method=:qr) + end + + # center SVD sweep to get singular values + canonize_site!(tn, center; direction=:left, method=:svd) + + return tn +end + +# TODO normalize! methods +function LinearAlgebra.normalize!(ψ::MPS, orthog_center=site"1") + mixed_canonize!(ψ, orthog_center) + normalize!(tensors(tn; between=(Site(id(root) - 1), root)), 2) + return ψ +end From c76fdb31395b72c5eea416d90a5f764530c8eb3b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 17 Sep 2024 16:23:32 -0400 Subject: [PATCH 18/75] Force Graphs to be a strong dependency --- Project.toml | 1 - ext/TenetGraphMakieExt.jl | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index 16910bef2..b9ae892c3 100644 --- a/Project.toml +++ b/Project.toml @@ -27,7 +27,6 @@ ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Dagger = "d58978e5-989f-55fb-8d15-ea34adc7bf54" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2" -Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" ITensorNetworks = "2919e153-833c-4bdc-8836-1ea460a35fc7" ITensors = "9136182c-28ba-11e9-034c-db9fb085ebd5" KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" diff --git a/ext/TenetGraphMakieExt.jl b/ext/TenetGraphMakieExt.jl index e8e121d7e..c013d4df2 100644 --- a/ext/TenetGraphMakieExt.jl +++ b/ext/TenetGraphMakieExt.jl @@ -1,9 +1,9 @@ module TenetGraphMakieExt +using Tenet using GraphMakie +using Graphs using Makie -const Graphs = GraphMakie.Graphs -using Tenet using Combinatorics: combinations """ From 636b03f7d7a42a65b44e0e6c36563379edb5a3cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 18 Sep 2024 14:41:38 -0400 Subject: [PATCH 19/75] Implement `Graphs.neighbors`, `isneighbor` methods --- src/Ansatz/Ansatz.jl | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/Ansatz/Ansatz.jl b/src/Ansatz/Ansatz.jl index f18cd2ff4..ef882ddef 100644 --- a/src/Ansatz/Ansatz.jl +++ b/src/Ansatz/Ansatz.jl @@ -61,14 +61,10 @@ function Base.isapprox(a::AbstractAnsatz, b::AbstractAnsatz; kwargs...) return ==(latice.((a, b))...) && isapprox(Quantum(a), Quantum(b); kwargs...) end -function Graphs.neighbors(tn::AbstractAnsatz, site::Site) - # TODO - # return neighbors(lattice(tn), site) -end - +Graphs.neighbors(tn::AbstractAnsatz, site::Site) = neighbor_labels(lattice(tn), site) function isneighbor(tn::AbstractAnsatz, a::Site, b::Site) - # TODO - # return isneighbor(lattice(tn), a, b) + lt = lattice(tn) + return has_edge(lt, MetaGraphsNext.code_for(lt, a), MetaGraphsNext.code_for(lt, b)) end @kwmethod function inds(tn::AbstractAnsatz; bond) From ab497b5a384ca4e83ee5caca016294eefd1f6af7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 18 Sep 2024 14:42:13 -0400 Subject: [PATCH 20/75] Fix `sites` method for `MPS` --- src/Ansatz/MPS.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Ansatz/MPS.jl b/src/Ansatz/MPS.jl index 0b11ca18a..542a1e176 100644 --- a/src/Ansatz/MPS.jl +++ b/src/Ansatz/MPS.jl @@ -138,7 +138,7 @@ end @kwmethod contract(tn::MPS; between, direction) = contract(tn; between, direction, delete_Λ=true) @kwmethod contract!(tn::MPS; between, direction) = contract!(tn; between, direction, delete_Λ=true) -@kwmethod function sites(ψ::MPS, site::Site; dir) +function sites(ψ::MPS, site::Site; dir) if dir === :left return site == site"1" ? nothing : Site(id(site) - 1) elseif dir === :right From 0079bfd36326c34ebf8ea33a9e7e22b100f70840 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 18 Sep 2024 14:42:32 -0400 Subject: [PATCH 21/75] Fix `inds` method for `MPS` --- src/Ansatz/MPS.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Ansatz/MPS.jl b/src/Ansatz/MPS.jl index 542a1e176..c6262c3a6 100644 --- a/src/Ansatz/MPS.jl +++ b/src/Ansatz/MPS.jl @@ -149,12 +149,12 @@ function sites(ψ::MPS, site::Site; dir) end @kwmethod function inds(ψ::MPS; at, dir) - if dir === :left && site == site"1" + if dir === :left && at == site"1" return nothing - elseif dir === :right && site == Site(nlanes(tn); dual=isdual(site)) + elseif dir === :right && at == Site(nlanes(ψ); dual=isdual(at)) return nothing elseif dir ∈ (:left, :right) - return inds(tn; bond=(site, sites(tn, site; dir))) + return inds(ψ; bond=(at, sites(ψ, at; dir))) else throw(ArgumentError("Unknown direction for MPS = :$dir")) end From c933053f9475569603128310eca724877fd1c902 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 18 Sep 2024 14:42:45 -0400 Subject: [PATCH 22/75] Fix typo --- src/Ansatz/Product.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Ansatz/Product.jl b/src/Ansatz/Product.jl index 8c86f570c..0c5bde2be 100644 --- a/src/Ansatz/Product.jl +++ b/src/Ansatz/Product.jl @@ -32,7 +32,7 @@ function Product(arrays::Vector{<:AbstractMatrix}) gen = IndexCounter() symbols = [nextindex!(gen) for _ in 1:(2 * length(arrays))] _tensors = map(enumerate(arrays)) do (i, array) - Tensor(array, [symbols[i + n], symbols[i]], []) + Tensor(array, [symbols[i + n], symbols[i]]) end sitemap = merge!(Dict(Site(i; dual=true) => symbols[i] for i in 1:n), Dict(Site(i) => symbols[i + n] for i in 1:n)) From 538e698c59b725493cd427dde956283e27cbc585 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 18 Sep 2024 14:43:19 -0400 Subject: [PATCH 23/75] Refactor `adapt_structure` method to support additional types --- ext/TenetAdaptExt.jl | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/ext/TenetAdaptExt.jl b/ext/TenetAdaptExt.jl index 6facff448..06ed5bc9f 100644 --- a/ext/TenetAdaptExt.jl +++ b/ext/TenetAdaptExt.jl @@ -7,7 +7,13 @@ Adapt.adapt_structure(to, x::Tensor) = Tensor(adapt(to, parent(x)), inds(x)) Adapt.adapt_structure(to, x::TensorNetwork) = TensorNetwork(adapt.(Ref(to), tensors(x))) Adapt.adapt_structure(to, x::Quantum) = Quantum(adapt(to, TensorNetwork(x)), x.sites) -Adapt.adapt_structure(to, x::Product) = Product(adapt(to, Quantum(x))) -Adapt.adapt_structure(to, x::Chain) = Chain(adapt(to, Quantum(x)), boundary(x)) +Adapt.adapt_structure(to, x::Ansatz) = Ansatz(adapt(to, Quantum(x)), lattice(x)) + +Adapt.adapt_structure(to, x::Product) = Product(adapt(to, Ansatz(x))) +Adapt.adapt_structure(to, x::Dense) = Dense(adapt(to, Ansatz(x))) +Adapt.adapt_structure(to, x::MPS) = MPS(adapt(to, Ansatz(x)), form(x)) +Adapt.adapt_structure(to, x::MPO) = MPO(adapt(to, Ansatz(x)), form(x)) +Adapt.adapt_structure(to, x::PEPS) = PEPS(adapt(to, Ansatz(x)), form(x)) +Adapt.adapt_structure(to, x::PEPO) = PEPO(adapt(to, Ansatz(x)), form(x)) end From 23fc8d43e0e04fa2bb2bc8c6966f44a87b5436b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 18 Sep 2024 14:44:51 -0400 Subject: [PATCH 24/75] Refactor `Reactant.make_tracer`, `Reactant.create_result` methods on top of recent changes --- ext/TenetReactantExt.jl | 38 ++++++++++++++++++++++++++++---------- 1 file changed, 28 insertions(+), 10 deletions(-) diff --git a/ext/TenetReactantExt.jl b/ext/TenetReactantExt.jl index 47d97beaa..1c11fec22 100644 --- a/ext/TenetReactantExt.jl +++ b/ext/TenetReactantExt.jl @@ -31,17 +31,23 @@ function Reactant.make_tracer(seen::IdDict, prev::Quantum, path::Tuple, mode::Re return Quantum(tracetn, copy(prev.sites)) end +function Reactant.make_tracer(seen::IdDict, prev::Ansatz, path::Tuple, mode::Reactant.TraceMode; kwargs...) + tracetn = Reactant.make_tracer(seen, Quantum(prev), Reactant.append_path(path, :tn), mode; kwargs...) + return Ansatz(tracetn, copy(Tenet.lattice(prev))) +end + +# TODO try rely on generic fallback for ansatzes function Reactant.make_tracer(seen::IdDict, prev::Tenet.Product, path::Tuple, mode::Reactant.TraceMode; kwargs...) - tracequantum = Reactant.make_tracer(seen, Quantum(prev), Reactant.append_path(path, :super), mode; kwargs...) + tracequantum = Reactant.make_tracer(seen, Ansatz(prev), Reactant.append_path(path, :tn), mode; kwargs...) return Tenet.Product(tracequantum) end -# TODO try rely on generic fallback for ansatzes -> do it when refactoring to MPS/MPO -function Reactant.make_tracer(seen::IdDict, prev::Tenet.Chain, path::Tuple, mode::Reactant.TraceMode; kwargs...) - tracequantum = Reactant.make_tracer(seen, Quantum(prev), Reactant.append_path(path, :super), mode; kwargs...) - return Tenet.Chain(tracequantum, boundary(prev)) +for A in (MPS, MPO) + @eval function Reactant.make_tracer(seen::IdDict, prev::$A, path::Tuple, mode::Reactant.TraceMode; kwargs...) + tracequantum = Reactant.make_tracer(seen, Ansatz(prev), Reactant.append_path(path, :tn), mode; kwargs...) + return $A(tracequantum, form(prev)) + end end - function Reactant.create_result(@nospecialize(tocopy::Tensor), @nospecialize(path), result_stores) data = Reactant.create_result(parent(tocopy), Reactant.append_path(path, :data), result_stores) return :($Tensor($data, $(inds(tocopy)))) @@ -59,10 +65,22 @@ function Reactant.create_result(tocopy::Quantum, @nospecialize(path), result_sto return :($Quantum($tn, $(copy(tocopy.sites)))) end -# TODO try rely on generic fallback for ansatzes -> do it when refactoring to MPS/MPO -function Reactant.create_result(tocopy::Tenet.Chain, @nospecialize(path), result_stores) - qtn = Reactant.create_result(Quantum(tocopy), Reactant.append_path(path, :super), result_stores) - return :($(Tenet.Chain)($qtn, $(boundary(tocopy)))) +function Reactant.create_result(tocopy::Ansatz, @nospecialize(path), result_stores) + tn = Reactant.create_result(Quantum(tocopy), Reactant.append_path(path, :tn), result_stores) + return :($Ansatz($tn, $(copy(Tenet.lattice(tocopy))))) +end + +# TODO try rely on generic fallback for ansatzes +function Reactant.create_result(tocopy::Tenet.Product, @nospecialize(path), result_stores) + tn = Reactant.create_result(Ansatz(tocopy), Reactant.append_path(path, :tn), result_stores) + return :($(Tenet.Product)($tn)) +end + +for A in (MPS, MPO) + @eval function Reactant.create_result(tocopy::$A, @nospecialize(path), result_stores) + tn = Reactant.create_result(Ansatz(tocopy), Reactant.append_path(path, :tn), result_stores) + return :($A($tn, form(tocopy))) + end end function Reactant.push_val!(ad_inputs, x::TensorNetwork, path) From 203c87d1dd0605cc7bd1c882627130a11518783b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 18 Sep 2024 14:54:01 -0400 Subject: [PATCH 25/75] Refactor `ChainRules` methods on top of new types --- ext/TenetChainRulesCoreExt/frules.jl | 39 ++++++++++------ ext/TenetChainRulesCoreExt/projectors.jl | 7 +-- ext/TenetChainRulesCoreExt/rrules.jl | 59 +++++++++++++----------- ext/TenetChainRulesTestUtilsExt.jl | 31 +++++++++---- ext/TenetFiniteDifferencesExt.jl | 14 ++++++ test/integration/ChainRules_test.jl | 48 +++++++++++-------- 6 files changed, 124 insertions(+), 74 deletions(-) diff --git a/ext/TenetChainRulesCoreExt/frules.jl b/ext/TenetChainRulesCoreExt/frules.jl index 941a723b6..f9a6bc76a 100644 --- a/ext/TenetChainRulesCoreExt/frules.jl +++ b/ext/TenetChainRulesCoreExt/frules.jl @@ -1,16 +1,39 @@ +using Tenet: AbstractTensorNetwork, AbstractQuantum + # `Tensor` constructor ChainRulesCore.frule((_, Δ, _), T::Type{<:Tensor}, data, inds) = T(data, inds), T(Δ, inds) # `TensorNetwork` constructor ChainRulesCore.frule((_, Δ), ::Type{TensorNetwork}, tensors) = TensorNetwork(tensors), TensorNetworkTangent(Δ) +# `Quantum` constructor +function ChainRulesCore.frule((_, ẋ, _), ::Type{Quantum}, x::TensorNetwork, sites) + return Quantum(x, sites), Tangent{Quantum}(; tn=ẋ, sites=NoTangent()) +end + +# `Ansatz` constructor +function ChainRulesCore.frule((_, ẋ), ::Type{Ansatz}, x::Quantum, lattice) + return Ansatz(x, lattice), Tangent{Ansatz}(; tn=ẋ, lattice=NoTangent()) +end + +# `AbstractAnsatz`-subtype constructors +ChainRulesCore.frule((_, ẋ), ::Type{Product}, x::Ansatz) = Product(x), Tangent{Product}(; tn=ẋ) +ChainRulesCore.frule((_, ẋ), ::Type{Dense}, x::Ansatz) = Dense(x, form), Tangent{Dense}(; tn=ẋ) +ChainRulesCore.frule((_, ẋ), ::Type{MPS}, x::Ansatz, form) = MPS(x, form), Tangent{MPS}(; tn=ẋ, lattice=NoTangent()) +ChainRulesCore.frule((_, ẋ), ::Type{MPO}, x::Ansatz, form) = MPO(x, form), Tangent{MPO}(; tn=ẋ, lattice=NoTangent()) +function ChainRulesCore.frule((_, ẋ), ::Type{PEPS}, x::Ansatz, form) + return PEPS(x, form), Tangent{PEPS}(; tn=ẋ, lattice=NoTangent()) +end + # `Base.conj` methods ChainRulesCore.frule((_, Δ), ::typeof(Base.conj), tn::Tensor) = conj(tn), conj(Δ) -ChainRulesCore.frule((_, Δ), ::typeof(Base.conj), tn::TensorNetwork) = conj(tn), conj(Δ) +ChainRulesCore.frule((_, Δ), ::typeof(Base.conj), tn::AbstractTensorNetwork) = conj(tn), conj(Δ) # `Base.merge` methods -ChainRulesCore.frule((_, ȧ, ḃ), ::typeof(Base.merge), a::TensorNetwork, b::TensorNetwork) = merge(a, b), merge(ȧ, ḃ) +function ChainRulesCore.frule((_, ȧ, ḃ), ::typeof(Base.merge), a::AbstractTensorNetwork, b::AbstractTensorNetwork) + return merge(a, b), merge(ȧ, ḃ) +end # `contract` methods function ChainRulesCore.frule((_, ẋ), ::typeof(contract), x::Tensor; kwargs...) @@ -22,15 +45,3 @@ function ChainRulesCore.frule((_, ȧ, ḃ), ::typeof(contract), a::Tensor, b::T ċ = contract(ȧ, b; kwargs...) + contract(a, ḃ; kwargs...) return c, ċ end - -function ChainRulesCore.frule((_, ẋ, _), ::Type{Quantum}, x::TensorNetwork, sites) - y = Quantum(x, sites) - ẏ = Tangent{Quantum}(; tn=ẋ) - return y, ẏ -end - -ChainRulesCore.frule((_, ẋ), ::Type{T}, x::Quantum) where {T<:Ansatz} = T(x), Tangent{T}(; super=ẋ) - -function ChainRulesCore.frule((_, ẋ, _), ::Type{T}, x::Quantum, boundary) where {T<:Ansatz} - return T(x, boundary), Tangent{T}(; super=ẋ, boundary=NoTangent()) -end diff --git a/ext/TenetChainRulesCoreExt/projectors.jl b/ext/TenetChainRulesCoreExt/projectors.jl index acd488fad..f974f90cc 100644 --- a/ext/TenetChainRulesCoreExt/projectors.jl +++ b/ext/TenetChainRulesCoreExt/projectors.jl @@ -36,8 +36,5 @@ end ChainRulesCore.ProjectTo(x::Quantum) = ProjectTo{Quantum}(; tn=ProjectTo(TensorNetwork(x)), sites=x.sites) (projector::ProjectTo{Quantum})(Δ) = Quantum(projector.tn(Δ), projector.sites) -ChainRulesCore.ProjectTo(x::T) where {T<:Ansatz} = ProjectTo{T}(; super=ProjectTo(Quantum(x))) -(projector::ProjectTo{T})(Δ::Union{T,Tangent{T}}) where {T<:Ansatz} = T(projector.super(Δ.super), Δ.boundary) - -# NOTE edge case: `Product` has no `boundary`. should it? -(projector::ProjectTo{T})(Δ::Union{T,Tangent{T}}) where {T<:Product} = T(projector.super(Δ.super)) +ChainRulesCore.ProjectTo(x::Ansatz) = ProjectTo{Ansatz}(; tn=ProjectTo(Quantum(x))) #, lattice=x.lattice) +(projector::ProjectTo{Ansatz})(Δ) = Ansatz(projector.tn(Δ), Δ.lattice) diff --git a/ext/TenetChainRulesCoreExt/rrules.jl b/ext/TenetChainRulesCoreExt/rrules.jl index 9992c8774..f6c89d055 100644 --- a/ext/TenetChainRulesCoreExt/rrules.jl +++ b/ext/TenetChainRulesCoreExt/rrules.jl @@ -9,6 +9,38 @@ TensorNetwork_pullback(Δ::TensorNetworkTangent) = (NoTangent(), tensors(Δ)) TensorNetwork_pullback(Δ::AbstractThunk) = TensorNetwork_pullback(unthunk(Δ)) ChainRulesCore.rrule(::Type{TensorNetwork}, tensors) = TensorNetwork(tensors), TensorNetwork_pullback +# `Quantum` constructor +Quantum_pullback(ȳ) = (NoTangent(), ȳ.tn, NoTangent()) +Quantum_pullback(ȳ::AbstractArray) = (NoTangent(), ȳ, NoTangent()) +Quantum_pullback(ȳ::AbstractThunk) = Quantum_pullback(unthunk(ȳ)) +ChainRulesCore.rrule(::Type{Quantum}, x::TensorNetwork, sites) = Quantum(x, sites), Quantum_pullback + +# `Ansatz` constructor +Ansatz_pullback(ȳ) = (NoTangent(), ȳ.tn, NoTangent()) +Ansatz_pullback(ȳ::AbstractThunk) = Ansatz_pullback(unthunk(ȳ)) +ChainRulesCore.rrule(::Type{Ansatz}, x::Quantum, lattice) = Ansatz(x, lattice), Ansatz_pullback + +# `AbstractAnsatz`-subtype constructors +Product_pullback(ȳ) = (NoTangent(), ȳ.tn) +Product_pullback(ȳ::AbstractThunk) = Product_pullback(unthunk(ȳ)) +ChainRulesCore.rrule(::Type{Product}, x::Ansatz) = Product(x), Product_pullback + +Dense_pullback(ȳ) = (NoTangent(), ȳ.tn) +Dense_pullback(ȳ::AbstractThunk) = Dense_pullback(unthunk(ȳ)) +ChainRulesCore.rrule(::Type{Dense}, x::Ansatz) = Dense(x), Dense_pullback + +MPS_pullback(ȳ) = (NoTangent(), ȳ.tn, NoTangent()) +MPS_pullback(ȳ::AbstractThunk) = MPS_pullback(unthunk(ȳ)) +ChainRulesCore.rrule(::Type{MPS}, x::Ansatz, form) = MPS(x, form), MPS_pullback + +MPO_pullback(ȳ) = (NoTangent(), ȳ.tn, NoTangent()) +MPO_pullback(ȳ::AbstractThunk) = MPO_pullback(unthunk(ȳ)) +ChainRulesCore.rrule(::Type{MPO}, x::Ansatz, form) = MPO(x, form), MPO_pullback + +PEPS_pullback(ȳ) = (NoTangent(), ȳ.tn, NoTangent()) +PEPS_pullback(ȳ::AbstractThunk) = PEPS_pullback(unthunk(ȳ)) +ChainRulesCore.rrule(::Type{PEPS}, x::Ansatz, form) = PEPS(x, form), PEPS_pullback + # `Base.conj` methods conj_pullback(Δ::Tensor) = (NoTangent(), conj(Δ)) conj_pullback(Δ::Tangent{Tensor}) = (NoTangent(), conj(Δ)) @@ -93,33 +125,6 @@ function ChainRulesCore.rrule(::typeof(contract), a::Tensor, b::Tensor; kwargs.. return c, contract_pullback end -Quantum_pullback(ȳ) = (NoTangent(), ȳ.tn, NoTangent()) -Quantum_pullback(ȳ::AbstractArray) = (NoTangent(), ȳ, NoTangent()) -Quantum_pullback(ȳ::AbstractThunk) = Quantum_pullback(unthunk(ȳ)) -ChainRulesCore.rrule(::Type{Quantum}, x::TensorNetwork, sites) = Quantum(x, sites), Quantum_pullback - -Ansatz_pullback(ȳ) = (NoTangent(), ȳ.super) -Ansatz_pullback(ȳ::AbstractThunk) = Ansatz_pullback(unthunk(ȳ)) -function ChainRulesCore.rrule(::Type{T}, x::Quantum) where {T<:Ansatz} - y = T(x) - return y, Ansatz_pullback -end - -Ansatz_boundary_pullback(ȳ) = (NoTangent(), ȳ.super, NoTangent()) -Ansatz_boundary_pullback(ȳ::AbstractThunk) = Ansatz_boundary_pullback(unthunk(ȳ)) -function ChainRulesCore.rrule(::Type{T}, x::Quantum, boundary) where {T<:Ansatz} - return T(x, boundary), Ansatz_boundary_pullback -end - -Ansatz_from_arrays_pullback(ȳ) = (NoTangent(), NoTangent(), NoTangent(), parent.(tensors(ȳ.super.tn))) -Ansatz_from_arrays_pullback(ȳ::AbstractThunk) = Ansatz_from_arrays_pullback(unthunk(ȳ)) -function ChainRulesCore.rrule( - ::Type{T}, socket::Tenet.Socket, boundary::Tenet.Boundary, arrays; kwargs... -) where {T<:Ansatz} - y = T(socket, boundary, arrays; kwargs...) - return y, Ansatz_from_arrays_pullback -end - copy_pullback(ȳ) = (NoTangent(), ȳ) copy_pullback(ȳ::AbstractThunk) = unthunk(ȳ) function ChainRulesCore.rrule(::typeof(copy), x::Quantum) diff --git a/ext/TenetChainRulesTestUtilsExt.jl b/ext/TenetChainRulesTestUtilsExt.jl index 1565387e0..ab3decd49 100644 --- a/ext/TenetChainRulesTestUtilsExt.jl +++ b/ext/TenetChainRulesTestUtilsExt.jl @@ -6,14 +6,16 @@ using Tenet using ChainRulesCore using ChainRulesTestUtils using Random +using Graphs +using MetaGraphsNext const TensorNetworkTangent = Base.get_extension(Tenet, :TenetChainRulesCoreExt).TensorNetworkTangent -function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::Vector{T}) where {T<:Tensor} - if isempty(x) +function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, tn::Vector{T}) where {T<:Tensor} + if isempty(tn) return Vector{T}() else - @invoke rand_tangent(rng::AbstractRNG, x::AbstractArray) + @invoke rand_tangent(rng::AbstractRNG, tn::AbstractArray) end end @@ -21,12 +23,25 @@ function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::TensorNetwork) return TensorNetworkTangent(Tensor[ProjectTo(tensor)(rand_tangent.(Ref(rng), tensor)) for tensor in tensors(x)]) end -function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::Quantum) - return Tangent{Quantum}(; tn=rand_tangent(rng, TensorNetwork(x)), sites=NoTangent()) +function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, tn::Quantum) + return Tangent{Quantum}(; tn=rand_tangent(rng, TensorNetwork(tn)), sites=NoTangent()) end -# WARN type-piracy -# NOTE used in `Quantum` constructor -ChainRulesTestUtils.rand_tangent(::AbstractRNG, x::Dict{<:Site,Symbol}) = NoTangent() +# WARN type-piracy, used in `Quantum` constructor +ChainRulesTestUtils.rand_tangent(::AbstractRNG, tn::Dict{<:Site,Symbol}) = NoTangent() + +function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, tn::Ansatz) + return Tangent{Ansatz}(; tn=rand_tangent(rng, Quantum(tn)), lattice=NoTangent()) +end + +# WARN not really type-piracy but almost, used in `Ansatz` constructor +ChainRulesTestUtils.rand_tangent(::AbstractRNG, tn::T) where {V,T<:MetaGraph{V,SimpleGraph{V},<:Site}} = NoTangent() + +# WARN not really type-piracy but almost, used when testing `Ansatz` +function ChainRulesTestUtils.test_approx( + actual::G, expected::G, msg; kwargs... +) where {G<:MetaGraph{Int64,SimpleGraph{Int64},<:Site}} + return actual == expected +end end diff --git a/ext/TenetFiniteDifferencesExt.jl b/ext/TenetFiniteDifferencesExt.jl index 7b3308428..a0355eed9 100644 --- a/ext/TenetFiniteDifferencesExt.jl +++ b/ext/TenetFiniteDifferencesExt.jl @@ -20,4 +20,18 @@ function FiniteDifferences.to_vec(x::Dict{Vector{Symbol},Tensor}) return x_vec, Dict_from_vec end +function FiniteDifferences.to_vec(x::Quantum) + x_vec, back = to_vec(TensorNetwork(x)) + Quantum_from_vec(v) = Quantum(back(v), copy(x.sites)) + + return x_vec, Quantum_from_vec +end + +function FiniteDifferences.to_vec(x::Ansatz) + x_vec, back = to_vec(Quantum(x)) + Ansatz_from_vec(v) = Ansatz(back(v), copy(x.lattice)) + + return x_vec, Ansatz_from_vec +end + end diff --git a/test/integration/ChainRules_test.jl b/test/integration/ChainRules_test.jl index 79b763867..e6219005d 100644 --- a/test/integration/ChainRules_test.jl +++ b/test/integration/ChainRules_test.jl @@ -1,6 +1,8 @@ @testset "ChainRules" begin using Tenet: Tensor, contract using ChainRulesTestUtils + using Graphs + using MetaGraphsNext @testset "Tensor" begin test_frule(Tensor, ones(), Symbol[]) @@ -190,30 +192,36 @@ end @testset "Ansatz" begin - @testset "Product" begin - tn = TensorNetwork([Tensor(ones(2), [:i]), Tensor(ones(2), [:j]), Tensor(ones(2), [:k])]) - qtn = Quantum(tn, Dict([site"1" => :i, site"2" => :j, site"3" => :k])) + tn = Quantum(TensorNetwork([Tensor(ones(2), [:i])]), Dict{Site,Symbol}(site"1" => :i)) + lattice = MetaGraph(Graph(1), Pair{Site,Nothing}[site"1" => nothing], Pair{Tuple{Site,Site},Nothing}[]) + test_frule(Ansatz, tn, lattice) + test_rrule(Ansatz, tn, lattice) + end - test_frule(Product, qtn) - test_rrule(Product, qtn) - end + @testset "Product" begin + tn = Product([ones(2), ones(2), ones(2)]) - @testset "Chain" begin - tn = Chain(State(), Open(), [ones(2, 2), ones(2, 2, 2), ones(2, 2)]) - # test_frule(Chain, Quantum(tn), Open()) - test_rrule(Chain, Quantum(tn), Open()) + test_frule(Product, Ansatz(tn)) + test_rrule(Product, Ansatz(tn)) + end - tn = Chain(State(), Periodic(), [ones(2, 2, 2), ones(2, 2, 2), ones(2, 2, 2)]) - # test_frule(Chain, Quantum(tn), Periodic()) - test_rrule(Chain, Quantum(tn), Periodic()) + @testset "MPS" begin + tn = MPS([ones(2, 2), ones(2, 2, 2), ones(2, 2)]) + # test_frule(MPS, Ansatz(tn), form(tn)) + test_rrule(MPS, Ansatz(tn), form(tn)) - tn = Chain(Operator(), Open(), [ones(2, 2, 2), ones(2, 2, 2, 2), ones(2, 2, 2)]) - # test_frule(Chain, Quantum(tn), Open()) - test_rrule(Chain, Quantum(tn), Open()) + # TODO reenable periodic MPS + # tn = MPS([ones(2, 2, 2), ones(2, 2, 2), ones(2, 2, 2)]) + # test_frule(Chain, Quantum(tn), Periodic()) + # test_rrule(Chain, Quantum(tn), Periodic()) - tn = Chain(Operator(), Periodic(), [ones(2, 2, 2, 2), ones(2, 2, 2, 2), ones(2, 2, 2, 2)]) - # test_frule(Chain, Quantum(tn), Periodic()) - test_rrule(Chain, Quantum(tn), Periodic()) - end + tn = MPO([ones(2, 2, 2), ones(2, 2, 2, 2), ones(2, 2, 2)]) + # test_frule(MPO, Ansatz(tn), form(tn)) + test_rrule(MPO, Ansatz(tn), form(tn)) + + # TODO reenable periodic MPO + # tn = Chain(Operator(), Periodic(), [ones(2, 2, 2, 2), ones(2, 2, 2, 2), ones(2, 2, 2, 2)]) + # test_frule(Chain, Quantum(tn), Periodic()) + # test_rrule(Chain, Quantum(tn), Periodic()) end end From 7520a2e39529bb405826168d3d99857dac611ab8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 18 Sep 2024 15:47:17 -0400 Subject: [PATCH 26/75] Refactor `ProjectTo` for `Ansatz` --- ext/TenetChainRulesCoreExt/projectors.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/TenetChainRulesCoreExt/projectors.jl b/ext/TenetChainRulesCoreExt/projectors.jl index f974f90cc..299a3aaab 100644 --- a/ext/TenetChainRulesCoreExt/projectors.jl +++ b/ext/TenetChainRulesCoreExt/projectors.jl @@ -36,5 +36,5 @@ end ChainRulesCore.ProjectTo(x::Quantum) = ProjectTo{Quantum}(; tn=ProjectTo(TensorNetwork(x)), sites=x.sites) (projector::ProjectTo{Quantum})(Δ) = Quantum(projector.tn(Δ), projector.sites) -ChainRulesCore.ProjectTo(x::Ansatz) = ProjectTo{Ansatz}(; tn=ProjectTo(Quantum(x))) #, lattice=x.lattice) +ChainRulesCore.ProjectTo(x::Ansatz) = ProjectTo{Ansatz}(; tn=ProjectTo(Quantum(x)), lattice=x.lattice) (projector::ProjectTo{Ansatz})(Δ) = Ansatz(projector.tn(Δ), Δ.lattice) From 64b34b95c9b1a6946c6331202b2c3f69c41ff23a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 18 Sep 2024 15:47:47 -0400 Subject: [PATCH 27/75] Refactor `rand` for `MPS`, `MPO` --- src/Ansatz/MPO.jl | 9 +++++---- src/Ansatz/MPS.jl | 5 +++-- src/TensorNetwork.jl | 4 ++++ 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/Ansatz/MPO.jl b/src/Ansatz/MPO.jl index b98f2d534..44d704618 100644 --- a/src/Ansatz/MPO.jl +++ b/src/Ansatz/MPO.jl @@ -81,9 +81,10 @@ Base.adjoint(tn::MPO) = MPO(adjoint(Ansatz(tn)), form(tn)) # TODO different input/output physical dims # TODO let choose the orthogonality center -function Base.rand(rng::Random.AbstractRNG, ::Type{MPO}, n, χ; eltype=Float64, physical_dim=2) +function Base.rand(rng::Random.AbstractRNG, ::Type{MPO}; n, maxdim, eltype=Float64, physdim=2) T = eltype - ip = op = physical_dim + ip = op = physdim + χ = maxdim arrays::Vector{AbstractArray{T,N} where {N}} = map(1:n) do i χl, χr = let after_mid = i > n ÷ 2, i = (n + 1 - abs(2i - n - 1)) ÷ 2 @@ -100,8 +101,8 @@ function Base.rand(rng::Random.AbstractRNG, ::Type{MPO}, n, χ; eltype=Float64, end # reshape boundary sites - arrays[1] = reshape(arrays[1], p, p, min(χ, ip * op)) - arrays[n] = reshape(arrays[n], min(χ, ip * op), p, p) + arrays[1] = reshape(arrays[1], ip, op, min(χ, ip * op)) + arrays[n] = reshape(arrays[n], min(χ, ip * op), ip, op) # TODO order might not be the best for performance return MPO(arrays; order=(:l, :i, :o, :r)) diff --git a/src/Ansatz/MPS.jl b/src/Ansatz/MPS.jl index c6262c3a6..6b83faf16 100644 --- a/src/Ansatz/MPS.jl +++ b/src/Ansatz/MPS.jl @@ -79,9 +79,10 @@ Base.adjoint(tn::MPS) = MPS(adjoint(Ansatz(tn)), form(tn)) # TODO different input/output physical dims # TODO let choose the orthogonality center -function Base.rand(rng::Random.AbstractRNG, ::Type{MPS}, n, χ; eltype=Float64, physical_dim=2) - p = physical_dim +function Base.rand(rng::Random.AbstractRNG, ::Type{MPS}; n, maxdim, eltype=Float64, physdim=2) + p = physdim T = eltype + χ = maxdim arrays::Vector{AbstractArray{T,N} where {N}} = map(1:n) do i χl, χr = let after_mid = i > n ÷ 2, i = (n + 1 - abs(2i - n - 1)) ÷ 2 diff --git a/src/TensorNetwork.jl b/src/TensorNetwork.jl index 8778de50b..826be2a34 100644 --- a/src/TensorNetwork.jl +++ b/src/TensorNetwork.jl @@ -727,6 +727,10 @@ function Base.rand(::Type{TensorNetwork}, n::Integer, regularity::Integer; kwarg return rand(Random.default_rng(), TensorNetwork, n, regularity; kwargs...) end +function Base.rand(::Type{T}, args...; kwargs...) where {T<:AbstractTensorNetwork} + return rand(Random.default_rng(), T, args...; kwargs...) +end + function Serialization.serialize(s::AbstractSerializer, obj::TensorNetwork) Serialization.writetag(s.io, Serialization.OBJECT_TAG) return serialize(s, tensors(obj)) From 0c79f0bb138d9e76b656ad715f4e068998168f80 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 18 Sep 2024 15:48:03 -0400 Subject: [PATCH 28/75] Export `canonize_site`, `canonize_site!` methods --- src/Tenet.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Tenet.jl b/src/Tenet.jl index 79d13304b..c9a7797d7 100644 --- a/src/Tenet.jl +++ b/src/Tenet.jl @@ -46,7 +46,7 @@ export PEPS include("Ansatz/PEPO.jl") export PEPO -export canonize, canonize!, mixed_canonize, mixed_canonize!, truncate! +export canonize_site, canonize_site!, canonize, canonize!, mixed_canonize, mixed_canonize!, truncate! export evolve!, expect, overlap # reexports from EinExprs From 076ee5b15fa36202d21a54a640d4b303e7004a1a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 18 Sep 2024 15:49:30 -0400 Subject: [PATCH 29/75] Refactor `Chain` tests on top of `MPS`, `MPO` --- test/Chain_test.jl | 436 --------------------------------------------- test/MPO_test.jl | 72 ++++++++ test/MPS_test.jl | 258 +++++++++++++++++++++++++++ test/runtests.jl | 3 +- 4 files changed, 332 insertions(+), 437 deletions(-) delete mode 100644 test/Chain_test.jl create mode 100644 test/MPO_test.jl create mode 100644 test/MPS_test.jl diff --git a/test/Chain_test.jl b/test/Chain_test.jl deleted file mode 100644 index d2b946896..000000000 --- a/test/Chain_test.jl +++ /dev/null @@ -1,436 +0,0 @@ -@testset "Chain ansatz" begin - @testset "Periodic boundary" begin - @testset "State" begin - qtn = Chain(State(), Periodic(), [rand(2, 4, 4) for _ in 1:3]) - @test socket(qtn) == State() - @test nsites(qtn; set=:inputs) == 0 - @test nsites(qtn; set=:outputs) == 3 - @test issetequal(sites(qtn), [site"1", site"2", site"3"]) - @test boundary(qtn) == Periodic() - @test leftindex(qtn, site"1") == rightindex(qtn, site"3") != nothing - - arrays = [rand(2, 1, 4), rand(2, 4, 3), rand(2, 3, 1)] - qtn = Chain(State(), Periodic(), arrays) # Default order (:o, :l, :r) - - @test size(tensors(qtn; at=Site(1))) == (2, 1, 4) - @test size(tensors(qtn; at=Site(2))) == (2, 4, 3) - @test size(tensors(qtn; at=Site(3))) == (2, 3, 1) - - @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) - @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) - @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) - - arrays = [permutedims(array, (3, 1, 2)) for array in arrays] # now we have (:r, :o, :l) - qtn = Chain(State(), Periodic(), arrays; order=[:r, :o, :l]) - - @test size(tensors(qtn; at=Site(1))) == (4, 2, 1) - @test size(tensors(qtn; at=Site(2))) == (3, 2, 4) - @test size(tensors(qtn; at=Site(3))) == (1, 2, 3) - - @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) - @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) - @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) - - for i in 1:nsites(qtn) - @test size(qtn, inds(qtn; at=Site(i))) == 2 - end - end - - @testset "Operator" begin - qtn = Chain(Operator(), Periodic(), [rand(2, 2, 4, 4) for _ in 1:3]) - @test socket(qtn) == Operator() - @test nsites(qtn; set=:inputs) == 3 - @test nsites(qtn; set=:outputs) == 3 - @test issetequal(sites(qtn), [site"1", site"2", site"3", site"1'", site"2'", site"3'"]) - @test boundary(qtn) == Periodic() - @test leftindex(qtn, site"1") == rightindex(qtn, site"3") != nothing - - arrays = [rand(2, 4, 1, 3), rand(2, 4, 3, 6), rand(2, 4, 6, 1)] # Default order (:o, :i, :l, :r) - qtn = Chain(Operator(), Periodic(), arrays) - - @test size(tensors(qtn; at=Site(1))) == (2, 4, 1, 3) - @test size(tensors(qtn; at=Site(2))) == (2, 4, 3, 6) - @test size(tensors(qtn; at=Site(3))) == (2, 4, 6, 1) - - @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) - @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) - @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) - - for i in 1:length(arrays) - @test size(qtn, inds(qtn; at=Site(i))) == 2 - @test size(qtn, inds(qtn; at=Site(i; dual=true))) == 4 - end - - arrays = [permutedims(array, (4, 1, 3, 2)) for array in arrays] # now we have (:r, :o, :l, :i) - qtn = Chain(Operator(), Periodic(), arrays; order=[:r, :o, :l, :i]) - - @test size(tensors(qtn; at=Site(1))) == (3, 2, 1, 4) - @test size(tensors(qtn; at=Site(2))) == (6, 2, 3, 4) - @test size(tensors(qtn; at=Site(3))) == (1, 2, 6, 4) - - @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) !== nothing - @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) !== nothing - @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) !== nothing - - for i in 1:length(arrays) - @test size(qtn, inds(qtn; at=Site(i))) == 2 - @test size(qtn, inds(qtn; at=Site(i; dual=true))) == 4 - end - end - end - - @testset "Open boundary" begin - @testset "State" begin - qtn = Chain(State(), Open(), [rand(2, 2), rand(2, 2, 2), rand(2, 2)]) - @test socket(qtn) == State() - @test nsites(qtn; set=:inputs) == 0 - @test nsites(qtn; set=:outputs) == 3 - @test issetequal(sites(qtn), [site"1", site"2", site"3"]) - @test boundary(qtn) == Open() - @test leftindex(qtn, site"1") == rightindex(qtn, site"3") == nothing - - arrays = [rand(2, 1), rand(2, 1, 3), rand(2, 3)] - qtn = Chain(State(), Open(), arrays) # Default order (:o, :l, :r) - - @test size(tensors(qtn; at=Site(1))) == (2, 1) - @test size(tensors(qtn; at=Site(2))) == (2, 1, 3) - @test size(tensors(qtn; at=Site(3))) == (2, 3) - - @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) === nothing - @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) - @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) - - arrays = [permutedims(arrays[1], (2, 1)), permutedims(arrays[2], (3, 1, 2)), permutedims(arrays[3], (1, 2))] # now we have (:r, :o, :l) - qtn = Chain(State(), Open(), arrays; order=[:r, :o, :l]) - - @test size(tensors(qtn; at=Site(1))) == (1, 2) - @test size(tensors(qtn; at=Site(2))) == (3, 2, 1) - @test size(tensors(qtn; at=Site(3))) == (2, 3) - - @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) === nothing - @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) !== nothing - @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) !== nothing - - for i in 1:nsites(qtn) - @test size(qtn, inds(qtn; at=Site(i))) == 2 - end - end - @testset "Operator" begin - qtn = Chain(Operator(), Open(), [rand(2, 2, 4), rand(2, 2, 4, 4), rand(2, 2, 4)]) - @test socket(qtn) == Operator() - @test nsites(qtn; set=:inputs) == 3 - @test nsites(qtn; set=:outputs) == 3 - @test issetequal(sites(qtn), [site"1", site"2", site"3", site"1'", site"2'", site"3'"]) - @test boundary(qtn) == Open() - @test leftindex(qtn, site"1") == rightindex(qtn, site"3") == nothing - - arrays = [rand(2, 4, 1), rand(2, 4, 1, 3), rand(2, 4, 3)] # Default order (:o :i, :l, :r) - qtn = Chain(Operator(), Open(), arrays) - - @test size(tensors(qtn; at=Site(1))) == (2, 4, 1) - @test size(tensors(qtn; at=Site(2))) == (2, 4, 1, 3) - @test size(tensors(qtn; at=Site(3))) == (2, 4, 3) - - @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) === nothing - @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) !== nothing - @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) !== nothing - - for i in 1:length(arrays) - @test size(qtn, inds(qtn; at=Site(i))) == 2 - @test size(qtn, inds(qtn; at=Site(i; dual=true))) == 4 - end - - arrays = [ - permutedims(arrays[1], (3, 1, 2)), - permutedims(arrays[2], (4, 1, 3, 2)), - permutedims(arrays[3], (1, 3, 2)), - ] # now we have (:r, :o, :l, :i) - qtn = Chain(Operator(), Open(), arrays; order=[:r, :o, :l, :i]) - - @test size(tensors(qtn; at=Site(1))) == (1, 2, 4) - @test size(tensors(qtn; at=Site(2))) == (3, 2, 1, 4) - @test size(tensors(qtn; at=Site(3))) == (2, 3, 4) - - @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) === nothing - @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) !== nothing - @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) !== nothing - - for i in 1:length(arrays) - @test size(qtn, inds(qtn; at=Site(i))) == 2 - @test size(qtn, inds(qtn; at=Site(i; dual=true))) == 4 - end - end - end - - @testset "Site" begin - using Tenet: leftsite, rightsite - qtn = Chain(State(), Periodic(), [rand(2, 4, 4) for _ in 1:3]) - - @test leftsite(qtn, Site(1)) == Site(3) - @test leftsite(qtn, Site(2)) == Site(1) - @test leftsite(qtn, Site(3)) == Site(2) - - @test rightsite(qtn, Site(1)) == Site(2) - @test rightsite(qtn, Site(2)) == Site(3) - @test rightsite(qtn, Site(3)) == Site(1) - - qtn = Chain(State(), Open(), [rand(2, 2), rand(2, 2, 2), rand(2, 2)]) - - @test isnothing(leftsite(qtn, Site(1))) - @test isnothing(rightsite(qtn, Site(3))) - - @test leftsite(qtn, Site(2)) == Site(1) - @test leftsite(qtn, Site(3)) == Site(2) - - @test rightsite(qtn, Site(2)) == Site(3) - @test rightsite(qtn, Site(1)) == Site(2) - end - - @testset "truncate" begin - qtn = Chain(State(), Open(), [rand(2, 2), rand(2, 2, 2), rand(2, 2)]) - canonize_site!(qtn, Site(2); direction=:right, method=:svd) - - @test_throws Tenet.MissingSchmidtCoefficientsException truncate!(qtn, [Site(1), Site(2)]; maxdim=1) - # @test_throws ArgumentError truncate!(qtn, [Site(2), Site(3)]) - - truncated = Tenet.truncate(qtn, [Site(2), Site(3)]; maxdim=1) - @test size(truncated, rightindex(truncated, Site(2))) == 1 - @test size(truncated, leftindex(truncated, Site(3))) == 1 - - singular_values = tensors(qtn; between=(Site(2), Site(3))) - truncated = Tenet.truncate(qtn, [Site(2), Site(3)]; threshold=singular_values[2] + 0.1) - @test size(truncated, rightindex(truncated, Site(2))) == 1 - @test size(truncated, leftindex(truncated, Site(3))) == 1 - end - - @testset "rand" begin - using LinearAlgebra: norm - - @testset "State" begin - n = 8 - χ = 10 - - qtn = rand(Chain, Open, State; n, p=2, χ) - @test socket(qtn) == State() - @test nsites(qtn; set=:inputs) == 0 - @test nsites(qtn; set=:outputs) == n - @test issetequal(sites(qtn), map(Site, 1:n)) - @test boundary(qtn) == Open() - @test isapprox(norm(qtn), 1.0) - @test maximum(last, size(qtn)) <= χ - end - - @testset "Operator" begin - n = 8 - χ = 10 - - qtn = rand(Chain, Open, Operator; n, p=2, χ) - @test socket(qtn) == Operator() - @test nsites(qtn; set=:inputs) == n - @test nsites(qtn; set=:outputs) == n - @test issetequal(sites(qtn), vcat(map(Site, 1:n), map(adjoint ∘ Site, 1:n))) - @test boundary(qtn) == Open() - @test isapprox(norm(qtn), 1.0) - @test maximum(last, size(qtn)) <= χ - end - end - - @testset "Canonization" begin - using Tenet - - @testset "contract" begin - qtn = rand(Chain, Open, State; n=5, p=2, χ=20) - let canonized = canonize(qtn) - @test_throws ArgumentError contract!(canonized; between=(Site(1), Site(2)), direction=:dummy) - end - - canonized = canonize(qtn) - - for i in 1:4 - contract_some = contract(canonized; between=(Site(i), Site(i + 1))) - Bᵢ = tensors(contract_some; at=Site(i)) - - @test isapprox(contract(contract_some), contract(qtn)) - @test_throws ArgumentError tensors(contract_some; between=(Site(i), Site(i + 1))) - - @test isrightcanonical(contract_some, Site(i)) - @test isleftcanonical( - contract(canonized; between=(Site(i), Site(i + 1)), direction=:right), Site(i + 1) - ) - - Γᵢ = tensors(canonized; at=Site(i)) - Λᵢ₊₁ = tensors(canonized; between=(Site(i), Site(i + 1))) - @test Bᵢ ≈ contract(Γᵢ, Λᵢ₊₁; dims=()) - end - end - - @testset "canonize_site" begin - qtn = Chain(State(), Open(), [rand(4, 4), rand(4, 4, 4), rand(4, 4)]) - - @test_throws ArgumentError canonize_site!(qtn, Site(1); direction=:left) - @test_throws ArgumentError canonize_site!(qtn, Site(3); direction=:right) - - for method in [:qr, :svd] - canonized = canonize_site(qtn, site"1"; direction=:right, method=method) - @test isleftcanonical(canonized, site"1") - @test isapprox(contract(transform(TensorNetwork(canonized), Tenet.HyperFlatten())), contract(qtn)) - - canonized = canonize_site(qtn, site"2"; direction=:right, method=method) - @test isleftcanonical(canonized, site"2") - @test isapprox(contract(transform(TensorNetwork(canonized), Tenet.HyperFlatten())), contract(qtn)) - - canonized = canonize_site(qtn, site"2"; direction=:left, method=method) - @test isrightcanonical(canonized, site"2") - @test isapprox(contract(transform(TensorNetwork(canonized), Tenet.HyperFlatten())), contract(qtn)) - - canonized = canonize_site(qtn, site"3"; direction=:left, method=method) - @test isrightcanonical(canonized, site"3") - @test isapprox(contract(transform(TensorNetwork(canonized), Tenet.HyperFlatten())), contract(qtn)) - end - - # Ensure that svd creates a new tensor - @test length(tensors(canonize_site(qtn, Site(2); direction=:left, method=:svd))) == 4 - end - - @testset "canonize" begin - using Tenet: isleftcanonical, isrightcanonical - - qtn = MPS([rand(4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4)]) - canonized = canonize(qtn) - - @test length(tensors(canonized)) == 9 # 5 tensors + 4 singular values vectors - @test isapprox(contract(transform(TensorNetwork(canonized), Tenet.HyperFlatten())), contract(qtn)) - @test isapprox(norm(qtn), norm(canonized)) - - # Extract the singular values between each adjacent pair of sites in the canonized chain - Λ = [tensors(canonized; between=(Site(i), Site(i + 1))) for i in 1:4] - @test map(λ -> sum(abs2, λ), Λ) ≈ ones(length(Λ)) * norm(canonized)^2 - - for i in 1:5 - canonized = canonize(qtn) - - if i == 1 - @test isleftcanonical(canonized, Site(i)) - elseif i == 5 # in the limits of the chain, we get the norm of the state - contract!(canonized; between=(Site(i - 1), Site(i)), direction=:right) - tensor = tensors(canonized; at=Site(i)) - replace!(canonized, tensor => tensor / norm(canonized)) - @test isleftcanonical(canonized, Site(i)) - else - contract!(canonized; between=(Site(i - 1), Site(i)), direction=:right) - @test isleftcanonical(canonized, Site(i)) - end - end - - for i in 1:5 - canonized = canonize(qtn) - - if i == 1 # in the limits of the chain, we get the norm of the state - contract!(canonized; between=(Site(i), Site(i + 1)), direction=:left) - tensor = tensors(canonized; at=Site(i)) - replace!(canonized, tensor => tensor / norm(canonized)) - @test isrightcanonical(canonized, Site(i)) - elseif i == 5 - @test isrightcanonical(canonized, Site(i)) - else - contract!(canonized; between=(Site(i), Site(i + 1)), direction=:left) - @test isrightcanonical(canonized, Site(i)) - end - end - end - - @testset "mixed_canonize" begin - qtn = Chain(State(), Open(), [rand(4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4)]) - canonized = mixed_canonize(qtn, Site(3)) - - @test length(tensors(canonized)) == length(tensors(qtn)) + 1 - - @test isleftcanonical(canonized, Site(1)) - @test isleftcanonical(canonized, Site(2)) - @test isrightcanonical(canonized, Site(3)) - @test isrightcanonical(canonized, Site(4)) - @test isrightcanonical(canonized, Site(5)) - - @test isapprox(contract(transform(TensorNetwork(canonized), Tenet.HyperFlatten())), contract(qtn)) - end - end - - @test begin - qtn = MPS([rand(4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4)]) - normalize!(qtn, Site(3)) - isapprox(norm(qtn), 1.0) - end - - @testset "adjoint" begin - qtn = rand(Chain, Open, State; n=5, p=2, χ=10) - adjoint_qtn = adjoint(qtn) - - for i in 1:nsites(qtn) - i < nsites(qtn) && - @test rightindex(adjoint_qtn, Site(i; dual=true)) == Symbol(String(rightindex(qtn, Site(i))) * "'") - i > 1 && @test leftindex(adjoint_qtn, Site(i; dual=true)) == Symbol(String(leftindex(qtn, Site(i))) * "'") - end - - @test isapprox(contract(qtn), contract(adjoint_qtn)) - end - - @testset "evolve!" begin - @testset "one site" begin - i = 2 - mat = reshape(LinearAlgebra.I(2), 2, 2) - gate = Dense(Tenet.Operator(), mat; sites=[Site(i), Site(i; dual=true)]) - - qtn = Chain(State(), Open(), [rand(2, 2), rand(2, 2, 2), rand(2, 2, 2), rand(2, 2, 2), rand(2, 2)]) - - @testset "canonical form" begin - canonized = canonize(qtn) - - evolved = evolve!(deepcopy(canonized), gate; threshold=1e-14) - @test isapprox(contract(evolved), contract(canonized)) - @test issetequal(size.(tensors(evolved)), [(2, 2), (2,), (2, 2, 2), (2,), (2, 2, 2), (2,), (2, 2)]) - @test isapprox(contract(evolved), contract(qtn)) - end - - @testset "arbitrary chain" begin - evolved = evolve!(deepcopy(qtn), gate; threshold=1e-14, iscanonical=false) - @test length(tensors(evolved)) == 5 - @test issetequal(size.(tensors(evolved)), [(2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2)]) - @test isapprox(contract(evolved), contract(qtn)) - end - end - - @testset "two sites" begin - i, j = 2, 3 - mat = reshape(kron(LinearAlgebra.I(2), LinearAlgebra.I(2)), 2, 2, 2, 2) - gate = Dense(Tenet.Operator(), mat; sites=[Site(i), Site(j), Site(i; dual=true), Site(j; dual=true)]) - - qtn = Chain(State(), Open(), [rand(2, 2), rand(2, 2, 2), rand(2, 2, 2), rand(2, 2)]) - - @testset "canonical form" begin - canonized = canonize(qtn) - - evolved = evolve!(deepcopy(canonized), gate; threshold=1e-14) - @test isapprox(contract(evolved), contract(canonized)) - @test issetequal(size.(tensors(evolved)), [(2, 2), (2,), (2, 2, 2), (2,), (2, 2, 2), (2,), (2, 2)]) - @test isapprox(contract(evolved), contract(qtn)) - end - - @testset "arbitrary chain" begin - evolved = evolve!(deepcopy(qtn), gate; threshold=1e-14, iscanonical=false) - @test length(tensors(evolved)) == 5 - @test issetequal(size.(tensors(evolved)), [(2, 2), (2, 2, 2), (2,), (2, 2, 2), (2, 2, 2), (2, 2)]) - @test isapprox(contract(evolved), contract(qtn)) - end - end - end - - @testset "expect" begin - i, j = 2, 3 - mat = reshape(kron(LinearAlgebra.I(2), LinearAlgebra.I(2)), 2, 2, 2, 2) - gate = Dense(Tenet.Operator(), mat; sites=[Site(i), Site(j), Site(i; dual=true), Site(j; dual=true)]) - - qtn = Chain(State(), Open(), [rand(2, 2), rand(2, 2, 2), rand(2, 2, 2), rand(2, 2, 2), rand(2, 2)]) - - @test isapprox(expect(qtn, [gate]), norm(qtn)^2) - end -end diff --git a/test/MPO_test.jl b/test/MPO_test.jl new file mode 100644 index 000000000..1d461597b --- /dev/null +++ b/test/MPO_test.jl @@ -0,0 +1,72 @@ +@testset "MPO" begin + H = MPO([rand(2, 2, 4), rand(2, 2, 4, 4), rand(2, 2, 4)]) + @test socket(H) == Operator() + @test nsites(H; set=:inputs) == 3 + @test nsites(H; set=:outputs) == 3 + @test issetequal(sites(H), [site"1", site"2", site"3", site"1'", site"2'", site"3'"]) + @test boundary(H) == Open() + @test inds(H; at=site"1", dir=:left) == inds(H; at=site"3", dir=:right) == nothing + + arrays = [rand(2, 4, 1), rand(2, 4, 1, 3), rand(2, 4, 3)] # Default order (:o :i, :l, :r) + H = MPO(arrays) + + @test size(tensors(H; at=Site(1))) == (2, 4, 1) + @test size(tensors(H; at=Site(2))) == (2, 4, 1, 3) + @test size(tensors(H; at=Site(3))) == (2, 4, 3) + + @test inds(H; at=Site(1), dir=:left) == inds(H; at=Site(3), dir=:right) === nothing + @test inds(H; at=Site(2), dir=:left) == inds(H; at=Site(1), dir=:right) !== nothing + @test inds(H; at=Site(3), dir=:left) == inds(H; at=Site(2), dir=:right) !== nothing + + for i in 1:length(arrays) + @test size(H, inds(H; at=Site(i))) == 2 + @test size(H, inds(H; at=Site(i; dual=true))) == 4 + end + + arrays = [ + permutedims(arrays[1], (3, 1, 2)), permutedims(arrays[2], (4, 1, 3, 2)), permutedims(arrays[3], (1, 3, 2)) + ] # now we have (:r, :o, :l, :i) + H = MPO(arrays; order=[:r, :o, :l, :i]) + + @test size(tensors(H; at=Site(1))) == (1, 2, 4) + @test size(tensors(H; at=Site(2))) == (3, 2, 1, 4) + @test size(tensors(H; at=Site(3))) == (2, 3, 4) + + @test inds(H; at=Site(1), dir=:left) == inds(H; at=Site(3), dir=:right) === nothing + @test inds(H; at=Site(2), dir=:left) == inds(H; at=Site(1), dir=:right) !== nothing + @test inds(H; at=Site(3), dir=:left) == inds(H; at=Site(2), dir=:right) !== nothing + + for i in 1:length(arrays) + @test size(H, inds(H; at=Site(i))) == 2 + @test size(H, inds(H; at=Site(i; dual=true))) == 4 + end + + @testset "Site" begin + H = MPO([rand(2, 2, 2), rand(2, 2, 2, 2), rand(2, 2, 2)]) + + @test isnothing(sites(H, Site(1); dir=:left)) + @test isnothing(sites(H, Site(3); dir=:right)) + + @test sites(H, Site(2); dir=:left) == Site(1) + @test sites(H, Site(3); dir=:left) == Site(2) + + @test sites(H, Site(2); dir=:right) == Site(3) + @test sites(H, Site(1); dir=:right) == Site(2) + end + + @testset "norm" begin + using LinearAlgebra: norm + + n = 8 + χ = 10 + H = rand(MPO; n, maxdim=χ) + + @test socket(H) == Operator() + @test nsites(H; set=:inputs) == n + @test nsites(H; set=:outputs) == n + @test issetequal(sites(H), vcat(map(Site, 1:n), map(adjoint ∘ Site, 1:n))) + @test boundary(H) == Open() + @test isapprox(norm(H), 1.0) + @test maximum(last, size(H)) <= χ + end +end diff --git a/test/MPS_test.jl b/test/MPS_test.jl new file mode 100644 index 000000000..7428af781 --- /dev/null +++ b/test/MPS_test.jl @@ -0,0 +1,258 @@ +@testset "MPS" begin + ψ = MPS([rand(2, 2), rand(2, 2, 2), rand(2, 2)]) + @test socket(ψ) == State() + @test nsites(ψ; set=:inputs) == 0 + @test nsites(ψ; set=:outputs) == 3 + @test issetequal(sites(ψ), [site"1", site"2", site"3"]) + @test boundary(ψ) == Open() + @test inds(ψ; at=site"1", dir=:left) == inds(ψ; at=site"3", dir=:right) == nothing + + arrays = [rand(2, 1), rand(2, 1, 3), rand(2, 3)] + ψ = MPS(arrays) # Default order (:o, :l, :r) + @test size(tensors(ψ; at=Site(1))) == (2, 1) + @test size(tensors(ψ; at=Site(2))) == (2, 1, 3) + @test size(tensors(ψ; at=Site(3))) == (2, 3) + @test inds(ψ; at=Site(1), dir=:left) == inds(ψ; at=Site(3), dir=:right) === nothing + @test inds(ψ; at=Site(2), dir=:left) == inds(ψ; at=Site(1), dir=:right) + @test inds(ψ; at=Site(3), dir=:left) == inds(ψ; at=Site(2), dir=:right) + + arrays = [permutedims(arrays[1], (2, 1)), permutedims(arrays[2], (3, 1, 2)), permutedims(arrays[3], (1, 2))] # now we have (:r, :o, :l) + ψ = MPS(arrays; order=[:r, :o, :l]) + @test size(tensors(ψ; at=Site(1))) == (1, 2) + @test size(tensors(ψ; at=Site(2))) == (3, 2, 1) + @test size(tensors(ψ; at=Site(3))) == (2, 3) + @test inds(ψ; at=Site(1), dir=:left) == inds(ψ; at=Site(3), dir=:right) === nothing + @test inds(ψ; at=Site(2), dir=:left) == inds(ψ; at=Site(1), dir=:right) !== nothing + @test inds(ψ; at=Site(3), dir=:left) == inds(ψ; at=Site(2), dir=:right) !== nothing + @test all(i -> size(ψ, inds(ψ; at=Site(i))) == 2, 1:nsites(ψ)) + + @testset "Site" begin + ψ = MPS([rand(2, 2), rand(2, 2, 2), rand(2, 2)]) + + @test isnothing(sites(ψ, Site(1); dir=:left)) + @test isnothing(sites(ψ, Site(3); dir=:right)) + + @test sites(ψ, Site(2); dir=:left) == Site(1) + @test sites(ψ, Site(3); dir=:left) == Site(2) + + @test sites(ψ, Site(2); dir=:right) == Site(3) + @test sites(ψ, Site(1); dir=:right) == Site(2) + end + + @testset "adjoint" begin + ψ = rand(MPS; n=3, maxdim=2, eltype=ComplexF64) + @test socket(ψ') == State(; dual=true) + @test isapprox(contract(ψ), conj(contract(ψ'))) + end + + @testset "truncate" begin + ψ = MPS([rand(2, 2), rand(2, 2, 2), rand(2, 2)]) + canonize_site!(ψ, Site(2); direction=:right, method=:svd) + + @test_throws Tenet.MissingSchmidtCoefficientsException truncate!(ψ, [Site(1), Site(2)]; maxdim=1) + # @test_throws ArgumentError truncate!(ψ, [Site(2), Site(3)]) + + truncated = truncate(ψ, [Site(2), Site(3)]; maxdim=1) + @test size(truncated, inds(truncated; at=Site(2), dir=:right)) == 1 + @test size(truncated, inds(truncated; at=Site(3), dir=:left)) == 1 + + singular_values = tensors(ψ; between=(Site(2), Site(3))) + truncated = truncate(ψ, [Site(2), Site(3)]; threshold=singular_values[2] + 0.1) + @test size(truncated, inds(truncated; at=Site(2), dir=:right)) == 1 + @test size(truncated, inds(truncated; at=Site(3), dir=:left)) == 1 + end + + @testset "norm" begin + using LinearAlgebra: norm + + n = 8 + χ = 10 + ψ = rand(MPS; n, maxdim=χ) + + @test socket(ψ) == State() + @test nsites(ψ; set=:inputs) == 0 + @test nsites(ψ; set=:outputs) == n + @test issetequal(sites(ψ), map(Site, 1:n)) + @test boundary(ψ) == Open() + @test isapprox(norm(ψ), 1.0) + @test maximum(last, size(ψ)) <= χ + end + + @testset "normalize!" begin + using LinearAlgebra: normalize! + + ψ = MPS([rand(4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4)]) + normalize!(ψ, Site(3)) + @test isapprox(norm(ψ), 1.0) + end + + @testset "canonize_site!" begin + ψ = MPS([rand(4, 4), rand(4, 4, 4), rand(4, 4)]) + + @test_throws ArgumentError canonize_site!(ψ, Site(1); direction=:left) + @test_throws ArgumentError canonize_site!(ψ, Site(3); direction=:right) + + for method in [:qr, :svd] + canonized = canonize_site(ψ, site"1"; direction=:right, method=method) + @test isleftcanonical(canonized, site"1") + @test isapprox(contract(transform(TensorNetwork(canonized), Tenet.HyperFlatten())), contract(ψ)) + + canonized = canonize_site(ψ, site"2"; direction=:right, method=method) + @test isleftcanonical(canonized, site"2") + @test isapprox(contract(transform(TensorNetwork(canonized), Tenet.HyperFlatten())), contract(ψ)) + + canonized = canonize_site(ψ, site"2"; direction=:left, method=method) + @test isrightcanonical(canonized, site"2") + @test isapprox(contract(transform(TensorNetwork(canonized), Tenet.HyperFlatten())), contract(ψ)) + + canonized = canonize_site(ψ, site"3"; direction=:left, method=method) + @test isrightcanonical(canonized, site"3") + @test isapprox(contract(transform(TensorNetwork(canonized), Tenet.HyperFlatten())), contract(ψ)) + end + + # Ensure that svd creates a new tensor + @test length(tensors(canonize_site(ψ, Site(2); direction=:left, method=:svd))) == 4 + end + + @testset "canonize!" begin + using Tenet: isleftcanonical, isrightcanonical + + ψ = MPS([rand(4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4)]) + canonized = canonize(ψ) + + @test length(tensors(canonized)) == 9 # 5 tensors + 4 singular values vectors + @test isapprox(contract(transform(TensorNetwork(canonized), Tenet.HyperFlatten())), contract(ψ)) + @test isapprox(norm(ψ), norm(canonized)) + + # Extract the singular values between each adjacent pair of sites in the canonized chain + Λ = [tensors(canonized; between=(Site(i), Site(i + 1))) for i in 1:4] + @test map(λ -> sum(abs2, λ), Λ) ≈ ones(length(Λ)) * norm(canonized)^2 + + for i in 1:5 + canonized = canonize(ψ) + + if i == 1 + @test isleftcanonical(canonized, Site(i)) + elseif i == 5 # in the limits of the chain, we get the norm of the state + contract!(canonized; between=(Site(i - 1), Site(i)), direction=:right) + tensor = tensors(canonized; at=Site(i)) + replace!(canonized, tensor => tensor / norm(canonized)) + @test isleftcanonical(canonized, Site(i)) + else + contract!(canonized; between=(Site(i - 1), Site(i)), direction=:right) + @test isleftcanonical(canonized, Site(i)) + end + end + + for i in 1:5 + canonized = canonize(ψ) + + if i == 1 # in the limits of the chain, we get the norm of the state + contract!(canonized; between=(Site(i), Site(i + 1)), direction=:left) + tensor = tensors(canonized; at=Site(i)) + replace!(canonized, tensor => tensor / norm(canonized)) + @test isrightcanonical(canonized, Site(i)) + elseif i == 5 + @test isrightcanonical(canonized, Site(i)) + else + contract!(canonized; between=(Site(i), Site(i + 1)), direction=:left) + @test isrightcanonical(canonized, Site(i)) + end + end + end + + @testset "mixed_canonize!" begin + ψ = MPS([rand(4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4)]) + canonized = mixed_canonize(ψ, Site(3)) + + @test length(tensors(canonized)) == length(tensors(ψ)) + 1 + + @test isleftcanonical(canonized, Site(1)) + @test isleftcanonical(canonized, Site(2)) + @test isrightcanonical(canonized, Site(3)) + @test isrightcanonical(canonized, Site(4)) + @test isrightcanonical(canonized, Site(5)) + + @test isapprox(contract(transform(TensorNetwork(canonized), Tenet.HyperFlatten())), contract(ψ)) + end + + @testset "expect" begin + i, j = 2, 3 + mat = reshape(kron(LinearAlgebra.I(2), LinearAlgebra.I(2)), 2, 2, 2, 2) + gate = Dense(Tenet.Operator(), mat; sites=[Site(i), Site(j), Site(i; dual=true), Site(j; dual=true)]) + ψ = MPS([rand(2, 2), rand(2, 2, 2), rand(2, 2, 2), rand(2, 2, 2), rand(2, 2)]) + + @test isapprox(expect(ψ, [gate]), norm(ψ)^2) + end + + @testset "evolve!" begin + @testset "one site" begin + i = 2 + mat = reshape(LinearAlgebra.I(2), 2, 2) + gate = Dense(Tenet.Operator(), mat; sites=[Site(i), Site(i; dual=true)]) + ψ = MPS([rand(2, 2), rand(2, 2, 2), rand(2, 2, 2), rand(2, 2, 2), rand(2, 2)]) + + @testset "canonical form" begin + canonized = canonize(ψ) + evolved = evolve!(deepcopy(canonized), gate; threshold=1e-14) + @test isapprox(contract(evolved), contract(canonized)) + @test issetequal(size.(tensors(evolved)), [(2, 2), (2,), (2, 2, 2), (2,), (2, 2, 2), (2,), (2, 2)]) + @test isapprox(contract(evolved), contract(ψ)) + end + + @testset "arbitrary chain" begin + evolved = evolve!(deepcopy(ψ), gate; threshold=1e-14, iscanonical=false) + @test length(tensors(evolved)) == 5 + @test issetequal(size.(tensors(evolved)), [(2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2)]) + @test isapprox(contract(evolved), contract(ψ)) + end + end + + @testset "two sites" begin + i, j = 2, 3 + mat = reshape(kron(LinearAlgebra.I(2), LinearAlgebra.I(2)), 2, 2, 2, 2) + gate = Dense(Tenet.Operator(), mat; sites=[Site(i), Site(j), Site(i; dual=true), Site(j; dual=true)]) + ψ = MPS([rand(2, 2), rand(2, 2, 2), rand(2, 2, 2), rand(2, 2)]) + + @testset "canonical form" begin + canonized = canonize(ψ) + evolved = evolve!(deepcopy(canonized), gate; threshold=1e-14) + @test isapprox(contract(evolved), contract(canonized)) + @test issetequal(size.(tensors(evolved)), [(2, 2), (2,), (2, 2, 2), (2,), (2, 2, 2), (2,), (2, 2)]) + @test isapprox(contract(evolved), contract(ψ)) + end + + @testset "arbitrary chain" begin + evolved = evolve!(deepcopy(ψ), gate; threshold=1e-14, iscanonical=false) + @test length(tensors(evolved)) == 5 + @test issetequal(size.(tensors(evolved)), [(2, 2), (2, 2, 2), (2,), (2, 2, 2), (2, 2, 2), (2, 2)]) + @test isapprox(contract(evolved), contract(ψ)) + end + end + end + + # TODO rename when method is renamed + @testset "contract between" begin + ψ = rand(MPS; n=5, maxdim=20) + let canonized = canonize(ψ) + @test_throws ArgumentError contract!(canonized; between=(Site(1), Site(2)), direction=:dummy) + end + + canonized = canonize(ψ) + + for i in 1:4 + contract_some = contract(canonized; between=(Site(i), Site(i + 1))) + Bᵢ = tensors(contract_some; at=Site(i)) + + @test isapprox(contract(contract_some), contract(ψ)) + @test_throws ArgumentError tensors(contract_some; between=(Site(i), Site(i + 1))) + + @test isrightcanonical(contract_some, Site(i)) + @test isleftcanonical(contract(canonized; between=(Site(i), Site(i + 1)), direction=:right), Site(i + 1)) + + Γᵢ = tensors(canonized; at=Site(i)) + Λᵢ₊₁ = tensors(canonized; between=(Site(i), Site(i + 1))) + @test Bᵢ ≈ contract(Γᵢ, Λᵢ₊₁; dims=()) + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 6a9bc8e71..d0b57277a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -11,7 +11,8 @@ using OMEinsum include("Site_test.jl") include("Quantum_test.jl") include("Product_test.jl") - include("Chain_test.jl") + include("MPS_test.jl") + include("MPO_test.jl") end # CI hangs on these tests for some unknown reason on Julia 1.9 From 5e823c33da267040fc2672770846c883e64e2126 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 18 Sep 2024 15:49:49 -0400 Subject: [PATCH 30/75] Add `Graphs`, `MetaGraphsNext` as test dependencies --- test/Project.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/Project.toml b/test/Project.toml index f8588675a..366561448 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -8,10 +8,12 @@ Dagger = "d58978e5-989f-55fb-8d15-ea34adc7bf54" DeltaArrays = "10b0fc19-5ccc-4427-889b-d75dd6306188" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2" +Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" ITensors = "9136182c-28ba-11e9-034c-db9fb085ebd5" KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" +MetaGraphsNext = "fa8bd995-216d-47f1-8a91-f3b68fbeb377" NetworkLayout = "46757867-2c16-5918-afeb-47bfcb05e46a" OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922" Permutations = "2ae35dd2-176d-5d53-8349-f30d82d94d4f" From 60ae718d6115a4f5627f1e20abe129e629a01ad2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 18 Sep 2024 16:50:39 -0400 Subject: [PATCH 31/75] Try using more `@site_str` instead of `Site` in MPS tests --- test/MPS_test.jl | 68 ++++++++++++++++++++++++------------------------ 1 file changed, 34 insertions(+), 34 deletions(-) diff --git a/test/MPS_test.jl b/test/MPS_test.jl index 7428af781..160436fab 100644 --- a/test/MPS_test.jl +++ b/test/MPS_test.jl @@ -9,34 +9,34 @@ arrays = [rand(2, 1), rand(2, 1, 3), rand(2, 3)] ψ = MPS(arrays) # Default order (:o, :l, :r) - @test size(tensors(ψ; at=Site(1))) == (2, 1) - @test size(tensors(ψ; at=Site(2))) == (2, 1, 3) - @test size(tensors(ψ; at=Site(3))) == (2, 3) - @test inds(ψ; at=Site(1), dir=:left) == inds(ψ; at=Site(3), dir=:right) === nothing - @test inds(ψ; at=Site(2), dir=:left) == inds(ψ; at=Site(1), dir=:right) - @test inds(ψ; at=Site(3), dir=:left) == inds(ψ; at=Site(2), dir=:right) + @test size(tensors(ψ; at=site"1")) == (2, 1) + @test size(tensors(ψ; at=site"2")) == (2, 1, 3) + @test size(tensors(ψ; at=site"3")) == (2, 3) + @test inds(ψ; at=site"1", dir=:left) == inds(ψ; at=site"3", dir=:right) === nothing + @test inds(ψ; at=site"2", dir=:left) == inds(ψ; at=site"1", dir=:right) + @test inds(ψ; at=site"3", dir=:left) == inds(ψ; at=site"2", dir=:right) arrays = [permutedims(arrays[1], (2, 1)), permutedims(arrays[2], (3, 1, 2)), permutedims(arrays[3], (1, 2))] # now we have (:r, :o, :l) ψ = MPS(arrays; order=[:r, :o, :l]) - @test size(tensors(ψ; at=Site(1))) == (1, 2) - @test size(tensors(ψ; at=Site(2))) == (3, 2, 1) - @test size(tensors(ψ; at=Site(3))) == (2, 3) - @test inds(ψ; at=Site(1), dir=:left) == inds(ψ; at=Site(3), dir=:right) === nothing - @test inds(ψ; at=Site(2), dir=:left) == inds(ψ; at=Site(1), dir=:right) !== nothing - @test inds(ψ; at=Site(3), dir=:left) == inds(ψ; at=Site(2), dir=:right) !== nothing + @test size(tensors(ψ; at=site"1")) == (1, 2) + @test size(tensors(ψ; at=site"2")) == (3, 2, 1) + @test size(tensors(ψ; at=site"3")) == (2, 3) + @test inds(ψ; at=site"1", dir=:left) == inds(ψ; at=site"3", dir=:right) === nothing + @test inds(ψ; at=site"2", dir=:left) == inds(ψ; at=site"1", dir=:right) !== nothing + @test inds(ψ; at=site"3", dir=:left) == inds(ψ; at=site"2", dir=:right) !== nothing @test all(i -> size(ψ, inds(ψ; at=Site(i))) == 2, 1:nsites(ψ)) @testset "Site" begin ψ = MPS([rand(2, 2), rand(2, 2, 2), rand(2, 2)]) - @test isnothing(sites(ψ, Site(1); dir=:left)) - @test isnothing(sites(ψ, Site(3); dir=:right)) + @test isnothing(sites(ψ, site"1"; dir=:left)) + @test isnothing(sites(ψ, site"3"; dir=:right)) - @test sites(ψ, Site(2); dir=:left) == Site(1) - @test sites(ψ, Site(3); dir=:left) == Site(2) + @test sites(ψ, site"2"; dir=:left) == site"1" + @test sites(ψ, site"3"; dir=:left) == site"2" - @test sites(ψ, Site(2); dir=:right) == Site(3) - @test sites(ψ, Site(1); dir=:right) == Site(2) + @test sites(ψ, site"2"; dir=:right) == site"3" + @test sites(ψ, site"1"; dir=:right) == site"2" end @testset "adjoint" begin @@ -49,17 +49,17 @@ ψ = MPS([rand(2, 2), rand(2, 2, 2), rand(2, 2)]) canonize_site!(ψ, Site(2); direction=:right, method=:svd) - @test_throws Tenet.MissingSchmidtCoefficientsException truncate!(ψ, [Site(1), Site(2)]; maxdim=1) - # @test_throws ArgumentError truncate!(ψ, [Site(2), Site(3)]) + @test_throws Tenet.MissingSchmidtCoefficientsException truncate!(ψ, [site"1", site"2"]; maxdim=1) + # @test_throws ArgumentError truncate!(ψ, [site"2", site"3"]) - truncated = truncate(ψ, [Site(2), Site(3)]; maxdim=1) - @test size(truncated, inds(truncated; at=Site(2), dir=:right)) == 1 - @test size(truncated, inds(truncated; at=Site(3), dir=:left)) == 1 + truncated = truncate(ψ, [site"2", site"3"]; maxdim=1) + @test size(truncated, inds(truncated; at=site"2", dir=:right)) == 1 + @test size(truncated, inds(truncated; at=site"3", dir=:left)) == 1 - singular_values = tensors(ψ; between=(Site(2), Site(3))) - truncated = truncate(ψ, [Site(2), Site(3)]; threshold=singular_values[2] + 0.1) - @test size(truncated, inds(truncated; at=Site(2), dir=:right)) == 1 - @test size(truncated, inds(truncated; at=Site(3), dir=:left)) == 1 + singular_values = tensors(ψ; between=(site"2", site"3")) + truncated = truncate(ψ, [site"2", site"3"]; threshold=singular_values[2] + 0.1) + @test size(truncated, inds(truncated; at=site"2", dir=:right)) == 1 + @test size(truncated, inds(truncated; at=site"3", dir=:left)) == 1 end @testset "norm" begin @@ -163,15 +163,15 @@ @testset "mixed_canonize!" begin ψ = MPS([rand(4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4)]) - canonized = mixed_canonize(ψ, Site(3)) + canonized = mixed_canonize(ψ, site"3") @test length(tensors(canonized)) == length(tensors(ψ)) + 1 - @test isleftcanonical(canonized, Site(1)) - @test isleftcanonical(canonized, Site(2)) - @test isrightcanonical(canonized, Site(3)) - @test isrightcanonical(canonized, Site(4)) - @test isrightcanonical(canonized, Site(5)) + @test isleftcanonical(canonized, site"1") + @test isleftcanonical(canonized, site"2") + @test isrightcanonical(canonized, site"3") + @test isrightcanonical(canonized, site"4") + @test isrightcanonical(canonized, site"5") @test isapprox(contract(transform(TensorNetwork(canonized), Tenet.HyperFlatten())), contract(ψ)) end @@ -235,7 +235,7 @@ @testset "contract between" begin ψ = rand(MPS; n=5, maxdim=20) let canonized = canonize(ψ) - @test_throws ArgumentError contract!(canonized; between=(Site(1), Site(2)), direction=:dummy) + @test_throws ArgumentError contract!(canonized; between=(site"1", site"2"), direction=:dummy) end canonized = canonize(ψ) From cca0cff310d68a122ccbad834838ab3f428e88f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 18 Sep 2024 16:51:15 -0400 Subject: [PATCH 32/75] Implement some `sites`, `inds` methods for `MPO` --- src/Ansatz/MPO.jl | 24 ++++++++++++++++++++++++ src/Ansatz/MPS.jl | 1 + 2 files changed, 25 insertions(+) diff --git a/src/Ansatz/MPO.jl b/src/Ansatz/MPO.jl index 44d704618..395a1d665 100644 --- a/src/Ansatz/MPO.jl +++ b/src/Ansatz/MPO.jl @@ -108,4 +108,28 @@ function Base.rand(rng::Random.AbstractRNG, ::Type{MPO}; n, maxdim, eltype=Float return MPO(arrays; order=(:l, :i, :o, :r)) end +# TODO change it to `lanes`? +# TODO refactor common code with `MPS` +function sites(ψ::MPO, site::Site; dir) + if dir === :left + return site <= site"1" ? nothing : Site(id(site) - 1) + elseif dir === :right + return site >= Site(nlanes(ψ)) ? nothing : Site(id(site) + 1) + else + throw(ArgumentError("Unknown direction for MPO = :$dir")) + end +end + +@kwmethod function inds(ψ::MPO; at, dir) + if dir === :left && at == site"1" + return nothing + elseif dir === :right && at == Site(nlanes(ψ); dual=isdual(at)) + return nothing + elseif dir ∈ (:left, :right) + return inds(ψ; bond=(at, sites(ψ, at; dir))) + else + throw(ArgumentError("Unknown direction for MPO = :$dir")) + end +end + function evolve!(ψ::MPS, op::MPO; threshold=nothing, maxdim=nothing, renormalize=false) end diff --git a/src/Ansatz/MPS.jl b/src/Ansatz/MPS.jl index 6b83faf16..8a9acc2c6 100644 --- a/src/Ansatz/MPS.jl +++ b/src/Ansatz/MPS.jl @@ -139,6 +139,7 @@ end @kwmethod contract(tn::MPS; between, direction) = contract(tn; between, direction, delete_Λ=true) @kwmethod contract!(tn::MPS; between, direction) = contract!(tn; between, direction, delete_Λ=true) +# TODO rename it to `lanes`? function sites(ψ::MPS, site::Site; dir) if dir === :left return site == site"1" ? nothing : Site(id(site) - 1) From 24800b860ceadd5ba1463c1f2dcc26cc950f7be8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 18 Sep 2024 16:52:18 -0400 Subject: [PATCH 33/75] Try using more `@site_str` in MPO tests --- test/MPO_test.jl | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/test/MPO_test.jl b/test/MPO_test.jl index 1d461597b..f6e6fcc67 100644 --- a/test/MPO_test.jl +++ b/test/MPO_test.jl @@ -10,13 +10,13 @@ arrays = [rand(2, 4, 1), rand(2, 4, 1, 3), rand(2, 4, 3)] # Default order (:o :i, :l, :r) H = MPO(arrays) - @test size(tensors(H; at=Site(1))) == (2, 4, 1) - @test size(tensors(H; at=Site(2))) == (2, 4, 1, 3) - @test size(tensors(H; at=Site(3))) == (2, 4, 3) + @test size(tensors(H; at=site"1")) == (2, 4, 1) + @test size(tensors(H; at=site"2")) == (2, 4, 1, 3) + @test size(tensors(H; at=site"3")) == (2, 4, 3) - @test inds(H; at=Site(1), dir=:left) == inds(H; at=Site(3), dir=:right) === nothing - @test inds(H; at=Site(2), dir=:left) == inds(H; at=Site(1), dir=:right) !== nothing - @test inds(H; at=Site(3), dir=:left) == inds(H; at=Site(2), dir=:right) !== nothing + @test inds(H; at=site"1", dir=:left) == inds(H; at=site"3", dir=:right) === nothing + @test inds(H; at=site"2", dir=:left) == inds(H; at=site"1", dir=:right) !== nothing + @test inds(H; at=site"3", dir=:left) == inds(H; at=site"2", dir=:right) !== nothing for i in 1:length(arrays) @test size(H, inds(H; at=Site(i))) == 2 @@ -28,13 +28,13 @@ ] # now we have (:r, :o, :l, :i) H = MPO(arrays; order=[:r, :o, :l, :i]) - @test size(tensors(H; at=Site(1))) == (1, 2, 4) - @test size(tensors(H; at=Site(2))) == (3, 2, 1, 4) - @test size(tensors(H; at=Site(3))) == (2, 3, 4) + @test size(tensors(H; at=site"1")) == (1, 2, 4) + @test size(tensors(H; at=site"2")) == (3, 2, 1, 4) + @test size(tensors(H; at=site"3")) == (2, 3, 4) - @test inds(H; at=Site(1), dir=:left) == inds(H; at=Site(3), dir=:right) === nothing - @test inds(H; at=Site(2), dir=:left) == inds(H; at=Site(1), dir=:right) !== nothing - @test inds(H; at=Site(3), dir=:left) == inds(H; at=Site(2), dir=:right) !== nothing + @test inds(H; at=site"1", dir=:left) == inds(H; at=site"3", dir=:right) === nothing + @test inds(H; at=site"2", dir=:left) == inds(H; at=site"1", dir=:right) !== nothing + @test inds(H; at=site"3", dir=:left) == inds(H; at=site"2", dir=:right) !== nothing for i in 1:length(arrays) @test size(H, inds(H; at=Site(i))) == 2 @@ -44,14 +44,14 @@ @testset "Site" begin H = MPO([rand(2, 2, 2), rand(2, 2, 2, 2), rand(2, 2, 2)]) - @test isnothing(sites(H, Site(1); dir=:left)) - @test isnothing(sites(H, Site(3); dir=:right)) + @test isnothing(sites(H, site"1"; dir=:left)) + @test isnothing(sites(H, site"3"; dir=:right)) - @test sites(H, Site(2); dir=:left) == Site(1) - @test sites(H, Site(3); dir=:left) == Site(2) + @test sites(H, site"2"; dir=:left) == site"1" + @test sites(H, site"3"; dir=:left) == site"2" - @test sites(H, Site(2); dir=:right) == Site(3) - @test sites(H, Site(1); dir=:right) == Site(2) + @test sites(H, site"2"; dir=:right) == site"3" + @test sites(H, site"1"; dir=:right) == site"2" end @testset "norm" begin From 6f8a0a069f307d478a5ad6bda361497094edf864 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 18 Sep 2024 17:18:37 -0400 Subject: [PATCH 34/75] Fix `tensors(; bond)` --- src/Ansatz/Ansatz.jl | 10 +++++++++- src/TensorNetwork.jl | 3 ++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/Ansatz/Ansatz.jl b/src/Ansatz/Ansatz.jl index ef882ddef..5fb59fba7 100644 --- a/src/Ansatz/Ansatz.jl +++ b/src/Ansatz/Ansatz.jl @@ -81,7 +81,15 @@ end return only(inds(tensor1) ∩ inds(tensor2)) end -@kwmethod tensors(tn::AbstractAnsatz; bond) = tn[inds(tn; bond)] +@kwmethod function tensors(tn::AbstractAnsatz; bond) + vind = inds(tn; bond) + return only( + tensors(tn, [vind]) do vinds, indices + indices == vinds + end, + ) +end + @kwmethod function tensors(tn::AbstractAnsatz; between) Base.depwarn( "`tensors(tn; between)` is deprecated, use `tensors(tn; bond)` instead.", diff --git a/src/TensorNetwork.jl b/src/TensorNetwork.jl index 826be2a34..78f9920f2 100644 --- a/src/TensorNetwork.jl +++ b/src/TensorNetwork.jl @@ -182,7 +182,8 @@ end return tensors(!isdisjoint, TensorNetwork(tn), intersects) end -function tensors(selector, tn::TensorNetwork, is::AbstractVecOrTuple{Symbol}) +function tensors(selector, tn::AbstractTensorNetwork, is::AbstractVecOrTuple{Symbol}) + tn = TensorNetwork(tn) return filter(Base.Fix1(selector, is) ∘ inds, tn.indexmap[first(is)]) end From 955361c7e8cf87b597782576dd49f0d8b39bef2b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 18 Sep 2024 17:25:16 -0400 Subject: [PATCH 35/75] Fix typo in `mixed_canonize!` --- src/Ansatz/MPS.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Ansatz/MPS.jl b/src/Ansatz/MPS.jl index 8a9acc2c6..cd1e980f7 100644 --- a/src/Ansatz/MPS.jl +++ b/src/Ansatz/MPS.jl @@ -285,17 +285,17 @@ and in the `orthog_center` there is a matrix with singular values. """ function mixed_canonize!(tn::MPS, orthog_center) # left-to-right QR sweep (left-canonical tensors) - for i in 1:(id(center) - 1) + for i in 1:(id(orthog_center) - 1) canonize_site!(tn, Site(i); direction=:right, method=:qr) end # right-to-left QR sweep (right-canonical tensors) - for i in nsites(tn):-1:(id(center) + 1) + for i in nsites(tn):-1:(id(orthog_center) + 1) canonize_site!(tn, Site(i); direction=:left, method=:qr) end # center SVD sweep to get singular values - canonize_site!(tn, center; direction=:left, method=:svd) + canonize_site!(tn, orthog_center; direction=:left, method=:svd) return tn end From 531002f1642679fb4a2f7b9306b4c97d904374cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 18 Sep 2024 17:27:00 -0400 Subject: [PATCH 36/75] Export `isleftcanonical`, `isrightcanonical` --- src/Tenet.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/Tenet.jl b/src/Tenet.jl index c9a7797d7..9e6338982 100644 --- a/src/Tenet.jl +++ b/src/Tenet.jl @@ -46,8 +46,10 @@ export PEPS include("Ansatz/PEPO.jl") export PEPO -export canonize_site, canonize_site!, canonize, canonize!, mixed_canonize, mixed_canonize!, truncate! -export evolve!, expect, overlap +# `truncate` not exported because it clashes with `Base.truncate` +export canonize_site, canonize_site!, canonize, canonize!, mixed_canonize, mixed_canonize! +export isleftcanonical, isrightcanonical +export evolve!, expect, overlap, truncate! # reexports from EinExprs export einexpr, inds From efc1c11b230bb2d969928d88d63b7a4bfd9f8204 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 18 Sep 2024 17:27:40 -0400 Subject: [PATCH 37/75] Fix `truncate!` --- src/Ansatz/Ansatz.jl | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/Ansatz/Ansatz.jl b/src/Ansatz/Ansatz.jl index 5fb59fba7..bd09222c2 100644 --- a/src/Ansatz/Ansatz.jl +++ b/src/Ansatz/Ansatz.jl @@ -127,10 +127,16 @@ function truncate!(tn::AbstractAnsatz, bond; threshold=nothing, maxdim=nothing) vind = inds(tn; bond) maxdim = isnothing(maxdim) ? size(tn, vind) : maxdim - threshold = isnothing(threshold) ? 1e-16 : threshold - extent = findfirst(1:maxdim) do i - abs(spectrum[i]) < threshold + extent = if isnothing(threshold) + 1:maxdim + else + 1:something( + findfirst(1:maxdim) do i + abs(spectrum[i]) < threshold + end, + maxdim, + ) end slice!(tn, vind, extent) From 08b3d454ceb385f2424efd4c0c3c706503026516 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 18 Sep 2024 17:28:01 -0400 Subject: [PATCH 38/75] Fix `truncate` tests on `MPS` --- test/MPS_test.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/MPS_test.jl b/test/MPS_test.jl index 160436fab..ba22dee3d 100644 --- a/test/MPS_test.jl +++ b/test/MPS_test.jl @@ -49,15 +49,15 @@ ψ = MPS([rand(2, 2), rand(2, 2, 2), rand(2, 2)]) canonize_site!(ψ, Site(2); direction=:right, method=:svd) - @test_throws Tenet.MissingSchmidtCoefficientsException truncate!(ψ, [site"1", site"2"]; maxdim=1) - # @test_throws ArgumentError truncate!(ψ, [site"2", site"3"]) + # @test_throws Tenet.MissingSchmidtCoefficientsException truncate!(ψ, [site"1", site"2"]; maxdim=1) + @test_throws ArgumentError truncate!(ψ, [site"1", site"2"]; maxdim=1) - truncated = truncate(ψ, [site"2", site"3"]; maxdim=1) + truncated = Tenet.truncate(ψ, [site"2", site"3"]; maxdim=1) @test size(truncated, inds(truncated; at=site"2", dir=:right)) == 1 @test size(truncated, inds(truncated; at=site"3", dir=:left)) == 1 singular_values = tensors(ψ; between=(site"2", site"3")) - truncated = truncate(ψ, [site"2", site"3"]; threshold=singular_values[2] + 0.1) + truncated = Tenet.truncate(ψ, [site"2", site"3"]; threshold=singular_values[2] + 0.1) @test size(truncated, inds(truncated; at=site"2", dir=:right)) == 1 @test size(truncated, inds(truncated; at=site"3", dir=:left)) == 1 end From c986ae2801bb4de548e05be852724f16c7ba5c02 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 18 Sep 2024 17:37:15 -0400 Subject: [PATCH 39/75] Fix `Dense` constructors --- src/Ansatz/Dense.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/Ansatz/Dense.jl b/src/Ansatz/Dense.jl index 0d8176d48..936794995 100644 --- a/src/Ansatz/Dense.jl +++ b/src/Ansatz/Dense.jl @@ -28,9 +28,9 @@ function Dense(::State, array::AbstractArray; sites=Site.(1:ndims(array))) tn = TensorNetwork([tensor]) qtn = Quantum(tn, sitemap) lattice = MetaGraph( - complete_graph(n), - Pair{Site,Nothing}[Site(i) => nothing for i in 1:n], - Pair{Tuple{Site,Site},Nothing}[(Site(i), Site(j)) => nothing for (i, j) in combinations(1:n, 2)], + complete_graph(nlanes(qtn)), + Pair{Site,Nothing}[i => nothing for i in lanes(qtn)], + Pair{Tuple{Site,Site},Nothing}[(i, j) => nothing for (i, j) in combinations(lanes(qtn), 2)], ) ansatz = Ansatz(qtn, lattice) return Dense(ansatz) @@ -50,9 +50,9 @@ function Dense(::Operator, array::AbstractArray; sites) sitemap = Dict{Site,Symbol}(map(splat(Pair), zip(sites, tensor_inds))) qtn = Quantum(tn, sitemap) lattice = MetaGraph( - complete_graph(n), - Pair{Site,Nothing}[Site(i) => nothing for i in 1:n], - Pair{Tuple{Site,Site},Nothing}[(Site(i), Site(j)) => nothing for (i, j) in combinations(1:n, 2)], + complete_graph(nlanes(qtn)), + Pair{Site,Nothing}[i => nothing for i in lanes(qtn)], + Pair{Tuple{Site,Site},Nothing}[(i, j) => nothing for (i, j) in combinations(lanes(qtn), 2)], ) ansatz = Ansatz(qtn, lattice) return Dense(ansatz) From f03bb1827e6efb30b28e7f96d4bf4bbd79b533ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 18 Sep 2024 17:37:35 -0400 Subject: [PATCH 40/75] Fix `truncate!` extension when using `threshold` --- src/Ansatz/Ansatz.jl | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/Ansatz/Ansatz.jl b/src/Ansatz/Ansatz.jl index bd09222c2..22b2344f1 100644 --- a/src/Ansatz/Ansatz.jl +++ b/src/Ansatz/Ansatz.jl @@ -131,12 +131,9 @@ function truncate!(tn::AbstractAnsatz, bond; threshold=nothing, maxdim=nothing) extent = if isnothing(threshold) 1:maxdim else - 1:something( - findfirst(1:maxdim) do i + 1:something(findfirst(1:maxdim) do i abs(spectrum[i]) < threshold - end, - maxdim, - ) + end - 1, maxdim) end slice!(tn, vind, extent) From 7db568d10914eabe890110833c61bcc3430c9235 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 18 Sep 2024 17:38:06 -0400 Subject: [PATCH 41/75] Refactor some tests of `MPS` to simplify --- test/MPS_test.jl | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/test/MPS_test.jl b/test/MPS_test.jl index ba22dee3d..89b91f7f7 100644 --- a/test/MPS_test.jl +++ b/test/MPS_test.jl @@ -53,13 +53,11 @@ @test_throws ArgumentError truncate!(ψ, [site"1", site"2"]; maxdim=1) truncated = Tenet.truncate(ψ, [site"2", site"3"]; maxdim=1) - @test size(truncated, inds(truncated; at=site"2", dir=:right)) == 1 - @test size(truncated, inds(truncated; at=site"3", dir=:left)) == 1 + @test size(truncated, inds(truncated; bond=[site"2", site"3"])) == 1 singular_values = tensors(ψ; between=(site"2", site"3")) truncated = Tenet.truncate(ψ, [site"2", site"3"]; threshold=singular_values[2] + 0.1) - @test size(truncated, inds(truncated; at=site"2", dir=:right)) == 1 - @test size(truncated, inds(truncated; at=site"3", dir=:left)) == 1 + @test size(truncated, inds(truncated; bond=[site"2", site"3"])) == 1 end @testset "norm" begin @@ -209,9 +207,8 @@ end @testset "two sites" begin - i, j = 2, 3 mat = reshape(kron(LinearAlgebra.I(2), LinearAlgebra.I(2)), 2, 2, 2, 2) - gate = Dense(Tenet.Operator(), mat; sites=[Site(i), Site(j), Site(i; dual=true), Site(j; dual=true)]) + gate = Dense(Tenet.Operator(), mat; sites=[site"2", site"3", site"2'", site"3'"]) ψ = MPS([rand(2, 2), rand(2, 2, 2), rand(2, 2, 2), rand(2, 2)]) @testset "canonical form" begin From fba30e907be949e01fac890c517fc10afd55f7b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 18 Sep 2024 17:38:21 -0400 Subject: [PATCH 42/75] Format code --- src/Ansatz/Ansatz.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Ansatz/Ansatz.jl b/src/Ansatz/Ansatz.jl index 22b2344f1..709e25108 100644 --- a/src/Ansatz/Ansatz.jl +++ b/src/Ansatz/Ansatz.jl @@ -132,7 +132,7 @@ function truncate!(tn::AbstractAnsatz, bond; threshold=nothing, maxdim=nothing) 1:maxdim else 1:something(findfirst(1:maxdim) do i - abs(spectrum[i]) < threshold + abs(spectrum[i]) < threshold end - 1, maxdim) end From edc0a1627fbaeec3ad1335ace9e1a8666a54022b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 18 Sep 2024 17:38:56 -0400 Subject: [PATCH 43/75] Fix typo in `normalize!` on `MPS` method --- src/Ansatz/MPS.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Ansatz/MPS.jl b/src/Ansatz/MPS.jl index cd1e980f7..eeb227570 100644 --- a/src/Ansatz/MPS.jl +++ b/src/Ansatz/MPS.jl @@ -303,6 +303,6 @@ end # TODO normalize! methods function LinearAlgebra.normalize!(ψ::MPS, orthog_center=site"1") mixed_canonize!(ψ, orthog_center) - normalize!(tensors(tn; between=(Site(id(root) - 1), root)), 2) + normalize!(tensors(tn; between=(Site(id(orthog_center) - 1), orthog_center)), 2) return ψ end From 8dc2795a8bcd0adbee48c7d537546d9c0d5e2634 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 18 Sep 2024 18:04:27 -0400 Subject: [PATCH 44/75] Fix typo --- src/Ansatz/MPS.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Ansatz/MPS.jl b/src/Ansatz/MPS.jl index eeb227570..b43dc8f79 100644 --- a/src/Ansatz/MPS.jl +++ b/src/Ansatz/MPS.jl @@ -303,6 +303,6 @@ end # TODO normalize! methods function LinearAlgebra.normalize!(ψ::MPS, orthog_center=site"1") mixed_canonize!(ψ, orthog_center) - normalize!(tensors(tn; between=(Site(id(orthog_center) - 1), orthog_center)), 2) + normalize!(tensors(ψ; between=(Site(id(orthog_center) - 1), orthog_center)), 2) return ψ end From c3e7ca0ddfec36eebd63fea1a304b81ee0626350 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 18 Sep 2024 18:14:28 -0400 Subject: [PATCH 45/75] Deprecate `isleftcanonical`, `isrightcanonical` in favor of `isisometry` --- src/Ansatz/MPS.jl | 42 ++++++++++++++---------------------------- src/Tenet.jl | 2 +- 2 files changed, 15 insertions(+), 29 deletions(-) diff --git a/src/Ansatz/MPS.jl b/src/Ansatz/MPS.jl index b43dc8f79..6eb11ef55 100644 --- a/src/Ansatz/MPS.jl +++ b/src/Ansatz/MPS.jl @@ -162,42 +162,28 @@ end end end -function isleftcanonical(ψ::MPS, site; atol::Real=1e-12) - right_ind = inds(ψ; at=site, dir=:right) +function isisometry(ψ::MPS, site; dir, atol::Real=1e-12) tensor = tensors(ψ; at=site) + dirind = inds(ψ; at=site, dir) - # we are at right-most site, we need to add an extra dummy dimension to the tensor - if isnothing(right_ind) - right_ind = gensym(:dummy) - tensor = Tensor(reshape(parent(tensor), size(tensor)..., 1), (inds(tensor)..., right_ind)) + if isnothing(dirind) + @show parent(contract(tensor, conj(tensor))) + return isapprox(parent(contract(tensor, conj(tensor))), fill(true); atol) end - # TODO is replace(conj(A)...) copying too much? - contracted = contract(tensor, replace(conj(tensor), right_ind => gensym(:new_ind))) - n = size(tensor, right_ind) - identity_matrix = Matrix(I, n, n) + inda, indb = gensym(:a), gensym(:b) + a = replace(tensor, dirind => inda) + b = replace(conj(tensor), dirind => indb) - return isapprox(contracted, identity_matrix; atol) -end - -function isrightcanonical(ψ::MPS, site; atol::Real=1e-12) - left_ind = inds(ψ; at=site, dir=:left) - tensor = tensors(ψ; at=site) + n = size(tensor, dirind) + contracted = contract(a, b; out=[inda, indb]) - # we are at left-most site, we need to add an extra dummy dimension to the tensor - if isnothing(left_ind) - left_ind = gensym(:dummy) - tensor = Tensor(reshape(parent(tensor), 1, size(tensor)...), (left_ind, inds(tensor)...)) - end - - #TODO is replace(conj(A)...) copying too much? - contracted = contract(tensor, replace(conj(tensor), left_ind => gensym(:new_ind))) - n = size(tensor, left_ind) - identity_matrix = Matrix(I, n, n) - - return isapprox(contracted, identity_matrix; atol) + return isapprox(contracted, I(n); atol) end +@deprecate isleftcanonical(ψ::MPS, site; atol::Real=1e-12) isisometry(ψ, site; dir=:right, atol) +@deprecate isrightcanonical(ψ::MPS, site; atol::Real=1e-12) isisometry(ψ, site; dir=:left, atol) + # NOTE: in method == :svd the spectral weights are stored in a vector connected to the now virtual hyperindex! function canonize_site!(ψ::MPS, site::Site; direction::Symbol, method=:qr) left_inds = Symbol[] diff --git a/src/Tenet.jl b/src/Tenet.jl index 9e6338982..7effdbf68 100644 --- a/src/Tenet.jl +++ b/src/Tenet.jl @@ -48,7 +48,7 @@ export PEPO # `truncate` not exported because it clashes with `Base.truncate` export canonize_site, canonize_site!, canonize, canonize!, mixed_canonize, mixed_canonize! -export isleftcanonical, isrightcanonical +export isisometry, isleftcanonical, isrightcanonical export evolve!, expect, overlap, truncate! # reexports from EinExprs From 380d7d30de1be3845f1d0338d8351d18b06f6a2a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 18 Sep 2024 18:14:43 -0400 Subject: [PATCH 46/75] Comment `renormalize` kwarg of `evolve!` --- src/Ansatz/Ansatz.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/Ansatz/Ansatz.jl b/src/Ansatz/Ansatz.jl index 709e25108..f7433d7bf 100644 --- a/src/Ansatz/Ansatz.jl +++ b/src/Ansatz/Ansatz.jl @@ -155,7 +155,8 @@ end overlap(a::AbstractAnsatz, b::AbstractAnsatz) = contract(merge(a, copy(b)')) function evolve!(ψ::AbstractAnsatz, gate; threshold=nothing, maxdim=nothing, renormalize=false) - return simple_update!(ψ, gate; threshold, maxdim, renormalize) + # TODO renormalize not yet implemented + return simple_update!(ψ, gate; threshold, maxdim) end # by popular demand (Stefano, I'm looking at you), I aliased `apply!` to `evolve!` From 3cb2db35c075c126a0e5632debf205255693b478 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 19 Sep 2024 08:42:21 -0400 Subject: [PATCH 47/75] Fix `simple_update!` on single site gates --- src/Ansatz/Ansatz.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/Ansatz/Ansatz.jl b/src/Ansatz/Ansatz.jl index f7433d7bf..cac842871 100644 --- a/src/Ansatz/Ansatz.jl +++ b/src/Ansatz/Ansatz.jl @@ -164,12 +164,13 @@ const apply! = evolve! function simple_update!(ψ::AbstractAnsatz, gate; threshold=nothing, maxdim=nothing, kwargs...) @assert issetequal(adjoint.(sites(gate; set=:inputs)), sites(gate; set=:outputs)) "Inputs of the gate must match outputs" - @assert isneighbor(ψ, lanes(gate)...) "Gate must act on neighboring sites" if nlanes(gate) == 1 return simple_update_1site!(ψ, gate) end + @assert isneighbor(ψ, lanes(gate)...) "Gate must act on neighboring sites" + return simple_update!(form(ψ), ψ, gate; kwargs...) end From 53122c181fb83a329312691c88ed021536046932 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 19 Sep 2024 08:43:58 -0400 Subject: [PATCH 48/75] Fix `isleftcanonical`, `isrightcanonical` tests on boundary sites --- test/MPS_test.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/test/MPS_test.jl b/test/MPS_test.jl index 89b91f7f7..a35aec258 100644 --- a/test/MPS_test.jl +++ b/test/MPS_test.jl @@ -132,9 +132,8 @@ if i == 1 @test isleftcanonical(canonized, Site(i)) elseif i == 5 # in the limits of the chain, we get the norm of the state + normalize!(tensors(canonized; bond=(Site(i - 1), Site(i)))) contract!(canonized; between=(Site(i - 1), Site(i)), direction=:right) - tensor = tensors(canonized; at=Site(i)) - replace!(canonized, tensor => tensor / norm(canonized)) @test isleftcanonical(canonized, Site(i)) else contract!(canonized; between=(Site(i - 1), Site(i)), direction=:right) @@ -146,9 +145,8 @@ canonized = canonize(ψ) if i == 1 # in the limits of the chain, we get the norm of the state + normalize!(tensors(canonized; bond=(Site(i), Site(i + 1)))) contract!(canonized; between=(Site(i), Site(i + 1)), direction=:left) - tensor = tensors(canonized; at=Site(i)) - replace!(canonized, tensor => tensor / norm(canonized)) @test isrightcanonical(canonized, Site(i)) elseif i == 5 @test isrightcanonical(canonized, Site(i)) From daec4d1964b8099218883b7477c4d7cd599b4078 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 19 Sep 2024 12:19:32 -0400 Subject: [PATCH 49/75] Fix `evolve!` calls in tests --- test/MPS_test.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/MPS_test.jl b/test/MPS_test.jl index a35aec258..43b665bb8 100644 --- a/test/MPS_test.jl +++ b/test/MPS_test.jl @@ -197,7 +197,7 @@ end @testset "arbitrary chain" begin - evolved = evolve!(deepcopy(ψ), gate; threshold=1e-14, iscanonical=false) + evolved = evolve!(deepcopy(ψ), gate; threshold=1e-14) @test length(tensors(evolved)) == 5 @test issetequal(size.(tensors(evolved)), [(2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2)]) @test isapprox(contract(evolved), contract(ψ)) @@ -218,7 +218,7 @@ end @testset "arbitrary chain" begin - evolved = evolve!(deepcopy(ψ), gate; threshold=1e-14, iscanonical=false) + evolved = evolve!(deepcopy(ψ), gate; threshold=1e-14) @test length(tensors(evolved)) == 5 @test issetequal(size.(tensors(evolved)), [(2, 2), (2, 2, 2), (2,), (2, 2, 2), (2, 2, 2), (2, 2)]) @test isapprox(contract(evolved), contract(ψ)) From c8fe5dae89e7531c25df0d7829ffaafba87090ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 19 Sep 2024 13:51:35 -0400 Subject: [PATCH 50/75] Fix indexing problems in `simple_update_1site!` --- src/Ansatz/Ansatz.jl | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/Ansatz/Ansatz.jl b/src/Ansatz/Ansatz.jl index cac842871..89d070f2f 100644 --- a/src/Ansatz/Ansatz.jl +++ b/src/Ansatz/Ansatz.jl @@ -174,6 +174,7 @@ function simple_update!(ψ::AbstractAnsatz, gate; threshold=nothing, maxdim=noth return simple_update!(form(ψ), ψ, gate; kwargs...) end +# TODO a lot of problems with merging... maybe we don't to merge manually function simple_update_1site!(ψ::AbstractAnsatz, gate) @assert nlanes(gate) == 1 "Gate must act only on one lane" @assert ninputs(gate) == 1 "Gate must have only one input" @@ -181,19 +182,20 @@ function simple_update_1site!(ψ::AbstractAnsatz, gate) # shallow copy to avoid problems if errors in mid execution gate = copy(gate) + resetindex!(gate; init=ninds(ψ)) contracting_index = gensym(:tmp) targetsite = only(sites(gate; set=:inputs))' + # reindex output of gate to match TN sitemap + replace!(gate, inds(gate; at=only(sites(gate; set=:outputs))) => inds(ψ; at=targetsite)) + # reindex contracting index replace!(ψ, inds(ψ; at=targetsite) => contracting_index) replace!(gate, inds(gate; at=targetsite') => contracting_index) - # reindex output of gate to match TN sitemap - replace!(gate, inds(gate; at=only(sites(gate; set=:outputs))) => inds(ψ; at=targetsite)) - # contract gate with TN - merge!(ψ, gate) + merge!(ψ, gate; reset=false) return contract!(ψ, contracting_index) end From 0a5408a9048f0735a70909b795e6786427883b53 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 22 Sep 2024 01:09:53 -0400 Subject: [PATCH 51/75] Refactor MPO tests --- test/MPO_test.jl | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/test/MPO_test.jl b/test/MPO_test.jl index f6e6fcc67..1f7332aca 100644 --- a/test/MPO_test.jl +++ b/test/MPO_test.jl @@ -7,8 +7,8 @@ @test boundary(H) == Open() @test inds(H; at=site"1", dir=:left) == inds(H; at=site"3", dir=:right) == nothing - arrays = [rand(2, 4, 1), rand(2, 4, 1, 3), rand(2, 4, 3)] # Default order (:o :i, :l, :r) - H = MPO(arrays) + # Default order (:o :i, :l, :r) + H = MPO([rand(2, 4, 1), rand(2, 4, 1, 3), rand(2, 4, 3)]) @test size(tensors(H; at=site"1")) == (2, 4, 1) @test size(tensors(H; at=site"2")) == (2, 4, 1, 3) @@ -23,10 +23,15 @@ @test size(H, inds(H; at=Site(i; dual=true))) == 4 end - arrays = [ - permutedims(arrays[1], (3, 1, 2)), permutedims(arrays[2], (4, 1, 3, 2)), permutedims(arrays[3], (1, 3, 2)) - ] # now we have (:r, :o, :l, :i) - H = MPO(arrays; order=[:r, :o, :l, :i]) + # now we have (:r, :o, :l, :i) + H = MPO( + [ + permutedims(arrays(H)[1], (3, 1, 2)), + permutedims(arrays(H)[2], (4, 1, 3, 2)), + permutedims(arrays(H)[3], (1, 3, 2)), + ]; + order=[:r, :o, :l, :i], + ) @test size(tensors(H; at=site"1")) == (1, 2, 4) @test size(tensors(H; at=site"2")) == (3, 2, 1, 4) From cab5fa445d1fa3388b0a4fb9ca512cf3c251f42d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 22 Sep 2024 01:11:45 -0400 Subject: [PATCH 52/75] Some fixes on `simple_update!` --- src/Ansatz/Ansatz.jl | 56 ++++++++++++++++++++++++++------------------ 1 file changed, 33 insertions(+), 23 deletions(-) diff --git a/src/Ansatz/Ansatz.jl b/src/Ansatz/Ansatz.jl index 89d070f2f..9a29edf77 100644 --- a/src/Ansatz/Ansatz.jl +++ b/src/Ansatz/Ansatz.jl @@ -155,8 +155,7 @@ end overlap(a::AbstractAnsatz, b::AbstractAnsatz) = contract(merge(a, copy(b)')) function evolve!(ψ::AbstractAnsatz, gate; threshold=nothing, maxdim=nothing, renormalize=false) - # TODO renormalize not yet implemented - return simple_update!(ψ, gate; threshold, maxdim) + return simple_update!(ψ, gate; threshold, maxdim, renormalize) end # by popular demand (Stefano, I'm looking at you), I aliased `apply!` to `evolve!` @@ -174,7 +173,7 @@ function simple_update!(ψ::AbstractAnsatz, gate; threshold=nothing, maxdim=noth return simple_update!(form(ψ), ψ, gate; kwargs...) end -# TODO a lot of problems with merging... maybe we don't to merge manually +# TODO a lot of problems with merging... maybe we shouldn't merge manually function simple_update_1site!(ψ::AbstractAnsatz, gate) @assert nlanes(gate) == 1 "Gate must act only on one lane" @assert ninputs(gate) == 1 "Gate must have only one input" @@ -199,24 +198,45 @@ function simple_update_1site!(ψ::AbstractAnsatz, gate) return contract!(ψ, contracting_index) end -function simple_update!(::NonCanonical, ψ::AbstractAnsatz, gate; threshold=nothing, maxdim=nothing) +# TODO remove `renormalize` argument? +function simple_update!(::NonCanonical, ψ::AbstractAnsatz, gate; threshold=nothing, maxdim=nothing, renormalize=false) @assert nlanes(gate) == 2 "Only 2-site gates are supported currently" + @assert isneighbor(ψ, lanes(gate)...) "Gate must act on neighboring sites" # shallow copy to avoid problems if errors in mid execution gate = copy(gate) + resetindex!(gate; init=ninds(ψ)) + @reindex! outputs(gate) => inputs(gate) - merge!(ψ, gate) + # contract involved sites + bond = (sitel, siter) = extrema(lanes(gate)) + vind = inds(ψ; bond) + linds = filter(==(vind), inds(tensors(ψ; at=sitel))) + rinds = filter(==(vind), inds(tensors(ψ; at=siter))) + contract!(ψ; bond) + + # contract physical inds with gate + merge!(ψ, gate; reset=false) contract!(ψ, inds(gate; set=:inputs)) - # TODO split + # decompose using SVD + svd!(ψ; left_inds=linds, right_inds=rinds, virtualind=vind) + + # truncate virtual index + if any(!isnothing, (threshold, maxdim)) + truncate!(ψ, bond; threshold, maxdim) + renormalize && normalize!(ψ, bond[1]) + end return ψ end -# TODO move non-canonical code to method above -# TODO remove `renormalize` argument +# TODO remove `renormalize` argument? # TODO refactor code -function simple_update!(::Canonical, ψ::AbstractAnsatz, gate; threshold, maxdim, renormalize=false, iscanonical=false) +function simple_update!(::Canonical, ψ::AbstractAnsatz, gate; threshold, maxdim, renormalize=false) + @assert nlanes(gate) == 2 "Only 2-site gates are supported currently" + @assert isneighbor(ψ, lanes(gate)...) "Gate must act on neighboring sites" + # shallow copy to avoid problems if errors in mid execution gate = copy(gate) @@ -226,7 +246,7 @@ function simple_update!(::Canonical, ψ::AbstractAnsatz, gate; threshold, maxdim virtualind::Symbol = inds(ψ; bond=bond) - iscanonical ? contract_2sitewf!(ψ, bond) : contract!(TensorNetwork(ψ), virtualind) + contract_2sitewf!(ψ, bond) # reindex contracting index contracting_inds = [gensym(:tmp) for _ in sites(gate; set=:inputs)] @@ -267,22 +287,12 @@ function simple_update!(::Canonical, ψ::AbstractAnsatz, gate; threshold, maxdim push!(left_inds, inds(ψ; at=sitel)) push!(right_inds, inds(ψ; at=siter)) - if iscanonical - unpack_2sitewf!(ψ, bond, left_inds, right_inds, virtualind) - else - svd!(ψ; left_inds, right_inds, virtualind) - end + unpack_2sitewf!(ψ, bond, left_inds, right_inds, virtualind) + # truncate virtual index if any(!isnothing, [threshold, maxdim]) truncate!(ψ, bond; threshold, maxdim) - - # renormalize the bond - if renormalize && iscanonical - λ = tensors(ψ; between=bond) - replace!(ψ, λ => normalize(λ)) # TODO this can be replaced by `normalize!(λ)` - elseif renormalize && !iscanonical - normalize!(ψ, bond[1]) - end + renormalize && normalize!(tensors(ψ; between=bond)) end return ψ From 2a6ebfd820826cf2723ef4ac97ab2e1bc57d2f5e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 22 Sep 2024 01:13:23 -0400 Subject: [PATCH 53/75] Prototype tests for `Ansatz` --- test/Ansatz_test.jl | 2 ++ test/runtests.jl | 1 + 2 files changed, 3 insertions(+) create mode 100644 test/Ansatz_test.jl diff --git a/test/Ansatz_test.jl b/test/Ansatz_test.jl new file mode 100644 index 000000000..4c9f64a0b --- /dev/null +++ b/test/Ansatz_test.jl @@ -0,0 +1,2 @@ +@testset "Ansatz" begin +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index d0b57277a..4183a75c9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,6 +10,7 @@ using OMEinsum include("Transformations_test.jl") include("Site_test.jl") include("Quantum_test.jl") + include("Ansatz_test.jl") include("Product_test.jl") include("MPS_test.jl") include("MPO_test.jl") From 54312c5e9ef1d8ae710ada9f65bf1cd13725bc83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 26 Sep 2024 12:33:28 -0400 Subject: [PATCH 54/75] Some fixes for `PEPS` constructor --- src/Ansatz/PEPS.jl | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/src/Ansatz/PEPS.jl b/src/Ansatz/PEPS.jl index 1cb061d4d..7a72ab9ad 100644 --- a/src/Ansatz/PEPS.jl +++ b/src/Ansatz/PEPS.jl @@ -34,7 +34,7 @@ function PEPS(arrays::Matrix{<:AbstractArray}; order=defaultorder(PEPS)) m, n = size(arrays) - predicate = all(eachindex(arrays)) do I + predicate = all(eachindex(IndexCartesian(), arrays)) do I i, j = Tuple(I) array = arrays[i, j] @@ -54,19 +54,21 @@ function PEPS(arrays::Matrix{<:AbstractArray}; order=defaultorder(PEPS)) vvinds = [nextindex!(gen) for _ in 1:(m - 1), _ in 1:n] hvinds = [nextindex!(gen) for _ in 1:m, _ in 1:(n - 1)] - _tensors = map(eachindex(IndexCartesian(), arrays)) do I - i, j = Tuple(I) - - array = arrays[i, j] - pind = pinds[i, j] - up = i == 1 ? missing : vvinds[i - 1, j] - down = i == m ? missing : vvinds[i, j] - left = j == 1 ? missing : hvinds[i, j - 1] - right = j == n ? missing : hvinds[i, j] - - # TODO customize order - Tensor(array, collect(skipmissing([pind, up, down, left, right]))) - end + tn = TensorNetwork( + map(eachindex(IndexCartesian(), arrays)) do I + i, j = Tuple(I) + + array = arrays[i, j] + pind = pinds[i, j] + up = i == 1 ? missing : vvinds[i - 1, j] + down = i == m ? missing : vvinds[i, j] + left = j == 1 ? missing : hvinds[i, j - 1] + right = j == n ? missing : hvinds[i, j] + + # TODO customize order + Tensor(array, collect(skipmissing([pind, up, down, left, right]))) + end, + ) sitemap = Dict(Site(i, j) => pinds[i, j] for i in 1:m, j in 1:n) qtn = Quatum(tn, sitemap) From 9781717c73e7dff04f7e0af1a9ae4efb0ba983fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 26 Sep 2024 12:39:38 -0400 Subject: [PATCH 55/75] Remove check in `PEPS` constructor --- src/Ansatz/PEPS.jl | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/Ansatz/PEPS.jl b/src/Ansatz/PEPS.jl index 7a72ab9ad..ff8939949 100644 --- a/src/Ansatz/PEPS.jl +++ b/src/Ansatz/PEPS.jl @@ -34,20 +34,20 @@ function PEPS(arrays::Matrix{<:AbstractArray}; order=defaultorder(PEPS)) m, n = size(arrays) - predicate = all(eachindex(IndexCartesian(), arrays)) do I - i, j = Tuple(I) - array = arrays[i, j] + # predicate = all(eachindex(IndexCartesian(), arrays)) do I + # i, j = Tuple(I) + # array = arrays[i, j] - N = ndims(array) - 1 - (i == 1 || i == m) && (N -= 1) - (j == 1 || j == n) && (N -= 1) + # N = ndims(array) - 1 + # (i == 1 || i == m) && (N -= 1) + # (j == 1 || j == n) && (N -= 1) - N > 0 - end + # N > 0 + # end - if !predicate - throw(DimensionMismatch()) - end + # if !predicate + # throw(DimensionMismatch()) + # end gen = IndexCounter() pinds = map(_ -> nextindex!(gen), arrays) From 5b8da11e659dfd63a97c5c0f059989d1b48765bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 26 Sep 2024 17:58:15 -0400 Subject: [PATCH 56/75] Fix reference to lattice in `adapt_structure` for `Ansatz` --- ext/TenetAdaptExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/TenetAdaptExt.jl b/ext/TenetAdaptExt.jl index 06ed5bc9f..534355bb8 100644 --- a/ext/TenetAdaptExt.jl +++ b/ext/TenetAdaptExt.jl @@ -7,7 +7,7 @@ Adapt.adapt_structure(to, x::Tensor) = Tensor(adapt(to, parent(x)), inds(x)) Adapt.adapt_structure(to, x::TensorNetwork) = TensorNetwork(adapt.(Ref(to), tensors(x))) Adapt.adapt_structure(to, x::Quantum) = Quantum(adapt(to, TensorNetwork(x)), x.sites) -Adapt.adapt_structure(to, x::Ansatz) = Ansatz(adapt(to, Quantum(x)), lattice(x)) +Adapt.adapt_structure(to, x::Ansatz) = Ansatz(adapt(to, Quantum(x)), Tenet.lattice(x)) Adapt.adapt_structure(to, x::Product) = Product(adapt(to, Ansatz(x))) Adapt.adapt_structure(to, x::Dense) = Dense(adapt(to, Ansatz(x))) From 63edfc228ad06cbf3bea1fbe36d7dcba9a128d2f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 26 Sep 2024 17:58:56 -0400 Subject: [PATCH 57/75] Stop orthogonalization to index on `mixed_canonize!` --- src/Ansatz/MPS.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Ansatz/MPS.jl b/src/Ansatz/MPS.jl index 6eb11ef55..2fa92a36d 100644 --- a/src/Ansatz/MPS.jl +++ b/src/Ansatz/MPS.jl @@ -281,7 +281,7 @@ function mixed_canonize!(tn::MPS, orthog_center) end # center SVD sweep to get singular values - canonize_site!(tn, orthog_center; direction=:left, method=:svd) + # canonize_site!(tn, orthog_center; direction=:left, method=:svd) return tn end @@ -289,6 +289,6 @@ end # TODO normalize! methods function LinearAlgebra.normalize!(ψ::MPS, orthog_center=site"1") mixed_canonize!(ψ, orthog_center) - normalize!(tensors(ψ; between=(Site(id(orthog_center) - 1), orthog_center)), 2) + normalize!(tensors(ψ; at=orthog_center), 2) return ψ end From 7921a49142c5ac9f35b4249cae62e54e2ced65ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 26 Sep 2024 18:20:21 -0400 Subject: [PATCH 58/75] Aesthetic name fix --- ext/TenetReactantExt.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ext/TenetReactantExt.jl b/ext/TenetReactantExt.jl index 1c11fec22..8309e9ec6 100644 --- a/ext/TenetReactantExt.jl +++ b/ext/TenetReactantExt.jl @@ -38,14 +38,14 @@ end # TODO try rely on generic fallback for ansatzes function Reactant.make_tracer(seen::IdDict, prev::Tenet.Product, path::Tuple, mode::Reactant.TraceMode; kwargs...) - tracequantum = Reactant.make_tracer(seen, Ansatz(prev), Reactant.append_path(path, :tn), mode; kwargs...) - return Tenet.Product(tracequantum) + tracetn = Reactant.make_tracer(seen, Ansatz(prev), Reactant.append_path(path, :tn), mode; kwargs...) + return Tenet.Product(tracetn) end for A in (MPS, MPO) @eval function Reactant.make_tracer(seen::IdDict, prev::$A, path::Tuple, mode::Reactant.TraceMode; kwargs...) - tracequantum = Reactant.make_tracer(seen, Ansatz(prev), Reactant.append_path(path, :tn), mode; kwargs...) - return $A(tracequantum, form(prev)) + tracetn = Reactant.make_tracer(seen, Ansatz(prev), Reactant.append_path(path, :tn), mode; kwargs...) + return $A(tracetn, form(prev)) end end function Reactant.create_result(@nospecialize(tocopy::Tensor), @nospecialize(path), result_stores) From eb95d07522098e96eacddfab443da42857c2c3f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sat, 28 Sep 2024 16:44:22 -0400 Subject: [PATCH 59/75] Stop using `IdDict` on Reactant extension --- ext/TenetReactantExt.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/ext/TenetReactantExt.jl b/ext/TenetReactantExt.jl index 8309e9ec6..eb096340f 100644 --- a/ext/TenetReactantExt.jl +++ b/ext/TenetReactantExt.jl @@ -10,13 +10,13 @@ const stablehlo = MLIR.Dialects.stablehlo const Enzyme = Reactant.Enzyme function Reactant.make_tracer( - seen::IdDict, @nospecialize(prev::RT), path::Tuple, mode::Reactant.TraceMode; kwargs... + seen, @nospecialize(prev::RT), path::Tuple, mode::Reactant.TraceMode; kwargs... ) where {RT<:Tensor} tracedata = Reactant.make_tracer(seen, parent(prev), Reactant.append_path(path, :data), mode; kwargs...) return Tensor(tracedata, inds(prev)) end -function Reactant.make_tracer(seen::IdDict, prev::TensorNetwork, path::Tuple, mode::Reactant.TraceMode; kwargs...) +function Reactant.make_tracer(seen, prev::TensorNetwork, path::Tuple, mode::Reactant.TraceMode; kwargs...) tracetensors = Vector{Tensor}(undef, Tenet.ntensors(prev)) for (i, tensor) in enumerate(tensors(prev)) tracetensors[i] = Reactant.make_tracer(seen, tensor, Reactant.append_path(path, i), mode; kwargs...) @@ -26,24 +26,24 @@ end Reactant.traced_getfield(x::TensorNetwork, i::Int) = tensors(x)[i] -function Reactant.make_tracer(seen::IdDict, prev::Quantum, path::Tuple, mode::Reactant.TraceMode; kwargs...) +function Reactant.make_tracer(seen, prev::Quantum, path::Tuple, mode::Reactant.TraceMode; kwargs...) tracetn = Reactant.make_tracer(seen, TensorNetwork(prev), Reactant.append_path(path, :tn), mode; kwargs...) return Quantum(tracetn, copy(prev.sites)) end -function Reactant.make_tracer(seen::IdDict, prev::Ansatz, path::Tuple, mode::Reactant.TraceMode; kwargs...) +function Reactant.make_tracer(seen, prev::Ansatz, path::Tuple, mode::Reactant.TraceMode; kwargs...) tracetn = Reactant.make_tracer(seen, Quantum(prev), Reactant.append_path(path, :tn), mode; kwargs...) return Ansatz(tracetn, copy(Tenet.lattice(prev))) end # TODO try rely on generic fallback for ansatzes -function Reactant.make_tracer(seen::IdDict, prev::Tenet.Product, path::Tuple, mode::Reactant.TraceMode; kwargs...) +function Reactant.make_tracer(seen, prev::Tenet.Product, path::Tuple, mode::Reactant.TraceMode; kwargs...) tracetn = Reactant.make_tracer(seen, Ansatz(prev), Reactant.append_path(path, :tn), mode; kwargs...) return Tenet.Product(tracetn) end for A in (MPS, MPO) - @eval function Reactant.make_tracer(seen::IdDict, prev::$A, path::Tuple, mode::Reactant.TraceMode; kwargs...) + @eval function Reactant.make_tracer(seen, prev::$A, path::Tuple, mode::Reactant.TraceMode; kwargs...) tracetn = Reactant.make_tracer(seen, Ansatz(prev), Reactant.append_path(path, :tn), mode; kwargs...) return $A(tracetn, form(prev)) end From 22824d8efa6bd44c6d6891915d228f1635882f4b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 29 Sep 2024 21:49:48 -0400 Subject: [PATCH 60/75] Fix `create_result` on `MPS`, `MPO` --- ext/TenetReactantExt.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/TenetReactantExt.jl b/ext/TenetReactantExt.jl index eb096340f..215f98943 100644 --- a/ext/TenetReactantExt.jl +++ b/ext/TenetReactantExt.jl @@ -77,9 +77,9 @@ function Reactant.create_result(tocopy::Tenet.Product, @nospecialize(path), resu end for A in (MPS, MPO) - @eval function Reactant.create_result(tocopy::$A, @nospecialize(path), result_stores) + @eval function Reactant.create_result(tocopy::A, @nospecialize(path), result_stores) where {A<:$A} tn = Reactant.create_result(Ansatz(tocopy), Reactant.append_path(path, :tn), result_stores) - return :($A($tn, form(tocopy))) + return :($A($tn, $(Tenet.form(tocopy)))) end end From dfb5c2e9b2c26fe1d258d92254427d14c64e3a9f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Mon, 30 Sep 2024 01:17:20 -0400 Subject: [PATCH 61/75] Refactor lattice generation in constructors of `Dense`, `Product`, `MPS`, `MPO`, `PEPS` --- src/Ansatz/Dense.jl | 14 ++++---------- src/Ansatz/MPO.jl | 2 +- src/Ansatz/MPS.jl | 2 +- src/Ansatz/PEPS.jl | 2 +- src/Ansatz/Product.jl | 6 ++++-- 5 files changed, 11 insertions(+), 15 deletions(-) diff --git a/src/Ansatz/Dense.jl b/src/Ansatz/Dense.jl index 936794995..97768b4a5 100644 --- a/src/Ansatz/Dense.jl +++ b/src/Ansatz/Dense.jl @@ -27,11 +27,8 @@ function Dense(::State, array::AbstractArray; sites=Site.(1:ndims(array))) tn = TensorNetwork([tensor]) qtn = Quantum(tn, sitemap) - lattice = MetaGraph( - complete_graph(nlanes(qtn)), - Pair{Site,Nothing}[i => nothing for i in lanes(qtn)], - Pair{Tuple{Site,Site},Nothing}[(i, j) => nothing for (i, j) in combinations(lanes(qtn), 2)], - ) + graph = complete_graph(nlanes(qtn)) + lattice = MetaGraph(graph, lanes(qtn) .=> nothing, map(x -> Site.(Tuple(x)) => nothing, edges(graph))) ansatz = Ansatz(qtn, lattice) return Dense(ansatz) end @@ -49,11 +46,8 @@ function Dense(::Operator, array::AbstractArray; sites) sitemap = Dict{Site,Symbol}(map(splat(Pair), zip(sites, tensor_inds))) qtn = Quantum(tn, sitemap) - lattice = MetaGraph( - complete_graph(nlanes(qtn)), - Pair{Site,Nothing}[i => nothing for i in lanes(qtn)], - Pair{Tuple{Site,Site},Nothing}[(i, j) => nothing for (i, j) in combinations(lanes(qtn), 2)], - ) + graph = complete_graph(nlanes(qtn)) + lattice = MetaGraph(graph, lanes(qtn) .=> nothing, map(x -> Site.(Tuple(x)) => nothing, edges(graph))) ansatz = Ansatz(qtn, lattice) return Dense(ansatz) end diff --git a/src/Ansatz/MPO.jl b/src/Ansatz/MPO.jl index 395a1d665..ddbe82a1d 100644 --- a/src/Ansatz/MPO.jl +++ b/src/Ansatz/MPO.jl @@ -59,7 +59,7 @@ function MPO(arrays::Vector{<:AbstractArray}; order=defaultorder(MPO)) merge!(sitemap, Dict(Site(i; dual=true) => symbols[i + n] for i in 1:n)) qtn = Quantum(tn, sitemap) graph = path_graph(n) - lattice = MetaGraph(graph, Site.(vertices(graph)) .=> nothing, map(x -> Site.(Tuple(x)) => nothing, edges(graph))) + lattice = MetaGraph(graph, lanes(qtn) .=> nothing, map(x -> Site.(Tuple(x)) => nothing, edges(graph))) ansatz = Ansatz(qtn, lattice) return MPO(ansatz, NonCanonical()) end diff --git a/src/Ansatz/MPS.jl b/src/Ansatz/MPS.jl index 2fa92a36d..51851a793 100644 --- a/src/Ansatz/MPS.jl +++ b/src/Ansatz/MPS.jl @@ -57,7 +57,7 @@ function MPS(arrays::Vector{<:AbstractArray}; order=defaultorder(MPS)) sitemap = Dict(Site(i) => symbols[i] for i in 1:n) qtn = Quantum(tn, sitemap) graph = path_graph(n) - lattice = MetaGraph(graph, Site.(vertices(graph)) .=> nothing, map(x -> Site.(Tuple(x)) => nothing, edges(graph))) + lattice = MetaGraph(graph, lanes(qtn) .=> nothing, map(x -> Site.(Tuple(x)) => nothing, edges(graph))) ansatz = Ansatz(qtn, lattice) return MPS(ansatz, NonCanonical()) end diff --git a/src/Ansatz/PEPS.jl b/src/Ansatz/PEPS.jl index ff8939949..ef262b1d3 100644 --- a/src/Ansatz/PEPS.jl +++ b/src/Ansatz/PEPS.jl @@ -74,7 +74,7 @@ function PEPS(arrays::Matrix{<:AbstractArray}; order=defaultorder(PEPS)) qtn = Quatum(tn, sitemap) graph = grid((m, n)) # TODO fix this - lattice = MetaGraph(graph, Site.(vertices(graph)) .=> nothing, map(x -> Site.(Tuple(x)) => nothing, edges(graph))) + lattice = MetaGraph(graph, lanes(qtn) .=> nothing, map(x -> Site.(Tuple(x)) => nothing, edges(graph))) ansatz = Ansatz(qtn, lattice) return PEPS(ansatz, NonCanonical()) end diff --git a/src/Ansatz/Product.jl b/src/Ansatz/Product.jl index 0c5bde2be..6412006d6 100644 --- a/src/Ansatz/Product.jl +++ b/src/Ansatz/Product.jl @@ -22,7 +22,8 @@ function Product(arrays::Vector{<:AbstractVector}) sitemap = Dict(Site(i) => symbols[i] for i in 1:n) qtn = Quantum(TensorNetwork(_tensors), sitemap) - lattice = MetaGraph(Graph(n), Pair{Site,Nothing}[Site(i) => nothing for i in 1:n], Pair{Tuple{Site,Site},Nothing}[]) + graph = Graph(n) + lattice = MetaGraph(graph, lanes(qtn) .=> nothing, map(x -> Site.(Tuple(x)) => nothing, edges(graph))) ansatz = Ansatz(qtn, lattice) return Product(ansatz) end @@ -37,7 +38,8 @@ function Product(arrays::Vector{<:AbstractMatrix}) sitemap = merge!(Dict(Site(i; dual=true) => symbols[i] for i in 1:n), Dict(Site(i) => symbols[i + n] for i in 1:n)) qtn = Quantum(TensorNetwork(_tensors), sitemap) - lattice = MetaGraph(Graph(n), Pair{Site,Nothing}[Site(i) => nothing for i in 1:n], Pair{Tuple{Site,Site},Nothing}[]) + graph = Graph(n) + lattice = MetaGraph(graph, lanes(qtn) .=> nothing, map(x -> Site.(Tuple(x)) => nothing, edges(graph))) ansatz = Ansatz(qtn, lattice) return Product(ansatz) end From ef25de523cd561a1fae407603336418256981707 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Mon, 30 Sep 2024 01:18:01 -0400 Subject: [PATCH 62/75] Fix `make_tracer`, `create_result` from Reactant on `Product`, `Dense` --- ext/TenetReactantExt.jl | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/ext/TenetReactantExt.jl b/ext/TenetReactantExt.jl index 215f98943..29a1b948c 100644 --- a/ext/TenetReactantExt.jl +++ b/ext/TenetReactantExt.jl @@ -37,9 +37,11 @@ function Reactant.make_tracer(seen, prev::Ansatz, path::Tuple, mode::Reactant.Tr end # TODO try rely on generic fallback for ansatzes -function Reactant.make_tracer(seen, prev::Tenet.Product, path::Tuple, mode::Reactant.TraceMode; kwargs...) - tracetn = Reactant.make_tracer(seen, Ansatz(prev), Reactant.append_path(path, :tn), mode; kwargs...) - return Tenet.Product(tracetn) +for A in (Product, Dense) + @eval function Reactant.make_tracer(seen, prev::$A, path::Tuple, mode::Reactant.TraceMode; kwargs...) + tracetn = Reactant.make_tracer(seen, Ansatz(prev), Reactant.append_path(path, :tn), mode; kwargs...) + return $A(tracetn) + end end for A in (MPS, MPO) @@ -48,6 +50,7 @@ for A in (MPS, MPO) return $A(tracetn, form(prev)) end end + function Reactant.create_result(@nospecialize(tocopy::Tensor), @nospecialize(path), result_stores) data = Reactant.create_result(parent(tocopy), Reactant.append_path(path, :data), result_stores) return :($Tensor($data, $(inds(tocopy)))) @@ -71,9 +74,11 @@ function Reactant.create_result(tocopy::Ansatz, @nospecialize(path), result_stor end # TODO try rely on generic fallback for ansatzes -function Reactant.create_result(tocopy::Tenet.Product, @nospecialize(path), result_stores) - tn = Reactant.create_result(Ansatz(tocopy), Reactant.append_path(path, :tn), result_stores) - return :($(Tenet.Product)($tn)) +for A in (Product, Dense) + @eval function Reactant.create_result(tocopy::A, @nospecialize(path), result_stores) where {A<:$A} + tn = Reactant.create_result(Ansatz(tocopy), Reactant.append_path(path, :tn), result_stores) + return :($A($tn)) + end end for A in (MPS, MPO) From 9ffddf384afc236114cd7bc4664c2f1f00ef4ab9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Mon, 30 Sep 2024 01:18:29 -0400 Subject: [PATCH 63/75] Implement `rand`, `normalize!`, `overlap` for `Dense` states --- src/Ansatz/Dense.jl | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/Ansatz/Dense.jl b/src/Ansatz/Dense.jl index 97768b4a5..8bc2b69f4 100644 --- a/src/Ansatz/Dense.jl +++ b/src/Ansatz/Dense.jl @@ -51,3 +51,22 @@ function Dense(::Operator, array::AbstractArray; sites) ansatz = Ansatz(qtn, lattice) return Dense(ansatz) end + +function Base.rand(rng::Random.AbstractRNG, ::Type{Dense}, ::State; n, eltype=Float64, physdim=2) + array = rand(rng, eltype, fill(physdim, n)...) + normalize!(array) + return Dense(State(), array; sites=Site.(1:n)) +end + +function LinearAlgebra.normalize!(ψ::Dense) + normalize!(only(arrays(ψ))) + return ψ +end + +function overlap(ϕ::Dense, ψ::Dense) + @assert lanes(ϕ) == lanes(ψ) + @assert socket(ϕ) == State() && socket(ψ) == State() + ψ = copy(ψ) + @reindex! outputs(ϕ) => outputs(ψ) + return contract(only(tensors(ϕ)), only(tensors(ψ))) +end From 7cd5fae8cb9df97aab8b496c0121b03bec0fc7a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 2 Oct 2024 04:31:59 +0200 Subject: [PATCH 64/75] Set temporarily a more concrete type of `lattice` in graph to circunvent a Julia bug --- src/Ansatz/Ansatz.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Ansatz/Ansatz.jl b/src/Ansatz/Ansatz.jl index 9a29edf77..22ba076b5 100644 --- a/src/Ansatz/Ansatz.jl +++ b/src/Ansatz/Ansatz.jl @@ -38,7 +38,7 @@ abstract type AbstractAnsatz <: AbstractQuantum end """ struct Ansatz <: AbstractAnsatz tn::Quantum - lattice::MetaGraph + lattice::MetaGraph{Int,G,Site{1},Nothing} where {G<:Graphs.AbstractGraph{Int}} function Ansatz(tn, lattice) if !issetequal(lanes(tn), labels(lattice)) From 71964b0735a9a972b2f33e9cb73ae85a86999943 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 6 Oct 2024 16:23:31 -0400 Subject: [PATCH 65/75] Fix pairwise `contract` between `TracedRArray` and `Array` --- ext/TenetReactantExt.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/ext/TenetReactantExt.jl b/ext/TenetReactantExt.jl index 29a1b948c..bcf918088 100644 --- a/ext/TenetReactantExt.jl +++ b/ext/TenetReactantExt.jl @@ -205,4 +205,9 @@ function Tenet.contract(a::Tensor{T,N,A}; dims=nonunique(inds(a)), out=nothing) return Tensor(data, ic) end +Tenet.contract(a::Tensor, b::Tensor{T,N,Reactant.TracedRArray{T,N}}; kwargs...) where {T,N} = contract(b, a; kwargs...) +function Tenet.contract(a::Tensor{Ta,Na,Reactant.TracedRArray{Ta,Na}}, b::Tensor{Tb,Nb}; kwargs...) where {Ta,Na,Tb,Nb} + return contract(a, Reactant.promote_to(Reactant.TracedRArray{Tb,Nb}, b); kwargs...) +end + end From 20d7183675670f019c2f8eb9150abf2df6f5f795 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 6 Oct 2024 16:47:42 -0400 Subject: [PATCH 66/75] Small fix to `promote_to` of `Tensor` --- ext/TenetReactantExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/TenetReactantExt.jl b/ext/TenetReactantExt.jl index bcf918088..5bc6a8b7c 100644 --- a/ext/TenetReactantExt.jl +++ b/ext/TenetReactantExt.jl @@ -207,7 +207,7 @@ end Tenet.contract(a::Tensor, b::Tensor{T,N,Reactant.TracedRArray{T,N}}; kwargs...) where {T,N} = contract(b, a; kwargs...) function Tenet.contract(a::Tensor{Ta,Na,Reactant.TracedRArray{Ta,Na}}, b::Tensor{Tb,Nb}; kwargs...) where {Ta,Na,Tb,Nb} - return contract(a, Reactant.promote_to(Reactant.TracedRArray{Tb,Nb}, b); kwargs...) + return contract(a, Tensor(Reactant.promote_to(Reactant.TracedRArray{Tb,Nb}, parent(b)), inds(b)); kwargs...) end end From 4121133897462bc27e6a957f9a017cd9abc29caa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 6 Oct 2024 17:05:42 -0400 Subject: [PATCH 67/75] Refactor by importing `Reactant.TracedRArray` --- ext/TenetReactantExt.jl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/ext/TenetReactantExt.jl b/ext/TenetReactantExt.jl index 5bc6a8b7c..f341abdf4 100644 --- a/ext/TenetReactantExt.jl +++ b/ext/TenetReactantExt.jl @@ -3,7 +3,7 @@ module TenetReactantExt using Tenet using EinExprs using Reactant -using Reactant: @reactant_override +using Reactant: @reactant_override, TracedRArray const MLIR = Reactant.MLIR const stablehlo = MLIR.Dialects.stablehlo @@ -148,7 +148,7 @@ end function Tenet.contract( a::Tensor{Ta,Na,Aa}, b::Tensor{Tb,Nb,Ab}; dims=(∩(inds(a), inds(b))), out=nothing -) where {Ta,Na,Aa<:Reactant.TracedRArray,Tb,Nb,Ab<:Reactant.TracedRArray} +) where {Ta,Na,Aa<:TracedRArray,Tb,Nb,Ab<:TracedRArray} ia = collect(inds(a)) ib = collect(inds(b)) i = ∩(dims, ia, ib) @@ -177,12 +177,12 @@ function Tenet.contract( result = Reactant.MLIR.IR.result(stablehlo.einsum(op_a, op_b; result_0, einsum_config)) - data = Reactant.TracedRArray{T,length(ic)}((), result, rsize) + data = TracedRArray{T,length(ic)}((), result, rsize) _res = Tensor(data, ic) return _res end -function Tenet.contract(a::Tensor{T,N,A}; dims=nonunique(inds(a)), out=nothing) where {T,N,A<:Reactant.TracedRArray} +function Tenet.contract(a::Tensor{T,N,A}; dims=nonunique(inds(a)), out=nothing) where {T,N,A<:TracedRArray} ia = inds(a) i = ∩(dims, ia) @@ -201,13 +201,13 @@ function Tenet.contract(a::Tensor{T,N,A}; dims=nonunique(inds(a)), out=nothing) result = Reactant.MLIR.IR.result(stablehlo.unary_einsum(operand; result_0, einsum_config)) - data = Reactant.TracedRArray{T,length(ic)}((), result, rsize) + data = TracedRArray{T,length(ic)}((), result, rsize) return Tensor(data, ic) end -Tenet.contract(a::Tensor, b::Tensor{T,N,Reactant.TracedRArray{T,N}}; kwargs...) where {T,N} = contract(b, a; kwargs...) -function Tenet.contract(a::Tensor{Ta,Na,Reactant.TracedRArray{Ta,Na}}, b::Tensor{Tb,Nb}; kwargs...) where {Ta,Na,Tb,Nb} - return contract(a, Tensor(Reactant.promote_to(Reactant.TracedRArray{Tb,Nb}, parent(b)), inds(b)); kwargs...) +Tenet.contract(a::Tensor, b::Tensor{T,N,TracedRArray{T,N}}; kwargs...) where {T,N} = contract(b, a; kwargs...) +function Tenet.contract(a::Tensor{Ta,Na,TracedRArray{Ta,Na}}, b::Tensor{Tb,Nb}; kwargs...) where {Ta,Na,Tb,Nb} + return contract(a, Tensor(Reactant.promote_to(TracedRArray{Tb,Nb}, parent(b)), inds(b)); kwargs...) end end From 4aa2830e0002a83c509d558ee7f116aa5092ba54 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 6 Oct 2024 17:07:07 -0400 Subject: [PATCH 68/75] Try remove ambiguity on `contract` with `TracedRArray` --- ext/TenetReactantExt.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ext/TenetReactantExt.jl b/ext/TenetReactantExt.jl index f341abdf4..b0b326d3e 100644 --- a/ext/TenetReactantExt.jl +++ b/ext/TenetReactantExt.jl @@ -147,8 +147,8 @@ end end function Tenet.contract( - a::Tensor{Ta,Na,Aa}, b::Tensor{Tb,Nb,Ab}; dims=(∩(inds(a), inds(b))), out=nothing -) where {Ta,Na,Aa<:TracedRArray,Tb,Nb,Ab<:TracedRArray} + a::Tensor{Ta,Na,TracedRArray{Ta,Na}}, b::Tensor{Tb,Nb,TracedRArray{Tb,Nb}}; dims=(∩(inds(a), inds(b))), out=nothing +) where {Ta,Na,Tb,Nb} ia = collect(inds(a)) ib = collect(inds(b)) i = ∩(dims, ia, ib) @@ -182,7 +182,7 @@ function Tenet.contract( return _res end -function Tenet.contract(a::Tensor{T,N,A}; dims=nonunique(inds(a)), out=nothing) where {T,N,A<:TracedRArray} +function Tenet.contract(a::Tensor{T,N,TracedRArray{T,N}}; dims=nonunique(inds(a)), out=nothing) where {T,N} ia = inds(a) i = ∩(dims, ia) From 25791371ae57534b3471e4a8ac7b00d19e5c0964 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Mon, 7 Oct 2024 20:37:06 -0400 Subject: [PATCH 69/75] Dispatch `det`, `logdet`, `tr` methods to underlying array on matrix `Tensor`s --- src/Tensor.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/Tensor.jl b/src/Tensor.jl index bc45fcef4..45b191520 100644 --- a/src/Tensor.jl +++ b/src/Tensor.jl @@ -243,3 +243,7 @@ function __expand_repeat(array, axis, size) end LinearAlgebra.opnorm(x::Tensor, p::Real) = opnorm(parent(x), p) + +LinearAlgebra.det(x::Tensor{T,2}) where {T} = det(parent(x)) +LinearAlgebra.logdet(x::Tensor{T,2}) where {T} = logdet(parent(x)) +LinearAlgebra.tr(x::Tensor{T,2}) where {T} = tr(parent(x)) From 1c65921e7893581f87dd4134faf4ba65553de33d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Mon, 7 Oct 2024 20:40:16 -0400 Subject: [PATCH 70/75] Implement Eigendecomposition for `Tensor` --- src/Numerics.jl | 66 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/src/Numerics.jl b/src/Numerics.jl index 0e9a7c02c..ad14675e4 100644 --- a/src/Numerics.jl +++ b/src/Numerics.jl @@ -119,6 +119,72 @@ function factorinds(tensor, left_inds, right_inds) return left_inds, right_inds end +# TODO is this an `AbstractTensorNetwork`? +# TODO add fancier `show` method +struct TensorEigen{T,V,Nᵣ,S<:AbstractVector{V},U<:AbstractArray{T,Nᵣ}} <: Factorization{T} + values::Tensor{V,1,S} + vectors::Tensor{T,Nᵣ,U} + right_inds::Vector{Symbol} +end + +function Base.getproperty(obj::TensorEigen, name::Symbol) + if name === :U + return obj.vectors + elseif name === :Λ + return obj.values + elseif name ∈ [:Uinv, :U⁻¹] + U = reshape(parent(obj.vectors), prod(size(obj.vectors)[1:(end - 1)]), size(obj.vectors)[end]) + Uinv = inv(U) + return Tensor(Uinv, [only(inds(obj.values)), obj.right_inds...]) + end + return getfield(obj, name) +end + +function Base.inv(F::TensorEigen) + U = reshape(parent(F.vectors), prod(size(F.vectors)[1:(end - 1)]), size(F.vectors)[end]) + return Tensor(U * inv(Diagonal(F.values)) / U, [F.left_inds..., F.right_inds...]) +end +LinearAlgebra.det(x::TensorEigen) = prod(x.values) + +Base.iterate(x::TensorEigen) = (x.values, :vectors) +Base.iterate(x::TensorEigen, state) = state == :vectors ? (x.vectors, nothing) : nothing + +LinearAlgebra.eigen(t::Tensor{<:Any,2}; kwargs...) = @invoke eigen(t::Tensor; left_inds=(first(inds(t)),), kwargs...) +function LinearAlgebra.eigen(tensor::Tensor; left_inds=(), right_inds=(), virtualind=Symbol(uuid4()), kwargs...) + left_inds, right_inds = factorinds(tensor, left_inds, right_inds) + + virtualind ∉ inds(tensor) || + throw(ArgumentError("new virtual bond name ($virtualind) cannot be already be present")) + + # permute array + left_sizes = map(Base.Fix1(size, tensor), left_inds) + right_sizes = map(Base.Fix1(size, tensor), right_inds) + tensor = permutedims(tensor, [left_inds..., right_inds...]) + data = reshape(parent(tensor), prod(left_sizes), prod(right_sizes)) + + # compute eigendecomposition + Λ, U = eigen(data; kwargs...) + + # tensorify results + Λ = Tensor(Λ, [virtualind]) + U = Tensor(reshape(U, left_sizes..., size(U, 2)), [left_inds..., virtualind]) + + return TensorEigen(Λ, U, right_inds) +end + +# TODO document when it returns a `Tensor` and when returns an `Array` +LinearAlgebra.eigvals(t::Tensor{<:Any,2}; kwargs...) = eigvals(parent(t); kwargs...) +function LinearAlgebra.eigvals(tensor::Tensor; left_inds=(), right_inds=(), kwargs...) + F = eigen(tensor; left_inds, right_inds, kwargs...) + return parent(F.values) +end + +LinearAlgebra.eigvecs(t::Tensor{<:Any,2}; kwargs...) = eigvecs(parent(t); kwargs...) +function LinearAlgebra.eigvecs(tensor::Tensor; left_inds=(), right_inds=(), kwargs...) + F = eigen(tensor; left_inds, right_inds, kwargs...) + return F.vectors +end + LinearAlgebra.svd(t::Tensor{<:Any,2}; kwargs...) = Base.@invoke svd(t::Tensor; left_inds=(first(inds(t)),), kwargs...) """ From 3670dfea25ec35db8988df260566eb900d1771f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Mon, 7 Oct 2024 20:41:03 -0400 Subject: [PATCH 71/75] Implement `eigen!` for `TensorNetwork` --- src/TensorNetwork.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/TensorNetwork.jl b/src/TensorNetwork.jl index 78f9920f2..a7f47110f 100644 --- a/src/TensorNetwork.jl +++ b/src/TensorNetwork.jl @@ -643,6 +643,13 @@ contract(t::Tensor, tn::AbstractTensorNetwork; kwargs...) = contract(tn, t; kwar return contract(intermediates...; dims=suminds(path)) end +function LinearAlgebra.eigen!(tn::AbstractTensorNetwork; left_inds=Symbol[], right_inds=Symbol[], kwargs...) + tensor = tn[left_inds ∪ right_inds...] + (; U, Λ, U⁻¹) = eigen(tensor; left_inds, right_inds, kwargs...) + replace!(tn, tensor => TensorNetwork([U, Λ, U⁻¹])) + return tn +end + function LinearAlgebra.svd!(tn::AbstractTensorNetwork; left_inds=Symbol[], right_inds=Symbol[], kwargs...) tensor = tn[left_inds ∪ right_inds...] U, s, Vt = svd(tensor; left_inds, right_inds, kwargs...) From cd937af768be4f9f522d30b922c23c623e878be6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Mon, 7 Oct 2024 20:53:46 -0400 Subject: [PATCH 72/75] small fix --- src/Numerics.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/Numerics.jl b/src/Numerics.jl index ad14675e4..7cdd81a2a 100644 --- a/src/Numerics.jl +++ b/src/Numerics.jl @@ -142,7 +142,8 @@ end function Base.inv(F::TensorEigen) U = reshape(parent(F.vectors), prod(size(F.vectors)[1:(end - 1)]), size(F.vectors)[end]) - return Tensor(U * inv(Diagonal(F.values)) / U, [F.left_inds..., F.right_inds...]) + left_inds = inds(F.vectors)[1:(end - 1)] + return Tensor(U * inv(Diagonal(F.values)) / U, [left_inds..., F.right_inds...]) end LinearAlgebra.det(x::TensorEigen) = prod(x.values) From fee3b07957e96a6ffa7fae10629d89eee4f95f6f Mon Sep 17 00:00:00 2001 From: Todorbsc <145352308+Todorbsc@users.noreply.github.com> Date: Mon, 14 Oct 2024 21:36:29 +0200 Subject: [PATCH 73/75] Implement an MPS method initializing the tensors to identity (copy-tensors) (#218) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Format code * Implement MPS identity initialization * Add tests for all dispatches of MPS identity init * Format julia code * Rename function header & add docstring * Fix test set for identity MPS * Format code * Rewrite MPS identity init function to nsites instead of arrays' dimensions * Format julia code * Update docstring of identity * Clean code in test (suggested by Jofre) Co-authored-by: Jofre Vallès Muns <61060572+jofrevalles@users.noreply.github.com> * Format julia code * Refactor virtualdims in identity (suggested by Sergio) * Update src/Ansatz/MPS.jl * Restrict to default order in identity MPS * Update src/Ansatz/MPS.jl * Remove order parameter in identity --------- Co-authored-by: Jofre Vallès Muns <61060572+jofrevalles@users.noreply.github.com> Co-authored-by: Sergio Sánchez Ramírez <15837247+mofeing@users.noreply.github.com> --- src/Ansatz/MPS.jl | 33 ++++++++++++++++++++++++++++++++ test/Ansatz_test.jl | 3 +-- test/MPS_test.jl | 46 +++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 80 insertions(+), 2 deletions(-) diff --git a/src/Ansatz/MPS.jl b/src/Ansatz/MPS.jl index 51851a793..28f83f7f6 100644 --- a/src/Ansatz/MPS.jl +++ b/src/Ansatz/MPS.jl @@ -62,6 +62,39 @@ function MPS(arrays::Vector{<:AbstractArray}; order=defaultorder(MPS)) return MPS(ansatz, NonCanonical()) end +""" + Base.identity(::Type{MPS}, n::Integer; physdim=2, maxdim=physdim^(n ÷ 2)) + +Returns an [`MPS`](@ref) of `n` sites whose tensors are initialized to COPY-tensors. + +# Keyword Arguments + + - `physdim` The physical or output dimension of each site. Defaults to 2. + - `maxdim` The maximum bond dimension. Defaults to `physdim^(n ÷ 2)`. +""" +function Base.identity(::Type{MPS}, n::Integer; physdim=2, maxdim=physdim^(n ÷ 2)) + # Create bond dimensions until the middle of the MPS considering maxdim + virtualdims = min.(maxdim, physdim .^ (1:(n ÷ 2))) + + # Complete the bond dimensions of the other half of the MPS + virtualdims = vcat(virtualdims, virtualdims[(isodd(n) ? end : end - 1):-1:1]) + + # Create each site dimensions in default order (:o, :l, :r) + arraysdims = [[physdim, virtualdims[1]]] + append!(arraysdims, [[physdim, virtualdims[i], virtualdims[i + 1]] for i in 1:(length(virtualdims) - 1)]) + push!(arraysdims, [physdim, virtualdims[end]]) + + # Create the MPS with copy-tensors according to the tensors dimensions + return MPS( + map(arraysdims) do arrdims + arr = zeros(ComplexF64, arrdims...) + deltas = [fill(i, length(arrdims)) for i in 1:physdim] + broadcast(delta -> arr[delta...] = 1.0, deltas) + arr + end, + ) +end + function Base.convert(::Type{MPS}, tn::Product) @assert socket(tn) == State() diff --git a/test/Ansatz_test.jl b/test/Ansatz_test.jl index 4c9f64a0b..7c9218851 100644 --- a/test/Ansatz_test.jl +++ b/test/Ansatz_test.jl @@ -1,2 +1 @@ -@testset "Ansatz" begin -end \ No newline at end of file +@testset "Ansatz" begin end diff --git a/test/MPS_test.jl b/test/MPS_test.jl index 43b665bb8..9ac2f7350 100644 --- a/test/MPS_test.jl +++ b/test/MPS_test.jl @@ -26,6 +26,52 @@ @test inds(ψ; at=site"3", dir=:left) == inds(ψ; at=site"2", dir=:right) !== nothing @test all(i -> size(ψ, inds(ψ; at=Site(i))) == 2, 1:nsites(ψ)) + @testset "Base.identity" begin + nsites_cases = [6, 7, 6, 7] + physdim_cases = [3, 2, 3, 2] + maxdim_cases = [nothing, nothing, 9, 4] # nothing means default + expected_tensorsizes_cases = [ + [(3, 3), (3, 3, 9), (3, 9, 27), (3, 27, 9), (3, 9, 3), (3, 3)], + [(2, 2), (2, 2, 4), (2, 4, 8), (2, 8, 8), (2, 8, 4), (2, 4, 2), (2, 2)], + [(3, 3), (3, 3, 9), (3, 9, 9), (3, 9, 9), (3, 9, 3), (3, 3)], + [(2, 2), (2, 2, 4), (2, 4, 4), (2, 4, 4), (2, 4, 4), (2, 4, 2), (2, 2)], + ] + + for (nsites, physdim, expected_tensorsizes, maxdim) in + zip(nsites_cases, physdim_cases, expected_tensorsizes_cases, maxdim_cases) + ψ = if isnothing(maxdim) + identity(MPS, nsites; physdim=physdim) + else + identity(MPS, nsites; physdim=physdim, maxdim=maxdim) + end + + # Test the tensor dimensions + obtained_tensorsizes = size.(tensors(ψ)) + @test obtained_tensorsizes == expected_tensorsizes + + # Test whether all tensors are the identity + alltns = tensors(ψ) + + # - Test extreme tensors (2D) equal identity + diagonal_2D = [fill(i, 2) for i in 1:physdim] + @test all(delta -> alltns[1][delta...] == 1, diagonal_2D) + @test sum(alltns[1]) == physdim + @test all(delta -> alltns[end][delta...] == 1, diagonal_2D) + @test sum(alltns[end]) == physdim + + # - Test bulk tensors (3D) equal identity + diagonal_3D = [fill(i, 3) for i in 1:physdim] + @test all(tns -> all(delta -> tns[delta...] == 1, diagonal_3D), alltns[2:(end - 1)]) + @test all(tns -> sum(tns) == physdim, alltns[2:(end - 1)]) + + # Test whether the contraction gives the identity + contracted_ψ = contract(ψ) + diagonal_nsitesD = [fill(i, nsites) for i in 1:physdim] + @test all(delta -> contracted_ψ[delta...] == 1, diagonal_nsitesD) + @test sum(contracted_ψ) == physdim + end + end + @testset "Site" begin ψ = MPS([rand(2, 2), rand(2, 2, 2), rand(2, 2)]) From f39b6f80b030450c14675f2e78ad9fa68aaaf1c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Fri, 18 Oct 2024 16:47:12 +0200 Subject: [PATCH 74/75] Implement `traced_type` for `AbstractTensorNetwork` Fixes some problems when calling `make_tracer` on an object that wraps an `Ansatz`, like `Tuple{MPS}` for example. --- ext/TenetReactantExt.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ext/TenetReactantExt.jl b/ext/TenetReactantExt.jl index b0b326d3e..152557820 100644 --- a/ext/TenetReactantExt.jl +++ b/ext/TenetReactantExt.jl @@ -9,6 +9,8 @@ const stablehlo = MLIR.Dialects.stablehlo const Enzyme = Reactant.Enzyme +Reactant.traced_type(::Type{T}, _, _) where {T<:Tenet.AbstractTensorNetwork} = T + function Reactant.make_tracer( seen, @nospecialize(prev::RT), path::Tuple, mode::Reactant.TraceMode; kwargs... ) where {RT<:Tensor} From f5aa926921a07267d2b7a1abeab3a6c1f5dbd6a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Fri, 18 Oct 2024 16:47:47 +0200 Subject: [PATCH 75/75] Format code --- ext/TenetReactantExt.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ext/TenetReactantExt.jl b/ext/TenetReactantExt.jl index 152557820..62bbf336b 100644 --- a/ext/TenetReactantExt.jl +++ b/ext/TenetReactantExt.jl @@ -10,6 +10,7 @@ const stablehlo = MLIR.Dialects.stablehlo const Enzyme = Reactant.Enzyme Reactant.traced_type(::Type{T}, _, _) where {T<:Tenet.AbstractTensorNetwork} = T +Reactant.traced_getfield(x::TensorNetwork, i::Int) = tensors(x)[i] function Reactant.make_tracer( seen, @nospecialize(prev::RT), path::Tuple, mode::Reactant.TraceMode; kwargs... @@ -26,8 +27,6 @@ function Reactant.make_tracer(seen, prev::TensorNetwork, path::Tuple, mode::Reac return TensorNetwork(tracetensors) end -Reactant.traced_getfield(x::TensorNetwork, i::Int) = tensors(x)[i] - function Reactant.make_tracer(seen, prev::Quantum, path::Tuple, mode::Reactant.TraceMode; kwargs...) tracetn = Reactant.make_tracer(seen, TensorNetwork(prev), Reactant.append_path(path, :tn), mode; kwargs...) return Quantum(tracetn, copy(prev.sites))