Skip to content

Commit

Permalink
Merge pull request spacetelescope#2687 from pllim/test-fitting-cleanup
Browse files Browse the repository at this point in the history
TST: Minor clean-up of test_fitting
  • Loading branch information
pllim authored Feb 2, 2024
2 parents 4319b64 + 0bf8d2a commit b89d91f
Showing 1 changed file with 20 additions and 16 deletions.
36 changes: 20 additions & 16 deletions jdaviz/configs/default/plugins/model_fitting/tests/test_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from jdaviz.configs.default.plugins.model_fitting import fitting_backend as fb
from jdaviz.configs.default.plugins.model_fitting import initializers
from jdaviz.configs.default.plugins.model_fitting.model_fitting import ModelFitting

SPECTRUM_SIZE = 200 # length of spectrum

Expand Down Expand Up @@ -56,14 +55,13 @@ def test_model_params():
assert np.all([p in expected_params for p in params])


@pytest.mark.filterwarnings('ignore')
def test_model_ids(cubeviz_helper, spectral_cube_wcs):
cubeviz_helper.load_data(Spectrum1D(flux=np.ones((3, 4, 5)) * u.nJy, wcs=spectral_cube_wcs),
data_label='test')
plugin = ModelFitting(app=cubeviz_helper.app)
plugin = cubeviz_helper.plugins["Model Fitting"]._obj
plugin.dataset_selected = 'test[FLUX]'
plugin.component_models = [{'id': 'valid_string_already_exists'}]
plugin.comp_selected = 'Linear1D'
plugin.model_comp_selected = 'Linear1D'

with pytest.raises(
ValueError,
Expand All @@ -78,7 +76,6 @@ def test_model_ids(cubeviz_helper, spectral_cube_wcs):
plugin.vue_add_model({})


@pytest.mark.filterwarnings(r"ignore:Model is linear in parameters.*")
def test_parameter_retrieval(cubeviz_helper, spectral_cube_wcs):
flux = np.ones((3, 4, 5))
flux[2, 2, :] = [1, 2, 3, 4, 5]
Expand All @@ -87,7 +84,9 @@ def test_parameter_retrieval(cubeviz_helper, spectral_cube_wcs):
plugin = cubeviz_helper.plugins["Model Fitting"]
plugin.create_model_component("Linear1D", "L")
plugin.cube_fit = True
plugin.calculate_fit()
with warnings.catch_warnings():
warnings.filterwarnings('ignore', message='Model is linear in parameters.*')
plugin.calculate_fit()

params = cubeviz_helper.get_model_parameters()
slope_res = np.zeros((4, 3))
Expand Down Expand Up @@ -202,7 +201,7 @@ def test_cube_fitting_backend(cubeviz_helper, unc, tmp_path):

# Fit to all spaxels.
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message=r"The fit may be unsuccessful.*")
warnings.filterwarnings("ignore", message="The fit may be unsuccessful.*")
fitted_parameters, fitted_spectrum = fb.fit_model_to_spectrum(
spectrum, model_list, expression, n_cpu=n_cpu)

Expand Down Expand Up @@ -290,8 +289,6 @@ def test_cube_fitting_backend(cubeviz_helper, unc, tmp_path):
assert_array_equal(flux_mask.data, mask)


@pytest.mark.filterwarnings(r"ignore:Model is linear in parameters.*")
@pytest.mark.filterwarnings(r"ignore:The fit may be unsuccessful.*")
def test_results_table(specviz_helper, spectrum1d):
data_label = 'test'
specviz_helper.load_data(spectrum1d, data_label=data_label)
Expand All @@ -300,7 +297,9 @@ def test_results_table(specviz_helper, spectrum1d):
mf.create_model_component('Linear1D')

mf.add_results.label = 'linear model'
mf.calculate_fit(add_data=True)
with warnings.catch_warnings():
warnings.filterwarnings('ignore', message='Model is linear in parameters.*')
mf.calculate_fit(add_data=True)
mf_table = mf.export_table()
assert len(mf_table) == 1
assert mf_table['equation'][-1] == 'L'
Expand All @@ -312,7 +311,10 @@ def test_results_table(specviz_helper, spectrum1d):

mf.create_model_component('Gaussian1D')
mf.add_results.label = 'composite model'
mf.calculate_fit(add_data=True)
with warnings.catch_warnings():
warnings.filterwarnings('ignore', message='Model is linear in parameters.*')
warnings.filterwarnings('ignore', message='The fit may be unsuccessful.*')
mf.calculate_fit(add_data=True)
mf_table = mf.export_table()
assert len(mf_table) == 2
assert mf_table['equation'][-1] == 'L+G'
Expand Down Expand Up @@ -356,8 +358,6 @@ def test_equation_validation(specviz_helper, spectrum1d):
assert mf._obj.model_equation_invalid_msg == 'model equation is required.'


@pytest.mark.filterwarnings(r"ignore:Model is linear in parameters.*")
@pytest.mark.filterwarnings(r"ignore:The fit may be unsuccessful.*")
def test_incompatible_units(specviz_helper, spectrum1d):
data_label = 'test'
specviz_helper.load_data(spectrum1d, data_label=data_label)
Expand All @@ -366,17 +366,21 @@ def test_incompatible_units(specviz_helper, spectrum1d):
mf.create_model_component('Linear1D')

mf.add_results.label = 'model native units'
mf.calculate_fit(add_data=True)
with warnings.catch_warnings():
warnings.filterwarnings('ignore', message='Model is linear in parameters.*')
mf.calculate_fit(add_data=True)

uc = specviz_helper.plugins['Unit Conversion']
assert uc.spectral_unit.selected == "Angstrom"
uc.spectral_unit = u.Hz

assert 'L is currently disabled' in mf._obj.model_equation_invalid_msg
mf.add_results.label = 'frequency units'
with pytest.raises(ValueError, match=r"model equation is invalid.*"):
with pytest.raises(ValueError, match="model equation is invalid.*"):
mf.calculate_fit(add_data=True)

mf.reestimate_model_parameters()
assert mf._obj.model_equation_invalid_msg == ''
mf.calculate_fit(add_data=True)
with warnings.catch_warnings():
warnings.filterwarnings('ignore', message='Model is linear in parameters.*')
mf.calculate_fit(add_data=True)

0 comments on commit b89d91f

Please sign in to comment.