Skip to content

Commit

Permalink
Support ds.where and ds.summary and add selector (#5805)
Browse files Browse the repository at this point in the history
  • Loading branch information
hoxbro authored Jul 24, 2023
1 parent 4fc8357 commit 6e59170
Show file tree
Hide file tree
Showing 2 changed files with 192 additions and 25 deletions.
113 changes: 88 additions & 25 deletions holoviews/operation/datashader.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __getattr__(name):


ds_version = Version(ds.__version__)
ds15 = ds_version >= Version('0.15.1')


class AggregationOperation(ResampleOperation2D):
Expand All @@ -58,13 +59,19 @@ class AggregationOperation(ResampleOperation2D):
aggregator parameter used to define a datashader Reduction.
"""

aggregator = param.ClassSelector(class_=(ds.reductions.Reduction, str),
default=ds.count(), doc="""
aggregator = param.ClassSelector(class_=(rd.Reduction, rd.summary, str),
default=rd.count(), doc="""
Datashader reduction function used for aggregating the data.
The aggregator may also define a column to aggregate; if
no column is defined the first value dimension of the element
will be used. May also be defined as a string.""")

selector = param.ClassSelector(class_=(rd.min, rd.max, rd.first, rd.last),
default=None, doc="""
Selector is a datashader reduction function used for selecting data.
The selector only works with aggregators which selects an item from
the original data. These selectors are min, max, first and last.""")

vdim_prefix = param.String(default='{kdims} ', allow_None=True, doc="""
Prefix to prepend to value dimension name where {kdims}
templates in the names of the input element key dimensions.""")
Expand All @@ -86,6 +93,11 @@ class AggregationOperation(ResampleOperation2D):

@classmethod
def _get_aggregator(cls, element, agg, add_field=True):
if ds15:
agg_types = (rd.count, rd.any, rd.where)
else:
agg_types = (rd.count, rd.any)

if isinstance(agg, str):
if agg not in cls._agg_methods:
agg_methods = sorted(cls._agg_methods)
Expand All @@ -98,7 +110,7 @@ def _get_aggregator(cls, element, agg, add_field=True):

elements = element.traverse(lambda x: x, [Element])
if (add_field and getattr(agg, 'column', False) in ('__temp__', None) and
not isinstance(agg, (rd.count, rd.any))):
not isinstance(agg, agg_types)):
if not elements:
raise ValueError('Could not find any elements to apply '
'%s operation to.' % cls.__name__)
Expand Down Expand Up @@ -147,8 +159,19 @@ def _get_agg_params(self, element, x, y, agg_fn, bounds):
if hasattr(agg_fn, 'reduction'):
category = agg_fn.cat_column
agg_fn = agg_fn.reduction
column = agg_fn.column if agg_fn else None
if column:
if isinstance(agg_fn, rd.summary):
column = None
else:
column = agg_fn.column if agg_fn else None
agg_name = type(agg_fn).__name__.title()
if agg_name == "Where":
# Set the first item to be the selector column.
col = agg_fn.column if not isinstance(agg_fn.column, rd.SpecialColumn) else agg_fn.selector.column
vdims = sorted(params["vdims"], key=lambda x: x != col)
# TODO: Should we add prefix to all of the where columns.
elif agg_name == "Summary":
vdims = list(agg_fn.keys)
elif column:
dims = [d for d in element.dimensions('ranges') if d == column]
if not dims:
raise ValueError("Aggregation column '{}' not found on '{}' element. "
Expand All @@ -163,13 +186,11 @@ def _get_agg_params(self, element, x, y, agg_fn, bounds):
else:
vdims = dims[0].clone(vdim_prefix + column)
elif category:
agg_name = type(agg_fn).__name__.title()
agg_label = f'{category} {agg_name}'
vdims = Dimension(f'{vdim_prefix}{agg_label}', label=agg_label)
if agg_name in ('Count', 'Any'):
vdims.nodata = 0
else:
agg_name = type(agg_fn).__name__.title()
vdims = Dimension(f'{vdim_prefix}{agg_name}', label=agg_name, nodata=0)
params['vdims'] = vdims
return params
Expand Down Expand Up @@ -298,6 +319,7 @@ def get_agg_data(cls, obj, category=None):

def _process(self, element, key=None):
agg_fn = self._get_aggregator(element, self.p.aggregator)
sel_fn = getattr(self.p, "selector", None)
if hasattr(agg_fn, 'cat_column'):
category = agg_fn.cat_column
else:
Expand Down Expand Up @@ -338,22 +360,27 @@ def _process(self, element, key=None):
agg_kwargs['line_width'] = self.p.line_width

dfdata = PandasInterface.as_dframe(data)
# Suppress numpy warning emitted by dask:
# https://github.com/dask/dask/issues/8439
with warnings.catch_warnings():
warnings.filterwarnings(
action='ignore', message='casting datetime64',
category=FutureWarning
)
agg = getattr(cvs, glyph)(dfdata, x.name, y.name, agg_fn, **agg_kwargs)
cvs_fn = getattr(cvs, glyph)

if sel_fn:
if isinstance(params["vdims"], (Dimension, str)):
params["vdims"] = [params["vdims"]]
sum_agg = ds.summary(**{str(params["vdims"][0]): agg_fn, "index": ds.where(sel_fn)})
agg = self._apply_datashader(dfdata, cvs_fn, sum_agg, agg_kwargs, x, y)
_ignore = [*params["vdims"], "index"]
sel_vdims = [s for s in agg if s not in _ignore]
params["vdims"] = [*params["vdims"], *sel_vdims]
else:
agg = self._apply_datashader(dfdata, cvs_fn, agg_fn, agg_kwargs, x, y)

if 'x_axis' in agg.coords and 'y_axis' in agg.coords:
agg = agg.rename({'x_axis': x, 'y_axis': y})
if xtype == 'datetime':
agg[x.name] = agg[x.name].astype('datetime64[ns]')
if ytype == 'datetime':
agg[y.name] = agg[y.name].astype('datetime64[ns]')

if agg.ndim == 2:
if isinstance(agg, xr.Dataset) or agg.ndim == 2:
# Replacing x and y coordinates to avoid numerical precision issues
eldata = agg if ds_version > Version('0.5.0') else (xs, ys, agg.data)
return self.p.element_type(eldata, **params)
Expand All @@ -365,6 +392,42 @@ def _process(self, element, key=None):
layers[c] = self.p.element_type(eldata, **params)
return NdOverlay(layers, kdims=[data.get_dimension(agg_fn.column)])

def _apply_datashader(self, dfdata, cvs_fn, agg_fn, agg_kwargs, x, y):
# Suppress numpy warning emitted by dask:
# https://github.com/dask/dask/issues/8439
with warnings.catch_warnings():
warnings.filterwarnings(
action='ignore', message='casting datetime64',
category=FutureWarning
)
agg = cvs_fn(dfdata, x.name, y.name, agg_fn, **agg_kwargs)

is_where_index = ds15 and isinstance(agg_fn, ds.where) and isinstance(agg_fn.column, rd.SpecialColumn)
is_summary_index = isinstance(agg_fn, ds.summary) and "index" in agg
if is_where_index or is_summary_index:
if is_where_index:
data = agg.data
agg = agg.to_dataset(name="index")
else: # summary index
data = agg.index.data
neg1 = data == -1
for col in dfdata.columns:
if col in agg.coords:
continue
val = dfdata[col].values[data]
if val.dtype.kind == 'f':
val[neg1] = np.nan
elif isinstance(val.dtype, pd.CategoricalDtype):
val = val.to_numpy()
val[neg1] = "-"
elif val.dtype.kind == "O":
val[neg1] = "-"
else:
val = val.astype(np.float64)
val[neg1] = np.nan
agg[col] = ((y.name, x.name), val)
return agg



class overlay_aggregate(aggregate):
Expand Down Expand Up @@ -719,8 +782,8 @@ class regrid(AggregationOperation):
with NaN values.
"""

aggregator = param.ClassSelector(default=ds.mean(),
class_=(ds.reductions.Reduction, str))
aggregator = param.ClassSelector(default=rd.mean(),
class_=(rd.Reduction, rd.summary, str))

expand = param.Boolean(default=False, doc="""
Whether the x_range and y_range should be allowed to expand
Expand Down Expand Up @@ -857,8 +920,8 @@ class contours_rasterize(aggregate):
default to any aggregator.
"""

aggregator = param.ClassSelector(default=ds.mean(),
class_=(ds.reductions.Reduction, str))
aggregator = param.ClassSelector(default=rd.mean(),
class_=(rd.Reduction, rd.summary, str))

@classmethod
def _get_aggregator(cls, element, agg, add_field=True):
Expand All @@ -876,8 +939,8 @@ class trimesh_rasterize(aggregate):
data.
"""

aggregator = param.ClassSelector(default=ds.mean(),
class_=(ds.reductions.Reduction, str))
aggregator = param.ClassSelector(default=rd.mean(),
class_=(rd.Reduction, rd.summary, str))

interpolation = param.ObjectSelector(default='bilinear',
objects=['bilinear', 'linear', None, False], doc="""
Expand Down Expand Up @@ -1257,8 +1320,8 @@ class geometry_rasterize(LineAggregationOperation):
Rasterizes geometries by converting them to spatialpandas.
"""

aggregator = param.ClassSelector(default=ds.mean(),
class_=(ds.reductions.Reduction, str))
aggregator = param.ClassSelector(default=rd.mean(),
class_=(rd.Reduction, rd.summary, str))

@classmethod
def _get_aggregator(cls, element, agg, add_field=True):
Expand Down Expand Up @@ -1342,7 +1405,7 @@ class rasterize(AggregationOperation):
dimensions of the linked plot and the ranges of the axes.
"""

aggregator = param.ClassSelector(class_=(ds.reductions.Reduction, str),
aggregator = param.ClassSelector(class_=(rd.Reduction, rd.summary, str),
default='default')

interpolation = param.ObjectSelector(
Expand Down
104 changes: 104 additions & 0 deletions holoviews/tests/operation/test_datashader.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,38 @@

AggregationOperation.vdim_prefix = ''

@pytest.fixture()
def point_data():
num = 100
np.random.seed(1)

dists = {
cat: pd.DataFrame(
{
"x": np.random.normal(x, s, num),
"y": np.random.normal(y, s, num),
"s": s,
"val": val,
"cat": cat,
}
)
for x, y, s, val, cat in [
(2, 2, 0.03, 0, "d1"),
(2, -2, 0.10, 1, "d2"),
(-2, -2, 0.50, 2, "d3"),
(-2, 2, 1.00, 3, "d4"),
(0, 0, 3.00, 4, "d5"),
]
}
df = pd.concat(dists, ignore_index=True)
return df


@pytest.fixture()
def point_plot(point_data):
return Points(point_data)


class DatashaderAggregateTests(ComparisonTestCase):
"""
Tests for datashader aggregation
Expand Down Expand Up @@ -1191,6 +1223,78 @@ def test_rasterize_apply_when_instance_with_line_width(self):
assert isinstance(overlay, Overlay)
assert len(overlay) == 2


@pytest.mark.parametrize("agg_input_fn,index_col",
(
[ds.first, [311, 433, 309, 482]],
[ds.last, [491, 483, 417, 482]],
[ds.min, [311, 433, 309, 482]],
[ds.max, [404, 433, 417, 482]],
)
)
def test_rasterize_where_agg_no_column(point_plot, agg_input_fn, index_col):
agg_fn = ds.where(agg_input_fn("val"))
rast_input = dict(dynamic=False, x_range=(-1, 1), y_range=(-1, 1), width=2, height=2)
img = rasterize(point_plot, aggregator=agg_fn, **rast_input)

assert list(img.data) == ["index", "s", "val", "cat"]
assert list(img.vdims) == ["val", "s", "cat"] # val first and no index

# N=100 in point_data is chosen to have a big enough sample size
# so that the index are not the same for the different agg_input_fn
np.testing.assert_array_equal(img.data["index"].data.flatten(), index_col)

img_simple = rasterize(point_plot, aggregator=agg_input_fn("val"), **rast_input)
np.testing.assert_array_equal(img_simple["val"], img["val"])


@pytest.mark.parametrize("agg_input_fn", (ds.first, ds.last, ds.min, ds.max))
def test_rasterize_where_agg_with_column(point_plot, agg_input_fn):
agg_fn = ds.where(agg_input_fn("val"), "s")
rast_input = dict(dynamic=False, x_range=(-1, 1), y_range=(-1, 1), width=2, height=2)
img = rasterize(point_plot, aggregator=agg_fn, **rast_input)

assert list(img.data) == ["s"]
img_no_column = rasterize(point_plot, aggregator=ds.where(agg_input_fn("val")), **rast_input)
np.testing.assert_array_equal(img["s"], img_no_column["s"])


def test_rasterize_summerize(point_plot):
agg_fn_count, agg_fn_first = ds.count(), ds.first("val")
agg_fn = ds.summary(count=agg_fn_count, first=agg_fn_first)
rast_input = dict(dynamic=False, x_range=(-1, 1), y_range=(-1, 1), width=2, height=2)
img_sum = rasterize(point_plot, aggregator=agg_fn, **rast_input)
img_count = rasterize(point_plot, aggregator=agg_fn_count, **rast_input)
img_first = rasterize(point_plot, aggregator=agg_fn_first, **rast_input)

np.testing.assert_array_equal(img_sum["first"], img_first["val"])

# Count has special handling in AggregationOperation which sets nodata=0
# this is not done for count in summary.
np.testing.assert_array_equal(img_sum["count"], np.nan_to_num(img_count["Count"]))


@pytest.mark.parametrize("sel_fn", (ds.first, ds.last, ds.min, ds.max))
def test_rasterize_selector(point_plot, sel_fn):
rast_input = dict(dynamic=False, x_range=(-1, 1), y_range=(-1, 1), width=2, height=2)
img = rasterize(point_plot, selector=sel_fn("val"), **rast_input)

# Count is from the aggregator
assert list(img.data) == ["Count", "index", "s", "val", "cat"]
assert list(img.vdims) == ["Count", "s", "val", "cat"] # no index

# The output for the selector should be equal to the output for the aggregator using
# ds.where
img_agg = rasterize(point_plot, aggregator=ds.where(sel_fn("val")), **rast_input)
for c in ["s", "val", "cat"]:
np.testing.assert_array_equal(img[c], img_agg[c])

# Checking the count is also the same
img_count = rasterize(point_plot, **rast_input)
np.testing.assert_array_equal(img["Count"], img_count["Count"])



class DatashaderSpreadTests(ComparisonTestCase):

def test_spread_rgb_1px(self):
Expand Down

0 comments on commit 6e59170

Please sign in to comment.