Skip to content

Commit

Permalink
Meta OT Initializer (#145)
Browse files Browse the repository at this point in the history
* add meta ot notebook

* tidy up notebook, add plot

* add the Meta OT nb to the advanced applications

* Pass on Meta OT initializers

+ Move to ott/core/initializers
+ Plot MNIST images and interpolations in nb
+ Add test and docs

* init_key -> rng

* call into ent_reg_cost for the dual objective

* MetaOTInitializer -> FixedGeometryMetaOTInitializer

* update nb to latest interface

* address comments from @michalk8

* address comments from @michalk8

* phrasing

Co-authored-by: “JTT94” <“[email protected]”>
  • Loading branch information
bamos and “JTT94” authored Oct 12, 2022
1 parent 8a8e406 commit 9d76c04
Show file tree
Hide file tree
Showing 6 changed files with 1,049 additions and 5 deletions.
2 changes: 2 additions & 0 deletions docs/core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ Sinkhorn Dual Initializers
initializers.DefaultInitializer
initializers.GaussianInitializer
initializers.SortingInitializer
initializers.MetaInitializer
initializers.MetaMLP

Low-Rank Sinkhorn
-----------------
Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ There are currently three packages, ``geometry``, ``core`` and ``tools``, playin
notebooks/icnn_inits.ipynb
notebooks/wasserstein_barycenters_gmms.ipynb
notebooks/gmm_pair_demo.ipynb
notebooks/MetaOT.ipynb

.. toctree::
:maxdepth: 1
Expand Down
775 changes: 775 additions & 0 deletions docs/notebooks/MetaOT.ipynb

Large diffs are not rendered by default.

7 changes: 7 additions & 0 deletions docs/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -588,3 +588,10 @@ @misc{pooladian:21
year = {2021},
copyright = {arXiv.org perpetual, non-exclusive license}
}

@article{amos:22,
title={Meta Optimal Transport},
author={Amos, Brandon and Cohen, Samuel and Luise, Giulia and Redko, Ievgen},
journal={arXiv preprint arXiv:2206.05262},
year={2022}
}
223 changes: 218 additions & 5 deletions ott/core/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Sinkhorn initializers."""
import functools
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Sequence, Tuple

import jax
import jax.numpy as jnp
import optax
from flax import linen as nn
from flax.training import train_state

from ott.core import linear_problems
from ott.geometry import pointcloud
from ott.core import linear_problems, sinkhorn
from ott.geometry import geometry, pointcloud

__all__ = ["DefaultInitializer", "GaussianInitializer", "SortingInitializer"]
__all__ = [
"DefaultInitializer", "GaussianInitializer", "SortingInitializer",
"MetaInitializer"
]


@jax.tree_util.register_pytree_node_class
Expand Down Expand Up @@ -51,9 +58,9 @@ def __call__(
Args:
ot_prob: Linear OT problem.
a: Initial potential/scaling f_u. If `None`, it will be initialized using
a: Initial potential/scaling f_u. If ``None``, it will be initialized using
:meth:`init_dual_a`.
b: Initial potential/scaling g_v. If `None`, it will be initialized using
b: Initial potential/scaling g_v. If ``None``, it will be initialized using
:meth:`init_dual_b`.
lse_mode: Return potentials if true, scalings otherwise.
Expand Down Expand Up @@ -278,6 +285,212 @@ def init_dual_a(
return f_u


@jax.tree_util.register_pytree_node_class
class MetaInitializer(DefaultInitializer):
"""Meta OT Initializer with a fixed geometry :cite:`amos:22`.
This initializer consists of a predictive model that outputs the
:math:`f` duals to solve the entropy-regularized OT problem given
input probability weights ``a`` and ``b``, and a given (assumed to be
fixed) geometry ``geom``.
The model's parameters are learned using a training set of OT
instances (multiple pairs of probability weights), that assume the
**same** geometry ``geom`` is used throughout, both for training and
evaluation. The meta model defaults to the MLP in
:class:`~ott.core.initializers.MetaMLP` and, with batched problem
instances passed into :meth:`update`.
**Sample training usage.** The following code shows a simple
example of using ``update`` to train the model, where
``a`` and ``b`` are the weights of the measures and
``geom`` is the fixed geometry.
.. code-block:: python
meta_initializer = init_lib.MetaInitializer(geom=geom)
while training():
a, b = sample_batch()
loss, init_f, meta_initializer.state = meta_initializer.update(
meta_initializer.state, a=a, b=b)
Args:
geom: The fixed geometry of the problem instances.
meta_model: The model to predict the potential :math:`f` from the measures.
opt: The optimizer to update the parameters.
rng: The PRNG key to use for initializing the model.
state: The training state of the model to start from.
"""

def __init__(
self,
geom: geometry.Geometry,
meta_model: Optional[nn.Module] = None,
opt: optax.GradientTransformation = optax.adam(learning_rate=1e-3),
rng: jax.random.PRNGKeyArray = jax.random.PRNGKey(0),
state: Optional[train_state.TrainState] = None
):
self.geom = geom
self.dtype = geom.x.dtype
self.opt = opt
self.rng = rng

na, nb = geom.shape
self.meta_model = MetaMLP(
potential_size=na
) if meta_model is None else meta_model

if state is None:
# Initialize the model's training state.
a_placeholder = jnp.zeros(na, dtype=self.dtype)
b_placeholder = jnp.zeros(nb, dtype=self.dtype)
params = self.meta_model.init(rng, a_placeholder, b_placeholder)['params']
self.state = train_state.TrainState.create(
apply_fn=self.meta_model.apply, params=params, tx=opt
)
else:
self.state = state

self.update_impl = self._get_update_fn()

def update(
self, state: train_state.TrainState, a: jnp.ndarray, b: jnp.ndarray
) -> Tuple[jnp.ndarray, jnp.ndarray, train_state.TrainState]:
r"""Update the meta model with the dual objective.
The goal is for the model to match the optimal duals, i.e.,
:math:`\hat f_\theta \approx f^\star`.
This can be done by training the predictions of :math:`\hat f_\theta`
to optimize the dual objective, which :math:`f^\star` also optimizes for.
The overall learning setup can thus be written as:
.. math::
\min_\theta\; {\mathbb E}_{(\alpha,\beta)\sim{\mathcal{D}}}\;
J(\hat f_\theta(a, b); \alpha, \beta),
where :math:`a,b` are the probabilities of the measures :math:`\alpha,\beta`,
:math:`\mathcal{D}` is a meta distribution of optimal transport problems,
.. math::
-J(f; \alpha, \beta, c) := \langle f, a\rangle + \langle g, b \rangle -
\varepsilon\left\langle \exp\{f/\varepsilon\}, K\exp\{g/\varepsilon\}\right\rangle
is the entropic dual objective,
and :math:`K_{i,j} := -C_{i,j}/\varepsilon` is the *Gibbs kernel*.
Args:
state: Optimizer state of the meta model.
a: Probabilites of the :math:`\alpha` measure's atoms.
b: Probabilites of the :math:`\beta` measure's atoms.
Returns:
The training loss, :math:`f`, and updated state.
"""
return self.update_impl(state, a, b)

def init_dual_a(
self, ot_prob: linear_problems.LinearProblem, lse_mode: bool
) -> jnp.ndarray:
# Detect if the problem is batched.
assert ot_prob.a.ndim in (1, 2) and ot_prob.b.ndim in (1, 2)
vmap_a_val = 0 if ot_prob.a.ndim == 2 else None
vmap_b_val = 0 if ot_prob.b.ndim == 2 else None

if vmap_a_val is not None or vmap_b_val is not None:
compute_f_maybe_batch = jax.vmap(
self._compute_f, in_axes=(vmap_a_val, vmap_b_val, None)
)
else:
compute_f_maybe_batch = self._compute_f

init_f = compute_f_maybe_batch(ot_prob.a, ot_prob.b, self.state.params)
f_u = init_f if lse_mode else ot_prob.geom.scaling_from_potential(init_f)
return f_u

def _get_update_fn(self):
"""Return the implementation (and jitted) update function."""

def dual_obj_loss_single(params, a, b):
f_pred = self._compute_f(a, b, params)
g_pred = self.geom.update_potential(
f_pred, jnp.zeros_like(b), jnp.log(b), 0, axis=0
)
g_pred = jnp.where(jnp.isfinite(g_pred), g_pred, 0.)

ot_prob = linear_problems.LinearProblem(geom=self.geom, a=a, b=b)
dual_obj = sinkhorn.ent_reg_cost(f_pred, g_pred, ot_prob, lse_mode=True)
loss = -dual_obj
return loss, f_pred

def loss_batch(params, a, b):
loss_fn = functools.partial(dual_obj_loss_single, params=params)
loss, f_pred = jax.vmap(loss_fn)(a=a, b=b)
return jnp.mean(loss), f_pred

@jax.jit
def update(state, a, b):
a = jnp.atleast_2d(a)
b = jnp.atleast_2d(b)
grad_fn = jax.value_and_grad(loss_batch, has_aux=True)
(loss, init_f), grads = grad_fn(state.params, a, b)
return loss, init_f, state.apply_gradients(grads=grads)

return update

def _compute_f(self, a, b, params):
r"""Predict the optimal :math:`f` potential.
Args:
a: Probabilites of the :math:`\alpha` measure's atoms.
b: Probabilites of the :math:`\beta` measure's atoms.
params: The parameters of the Meta model.
Returns:
The :math:`f` potential.
"""
return self.meta_model.apply({'params': params}, a, b)

def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]:
return [self.geom, self.meta_model, self.opt], {
'rng': self.rng,
'state': self.state
}


class MetaMLP(nn.Module):
r"""A Meta MLP potential for :class:`~ott.core.initializers.MetaInitializer`.
This provides an MLP :math:`\hat f_\theta(a, b)` that maps from the probabilities
of the measures to the optimal dual potentials :math:`f`.
Args:
potential_size: The dimensionality of :math:`f`.
num_hidden_units: The number of hidden units in each layer.
num_hidden_layers: The number of hidden layers.
"""

potential_size: int
num_hidden_units: int = 512
num_hidden_layers: int = 3

@nn.compact
def __call__(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
r"""Make a prediction.
Args:
a: Probabilites of the :math:`\alpha` measure's atoms.
b: Probabilites of the :math:`\beta` measure's atoms.
Returns:
The :math:`f` potential.
"""
dtype = a.dtype
z = jnp.concatenate((a, b))
for _ in range(self.num_hidden_layers):
z = nn.relu(nn.Dense(self.num_hidden_units, dtype=dtype)(z))
f = nn.Dense(self.potential_size, dtype=dtype)(z)
return f


def _vectorized_update(
f: jnp.ndarray, modified_cost: jnp.ndarray
) -> jnp.ndarray:
Expand Down
46 changes: 46 additions & 0 deletions tests/core/initializers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,52 @@ def test_gauss_initializer(self, lse_mode, rng: jnp.ndarray):
if lse_mode:
assert base_num_iter >= gaus_num_iter

@pytest.mark.parametrize('lse_mode', [True, False])
def test_meta_initializer(self, lse_mode, rng: jnp.ndarray):
"""Tests Meta initializer"""
# define OT problem
n = 200
m = 200
d = 2
epsilon = 0.01

ot_problem = create_ot_problem(rng, n, m, d, epsilon=epsilon, online=False)
a = ot_problem.a
b = ot_problem.b
geom = ot_problem.geom

# run sinkhorn
sink_out = run_sinkhorn(
x=ot_problem.geom.x,
y=ot_problem.geom.y,
a=ot_problem.a,
b=ot_problem.b,
epsilon=epsilon,
lse_mode=lse_mode
)
base_num_iter = jnp.sum(sink_out.errors > -1)

# Overfit the initializer to the problem.
meta_initializer = init_lib.MetaInitializer(geom)
for _ in range(100):
_, _, meta_initializer.state = meta_initializer.update(
meta_initializer.state, a=a, b=b
)

sink_out = sinkhorn.sinkhorn(
geom,
a=a,
b=b,
jit=True,
initializer=meta_initializer,
lse_mode=lse_mode
)
meta_num_iter = jnp.sum(sink_out.errors > -1)

# check initializer is better
if lse_mode:
assert base_num_iter >= meta_num_iter


class TestLRInitializers:

Expand Down

0 comments on commit 9d76c04

Please sign in to comment.