Skip to content

Commit

Permalink
Add rasterize support for ds.by and selector
Browse files Browse the repository at this point in the history
  • Loading branch information
hoxbro committed Nov 6, 2024
1 parent fb4ddd1 commit 10d4272
Showing 1 changed file with 35 additions and 12 deletions.
47 changes: 35 additions & 12 deletions holoviews/operation/datashader.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import enum
import warnings
from collections.abc import Callable, Iterable
from functools import partial
Expand Down Expand Up @@ -238,6 +239,13 @@ class LineAggregationOperation(AggregationOperation):



class AggState(enum.Enum):
AGG = 0 # Only aggregator
AGG_BY = 1 # Only aggregator, where the aggregator is ds.by
SEL = 2 # Selector and aggregator
SEL_BY = 3 # Selector and aggregator, where the aggregator is ds.by


class aggregate(LineAggregationOperation):
"""
aggregate implements 2D binning for any valid HoloViews Element
Expand Down Expand Up @@ -391,13 +399,18 @@ def _process(self, element, key=None):
dfdata = PandasInterface.as_dframe(data)
cvs_fn = getattr(cvs, glyph)

if sel_fn:
if isinstance(agg_fn, ds.by):
expr_state = AggState.SEL_BY if sel_fn else AggState.AGG_BY
else:
expr_state = AggState.SEL if sel_fn else AggState.AGG

if expr_state in (AggState.SEL, AggState.SEL_BY):
if isinstance(params["vdims"], (Dimension, str)):
params["vdims"] = [params["vdims"]]
sum_agg = ds.summary(**{str(params["vdims"][0]): agg_fn, "__index__": ds.where(sel_fn)})
agg = self._apply_datashader(dfdata, cvs_fn, sum_agg, agg_kwargs, x, y)
agg = self._apply_datashader(dfdata, cvs_fn, sum_agg, agg_kwargs, x, y, expr_state)
else:
agg = self._apply_datashader(dfdata, cvs_fn, agg_fn, agg_kwargs, x, y)
agg = self._apply_datashader(dfdata, cvs_fn, agg_fn, agg_kwargs, x, y, expr_state)

if 'x_axis' in agg.coords and 'y_axis' in agg.coords:
agg = agg.rename({'x_axis': x, 'y_axis': y})
Expand All @@ -406,13 +419,16 @@ def _process(self, element, key=None):
if ytype == 'datetime':
agg[y.name] = agg[y.name].astype('datetime64[ns]')

if isinstance(agg, xr.Dataset) or agg.ndim == 2:
if expr_state in (AggState.AGG, AggState.SEL):
return self.p.element_type(agg, **params)
else:
elif expr_state == AggState.AGG_BY:
params['vdims'] = list(map(str, agg.coords[agg_fn.column].data))
return ImageStack(agg, **params)
elif expr_state == AggState.SEL_BY:
params['vdims'] = [d for d in agg.data_vars if d not in agg.attrs["selector_columns"]]
return ImageStack(agg, **params)

def _apply_datashader(self, dfdata, cvs_fn, agg_fn, agg_kwargs, x, y):
def _apply_datashader(self, dfdata, cvs_fn, agg_fn, agg_kwargs, x, y, agg_state: AggState):
# Suppress numpy warning emitted by dask:
# https://github.com/dask/dask/issues/8439
with warnings.catch_warnings():
Expand All @@ -423,19 +439,26 @@ def _apply_datashader(self, dfdata, cvs_fn, agg_fn, agg_kwargs, x, y):
agg = cvs_fn(dfdata, x.name, y.name, agg_fn, **agg_kwargs)

is_where_index = DATASHADER_GE_0_15_1 and isinstance(agg_fn, ds.where) and isinstance(agg_fn.column, rd.SpecialColumn)
is_summary_index = isinstance(agg_fn, ds.summary) and "__index__" in agg
is_summary_index = agg_state in (AggState.SEL, AggState.SEL_BY)
if is_where_index or is_summary_index:
if is_where_index:
data = agg.data
index = agg.data
agg = agg.to_dataset(name="__index__")
else: # summary index
data = agg["__index__"].data
neg1 = data == -1
index = agg["__index__"].data
if agg_state == AggState.SEL_BY:
main_dim = next(k for k in agg if k != "__index__")
# Taking values from the main dimension expanding it to
# a new dataset
agg = agg[main_dim].to_dataset(dim=list(agg.dims)[2])
agg["__index__"] = ((y.name, x.name), index)

neg1 = index == -1
agg.attrs["selector_columns"] = sel_cols = ["__index__"]
for col in dfdata.columns:
if col in agg.coords:
continue
val = dfdata[col].values[data]
val = dfdata[col].values[index]
if val.dtype.kind == 'f':
val[neg1] = np.nan
elif isinstance(val.dtype, pd.CategoricalDtype):
Expand All @@ -451,7 +474,7 @@ def _apply_datashader(self, dfdata, cvs_fn, agg_fn, agg_kwargs, x, y):
agg[col] = ((y.name, x.name), val)
sel_cols.append(col)

if isinstance(agg_fn, ds.by):
if agg_state == AggState.AGG_BY:
col = agg_fn.column
if '' in agg.coords[col]:
agg = agg.drop_sel(**{col: ''})
Expand Down

0 comments on commit 10d4272

Please sign in to comment.