-
Notifications
You must be signed in to change notification settings - Fork 86
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* implements sliced W * bug in rng handling in test * typo * incorporate comments * typos in pydocs * adding to docs * reword * small refactoring * fix typos * kwargs * Update phrasing * Clean tests a bit * adding option to pass weights --------- Co-authored-by: Michal Klein <[email protected]>
- Loading branch information
1 parent
4aed3ec
commit 27b639e
Showing
5 changed files
with
249 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,5 +18,6 @@ | |
progot, | ||
segment_sinkhorn, | ||
sinkhorn_divergence, | ||
sliced, | ||
soft_sort, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
# Copyright OTT-JAX | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Any, Callable, Optional, Tuple | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
|
||
from ott import utils | ||
from ott.geometry import costs, pointcloud | ||
from ott.solvers import linear | ||
from ott.solvers.linear import univariate | ||
|
||
__all__ = ["random_proj_sphere", "sliced_wasserstein"] | ||
|
||
Projector = Callable[[jnp.ndarray, int, jax.Array], jnp.ndarray] | ||
|
||
|
||
def random_proj_sphere( | ||
x: jnp.ndarray, | ||
n_proj: int = 1000, | ||
rng: Optional[jax.Array] = None | ||
) -> jnp.ndarray: | ||
"""Project data on directions sampled randomly from sphere. | ||
Args: | ||
x: Array of size ``[n, dim]``. | ||
n_proj: Number of randomly generated projections. | ||
rng: Key used to sample feature extractors. | ||
Returns: | ||
Array of size ``[n, n_proj]`` features. | ||
""" | ||
rng = utils.default_prng_key(rng) | ||
dim = x.shape[-1] | ||
proj_m = jax.random.normal(rng, (n_proj, dim)) | ||
proj_m /= jnp.linalg.norm(proj_m, axis=1, keepdims=True) | ||
return x @ proj_m.T | ||
|
||
|
||
def sliced_wasserstein( | ||
x: jnp.ndarray, | ||
y: jnp.ndarray, | ||
a: Optional[jnp.ndarray] = None, | ||
b: Optional[jnp.ndarray] = None, | ||
cost_fn: Optional[costs.CostFn] = None, | ||
proj_fn: Optional[Projector] = None, | ||
weights: Optional[jnp.ndarray] = None, | ||
return_transport: bool = False, | ||
return_dual_variables: bool = False, | ||
**kwargs: Any, | ||
) -> Tuple[jnp.ndarray, univariate.UnivariateOutput]: | ||
r"""Compute the Sliced Wasserstein distance between two weighted point clouds. | ||
Follows the approach outlined in :cite:`rabin:12` to compute a proxy for OT | ||
distances that relies on creating features (possibly randomly) for data, | ||
through e.g., projections, and then sum the 1D Wasserstein distances between | ||
these features' univariate distributions on both source and target samples. | ||
Args: | ||
x: Array of shape ``[n, dim]`` of source points' coordinates. | ||
y: Array of shape ``[m, dim]`` of target points' coordinates. | ||
a: Array of shape ``[n,]`` of source probability weights. | ||
b: Array of shape ``[m,]`` of target probability weights. | ||
cost_fn: Cost function. Must be a submodular function of two real arguments, | ||
i.e. such that :math:`\partial c(x,y)/\partial x \partial y <0`. If | ||
:obj:`None`, use :class:`~ott.geometry.costs.SqEuclidean`. | ||
proj_fn: Projection function, mapping any ``[b, dim]`` matrix of coordinates | ||
to ``[b, n_proj]`` matrix of features, on which 1D transports (for | ||
``n_proj`` directions) are subsequently computed independently. | ||
By default, use :func:`~ott.tools.sliced.random_proj_sphere`. | ||
weights: Array of shape ``[n_proj,]`` of weights used to average the | ||
``n_proj`` 1D Wasserstein contributions (one for each feature) and form | ||
the sliced Wasserstein distance. Uniform by default, resulting in average | ||
of all these values. | ||
return_transport: Whether to store ``n_proj`` transport plans in the output. | ||
return_dual_variables: Whether to store ``n_proj`` pairs of dual vectors | ||
in the output. | ||
kwargs: Keyword arguments to ``proj_fn``. Could for instance | ||
include, as done with default projector, number of ``n_proj`` projections, | ||
as well as a ``rng`` key to sample as many directions. | ||
Returns: | ||
The sliced Wasserstein distance with the corresponding output object. | ||
""" | ||
if proj_fn is None: | ||
proj_fn = random_proj_sphere | ||
|
||
x_proj, y_proj = proj_fn(x, **kwargs), proj_fn(y, **kwargs), | ||
geom = pointcloud.PointCloud(x_proj, y_proj, cost_fn=cost_fn) | ||
|
||
out = linear.solve_univariate( | ||
geom, | ||
a, | ||
b, | ||
return_transport=return_transport, | ||
return_dual_variables=return_dual_variables | ||
) | ||
return jnp.average(out.ot_costs, weights=weights), out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
# Copyright OTT-JAX | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Callable, Optional, Tuple | ||
|
||
import pytest | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
import numpy as np | ||
|
||
from ott.geometry import costs, pointcloud | ||
from ott.solvers import linear | ||
from ott.tools import sliced | ||
|
||
Projector = Callable[[jnp.ndarray, int, jax.Array], jnp.ndarray] | ||
|
||
|
||
def custom_proj( | ||
x: jnp.ndarray, | ||
rng: Optional[jax.Array] = None, | ||
n_proj: int = 27 | ||
) -> jnp.ndarray: | ||
dim = x.shape[1] | ||
rng = jax.random.PRNGKey(42) if rng is None else rng | ||
proj_m = jax.random.uniform(rng, (n_proj, dim)) | ||
return (x @ proj_m.T) ** 2 | ||
|
||
|
||
def gen_data( | ||
rng: jax.Array, n: int, m: int, dim: int | ||
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: | ||
rngs = jax.random.split(rng, 4) | ||
x = jax.random.uniform(rngs[0], (n, dim)) | ||
y = jax.random.uniform(rngs[1], (m, dim)) | ||
a = jax.random.uniform(rngs[2], (n,)) | ||
b = jax.random.uniform(rngs[3], (m,)) | ||
a /= jnp.sum(a) | ||
b /= jnp.sum(b) | ||
return a, x, b, y | ||
|
||
|
||
class TestSliced: | ||
|
||
@pytest.mark.parametrize("proj_fn", [None, custom_proj]) | ||
@pytest.mark.parametrize("cost_fn", [costs.PNormP(1.3), None]) | ||
def test_random_projs( | ||
self, rng: jax.Array, cost_fn: Optional[costs.CostFn], | ||
proj_fn: Optional[Projector] | ||
): | ||
n, m, dim, n_proj = 12, 17, 5, 13 | ||
rng1, rng2 = jax.random.split(rng, 2) | ||
a, x, b, y = gen_data(rng1, n, m, dim) | ||
weights = jax.random.uniform(rng2, n_proj) | ||
|
||
# Test non-negative and returns output as needed. | ||
cost, out = sliced.sliced_wasserstein( | ||
x, | ||
y, | ||
a, | ||
b, | ||
cost_fn=cost_fn, | ||
proj_fn=proj_fn, | ||
n_proj=n_proj, | ||
rng=rng2, | ||
weights=weights | ||
) | ||
assert cost > 0.0 | ||
np.testing.assert_array_equal( | ||
cost, jnp.average(out.ot_costs, weights=weights) | ||
) | ||
|
||
@pytest.mark.parametrize("cost_fn", [costs.SqPNorm(1.4), None]) | ||
def test_consistency_with_id( | ||
self, rng: jax.Array, cost_fn: Optional[costs.CostFn] | ||
): | ||
n, m, dim = 11, 12, 4 | ||
a, x, b, y = gen_data(rng, n, m, dim) | ||
|
||
# Test matches standard implementation when using identity. | ||
cost, _ = sliced.sliced_wasserstein( | ||
x, y, proj_fn=lambda x: x, cost_fn=cost_fn | ||
) | ||
geom = pointcloud.PointCloud(x=x, y=y, cost_fn=cost_fn) | ||
out_lin = jnp.mean(linear.solve_univariate(geom).ot_costs) | ||
np.testing.assert_allclose(out_lin, cost, rtol=1e-6, atol=1e-6) | ||
|
||
@pytest.mark.parametrize("proj_fn", [None, custom_proj]) | ||
def test_diff(self, rng: jax.Array, proj_fn: Optional[Projector]): | ||
eps = 1e-4 | ||
n, m, dim = 13, 16, 7 | ||
a, x, b, y = gen_data(rng, n, m, dim) | ||
|
||
# Test differentiability. We assume uniform samples because makes diff | ||
# more accurate (avoiding ties, making computations a lot more sensitive). | ||
dx = jax.random.uniform(rng, (n, dim)) - 0.5 | ||
cost_p, _ = sliced.sliced_wasserstein(x + eps * dx, y) | ||
cost_m, _ = sliced.sliced_wasserstein(x - eps * dx, y) | ||
g, _ = jax.jit(jax.grad(sliced.sliced_wasserstein, has_aux=True))(x, y) | ||
|
||
np.testing.assert_allclose( | ||
jnp.sum(g * dx), (cost_p - cost_m) / (2 * eps), atol=1e-3, rtol=1e-3 | ||
) |