From fb778fcddb6f6dc83a9761e864a9f884839d9518 Mon Sep 17 00:00:00 2001 From: Marco Cuturi Date: Thu, 13 Oct 2022 21:59:49 +0200 Subject: [PATCH] fix power=1.0 using abs on cost values when needed (#153) * fix power=1.0 using abs on cost values when needed * simplify and remove jax.lax.cond + docs * small change in doc to clarify. --- docs/references.bib | 18 ++++++++++ ott/geometry/pointcloud.py | 14 +++++--- ott/tools/sinkhorn_divergence.py | 33 ++++++++++++----- tests/core/sinkhorn_test.py | 11 +++--- tests/tools/sinkhorn_divergence_test.py | 48 +++++++++++++++++-------- 5 files changed, 93 insertions(+), 31 deletions(-) diff --git a/docs/references.bib b/docs/references.bib index 826cdaf07..293009e02 100644 --- a/docs/references.bib +++ b/docs/references.bib @@ -29,6 +29,24 @@ @InProceedings{peyre:16 url = {https://proceedings.mlr.press/v48/peyre16.html}, } + +@InProceedings{feydy:19, + title = {Interpolating between Optimal Transport and MMD using Sinkhorn Divergences}, + author = {Feydy, Jean and S\'{e}journ\'{e}, Thibault and Vialard, Fran\c{c}ois-Xavier and Amari, Shun-ichi and Trouve, Alain and Peyr\'{e}, Gabriel}, + booktitle = {Proceedings of the Twenty-Second International Conference on Artificial Intelligence and Statistics}, + pages = {2681--2690}, + year = {2019}, + editor = {Chaudhuri, Kamalika and Sugiyama, Masashi}, + volume = {89}, + series = {Proceedings of Machine Learning Research}, + month = {16--18 Apr}, + publisher = {PMLR}, + pdf = {http://proceedings.mlr.press/v89/feydy19a/feydy19a.pdf}, + url = {https://proceedings.mlr.press/v89/feydy19a.html}, + abstract = {Comparing probability distributions is a fundamental problem in data sciences. Simple norms and divergences such as the total variation and the relative entropy only compare densities in a point-wise manner and fail to capture the geometric nature of the problem. In sharp contrast, Maximum Mean Discrepancies (MMD) and Optimal Transport distances (OT) are two classes of distances between measures that take into account the geometry of the underlying space and metrize the convergence in law. This paper studies the Sinkhorn divergences, a family of geometric divergences that interpolates between MMD and OT. Relying on a new notion of geometric entropy, we provide theoretical guarantees for these divergences: positivity, convexity and metrization of the convergence in law. On the practical side, we detail a numerical scheme that enables the large scale application of these divergences for machine learning: on the GPU, gradients of the Sinkhorn loss can be computed for batches of a million samples.} +} + + @InProceedings{cuturi:14, title = {Fast Computation of Wasserstein Barycenters}, author = {Cuturi, Marco and Doucet, Arnaud}, diff --git a/ott/geometry/pointcloud.py b/ott/geometry/pointcloud.py index 537cdfa43..64c10482b 100644 --- a/ott/geometry/pointcloud.py +++ b/ott/geometry/pointcloud.py @@ -41,7 +41,9 @@ class PointCloud(geometry.Geometry): x : n x d array of n d-dimensional vectors y : m x d array of m d-dimensional vectors. If `None`, use ``x``. cost_fn: a CostFn function between two points in dimension d. - power: a power to raise (norm(x) + norm(y) + cost(x,y)) ** + power: a power to raise `(cost_fn(x,y)) ** . / 2.0`. As a result, + `power`=2.0 is the default and means no change is applied to the output of + `cost_fn`. batch_size: When ``None``, the cost matrix corresponding to that point cloud is computed, stored and later re-used at each application of :meth:`apply_lse_kernel`. When ``batch_size`` is a positive integer, @@ -184,7 +186,9 @@ def _compute_cost_matrix(self) -> jnp.ndarray: cost_matrix = self._cost_fn.all_pairs_pairwise(self.x, self.y) if self._axis_norm is not None: cost_matrix += self._norm_x[:, jnp.newaxis] + self._norm_y[jnp.newaxis, :] - return cost_matrix ** (0.5 * self.power) + if self.power != 2.0: + cost_matrix = jnp.abs(cost_matrix) ** (0.5 * self.power) + return cost_matrix def apply_lse_kernel( self, @@ -762,8 +766,10 @@ def _transport_from_scalings_xy( def _cost(x, y, norm_x, norm_y, cost_fn, cost_pow, scale_cost): one_line_pairwise = jax.vmap(cost_fn.pairwise, in_axes=[0, None]) - return ((norm_x + norm_y + one_line_pairwise(x, y)) ** (0.5 * cost_pow) * - scale_cost) + cost = norm_x + norm_y + one_line_pairwise(x, y) + if cost_pow != 2.0: + cost = jnp.abs(cost) ** (0.5 * cost_pow) + return cost * scale_cost def _apply_cost_xy( diff --git a/ott/tools/sinkhorn_divergence.py b/ott/tools/sinkhorn_divergence.py index c82b05a10..b3564622d 100644 --- a/ott/tools/sinkhorn_divergence.py +++ b/ott/tools/sinkhorn_divergence.py @@ -55,6 +55,7 @@ def sinkhorn_divergence( sinkhorn_kwargs: Mapping[str, Any] = MappingProxyType({}), static_b: bool = False, share_epsilon: bool = True, + symmetric_sinkhorn: bool = False, **kwargs: Any, ) -> SinkhornDivergenceOutput: """Compute Sinkhorn divergence defined by a geometry, weights, parameters. @@ -78,6 +79,8 @@ def sinkhorn_divergence( geometry). This flag is set to True by default, because in the default setting, the epsilon regularization is a function of the mean of the cost matrix. + symmetric_sinkhorn: Use Sinkhorn updates in Eq. 25 of :cite:`feydy:19` for + symmetric terms comparing x/x and y/y. kwargs: keywords arguments to the generic class. This is specific to each geometry. @@ -97,7 +100,13 @@ def sinkhorn_divergence( a = jnp.ones(num_a) / num_a if a is None else a b = jnp.ones(num_b) / num_b if b is None else b return _sinkhorn_divergence( - geom_xy, geom_x, geom_y, a=a, b=b, **sinkhorn_kwargs + geom_xy, + geom_x, + geom_y, + a=a, + b=b, + symmetric_sinkhorn=symmetric_sinkhorn, + **sinkhorn_kwargs ) @@ -107,6 +116,7 @@ def _sinkhorn_divergence( geometry_yy: Optional[geometry.Geometry], a: jnp.ndarray, b: jnp.ndarray, + symmetric_sinkhorn: bool, **kwargs: Any, ) -> SinkhornDivergenceOutput: """Compute the (unbalanced) sinkhorn divergence for the wrapper function. @@ -125,6 +135,8 @@ def _sinkhorn_divergence( all elements of b must match that of a to converge. b: jnp.ndarray[m]: the weight of each target point. The sum of all elements of b must match that of a to converge. + symmetric_sinkhorn: Use Sinkhorn updates in Eq. 25 of :cite:`feydy:19` for + symmetric terms comparing x/x and y/y. kwargs: Keyword arguments to :func:`ott.core.sinkhorn.sinkhorn`. Returns: @@ -141,13 +153,14 @@ def _sinkhorn_divergence( # arising in implicit differentiation (if used) of the potentials computed for # the symmetric parts should be marked as symmetric. kwargs_symmetric = kwargs.copy() - kwargs_symmetric.update( - parallel_dual_updates=True, - momentum=0.5, - chg_momentum_from=0, - anderson_acceleration=0, - implicit_solver_symmetric=True - ) + if symmetric_sinkhorn: + kwargs_symmetric.update( + parallel_dual_updates=True, + momentum=0.5, + chg_momentum_from=0, + anderson_acceleration=0, + implicit_solver_symmetric=True + ) out_xy = sinkhorn.sinkhorn(geometry_xy, a, b, **kwargs) out_xx = sinkhorn.sinkhorn(geometry_xx, a, a, **kwargs_symmetric) @@ -184,6 +197,7 @@ def segment_sinkhorn_divergence( sinkhorn_kwargs: Mapping[str, Any] = MappingProxyType({}), static_b: bool = False, share_epsilon: bool = True, + symmetric_sinkhorn: bool = False, **kwargs: Any ) -> jnp.ndarray: """Compute sinkhorn divergence between subsets of vectors given in `x` & `y`. @@ -242,6 +256,8 @@ def segment_sinkhorn_divergence( geometry). This flag is set to True by default, because in the default setting, the epsilon regularization is a function of the mean of the cost matrix. + symmetric_sinkhorn: Use Sinkhorn updates in Eq. 25 of :cite:`feydy:19` for + symmetric terms comparing x/x and y/y. kwargs: keywords arguments passed to form :class:`ott.geometry.pointcloud.PointCloud` geometry objects from the subsets of points and masses selected in `x` and `y`, this could be for @@ -274,6 +290,7 @@ def eval_fn( sinkhorn_kwargs=sinkhorn_kwargs, static_b=static_b, share_epsilon=share_epsilon, + symmetric_sinkhorn=symmetric_sinkhorn, cost_fn=cost_fn, src_mask=mask_x, tgt_mask=mask_y, diff --git a/tests/core/sinkhorn_test.py b/tests/core/sinkhorn_test.py index 16b09d379..580b334d9 100644 --- a/tests/core/sinkhorn_test.py +++ b/tests/core/sinkhorn_test.py @@ -45,18 +45,19 @@ def initialize(self, rng: jnp.ndarray): self.b = b / jnp.sum(b) @pytest.mark.fast.with_args( - "lse_mode,momentum,chg_momentum_from,inner_iterations,norm_error", - [(True, 1.0, 29, 10, 1), (False, 1.0, 30, 10, 1), (True, 1.0, 60, 1, 2), - (True, 1.0, 12, 24, 4)], + "lse_mode,momentum,chg_momentum_from,inner_iterations,norm_error,power", + [(True, 1.0, 29, 10, 1, 2.0), (False, 1.0, 30, 10, 1, 2.2), + (True, 1.0, 60, 1, 2, 1.0), (True, 1.0, 12, 24, 4, 3.0)], ids=["lse-Leh-mom", "scal-Leh-mom", "lse-Leh-1", "lse-Leh-24"], only_fast=[0, -1], ) def test_euclidean_point_cloud( - self, lse_mode, momentum, chg_momentum_from, inner_iterations, norm_error + self, lse_mode, momentum, chg_momentum_from, inner_iterations, norm_error, + power ): """Two point clouds, tested with various parameters.""" threshold = 1e-3 - geom = pointcloud.PointCloud(self.x, self.y, epsilon=0.1) + geom = pointcloud.PointCloud(self.x, self.y, epsilon=0.1, power=power) out = sinkhorn.sinkhorn( geom, a=self.a, diff --git a/tests/tools/sinkhorn_divergence_test.py b/tests/tools/sinkhorn_divergence_test.py index 6995103cd..152002638 100644 --- a/tests/tools/sinkhorn_divergence_test.py +++ b/tests/tools/sinkhorn_divergence_test.py @@ -21,6 +21,7 @@ import numpy as np import pytest +from ott.core import sinkhorn from ott.geometry import costs, geometry, pointcloud from ott.tools import sinkhorn_divergence from ott.tools.gaussian_mixture import gaussian_mixture @@ -38,32 +39,51 @@ def setUp(self, rng: jnp.ndarray): self._a = a / jnp.sum(a) self._b = b / jnp.sum(b) - def test_euclidean_point_cloud(self): + @pytest.mark.fast.with_args( + power=[1.0, 2.0, 2.7], + epsilon=[.01, .001], + only_fast={ + "power": 2.0, + "epsilon": .01 + }, + ) + def test_euclidean_point_cloud(self, power, epsilon): rngs = jax.random.split(self.rng, 2) x = jax.random.uniform(rngs[0], (self._num_points[0], self._dim)) y = jax.random.uniform(rngs[1], (self._num_points[1], self._dim)) - geometry_xx = pointcloud.PointCloud(x, x, epsilon=0.01) - geometry_xy = pointcloud.PointCloud(x, y, epsilon=0.01) - geometry_yy = pointcloud.PointCloud(y, y, epsilon=0.01) - div = sinkhorn_divergence._sinkhorn_divergence( - geometry_xy, geometry_xx, geometry_yy, self._a, self._b, threshold=1e-2 + + div = sinkhorn_divergence.sinkhorn_divergence( + pointcloud.PointCloud, + x, + y, + a=self._a, + b=self._b, + epsilon=epsilon, + power=power ) assert div.divergence > 0.0 assert len(div.potentials) == 3 - # Test symmetric setting, - # test that symmetric evaluation converges earlier/better. + geometry_xy = pointcloud.PointCloud(x, y, epsilon=epsilon, power=power) + geometry_xx = pointcloud.PointCloud(x, epsilon=epsilon, power=power) + geometry_yy = pointcloud.PointCloud(y, epsilon=epsilon, power=power) + + div2 = sinkhorn.sinkhorn(geometry_xy, self._a, self._b).reg_ot_cost + div2 -= 0.5 * sinkhorn.sinkhorn(geometry_xx, self._a, self._a).reg_ot_cost + div2 -= 0.5 * sinkhorn.sinkhorn(geometry_yy, self._b, self._b).reg_ot_cost + + np.testing.assert_allclose(div.divergence, div2, rtol=1e-5, atol=1e-5) + + # Test div of x to itself close to 0. div = sinkhorn_divergence.sinkhorn_divergence( pointcloud.PointCloud, x, x, - epsilon=1e-1, - sinkhorn_kwargs={'inner_iterations': 1} + epsilon=.1, + sinkhorn_kwargs={'inner_iterations': 1}, + power=power ) np.testing.assert_allclose(div.divergence, 0.0, rtol=1e-5, atol=1e-5) - iters_xx = jnp.sum(div.errors[0] > 0) - iters_xx_sym = jnp.sum(div.errors[1] > 0) - assert iters_xx > iters_xx_sym @pytest.mark.fast def test_euclidean_autoepsilon(self): @@ -76,7 +96,7 @@ def test_euclidean_autoepsilon(self): cloud_b, a=self._a, b=self._b, - sinkhorn_kwargs={"threshold": 1e-2} + sinkhorn_kwargs={"threshold": 1e-2}, ) assert div.divergence > 0.0 assert len(div.potentials) == 3