Skip to content

Commit

Permalink
implements sliced W (#576)
Browse files Browse the repository at this point in the history
* implements sliced W

* bug in rng handling in test

* typo

* incorporate comments

* typos in pydocs

* adding to docs

* reword

* small refactoring

* fix typos

* kwargs

* Update phrasing

* Clean tests a bit

* adding option to pass weights

---------

Co-authored-by: Michal Klein <[email protected]>
  • Loading branch information
marcocuturi and michalk8 authored Sep 13, 2024
1 parent 4aed3ec commit 27b639e
Show file tree
Hide file tree
Showing 5 changed files with 249 additions and 4 deletions.
9 changes: 9 additions & 0 deletions docs/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -1027,3 +1027,12 @@ @article{lin:22
volume = {23},
year = {2022},
}

@inproceedings{rabin:12,
author = {Rabin, Julien and Peyr{\'e}, Gabriel and Delon, Julie and Bernot, Marc},
title = {Wasserstein barycenter and its application to texture mixing},
booktitle = {Scale Space and Variational Methods in Computer Vision: Third International Conference, SSVM 2011, Ein-Gedi, Israel, May 29--June 2, 2011, Revised Selected Papers 3},
pages = {435--446},
year = {2012},
organization = {Springer}
}
21 changes: 17 additions & 4 deletions docs/tools.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,15 @@ ott.tools
.. module:: ott.tools
.. currentmodule:: ott.tools

The tools package contains high level functions that build on outputs produced
by core functions. They can be used to compute Sinkhorn divergences
:cite:`sejourne:19`, instantiate transport matrices, provide differentiable
approximations to ranks and quantile functions :cite:`cuturi:19`, etc.
The :mod:`~ott.tools` package contains high level functions that build on
outputs produced by lower-level components in the toolbox, such as
:mod:`~ott.solvers`.

In particular, we provide user-friendly APIs to compute Sinkhorn divergences
:cite:`genevay:18,sejourne:19`, sliced Wasserstein distances :cite:`rabin:12`,
differentiable approximations to ranks and quantile functions :cite:`cuturi:19`,
and various tools to study Gaussians with the 2-Wasserstein metric
:cite:`gelbrich:90,delon:20`, etc.

Segmented Sinkhorn
------------------
Expand All @@ -23,6 +28,14 @@ Sinkhorn Divergence
sinkhorn_divergence.sinkhorn_divergence
sinkhorn_divergence.segment_sinkhorn_divergence

Sliced Wasserstein Distance
---------------------------
.. autosummary::
:toctree: _autosummary

sliced.random_proj_sphere
sliced.sliced_wasserstein

ProgOT
------
.. autosummary::
Expand Down
1 change: 1 addition & 0 deletions src/ott/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@
progot,
segment_sinkhorn,
sinkhorn_divergence,
sliced,
soft_sort,
)
109 changes: 109 additions & 0 deletions src/ott/tools/sliced.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional, Tuple

import jax
import jax.numpy as jnp

from ott import utils
from ott.geometry import costs, pointcloud
from ott.solvers import linear
from ott.solvers.linear import univariate

__all__ = ["random_proj_sphere", "sliced_wasserstein"]

Projector = Callable[[jnp.ndarray, int, jax.Array], jnp.ndarray]


def random_proj_sphere(
x: jnp.ndarray,
n_proj: int = 1000,
rng: Optional[jax.Array] = None
) -> jnp.ndarray:
"""Project data on directions sampled randomly from sphere.
Args:
x: Array of size ``[n, dim]``.
n_proj: Number of randomly generated projections.
rng: Key used to sample feature extractors.
Returns:
Array of size ``[n, n_proj]`` features.
"""
rng = utils.default_prng_key(rng)
dim = x.shape[-1]
proj_m = jax.random.normal(rng, (n_proj, dim))
proj_m /= jnp.linalg.norm(proj_m, axis=1, keepdims=True)
return x @ proj_m.T


def sliced_wasserstein(
x: jnp.ndarray,
y: jnp.ndarray,
a: Optional[jnp.ndarray] = None,
b: Optional[jnp.ndarray] = None,
cost_fn: Optional[costs.CostFn] = None,
proj_fn: Optional[Projector] = None,
weights: Optional[jnp.ndarray] = None,
return_transport: bool = False,
return_dual_variables: bool = False,
**kwargs: Any,
) -> Tuple[jnp.ndarray, univariate.UnivariateOutput]:
r"""Compute the Sliced Wasserstein distance between two weighted point clouds.
Follows the approach outlined in :cite:`rabin:12` to compute a proxy for OT
distances that relies on creating features (possibly randomly) for data,
through e.g., projections, and then sum the 1D Wasserstein distances between
these features' univariate distributions on both source and target samples.
Args:
x: Array of shape ``[n, dim]`` of source points' coordinates.
y: Array of shape ``[m, dim]`` of target points' coordinates.
a: Array of shape ``[n,]`` of source probability weights.
b: Array of shape ``[m,]`` of target probability weights.
cost_fn: Cost function. Must be a submodular function of two real arguments,
i.e. such that :math:`\partial c(x,y)/\partial x \partial y <0`. If
:obj:`None`, use :class:`~ott.geometry.costs.SqEuclidean`.
proj_fn: Projection function, mapping any ``[b, dim]`` matrix of coordinates
to ``[b, n_proj]`` matrix of features, on which 1D transports (for
``n_proj`` directions) are subsequently computed independently.
By default, use :func:`~ott.tools.sliced.random_proj_sphere`.
weights: Array of shape ``[n_proj,]`` of weights used to average the
``n_proj`` 1D Wasserstein contributions (one for each feature) and form
the sliced Wasserstein distance. Uniform by default, resulting in average
of all these values.
return_transport: Whether to store ``n_proj`` transport plans in the output.
return_dual_variables: Whether to store ``n_proj`` pairs of dual vectors
in the output.
kwargs: Keyword arguments to ``proj_fn``. Could for instance
include, as done with default projector, number of ``n_proj`` projections,
as well as a ``rng`` key to sample as many directions.
Returns:
The sliced Wasserstein distance with the corresponding output object.
"""
if proj_fn is None:
proj_fn = random_proj_sphere

x_proj, y_proj = proj_fn(x, **kwargs), proj_fn(y, **kwargs),
geom = pointcloud.PointCloud(x_proj, y_proj, cost_fn=cost_fn)

out = linear.solve_univariate(
geom,
a,
b,
return_transport=return_transport,
return_dual_variables=return_dual_variables
)
return jnp.average(out.ot_costs, weights=weights), out
113 changes: 113 additions & 0 deletions tests/tools/sliced_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Optional, Tuple

import pytest

import jax
import jax.numpy as jnp
import numpy as np

from ott.geometry import costs, pointcloud
from ott.solvers import linear
from ott.tools import sliced

Projector = Callable[[jnp.ndarray, int, jax.Array], jnp.ndarray]


def custom_proj(
x: jnp.ndarray,
rng: Optional[jax.Array] = None,
n_proj: int = 27
) -> jnp.ndarray:
dim = x.shape[1]
rng = jax.random.PRNGKey(42) if rng is None else rng
proj_m = jax.random.uniform(rng, (n_proj, dim))
return (x @ proj_m.T) ** 2


def gen_data(
rng: jax.Array, n: int, m: int, dim: int
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
rngs = jax.random.split(rng, 4)
x = jax.random.uniform(rngs[0], (n, dim))
y = jax.random.uniform(rngs[1], (m, dim))
a = jax.random.uniform(rngs[2], (n,))
b = jax.random.uniform(rngs[3], (m,))
a /= jnp.sum(a)
b /= jnp.sum(b)
return a, x, b, y


class TestSliced:

@pytest.mark.parametrize("proj_fn", [None, custom_proj])
@pytest.mark.parametrize("cost_fn", [costs.PNormP(1.3), None])
def test_random_projs(
self, rng: jax.Array, cost_fn: Optional[costs.CostFn],
proj_fn: Optional[Projector]
):
n, m, dim, n_proj = 12, 17, 5, 13
rng1, rng2 = jax.random.split(rng, 2)
a, x, b, y = gen_data(rng1, n, m, dim)
weights = jax.random.uniform(rng2, n_proj)

# Test non-negative and returns output as needed.
cost, out = sliced.sliced_wasserstein(
x,
y,
a,
b,
cost_fn=cost_fn,
proj_fn=proj_fn,
n_proj=n_proj,
rng=rng2,
weights=weights
)
assert cost > 0.0
np.testing.assert_array_equal(
cost, jnp.average(out.ot_costs, weights=weights)
)

@pytest.mark.parametrize("cost_fn", [costs.SqPNorm(1.4), None])
def test_consistency_with_id(
self, rng: jax.Array, cost_fn: Optional[costs.CostFn]
):
n, m, dim = 11, 12, 4
a, x, b, y = gen_data(rng, n, m, dim)

# Test matches standard implementation when using identity.
cost, _ = sliced.sliced_wasserstein(
x, y, proj_fn=lambda x: x, cost_fn=cost_fn
)
geom = pointcloud.PointCloud(x=x, y=y, cost_fn=cost_fn)
out_lin = jnp.mean(linear.solve_univariate(geom).ot_costs)
np.testing.assert_allclose(out_lin, cost, rtol=1e-6, atol=1e-6)

@pytest.mark.parametrize("proj_fn", [None, custom_proj])
def test_diff(self, rng: jax.Array, proj_fn: Optional[Projector]):
eps = 1e-4
n, m, dim = 13, 16, 7
a, x, b, y = gen_data(rng, n, m, dim)

# Test differentiability. We assume uniform samples because makes diff
# more accurate (avoiding ties, making computations a lot more sensitive).
dx = jax.random.uniform(rng, (n, dim)) - 0.5
cost_p, _ = sliced.sliced_wasserstein(x + eps * dx, y)
cost_m, _ = sliced.sliced_wasserstein(x - eps * dx, y)
g, _ = jax.jit(jax.grad(sliced.sliced_wasserstein, has_aux=True))(x, y)

np.testing.assert_allclose(
jnp.sum(g * dx), (cost_p - cost_m) / (2 * eps), atol=1e-3, rtol=1e-3
)

0 comments on commit 27b639e

Please sign in to comment.