diff --git a/Wrappers/Python/test/test_algorithms.py b/Wrappers/Python/test/test_algorithms.py index e3f6302d78..722bc685c2 100644 --- a/Wrappers/Python/test/test_algorithms.py +++ b/Wrappers/Python/test/test_algorithms.py @@ -245,6 +245,27 @@ def test_gd_armijo_rosen(self): gd.solution.array[0], self.scipy_opt_high.x[0], atol=1e-2) np.testing.assert_allclose( gd.solution.array[1], self.scipy_opt_high.x[1], atol=1e-2) + + def test_gd_run_no_iterations(self): + gd = GD(initial=self.initial, objective_function=self.f, step_size=0.002) + with self.assertRaises(ValueError): + gd.run() + + def test_gd_run_infinite(self): + gd = GD(initial=self.initial, objective_function=self.f, step_size=0.002) + with self.assertRaises(ValueError): + gd.run(np.inf) + + class StopCallback(callbacks.Callback): + def __init__(self): + self.count = 0 + def __call__(self, algorithm): + self.count += 1 + if self.count == 10: + raise StopIteration + with self.assertWarns(UserWarning): + gd.run(np.inf, callbacks=[StopCallback()]) + self.assertEqual(gd.iteration, 9) class TestFISTA(CCPiTestClass):