Skip to content

Commit

Permalink
Feature/batched vmap (#588)
Browse files Browse the repository at this point in the history
* Start batched vmap

* Initial `batched_vmap` impl

* Nicer formatting

* Fix getting shape

* Remove private API usage

* Fix new args

* Add a TODO

* Canonicalize axes

* Add `batched_vmap` to docs

* Removed batched transport functions

* Remove `_norm_{x,y}` from `CostFn`

* Implement `apply_lse_kernel`

* Implememt `apply_kernel`

* Implement `apply_cost`

* Remove old functions

* Make function private

* Refactor `apply_cost` to have consistent shapes

* Use `_apply_cost_to_vec` in `PointCloud`

* Remoeve TODO

* Formatting

* Simplify `_apply_sqeucl_cost`

* Fix `RecusionError`

* Remove docstring of a private method

* Fix `apply_lse_kernel`

* Squeeze only 1 axis of the cost

* Add TODO

* Rename function, make a property

* Remove unused helper function

* Compute mean summary online

* Compute mean online

* Compute max cost matrix

* Update error message

* Remove TODO

* Flatten out axes

* Fix missing cross terms in the costs

* Fix geom tests

* Fix dtype

* Start implementing transport functions

* Implement online transport functions

* Fix solver tests

* Fix Bures test

* Don't use `pairwise` in tests

* Update notebook that uses `norm`

* Fix bug in `UnbalancedBures`

* Rename `pairwise -> __call__`

* Remove old shape code

* Always instantiate the cost for online

* Remove old TODO

* Extract `_apply_cost_to_vec_fast`

* Update max cost in LRCGeom

* Fix test, use more `multi_dot`

* Remove `batch_size` from `LRCGeometry`

* Add better warning error

* Reorder properties

* Add docs to `batched_vmap`

* Start adding tests

* Reorder functions in test

* Fix axes, add a test

* Update test fn

* Move out assert

* Dont canon out_axes

* Check max traces

* Test memory of batched vmap

* Install `typing_extensions`

* Remove `.` from description

* Add more `out_axes` tests

* Add `in_axes` test

* Fix negative axes

* Increase memory limit in the test

* Add in_axes pytree test

* Remove old warnings filters

* Update fixtures

* Update SqEucl cost.

* Update docstrings

* Remove unused imports from the docs

* Revert test pre-commits

* Fix ICNN init notebook

Was broken by #551

* Improve error message
  • Loading branch information
michalk8 authored Oct 16, 2024
1 parent 706cef7 commit c9d3a49
Show file tree
Hide file tree
Showing 40 changed files with 784 additions and 774 deletions.
10 changes: 3 additions & 7 deletions docs/tutorials/barycenter/000_Sinkhorn_Barycenters.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@
"metadata": {},
"outputs": [],
"source": [
"import nilearn\n",
"from nilearn import datasets, image, plotting\n",
"from nilearn.image import get_data\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
Expand All @@ -38,7 +36,7 @@
"id": "02VJX2uXYHDX"
},
"source": [
"## Import neuroimaging data using `nilearn`.\n",
"## Import neuroimaging data using `nilearn`\n",
"\n",
"We recover a few MRI data points..."
]
Expand Down Expand Up @@ -157,7 +155,7 @@
}
],
"source": [
"a = jnp.array(get_data(gm_imgs)).transpose((3, 0, 1, 2))\n",
"a = jnp.array(image.get_data(gm_imgs)).transpose((3, 0, 1, 2))\n",
"grid_size = a.shape[1:4]\n",
"a = a.reshape((n_subjects, -1)) + 1e-2\n",
"a = a / np.sum(a, axis=1)[:, np.newaxis]\n",
Expand Down Expand Up @@ -320,9 +318,7 @@
],
"source": [
"def data_to_nii(x):\n",
" return nilearn.image.new_img_like(\n",
" gm_imgs[0], data=np.array(x.reshape(grid_size))\n",
" )\n",
" return image.new_img_like(gm_imgs[0], data=np.array(x.reshape(grid_size)))\n",
"\n",
"\n",
"plotting.plot_epi(data_to_nii(barycenter.histogram))\n",
Expand Down
1 change: 0 additions & 1 deletion docs/tutorials/barycenter/200_gmm_pair_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@
" fit_gmm_pair,\n",
" gaussian_mixture,\n",
" gaussian_mixture_pair,\n",
" probabilities,\n",
")"
]
},
Expand Down
10 changes: 7 additions & 3 deletions docs/tutorials/geometry/100_grid.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -241,11 +241,15 @@
"class MyCost(costs.CostFn):\n",
" \"\"\"An unusual cost function.\"\"\"\n",
"\n",
" def norm(self, x):\n",
" def norm(self, x: jnp.ndarray) -> jnp.ndarray:\n",
" return jnp.sum(x**3 + jnp.cos(x) ** 2, axis=-1)\n",
"\n",
" def pairwise(self, x, y):\n",
" return -jnp.sum(jnp.sin(x + 1) * jnp.sin(y)) * 2"
" def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:\n",
" return (\n",
" self.norm(x)\n",
" + self.norm(y)\n",
" - jnp.sum(jnp.sin(x + 1) * jnp.sin(y)) * 2\n",
" )"
]
},
{
Expand Down
9 changes: 3 additions & 6 deletions docs/tutorials/linear/000_One_Sinkhorn.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,16 @@
},
"outputs": [],
"source": [
"import functools\n",
"import time\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import ott\n",
"from ott import problems\n",
"from ott.geometry import geometry, pointcloud\n",
"from ott.solvers import linear\n",
"from ott.solvers.linear import acceleration, sinkhorn"
"from ott.solvers.linear import acceleration"
]
},
{
Expand Down Expand Up @@ -1087,7 +1084,7 @@
" plot_results(\n",
" out_scaling[i],\n",
" leg_scaling[i],\n",
" title=rf\"Decay, $\\varepsilon$=\" + str(epsilon),\n",
" title=r\"Decay, $\\varepsilon$=\" + str(epsilon),\n",
" xlabel=\"iterations\",\n",
" ylabel=\"error\",\n",
" )"
Expand Down Expand Up @@ -1276,7 +1273,7 @@
" plot_results(\n",
" out_mixed[i],\n",
" leg_mixed[i],\n",
" title=rf\"Mixed strategy, $\\varepsilon$=\" + str(epsilon),\n",
" title=r\"Mixed strategy, $\\varepsilon$=\" + str(epsilon),\n",
" xlabel=\"iterations\",\n",
" ylabel=\"error\",\n",
" )"
Expand Down
8 changes: 3 additions & 5 deletions docs/tutorials/linear/100_OTT_&_POT.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,19 @@
},
"outputs": [],
"source": [
"import timeit\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import numpy as np\n",
"import ot\n",
"\n",
"import matplotlib.pyplot as plt\n",
"\n",
"plt.rc(\"font\", size=20)\n",
"import mpl_toolkits.axes_grid1\n",
"\n",
"from ott.geometry import pointcloud\n",
"from ott.problems.linear import linear_problem\n",
"from ott.solvers.linear import sinkhorn"
"from ott.solvers.linear import sinkhorn\n",
"\n",
"plt.rc(\"font\", size=20)"
]
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
"metadata": {},
"outputs": [],
"source": [
"from typing import Any, Callable, Tuple\n",
"from typing import Any, Callable\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
Expand Down
3 changes: 1 addition & 2 deletions docs/tutorials/linear/600_mmsink.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
"metadata": {},
"outputs": [],
"source": [
"import math\n",
"from typing import List, Optional, Tuple\n",
"from typing import Optional\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
Expand Down
5 changes: 1 addition & 4 deletions docs/tutorials/misc/000_tracking_progress.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,12 @@
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import numpy as np\n",
"\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from ott import utils\n",
"from ott.geometry import pointcloud\n",
"from ott.problems.linear import linear_problem\n",
"from ott.problems.quadratic import quadratic_problem\n",
"from ott.solvers import linear, quadratic\n",
"from ott.solvers import linear\n",
"from ott.solvers.linear import sinkhorn, sinkhorn_lr\n",
"from ott.solvers.quadratic import gromov_wasserstein"
]
Expand Down
4 changes: 2 additions & 2 deletions docs/tutorials/misc/200_application_biology.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,8 @@
"for day in DAYS:\n",
" tmp = cell_distribution_ipsc[day].copy()\n",
" tmp[tmp >= 1e-2] = alpha_bins[0]\n",
" tmp[np.logical_and(1e-2 > tmp, tmp >= 5e-4)] = alpha_bins[1]\n",
" tmp[5e-4 > tmp] = alpha_bins[2]\n",
" tmp[np.logical_and(tmp < 1e-2, tmp >= 5e-4)] = alpha_bins[1]\n",
" tmp[tmp < 5e-4] = alpha_bins[2]\n",
" binned_cell_distribution_ipsc[day] = tmp"
]
},
Expand Down
2 changes: 0 additions & 2 deletions docs/tutorials/neural/000_neural_dual.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"import numpy as np\n",
"from torch.utils.data import DataLoader, IterableDataset\n",
"\n",
"import optax\n",
"\n",
Expand Down
71 changes: 35 additions & 36 deletions docs/tutorials/neural/100_icnn_inits.ipynb

Large diffs are not rendered by default.

7 changes: 3 additions & 4 deletions docs/tutorials/neural/200_Monge_Gap.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"import dataclasses\n",
"from collections.abc import Iterator, Mapping\n",
"from types import MappingProxyType\n",
"from typing import Any, Dict, Literal, Optional, Tuple, Union\n",
"from typing import Any, Literal, Optional\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
Expand All @@ -31,7 +31,6 @@
"from ott.geometry import costs, pointcloud\n",
"from ott.neural.methods import monge_gap\n",
"from ott.neural.networks import potentials\n",
"from ott.solvers.linear import acceleration\n",
"from ott.tools import sinkhorn_divergence"
]
},
Expand Down Expand Up @@ -587,7 +586,7 @@
}
],
"source": [
"plot_fit_map(rf\"with $\\ell_2$ Monge gap\", out_nn_l2, logs_l2)"
"plot_fit_map(r\"with $\\ell_2$ Monge gap\", out_nn_l2, logs_l2)"
]
},
{
Expand Down Expand Up @@ -659,7 +658,7 @@
}
],
"source": [
"plot_fit_map(rf\"with $\\ell_2^2$ Monge gap\", out_nn_l22, logs_l22)"
"plot_fit_map(r\"with $\\ell_2^2$ Monge gap\", out_nn_l22, logs_l22)"
]
}
],
Expand Down
3 changes: 1 addition & 2 deletions docs/tutorials/quadratic/000_gromov_wasserstein.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@
"import numpy as np\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import mpl_toolkits.mplot3d.axes3d as p3\n",
"from IPython import display\n",
"from matplotlib import animation, cm\n",
"from matplotlib import animation\n",
"\n",
"from ott.geometry import pointcloud\n",
"from ott.problems.quadratic import quadratic_problem\n",
Expand Down
1 change: 1 addition & 0 deletions docs/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ function for :class:`~ott.solvers.linear.sinkhorn.Sinkhorn`.

default_progress_fn
tqdm_progress_fn
batched_vmap
11 changes: 3 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "ott-jax"
description = "Optimal Transport Tools in JAX."
description = "Optimal Transport Tools in JAX"
requires-python = ">=3.9"
dynamic = ["version"]
readme = {file = "README.md", content-type = "text/markdown"}
Expand All @@ -17,6 +17,7 @@ dependencies = [
"jaxopt>=0.8",
"lineax>=0.0.5",
"numpy>=1.20.0",
"typing_extensions; python_version <= '3.9'",
]
keywords = [
"optimal transport",
Expand Down Expand Up @@ -107,7 +108,7 @@ multi_line_output = 3
sections = ["FUTURE", "STDLIB", "THIRDPARTY", "TEST", "NUMERIC", "NEURAL", "PLOTTING", "FIRSTPARTY", "LOCALFOLDER"]
# also contains what we import in notebooks/tests
known_neural = ["flax", "optax", "diffrax", "orbax"]
known_numeric = ["numpy", "scipy", "jax", "flax", "optax", "jaxopt", "ot", "torch", "torchvision", "pandas", "sklearn", "tslearn"]
known_numeric = ["numpy", "scipy", "jax", "chex", "flax", "optax", "jaxopt", "ot", "torch", "torchvision", "pandas", "sklearn", "tslearn"]
known_test = ["_pytest", "pytest"]
known_plotting = ["IPython", "matplotlib", "mpl_toolkits", "seaborn"]

Expand All @@ -120,12 +121,6 @@ markers = [
"cpu: Mark tests as CPU only.",
"fast: Mark tests as fast.",
]
filterwarnings = [
"ignore:\\n*.*scipy.sparse array",
"ignore:jax.random.KeyArray is deprecated:DeprecationWarning",
"ignore:.*jax.config:DeprecationWarning",
"ignore:jax.core.Shape is deprecated:DeprecationWarning:chex",
]

[tool.coverage.run]
branch = true
Expand Down
Loading

0 comments on commit c9d3a49

Please sign in to comment.