diff --git a/jdaviz/configs/cubeviz/plugins/slice/slice.py b/jdaviz/configs/cubeviz/plugins/slice/slice.py index 9871c61f70..a1cf014677 100644 --- a/jdaviz/configs/cubeviz/plugins/slice/slice.py +++ b/jdaviz/configs/cubeviz/plugins/slice/slice.py @@ -125,6 +125,9 @@ def _initialize_location(self, *args): if str(viewer.state.x_att) not in self.valid_slice_att_names: # avoid setting value to degs, before x_att is changed to wavelength, for example continue + # ensure the cache is reset (if previous attempts to initialize failed resulting in an + # empty list as the cache) + viewer._clear_cache('slice_values') slice_values = viewer.slice_values if len(slice_values): self.value = slice_values[int(len(slice_values)/2)] diff --git a/jdaviz/configs/cubeviz/plugins/viewers.py b/jdaviz/configs/cubeviz/plugins/viewers.py index 24ee29b41c..0d8b6c6614 100644 --- a/jdaviz/configs/cubeviz/plugins/viewers.py +++ b/jdaviz/configs/cubeviz/plugins/viewers.py @@ -47,20 +47,20 @@ def _get_component(layer): # or slice_component_label is not a component in this layer # either way, return an empty array and skip this layer return np.array([]) - data_obj = data_comp.data - data_units = data_comp.units - data_spec_axis = np.asarray(data_obj.data, dtype=float) * u.Unit(data_units) # Convert axis if display units are set and are different - if slice_display_units and slice_display_units != data_units: - return data_spec_axis.to_value(slice_display_units, - equivalencies=u.spectral()) + data_units = getattr(data_comp, 'units', None) + if slice_display_units and data_units and slice_display_units != data_units: + data = np.asarray(data_comp.data, dtype=float) * u.Unit(data_units) + return data.to_value(slice_display_units, + equivalencies=u.spectral()) else: - return data_spec_axis + return data_comp.data try: return np.asarray(np.unique(np.concatenate([_get_component(layer) for layer in self.layers])), # noqa dtype=float) except ValueError: + # NOTE: this will result in caching an empty list return np.array([]) def _set_slice_indicator_value(self, value): @@ -109,23 +109,22 @@ def slice_values(self): try: # Retrieve layer data and units using the slice index of the world components ids - data_obj = layer.layer.data.get_component(world_comp_ids[self.slice_index]).data - data_units = layer.layer.data.get_component(world_comp_ids[self.slice_index]).units + data_comp = layer.layer.data.get_component(world_comp_ids[self.slice_index]) except (AttributeError, KeyError): continue - # Find the spectral axis - data_spec_axis = np.asarray(data_obj.take(0, take_inds[0]).take(0, take_inds[1]), # noqa - dtype=float) + data = np.asarray(data_comp.data.take(0, take_inds[0]).take(0, take_inds[1]), # noqa + dtype=float) # Convert to display units if applicable - if slice_display_units and slice_display_units != data_units: - converted_axis = (data_spec_axis * u.Unit(data_units)).to_value( + data_units = getattr(data_comp, 'units', None) + if slice_display_units and data_units and slice_display_units != data_units: + converted_axis = (data * u.Unit(data_units)).to_value( slice_display_units, equivalencies=u.spectral() + u.pixel_scale(1*u.pix) ) else: - converted_axis = data_spec_axis + converted_axis = data return converted_axis diff --git a/jdaviz/configs/default/plugins/viewers.py b/jdaviz/configs/default/plugins/viewers.py index 9b1cf6aea3..42090a1022 100644 --- a/jdaviz/configs/default/plugins/viewers.py +++ b/jdaviz/configs/default/plugins/viewers.py @@ -10,6 +10,7 @@ from jdaviz.components.toolbar_nested import NestedJupyterToolbar from jdaviz.core.astrowidgets_api import AstrowidgetsImageViewerMixin from jdaviz.core.registries import viewer_registry +from jdaviz.core.template_mixin import WithCache from jdaviz.core.user_api import ViewerUserApi from jdaviz.utils import ColorCycler, get_subset_type, _wcs_only_label, layer_is_not_dq @@ -20,7 +21,7 @@ viewer_registry.add("g-table-viewer", label="Table", cls=TableViewer) -class JdavizViewerMixin: +class JdavizViewerMixin(WithCache): toolbar = None tools_nested = [] _prev_limits = None diff --git a/jdaviz/core/template_mixin.py b/jdaviz/core/template_mixin.py index c73b33bc39..75252ef549 100644 --- a/jdaviz/core/template_mixin.py +++ b/jdaviz/core/template_mixin.py @@ -69,7 +69,7 @@ __all__ = ['show_widget', 'TemplateMixin', 'PluginTemplateMixin', 'skip_if_no_updates_since_last_active', 'skip_if_not_tray_instance', 'with_spinner', 'with_temp_disable', - 'ViewerPropertiesMixin', + 'WithCache', 'ViewerPropertiesMixin', 'BasePluginComponent', 'MultiselectMixin', 'SelectPluginComponent', 'UnitSelectPluginComponent', 'EditableSelectPluginComponent', @@ -186,7 +186,19 @@ def flux_viewer(self): return self.app.get_viewer(viewer_reference) -class TemplateMixin(VuetifyTemplate, HubListener, ViewerPropertiesMixin): +class WithCache: + def _clear_cache(self, *attrs): + """ + provide convenience function to clearing the cache for cached_properties + """ + if not len(attrs): + attrs = getattr(self, '_cached_properties', []) + for attr in attrs: + if attr in self.__dict__: + del self.__dict__[attr] + + +class TemplateMixin(VuetifyTemplate, HubListener, ViewerPropertiesMixin, WithCache): config = Unicode("").tag(sync=True) vdocs = Unicode("").tag(sync=True) popout_button = Any().tag(sync=True, **widget_serialization) @@ -612,7 +624,7 @@ def show(self, loc="inline", title=None): # pragma: no cover show_widget(self, loc=loc, title=title) -class BasePluginComponent(HubListener, ViewerPropertiesMixin): +class BasePluginComponent(HubListener, ViewerPropertiesMixin, WithCache): """ This base class handles attaching traitlets from the plugin itself to logic handled within the component, support for caching and clearing caches on properties, @@ -637,16 +649,6 @@ def __setattr__(self, attr, value, force_super=False): return setattr(self._plugin, self._plugin_traitlets.get(attr), value) - def _clear_cache(self, *attrs): - """ - provide convenience function to clearing the cache for cached_properties - """ - if not len(attrs): - attrs = self._cached_properties - for attr in attrs: - if attr in self.__dict__: - del self.__dict__[attr] - def add_traitlets(self, **traitlets): for k, v in traitlets.items(): if v is None: