Skip to content

Commit

Permalink
Move Quantum methods to AbstractQuantum
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed Aug 6, 2024
1 parent 27fd010 commit 2d65c78
Showing 1 changed file with 141 additions and 142 deletions.
283 changes: 141 additions & 142 deletions src/Quantum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,158 +6,79 @@ Its subtypes must implement conversion or extraction of the underlying `Quantum`
"""
abstract type AbstractQuantum <: AbstractTensorNetwork end

"""
Quantum
Tensor Network with a notion of "causality". This leads to the notion of sites and directionality (input/output).
# `AbstractTensorNetwork` interface
TensorNetwork(tn::AbstractQuantum) = TensorNetwork(Quantum(tn))

# Notes
inds(tn::AbstractQuantum, ::Val{:at}, site::Site) = inds(Quantum(tn), Val(:at), site)
tensors(tn::AbstractQuantum, ::Val{:at}, site::Site) = only(tensors(tn; intersects=inds(tn; at=site)))

- Indices are referenced by `Site`s.
# `AbstractQuantum` interface
# TODO would be simpler and easier by overloading `Core.kwcall`? ⚠️ it's an internal implementation detail
"""
struct Quantum <: AbstractQuantum
tn::TensorNetwork

# WARN keep them synchronized
sites::Dict{Site,Symbol}
# sitetensors::Dict{Site,Tensor}
sites(q::AbstractQuantum)
function Quantum(tn::TensorNetwork, sites)
for (_, index) in sites
if !haskey(tn.indexmap, index)
error("Index $index not found in TensorNetwork")
elseif index inds(tn; set=:open)
error("Index $index must be open")
end
end

# sitetensors = map(sites) do (site, index)
# site => tn[index]
# end |> Dict{Site,Tensor}

return new(tn, sites)
end
end

Quantum(qtn::Quantum) = qtn

"""
TensorNetwork(q::Quantum)
Returns the underlying `TensorNetwork` of a [`Quantum`](@ref) Tensor Network.
Returns the sites of a [`AbstractQuantum`](@ref) Tensor Network.
"""
TensorNetwork(q::Quantum) = q.tn
TensorNetwork(q::AbstractQuantum) = TensorNetwork(Quantum(q))

Base.copy(q::Quantum) = Quantum(copy(TensorNetwork(q)), copy(q.sites))

Base.similar(q::Quantum) = Quantum(similar(TensorNetwork(q)), copy(q.sites))
Base.zero(q::Quantum) = Quantum(zero(TensorNetwork(q)), copy(q.sites))

Base.:(==)(a::Quantum, b::Quantum) = a.tn == b.tn && a.sites == b.sites
Base.isapprox(a::Quantum, b::Quantum; kwargs...) = isapprox(a.tn, b.tn; kwargs...) && a.sites == b.sites

"""
adjoint(q::Quantum)
Returns the adjoint of a [`Quantum`](@ref) Tensor Network; i.e. the conjugate Tensor Network with the inputs and outputs swapped.
"""
function Base.adjoint(qtn::Quantum)
sites = Dict{Site,Symbol}(
Iterators.map(qtn.sites) do (site, index)
site' => index
end,
)

tn = conj(TensorNetwork(qtn))

# rename inner indices
physical_inds = values(sites)
virtual_inds = setdiff(inds(tn), physical_inds)
replace!(tn, map(virtual_inds) do i
i => Symbol(i, "'")
end...)

return Quantum(tn, sites)
function sites(tn::AbstractQuantum; kwargs...)
isempty(kwargs) && return sites(tn, Val(:set), :all)
key = only(keys(kwargs))
value = values(kwargs)[key]
return sites(tn, Val(key), value)
end

"""
ninputs(q::Quantum)
Returns the number of input sites of a [`Quantum`](@ref) Tensor Network.
"""
ninputs(q::Quantum) = count(isdual, keys(q.sites))

"""
noutputs(q::Quantum)
Returns the number of output sites of a [`Quantum`](@ref) Tensor Network.
"""
noutputs(q::Quantum) = count(!isdual, keys(q.sites))
nsites(tn::AbstractQuantum; kwargs...) = nsites(Quantum(tn); kwargs...)

"""
inputs(q::Quantum)
Returns the input sites of a [`Quantum`](@ref) Tensor Network.
"""
inputs(q::Quantum) = sort!(collect(filter(isdual, keys(q.sites))))
# inputs(q::Quantum) = sort!(collect(filter(isdual, keys(q.sites))))
@deprecate inputs(tn::AbstractQuantum) sites(tn; set=:inputs)

"""
outputs(q::Quantum)
Returns the output sites of a [`Quantum`](@ref) Tensor Network.
"""
outputs(q::Quantum) = sort!(collect(filter(!isdual, keys(q.sites))))

Base.summary(io::IO, q::Quantum) = print(io, "$(length(q.tn.tensormap))-tensors Quantum")
Base.show(io::IO, q::Quantum) = print(io, "Quantum (inputs=$(ninputs(q)), outputs=$(noutputs(q)))")
# outputs(q::Quantum) = sort!(collect(filter(!isdual, keys(q.sites))))
@deprecate outputs(tn::AbstractQuantum) sites(tn; set=:outputs)

"""
sites(q::Quantum)
Returns the sites of a [`Quantum`](@ref) Tensor Network.
"""
function sites(tn::Quantum; kwargs...)
if isempty(kwargs)
collect(keys(tn.sites))
elseif keys(kwargs) === (:at,)
findfirst(i -> i === kwargs[:at], tn.sites)
else
throw(MethodError(sites, (Quantum,), kwargs))
end
end

"""
nsites(q::Quantum)
ninputs(q::Quantum)
Returns the number of sites of a [`Quantum`](@ref) Tensor Network.
Returns the number of input sites of a [`Quantum`](@ref) Tensor Network.
"""
nsites(tn::Quantum) = length(tn.sites)
# ninputs(q::Quantum) = count(isdual, keys(q.sites))
@deprecate ninputs(tn::AbstractQuantum) nsites(tn; set=:inputs)

"""
lanes(q::Quantum)
noutputs(q::Quantum)
Returns the lanes of a [`Quantum`](@ref) Tensor Network.
Returns the number of output sites of a [`Quantum`](@ref) Tensor Network.
"""
lanes(tn::Quantum) = unique(
Iterators.map(Iterators.flatten([inputs(tn), outputs(tn)])) do site
isdual(site) ? site' : site
end,
)
# noutputs(q::Quantum) = count(!isdual, keys(q.sites))
@deprecate noutputs(tn::AbstractQuantum) nsites(tn; set=:outputs)

"""
nlanes(q::Quantum)
lanes(q::AbstractQuantum)
Returns the number of lanes of a [`Quantum`](@ref) Tensor Network.
Returns the lanes of a [`AbstractQuantum`](@ref) Tensor Network.
"""
nlanes(tn::Quantum) = length(lanes(tn))
function lanes(tn::AbstractQuantum)
return unique(
Iterators.map(Iterators.flatten([inputs(tn), outputs(tn)])) do site
isdual(site) ? site' : site
end,
)
end

"""
getindex(q::Quantum, site::Site)
nlanes(q::AbstractQuantum)
Returns the index associated with a site in a [`Quantum`](@ref) Tensor Network.
Returns the number of lanes of a [`AbstractQuantum`](@ref) Tensor Network.
"""
Base.getindex(q::Quantum, site::Site) = inds(q; at=site)
nlanes(tn::AbstractQuantum) = length(lanes(tn))

"""
Socket
Expand All @@ -178,7 +99,7 @@ struct Scalar <: Socket end
Socket representing a state; i.e. a Tensor Network with only input sites (or only output sites if `dual = true`).
"""
Base.@kwdef struct State <: Socket
@kwdef struct State <: Socket
dual::Bool = false
end

Expand All @@ -194,7 +115,7 @@ struct Operator <: Socket end
Returns the socket of a [`Quantum`](@ref) Tensor Network; i.e. whether it is a [`Scalar`](@ref), [`State`](@ref) or [`Operator`](@ref).
"""
function socket(q::Quantum)
function socket(q::AbstractQuantum)
_sites = sites(q)
if isempty(_sites)
Scalar()
Expand All @@ -207,44 +128,122 @@ function socket(q::Quantum)
end
end

# forward `TensorNetwork` methods
for f in [:(Tenet.arrays), :(Base.collect)]
@eval $f(@nospecialize tn::Quantum) = $f(TensorNetwork(tn))
"""
Quantum
Tensor Network with a notion of "causality". This leads to the notion of sites and directionality (input/output).
# Notes
- Indices are referenced by `Site`s.
"""
struct Quantum <: AbstractQuantum
tn::TensorNetwork

# WARN keep them synchronized
sites::Dict{Site,Symbol}
# sitetensors::Dict{Site,Tensor}

function Quantum(tn::TensorNetwork, sites)
for (_, index) in sites
if !haskey(tn.indexmap, index)
error("Index $index not found in TensorNetwork")
elseif index inds(tn; set=:open)
error("Index $index must be open")
end
end

# sitetensors = map(sites) do (site, index)
# site => tn[index]
# end |> Dict{Site,Tensor}

return new(tn, sites)
end
end

Quantum(qtn::Quantum) = qtn

"""
inds(tn::Quantum, set::Symbol = :all, args...; kwargs...)
TensorNetwork(q::Quantum)
Returns the underlying `TensorNetwork` of a [`Quantum`](@ref) Tensor Network.
"""
TensorNetwork(q::Quantum) = q.tn

Options:
Base.copy(q::Quantum) = Quantum(copy(TensorNetwork(q)), copy(q.sites))

Base.similar(q::Quantum) = Quantum(similar(TensorNetwork(q)), copy(q.sites))
Base.zero(q::Quantum) = Quantum(zero(TensorNetwork(q)), copy(q.sites))

Base.:(==)(a::Quantum, b::Quantum) = a.tn == b.tn && a.sites == b.sites
Base.isapprox(a::Quantum, b::Quantum; kwargs...) = isapprox(a.tn, b.tn; kwargs...) && a.sites == b.sites

Base.summary(io::IO, q::Quantum) = print(io, "$(length(q.tn.tensormap))-tensors Quantum")
Base.show(io::IO, q::Quantum) = print(io, "Quantum (inputs=$(ninputs(q)), outputs=$(noutputs(q)))")

Tenet.inds(tn::Quantum, ::Val{:at}, site::Site) = Quantum(tn).sites[site]

- `:at`: index at a site
"""
function Tenet.inds(tn::Quantum; kwargs...)
if keys(kwargs) === (:at,)
inds(tn, Val(:at), kwargs[:at])
adjoint(q::Quantum)
Returns the adjoint of a [`Quantum`](@ref) Tensor Network; i.e. the conjugate Tensor Network with the inputs and outputs swapped.
"""
function Base.adjoint(qtn::Quantum)
sites = Dict{Site,Symbol}(
Iterators.map(qtn.sites) do (site, index)
site' => index
end,
)

tn = conj(TensorNetwork(qtn))

# rename inner indices
physical_inds = values(sites)
virtual_inds = setdiff(inds(tn), physical_inds)
replace!(tn, map(virtual_inds) do i
i => Symbol(i, "'")
end...)

return Quantum(tn, sites)
end

function sites(tn::AbstractQuantum, ::Val{:set}, query)
tn = Quantum(tn)

if query === :all
collect(keys(tn.sites))
elseif query === :inputs
filter(isdual, keys(tn.sites))
elseif query === :outputs
filter(!isdual, keys(tn.sites))
else
inds(TensorNetwork(tn); kwargs...)
throw(MethodError(sites, (Quantum,), kwargs))
end
end

Tenet.inds(tn::Quantum, ::Val{:at}, site::Site) = tn.sites[site]
# sites(tn::AbstractQuantum, ::Val{:at}, i) = findfirst(i -> i === kwargs[:at], tn.sites)

"""
tensors(tn::Quantum, query::Symbol, args...; kwargs...)
Options:
nsites(q::Quantum)
- `:at`: tensor at a site
Returns the number of sites of a [`Quantum`](@ref) Tensor Network.
"""
function Tenet.tensors(tn::Quantum; kwargs...)
if keys(kwargs) === (:at,)
tensors(tn, Val(:at), kwargs[:at])
else
tensors(TensorNetwork(tn); kwargs...)
function nsites(tn::Quantum; set=:all)
if set === :all
length(tn.sites)
elseif set === :inputs
length(sites(tn; set))
elseif set === :outputs
length(sites(tn; set))
end
end

Tenet.tensors(tn::Quantum, ::Val{:at}, site::Site) = only(tensors(tn; intersects=inds(tn; at=site)))
"""
getindex(q::Quantum, site::Site)
Returns the index associated with a site in a [`Quantum`](@ref) Tensor Network.
"""
@deprecate Base.getindex(q::Quantum, site::Site) inds(q; at=site)

# TODO use interfaces/abstract types for better composition of functionality
@inline function Base.replace!(tn::Quantum, old_new::P...) where {P<:Pair}
Expand Down Expand Up @@ -275,9 +274,9 @@ function reindex!(a::Quantum, ioa, b::Quantum, iob)
resetindex!(b; init=ninds(TensorNetwork(a)) + 1)

sitesb = if iob === :inputs
inputs(b)
collect(inputs(b))
elseif iob === :outputs
outputs(b)
collect(outputs(b))
else
error("Invalid argument: :$iob")
end
Expand Down

0 comments on commit 2d65c78

Please sign in to comment.