Skip to content

Commit

Permalink
Renamed some tests to be picked up by pytest
Browse files Browse the repository at this point in the history
  • Loading branch information
daurer committed Jan 24, 2025
1 parent 021709f commit ca35ab4
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,13 @@ def check_engine_output(self, output, plotting=False, debug=False):
RMSE_ob = (np.mean(np.abs(OBJ_ML_serial - OBJ_ML)**2))
RMSE_pr = (np.mean(np.abs(PRB_ML_serial - PRB_ML)**2))
# RMSE_LL = (np.mean(np.abs(LL_ML_serial - LL_ML)**2))
np.testing.assert_allclose(RMSE_ob, 0.0, atol=1e-2,
np.testing.assert_allclose(RMSE_ob, 0.0, atol=1e-2,
err_msg="The object arrays are not matching as expected")
np.testing.assert_allclose(RMSE_pr, 0.0, atol=1e-2,
np.testing.assert_allclose(RMSE_pr, 0.0, atol=1e-2,
err_msg="The object arrays are not matching as expected")
# np.testing.assert_allclose(RMSE_LL, 0.0, atol=1e-7,
# err_msg="The log-likelihood errors are not matching as expected")


def test_ML_serial_base(self):
out = []
Expand Down Expand Up @@ -150,6 +150,25 @@ def test_ML_serial_smoothing_regularizer(self):
scanmodel="BlockFull", autosave=False, verbose_level="critical"))
self.check_engine_output(out, plotting=False, debug=False)

def test_ML_serial_wavefield_preconditioner(self):
out = []
for eng in ["ML", "ML_serial"]:
engine_params = u.Param()
engine_params.name = eng
engine_params.numiter = 100
engine_params.floating_intensities = False
engine_params.reg_del2 = False
engine_params.reg_del2_amplitude = 1.
engine_params.smooth_gradient = 0.
engine_params.smooth_gradient_decay = 0.
engine_params.scale_precond = False
engine_params.wavefield_precond = True
engine_params.wavefield_delta_object = 0.1
engine_params.wavefield_delta_probe = 0.1
out.append(tu.EngineTestRunner(engine_params, output_path=self.outpath, init_correct_probe=True,
scanmodel="BlockFull", autosave=False, verbose_level="critical"))
self.check_engine_output(out, plotting=False, debug=False)

def test_ML_serial_all(self):
out = []
for eng in ["ML", "ML_serial"]:
Expand Down

0 comments on commit ca35ab4

Please sign in to comment.