Skip to content

Commit

Permalink
fix power=1.0 using abs on cost values when needed (#153)
Browse files Browse the repository at this point in the history
* fix power=1.0 using abs on cost values when needed

* simplify and remove jax.lax.cond + docs

* small change in doc to clarify.
  • Loading branch information
marcocuturi authored Oct 13, 2022
1 parent 9d76c04 commit fb778fc
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 31 deletions.
18 changes: 18 additions & 0 deletions docs/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,24 @@ @InProceedings{peyre:16
url = {https://proceedings.mlr.press/v48/peyre16.html},
}


@InProceedings{feydy:19,
title = {Interpolating between Optimal Transport and MMD using Sinkhorn Divergences},
author = {Feydy, Jean and S\'{e}journ\'{e}, Thibault and Vialard, Fran\c{c}ois-Xavier and Amari, Shun-ichi and Trouve, Alain and Peyr\'{e}, Gabriel},
booktitle = {Proceedings of the Twenty-Second International Conference on Artificial Intelligence and Statistics},
pages = {2681--2690},
year = {2019},
editor = {Chaudhuri, Kamalika and Sugiyama, Masashi},
volume = {89},
series = {Proceedings of Machine Learning Research},
month = {16--18 Apr},
publisher = {PMLR},
pdf = {http://proceedings.mlr.press/v89/feydy19a/feydy19a.pdf},
url = {https://proceedings.mlr.press/v89/feydy19a.html},
abstract = {Comparing probability distributions is a fundamental problem in data sciences. Simple norms and divergences such as the total variation and the relative entropy only compare densities in a point-wise manner and fail to capture the geometric nature of the problem. In sharp contrast, Maximum Mean Discrepancies (MMD) and Optimal Transport distances (OT) are two classes of distances between measures that take into account the geometry of the underlying space and metrize the convergence in law. This paper studies the Sinkhorn divergences, a family of geometric divergences that interpolates between MMD and OT. Relying on a new notion of geometric entropy, we provide theoretical guarantees for these divergences: positivity, convexity and metrization of the convergence in law. On the practical side, we detail a numerical scheme that enables the large scale application of these divergences for machine learning: on the GPU, gradients of the Sinkhorn loss can be computed for batches of a million samples.}
}


@InProceedings{cuturi:14,
title = {Fast Computation of Wasserstein Barycenters},
author = {Cuturi, Marco and Doucet, Arnaud},
Expand Down
14 changes: 10 additions & 4 deletions ott/geometry/pointcloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ class PointCloud(geometry.Geometry):
x : n x d array of n d-dimensional vectors
y : m x d array of m d-dimensional vectors. If `None`, use ``x``.
cost_fn: a CostFn function between two points in dimension d.
power: a power to raise (norm(x) + norm(y) + cost(x,y)) **
power: a power to raise `(cost_fn(x,y)) ** . / 2.0`. As a result,
`power`=2.0 is the default and means no change is applied to the output of
`cost_fn`.
batch_size: When ``None``, the cost matrix corresponding to that point cloud
is computed, stored and later re-used at each application of
:meth:`apply_lse_kernel`. When ``batch_size`` is a positive integer,
Expand Down Expand Up @@ -184,7 +186,9 @@ def _compute_cost_matrix(self) -> jnp.ndarray:
cost_matrix = self._cost_fn.all_pairs_pairwise(self.x, self.y)
if self._axis_norm is not None:
cost_matrix += self._norm_x[:, jnp.newaxis] + self._norm_y[jnp.newaxis, :]
return cost_matrix ** (0.5 * self.power)
if self.power != 2.0:
cost_matrix = jnp.abs(cost_matrix) ** (0.5 * self.power)
return cost_matrix

def apply_lse_kernel(
self,
Expand Down Expand Up @@ -762,8 +766,10 @@ def _transport_from_scalings_xy(

def _cost(x, y, norm_x, norm_y, cost_fn, cost_pow, scale_cost):
one_line_pairwise = jax.vmap(cost_fn.pairwise, in_axes=[0, None])
return ((norm_x + norm_y + one_line_pairwise(x, y)) ** (0.5 * cost_pow) *
scale_cost)
cost = norm_x + norm_y + one_line_pairwise(x, y)
if cost_pow != 2.0:
cost = jnp.abs(cost) ** (0.5 * cost_pow)
return cost * scale_cost


def _apply_cost_xy(
Expand Down
33 changes: 25 additions & 8 deletions ott/tools/sinkhorn_divergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def sinkhorn_divergence(
sinkhorn_kwargs: Mapping[str, Any] = MappingProxyType({}),
static_b: bool = False,
share_epsilon: bool = True,
symmetric_sinkhorn: bool = False,
**kwargs: Any,
) -> SinkhornDivergenceOutput:
"""Compute Sinkhorn divergence defined by a geometry, weights, parameters.
Expand All @@ -78,6 +79,8 @@ def sinkhorn_divergence(
geometry). This flag is set to True by default, because in the default
setting, the epsilon regularization is a function of the mean of the cost
matrix.
symmetric_sinkhorn: Use Sinkhorn updates in Eq. 25 of :cite:`feydy:19` for
symmetric terms comparing x/x and y/y.
kwargs: keywords arguments to the generic class. This is specific to each
geometry.
Expand All @@ -97,7 +100,13 @@ def sinkhorn_divergence(
a = jnp.ones(num_a) / num_a if a is None else a
b = jnp.ones(num_b) / num_b if b is None else b
return _sinkhorn_divergence(
geom_xy, geom_x, geom_y, a=a, b=b, **sinkhorn_kwargs
geom_xy,
geom_x,
geom_y,
a=a,
b=b,
symmetric_sinkhorn=symmetric_sinkhorn,
**sinkhorn_kwargs
)


Expand All @@ -107,6 +116,7 @@ def _sinkhorn_divergence(
geometry_yy: Optional[geometry.Geometry],
a: jnp.ndarray,
b: jnp.ndarray,
symmetric_sinkhorn: bool,
**kwargs: Any,
) -> SinkhornDivergenceOutput:
"""Compute the (unbalanced) sinkhorn divergence for the wrapper function.
Expand All @@ -125,6 +135,8 @@ def _sinkhorn_divergence(
all elements of b must match that of a to converge.
b: jnp.ndarray<float>[m]: the weight of each target point. The sum of
all elements of b must match that of a to converge.
symmetric_sinkhorn: Use Sinkhorn updates in Eq. 25 of :cite:`feydy:19` for
symmetric terms comparing x/x and y/y.
kwargs: Keyword arguments to :func:`ott.core.sinkhorn.sinkhorn`.
Returns:
Expand All @@ -141,13 +153,14 @@ def _sinkhorn_divergence(
# arising in implicit differentiation (if used) of the potentials computed for
# the symmetric parts should be marked as symmetric.
kwargs_symmetric = kwargs.copy()
kwargs_symmetric.update(
parallel_dual_updates=True,
momentum=0.5,
chg_momentum_from=0,
anderson_acceleration=0,
implicit_solver_symmetric=True
)
if symmetric_sinkhorn:
kwargs_symmetric.update(
parallel_dual_updates=True,
momentum=0.5,
chg_momentum_from=0,
anderson_acceleration=0,
implicit_solver_symmetric=True
)

out_xy = sinkhorn.sinkhorn(geometry_xy, a, b, **kwargs)
out_xx = sinkhorn.sinkhorn(geometry_xx, a, a, **kwargs_symmetric)
Expand Down Expand Up @@ -184,6 +197,7 @@ def segment_sinkhorn_divergence(
sinkhorn_kwargs: Mapping[str, Any] = MappingProxyType({}),
static_b: bool = False,
share_epsilon: bool = True,
symmetric_sinkhorn: bool = False,
**kwargs: Any
) -> jnp.ndarray:
"""Compute sinkhorn divergence between subsets of vectors given in `x` & `y`.
Expand Down Expand Up @@ -242,6 +256,8 @@ def segment_sinkhorn_divergence(
geometry). This flag is set to True by default, because in the default
setting, the epsilon regularization is a function of the mean of the cost
matrix.
symmetric_sinkhorn: Use Sinkhorn updates in Eq. 25 of :cite:`feydy:19` for
symmetric terms comparing x/x and y/y.
kwargs: keywords arguments passed to form
:class:`ott.geometry.pointcloud.PointCloud` geometry objects from the
subsets of points and masses selected in `x` and `y`, this could be for
Expand Down Expand Up @@ -274,6 +290,7 @@ def eval_fn(
sinkhorn_kwargs=sinkhorn_kwargs,
static_b=static_b,
share_epsilon=share_epsilon,
symmetric_sinkhorn=symmetric_sinkhorn,
cost_fn=cost_fn,
src_mask=mask_x,
tgt_mask=mask_y,
Expand Down
11 changes: 6 additions & 5 deletions tests/core/sinkhorn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,19 @@ def initialize(self, rng: jnp.ndarray):
self.b = b / jnp.sum(b)

@pytest.mark.fast.with_args(
"lse_mode,momentum,chg_momentum_from,inner_iterations,norm_error",
[(True, 1.0, 29, 10, 1), (False, 1.0, 30, 10, 1), (True, 1.0, 60, 1, 2),
(True, 1.0, 12, 24, 4)],
"lse_mode,momentum,chg_momentum_from,inner_iterations,norm_error,power",
[(True, 1.0, 29, 10, 1, 2.0), (False, 1.0, 30, 10, 1, 2.2),
(True, 1.0, 60, 1, 2, 1.0), (True, 1.0, 12, 24, 4, 3.0)],
ids=["lse-Leh-mom", "scal-Leh-mom", "lse-Leh-1", "lse-Leh-24"],
only_fast=[0, -1],
)
def test_euclidean_point_cloud(
self, lse_mode, momentum, chg_momentum_from, inner_iterations, norm_error
self, lse_mode, momentum, chg_momentum_from, inner_iterations, norm_error,
power
):
"""Two point clouds, tested with various parameters."""
threshold = 1e-3
geom = pointcloud.PointCloud(self.x, self.y, epsilon=0.1)
geom = pointcloud.PointCloud(self.x, self.y, epsilon=0.1, power=power)
out = sinkhorn.sinkhorn(
geom,
a=self.a,
Expand Down
48 changes: 34 additions & 14 deletions tests/tools/sinkhorn_divergence_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import numpy as np
import pytest

from ott.core import sinkhorn
from ott.geometry import costs, geometry, pointcloud
from ott.tools import sinkhorn_divergence
from ott.tools.gaussian_mixture import gaussian_mixture
Expand All @@ -38,32 +39,51 @@ def setUp(self, rng: jnp.ndarray):
self._a = a / jnp.sum(a)
self._b = b / jnp.sum(b)

def test_euclidean_point_cloud(self):
@pytest.mark.fast.with_args(
power=[1.0, 2.0, 2.7],
epsilon=[.01, .001],
only_fast={
"power": 2.0,
"epsilon": .01
},
)
def test_euclidean_point_cloud(self, power, epsilon):
rngs = jax.random.split(self.rng, 2)
x = jax.random.uniform(rngs[0], (self._num_points[0], self._dim))
y = jax.random.uniform(rngs[1], (self._num_points[1], self._dim))
geometry_xx = pointcloud.PointCloud(x, x, epsilon=0.01)
geometry_xy = pointcloud.PointCloud(x, y, epsilon=0.01)
geometry_yy = pointcloud.PointCloud(y, y, epsilon=0.01)
div = sinkhorn_divergence._sinkhorn_divergence(
geometry_xy, geometry_xx, geometry_yy, self._a, self._b, threshold=1e-2

div = sinkhorn_divergence.sinkhorn_divergence(
pointcloud.PointCloud,
x,
y,
a=self._a,
b=self._b,
epsilon=epsilon,
power=power
)
assert div.divergence > 0.0
assert len(div.potentials) == 3

# Test symmetric setting,
# test that symmetric evaluation converges earlier/better.
geometry_xy = pointcloud.PointCloud(x, y, epsilon=epsilon, power=power)
geometry_xx = pointcloud.PointCloud(x, epsilon=epsilon, power=power)
geometry_yy = pointcloud.PointCloud(y, epsilon=epsilon, power=power)

div2 = sinkhorn.sinkhorn(geometry_xy, self._a, self._b).reg_ot_cost
div2 -= 0.5 * sinkhorn.sinkhorn(geometry_xx, self._a, self._a).reg_ot_cost
div2 -= 0.5 * sinkhorn.sinkhorn(geometry_yy, self._b, self._b).reg_ot_cost

np.testing.assert_allclose(div.divergence, div2, rtol=1e-5, atol=1e-5)

# Test div of x to itself close to 0.
div = sinkhorn_divergence.sinkhorn_divergence(
pointcloud.PointCloud,
x,
x,
epsilon=1e-1,
sinkhorn_kwargs={'inner_iterations': 1}
epsilon=.1,
sinkhorn_kwargs={'inner_iterations': 1},
power=power
)
np.testing.assert_allclose(div.divergence, 0.0, rtol=1e-5, atol=1e-5)
iters_xx = jnp.sum(div.errors[0] > 0)
iters_xx_sym = jnp.sum(div.errors[1] > 0)
assert iters_xx > iters_xx_sym

@pytest.mark.fast
def test_euclidean_autoepsilon(self):
Expand All @@ -76,7 +96,7 @@ def test_euclidean_autoepsilon(self):
cloud_b,
a=self._a,
b=self._b,
sinkhorn_kwargs={"threshold": 1e-2}
sinkhorn_kwargs={"threshold": 1e-2},
)
assert div.divergence > 0.0
assert len(div.potentials) == 3
Expand Down

0 comments on commit fb778fc

Please sign in to comment.