Skip to content

Commit

Permalink
Merge branch 'main' into inspector_tests
Browse files Browse the repository at this point in the history
  • Loading branch information
nvaytet authored Sep 26, 2023
2 parents 9a971a2 + 2760b16 commit 297a0b1
Show file tree
Hide file tree
Showing 10 changed files with 225 additions and 11 deletions.
48 changes: 47 additions & 1 deletion docs/basics/line-plot.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,53 @@
"id": "123c2b07-ad46-485b-9842-8a754d1fc44b",
"metadata": {},
"source": [
"Note that if no coordinate of name `'x'` exists, a dummy one will be generated using `scipp.arange`."
"Note that if no coordinate of name `'x'` exists, a dummy one will be generated using `scipp.arange`.\n",
"\n",
"## Plotting one variable as a function of another\n",
"\n",
"<div class=\"versionadded\" style=\"font-weight: bold;\">\n",
"\n",
"<img src=\"../_static/circle-exclamation.svg\" width=\"16\" height=\"16\" />\n",
"&nbsp;\n",
"New in version <TODO:VERSION>.\n",
"\n",
"</div>\n",
"\n",
"Sometimes it is useful, for quickly inspecting data, to plot one variable as a function of another,\n",
"without having to first explicitly store them both in a `DataArray`.\n",
"\n",
"For this, one can use a small dedicated function called `xyplot`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e66f38b9-3731-4d6f-b017-2ccb5202bf79",
"metadata": {},
"outputs": [],
"source": [
"x = sc.arange('distance', 50.0, unit='m')\n",
"y = x**2\n",
"\n",
"pp.xyplot(x, y)"
]
},
{
"cell_type": "markdown",
"id": "9107b2df-660c-412f-8171-c24484a5bae1",
"metadata": {},
"source": [
"Any additional keyword arguments are forwarded to the `plot` function:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "186abd7d-9ed3-4947-848e-8418879ab96f",
"metadata": {},
"outputs": [],
"source": [
"pp.xyplot(x, y, ls='solid', color='purple', marker=None, lw=3)"
]
}
],
Expand Down
1 change: 1 addition & 0 deletions docs/customization/graph-node-tips.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
"def add(x, y):\n",
" return x + y\n",
"\n",
"\n",
"c = pp.Node(add, a, b)\n",
"c()"
]
Expand Down
1 change: 1 addition & 0 deletions docs/reference/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
superplot
inspector
scatter3d
xyplot
```

## Core
Expand Down
3 changes: 2 additions & 1 deletion src/plopp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from . import data
from .core import Node, View, node, show_graph, widget_node
from .graphics import Camera, figure1d, figure2d, figure3d, tiled
from .plotting import inspector, plot, scatter3d, slicer, superplot
from .plotting import inspector, plot, scatter3d, slicer, superplot, xyplot


def show():
Expand Down Expand Up @@ -51,4 +51,5 @@ def show():
'superplot',
'tiled',
'widget_node',
'xyplot',
]
1 change: 1 addition & 0 deletions src/plopp/plotting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
from .scatter3d import scatter3d
from .slicer import slicer
from .superplot import superplot
from .xyplot import xyplot
45 changes: 38 additions & 7 deletions src/plopp/plotting/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,24 +46,55 @@ def from_compatible_lib(obj: Any) -> Any:
return obj


def _to_data_array(
obj: Union[Plottable, list],
) -> sc.DataArray:
def _maybe_to_variable(obj: Union[Plottable, list]) -> Plottable:
"""
Convert an input to a DataArray, potentially adding fake coordinates if they are
missing.
Attempt to convert the input to a Variable.
If the input is either a list or a numpy array, it will be converted.
Otherwise, the input will be returned unchanged.
"""
out = obj
if isinstance(out, list):
out = np.array(out)
if isinstance(out, np.ndarray):
dims = [f"axis-{i}" for i in range(len(out.shape))]
out = sc.Variable(dims=dims, values=out)
return out


def to_variable(obj) -> sc.Variable:
"""
Convert an input to a Variable. If the object returned by the conversion is not a
Variable, raise an error.
Parameters
----------
obj:
The input object to be converted.
"""
out = _maybe_to_variable(obj)
if not isinstance(out, sc.Variable):
raise TypeError(f"Cannot convert input of type {type(obj)} to a Variable.")
return out


def to_data_array(
obj: Union[Plottable, list],
) -> sc.DataArray:
"""
Convert an input to a DataArray, potentially adding fake coordinates if they are
missing.
Parameters
----------
obj:
The input object to be converted.
"""
out = _maybe_to_variable(obj)
if isinstance(out, sc.Variable):
out = sc.DataArray(data=out)
out = from_compatible_lib(out)
if not isinstance(out, sc.DataArray):
raise ValueError(f"Cannot convert input of type {type(obj)} to a DataArray.")
raise TypeError(f"Cannot convert input of type {type(obj)} to a DataArray.")
out = out.copy(deep=False)
for dim, size in out.sizes.items():
if dim not in out.coords:
Expand Down Expand Up @@ -152,7 +183,7 @@ def preprocess(
coords:
If supplied, use these coords instead of the input's dimension coordinates.
"""
out = _to_data_array(obj)
out = to_data_array(obj)
check_not_binned(out)
check_allowed_dtypes(out)
if not out.name:
Expand Down
54 changes: 54 additions & 0 deletions src/plopp/plotting/xyplot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)

from typing import Union

import scipp as sc
from numpy import ndarray

from ..core import Node
from ..graphics import figure1d
from .common import to_variable


def _make_data_array(x: sc.Variable, y: sc.Variable) -> sc.DataArray:
"""
Make a data array from the supplied variables, using ``x`` as the coordinate and
``y`` as the data.
Parameters
----------
x:
The variable to use as the coordinate.
y:
The variable to use as the data.
"""
return sc.DataArray(data=y, coords={x.dim: x})


def xyplot(
x: Union[sc.Variable, ndarray, list, Node],
y: Union[sc.Variable, ndarray, list, Node],
**kwargs,
):
"""
Make a one-dimensional plot of one variable ``y`` as a function of another ``x``.
.. versionadded:: <TODO:VERSION>
Parameters
----------
x:
The variable to use as the coordinates for the horizontal axis.
Must be one-dimensional.
y:
The variable to use as the data for the vertical axis. Must be one-dimensional.
**kwargs:
See :py:func:`plopp.plot`.
"""
x = Node(to_variable, x)
y = Node(to_variable, y)
dim = x().dim
if dim != y().dim:
raise sc.DimensionError("Dimensions of x and y must match")
return figure1d(Node(_make_data_array, x=x, y=y), **kwargs)
2 changes: 1 addition & 1 deletion tests/plotting/inspector_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_from_node(use_ipympl):

def test_multiple_inputs_raises(use_ipympl):
da = pp.data.data3d()
with pytest.raises(ValueError, match='Cannot convert input of type'):
with pytest.raises(TypeError, match='Cannot convert input of type'):
pp.inspector({'a': da, 'b': 2.3 * da})


Expand Down
2 changes: 1 addition & 1 deletion tests/plotting/plot_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def test_raises_ValueError_when_given_unsupported_data_type():
c = a * 3.0
d = a * 4.0
nested_dict = {'group1': {'a': a, 'b': b}, 'group2': {'c': c, 'd': d}}
with pytest.raises(ValueError, match='Cannot convert input of type'):
with pytest.raises(TypeError, match='Cannot convert input of type'):
pp.plot(nested_dict)


Expand Down
79 changes: 79 additions & 0 deletions tests/plotting/xyplot_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)

import numpy as np
import pytest
import scipp as sc

import plopp as pp


def test_xyplot_variable():
x = sc.arange('time', 20.0, unit='s')
y = sc.arange('time', 100.0, 120.0, unit='K')
fig = pp.xyplot(x, y)
assert fig.canvas.xlabel == 'time [s]'
assert fig.canvas.ylabel == '[K]'


def test_xyplot_ndarray():
N = 50
x = np.arange(float(N))
y = np.linspace(-44.0, 44.0, N)
pp.xyplot(x, y)


def test_xyplot_list():
x = [1, 2, 3, 4, 5]
y = [6, 7, 2, 0, 1]
pp.xyplot(x, y)


def test_xyplot_different_dims_raises():
x = sc.arange('x', 20.0, unit='s')
y = sc.arange('y', 100.0, 120.0, unit='K')
with pytest.raises(sc.DimensionError, match='Dimensions of x and y must match'):
pp.xyplot(x, y)


def test_xyplot_data_array_raises():
x = sc.arange('x', 20.0, unit='s')
y = pp.data.data1d()
with pytest.raises(TypeError, match='Cannot convert input of type'):
pp.xyplot(x, y)
with pytest.raises(TypeError, match='Cannot convert input of type'):
pp.xyplot(y, x)


def test_xyplot_2d_variable_raises():
x = sc.arange('x', 50.0, unit='s')
y = pp.data.data2d().data
with pytest.raises(sc.DimensionError, match='Expected 1 dimensions, got 2'):
pp.xyplot(x, y)
with pytest.raises(sc.DimensionError, match='Expected 1 dimensions, got 2'):
pp.xyplot(y, x)


def test_xyplot_variable_kwargs():
x = sc.arange('time', 20.0, unit='s')
y = sc.arange('time', 100.0, 120.0, unit='K')
fig = pp.xyplot(x, y, color='red', vmin=102.0, vmax=115.0)
assert np.allclose(fig.canvas.yrange, [102.0, 115.0])
line = list(fig.artists.values())[0]
assert line.color == 'red'


def test_xyplot_bin_edges():
x = sc.arange('time', 21.0, unit='s')
y = sc.arange('time', 100.0, 120.0, unit='K')
fig = pp.xyplot(x, y)
line = list(fig.artists.values())[0]
assert len(line._line.get_xdata()) == 21


def test_xyplot_from_nodes():
x = sc.arange('time', 20.0, unit='s')
y = sc.arange('time', 100.0, 120.0, unit='K')
pp.xyplot(pp.Node(x), y)
pp.xyplot(x, pp.Node(y))
pp.xyplot(pp.Node(x), pp.Node(y))

0 comments on commit 297a0b1

Please sign in to comment.