From 458e265a1282d417f87bc0e07a5b6b77ffe9ae50 Mon Sep 17 00:00:00 2001 From: soerenab <36963673+soerenab@users.noreply.github.com> Date: Fri, 12 Apr 2024 18:50:45 +0200 Subject: [PATCH] bug fix: avoid mixing up linear and quadratic in genot (#517) * bug fix: avoid mixing up linear and quadratic part by returning Dict in genot prepare_data() * fix data_match_fn() setup in genot tests * prepare_data() in GENOT now returns a tuple instead of a dict; change order of args in utils.match_quadratic() * Update docs * Fix typo --------- Co-authored-by: Michal Klein <46717574+michalk8@users.noreply.github.com> --- src/ott/neural/methods/flows/genot.py | 38 +++++++++++++++++---------- tests/neural/methods/genot_test.py | 21 +++++---------- 2 files changed, 31 insertions(+), 28 deletions(-) diff --git a/src/ott/neural/methods/flows/genot.py b/src/ott/neural/methods/flows/genot.py index ce200d376..e7ca5c1bc 100644 --- a/src/ott/neural/methods/flows/genot.py +++ b/src/ott/neural/methods/flows/genot.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import jax import jax.numpy as jnp @@ -29,12 +29,11 @@ __all__ = ["GENOT"] -# input: (src_lin, tgt_lin, src_quad, tgt_quad), output: (len(src), len(tgt)) -# all are optional because the problem can be linear/quadratic/fused -DataMatchFn_t = Callable[[ - Optional[jnp.ndarray], Optional[jnp.ndarray], Optional[jnp.ndarray], - Optional[jnp.ndarray] -], jnp.ndarray] +LinTerm = Tuple[jnp.ndarray, jnp.ndarray] +QuadTerm = Tuple[jnp.ndarray, jnp.ndarray, Optional[jnp.ndarray], + Optional[jnp.ndarray]] +DataMatchFn = Union[Callable[[LinTerm], jnp.ndarray], Callable[[QuadTerm], + jnp.ndarray]] class GENOT: @@ -49,8 +48,14 @@ class GENOT: vf: Vector field parameterized by a neural network. flow: Flow between the latent and the target distributions. data_match_fn: Function to match samples from the source and the target - distributions with a ``(src_lin, tgt_lin, src_quad, tgt_quad) -> matching`` - signature. + distributions. Depending on the data passed in :meth:`__call__`, it has + the following signature: + + - ``(src_lin, tgt_lin) -> matching`` - linear matching. + - ``(src_quad, tgt_quad, src_lin, tgt_lin) -> matching`` - + quadratic (fused) GW matching. In the pure GW setting, both ``src_lin`` + and ``tgt_lin`` will be set to :obj:`None`. + source_dim: Dimensionality of the source distribution. target_dim: Dimensionality of the target distribution. condition_dim: Dimension of the conditions. If :obj:`None`, the underlying @@ -73,7 +78,7 @@ def __init__( self, vf: velocity_field.VelocityField, flow: dynamics.BaseFlow, - data_match_fn: DataMatchFn_t, + data_match_fn: DataMatchFn, *, source_dim: int, target_dim: int, @@ -162,19 +167,24 @@ def __call__( def prepare_data( batch: Dict[str, jnp.ndarray] - ) -> Tuple[Tuple[jnp.ndarray, Optional[jnp.ndarray], jnp.ndarray], Tuple[ - jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]]: + ) -> Tuple[Tuple[jnp.ndarray, Optional[jnp.ndarray], jnp.ndarray], + Tuple[jnp.ndarray, jnp.ndarray, Optional[jnp.ndarray], + Optional[jnp.ndarray]]]: src_lin, src_quad = batch.get("src_lin"), batch.get("src_quad") tgt_lin, tgt_quad = batch.get("tgt_lin"), batch.get("tgt_quad") - arrs = src_lin, tgt_lin, src_quad, tgt_quad if src_quad is None and tgt_quad is None: # lin src, tgt = src_lin, tgt_lin + arrs = src_lin, tgt_lin elif src_lin is None and tgt_lin is None: # quad src, tgt = src_quad, tgt_quad - elif all(arr is not None for arr in arrs): # fused quad + arrs = src_quad, tgt_quad + elif all( + arr is not None for arr in (src_lin, tgt_lin, src_quad, tgt_quad) + ): # fused quad src = jnp.concatenate([src_lin, src_quad], axis=1) tgt = jnp.concatenate([tgt_lin, tgt_quad], axis=1) + arrs = src_quad, tgt_quad, src_lin, tgt_lin else: raise RuntimeError("Cannot infer OT problem type from data.") diff --git a/tests/neural/methods/genot_test.py b/tests/neural/methods/genot_test.py index 2c746596c..d4d8a1399 100644 --- a/tests/neural/methods/genot_test.py +++ b/tests/neural/methods/genot_test.py @@ -11,8 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import functools -from typing import Literal, Optional +from typing import Literal import pytest @@ -28,20 +27,14 @@ from ott.solvers import utils as solver_utils -def data_match_fn( - src_lin: Optional[jnp.ndarray], tgt_lin: Optional[jnp.ndarray], - src_quad: Optional[jnp.ndarray], tgt_quad: Optional[jnp.ndarray], *, - typ: Literal["lin", "quad", "fused"] -) -> jnp.ndarray: +def get_match_fn(typ: Literal["lin", "quad", "fused"]): if typ == "lin": - return solver_utils.match_linear(x=src_lin, y=tgt_lin) + return solver_utils.match_linear if typ == "quad": - return solver_utils.match_quadratic(xx=src_quad, yy=tgt_quad) + return solver_utils.match_quadratic if typ == "fused": - return solver_utils.match_quadratic( - xx=src_quad, yy=tgt_quad, x=src_lin, y=tgt_lin - ) - raise NotImplementedError(f"Unknown type: {typ}.") + return solver_utils.match_quadratic + raise NotImplementedError(typ) class TestGENOT: @@ -69,7 +62,7 @@ def test_genot(self, rng: jax.Array, dl: str, request): model = genot.GENOT( vf, flow=dynamics.ConstantNoiseFlow(0.0), - data_match_fn=functools.partial(data_match_fn, typ=problem_type), + data_match_fn=get_match_fn(problem_type), source_dim=src_dim, target_dim=tgt_dim, condition_dim=cond_dim,