Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support RangeXY streams on multi-axes #5826

Merged
merged 12 commits into from
Jul 25, 2023
32 changes: 27 additions & 5 deletions holoviews/plotting/bokeh/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ class Callback:
# The plotting handle(s) to attach the JS callback on
models = []

# Additional handles to hash on for uniqueness
extra_handles = []
philippjfr marked this conversation as resolved.
Show resolved Hide resolved

# Conditions when callback should be skipped
skip_events = []
skip_changes = []
Expand Down Expand Up @@ -209,6 +212,9 @@ def _init_plot_handles(self):
if h in self.plot_handles:
requested[h] = handles[h]
self.handle_ids.update(self._get_stream_handle_ids(requested))
for h in self.extra_handles:
if h in self.plot_handles:
requested[h] = handles[h]
return requested

def _get_stream_handle_ids(self, handles):
Expand Down Expand Up @@ -379,20 +385,24 @@ def set_callback(self, handle):

def initialize(self, plot_id=None):
handles = self._init_plot_handles()
cb_handles = []
for handle_name in self.models:
hash_handles, cb_handles = [], []
for handle_name in self.models+self.extra_handles:
if handle_name not in handles:
warn_args = (handle_name, type(self.plot).__name__,
type(self).__name__)
print('{} handle not found on {}, cannot '
'attach {} callback'.format(*warn_args))
continue
cb_handles.append(handles[handle_name])
handle = handles[handle_name]
if handle_name not in self.extra_handles:
cb_handles.append(handle)
hash_handles.append(handle)

# Hash the plot handle with Callback type allowing multiple
# callbacks on one handle to be merged
handle_ids = [id(h) for h in cb_handles]
cb_hash = tuple(handle_ids)+(id(type(self)),)
hash_ids = [id(h) for h in hash_handles]
cb_hash = tuple(hash_ids)+(id(type(self)),)
if cb_hash in self._callbacks:
# Merge callbacks if another callback has already been attached
cb = self._callbacks[cb_hash]
Expand Down Expand Up @@ -599,11 +609,13 @@ class RangeXYCallback(Callback):

models = ['plot']

extra_handles = ['x_range', 'y_range']

attributes = {
'x0': 'cb_obj.x0',
'y0': 'cb_obj.y0',
'x1': 'cb_obj.x1',
'y1': 'cb_obj.y1'
'y1': 'cb_obj.y1',
}

_js_on_event = """
Expand All @@ -624,6 +636,12 @@ def set_callback(self, handle):
handle.js_on_event('rangesupdate', CustomJS(code=self._js_on_event))

def _process_msg(self, msg):
if self.plot.state.x_range is not self.plot.handles['x_range']:
x_range = self.plot.handles['x_range']
msg['x0'], msg['x1'] = x_range.start, x_range.end
if self.plot.state.y_range is not self.plot.handles['y_range']:
y_range = self.plot.handles['y_range']
msg['y0'], msg['y1'] = y_range.start, y_range.end
data = {}
if 'x0' in msg and 'x1' in msg:
x0, x1 = msg['x0'], msg['x1']
Expand Down Expand Up @@ -657,6 +675,8 @@ class RangeXCallback(RangeXYCallback):

models = ['plot']

extra_handles = ['x_range']

attributes = {
'x0': 'cb_obj.x0',
'x1': 'cb_obj.x1',
Expand All @@ -672,6 +692,8 @@ class RangeYCallback(RangeXYCallback):

models = ['plot']

extra_handles = ['y_range']

attributes = {
'y0': 'cb_obj.y0',
'y1': 'cb_obj.y1'
Expand Down
30 changes: 26 additions & 4 deletions holoviews/plotting/bokeh/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -1692,6 +1692,29 @@ def _init_glyphs(self, plot, element, ranges, source):
with abbreviated_exception():
self._update_glyph(renderer, properties, mapping, glyph, source, source.data)

def _find_axes(self, plot, element):
"""
Looks up the axes and plot ranges given the plot and an element.
"""
axis_dims = self._get_axis_dims(element)[:2]
if self.invert_axes:
axis_dims[0], axis_dims[1] = axis_dims[::-1]
x, y = axis_dims
if x.name in plot.extra_x_ranges:
x_range = plot.extra_x_ranges[x.name]
xaxes = [xaxis for xaxis in plot.xaxis if xaxis.x_range_name == x.name]
x_axis = (xaxes if xaxes else plot.xaxis)[0]
else:
x_range = plot.x_range
x_axis = plot.xaxis[0]
if y.name in plot.extra_y_ranges:
y_range = plot.extra_y_ranges[y.name]
yaxes = [yaxis for yaxis in plot.yaxis if yaxis.y_range_name == y.name]
y_axis = (yaxes if yaxes else plot.yaxis)[0]
else:
y_range = plot.y_range
y_axis = plot.yaxis[0]
return (x_axis, y_axis), (x_range, y_range)

def initialize_plot(self, ranges=None, plot=None, plots=None, source=None):
"""
Expand All @@ -1714,10 +1737,9 @@ def initialize_plot(self, ranges=None, plot=None, plots=None, source=None):
plot = self._init_plot(key, style_element, ranges=ranges, plots=plots)
self._init_axes(plot)
else:
self.handles['xaxis'] = plot.xaxis[0]
self.handles['x_range'] = plot.x_range
self.handles['yaxis'] = plot.yaxis[0]
self.handles['y_range'] = plot.y_range
axes, plot_ranges = self._find_axes(plot, element)
self.handles['xaxis'], self.handles['yaxis'] = axes
self.handles['x_range'], self.handles['y_range'] = plot_ranges
philippjfr marked this conversation as resolved.
Show resolved Hide resolved
self.handles['plot'] = plot

if self.autorange:
Expand Down
12 changes: 12 additions & 0 deletions holoviews/plotting/bokeh/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,18 @@ def __init__(self, *args, **kwargs):
if self.hmap.type == Raster:
self.invert_yaxis = not self.invert_yaxis

def _init_glyph(self, plot, mapping, properties):
renderer, glyph = super()._init_glyph(plot, mapping, properties)
axis_dims = self._get_axis_dims(self.current_frame)[:2]
if self.invert_axes:
axis_dims[0], axis_dims[1] = axis_dims[::-1]
xdim, ydim = axis_dims
if xdim.name in plot.extra_x_ranges:
renderer.x_range_name = xdim.name
if ydim.name in plot.extra_y_ranges:
renderer.y_range_name = ydim.name
return renderer, glyph

def get_data(self, element, ranges, style):
mapping = dict(image='image', x='x', y='y', dw='dw', dh='dh')
val_dim = element.vdims[0]
Expand Down
2 changes: 1 addition & 1 deletion holoviews/plotting/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -1699,7 +1699,7 @@ def __init__(self, overlay, ranges=None, batched=True, keys=None, group_counter=

if ('multi_y' in self.param) and self.multi_y:
for s in self.streams:
intersection = set(s.param) & {'y', 'y_selection', 'y_range', 'bounds', 'boundsy'}
intersection = set(s.param) & {'y', 'y_selection', 'bounds', 'boundsy'}
if intersection:
self.param.warning(f'{type(s).__name__} stream parameters'
f' {list(intersection)} not yet supported with multi_y=True')
Expand Down
Loading