Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/batched vmap #588

Merged
merged 79 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
79 commits
Select commit Hold shift + click to select a range
4472425
Start batched vmap
michalk8 Sep 20, 2024
80b9c60
Initial `batched_vmap` impl
michalk8 Oct 10, 2024
7c58b9a
Nicer formatting
michalk8 Oct 10, 2024
b995d15
Fix getting shape
michalk8 Oct 10, 2024
e5cfa1e
Remove private API usage
michalk8 Oct 10, 2024
f8eee7b
Fix new args
michalk8 Oct 10, 2024
51980ef
Add a TODO
michalk8 Oct 10, 2024
7a510ae
Canonicalize axes
michalk8 Oct 10, 2024
1ca05f1
Add `batched_vmap` to docs
michalk8 Oct 10, 2024
e7ad8a4
Removed batched transport functions
michalk8 Oct 10, 2024
acf221b
Remove `_norm_{x,y}` from `CostFn`
michalk8 Oct 10, 2024
07f4734
Implement `apply_lse_kernel`
michalk8 Oct 10, 2024
98f1de9
Implememt `apply_kernel`
michalk8 Oct 10, 2024
ba825f9
Implement `apply_cost`
michalk8 Oct 10, 2024
9036329
Remove old functions
michalk8 Oct 10, 2024
62089bc
Make function private
michalk8 Oct 10, 2024
b9bb64a
Refactor `apply_cost` to have consistent shapes
michalk8 Oct 10, 2024
d8b5ea6
Use `_apply_cost_to_vec` in `PointCloud`
michalk8 Oct 10, 2024
12f923c
Remoeve TODO
michalk8 Oct 10, 2024
a43dcdb
Formatting
michalk8 Oct 10, 2024
e0d75b0
Simplify `_apply_sqeucl_cost`
michalk8 Oct 10, 2024
f5445ec
Fix `RecusionError`
michalk8 Oct 10, 2024
4922ca9
Remove docstring of a private method
michalk8 Oct 10, 2024
799c108
Fix `apply_lse_kernel`
michalk8 Oct 10, 2024
9a5c1ca
Squeeze only 1 axis of the cost
michalk8 Oct 10, 2024
8543538
Add TODO
michalk8 Oct 10, 2024
317eb02
Rename function, make a property
michalk8 Oct 10, 2024
d31fd4d
Remove unused helper function
michalk8 Oct 10, 2024
4b0f150
Compute mean summary online
michalk8 Oct 10, 2024
8843937
Compute mean online
michalk8 Oct 10, 2024
83c9960
Compute max cost matrix
michalk8 Oct 10, 2024
69a9599
Update error message
michalk8 Oct 11, 2024
6667abd
Remove TODO
michalk8 Oct 11, 2024
ac9b928
Flatten out axes
michalk8 Oct 11, 2024
c113946
Fix missing cross terms in the costs
michalk8 Oct 11, 2024
75d9e7a
Fix geom tests
michalk8 Oct 11, 2024
44eb5a8
Fix dtype
michalk8 Oct 11, 2024
cbb4ea0
Start implementing transport functions
michalk8 Oct 11, 2024
e8bb1b5
Implement online transport functions
michalk8 Oct 11, 2024
7d4001e
Fix solver tests
michalk8 Oct 11, 2024
8941224
Fix Bures test
michalk8 Oct 11, 2024
a565e09
Don't use `pairwise` in tests
michalk8 Oct 11, 2024
1533324
Update notebook that uses `norm`
michalk8 Oct 11, 2024
3e7ff8b
Fix bug in `UnbalancedBures`
michalk8 Oct 11, 2024
b815fbc
Rename `pairwise -> __call__`
michalk8 Oct 11, 2024
739afde
Remove old shape code
michalk8 Oct 11, 2024
0d7f6ae
Always instantiate the cost for online
michalk8 Oct 11, 2024
4863fcf
Remove old TODO
michalk8 Oct 11, 2024
4aa4c6b
Extract `_apply_cost_to_vec_fast`
michalk8 Oct 11, 2024
8511073
Update max cost in LRCGeom
michalk8 Oct 11, 2024
47462d2
Fix test, use more `multi_dot`
michalk8 Oct 11, 2024
05630a8
Remove `batch_size` from `LRCGeometry`
michalk8 Oct 11, 2024
0994d7a
Add better warning error
michalk8 Oct 15, 2024
5d88ad4
Reorder properties
michalk8 Oct 15, 2024
f8143fc
Add docs to `batched_vmap`
michalk8 Oct 15, 2024
a82688c
Start adding tests
michalk8 Oct 15, 2024
1d2d12d
Reorder functions in test
michalk8 Oct 15, 2024
44b1126
Fix axes, add a test
michalk8 Oct 15, 2024
889f81f
Update test fn
michalk8 Oct 15, 2024
b16d5a8
Move out assert
michalk8 Oct 15, 2024
c984a43
Dont canon out_axes
michalk8 Oct 15, 2024
4426994
Check max traces
michalk8 Oct 15, 2024
5e5125b
Test memory of batched vmap
michalk8 Oct 15, 2024
cb31db7
Install `typing_extensions`
michalk8 Oct 15, 2024
57bf9ca
Merge branch 'main' into feature/batched-vmap
michalk8 Oct 15, 2024
721eca9
Remove `.` from description
michalk8 Oct 15, 2024
f9a41bd
Add more `out_axes` tests
michalk8 Oct 15, 2024
78003d9
Add `in_axes` test
michalk8 Oct 15, 2024
9e1ae03
Fix negative axes
michalk8 Oct 15, 2024
427b5ec
Increase memory limit in the test
michalk8 Oct 16, 2024
fff0ce6
Add in_axes pytree test
michalk8 Oct 16, 2024
f72abf1
Remove old warnings filters
michalk8 Oct 16, 2024
b19ff4b
Update fixtures
michalk8 Oct 16, 2024
462c630
Update SqEucl cost.
michalk8 Oct 16, 2024
babb095
Update docstrings
michalk8 Oct 16, 2024
87df731
Remove unused imports from the docs
michalk8 Oct 16, 2024
07dff82
Revert test pre-commits
michalk8 Oct 16, 2024
5450808
Fix ICNN init notebook
michalk8 Oct 16, 2024
e390e64
Improve error message
michalk8 Oct 16, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions docs/tutorials/barycenter/000_Sinkhorn_Barycenters.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@
"metadata": {},
"outputs": [],
"source": [
"import nilearn\n",
"from nilearn import datasets, image, plotting\n",
"from nilearn.image import get_data\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
Expand All @@ -38,7 +36,7 @@
"id": "02VJX2uXYHDX"
},
"source": [
"## Import neuroimaging data using `nilearn`.\n",
"## Import neuroimaging data using `nilearn`\n",
"\n",
"We recover a few MRI data points..."
]
Expand Down Expand Up @@ -157,7 +155,7 @@
}
],
"source": [
"a = jnp.array(get_data(gm_imgs)).transpose((3, 0, 1, 2))\n",
"a = jnp.array(image.get_data(gm_imgs)).transpose((3, 0, 1, 2))\n",
"grid_size = a.shape[1:4]\n",
"a = a.reshape((n_subjects, -1)) + 1e-2\n",
"a = a / np.sum(a, axis=1)[:, np.newaxis]\n",
Expand Down Expand Up @@ -320,9 +318,7 @@
],
"source": [
"def data_to_nii(x):\n",
" return nilearn.image.new_img_like(\n",
" gm_imgs[0], data=np.array(x.reshape(grid_size))\n",
" )\n",
" return image.new_img_like(gm_imgs[0], data=np.array(x.reshape(grid_size)))\n",
"\n",
"\n",
"plotting.plot_epi(data_to_nii(barycenter.histogram))\n",
Expand Down
1 change: 0 additions & 1 deletion docs/tutorials/barycenter/200_gmm_pair_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@
" fit_gmm_pair,\n",
" gaussian_mixture,\n",
" gaussian_mixture_pair,\n",
" probabilities,\n",
")"
]
},
Expand Down
10 changes: 7 additions & 3 deletions docs/tutorials/geometry/100_grid.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -241,11 +241,15 @@
"class MyCost(costs.CostFn):\n",
" \"\"\"An unusual cost function.\"\"\"\n",
"\n",
" def norm(self, x):\n",
" def norm(self, x: jnp.ndarray) -> jnp.ndarray:\n",
" return jnp.sum(x**3 + jnp.cos(x) ** 2, axis=-1)\n",
"\n",
" def pairwise(self, x, y):\n",
" return -jnp.sum(jnp.sin(x + 1) * jnp.sin(y)) * 2"
" def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:\n",
" return (\n",
" self.norm(x)\n",
" + self.norm(y)\n",
" - jnp.sum(jnp.sin(x + 1) * jnp.sin(y)) * 2\n",
" )"
]
},
{
Expand Down
9 changes: 3 additions & 6 deletions docs/tutorials/linear/000_One_Sinkhorn.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,16 @@
},
"outputs": [],
"source": [
"import functools\n",
"import time\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import ott\n",
"from ott import problems\n",
"from ott.geometry import geometry, pointcloud\n",
"from ott.solvers import linear\n",
"from ott.solvers.linear import acceleration, sinkhorn"
"from ott.solvers.linear import acceleration"
]
},
{
Expand Down Expand Up @@ -1087,7 +1084,7 @@
" plot_results(\n",
" out_scaling[i],\n",
" leg_scaling[i],\n",
" title=rf\"Decay, $\\varepsilon$=\" + str(epsilon),\n",
" title=r\"Decay, $\\varepsilon$=\" + str(epsilon),\n",
" xlabel=\"iterations\",\n",
" ylabel=\"error\",\n",
" )"
Expand Down Expand Up @@ -1276,7 +1273,7 @@
" plot_results(\n",
" out_mixed[i],\n",
" leg_mixed[i],\n",
" title=rf\"Mixed strategy, $\\varepsilon$=\" + str(epsilon),\n",
" title=r\"Mixed strategy, $\\varepsilon$=\" + str(epsilon),\n",
" xlabel=\"iterations\",\n",
" ylabel=\"error\",\n",
" )"
Expand Down
8 changes: 3 additions & 5 deletions docs/tutorials/linear/100_OTT_&_POT.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,19 @@
},
"outputs": [],
"source": [
"import timeit\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import numpy as np\n",
"import ot\n",
"\n",
"import matplotlib.pyplot as plt\n",
"\n",
"plt.rc(\"font\", size=20)\n",
"import mpl_toolkits.axes_grid1\n",
"\n",
"from ott.geometry import pointcloud\n",
"from ott.problems.linear import linear_problem\n",
"from ott.solvers.linear import sinkhorn"
"from ott.solvers.linear import sinkhorn\n",
"\n",
"plt.rc(\"font\", size=20)"
]
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
"metadata": {},
"outputs": [],
"source": [
"from typing import Any, Callable, Tuple\n",
"from typing import Any, Callable\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
Expand Down
3 changes: 1 addition & 2 deletions docs/tutorials/linear/600_mmsink.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
"metadata": {},
"outputs": [],
"source": [
"import math\n",
"from typing import List, Optional, Tuple\n",
"from typing import Optional\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
Expand Down
5 changes: 1 addition & 4 deletions docs/tutorials/misc/000_tracking_progress.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,12 @@
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import numpy as np\n",
"\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from ott import utils\n",
"from ott.geometry import pointcloud\n",
"from ott.problems.linear import linear_problem\n",
"from ott.problems.quadratic import quadratic_problem\n",
"from ott.solvers import linear, quadratic\n",
"from ott.solvers import linear\n",
"from ott.solvers.linear import sinkhorn, sinkhorn_lr\n",
"from ott.solvers.quadratic import gromov_wasserstein"
]
Expand Down
4 changes: 2 additions & 2 deletions docs/tutorials/misc/200_application_biology.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,8 @@
"for day in DAYS:\n",
" tmp = cell_distribution_ipsc[day].copy()\n",
" tmp[tmp >= 1e-2] = alpha_bins[0]\n",
" tmp[np.logical_and(1e-2 > tmp, tmp >= 5e-4)] = alpha_bins[1]\n",
" tmp[5e-4 > tmp] = alpha_bins[2]\n",
" tmp[np.logical_and(tmp < 1e-2, tmp >= 5e-4)] = alpha_bins[1]\n",
" tmp[tmp < 5e-4] = alpha_bins[2]\n",
" binned_cell_distribution_ipsc[day] = tmp"
]
},
Expand Down
2 changes: 0 additions & 2 deletions docs/tutorials/neural/000_neural_dual.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"import numpy as np\n",
"from torch.utils.data import DataLoader, IterableDataset\n",
"\n",
"import optax\n",
"\n",
Expand Down
71 changes: 35 additions & 36 deletions docs/tutorials/neural/100_icnn_inits.ipynb

Large diffs are not rendered by default.

7 changes: 3 additions & 4 deletions docs/tutorials/neural/200_Monge_Gap.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"import dataclasses\n",
"from collections.abc import Iterator, Mapping\n",
"from types import MappingProxyType\n",
"from typing import Any, Dict, Literal, Optional, Tuple, Union\n",
"from typing import Any, Literal, Optional\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
Expand All @@ -31,7 +31,6 @@
"from ott.geometry import costs, pointcloud\n",
"from ott.neural.methods import monge_gap\n",
"from ott.neural.networks import potentials\n",
"from ott.solvers.linear import acceleration\n",
"from ott.tools import sinkhorn_divergence"
]
},
Expand Down Expand Up @@ -587,7 +586,7 @@
}
],
"source": [
"plot_fit_map(rf\"with $\\ell_2$ Monge gap\", out_nn_l2, logs_l2)"
"plot_fit_map(r\"with $\\ell_2$ Monge gap\", out_nn_l2, logs_l2)"
]
},
{
Expand Down Expand Up @@ -659,7 +658,7 @@
}
],
"source": [
"plot_fit_map(rf\"with $\\ell_2^2$ Monge gap\", out_nn_l22, logs_l22)"
"plot_fit_map(r\"with $\\ell_2^2$ Monge gap\", out_nn_l22, logs_l22)"
]
}
],
Expand Down
3 changes: 1 addition & 2 deletions docs/tutorials/quadratic/000_gromov_wasserstein.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@
"import numpy as np\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import mpl_toolkits.mplot3d.axes3d as p3\n",
"from IPython import display\n",
"from matplotlib import animation, cm\n",
"from matplotlib import animation\n",
"\n",
"from ott.geometry import pointcloud\n",
"from ott.problems.quadratic import quadratic_problem\n",
Expand Down
1 change: 1 addition & 0 deletions docs/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ function for :class:`~ott.solvers.linear.sinkhorn.Sinkhorn`.

default_progress_fn
tqdm_progress_fn
batched_vmap
11 changes: 3 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "ott-jax"
description = "Optimal Transport Tools in JAX."
description = "Optimal Transport Tools in JAX"
requires-python = ">=3.9"
dynamic = ["version"]
readme = {file = "README.md", content-type = "text/markdown"}
Expand All @@ -17,6 +17,7 @@ dependencies = [
"jaxopt>=0.8",
"lineax>=0.0.5",
"numpy>=1.20.0",
"typing_extensions; python_version <= '3.9'",
]
keywords = [
"optimal transport",
Expand Down Expand Up @@ -107,7 +108,7 @@ multi_line_output = 3
sections = ["FUTURE", "STDLIB", "THIRDPARTY", "TEST", "NUMERIC", "NEURAL", "PLOTTING", "FIRSTPARTY", "LOCALFOLDER"]
# also contains what we import in notebooks/tests
known_neural = ["flax", "optax", "diffrax", "orbax"]
known_numeric = ["numpy", "scipy", "jax", "flax", "optax", "jaxopt", "ot", "torch", "torchvision", "pandas", "sklearn", "tslearn"]
known_numeric = ["numpy", "scipy", "jax", "chex", "flax", "optax", "jaxopt", "ot", "torch", "torchvision", "pandas", "sklearn", "tslearn"]
known_test = ["_pytest", "pytest"]
known_plotting = ["IPython", "matplotlib", "mpl_toolkits", "seaborn"]

Expand All @@ -120,12 +121,6 @@ markers = [
"cpu: Mark tests as CPU only.",
"fast: Mark tests as fast.",
]
filterwarnings = [
"ignore:\\n*.*scipy.sparse array",
"ignore:jax.random.KeyArray is deprecated:DeprecationWarning",
"ignore:.*jax.config:DeprecationWarning",
"ignore:jax.core.Shape is deprecated:DeprecationWarning:chex",
]

[tool.coverage.run]
branch = true
Expand Down
Loading