From 59957102f927e90ac8806d0f4cfc4bc3e07b1130 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Fri, 2 Aug 2024 17:50:20 +0200 Subject: [PATCH] Fix `replace!` when old an new indices intersect --- src/TensorNetwork.jl | 38 ++++++++++++++++++++++++++++++++------ 1 file changed, 32 insertions(+), 6 deletions(-) diff --git a/src/TensorNetwork.jl b/src/TensorNetwork.jl index 4f3027b66..3389ba402 100644 --- a/src/TensorNetwork.jl +++ b/src/TensorNetwork.jl @@ -393,13 +393,39 @@ function Base.replace!(tn::TensorNetwork, pair::Pair{<:Tensor,<:Tensor}) end function Base.replace!(tn::TensorNetwork, old_new::Pair{Symbol,Symbol}...) - first.(old_new) ⊆ keys(tn.indexmap) || - throw(ArgumentError("set of old indices must be a subset of current indices")) - isdisjoint(last.(old_new), keys(tn.indexmap)) || - throw(ArgumentError("set of new indices must be disjoint to current indices")) - for pair in old_new - replace!(tn, pair) + from, to = first.(old_new), last.(old_new) + allinds = inds(tn) + + # condition: from ⊆ allinds + from ⊆ allinds || throw(ArgumentError("set of old indices must be a subset of current indices")) + + # condition: from \ to ∩ allinds = ∅ + isdisjoint(setdiff(to, from), allinds) || throw( + ArgumentError( + "new indices must be either a element of the old indices or not an element of the TensorNetwork's indices", + ), + ) + + from′ = setdiff(from, to) + to′ = setdiff(to, from) + + # no overlap so easy replacement + for (f, t) in zip(from′, to′) + replace!(tn, f => t) + end + + # overlap between old and new indices => need a temporary name `replace!` + overlap = from ∩ to + if !isempty(overlap) + tmp = Dict([i => gensym(i) for i in overlap]) + + # replace old indices with temporary names + replace!(tn, pairs(tmp)...) + + # replace temporary names with new indices + replace!(tn, [tmp[i] => i for i in Iterators.filter(∈(overlap), to)]...) end + return tn end