Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
francisco-dlp committed Oct 9, 2023
1 parent 5c3f352 commit 5b06c0a
Show file tree
Hide file tree
Showing 9 changed files with 27 additions and 27 deletions.
2 changes: 1 addition & 1 deletion exspy/exspy/test/models/test_eelsmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,4 +605,4 @@ def test_model_store_restore(self):
m = self.m
m.store()
mc = m.signal.models.a.restore()
assert np.array_equal(m(), mc())
assert np.array_equal(m._get_current_data(), mc._get_current_data())
2 changes: 1 addition & 1 deletion hyperspy/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -1054,7 +1054,7 @@ def _component2plot(self, axes_manager, out_of_range2nans=True):
old_axes_manager = self.model.axes_manager
self.model.axes_manager = axes_manager
self.fetch_stored_values()
s = self.model.__call__(component_list=[self])
s = self.model._get_current_data(component_list=[self])
if not self.active:
s.fill(np.nan)
if old_axes_manager is not None:
Expand Down
2 changes: 1 addition & 1 deletion hyperspy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1212,7 +1212,7 @@ def _get_variance(self, only_current=True):

def _calculate_chisq(self):
variance = self._get_variance()
d = self(onlyactive=True, binned=self._binned).ravel() - self.signal._get_current_data(as_numpy=True)[
d = self._get_current_data(onlyactive=True, binned=self._binned).ravel() - self.signal._get_current_data(as_numpy=True)[
np.where(self._channel_switches)]
d *= d / (1. * variance) # d = difference^2 / variance.
self.chisq.data[self.signal.axes_manager.indices[::-1]] = d.sum()
Expand Down
8 changes: 4 additions & 4 deletions hyperspy/tests/component/test_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def setup_method(self, method):
self.c = Component(["one", "two"])
c = self.c
c.model = mock.MagicMock()
c.model.__call__ = mock.MagicMock()
c.model._get_current_data = mock.MagicMock()
c.model._channel_switches = np.array([True, False, True])
c.model.axis.axis = np.array([0.1, 0.2, 0.3])
c.function = mock.MagicMock()
Expand All @@ -249,23 +249,23 @@ def test_plotting_active_component_notbinned(self):
c = self.c
c.active = True
c.model.signal.axes_manager[-1].is_binned = False
c.model.__call__.return_value = np.array([1.3])
c.model._get_current_data.return_value = np.array([1.3])
res = c._component2plot(c.model.axes_manager, out_of_range2nans=False)
np.testing.assert_array_equal(res, np.array([1.3, ]))

def test_plotting_active_component_binned(self):
c = self.c
c.active = True
c.model.signal.axes_manager[-1].is_binned = True
c.model.__call__.return_value = np.array([1.3])
c.model._get_current_data.return_value = np.array([1.3])
res = c._component2plot(c.model.axes_manager, out_of_range2nans=False)
np.testing.assert_array_equal(res, np.array([1.3, ]))

def test_plotting_active_component_out_of_range(self):
c = self.c
c.active = True
c.model.signal.axes_manager[-1].is_binned = False
c.model.__call__.return_value = np.array([1.1, 1.3])
c.model._get_current_data.return_value = np.array([1.1, 1.3])
res = c._component2plot(c.model.axes_manager, out_of_range2nans=True)
np.testing.assert_array_equal(res, np.array([1.1, np.nan, 1.3]))

Expand Down
2 changes: 1 addition & 1 deletion hyperspy/tests/model/test_fit_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def test_fit_multiple_component(self):
m.fit_component(g2, signal_range=(1500, 2200))
m.fit_component(g3, signal_range=(5800, 6150))
np.testing.assert_allclose(self.model.signal.data,
m(),
m._get_current_data(),
rtol=self.rtol,
atol=10e-3)

Expand Down
4 changes: 2 additions & 2 deletions hyperspy/tests/model/test_linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ def test_without_twins(self):
np.testing.assert_allclose(gs[0].A.value, 20)
np.testing.assert_allclose(gs[1].A.value, -10)
np.testing.assert_allclose(gs[2].A.value, 5)
np.testing.assert_allclose(s.data, m())
np.testing.assert_allclose(s.data, m._get_current_data())

def test_with_twins(self):
gs = self.gs
Expand All @@ -536,7 +536,7 @@ def test_with_twins(self):
np.testing.assert_allclose(gs[0].A.value, 20)
np.testing.assert_allclose(gs[1].A.value, -10)
np.testing.assert_allclose(gs[2].A.value, 5)
np.testing.assert_allclose(s.data, m())
np.testing.assert_allclose(s.data, m._get_current_data())


def test_compute_constant_term():
Expand Down
28 changes: 14 additions & 14 deletions hyperspy/tests/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,14 @@ def test_call_method_no_convolutions(self):
m.convolved = False

m[1].active = False
r1 = m()
r2 = m(onlyactive=True)
r1 = m._get_current_data()
r2 = m._get_current_data(onlyactive=True)
np.testing.assert_allclose(m[0].function(0) * 2, r1)
np.testing.assert_allclose(m[0].function(0), r2)

m.convolved = True
r1 = m(non_convolved=True)
r2 = m(non_convolved=True, onlyactive=True)
r1 = m._get_current_data(non_convolved=True)
r2 = m._get_current_data(non_convolved=True, onlyactive=True)
np.testing.assert_allclose(m[0].function(0) * 2, r1)
np.testing.assert_allclose(m[0].function(0), r2)

Expand All @@ -124,8 +124,8 @@ def test_call_method_with_convolutions(self):
m[2].convolved = False
m.convolution_axis = np.array([0.0])

r1 = m()
r2 = m(onlyactive=True)
r1 = m._get_current_data()
r2 = m._get_current_data(onlyactive=True)
np.testing.assert_allclose(m[0].function(0) * 2.3, r1)
np.testing.assert_allclose(m[0].function(0) * 1.3, r2)

Expand All @@ -135,16 +135,16 @@ def test_call_method_binned(self):
m.remove(1)
m.signal.axes_manager[-1].is_binned = True
m.signal.axes_manager[-1].scale = 0.3
r1 = m()
r1 = m._get_current_data()
np.testing.assert_allclose(m[0].function(0) * 0.3, r1)


class TestModelPlotCall:
def setup_method(self, method):
s = hs.signals.Signal1D(np.empty(1))
m = s.create_model()
m.__call__ = mock.MagicMock()
m.__call__.return_value = np.array([0.5, 0.25])
m._get_current_data = mock.MagicMock()
m._get_current_data.return_value = np.array([0.5, 0.25])
m.axis = mock.MagicMock()
m.fetch_stored_values = mock.MagicMock()
m._channel_switches = np.array([0, 1, 1, 0, 0], dtype=bool)
Expand All @@ -157,16 +157,16 @@ def test_model2plot_own_am(self):
np.testing.assert_array_equal(
res, np.array([np.nan, 0.5, 0.25, np.nan, np.nan])
)
assert m.__call__.called
assert m.__call__.call_args[1] == {"non_convolved": False, "onlyactive": True}
assert m._get_current_data.called
assert m._get_current_data.call_args[1] == {"non_convolved": False, "onlyactive": True}
assert not m.fetch_stored_values.called

def test_model2plot_other_am(self):
m = self.model
res = m._model2plot(m.axes_manager.deepcopy(), out_of_range2nans=False)
np.testing.assert_array_equal(res, np.array([0.5, 0.25]))
assert m.__call__.called
assert m.__call__.call_args[1] == {"non_convolved": False, "onlyactive": True}
assert m._get_current_data.called
assert m._get_current_data.call_args[1] == {"non_convolved": False, "onlyactive": True}
assert 2 == m.fetch_stored_values.call_count


Expand Down Expand Up @@ -627,7 +627,7 @@ def test_binned_uniform(self, binned, uniform):
m.signal.axes_manager[-1].scale = 0.3
if uniform:
m.signal.axes_manager[-1].convert_to_non_uniform_axis()
np.testing.assert_allclose(m[0].function(0) * 0.3, m())
np.testing.assert_allclose(m[0].function(0) * 0.3, m._get_current_data())
self.m.print_current_values()


Expand Down
2 changes: 1 addition & 1 deletion hyperspy/tests/model/test_model2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def test_fit_no_odr_error(self):

def test_call(self):
with pytest.raises(ValueError):
self.m(component_list=0)
self.m._get_current_data(component_list=0)


def test_Model2D_NotImplementedError_fitting():
Expand Down
4 changes: 2 additions & 2 deletions hyperspy/tests/signals/test_binned.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def setup_method(self, method):

def test_unbinned(self):
self.m.signal.axes_manager[-1].is_binned = False
assert self.m() == 1
assert self.m._get_current_data() == 1

def test_binned(self):
self.m.signal.axes_manager[-1].is_binned = True
assert self.m() == 0.1
assert self.m._get_current_data() == 0.1

0 comments on commit 5b06c0a

Please sign in to comment.