-
Notifications
You must be signed in to change notification settings - Fork 79
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Start batched vmap * Initial `batched_vmap` impl * Nicer formatting * Fix getting shape * Remove private API usage * Fix new args * Add a TODO * Canonicalize axes * Add `batched_vmap` to docs * Removed batched transport functions * Remove `_norm_{x,y}` from `CostFn` * Implement `apply_lse_kernel` * Implememt `apply_kernel` * Implement `apply_cost` * Remove old functions * Make function private * Refactor `apply_cost` to have consistent shapes * Use `_apply_cost_to_vec` in `PointCloud` * Remoeve TODO * Formatting * Simplify `_apply_sqeucl_cost` * Fix `RecusionError` * Remove docstring of a private method * Fix `apply_lse_kernel` * Squeeze only 1 axis of the cost * Add TODO * Rename function, make a property * Remove unused helper function * Compute mean summary online * Compute mean online * Compute max cost matrix * Update error message * Remove TODO * Flatten out axes * Fix missing cross terms in the costs * Fix geom tests * Fix dtype * Start implementing transport functions * Implement online transport functions * Fix solver tests * Fix Bures test * Don't use `pairwise` in tests * Update notebook that uses `norm` * Fix bug in `UnbalancedBures` * Rename `pairwise -> __call__` * Remove old shape code * Always instantiate the cost for online * Remove old TODO * Extract `_apply_cost_to_vec_fast` * Update max cost in LRCGeom * Fix test, use more `multi_dot` * Remove `batch_size` from `LRCGeometry` * Add better warning error * Reorder properties * Add docs to `batched_vmap` * Start adding tests * Reorder functions in test * Fix axes, add a test * Update test fn * Move out assert * Dont canon out_axes * Check max traces * Test memory of batched vmap * Install `typing_extensions` * Remove `.` from description * Add more `out_axes` tests * Add `in_axes` test * Fix negative axes * Increase memory limit in the test * Add in_axes pytree test * Remove old warnings filters * Update fixtures * Update SqEucl cost. * Update docstrings * Remove unused imports from the docs * Revert test pre-commits * Fix ICNN init notebook Was broken by #551 * Improve error message
- Loading branch information
Showing
40 changed files
with
784 additions
and
774 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,3 +11,4 @@ function for :class:`~ott.solvers.linear.sinkhorn.Sinkhorn`. | |
|
||
default_progress_fn | ||
tqdm_progress_fn | ||
batched_vmap |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.