Skip to content

Commit

Permalink
TYP: plotting (pandas-dev#55887)
Browse files Browse the repository at this point in the history
* TYP: _iter_data

* TYP: plotting

* TYP: plotting

* TYP: plotting

* Improve check

* TYP: plotting

* lint fixup

* mypy fixup

* pyright fixup
  • Loading branch information
jbrockmendel authored Nov 9, 2023
1 parent 6755b81 commit d734496
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 65 deletions.
16 changes: 10 additions & 6 deletions pandas/plotting/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,15 @@

from pandas._typing import IndexLabel

from pandas import DataFrame
from pandas import (
DataFrame,
Series,
)
from pandas.core.groupby.generic import DataFrameGroupBy


def hist_series(
self,
self: Series,
by=None,
ax=None,
grid: bool = True,
Expand Down Expand Up @@ -512,7 +516,7 @@ def boxplot(
@Substitution(data="", backend=_backend_doc)
@Appender(_boxplot_doc)
def boxplot_frame(
self,
self: DataFrame,
column=None,
by=None,
ax=None,
Expand Down Expand Up @@ -542,7 +546,7 @@ def boxplot_frame(


def boxplot_frame_groupby(
grouped,
grouped: DataFrameGroupBy,
subplots: bool = True,
column=None,
fontsize: int | None = None,
Expand Down Expand Up @@ -843,11 +847,11 @@ class PlotAccessor(PandasObject):
_kind_aliases = {"density": "kde"}
_all_kinds = _common_kinds + _series_kinds + _dataframe_kinds

def __init__(self, data) -> None:
def __init__(self, data: Series | DataFrame) -> None:
self._parent = data

@staticmethod
def _get_call_args(backend_name: str, data, args, kwargs):
def _get_call_args(backend_name: str, data: Series | DataFrame, args, kwargs):
"""
This function makes calls to this accessor `__call__` method compatible
with the previous `SeriesPlotMethods.__call__` and
Expand Down
16 changes: 8 additions & 8 deletions pandas/plotting/_matplotlib/boxplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,17 +93,18 @@ def __init__(self, data, return_type: str = "axes", **kwargs) -> None:
# error: Signature of "_plot" incompatible with supertype "MPLPlot"
@classmethod
def _plot( # type: ignore[override]
cls, ax: Axes, y, column_num=None, return_type: str = "axes", **kwds
cls, ax: Axes, y: np.ndarray, column_num=None, return_type: str = "axes", **kwds
):
ys: np.ndarray | list[np.ndarray]
if y.ndim == 2:
y = [remove_na_arraylike(v) for v in y]
ys = [remove_na_arraylike(v) for v in y]
# Boxplot fails with empty arrays, so need to add a NaN
# if any cols are empty
# GH 8181
y = [v if v.size > 0 else np.array([np.nan]) for v in y]
ys = [v if v.size > 0 else np.array([np.nan]) for v in ys]
else:
y = remove_na_arraylike(y)
bp = ax.boxplot(y, **kwds)
ys = remove_na_arraylike(y)
bp = ax.boxplot(ys, **kwds)

if return_type == "dict":
return bp, bp
Expand Down Expand Up @@ -240,8 +241,7 @@ def _make_plot(self, fig: Figure) -> None:
self.maybe_color_bp(bp)
self._return_obj = ret

labels = [left for left, _ in self._iter_data()]
labels = [pprint_thing(left) for left in labels]
labels = [pprint_thing(left) for left in self.data.columns]
if not self.use_index:
labels = [pprint_thing(key) for key in range(len(labels))]
_set_ticklabels(
Expand All @@ -251,7 +251,7 @@ def _make_plot(self, fig: Figure) -> None:
def _make_legend(self) -> None:
pass

def _post_plot_logic(self, ax, data) -> None:
def _post_plot_logic(self, ax: Axes, data) -> None:
# GH 45465: make sure that the boxplot doesn't ignore xlabel/ylabel
if self.xlabel:
ax.set_xlabel(pprint_thing(self.xlabel))
Expand Down
97 changes: 62 additions & 35 deletions pandas/plotting/_matplotlib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from collections.abc import (
Hashable,
Iterable,
Iterator,
Sequence,
)
from typing import (
Expand Down Expand Up @@ -431,17 +432,15 @@ def _validate_color_args(self):
)

@final
def _iter_data(self, data=None, keep_index: bool = False, fillna=None):
if data is None:
data = self.data
if fillna is not None:
data = data.fillna(fillna)

@staticmethod
def _iter_data(
data: DataFrame | dict[Hashable, Series | DataFrame]
) -> Iterator[tuple[Hashable, np.ndarray]]:
for col, values in data.items():
if keep_index is True:
yield col, values
else:
yield col, values.values
# This was originally written to use values.values before EAs
# were implemented; adding np.asarray(...) to keep consistent
# typing.
yield col, np.asarray(values.values)

@property
def nseries(self) -> int:
Expand Down Expand Up @@ -480,7 +479,7 @@ def _has_plotted_object(ax: Axes) -> bool:
return len(ax.lines) != 0 or len(ax.artists) != 0 or len(ax.containers) != 0

@final
def _maybe_right_yaxis(self, ax: Axes, axes_num: int):
def _maybe_right_yaxis(self, ax: Axes, axes_num: int) -> Axes:
if not self.on_right(axes_num):
# secondary axes may be passed via ax kw
return self._get_ax_layer(ax)
Expand Down Expand Up @@ -656,11 +655,7 @@ def _compute_plot_data(self):

numeric_data = data.select_dtypes(include=include_type, exclude=exclude_type)

try:
is_empty = numeric_data.columns.empty
except AttributeError:
is_empty = not len(numeric_data)

is_empty = numeric_data.shape[-1] == 0
# no non-numeric frames or series allowed
if is_empty:
raise TypeError("no numeric data to plot")
Expand All @@ -682,7 +677,7 @@ def _add_table(self) -> None:
tools.table(ax, data)

@final
def _post_plot_logic_common(self, ax, data):
def _post_plot_logic_common(self, ax: Axes, data) -> None:
"""Common post process for each axes"""
if self.orientation == "vertical" or self.orientation is None:
self._apply_axis_properties(ax.xaxis, rot=self.rot, fontsize=self.fontsize)
Expand All @@ -701,7 +696,7 @@ def _post_plot_logic_common(self, ax, data):
raise ValueError

@abstractmethod
def _post_plot_logic(self, ax, data) -> None:
def _post_plot_logic(self, ax: Axes, data) -> None:
"""Post process for each axes. Overridden in child classes"""

@final
Expand Down Expand Up @@ -1056,7 +1051,7 @@ def _get_colors(
)

@final
def _parse_errorbars(self, label, err):
def _parse_errorbars(self, label: str, err):
"""
Look for error keyword arguments and return the actual errorbar data
or return the error DataFrame/dict
Expand Down Expand Up @@ -1137,7 +1132,10 @@ def match_labels(data, e):
err = np.tile(err, (self.nseries, 1))

elif is_number(err):
err = np.tile([err], (self.nseries, len(self.data)))
err = np.tile(
[err], # pyright: ignore[reportGeneralTypeIssues]
(self.nseries, len(self.data)),
)

else:
msg = f"No valid {label} detected"
Expand Down Expand Up @@ -1418,14 +1416,14 @@ def _make_plot(self, fig: Figure) -> None:

x = data.index # dummy, not used
plotf = self._ts_plot
it = self._iter_data(data=data, keep_index=True)
it = data.items()
else:
x = self._get_xticks(convert_period=True)
# error: Incompatible types in assignment (expression has type
# "Callable[[Any, Any, Any, Any, Any, Any, KwArg(Any)], Any]", variable has
# type "Callable[[Any, Any, Any, Any, KwArg(Any)], Any]")
plotf = self._plot # type: ignore[assignment]
it = self._iter_data()
it = self._iter_data(data=self.data)

stacking_id = self._get_stacking_id()
is_errorbar = com.any_not_none(*self.errors.values())
Expand All @@ -1434,7 +1432,12 @@ def _make_plot(self, fig: Figure) -> None:
for i, (label, y) in enumerate(it):
ax = self._get_ax(i)
kwds = self.kwds.copy()
style, kwds = self._apply_style_colors(colors, kwds, i, label)
style, kwds = self._apply_style_colors(
colors,
kwds,
i,
label, # pyright: ignore[reportGeneralTypeIssues]
)

errors = self._get_errorbars(label=label, index=i)
kwds = dict(kwds, **errors)
Expand All @@ -1446,7 +1449,7 @@ def _make_plot(self, fig: Figure) -> None:
newlines = plotf(
ax,
x,
y,
y, # pyright: ignore[reportGeneralTypeIssues]
style=style,
column_num=i,
stacking_id=stacking_id,
Expand All @@ -1465,7 +1468,14 @@ def _make_plot(self, fig: Figure) -> None:
# error: Signature of "_plot" incompatible with supertype "MPLPlot"
@classmethod
def _plot( # type: ignore[override]
cls, ax: Axes, x, y, style=None, column_num=None, stacking_id=None, **kwds
cls,
ax: Axes,
x,
y: np.ndarray,
style=None,
column_num=None,
stacking_id=None,
**kwds,
):
# column_num is used to get the target column from plotf in line and
# area plots
Expand All @@ -1492,7 +1502,7 @@ def _ts_plot(self, ax: Axes, x, data: Series, style=None, **kwds):
decorate_axes(ax.right_ax, freq, kwds)
ax._plot_data.append((data, self._kind, kwds))

lines = self._plot(ax, data.index, data.values, style=style, **kwds)
lines = self._plot(ax, data.index, np.asarray(data.values), style=style, **kwds)
# set date formatter, locators and rescale limits
# error: Argument 3 to "format_dateaxis" has incompatible type "Index";
# expected "DatetimeIndex | PeriodIndex"
Expand Down Expand Up @@ -1520,7 +1530,9 @@ def _initialize_stacker(cls, ax: Axes, stacking_id, n: int) -> None:

@final
@classmethod
def _get_stacked_values(cls, ax: Axes, stacking_id, values, label):
def _get_stacked_values(
cls, ax: Axes, stacking_id: int | None, values: np.ndarray, label
) -> np.ndarray:
if stacking_id is None:
return values
if not hasattr(ax, "_stacker_pos_prior"):
Expand All @@ -1540,7 +1552,7 @@ def _get_stacked_values(cls, ax: Axes, stacking_id, values, label):

@final
@classmethod
def _update_stacker(cls, ax: Axes, stacking_id, values) -> None:
def _update_stacker(cls, ax: Axes, stacking_id: int | None, values) -> None:
if stacking_id is None:
return
if (values >= 0).all():
Expand Down Expand Up @@ -1618,7 +1630,7 @@ def _plot( # type: ignore[override]
cls,
ax: Axes,
x,
y,
y: np.ndarray,
style=None,
column_num=None,
stacking_id=None,
Expand Down Expand Up @@ -1744,7 +1756,7 @@ def _plot( # type: ignore[override]
cls,
ax: Axes,
x,
y,
y: np.ndarray,
w,
start: int | npt.NDArray[np.intp] = 0,
log: bool = False,
Expand All @@ -1763,7 +1775,8 @@ def _make_plot(self, fig: Figure) -> None:
pos_prior = neg_prior = np.zeros(len(self.data))
K = self.nseries

for i, (label, y) in enumerate(self._iter_data(fillna=0)):
data = self.data.fillna(0)
for i, (label, y) in enumerate(self._iter_data(data=data)):
ax = self._get_ax(i)
kwds = self.kwds.copy()
if self._is_series:
Expand Down Expand Up @@ -1842,7 +1855,14 @@ def _post_plot_logic(self, ax: Axes, data) -> None:

self._decorate_ticks(ax, self._get_index_name(), str_index, s_edge, e_edge)

def _decorate_ticks(self, ax: Axes, name, ticklabels, start_edge, end_edge) -> None:
def _decorate_ticks(
self,
ax: Axes,
name: str | None,
ticklabels: list[str],
start_edge: float,
end_edge: float,
) -> None:
ax.set_xlim((start_edge, end_edge))

if self.xticks is not None:
Expand Down Expand Up @@ -1876,7 +1896,7 @@ def _plot( # type: ignore[override]
cls,
ax: Axes,
x,
y,
y: np.ndarray,
w,
start: int | npt.NDArray[np.intp] = 0,
log: bool = False,
Expand All @@ -1887,7 +1907,14 @@ def _plot( # type: ignore[override]
def _get_custom_index_name(self):
return self.ylabel

def _decorate_ticks(self, ax: Axes, name, ticklabels, start_edge, end_edge) -> None:
def _decorate_ticks(
self,
ax: Axes,
name: str | None,
ticklabels: list[str],
start_edge: float,
end_edge: float,
) -> None:
# horizontal bars
ax.set_ylim((start_edge, end_edge))
ax.set_yticks(self.tick_pos)
Expand Down Expand Up @@ -1921,7 +1948,7 @@ def _make_plot(self, fig: Figure) -> None:
colors = self._get_colors(num_colors=len(self.data), color_kwds="colors")
self.kwds.setdefault("colors", colors)

for i, (label, y) in enumerate(self._iter_data()):
for i, (label, y) in enumerate(self._iter_data(data=self.data)):
ax = self._get_ax(i)
if label is not None:
label = pprint_thing(label)
Expand Down
8 changes: 4 additions & 4 deletions pandas/plotting/_matplotlib/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@
from pandas.plotting._matplotlib.misc import unpack_single_str_list

if TYPE_CHECKING:
from collections.abc import Hashable

from pandas._typing import IndexLabel


def create_iter_data_given_by(
data: DataFrame, kind: str = "hist"
) -> dict[str, DataFrame | Series]:
) -> dict[Hashable, DataFrame | Series]:
"""
Create data for iteration given `by` is assigned or not, and it is only
used in both hist and boxplot.
Expand Down Expand Up @@ -126,9 +128,7 @@ def reconstruct_data_with_by(
return data


def reformat_hist_y_given_by(
y: Series | np.ndarray, by: IndexLabel | None
) -> Series | np.ndarray:
def reformat_hist_y_given_by(y: np.ndarray, by: IndexLabel | None) -> np.ndarray:
"""Internal function to reformat y given `by` is applied or not for hist plot.
If by is None, input y is 1-d with NaN removed; and if by is not None, groupby
Expand Down
Loading

0 comments on commit d734496

Please sign in to comment.