Skip to content

Commit

Permalink
Merge pull request spacetelescope#2895 from kecnry/viewer-slices-gene…
Browse files Browse the repository at this point in the history
…ralize

Generalize viewer slices logic
  • Loading branch information
bmorris3 authored May 29, 2024
2 parents 609f16e + 904e269 commit 50ebc1e
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 29 deletions.
3 changes: 3 additions & 0 deletions jdaviz/configs/cubeviz/plugins/slice/slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ def _initialize_location(self, *args):
if str(viewer.state.x_att) not in self.valid_slice_att_names:
# avoid setting value to degs, before x_att is changed to wavelength, for example
continue
# ensure the cache is reset (if previous attempts to initialize failed resulting in an
# empty list as the cache)
viewer._clear_cache('slice_values')
slice_values = viewer.slice_values
if len(slice_values):
self.value = slice_values[int(len(slice_values)/2)]
Expand Down
29 changes: 14 additions & 15 deletions jdaviz/configs/cubeviz/plugins/viewers.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,20 @@ def _get_component(layer):
# or slice_component_label is not a component in this layer
# either way, return an empty array and skip this layer
return np.array([])
data_obj = data_comp.data
data_units = data_comp.units
data_spec_axis = np.asarray(data_obj.data, dtype=float) * u.Unit(data_units)

# Convert axis if display units are set and are different
if slice_display_units and slice_display_units != data_units:
return data_spec_axis.to_value(slice_display_units,
equivalencies=u.spectral())
data_units = getattr(data_comp, 'units', None)
if slice_display_units and data_units and slice_display_units != data_units:
data = np.asarray(data_comp.data, dtype=float) * u.Unit(data_units)
return data.to_value(slice_display_units,
equivalencies=u.spectral())
else:
return data_spec_axis
return data_comp.data
try:
return np.asarray(np.unique(np.concatenate([_get_component(layer) for layer in self.layers])), # noqa
dtype=float)
except ValueError:
# NOTE: this will result in caching an empty list
return np.array([])

def _set_slice_indicator_value(self, value):
Expand Down Expand Up @@ -109,23 +109,22 @@ def slice_values(self):

try:
# Retrieve layer data and units using the slice index of the world components ids
data_obj = layer.layer.data.get_component(world_comp_ids[self.slice_index]).data
data_units = layer.layer.data.get_component(world_comp_ids[self.slice_index]).units
data_comp = layer.layer.data.get_component(world_comp_ids[self.slice_index])
except (AttributeError, KeyError):
continue

# Find the spectral axis
data_spec_axis = np.asarray(data_obj.take(0, take_inds[0]).take(0, take_inds[1]), # noqa
dtype=float)
data = np.asarray(data_comp.data.take(0, take_inds[0]).take(0, take_inds[1]), # noqa
dtype=float)

# Convert to display units if applicable
if slice_display_units and slice_display_units != data_units:
converted_axis = (data_spec_axis * u.Unit(data_units)).to_value(
data_units = getattr(data_comp, 'units', None)
if slice_display_units and data_units and slice_display_units != data_units:
converted_axis = (data * u.Unit(data_units)).to_value(
slice_display_units,
equivalencies=u.spectral() + u.pixel_scale(1*u.pix)
)
else:
converted_axis = data_spec_axis
converted_axis = data

return converted_axis

Expand Down
3 changes: 2 additions & 1 deletion jdaviz/configs/default/plugins/viewers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from jdaviz.components.toolbar_nested import NestedJupyterToolbar
from jdaviz.core.astrowidgets_api import AstrowidgetsImageViewerMixin
from jdaviz.core.registries import viewer_registry
from jdaviz.core.template_mixin import WithCache
from jdaviz.core.user_api import ViewerUserApi
from jdaviz.utils import ColorCycler, get_subset_type, _wcs_only_label, layer_is_not_dq

Expand All @@ -20,7 +21,7 @@
viewer_registry.add("g-table-viewer", label="Table", cls=TableViewer)


class JdavizViewerMixin:
class JdavizViewerMixin(WithCache):
toolbar = None
tools_nested = []
_prev_limits = None
Expand Down
28 changes: 15 additions & 13 deletions jdaviz/core/template_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
__all__ = ['show_widget', 'TemplateMixin', 'PluginTemplateMixin',
'skip_if_no_updates_since_last_active', 'skip_if_not_tray_instance',
'with_spinner', 'with_temp_disable',
'ViewerPropertiesMixin',
'WithCache', 'ViewerPropertiesMixin',
'BasePluginComponent',
'MultiselectMixin',
'SelectPluginComponent', 'UnitSelectPluginComponent', 'EditableSelectPluginComponent',
Expand Down Expand Up @@ -186,7 +186,19 @@ def flux_viewer(self):
return self.app.get_viewer(viewer_reference)


class TemplateMixin(VuetifyTemplate, HubListener, ViewerPropertiesMixin):
class WithCache:
def _clear_cache(self, *attrs):
"""
provide convenience function to clearing the cache for cached_properties
"""
if not len(attrs):
attrs = getattr(self, '_cached_properties', [])
for attr in attrs:
if attr in self.__dict__:
del self.__dict__[attr]


class TemplateMixin(VuetifyTemplate, HubListener, ViewerPropertiesMixin, WithCache):
config = Unicode("").tag(sync=True)
vdocs = Unicode("").tag(sync=True)
popout_button = Any().tag(sync=True, **widget_serialization)
Expand Down Expand Up @@ -612,7 +624,7 @@ def show(self, loc="inline", title=None): # pragma: no cover
show_widget(self, loc=loc, title=title)


class BasePluginComponent(HubListener, ViewerPropertiesMixin):
class BasePluginComponent(HubListener, ViewerPropertiesMixin, WithCache):
"""
This base class handles attaching traitlets from the plugin itself to logic
handled within the component, support for caching and clearing caches on properties,
Expand All @@ -637,16 +649,6 @@ def __setattr__(self, attr, value, force_super=False):

return setattr(self._plugin, self._plugin_traitlets.get(attr), value)

def _clear_cache(self, *attrs):
"""
provide convenience function to clearing the cache for cached_properties
"""
if not len(attrs):
attrs = self._cached_properties
for attr in attrs:
if attr in self.__dict__:
del self.__dict__[attr]

def add_traitlets(self, **traitlets):
for k, v in traitlets.items():
if v is None:
Expand Down

0 comments on commit 50ebc1e

Please sign in to comment.