From 50f989b0ecd7fbb5afdeae5eca4e1f3f28c399f8 Mon Sep 17 00:00:00 2001 From: Jakub Both Date: Fri, 3 Nov 2023 23:54:46 +0100 Subject: [PATCH] TST: Adapt Wasserstein test to new interface after refactor. --- tests/unit/test_wasserstein.py | 62 ++++++++-------------------------- 1 file changed, 15 insertions(+), 47 deletions(-) diff --git a/tests/unit/test_wasserstein.py b/tests/unit/test_wasserstein.py index 5df63246..54aba03b 100644 --- a/tests/unit/test_wasserstein.py +++ b/tests/unit/test_wasserstein.py @@ -75,28 +75,21 @@ # Linearization newton_options = { # Scheme - "L": 1e-9, + "L": 1e9, } bregman_std_options = { # Scheme "L": 1, } -bregman_reordered_options = { - # Scheme - "L": 1, - "bregman_mode": "reordered", -} bregman_adaptive_options = { # Scheme "L": 1, - "bregman_mode": "adaptive", - "bregman_update_cond": lambda iter: iter % 20 == 0, + "bregman_update": lambda iter: iter % 20 == 0, } linearizations = { "newton": [newton_options], "bregman": [ bregman_std_options, - bregman_reordered_options, bregman_adaptive_options, ], } @@ -117,27 +110,24 @@ # Linear solver lu_options = { # Linear solver - "linear_solver": "lu" + "linear_solver": "direct", } amg_options = { - "linear_solver": "amg-potential", - "linear_solver_tol": 1e-8, + "linear_solver": "amg", + "linear_solver_options": { + "tol": 1e-8, + }, } solvers = [lu_options, amg_options] # General options options = { - # Solver parameters - "regularization": 1e-16, - # Scheme - "lumping": True, # Performance control "num_iter": 400, "tol_residual": 1e-10, "tol_increment": 1e-6, "tol_distance": 1e-10, - # Output - "verbose": False, + "return_info": True, } # ! ---- Tests ---- @@ -151,15 +141,14 @@ def test_newton(a_key, s_key, dim): options.update(newton_options) options.update(accelerations[a_key]) options.update(solvers[s_key]) - distance, _, _, _, status = darsia.wasserstein_distance( + distance, info = darsia.wasserstein_distance( src_image[dim], dst_image[dim], options=options, method="newton", - return_solution=True, ) assert np.isclose(distance, true_distance[dim], atol=1e-5) - assert status["converged"] + assert info["converged"] @pytest.mark.parametrize("a_key", range(len(accelerations))) @@ -170,34 +159,14 @@ def test_std_bregman(a_key, s_key, dim): options.update(bregman_std_options) options.update(accelerations[a_key]) options.update(solvers[s_key]) - distance, _, _, _, status = darsia.wasserstein_distance( - src_image[dim], - dst_image[dim], - options=options, - method="bregman", - return_solution=True, - ) - assert np.isclose(distance, true_distance[dim], atol=1e-2) # TODO - assert status["converged"] - - -@pytest.mark.parametrize("a_key", range(len(accelerations))) -@pytest.mark.parametrize("s_key", range(len(solvers))) -@pytest.mark.parametrize("dim", [2, 3]) -def test_reordered_bregman(a_key, s_key, dim): - """Test all combinations for reordered Bregman.""" - options.update(bregman_reordered_options) - options.update(accelerations[a_key]) - options.update(solvers[s_key]) - distance, _, _, _, status = darsia.wasserstein_distance( + distance, info = darsia.wasserstein_distance( src_image[dim], dst_image[dim], options=options, method="bregman", - return_solution=True, ) - assert np.isclose(distance, true_distance[dim], atol=1e-2) # TODO - assert status["converged"] + assert np.isclose(distance, true_distance[dim], atol=1e-2) + assert info["converged"] @pytest.mark.parametrize("a_key", range(len(accelerations))) @@ -208,12 +177,11 @@ def test_adaptive_bregman(a_key, s_key, dim): options.update(bregman_adaptive_options) options.update(accelerations[a_key]) options.update(solvers[s_key]) - distance, _, _, _, status = darsia.wasserstein_distance( + distance, info = darsia.wasserstein_distance( src_image[dim], dst_image[dim], options=options, method="bregman", - return_solution=True, ) assert np.isclose(distance, true_distance[dim], atol=1e-5) - assert status["converged"] + assert info["converged"]