Skip to content

Commit

Permalink
Add unittests
Browse files Browse the repository at this point in the history
  • Loading branch information
hoxbro committed Jul 19, 2023
1 parent aee0dcc commit 4c4498f
Showing 1 changed file with 104 additions and 0 deletions.
104 changes: 104 additions & 0 deletions holoviews/tests/operation/test_datashader.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,38 @@

AggregationOperation.vdim_prefix = ''

@pytest.fixture()
def point_data():
num = 100
np.random.seed(1)

dists = {
cat: pd.DataFrame(
{
"x": np.random.normal(x, s, num),
"y": np.random.normal(y, s, num),
"s": s,
"val": val,
"cat": cat,
}
)
for x, y, s, val, cat in [
(2, 2, 0.03, 0, "d1"),
(2, -2, 0.10, 1, "d2"),
(-2, -2, 0.50, 2, "d3"),
(-2, 2, 1.00, 3, "d4"),
(0, 0, 3.00, 4, "d5"),
]
}
df = pd.concat(dists, ignore_index=True)
return df


@pytest.fixture()
def point_plot(point_data):
return Points(point_data)


class DatashaderAggregateTests(ComparisonTestCase):
"""
Tests for datashader aggregation
Expand Down Expand Up @@ -1169,6 +1201,78 @@ def test_rasterize_image_expand_default(self):
output = img.data["z"].to_numpy()
assert np.isnan(output).any()


@pytest.mark.parametrize("agg_input_fn,index_col",
(
[ds.first, [311, 433, 309, 482]],
[ds.last, [491, 483, 417, 482]],
[ds.min, [311, 433, 309, 482]],
[ds.max, [404, 433, 417, 482]],
)
)
def test_rasterize_where_agg_no_column(point_plot, agg_input_fn, index_col):
agg_fn = ds.where(agg_input_fn("val"))
rast_input = dict(dynamic=False, x_range=(-1, 1), y_range=(-1, 1), width=2, height=2)
img = rasterize(point_plot, aggregator=agg_fn, **rast_input)

assert list(img.data) == ["index", "s", "val", "cat"]
assert list(img.vdims) == ["val", "s", "cat"] # val first and no index

# N=100 in point_data is chosen to have a big enough sample size
# so that the index are not the same for the different agg_input_fn
np.testing.assert_array_equal(img.data["index"].data.flatten(), index_col)

img_simple = rasterize(point_plot, aggregator=agg_input_fn("val"), **rast_input)
np.testing.assert_array_equal(img_simple["val"], img["val"])


@pytest.mark.parametrize("agg_input_fn", (ds.first, ds.last, ds.min, ds.max))
def test_rasterize_where_agg_with_column(point_plot, agg_input_fn):
agg_fn = ds.where(agg_input_fn("val"), "s")
rast_input = dict(dynamic=False, x_range=(-1, 1), y_range=(-1, 1), width=2, height=2)
img = rasterize(point_plot, aggregator=agg_fn, **rast_input)

assert list(img.data) == ["s"]
img_no_column = rasterize(point_plot, aggregator=ds.where(agg_input_fn("val")), **rast_input)
np.testing.assert_array_equal(img["s"], img_no_column["s"])


def test_rasterize_summerize(point_plot):
agg_fn_count, agg_fn_first = ds.count(), ds.first("val")
agg_fn = ds.summary(count=agg_fn_count, first=agg_fn_first)
rast_input = dict(dynamic=False, x_range=(-1, 1), y_range=(-1, 1), width=2, height=2)
img_sum = rasterize(point_plot, aggregator=agg_fn, **rast_input)
img_count = rasterize(point_plot, aggregator=agg_fn_count, **rast_input)
img_first = rasterize(point_plot, aggregator=agg_fn_first, **rast_input)

np.testing.assert_array_equal(img_sum["first"], img_first["val"])

# Count has special handling in AggregationOperation which sets nodata=0
# this is not done for count in summary.
np.testing.assert_array_equal(img_sum["count"], np.nan_to_num(img_count["Count"]))


@pytest.mark.parametrize("sel_fn", (ds.first, ds.last, ds.min, ds.max))
def test_rasterize_selector(point_plot, sel_fn):
rast_input = dict(dynamic=False, x_range=(-1, 1), y_range=(-1, 1), width=2, height=2)
img = rasterize(point_plot, selector=sel_fn("val"), **rast_input)

# Count is from the aggregator
assert list(img.data) == ["Count", "index", "s", "val", "cat"]
assert list(img.vdims) == ["Count", "s", "val", "cat"] # no index

# The output for the selector should be equal to the output for the aggregator using
# ds.where
img_agg = rasterize(point_plot, aggregator=ds.where(sel_fn("val")), **rast_input)
for c in ["s", "val", "cat"]:
np.testing.assert_array_equal(img[c], img_agg[c])

# Checking the count is also the same
img_count = rasterize(point_plot, **rast_input)
np.testing.assert_array_equal(img["Count"], img_count["Count"])



class DatashaderSpreadTests(ComparisonTestCase):

def test_spread_rgb_1px(self):
Expand Down

0 comments on commit 4c4498f

Please sign in to comment.