Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/batched vmap #588

Merged
merged 79 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from 72 commits
Commits
Show all changes
79 commits
Select commit Hold shift + click to select a range
4472425
Start batched vmap
michalk8 Sep 20, 2024
80b9c60
Initial `batched_vmap` impl
michalk8 Oct 10, 2024
7c58b9a
Nicer formatting
michalk8 Oct 10, 2024
b995d15
Fix getting shape
michalk8 Oct 10, 2024
e5cfa1e
Remove private API usage
michalk8 Oct 10, 2024
f8eee7b
Fix new args
michalk8 Oct 10, 2024
51980ef
Add a TODO
michalk8 Oct 10, 2024
7a510ae
Canonicalize axes
michalk8 Oct 10, 2024
1ca05f1
Add `batched_vmap` to docs
michalk8 Oct 10, 2024
e7ad8a4
Removed batched transport functions
michalk8 Oct 10, 2024
acf221b
Remove `_norm_{x,y}` from `CostFn`
michalk8 Oct 10, 2024
07f4734
Implement `apply_lse_kernel`
michalk8 Oct 10, 2024
98f1de9
Implememt `apply_kernel`
michalk8 Oct 10, 2024
ba825f9
Implement `apply_cost`
michalk8 Oct 10, 2024
9036329
Remove old functions
michalk8 Oct 10, 2024
62089bc
Make function private
michalk8 Oct 10, 2024
b9bb64a
Refactor `apply_cost` to have consistent shapes
michalk8 Oct 10, 2024
d8b5ea6
Use `_apply_cost_to_vec` in `PointCloud`
michalk8 Oct 10, 2024
12f923c
Remoeve TODO
michalk8 Oct 10, 2024
a43dcdb
Formatting
michalk8 Oct 10, 2024
e0d75b0
Simplify `_apply_sqeucl_cost`
michalk8 Oct 10, 2024
f5445ec
Fix `RecusionError`
michalk8 Oct 10, 2024
4922ca9
Remove docstring of a private method
michalk8 Oct 10, 2024
799c108
Fix `apply_lse_kernel`
michalk8 Oct 10, 2024
9a5c1ca
Squeeze only 1 axis of the cost
michalk8 Oct 10, 2024
8543538
Add TODO
michalk8 Oct 10, 2024
317eb02
Rename function, make a property
michalk8 Oct 10, 2024
d31fd4d
Remove unused helper function
michalk8 Oct 10, 2024
4b0f150
Compute mean summary online
michalk8 Oct 10, 2024
8843937
Compute mean online
michalk8 Oct 10, 2024
83c9960
Compute max cost matrix
michalk8 Oct 10, 2024
69a9599
Update error message
michalk8 Oct 11, 2024
6667abd
Remove TODO
michalk8 Oct 11, 2024
ac9b928
Flatten out axes
michalk8 Oct 11, 2024
c113946
Fix missing cross terms in the costs
michalk8 Oct 11, 2024
75d9e7a
Fix geom tests
michalk8 Oct 11, 2024
44eb5a8
Fix dtype
michalk8 Oct 11, 2024
cbb4ea0
Start implementing transport functions
michalk8 Oct 11, 2024
e8bb1b5
Implement online transport functions
michalk8 Oct 11, 2024
7d4001e
Fix solver tests
michalk8 Oct 11, 2024
8941224
Fix Bures test
michalk8 Oct 11, 2024
a565e09
Don't use `pairwise` in tests
michalk8 Oct 11, 2024
1533324
Update notebook that uses `norm`
michalk8 Oct 11, 2024
3e7ff8b
Fix bug in `UnbalancedBures`
michalk8 Oct 11, 2024
b815fbc
Rename `pairwise -> __call__`
michalk8 Oct 11, 2024
739afde
Remove old shape code
michalk8 Oct 11, 2024
0d7f6ae
Always instantiate the cost for online
michalk8 Oct 11, 2024
4863fcf
Remove old TODO
michalk8 Oct 11, 2024
4aa4c6b
Extract `_apply_cost_to_vec_fast`
michalk8 Oct 11, 2024
8511073
Update max cost in LRCGeom
michalk8 Oct 11, 2024
47462d2
Fix test, use more `multi_dot`
michalk8 Oct 11, 2024
05630a8
Remove `batch_size` from `LRCGeometry`
michalk8 Oct 11, 2024
0994d7a
Add better warning error
michalk8 Oct 15, 2024
5d88ad4
Reorder properties
michalk8 Oct 15, 2024
f8143fc
Add docs to `batched_vmap`
michalk8 Oct 15, 2024
a82688c
Start adding tests
michalk8 Oct 15, 2024
1d2d12d
Reorder functions in test
michalk8 Oct 15, 2024
44b1126
Fix axes, add a test
michalk8 Oct 15, 2024
889f81f
Update test fn
michalk8 Oct 15, 2024
b16d5a8
Move out assert
michalk8 Oct 15, 2024
c984a43
Dont canon out_axes
michalk8 Oct 15, 2024
4426994
Check max traces
michalk8 Oct 15, 2024
5e5125b
Test memory of batched vmap
michalk8 Oct 15, 2024
cb31db7
Install `typing_extensions`
michalk8 Oct 15, 2024
57bf9ca
Merge branch 'main' into feature/batched-vmap
michalk8 Oct 15, 2024
721eca9
Remove `.` from description
michalk8 Oct 15, 2024
f9a41bd
Add more `out_axes` tests
michalk8 Oct 15, 2024
78003d9
Add `in_axes` test
michalk8 Oct 15, 2024
9e1ae03
Fix negative axes
michalk8 Oct 15, 2024
427b5ec
Increase memory limit in the test
michalk8 Oct 16, 2024
fff0ce6
Add in_axes pytree test
michalk8 Oct 16, 2024
f72abf1
Remove old warnings filters
michalk8 Oct 16, 2024
b19ff4b
Update fixtures
michalk8 Oct 16, 2024
462c630
Update SqEucl cost.
michalk8 Oct 16, 2024
babb095
Update docstrings
michalk8 Oct 16, 2024
87df731
Remove unused imports from the docs
michalk8 Oct 16, 2024
07dff82
Revert test pre-commits
michalk8 Oct 16, 2024
5450808
Fix ICNN init notebook
michalk8 Oct 16, 2024
e390e64
Improve error message
michalk8 Oct 16, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions docs/tutorials/geometry/100_grid.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -241,11 +241,15 @@
"class MyCost(costs.CostFn):\n",
" \"\"\"An unusual cost function.\"\"\"\n",
"\n",
" def norm(self, x):\n",
" def norm(self, x: jnp.ndarray) -> jnp.ndarray:\n",
" return jnp.sum(x**3 + jnp.cos(x) ** 2, axis=-1)\n",
"\n",
" def pairwise(self, x, y):\n",
" return -jnp.sum(jnp.sin(x + 1) * jnp.sin(y)) * 2"
" def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:\n",
" return (\n",
" self.norm(x)\n",
" + self.norm(y)\n",
" - jnp.sum(jnp.sin(x + 1) * jnp.sin(y)) * 2\n",
" )"
]
},
{
Expand Down
1 change: 1 addition & 0 deletions docs/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ function for :class:`~ott.solvers.linear.sinkhorn.Sinkhorn`.

default_progress_fn
tqdm_progress_fn
batched_vmap
11 changes: 3 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "ott-jax"
description = "Optimal Transport Tools in JAX."
description = "Optimal Transport Tools in JAX"
requires-python = ">=3.9"
dynamic = ["version"]
readme = {file = "README.md", content-type = "text/markdown"}
Expand All @@ -17,6 +17,7 @@ dependencies = [
"jaxopt>=0.8",
"lineax>=0.0.5",
"numpy>=1.20.0",
"typing_extensions; python_version <= '3.9'",
]
keywords = [
"optimal transport",
Expand Down Expand Up @@ -107,7 +108,7 @@ multi_line_output = 3
sections = ["FUTURE", "STDLIB", "THIRDPARTY", "TEST", "NUMERIC", "NEURAL", "PLOTTING", "FIRSTPARTY", "LOCALFOLDER"]
# also contains what we import in notebooks/tests
known_neural = ["flax", "optax", "diffrax", "orbax"]
known_numeric = ["numpy", "scipy", "jax", "flax", "optax", "jaxopt", "ot", "torch", "torchvision", "pandas", "sklearn", "tslearn"]
known_numeric = ["numpy", "scipy", "jax", "chex", "flax", "optax", "jaxopt", "ot", "torch", "torchvision", "pandas", "sklearn", "tslearn"]
known_test = ["_pytest", "pytest"]
known_plotting = ["IPython", "matplotlib", "mpl_toolkits", "seaborn"]

Expand All @@ -120,12 +121,6 @@ markers = [
"cpu: Mark tests as CPU only.",
"fast: Mark tests as fast.",
]
filterwarnings = [
"ignore:\\n*.*scipy.sparse array",
"ignore:jax.random.KeyArray is deprecated:DeprecationWarning",
"ignore:.*jax.config:DeprecationWarning",
"ignore:jax.core.Shape is deprecated:DeprecationWarning:chex",
]

[tool.coverage.run]
branch = true
Expand Down
80 changes: 22 additions & 58 deletions src/ott/geometry/costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import abc
import functools
import math
from typing import Any, Callable, Dict, Optional, Tuple, Union
from typing import Any, Callable, Dict, Optional, Tuple

import jax
import jax.numpy as jnp
Expand All @@ -39,29 +39,16 @@
"SoftDTW",
]

# TODO(michalk8): norm check
Func = Callable[[jnp.ndarray], float]


@jtu.register_pytree_node_class
class CostFn(abc.ABC):
"""Base class for all costs.

Cost functions evaluate a function on a pair of inputs. For convenience,
that function is split into two norms -- evaluated on each input separately --
followed by a pairwise cost that involves both inputs, as in:

.. math::
c(x, y) = norm(x) + norm(y) + pairwise(x, y)

If the :attr:`norm` function is not implemented, that value is handled as
:math:`0`, and only :func:`pairwise` is used.
"""

# no norm function created by default.
norm: Optional[Callable[[jnp.ndarray], Union[float, jnp.ndarray]]] = None
"""Base class for all costs."""

@abc.abstractmethod
def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
"""Compute cost between :math:`x` and :math:`y`.

Args:
Expand Down Expand Up @@ -99,22 +86,6 @@ def _padder(cls, dim: int) -> jnp.ndarray:
"""
return jnp.zeros((1, dim))

def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
"""Compute cost between :math:`x` and :math:`y`.

Args:
x: Array.
y: Array.

Returns:
The cost, optionally including the :attr:`norms <norm>` of
:math:`x`/:math:`y`.
"""
cost = self.pairwise(x, y)
if self.norm is None:
return cost
return cost + self.norm(x) + self.norm(y)

def all_pairs(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
"""Compute matrix of all pairwise costs, including the :attr:`norms <norm>`.

Expand All @@ -127,18 +98,6 @@ def all_pairs(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
"""
return jax.vmap(lambda x_: jax.vmap(lambda y_: self(x_, y_))(y))(x)

def all_pairs_pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
"""Compute matrix of all pairwise costs, excluding the :attr:`norms <norm>`.

Args:
x: Array of shape ``[n, ...]``.
y: Array of shape ``[m, ...]``.

Returns:
Array of shape ``[n, m]`` of cost evaluations.
"""
return jax.vmap(lambda x_: jax.vmap(lambda y_: self.pairwise(x_, y_))(y))(x)

def twist_operator(
self, vec: jnp.ndarray, dual_vec: jnp.ndarray, variable: bool
) -> jnp.ndarray:
Expand Down Expand Up @@ -200,7 +159,7 @@ def h_legendre(self, z: jnp.ndarray) -> float:
"""Legendre transform of :func:`h` when it is convex."""
raise NotImplementedError("Legendre transform of `h` is not implemented.")

def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
"""Compute cost as evaluation of :func:`h` on :math:`x-y`."""
return self.h(x - y)

Expand Down Expand Up @@ -539,7 +498,7 @@ class Euclidean(CostFn):
because the function is not strictly convex (it is linear on rays).
"""

def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
"""Compute Euclidean norm using custom jvp implementation.
michalk8 marked this conversation as resolved.
Show resolved Hide resolved

Here we use a custom jvp implementation for the norm that does not yield
Expand All @@ -556,13 +515,14 @@ class SqEuclidean(TICost):
Implemented as a translation invariant cost, :math:`h(z) = \|z\|^2`.
"""

def norm(self, x: jnp.ndarray) -> Union[float, jnp.ndarray]:
def norm(self, x: jnp.ndarray) -> jnp.ndarray:
"""Compute squared Euclidean norm for vector."""
return jnp.sum(x ** 2, axis=-1)

def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
"""Compute minus twice the dot-product between vectors."""
return -2.0 * jnp.vdot(x, y)
cross_term = -2.0 * jnp.vdot(x, y)
return self.norm(x) + self.norm(y) + cross_term

def h(self, z: jnp.ndarray) -> float: # noqa: D102
return jnp.sum(z ** 2)
Expand All @@ -588,7 +548,7 @@ def __init__(self, ridge: float = 1e-8):
super().__init__()
self._ridge = ridge

def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
"""Cosine distance between vectors, denominator regularized with ridge."""
x_norm = jnp.linalg.norm(x, axis=-1)
y_norm = jnp.linalg.norm(y, axis=-1)
Expand Down Expand Up @@ -624,7 +584,7 @@ def __init__(self, n: int, ridge: float = 1e-8):
self.n = n
self._ridge = ridge

def pairwise(self, x: jnp.ndarray, y: jnp.ndarray): # noqa: D102
def __call__(self, x: jnp.ndarray, y: jnp.ndarray): # noqa: D102
x_norm = jnp.linalg.norm(x, axis=-1)
y_norm = jnp.linalg.norm(y, axis=-1)
cosine_similarity = jnp.vdot(x, y) / (x_norm * y_norm + self._ridge)
Expand Down Expand Up @@ -688,7 +648,7 @@ def norm(self, x: jnp.ndarray) -> jnp.ndarray:
norm += jnp.trace(cov, axis1=-2, axis2=-1)
return norm

def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
"""Compute - 2 x Bures dot-product."""
mean_x, cov_x = x_to_means_and_covs(x, self._dimension)
mean_y, cov_y = x_to_means_and_covs(y, self._dimension)
Expand All @@ -698,7 +658,10 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
sq__sq_x_y_sq_x = matrix_square_root.sqrtm(
sq_x_y_sq_x, self._dimension, **self._sqrtm_kw
)[0]
return -2 * (mean_dot_prod + jnp.trace(sq__sq_x_y_sq_x, axis1=-2, axis2=-1))
cross_term = -2.0 * (
mean_dot_prod + jnp.trace(sq__sq_x_y_sq_x, axis1=-2, axis2=-1)
)
return self.norm(x) + self.norm(y) + cross_term

def covariance_fixpoint_iter(
self,
Expand Down Expand Up @@ -883,7 +846,7 @@ def norm(self, x: jnp.ndarray) -> jnp.ndarray:
"""
return self._gamma * x[..., 0]

def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
"""Compute dot-product for unbalanced Bures.

Args:
Expand Down Expand Up @@ -939,12 +902,13 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
log_m_pi += -0.5 * ldet_c_ab

# if all logdet signs are 1, output value, nan otherwise
pos_signs = (sldet_c + sldet_c_ab + sldet_t_ab + sldet_t_ab) == 4
pos_signs = (sldet_c + sldet_c_ab + sldet_ab + sldet_t_ab) == 4

return jax.lax.cond(
cross_term = jax.lax.cond(
pos_signs, lambda: 2 * sig2 * mass_x * mass_y - 2 *
(sig2 + gam) * jnp.exp(log_m_pi), lambda: jnp.nan
)
return self.norm(x) + self.norm(y) + cross_term

def tree_flatten(self): # noqa: D102
return (), (self._dimension, self._sigma, self._gamma, self._sqrtm_kw)
Expand Down Expand Up @@ -977,7 +941,7 @@ def __init__(
self.ground_cost = SqEuclidean() if ground_cost is None else ground_cost
self.debiased = debiased

def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: # noqa: D102
def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> float: # noqa: D102
c_xy = self._soft_dtw(x, y)
if self.debiased:
return c_xy - 0.5 * (self._soft_dtw(x, x) + self._soft_dtw(y, y))
Expand Down
2 changes: 1 addition & 1 deletion src/ott/geometry/distrib_costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(
)
self._solve_fn = solve_fn

def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
"""Wasserstein distance between :math:`x` and :math:`y` seen as a 1D dist.

Args:
Expand Down
9 changes: 5 additions & 4 deletions src/ott/geometry/geodesic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import jax
import jax.experimental.sparse as jesp
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
from scipy.special import ive

Expand All @@ -28,7 +29,7 @@
Array_g = Union[jnp.ndarray, jesp.BCOO]


@jax.tree_util.register_pytree_node_class
@jtu.register_pytree_node_class
class Geodesic(geometry.Geometry):
r"""Graph distance approximation using heat kernel :cite:`huguet:2023`.

Expand Down Expand Up @@ -134,22 +135,22 @@ def from_graph(

def apply_kernel(
self,
scaling: jnp.ndarray,
vec: jnp.ndarray,
eps: Optional[float] = None,
axis: int = 0,
) -> jnp.ndarray:
r"""Apply :attr:`kernel_matrix` on positive scaling vector.

Args:
scaling: Scaling to apply the kernel to.
vec: Scaling to apply the kernel to.
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
eps: passed for consistency, not used yet.
axis: passed for consistency, not used yet.

Returns:
Kernel applied to ``scaling``.
"""
return expm_multiply(
self.scaled_laplacian, scaling, self.chebyshev_coeffs, 0.5 * self.eigval
self.scaled_laplacian, vec, self.chebyshev_coeffs, 0.5 * self.eigval
)

@property
Expand Down
Loading
Loading