Skip to content

Commit

Permalink
TST: Adapt Wasserstein test to new interface after refactor.
Browse files Browse the repository at this point in the history
  • Loading branch information
jwboth committed Nov 3, 2023
1 parent 1db309b commit 50f989b
Showing 1 changed file with 15 additions and 47 deletions.
62 changes: 15 additions & 47 deletions tests/unit/test_wasserstein.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,28 +75,21 @@
# Linearization
newton_options = {
# Scheme
"L": 1e-9,
"L": 1e9,
}
bregman_std_options = {
# Scheme
"L": 1,
}
bregman_reordered_options = {
# Scheme
"L": 1,
"bregman_mode": "reordered",
}
bregman_adaptive_options = {
# Scheme
"L": 1,
"bregman_mode": "adaptive",
"bregman_update_cond": lambda iter: iter % 20 == 0,
"bregman_update": lambda iter: iter % 20 == 0,
}
linearizations = {
"newton": [newton_options],
"bregman": [
bregman_std_options,
bregman_reordered_options,
bregman_adaptive_options,
],
}
Expand All @@ -117,27 +110,24 @@
# Linear solver
lu_options = {
# Linear solver
"linear_solver": "lu"
"linear_solver": "direct",
}
amg_options = {
"linear_solver": "amg-potential",
"linear_solver_tol": 1e-8,
"linear_solver": "amg",
"linear_solver_options": {
"tol": 1e-8,
},
}
solvers = [lu_options, amg_options]

# General options
options = {
# Solver parameters
"regularization": 1e-16,
# Scheme
"lumping": True,
# Performance control
"num_iter": 400,
"tol_residual": 1e-10,
"tol_increment": 1e-6,
"tol_distance": 1e-10,
# Output
"verbose": False,
"return_info": True,
}

# ! ---- Tests ----
Expand All @@ -151,15 +141,14 @@ def test_newton(a_key, s_key, dim):
options.update(newton_options)
options.update(accelerations[a_key])
options.update(solvers[s_key])
distance, _, _, _, status = darsia.wasserstein_distance(
distance, info = darsia.wasserstein_distance(
src_image[dim],
dst_image[dim],
options=options,
method="newton",
return_solution=True,
)
assert np.isclose(distance, true_distance[dim], atol=1e-5)
assert status["converged"]
assert info["converged"]


@pytest.mark.parametrize("a_key", range(len(accelerations)))
Expand All @@ -170,34 +159,14 @@ def test_std_bregman(a_key, s_key, dim):
options.update(bregman_std_options)
options.update(accelerations[a_key])
options.update(solvers[s_key])
distance, _, _, _, status = darsia.wasserstein_distance(
src_image[dim],
dst_image[dim],
options=options,
method="bregman",
return_solution=True,
)
assert np.isclose(distance, true_distance[dim], atol=1e-2) # TODO
assert status["converged"]


@pytest.mark.parametrize("a_key", range(len(accelerations)))
@pytest.mark.parametrize("s_key", range(len(solvers)))
@pytest.mark.parametrize("dim", [2, 3])
def test_reordered_bregman(a_key, s_key, dim):
"""Test all combinations for reordered Bregman."""
options.update(bregman_reordered_options)
options.update(accelerations[a_key])
options.update(solvers[s_key])
distance, _, _, _, status = darsia.wasserstein_distance(
distance, info = darsia.wasserstein_distance(
src_image[dim],
dst_image[dim],
options=options,
method="bregman",
return_solution=True,
)
assert np.isclose(distance, true_distance[dim], atol=1e-2) # TODO
assert status["converged"]
assert np.isclose(distance, true_distance[dim], atol=1e-2)
assert info["converged"]


@pytest.mark.parametrize("a_key", range(len(accelerations)))
Expand All @@ -208,12 +177,11 @@ def test_adaptive_bregman(a_key, s_key, dim):
options.update(bregman_adaptive_options)
options.update(accelerations[a_key])
options.update(solvers[s_key])
distance, _, _, _, status = darsia.wasserstein_distance(
distance, info = darsia.wasserstein_distance(
src_image[dim],
dst_image[dim],
options=options,
method="bregman",
return_solution=True,
)
assert np.isclose(distance, true_distance[dim], atol=1e-5)
assert status["converged"]
assert info["converged"]

0 comments on commit 50f989b

Please sign in to comment.