Skip to content

Commit

Permalink
Fix reindex! and replace! (#178)
Browse files Browse the repository at this point in the history
* Fix reindex! by returing the final index replace mapping

* Small fix in reindex function

* Swap site indices to ensure proper input/output matching

* Add tests for reindex! function

* Fix code and tests

* Refactor `replace!` methods for `TensorNetwork`

* Implement `replace!` for `Quantum` and make `reindex!` use it

* Enhance tests by checking tensor inds

* Add complicated replacement test

* Fix `replace!` on `TensorNetwork` with overlapped indices

* Refactor tests

* Refactor test

* Fix test

* Implement `resetindex!` for `Quantum`

* Fix `reindex!`

Now it also modifies the indices of `a`

* Fix typo in tests

---------

Co-authored-by: Sergio Sánchez Ramírez <[email protected]>
  • Loading branch information
jofrevalles and mofeing authored Aug 6, 2024
1 parent e8ef246 commit b067885
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 45 deletions.
45 changes: 37 additions & 8 deletions src/Quantum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,34 @@ end

Tenet.tensors(tn::Quantum, ::Val{:at}, site::Site) = only(tensors(tn; intersects=inds(tn; at=site)))

# TODO use interfaces/abstract types for better composition of functionality
@inline function Base.replace!(tn::Quantum, old_new::P...) where {P<:Pair}
return invoke(replace!, Tuple{Quantum,Base.AbstractVecOrTuple{P}}, tn, old_new)
end
@inline Base.replace!(tn::Quantum, old_new::Dict) = replace!(tn, collect(old_new))

function Base.replace!(tn::Quantum, old_new::Base.AbstractVecOrTuple{Pair{Symbol,Symbol}})
# replace indices in underlying Tensor Network
replace!(TensorNetwork(tn), old_new)

# replace indices in site information
from, to = first.(old_new), last.(old_new)
for (site, index) in tn.sites
i = findfirst(==(index), from)
if !isnothing(i)
tn.sites[site] = to[i]
end
end

return tn
end

function reindex!(a::Quantum, ioa, b::Quantum, iob)
ioa [:inputs, :outputs] || error("Invalid argument: :$ioa")

resetindex!(a)
resetindex!(b; init=ninds(TensorNetwork(a)) + 1)

sitesb = if iob === :inputs
inputs(b)
elseif iob === :outputs
Expand All @@ -256,21 +281,25 @@ function reindex!(a::Quantum, ioa, b::Quantum, iob)
return b
end

resetindex_mapping = resetindex!(Val(:return_mapping), TensorNetwork(b); init=ninds(TensorNetwork(a)))
replacements = merge!(resetindex_mapping, Dict(replacements))
replace!(TensorNetwork(b), replacements...)

for site in sitesb
b.sites[site] = inds(a; at=ioa != iob ? site' : site)
end
replace!(b, replacements)

return b
end

function resetindex!(tn::Quantum; init=1)
mapping = resetindex!(Val(:return_mapping), TensorNetwork(tn); init)

replace!(TensorNetwork(tn), mapping)

for (site, index) in tn.sites
tn.sites[site] = mapping[index]
end
end

"""
@reindex! a => b
Reindexes the input/output sites of a [`Quantum`](@ref) Tensor Network `b` to match the input/output sites of another [`Quantum`](@ref) Tensor Network `a`.
Reindexes the input/output sites of two [`Quantum`](@ref) Tensor Networks to be able to connect between them.
"""
macro reindex!(expr)
@assert Meta.isexpr(expr, :call) && expr.args[1] == :(=>)
Expand Down
72 changes: 36 additions & 36 deletions src/TensorNetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -376,27 +376,37 @@ Replace the element in `old` with the one in `new`. Depending on the types of `o
- If `Symbol`s, it will correspond to a index renaming.
- If `Tensor`s, first element that satisfies _egality_ (`≡` or `===`) will be replaced.
"""
Base.replace!(tn::TensorNetwork, old_new::Pair...) = replace!(tn, old_new)
@inline function Base.replace!(tn::TensorNetwork, old_new::P...) where {P<:Pair}
return invoke(replace!, Tuple{TensorNetwork,Base.AbstractVecOrTuple{P}}, tn, old_new)
end
@inline Base.replace!(tn::TensorNetwork, old_new::Dict) = replace!(tn, collect(old_new))

function Base.replace!(tn::TensorNetwork, old_new::Base.AbstractVecOrTuple{Pair})
for pair in old_new
replace!(tn, pair)
end
return tn
end

Base.replace(tn::TensorNetwork, old_new::Pair...) = replace(tn, old_new)
Base.replace(tn::TensorNetwork, old_new) = replace!(copy(tn), old_new)

function Base.replace!(tn::TensorNetwork, pair::Pair{<:Tensor,<:Tensor})
old_tensor, new_tensor = pair
issetequal(inds(new_tensor), inds(old_tensor)) || throw(ArgumentError("replacing tensor indices don't match"))

push!(tn, new_tensor)
delete!(tn, old_tensor)

function Base.replace!(tn::TensorNetwork, old_new::Pair{Symbol,Symbol})
old, new = old_new
old keys(tn.indexmap) || throw(ArgumentError("index $old does not exist"))
old == new && return tn
new keys(tn.indexmap) || throw(ArgumentError("index $new is already present"))
# NOTE `copy` because collection underneath is mutated
for tensor in copy(tn.indexmap[old])
# NOTE do not `delete!` before `push!` as indices can be lost due to `tryprune!`
push!(tn, replace(tensor, old_new))
delete!(tn, tensor)
end
delete!(tn.indexmap, old)
return tn
end

function Base.replace!(tn::TensorNetwork, old_new::Pair{Symbol,Symbol}...)
function Base.replace!(tn::TensorNetwork, old_new::Base.AbstractVecOrTuple{Pair{Symbol,Symbol}})
from, to = first.(old_new), last.(old_new)
allinds = inds(tn)

Expand All @@ -410,43 +420,33 @@ function Base.replace!(tn::TensorNetwork, old_new::Pair{Symbol,Symbol}...)
),
)

from′ = setdiff(from, to)
to′ = setdiff(to, from)

# no overlap so easy replacement
for (f, t) in zip(from′, to′)
replace!(tn, f => t)
end

# overlap between old and new indices => need a temporary name `replace!`
overlap = from to
if !isempty(overlap)
tmp = Dict([i => gensym(i) for i in overlap])
if isempty(overlap)
# no overlap so easy replacement
for (f, t) in zip(from, to)
replace!(tn, f => t)
end
else
# overlap between old and new indices => need a temporary name `replace!`
tmp = Dict([i => gensym(i) for i in from])

# replace old indices with temporary names
replace!(tn, pairs(tmp)...)
replace!(tn, tmp)

# replace temporary names with new indices
replace!(tn, [tmp[i] => i for i in Iterators.filter((overlap), to)]...)
replace!(tn, [tmp[f] => t for (f, t) in zip(from, to)])
end

# return the final index mapping
return tn
end

function Base.replace!(tn::TensorNetwork, old_new::Pair{Symbol,Symbol})
old, new = old_new
old keys(tn.indexmap) || throw(ArgumentError("index $old does not exist"))
old == new && return tn
new keys(tn.indexmap) || throw(ArgumentError("index $new is already present"))

# NOTE `copy` because collection underneath is mutated
for tensor in copy(tn.indexmap[old])
# NOTE do not `delete!` before `push!` as indices can be lost due to `tryprune!`
push!(tn, replace(tensor, old_new))
delete!(tn, tensor)
end
function Base.replace!(tn::TensorNetwork, pair::Pair{<:Tensor,<:Tensor})
old_tensor, new_tensor = pair
issetequal(inds(new_tensor), inds(old_tensor)) || throw(ArgumentError("replacing tensor indices don't match"))

delete!(tn.indexmap, old)
push!(tn, new_tensor)
delete!(tn, old_tensor)

return tn
end
Expand All @@ -456,7 +456,7 @@ function Base.replace!(tn::TensorNetwork, old_new::Pair{<:Tensor,<:TensorNetwork
issetequal(inds(new; set=:open), inds(old)) || throw(ArgumentError("indices don't match"))

# rename internal indices so there is no accidental hyperedge
replace!(new, [index => Symbol(uuid4()) for index in filter((inds(tn)), inds(new; set=:inner))]...)
replace!(new, [index => Symbol(uuid4()) for index in filter((inds(tn)), inds(new; set=:inner))])

merge!(tn, new)
delete!(tn, old)
Expand Down
76 changes: 76 additions & 0 deletions test/Quantum_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,80 @@
tn = TensorNetwork(_tensors)
@test_throws ErrorException Quantum(tn, Dict(site"1" => :j))
@test_throws ErrorException Quantum(tn, Dict(site"1" => :i))

@testset "reindex!" begin
@testset "manual indices" begin
# mps-like tensor network
mps = Quantum(
TensorNetwork(
Tensor[
Tensor(rand(2, 2), [:i, :j]), Tensor(rand(2, 2, 2), [:j, :k, :l]), Tensor(rand(2, 2), [:l, :m])
],
),
Dict(site"1" => :i, site"2" => :k, site"3" => :m),
)

# mpo-like tensor network
mpo = Quantum(
TensorNetwork(
Tensor[
Tensor(rand(2, 2, 2), [:i, :j, :k]),
Tensor(rand(2, 2, 2, 2), [:l, :m, :k, :n]),
Tensor(rand(2, 2, 2), [:o, :p, :n]),
],
),
Dict(site"1" => :i, site"1'" => :j, site"2" => :l, site"2'" => :m, site"3" => :o, site"3'" => :p),
)

Tenet.@reindex! outputs(mps) => inputs(mpo)

@test issetequal([inds(mps; at=i) for i in outputs(mps)], [inds(mpo; at=i) for i in inputs(mpo)])

# test that the both inputs/outputs appear on the corresponding tensor
@test all(Site.(1:3)) do i
inds(mps; at=i) inds(tensors(mpo; at=i))
end

@test all(Site.(1:3)) do i
(inds(mpo; at=i), inds(mpo; at=i')) inds(tensors(mpo; at=i))
end
end

@testset "regular indices" begin
# mps-like tensor network
mps = Quantum(
TensorNetwork(
Tensor[
Tensor(rand(2, 2), [:A, :B]), Tensor(rand(2, 2, 2), [:C, :B, :D]), Tensor(rand(2, 2), [:E, :D])
],
),
Dict(site"1" => :A, site"2" => :C, site"3" => :E),
)

# mpo-like tensor network
mpo = Quantum(
TensorNetwork(
Tensor[
Tensor(rand(2, 2, 2), [:A, :B, :C]),
Tensor(rand(2, 2, 2, 2), [:D, :E, :C, :F]),
Tensor(rand(2, 2, 2), [:F, :G, :H]),
],
),
Dict(site"1" => :A, site"1'" => :B, site"2" => :D, site"2'" => :E, site"3" => :G, site"3'" => :H),
)

Tenet.@reindex! outputs(mps) => inputs(mpo)

@test issetequal([inds(mps; at=i) for i in outputs(mps)], [inds(mpo; at=i) for i in inputs(mpo)])

# test that the both inputs/outputs appear on the corresponding tensor
@test all(Site.(1:3)) do i
inds(mps; at=i) inds(tensors(mpo; at=i))
end

@test all(Site.(1:3)) do i
(inds(mpo; at=i), inds(mpo; at=i')) inds(tensors(mpo; at=i))
end
end
end
end
34 changes: 33 additions & 1 deletion test/TensorNetwork_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@
@test t_ilm === tn[:i, :l, :m]
@test t_lm === tn[:l, :m]

# NOTE although it should throw `KeyError`, it throws `ArgumentError` due to implementation
# NOTE although it should throw `KeyError`, it throws `ArgumentError` due to implementation
@test_throws ArgumentError tn[:i, :x]
@test_throws ArgumentError tn[:i, :j, :k]
end
Expand Down Expand Up @@ -526,6 +526,38 @@
@test issetequal(finalinds, inds(tnA))
@test issetequal(finaltensors, tensors(tnA))
end

@testset "overlapping replacement" begin
A = Tensor(rand(2, 2, 2), (:F, :A, :K))
B = Tensor(rand(2, 2, 2, 2), (:K, :G, :B, :L))
C = Tensor(rand(2, 2, 2, 2), (:L, :H, :C, :M))
D = Tensor(rand(2, 2, 2, 2), (:M, :I, :D, :N))
E = Tensor(rand(2, 2, 2), (:N, :J, :E))

old_new = [
:N => :N,
:F => :A,
:M => :O,
:A => :P,
:D => :J,
:B => :K,
:I => :D,
:H => :C,
:G => :B,
:J => :E,
:K => :L,
:L => :U,
:E => :M,
:C => :V,
]
tn = TensorNetwork([A, B, C, D, E])

replace!(tn, old_new...)

@test issetequal(
inds.(tensors(tn)), [[:A, :P, :L], [:L, :B, :K, :U], [:U, :C, :V, :O], [:O, :D, :J, :N], [:N, :E, :M]]
)
end
end

@testset "Base.in" begin
Expand Down

0 comments on commit b067885

Please sign in to comment.