Skip to content

Commit

Permalink
Improve sinkhorn_gibbs (#90)
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion authored Jun 3, 2021
1 parent 7e8cbe2 commit dc4dd48
Show file tree
Hide file tree
Showing 7 changed files with 329 additions and 109 deletions.
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "OptimalTransport"
uuid = "7e02d93a-ae51-4f58-b602-d97af76e3b33"
authors = ["zsteve <[email protected]>"]
version = "0.3.7"
version = "0.3.8"

[deps]
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Expand All @@ -25,6 +25,7 @@ StatsBase = "0.33.8"
julia = "1"

[extras]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
PythonOT = "3c485715-4278-42b2-9b5f-8f00e43c12ef"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -33,4 +34,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tulip = "6dd1b50a-3aae-11e9-10b5-ef983d2400fa"

[targets]
test = ["Pkg", "PythonOT", "Random", "SafeTestsets", "Test", "Tulip"]
test = ["ForwardDiff", "Pkg", "PythonOT", "Random", "SafeTestsets", "Test", "Tulip"]
129 changes: 70 additions & 59 deletions src/OptimalTransport.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,10 @@ export ot_cost, ot_plan, wasserstein, squared2wasserstein

const MOI = MathOptInterface

include("utils.jl")
include("exact.jl")
include("wasserstein.jl")

dot_matwise(x::AbstractMatrix, y::AbstractMatrix) = dot(x, y)
function dot_matwise(x::AbstractArray, y::AbstractMatrix)
xmat = reshape(x, size(x, 1) * size(x, 2), :)
return reshape(reshape(y, 1, :) * xmat, size(x)[3:end])
end

"""
sinkhorn_gibbs(
μ, ν, K; atol=0, rtol=atol > 0 ? 0 : √eps, check_convergence=10, maxiter=1_000
Expand Down Expand Up @@ -58,11 +53,12 @@ isapprox(sum(G; dims=2), μ; atol=atol, rtol=rtol, norm=x -> norm(x, 1))
The default `rtol` depends on the types of `μ`, `ν`, and `K`. After `maxiter` iterations,
the computation is stopped.
Note that for a common kernel `K`, multiple histograms may be provided for a batch computation by passing `μ` and `ν`
as matrices whose columns `μ[:, i]` and `ν[:, i]` correspond to pairs of histograms.
The output are then matrices `u` and `v` such that `u[:, i]` and `v[:, i]` are the dual variables for `μ[:, i]` and `ν[:, i]`.
In addition, the case where one of `μ` or `ν` is a single histogram and the other a matrix of histograms is supported.
Batch computations for multiple histograms with a common Gibbs kernel `K` can be performed
by passing `μ` or `ν` as matrices whose columns correspond to histograms. It is required
that the number of source and target marginals is equal or that a single source or single
target marginal is provided (either as matrix or as vector). The optimal transport plans are
returned as three-dimensional array where `γ[:, :, i]` is the optimal transport plan for the
`i`th pair of source and target marginals.
"""
function sinkhorn_gibbs(
μ,
Expand All @@ -87,43 +83,66 @@ function sinkhorn_gibbs(
:sinkhorn_gibbs,
)
end
if (size(μ, 2) != size(ν, 2)) && (min(size(μ, 2), size(ν, 2)) > 1)
throw(
DimensionMismatch(
"Error: number of columns in μ and ν must coincide, if both are matrix valued",
),
)
end
all(sum(μ; dims=1) .≈ sum(ν; dims=1)) ||
throw(ArgumentError("source and target marginals must have the same mass"))

# checks
size2 = checksize2(μ, ν)
checkbalanced(μ, ν)

# set default values of tolerances
T = float(Base.promote_eltype(μ, ν, K))
_atol = atol === nothing ? 0 : atol
_rtol = rtol === nothing ? (_atol > zero(_atol) ? zero(T) : sqrt(eps(T))) : rtol

# initial iteration
u = if isequal(size(μ, 2), size(ν, 2))
similar(μ)
else
repeat(similar(μ[:, 1]); outer=(1, max(size(μ, 2), size(ν, 2))))
# initialize iterates
u = similar(μ, T, size(μ, 1), size2...)
v = similar(ν, T, size(ν, 1), size2...)
fill!(v, one(T))

# arrays for convergence check
Kv = similar(u)
mul!(Kv, K, v)
tmp = similar(u)
norm_μ = μ isa AbstractVector ? sum(abs, μ) : sum(abs, μ; dims=1)
if u isa AbstractMatrix
tmp2 = similar(u)
norm_uKv = similar(u, 1, size2...)
norm_diff = similar(u, 1, size2...)
_isconverged = similar(u, Bool, 1, size2...)
end
u .= μ ./ vec(sum(K; dims=2))
v = ν ./ (K' * u)
tmp1 = K * v
tmp2 = similar(u)

norm_μ = sum(abs, μ; dims=1) # for convergence check
isconverged = false
check_step = check_convergence === nothing ? 10 : check_convergence
for iter in 0:maxiter
if iter % check_step == 0
# check source marginal
# do not overwrite `tmp1` but reuse it for computing `u` if not converged
@. tmp2 = u * tmp1
norm_uKv = sum(abs, tmp2; dims=1)
@. tmp2 = μ - tmp2
norm_diff = sum(abs, tmp2; dims=1)
to_check_step = check_step
for iter in 1:maxiter
# reduce counter
to_check_step -= 1

# compute next iterate
u .= μ ./ Kv
mul!(v, K', u)
v .= ν ./ v
mul!(Kv, K, v)

# check source marginal
# always check convergence after the final iteration
if to_check_step <= 0 || iter == maxiter
# reset counter
to_check_step = check_step

# do not overwrite `Kv` but reuse it for computing `u` if not converged
tmp .= u .* Kv
if u isa AbstractMatrix
tmp2 .= abs.(tmp)
sum!(norm_uKv, tmp2)
else
norm_uKv = sum(abs, tmp)
end
tmp .= abs.(μ .- tmp)
if u isa AbstractMatrix
sum!(norm_diff, tmp)
else
norm_diff = sum(tmp)
end

@debug "Sinkhorn algorithm (" *
string(iter) *
Expand All @@ -133,20 +152,17 @@ function sinkhorn_gibbs(
string(maximum(norm_diff))

# check stopping criterion
if all(@. norm_diff < max(_atol, _rtol * max(norm_μ, norm_uKv)))
isconverged = if u isa AbstractMatrix
@. _isconverged = norm_diff < max(_atol, _rtol * max(norm_μ, norm_uKv))
all(_isconverged)
else
norm_diff < max(_atol, _rtol * max(norm_μ, norm_uKv))
end
if isconverged
@debug "Sinkhorn algorithm ($iter/$maxiter): converged"
isconverged = true
break
end
end

# perform next iteration
if iter < maxiter
@. u = μ / tmp1
mul!(v, K', u)
@. v = ν / v
mul!(tmp1, K, v)
end
end

if !isconverged
Expand All @@ -156,13 +172,6 @@ function sinkhorn_gibbs(
return u, v
end

function add_singleton(x::AbstractArray, ::Val{dim}) where {dim}
shape = ntuple(ndims(x) + 1) do i
return i < dim ? size(x, i) : (i > dim ? size(x, i - 1) : 1)
end
return reshape(x, shape)
end

"""
sinkhorn(
μ, ν, C, ε; atol=0, rtol=atol > 0 ? 0 : √eps, check_convergence=10, maxiter=1_000
Expand All @@ -188,10 +197,12 @@ isapprox(sum(G; dims=2), μ; atol=atol, rtol=rtol, norm=x -> norm(x, 1))
The default `rtol` depends on the types of `μ`, `ν`, and `C`. After `maxiter` iterations,
the computation is stopped.
Note that for a common cost `C`, multiple histograms may be provided for a batch computation by passing `μ` and `ν`
as matrices whose columns `μ[:, i]` and `ν[:, i]` correspond to pairs of histograms.
The output in this case is an `Array` `γ` of coupling matrices such that `γ[:, :, i]` is a coupling of `μ[:, i]` and `ν[:, i]`.
Batch computations for multiple histograms with a common cost matrix `C` can be performed by
passing `μ` or `ν` as matrices whose columns correspond to histograms. It is required that
the number of source and target marginals is equal or that a single source or single target
marginal is provided (either as matrix or as vector). The optimal transport plans are
returned as three-dimensional array where `γ[:, :, i]` is the optimal transport plan for the
`i`th pair of source and target marginals.
See also: [`sinkhorn2`](@ref)
"""
Expand Down
56 changes: 56 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""
add_singleton(x::AbstractArray, ::Val{dim}) where {dim}
Add an additional dimension `dim` of size 1 to array `x`.
"""
function add_singleton(x::AbstractArray, ::Val{dim}) where {dim}
shape = ntuple(max(ndims(x) + 1, dim)) do i
return i < dim ? size(x, i) : (i > dim ? size(x, i - 1) : 1)
end
return reshape(x, shape)
end

"""
dot_matwise(x::AbstractArray, y::AbstractArray)
Compute the inner product of all matrices in `x` and `y`.
At least one of `x` and `y` has to be a matrix.
"""
dot_matwise(x::AbstractMatrix, y::AbstractMatrix) = dot(x, y)
function dot_matwise(x::AbstractArray, y::AbstractMatrix)
xmat = reshape(x, size(x, 1) * size(x, 2), :)
return reshape(reshape(y, 1, :) * xmat, size(x)[3:end])
end
dot_matwise(x::AbstractMatrix, y::AbstractArray) = dot_matwise(y, x)

"""
checksize2(x::AbstractVecOrMat, y::AbstractVecOrMat)
Check if arrays `x` and `y` are compatible, then return a tuple of its broadcasted second
dimension.
"""
checksize2(::AbstractVector, ::AbstractVector) = ()
function checksize2::AbstractVecOrMat, ν::AbstractVecOrMat)
size_μ_2 = size(μ, 2)
size_ν_2 = size(ν, 2)
if size_μ_2 > 1 && size_ν_2 > 1 && size_μ_2 != size_ν_2
throw(DimensionMismatch("size of source and target marginals is not compatible"))
end
return (max(size_μ_2, size_ν_2),)
end

"""
checkbalanced(μ::AbstractVecOrMat, ν::AbstractVecOrMat)
Check that source and target marginals `μ` and `ν` are balanced.
"""
function checkbalanced::AbstractVector, ν::AbstractVector)
sum(μ) sum(ν) || throw(ArgumentError("source and target marginals are not balanced"))
return nothing
end
function checkbalanced(x::AbstractVecOrMat, y::AbstractVecOrMat)
all(isapprox.(sum(x; dims=1), sum(y; dims=1))) ||
throw(ArgumentError("source and target marginals are not balanced"))
return nothing
end
Loading

2 comments on commit dc4dd48

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/38091

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.3.8 -m "<description of version>" dc4dd48798d35dfd5c1ef017dafb2bf6b523cb24
git push origin v0.3.8

Please sign in to comment.