diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index 30fe3691635e..bcae072d4fc2 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -64,7 +64,7 @@ from polars._utils.wrap import wrap_expr, wrap_ldf, wrap_s from polars.dataframe._html import NotebookFormatter from polars.dataframe.group_by import DynamicGroupBy, GroupBy, RollingGroupBy -from polars.dataframe.plotting import DataFramePlot +from polars.dataframe.plotting.plotting import DataFramePlot from polars.datatypes import ( N_INFER_DEFAULT, Boolean, @@ -611,7 +611,7 @@ def _replace(self, column: str, new_column: Series) -> DataFrame: @property @unstable() - def plot(self) -> DataFramePlot: + def plot(self, backend: str | None = None) -> DataFramePlot: """ Create a plot namespace. @@ -704,7 +704,7 @@ def plot(self) -> DataFramePlot: if not _ALTAIR_AVAILABLE or parse_version(altair.__version__) < (5, 4, 0): msg = "altair>=5.4.0 is required for `.plot`" raise ModuleUpgradeRequiredError(msg) - return DataFramePlot(self) + return DataFramePlot(self, backend) @property @unstable() diff --git a/py-polars/polars/dataframe/plotting.py b/py-polars/polars/dataframe/plotting/plotting.py similarity index 88% rename from py-polars/polars/dataframe/plotting.py rename to py-polars/polars/dataframe/plotting/plotting.py index b452e50bfe66..1abce7c517d0 100644 --- a/py-polars/polars/dataframe/plotting.py +++ b/py-polars/polars/dataframe/plotting/plotting.py @@ -1,8 +1,8 @@ from __future__ import annotations import inspect -from typing import TYPE_CHECKING, Callable, Union - +from typing import TYPE_CHECKING, Any, Callable, Union +import os from polars.dependencies import altair as alt if TYPE_CHECKING: @@ -34,6 +34,38 @@ class DataFramePlot: """DataFrame.plot namespace.""" + def __init__(self, df: DataFrame, backend: str | None = None) -> None: + self._df = df + # TODO: add config for backend + if backend is None and "POLARS_PLOTTING_BACKEND" in os.environ: + backend = os.environ["POLARS_PLOTTING_BACKEND"] + elif backend is None: + backend = ( + "altair" # TODO: change from default to detecting installed library + ) + + if backend == "altair": + self._backend = AltairPlot(df) + + def bar( + self, + x: X | None = None, + y: Y | None = None, + color: Color | None = None, + /, + **kwargs: Any, + ) -> Any: + return self._backend.bar(x, y, color, **kwargs) + + def alt(self) -> AltairPlot: + # going through the extra class makes it so users can get the right static + # typed outputs or + return AltairPlot(self._df) + + +class AltairPlot: + """DataFrame.plot namespace for altair.""" + def __init__(self, df: DataFrame) -> None: self._chart = alt.Chart(df) diff --git a/py-polars/polars/series/plotting.py b/py-polars/polars/series/plotting.py index 8747f1a32352..1bf85d85170c 100644 --- a/py-polars/polars/series/plotting.py +++ b/py-polars/polars/series/plotting.py @@ -10,7 +10,7 @@ from altair.typing import EncodeKwds - from polars.dataframe.plotting import Encodings + from polars.dataframe.plotting.plotting import Encodings if sys.version_info >= (3, 11): from typing import Unpack