From eeaea89f8685303a58e7ff9e43b552542367985b Mon Sep 17 00:00:00 2001 From: Ben Rutter Date: Mon, 8 Jul 2024 13:19:08 +0100 Subject: [PATCH 01/19] initial dask integration for test_common --- narwhals/_pandas_like/dataframe.py | 8 +++++++- narwhals/_pandas_like/series.py | 5 +++++ narwhals/_pandas_like/utils.py | 15 +++++++++++++-- narwhals/dataframe.py | 5 +++++ narwhals/dependencies.py | 5 +++++ narwhals/translate.py | 4 ++++ requirements-dev.txt | 2 +- tests/conftest.py | 8 +++++++- tests/frame/test_common.py | 18 ++++++++++-------- tests/utils.py | 11 +++++++++++ 10 files changed, 68 insertions(+), 13 deletions(-) diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index a38f1bd4b..f6aeb42aa 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -18,6 +18,7 @@ from narwhals._pandas_like.utils import validate_dataframe_comparand from narwhals._pandas_like.utils import validate_indices from narwhals.dependencies import get_cudf +from narwhals.dependencies import get_dask from narwhals.dependencies import get_modin from narwhals.dependencies import get_numpy from narwhals.dependencies import get_pandas @@ -64,6 +65,8 @@ def __native_namespace__(self) -> Any: return get_modin() if self._implementation == "cudf": # pragma: no cover return get_cudf() + if self._implementation == "dask": # pragma: no cover + return get_dask() msg = f"Expected pandas/modin/cudf, got: {type(self._implementation)}" # pragma: no cover raise AssertionError(msg) @@ -104,7 +107,6 @@ def __getitem__(self, item: str | slice) -> PandasSeries | PandasDataFrame: elif isinstance(item, (slice, Sequence)): from narwhals._pandas_like.dataframe import PandasDataFrame - return PandasDataFrame( self._dataframe.iloc[item], implementation=self._implementation ) @@ -397,6 +399,8 @@ def to_numpy(self) -> Any: import numpy as np return np.hstack([self[col].to_numpy()[:, None] for col in self.columns]) + if self._implementation == "dask": + return self._dataframe.compute().to_numpy() return self._dataframe.to_numpy() def to_pandas(self) -> Any: @@ -404,6 +408,8 @@ def to_pandas(self) -> Any: return self._dataframe if self._implementation == "modin": # pragma: no cover return self._dataframe._to_pandas() + if self._implementation == "dask": # pragma: no cover + return self._dataframe.compute() return self._dataframe.to_pandas() # pragma: no cover def write_parquet(self, file: Any) -> Any: diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index 3b92a5aa2..b233321f2 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -13,6 +13,7 @@ from narwhals._pandas_like.utils import translate_dtype from narwhals._pandas_like.utils import validate_column_comparand from narwhals.dependencies import get_cudf +from narwhals.dependencies import get_dask from narwhals.dependencies import get_modin from narwhals.dependencies import get_pandas from narwhals.utils import parse_version @@ -106,6 +107,8 @@ def __native_namespace__(self) -> Any: return get_modin() if self._implementation == "cudf": # pragma: no cover return get_cudf() + if self._implementation == "dask": # pragma: no cover + return get_dask() msg = f"Expected pandas/modin/cudf, got: {type(self._implementation)}" # pragma: no cover raise AssertionError(msg) @@ -138,6 +141,8 @@ def _from_iterable( ) def __len__(self) -> int: + if self._implementation == "dask": + return len(self._series) return self.shape[0] @property diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index add99a3cd..40433d397 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -8,6 +8,7 @@ from typing import TypeVar from narwhals.dependencies import get_cudf +from narwhals.dependencies import get_dask from narwhals.dependencies import get_modin from narwhals.dependencies import get_numpy from narwhals.dependencies import get_pandas @@ -54,7 +55,7 @@ def validate_column_comparand(index: Any, other: Any) -> Any: if other.len() == 1: # broadcast return other.item() - if other._series.index is not index: + if other._series.index is not index and other._implementation != "dask": return set_axis(other._series, index, implementation=other._implementation) return other._series return other @@ -75,7 +76,7 @@ def validate_dataframe_comparand(index: Any, other: Any) -> Any: if other.len() == 1: # broadcast return other._series.iloc[0] - if other._series.index is not index: + if other._series.index is not index and other._implementation != "dask": return set_axis(other._series, index, implementation=other._implementation) return other._series raise AssertionError("Please report a bug") @@ -314,6 +315,10 @@ def horizontal_concat(dfs: list[Any], implementation: str) -> Any: mpd = get_modin() return mpd.concat(dfs, axis=1) + if implementation == "dask": # pragma: no cover + dd = get_dask() + + return dd.concat(dfs, axis=1) msg = f"Unknown implementation: {implementation}" # pragma: no cover raise TypeError(msg) # pragma: no cover @@ -347,6 +352,10 @@ def vertical_concat(dfs: list[Any], implementation: str) -> Any: mpd = get_modin() return mpd.concat(dfs, axis=0) + if implementation == "dask": # pragma: no cover + dd = get_dask() + + return dd.concat(dfs, axis=0) msg = f"Unknown implementation: {implementation}" # pragma: no cover raise TypeError(msg) # pragma: no cover @@ -387,6 +396,8 @@ def set_axis(obj: T, index: Any, implementation: str) -> T: kwargs["copy"] = False else: # pragma: no cover pass + if implementation == "dask": + raise NotImplementedError("figuring this bit out still!") return obj.set_axis(index, axis=0, **kwargs) # type: ignore[no-any-return, attr-defined] diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index d0cc15ee4..a9594186d 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -14,6 +14,7 @@ from narwhals._arrow.dataframe import ArrowDataFrame from narwhals._pandas_like.dataframe import PandasDataFrame from narwhals.dependencies import get_cudf +from narwhals.dependencies import get_dask from narwhals.dependencies import get_modin from narwhals.dependencies import get_numpy from narwhals.dependencies import get_pandas @@ -248,6 +249,10 @@ def __init__( df, pa.Table ): # pragma: no cover self._dataframe = ArrowDataFrame(df) + elif (dd := get_dask()) is not None and isinstance( + df, dd.DataFrame + ): # pragma: no cover + self._dataframe = PandasDataFrame(df, implementation="dask") else: msg = f"Expected pandas-like dataframe, Polars dataframe, or Polars lazyframe, got: {type(df)}" raise TypeError(msg) diff --git a/narwhals/dependencies.py b/narwhals/dependencies.py index 555407e59..03465f4fb 100644 --- a/narwhals/dependencies.py +++ b/narwhals/dependencies.py @@ -32,6 +32,11 @@ def get_modin() -> Any: # pragma: no cover return None +def get_dask() -> Any: + """Get dask.dataframe module (if already improted - else return None).""" + return sys.modules.get("dask.dataframe", None) + + def get_cudf() -> Any: """Get cudf module (if already imported - else return None).""" return sys.modules.get("cudf", None) diff --git a/narwhals/translate.py b/narwhals/translate.py index ab3f6caf5..75a16ffea 100644 --- a/narwhals/translate.py +++ b/narwhals/translate.py @@ -9,6 +9,7 @@ from typing import overload from narwhals.dependencies import get_cudf +from narwhals.dependencies import get_dask from narwhals.dependencies import get_modin from narwhals.dependencies import get_pandas from narwhals.dependencies import get_polars @@ -216,6 +217,7 @@ def from_native( - pandas.DataFrame - polars.DataFrame - polars.LazyFrame + - dask.dataframe.DataFrame - anything with a `__narwhals_dataframe__` or `__narwhals_lazyframe__` method - pandas.Series - polars.Series @@ -254,6 +256,8 @@ def from_native( and isinstance(native_dataframe, mpd.DataFrame) or (cudf := get_cudf()) is not None and isinstance(native_dataframe, cudf.DataFrame) + or (dd := get_dask()) is not None + and isinstance(native_dataframe, dd.DataFrame) ): if series_only: # pragma: no cover (todo) raise TypeError("Cannot only use `series_only` with dataframe") diff --git a/requirements-dev.txt b/requirements-dev.txt index 0586d00c6..ad28f0763 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -7,4 +7,4 @@ pytest pytest-cov hypothesis scikit-learn - +dask[dataframe] diff --git a/tests/conftest.py b/tests/conftest.py index 5b71cdc0e..faa4e8c47 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,7 +6,7 @@ import pyarrow as pa import pytest -from narwhals.dependencies import get_modin +from narwhals.dependencies import get_modin, get_dask from narwhals.typing import IntoDataFrame from narwhals.utils import parse_version @@ -47,6 +47,10 @@ def modin_constructor(obj: Any) -> IntoDataFrame: # pragma: no cover mpd = get_modin() return mpd.DataFrame(obj).convert_dtypes(dtype_backend="pyarrow") # type: ignore[no-any-return] +def dask_contructor(obj: Any) -> IntoDataFrame: + dd = get_dask() + return dd.DataFrame(obj) # type: ignore[no-any-return] + def polars_constructor(obj: Any) -> IntoDataFrame: return pl.DataFrame(obj) @@ -63,6 +67,8 @@ def polars_lazy_constructor(obj: Any) -> pl.LazyFrame: params.append(polars_constructor) if get_modin() is not None: # pragma: no cover params.append(modin_constructor) +if get_dask() is not None: + params.append(dask_contructor) @pytest.fixture(params=params) diff --git a/tests/frame/test_common.py b/tests/frame/test_common.py index e7ed846cc..4f1808fbd 100644 --- a/tests/frame/test_common.py +++ b/tests/frame/test_common.py @@ -18,7 +18,7 @@ from narwhals.functions import show_versions from narwhals.utils import parse_version from tests.utils import compare_dicts -from tests.utils import maybe_get_modin_df +from tests.utils import maybe_get_modin_df, maybe_get_dask_df if TYPE_CHECKING: from narwhals.dtypes import DType @@ -53,6 +53,7 @@ df_right_pandas = pd.DataFrame({"c": [6, 12, -1], "d": [0, -4, 2]}) df_right_lazy = pl.LazyFrame({"c": [6, 12, -1], "d": [0, -4, 2]}) df_mpd = maybe_get_modin_df(df_pandas) +df_dd = maybe_get_dask_df(df_pandas) df_pa = pa.table({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}) df_pa_na = pa.table({"a": [None, 3, 2], "b": [4, 4, 6], "z": [7.0, None, 9]}) @@ -301,6 +302,7 @@ def test_cross_join_non_pandas() -> None: df_lazy, df_pandas, # df_mpd, (todo: understand the difference between ipython/jupyter and pytest runs) + df_dd, ], ) @pytest.mark.parametrize( @@ -319,7 +321,6 @@ def test_anti_join( result = df.join(other, how="anti", left_on=join_key, right_on=join_key) # type: ignore[arg-type] compare_dicts(result, expected) - @pytest.mark.parametrize( "df_raw", [df_pandas, df_lazy, df_pandas_nullable, df_pandas_pyarrow] ) @@ -352,7 +353,7 @@ def test_columns(df_raw: Any) -> None: assert result == expected -@pytest.mark.parametrize("df_raw", [df_polars, df_pandas, df_mpd, df_lazy]) +@pytest.mark.parametrize("df_raw", [df_polars, df_pandas, df_mpd, df_dd, df_lazy]) def test_lazy_instantiation(df_raw: Any) -> None: result = nw.from_native(df_raw) result_native = nw.to_native(result) @@ -368,7 +369,7 @@ def test_lazy_instantiation_error(df_raw: Any) -> None: _ = nw.DataFrame(df_raw).shape -@pytest.mark.parametrize("df_raw", [df_polars, df_pandas, df_mpd]) +@pytest.mark.parametrize("df_raw", [df_polars, df_pandas, df_mpd, df_dd]) def test_eager_instantiation(df_raw: Any) -> None: result = nw.from_native(df_raw, eager_only=True) result_native = nw.to_native(result) @@ -390,7 +391,7 @@ def test_accepted_dataframes() -> None: nw.LazyFrame(array) -@pytest.mark.parametrize("df_raw", [df_polars, df_pandas, df_mpd, df_pa]) +@pytest.mark.parametrize("df_raw", [df_polars, df_pandas, df_mpd, df_dd, df_pa]) @pytest.mark.filterwarnings("ignore:.*Passing a BlockManager.*:DeprecationWarning") @pytest.mark.skipif( parse_version(pd.__version__) < parse_version("2.0.0"), @@ -403,7 +404,7 @@ def test_convert_pandas(df_raw: Any) -> None: @pytest.mark.parametrize( - "df_raw", [df_polars, df_pandas, df_mpd, df_pandas_nullable, df_pandas_pyarrow] + "df_raw", [df_polars, df_pandas, df_mpd, df_dd, df_pandas_nullable, df_pandas_pyarrow] ) @pytest.mark.filterwarnings( r"ignore:np\.find_common_type is deprecated\.:DeprecationWarning" @@ -418,7 +419,7 @@ def test_convert_numpy(df_raw: Any) -> None: assert result.dtype == "float64" -@pytest.mark.parametrize("df_raw", [df_polars, df_pandas, df_mpd, df_lazy]) +@pytest.mark.parametrize("df_raw", [df_polars, df_pandas, df_mpd, df_dd, df_lazy]) def test_expr_binary(df_raw: Any) -> None: result = nw.from_native(df_raw).with_columns( a=(1 + 3 * nw.col("a")) * (1 / nw.col("a")), @@ -481,7 +482,7 @@ def test_expr_unary(df_raw: Any) -> None: compare_dicts(result_native, expected) -@pytest.mark.parametrize("df_raw", [df_polars, df_pandas, df_mpd, df_lazy]) +@pytest.mark.parametrize("df_raw", [df_polars, df_pandas, df_mpd, df_dd, df_lazy]) def test_expr_transform(df_raw: Any) -> None: result = nw.from_native(df_raw).with_columns( a=nw.col("a").is_between(-1, 1), b=nw.col("b").is_in([4, 5]) @@ -574,6 +575,7 @@ def test_drop_nulls(df_raw: Any) -> None: df_pandas, df_polars, df_mpd, + df_dd, df_pa, ], ) diff --git a/tests/utils.py b/tests/utils.py index 202c332e0..da2a00d00 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -23,6 +23,8 @@ def zip_strict(left: Sequence[Any], right: Sequence[Any]) -> Iterator[Any]: def compare_dicts(result: Any, expected: dict[str, Any]) -> None: if hasattr(result, "collect"): result = result.collect() + if hasattr(result, "_dataframe") and hasattr(result._dataframe, "_implementation") and result._dataframe._implementation == "dask": + result = result.to_pandas() if hasattr(result, "columns"): for key in result.columns: assert key in expected @@ -51,6 +53,15 @@ def maybe_get_modin_df(df_pandas: pd.DataFrame) -> Any: warnings.filterwarnings("ignore", category=UserWarning) return mpd.DataFrame(df_pandas.to_dict(orient="list")) +def maybe_get_dask_df(df_pandas: pd.DataFrame) -> Any: + """Convert a pandas DataFrame to a Dask Dataframe if Dask is availabile.""" + try: + import dask.dataframe as dd + + except ImportError: + return df_pandas.copy() + else: + return dd.from_pandas(df_pandas, npartitions=1) def is_windows() -> bool: """Check if the current platform is Windows.""" From a05ee0677e554d5aeb8aaa9c6339dddad7866277 Mon Sep 17 00:00:00 2001 From: Ben Rutter Date: Mon, 8 Jul 2024 15:47:12 +0100 Subject: [PATCH 02/19] series support added in --- narwhals/_pandas_like/dataframe.py | 10 ++++-- narwhals/_pandas_like/series.py | 27 ++++++++++++++++ narwhals/series.py | 6 ++++ narwhals/translate.py | 2 ++ tests/conftest.py | 8 +++++ tests/series/test_common.py | 52 ++++++++++++++++++++++-------- 6 files changed, 89 insertions(+), 16 deletions(-) diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index f6aeb42aa..9c3cdeb37 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -271,9 +271,15 @@ def sort( # --- convert --- def collect(self) -> PandasDataFrame: + if self._implementation == "dask": + return_df = self._dataframe.compute() + return_implementation = "pandas" + else: + return_df = self._dataframe + return_implementation = self._implementation return PandasDataFrame( - self._dataframe, - implementation=self._implementation, + return_df, + implementation=return_implementation, ) # --- actions --- diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index b233321f2..87deed731 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -166,6 +166,11 @@ def cast( return self._from_series(ser.astype(dtype)) def item(self: Self, index: int | None = None) -> Any: + if self._implementation == "dask": + msg = ( + "Positional indexing is not available in Dask" + ) + raise NotImplementedError(msg) # cuDF doesn't have Series.item(). if index is None: if len(self) != 1: @@ -480,11 +485,19 @@ def to_pandas(self) -> Any: return self._series.to_pandas() elif self._implementation == "modin": # pragma: no cover return self._series._to_pandas() + elif self._implementation == "dask": # pragma: no cover + return self._series.compute() msg = f"Unknown implementation: {self._implementation}" # pragma: no cover raise AssertionError(msg) # --- descriptive --- def is_duplicated(self: Self) -> Self: + if self._implementation == "dask": + msg = ( + "Checking for duplication requires 'duplicated' method " + "which is not currently implemented in dask" + ) + raise NotImplementedError(msg) return self._from_series(self._series.duplicated(keep=False)) def is_empty(self: Self) -> bool: @@ -497,9 +510,13 @@ def null_count(self: Self) -> int: return self._series.isnull().sum() # type: ignore[no-any-return] def is_first_distinct(self: Self) -> Self: + if self._implementation == "dask": + raise NotImplementedError("Not currently implemented in dask") return self._from_series(~self._series.duplicated(keep="first")) def is_last_distinct(self: Self) -> Self: + if self._implementation == "dask": + raise NotImplementedError("Not currently implemented in dask") return self._from_series(~self._series.duplicated(keep="last")) def is_sorted(self: Self, *, descending: bool = False) -> bool: @@ -532,6 +549,16 @@ def quantile( quantile: float, interpolation: Literal["nearest", "higher", "lower", "midpoint", "linear"], ) -> Any: + if self._implementation == "dask": + if interpolation == "linear": + return self._series.quantile(q=quantile) + message = ( + "Dask performs approximate quantile calculations " + + "and does not support specific interpolations methods. " + "Interpolation keywords other than 'linear' are not supported" + ) + raise NotImplementedError(message) return self._series.quantile(q=quantile, interpolation=interpolation) def zip_with(self: Self, mask: Any, other: Any) -> PandasSeries: diff --git a/narwhals/series.py b/narwhals/series.py index 09d9d0d2b..cf1247041 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -6,6 +6,7 @@ from narwhals._arrow.series import ArrowSeries from narwhals.dependencies import get_cudf +from narwhals.dependencies import get_dask from narwhals.dependencies import get_modin from narwhals.dependencies import get_pandas from narwhals.dependencies import get_polars @@ -62,6 +63,11 @@ def __init__( ): # pragma: no cover self._series = PandasSeries(series, implementation="cudf") return + if (dd := get_dask()) is not None and isinstance( + series, dd.Series + ): #pragma: no cover + self._series = PandasSeries(series, implementation="dask") + return if (pa := get_pyarrow()) is not None and isinstance( series, pa.ChunkedArray ): # pragma: no cover diff --git a/narwhals/translate.py b/narwhals/translate.py index 75a16ffea..dd43953e6 100644 --- a/narwhals/translate.py +++ b/narwhals/translate.py @@ -290,6 +290,8 @@ def from_native( and isinstance(native_dataframe, cudf.Series) or (pa := get_pyarrow()) is not None and isinstance(native_dataframe, pa.ChunkedArray) + or (dd := get_dask()) is not None + and isinstance(native_dataframe, dd.Series) ) ): if not allow_series: # pragma: no cover (todo) diff --git a/tests/conftest.py b/tests/conftest.py index faa4e8c47..5d280671a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -105,10 +105,16 @@ def modin_series_constructor(obj: Any) -> Any: # pragma: no cover return mpd.Series(obj).convert_dtypes(dtype_backend="pyarrow") +def dask_series_constructor(obj: Any) -> Any: # pragma: no cover + dd = get_dask() + return dd.Series(obj) + def polars_series_constructor(obj: Any) -> Any: return pl.Series(obj) + + if parse_version(pd.__version__) >= parse_version("2.0.0"): params_series = [ pandas_series_constructor, @@ -120,6 +126,8 @@ def polars_series_constructor(obj: Any) -> Any: params_series.append(polars_series_constructor) if get_modin() is not None: # pragma: no cover params_series.append(modin_series_constructor) +if get_dask() is not None: # pragma: no cover + params_series.append(dask_series_constructor) @pytest.fixture(params=params_series) diff --git a/tests/series/test_common.py b/tests/series/test_common.py index 300f0c69a..1f9c2e3a8 100644 --- a/tests/series/test_common.py +++ b/tests/series/test_common.py @@ -12,7 +12,9 @@ from pandas.testing import assert_series_equal import narwhals as nw +from narwhals.dependencies import get_dask from narwhals.utils import parse_version +from tests.utils import maybe_get_dask_df df_pandas = pd.DataFrame({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}) if parse_version(pd.__version__) >= parse_version("1.5.0"): @@ -40,13 +42,21 @@ df_pandas_nullable = df_pandas df_polars = pl.DataFrame({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}) df_lazy = pl.LazyFrame({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}) +df_dask = maybe_get_dask_df(df_pandas) + +def compute_if_dask(result: Any) -> Any: + if hasattr(result, "_series") and hasattr(result._series, "_implementation") and result._series._implementation == "dask": + return result.to_pandas() + return result + @pytest.mark.parametrize( - "df_raw", [df_pandas, df_polars, df_pandas_nullable, df_pandas_pyarrow] + "df_raw", [df_pandas, df_polars, df_pandas_nullable, df_pandas_pyarrow, df_dask] ) def test_len(df_raw: Any) -> None: result = len(nw.from_native(df_raw["a"], series_only=True)) + result = compute_if_dask(result) assert result == 3 result = nw.from_native(df_raw["a"], series_only=True).len() assert result == 3 @@ -54,16 +64,17 @@ def test_len(df_raw: Any) -> None: assert result == 3 -@pytest.mark.parametrize("df_raw", [df_pandas, df_polars]) +@pytest.mark.parametrize("df_raw", [df_pandas, df_polars, df_dask]) @pytest.mark.filterwarnings("ignore:np.find_common_type is deprecated:DeprecationWarning") def test_is_in(df_raw: Any) -> None: result = nw.from_native(df_raw["a"], series_only=True).is_in([1, 2]) + result = compute_if_dask(result) assert result[0] assert not result[1] assert result[2] -@pytest.mark.parametrize("df_raw", [df_pandas, df_polars]) +@pytest.mark.parametrize("df_raw", [df_pandas, df_polars, df_dask]) @pytest.mark.filterwarnings("ignore:np.find_common_type is deprecated:DeprecationWarning") def test_is_in_other(df_raw: Any) -> None: with pytest.raises( @@ -75,37 +86,42 @@ def test_is_in_other(df_raw: Any) -> None: nw.from_native(df_raw).with_columns(contains=nw.col("c").is_in("sets")) -@pytest.mark.parametrize("df_raw", [df_pandas, df_polars]) +@pytest.mark.parametrize("df_raw", [df_pandas, df_polars, df_dask]) @pytest.mark.filterwarnings("ignore:np.find_common_type is deprecated:DeprecationWarning") def test_filter(df_raw: Any) -> None: result = nw.from_native(df_raw["a"], series_only=True).filter(df_raw["a"] > 1) + result = compute_if_dask(result) expected = np.array([3, 2]) assert (result.to_numpy() == expected).all() result = nw.from_native(df_raw, eager_only=True).select( nw.col("a").filter(nw.col("a") > 1) )["a"] + result = compute_if_dask(result) expected = np.array([3, 2]) assert (result.to_numpy() == expected).all() -@pytest.mark.parametrize("df_raw", [df_pandas, df_polars]) +@pytest.mark.parametrize("df_raw", [df_pandas, df_polars, df_dask]) def test_gt(df_raw: Any) -> None: s = nw.from_native(df_raw["a"], series_only=True) result = s > s # noqa: PLR0124 + result = compute_if_dask(result) assert not result[0] assert not result[1] assert not result[2] result = s > 1 + result = compute_if_dask(result) assert not result[0] assert result[1] assert result[2] @pytest.mark.parametrize( - "df_raw", [df_pandas, df_lazy, df_pandas_nullable, df_pandas_pyarrow] + "df_raw", [df_pandas, df_lazy, df_pandas_nullable, df_pandas_pyarrow, df_dask] ) def test_dtype(df_raw: Any) -> None: result = nw.from_native(df_raw).lazy().collect()["a"].dtype + result = compute_if_dask(result) assert result == nw.Int64 assert result.is_numeric() @@ -132,7 +148,7 @@ def test_reductions(df_raw: Any) -> None: @pytest.mark.parametrize( - "df_raw", [df_pandas, df_lazy, df_pandas_nullable, df_pandas_pyarrow] + "df_raw", [df_pandas, df_lazy, df_pandas_nullable, df_pandas_pyarrow, df_dask] ) def test_boolean_reductions(df_raw: Any) -> None: df = nw.from_native(df_raw).lazy().select(nw.col("a") > 1) @@ -140,7 +156,7 @@ def test_boolean_reductions(df_raw: Any) -> None: assert df.collect()["a"].any() -@pytest.mark.parametrize("df_raw", [df_pandas, df_lazy]) +@pytest.mark.parametrize("df_raw", [df_pandas, df_lazy, df_dask]) @pytest.mark.skipif( parse_version(pd.__version__) < parse_version("2.0.0"), reason="too old for pyarrow" ) @@ -163,6 +179,7 @@ def test_to_numpy() -> None: def test_is_duplicated(df_raw: Any) -> None: series = nw.from_native(df_raw["b"], series_only=True) result = series.is_duplicated() + result = compute_if_dask(result) expected = np.array([True, True, False]) assert (result.to_numpy() == expected).all() @@ -198,7 +215,7 @@ def test_is_last_distinct(df_raw: Any) -> None: assert (result.to_numpy() == expected).all() -@pytest.mark.parametrize("df_raw", [df_pandas, df_polars]) +@pytest.mark.parametrize("df_raw", [df_pandas, df_polars, df_dask]) def test_value_counts(df_raw: Any) -> None: series = nw.from_native(df_raw["b"], series_only=True) sorted_result = series.value_counts(sort=True) @@ -215,7 +232,7 @@ def test_value_counts(df_raw: Any) -> None: assert (a[a[:, 0].argsort()] == expected).all() -@pytest.mark.parametrize("df_raw", [df_pandas, df_polars]) +@pytest.mark.parametrize("df_raw", [df_pandas, df_polars, df_dask]) @pytest.mark.parametrize( ("col", "descending", "expected"), [("a", False, False), ("z", False, True), ("z", True, False)], @@ -223,10 +240,12 @@ def test_value_counts(df_raw: Any) -> None: def test_is_sorted(df_raw: Any, col: str, descending: bool, expected: bool) -> None: # noqa: FBT001 series = nw.from_native(df_raw[col], series_only=True) result = series.is_sorted(descending=descending) + if (dd := get_dask()) is not None and isinstance(df_raw, dd.DataFrame): + result = result.compute() assert result == expected -@pytest.mark.parametrize("df_raw", [df_pandas, df_polars]) +@pytest.mark.parametrize("df_raw", [df_pandas, df_polars, df_dask]) def test_is_sorted_invalid(df_raw: Any) -> None: series = nw.from_native(df_raw["z"], series_only=True) @@ -234,7 +253,7 @@ def test_is_sorted_invalid(df_raw: Any) -> None: series.is_sorted(descending="invalid_type") # type: ignore[arg-type] -@pytest.mark.parametrize("df_raw", [df_pandas, df_polars]) +@pytest.mark.parametrize("df_raw", [df_pandas, df_polars, df_dask]) @pytest.mark.parametrize( ("interpolation", "expected"), [ @@ -254,7 +273,12 @@ def test_quantile( q = 0.3 series = nw.from_native(df_raw["z"], allow_series=True) + if (dd := get_dask()) and (is_dask_test := isinstance(df_raw, dd.DataFrame)): + interpolation = "linear" # other interpolation methods not supported + expected = 7.6 result = series.quantile(quantile=q, interpolation=interpolation) # type: ignore[union-attr] + if is_dask_test: + result = result.compute() assert result == expected @@ -304,7 +328,7 @@ def test_item(df_raw: Any, index: int, expected: int) -> None: s.item(None) -@pytest.mark.parametrize("df_raw", [df_pandas, df_polars]) +@pytest.mark.parametrize("df_raw", [df_pandas, df_polars, df_dask]) @pytest.mark.parametrize("n", [1, 2, 3, 10]) def test_head(df_raw: Any, n: int) -> None: s_raw = df_raw["z"] @@ -313,7 +337,7 @@ def test_head(df_raw: Any, n: int) -> None: assert s.head(n) == nw.from_native(s_raw.head(n), series_only=True) -@pytest.mark.parametrize("df_raw", [df_pandas, df_polars]) +@pytest.mark.parametrize("df_raw", [df_pandas, df_polars, df_dask]) @pytest.mark.parametrize("n", [1, 2, 3, 10]) def test_tail(df_raw: Any, n: int) -> None: s_raw = df_raw["z"] From 70a4144f0c304d4d7e78a83b09e82b0249d6a2dd Mon Sep 17 00:00:00 2001 From: Ben Rutter Date: Tue, 9 Jul 2024 11:35:23 +0100 Subject: [PATCH 03/19] individual workarounds for dask differences --- narwhals/_pandas_like/dataframe.py | 3 ++- narwhals/_pandas_like/group_by.py | 8 ++++++-- narwhals/_pandas_like/utils.py | 6 ++++-- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index 9c3cdeb37..0a38568a7 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -167,7 +167,8 @@ def select( if not new_series: # return empty dataframe, like Polars does return self._from_dataframe(self._dataframe.__class__()) - new_series = validate_indices(new_series) + if self._implementation != "dask": + new_series = validate_indices(new_series) df = horizontal_concat( new_series, implementation=self._implementation, diff --git a/narwhals/_pandas_like/group_by.py b/narwhals/_pandas_like/group_by.py index e2ac36c6f..26f7fa1cb 100644 --- a/narwhals/_pandas_like/group_by.py +++ b/narwhals/_pandas_like/group_by.py @@ -30,10 +30,13 @@ class PandasGroupBy: def __init__(self, df: PandasDataFrame, keys: list[str]) -> None: self._df = df self._keys = list(keys) + keywords = {} + if df._implementation != "dask": + keywords |= {"as_index": True} self._grouped = self._df._dataframe.groupby( list(self._keys), sort=False, - as_index=True, + **keywords, ) def agg( @@ -58,6 +61,7 @@ def agg( raise ValueError(msg) output_names.extend(expr._output_names) + dataframe_is_empty=self._df._dataframe.empty if self._df._implementation != "dask" else len(self._df) == 0 return agg_pandas( self._grouped, exprs, @@ -65,7 +69,7 @@ def agg( output_names, self._from_dataframe, implementation, - dataframe_is_empty=self._df._dataframe.empty, + dataframe_is_empty=dataframe_is_empty, ) def _from_dataframe(self, df: PandasDataFrame) -> PandasDataFrame: diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index 40433d397..265a19da1 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -317,7 +317,8 @@ def horizontal_concat(dfs: list[Any], implementation: str) -> Any: return mpd.concat(dfs, axis=1) if implementation == "dask": # pragma: no cover dd = get_dask() - + if hasattr(dfs[0], "_series"): # TODO: sort out this hack + return dd.concat([i._series for i in dfs], axis=1) return dd.concat(dfs, axis=1) msg = f"Unknown implementation: {implementation}" # pragma: no cover raise TypeError(msg) # pragma: no cover @@ -397,7 +398,8 @@ def set_axis(obj: T, index: Any, implementation: str) -> T: else: # pragma: no cover pass if implementation == "dask": - raise NotImplementedError("figuring this bit out still!") + msg = "Setting axis on columns is not currently supported for dask" + raise NotImplementedError(msg) return obj.set_axis(index, axis=0, **kwargs) # type: ignore[no-any-return, attr-defined] From 638cfeee056527a94bc1be6f38e63584a76abb59 Mon Sep 17 00:00:00 2001 From: Ben Rutter Date: Fri, 12 Jul 2024 16:00:22 +0100 Subject: [PATCH 04/19] initial tidy up --- narwhals/_pandas_like/dataframe.py | 3 +- narwhals/_pandas_like/namespace.py | 4 ++- narwhals/_pandas_like/series.py | 24 ++++------------ narwhals/_pandas_like/utils.py | 46 +++++++++++++++++++++++++++++- 4 files changed, 55 insertions(+), 22 deletions(-) diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index 0a38568a7..9c3cdeb37 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -167,8 +167,7 @@ def select( if not new_series: # return empty dataframe, like Polars does return self._from_dataframe(self._dataframe.__class__()) - if self._implementation != "dask": - new_series = validate_indices(new_series) + new_series = validate_indices(new_series) df = horizontal_concat( new_series, implementation=self._implementation, diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index 5227f22c3..2273c3947 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -70,10 +70,12 @@ def _create_expr_from_callable( # noqa: PLR0913 def _create_series_from_scalar( self, value: Any, series: PandasSeries ) -> PandasSeries: + + index = series._series.index[0:1] if self.implementation != "dask" else None return PandasSeries._from_iterable( [value], name=series._series.name, - index=series._series.index[0:1], + index=index, implementation=self._implementation, ) diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index 87deed731..f5507c40a 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -8,6 +8,7 @@ from narwhals._pandas_like.utils import int_dtype_mapper from narwhals._pandas_like.utils import native_series_from_iterable +from narwhals._pandas_like.utils import not_implemented_in from narwhals._pandas_like.utils import reverse_translate_dtype from narwhals._pandas_like.utils import to_datetime from narwhals._pandas_like.utils import translate_dtype @@ -141,9 +142,7 @@ def _from_iterable( ) def __len__(self) -> int: - if self._implementation == "dask": - return len(self._series) - return self.shape[0] + return len(self._series) @property def name(self) -> str: @@ -165,12 +164,8 @@ def cast( dtype = reverse_translate_dtype(dtype, ser.dtype, self._implementation) return self._from_series(ser.astype(dtype)) + @not_implemented_in("dask") def item(self: Self, index: int | None = None) -> Any: - if self._implementation == "dask": - msg = ( - "Positional indexing is not available in Dask" - ) - raise NotImplementedError(msg) # cuDF doesn't have Series.item(). if index is None: if len(self) != 1: @@ -491,13 +486,8 @@ def to_pandas(self) -> Any: raise AssertionError(msg) # --- descriptive --- + @not_implemented_in("dask") def is_duplicated(self: Self) -> Self: - if self._implementation == "dask": - msg = ( - "Checking for duplication requires 'duplicated' method " - "which is not currently implemented in dask" - ) - raise NotImplementedError(msg) return self._from_series(self._series.duplicated(keep=False)) def is_empty(self: Self) -> bool: @@ -509,14 +499,12 @@ def is_unique(self: Self) -> Self: def null_count(self: Self) -> int: return self._series.isnull().sum() # type: ignore[no-any-return] + @not_implemented_in("dask") def is_first_distinct(self: Self) -> Self: - if self._implementation == "dask": - raise NotImplementedError("Not currently implemented in dask") return self._from_series(~self._series.duplicated(keep="first")) + @not_implemented_in("dask") def is_last_distinct(self: Self) -> Self: - if self._implementation == "dask": - raise NotImplementedError("Not currently implemented in dask") return self._from_series(~self._series.duplicated(keep="last")) def is_sorted(self: Self, *, descending: bool = False) -> bool: diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index 265a19da1..ebe56dad7 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -2,8 +2,10 @@ import secrets from copy import copy +from functools import wraps from typing import TYPE_CHECKING from typing import Any +from typing import Callable from typing import Iterable from typing import TypeVar @@ -317,7 +319,7 @@ def horizontal_concat(dfs: list[Any], implementation: str) -> Any: return mpd.concat(dfs, axis=1) if implementation == "dask": # pragma: no cover dd = get_dask() - if hasattr(dfs[0], "_series"): # TODO: sort out this hack + if hasattr(dfs[0], "_series"): return dd.concat([i._series for i in dfs], axis=1) return dd.concat(dfs, axis=1) msg = f"Unknown implementation: {implementation}" # pragma: no cover @@ -380,6 +382,33 @@ def native_series_from_iterable( if implementation == "arrow": pa = get_pyarrow() return pa.chunked_array([data]) + if implementation == "dask": # pragma: no cover + dd = get_dask() + pd = get_pandas() + if not hasattr(data[0], "compute"): + return ( + pd.Series( + data, + name=name, + index=index, + copy=False, + ) + .pipe(dd.from_pandas) + ) + # TODO: This is a current workaround, but needs more logic to avoid + # computing everything + breakpoint() + return ( + pd.Series( + [i.compute() for i in data], + name=name, + copy=False, + ) + .pipe(dd.from_pandas) + ) + + breakpoint() + raise NotImplementedError msg = f"Unknown implementation: {implementation}" # pragma: no cover raise TypeError(msg) # pragma: no cover @@ -661,3 +690,18 @@ def generate_unique_token(n_bytes: int, columns: list[str]) -> str: # pragma: n "join operation" ) raise AssertionError(msg) + + +def not_implemented_in(*implementations: list[str]) -> Callable: + """ + Produces method decorator to raise not implemented warnings for given implementations + """ + def check_implementation_wrapper(func: Callable) -> Callable: + """Wraps function to return same function + implementation check""" + @wraps(func) + def wrapped_func(self, *args, **kwargs): + if (implementation := self._implementation) in implementations: + raise NotImplementedError(f"Not implemented in {implementation}") + return func(self, *args, **kwargs) + return wrapped_func + return check_implementation_wrapper From 1d349ef450832bbf9e7abf3578889c925246fd3b Mon Sep 17 00:00:00 2001 From: Ben Rutter Date: Fri, 12 Jul 2024 16:15:38 +0100 Subject: [PATCH 05/19] fixes --- narwhals/_pandas_like/namespace.py | 2 +- narwhals/_pandas_like/utils.py | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index 2273c3947..213d29037 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -71,7 +71,7 @@ def _create_series_from_scalar( self, value: Any, series: PandasSeries ) -> PandasSeries: - index = series._series.index[0:1] if self.implementation != "dask" else None + index = series._series.index[0:1] if self._implementation != "dask" else None return PandasSeries._from_iterable( [value], name=series._series.name, diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index ebe56dad7..36d415e5b 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -397,7 +397,6 @@ def native_series_from_iterable( ) # TODO: This is a current workaround, but needs more logic to avoid # computing everything - breakpoint() return ( pd.Series( [i.compute() for i in data], @@ -406,9 +405,6 @@ def native_series_from_iterable( ) .pipe(dd.from_pandas) ) - - breakpoint() - raise NotImplementedError msg = f"Unknown implementation: {implementation}" # pragma: no cover raise TypeError(msg) # pragma: no cover @@ -653,6 +649,8 @@ def to_datetime(implementation: str) -> Any: return get_modin().to_datetime if implementation == "cudf": return get_cudf().to_datetime + if implementation == "dask": + return get_dask().to_datetime raise AssertionError From b2ab1107eb727ea9691b096b0a53adca03f48a11 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 12 Jul 2024 15:44:55 +0000 Subject: [PATCH 06/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- narwhals/_pandas_like/dataframe.py | 1 - narwhals/_pandas_like/namespace.py | 6 +++++- narwhals/dataframe.py | 5 ----- narwhals/series.py | 6 ------ tests/conftest.py | 7 ++++--- tests/frame/test_common.py | 4 +++- tests/series/test_common.py | 13 ++++++++----- tests/utils.py | 8 +++++++- 8 files changed, 27 insertions(+), 23 deletions(-) diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index 24fda3191..2f3f59754 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -508,7 +508,6 @@ def to_pandas(self) -> Any: return self._dataframe.compute() return self._native_dataframe.to_pandas() # pragma: no cover - def write_parquet(self, file: Any) -> Any: self._native_dataframe.to_parquet(file) diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index c43688587..42252f0fc 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -78,7 +78,11 @@ def _create_expr_from_callable( def _create_series_from_scalar( self, value: Any, series: PandasLikeSeries ) -> PandasLikeSeries: - index = series._series.index[0:1] if self._implementation != Implementation.DASK else None + index = ( + series._series.index[0:1] + if self._implementation != Implementation.DASK + else None + ) return PandasLikeSeries._from_iterable( [value], name=series._native_series.name, diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index 439493138..d5cec6a06 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -11,11 +11,6 @@ from typing import TypeVar from typing import overload -from narwhals._arrow.dataframe import ArrowDataFrame -from narwhals._pandas_like.dataframe import PandasDataFrame -from narwhals.dependencies import get_cudf -from narwhals.dependencies import get_dask -from narwhals.dependencies import get_modin from narwhals.dependencies import get_numpy from narwhals.dependencies import get_polars from narwhals.dtypes import to_narwhals_dtype diff --git a/narwhals/series.py b/narwhals/series.py index dc3e89168..c7e2c2dcc 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -4,12 +4,6 @@ from typing import Any from typing import Literal - -from narwhals._arrow.series import ArrowSeries -from narwhals.dependencies import get_cudf -from narwhals.dependencies import get_dask -from narwhals.dependencies import get_modin -from narwhals.dependencies import get_pandas from narwhals.dependencies import get_polars from narwhals.dtypes import to_narwhals_dtype from narwhals.dtypes import translate_dtype diff --git a/tests/conftest.py b/tests/conftest.py index cbc3e01e7..e61f4dde9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,7 +6,8 @@ import pyarrow as pa import pytest -from narwhals.dependencies import get_modin, get_dask +from narwhals.dependencies import get_dask +from narwhals.dependencies import get_modin from narwhals.typing import IntoDataFrame from narwhals.utils import parse_version @@ -47,6 +48,7 @@ def modin_constructor(obj: Any) -> IntoDataFrame: # pragma: no cover mpd = get_modin() return mpd.DataFrame(obj).convert_dtypes(dtype_backend="pyarrow") # type: ignore[no-any-return] + def dask_contructor(obj: Any) -> IntoDataFrame: dd = get_dask() return dd.DataFrame(obj) # type: ignore[no-any-return] @@ -109,12 +111,11 @@ def dask_series_constructor(obj: Any) -> Any: # pragma: no cover dd = get_dask() return dd.Series(obj) + def polars_series_constructor(obj: Any) -> Any: return pl.Series(obj) - - if parse_version(pd.__version__) >= parse_version("2.0.0"): params_series = [ pandas_series_constructor, diff --git a/tests/frame/test_common.py b/tests/frame/test_common.py index 280f01f1d..4021eeefb 100644 --- a/tests/frame/test_common.py +++ b/tests/frame/test_common.py @@ -17,7 +17,8 @@ from narwhals.functions import show_versions from narwhals.utils import parse_version from tests.utils import compare_dicts -from tests.utils import maybe_get_modin_df, maybe_get_dask_df +from tests.utils import maybe_get_dask_df +from tests.utils import maybe_get_modin_df if TYPE_CHECKING: from narwhals.typing import IntoFrameT @@ -89,6 +90,7 @@ def test_std(df_raw: Any) -> None: } compare_dicts(result_native, expected) + @pytest.mark.parametrize( "df_raw", [df_pandas, df_lazy, df_pandas_nullable, df_pandas_pyarrow] ) diff --git a/tests/series/test_common.py b/tests/series/test_common.py index 5b3b953db..b482e20ad 100644 --- a/tests/series/test_common.py +++ b/tests/series/test_common.py @@ -11,8 +11,8 @@ from numpy.testing import assert_array_equal from pandas.testing import assert_series_equal -from narwhals.dependencies import get_dask import narwhals.stable.v1 as nw +from narwhals.dependencies import get_dask from narwhals.utils import parse_version from tests.utils import maybe_get_dask_df @@ -44,11 +44,15 @@ df_lazy = pl.LazyFrame({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}) df_dask = maybe_get_dask_df(df_pandas) + def compute_if_dask(result: Any) -> Any: - if hasattr(result, "_series") and hasattr(result._series, "_implementation") and result._series._implementation == "dask": - return result.to_pandas() + if ( + hasattr(result, "_series") + and hasattr(result._series, "_implementation") + and result._series._implementation == "dask" + ): + return result.to_pandas() return result - @pytest.mark.parametrize( @@ -64,7 +68,6 @@ def test_len(df_raw: Any) -> None: assert result == 3 - @pytest.mark.parametrize("df_raw", [df_pandas, df_polars, df_dask]) def test_is_in(df_raw: Any) -> None: result = nw.from_native(df_raw["a"], series_only=True).is_in([1, 2]) diff --git a/tests/utils.py b/tests/utils.py index 0d74937e0..2c7d49b2d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -22,7 +22,11 @@ def zip_strict(left: Sequence[Any], right: Sequence[Any]) -> Iterator[Any]: def compare_dicts(result: Any, expected: dict[str, Any]) -> None: if hasattr(result, "collect"): result = result.collect() - if hasattr(result, "_dataframe") and hasattr(result._dataframe, "_implementation") and result._dataframe._implementation == "dask": + if ( + hasattr(result, "_dataframe") + and hasattr(result._dataframe, "_implementation") + and result._dataframe._implementation == "dask" + ): result = result.to_pandas() if hasattr(result, "columns"): for key in result.columns: @@ -52,6 +56,7 @@ def maybe_get_modin_df(df_pandas: pd.DataFrame) -> Any: warnings.filterwarnings("ignore", category=UserWarning) return mpd.DataFrame(df_pandas.to_dict(orient="list")) + def maybe_get_dask_df(df_pandas: pd.DataFrame) -> Any: """Convert a pandas DataFrame to a Dask Dataframe if Dask is availabile.""" try: @@ -62,6 +67,7 @@ def maybe_get_dask_df(df_pandas: pd.DataFrame) -> Any: else: return dd.from_pandas(df_pandas, npartitions=1) + def is_windows() -> bool: """Check if the current platform is Windows.""" return sys.platform in ["win32", "cygwin"] From f10c279a8d29ad5ad1f38ba8df094004c0797440 Mon Sep 17 00:00:00 2001 From: Ben Rutter Date: Fri, 12 Jul 2024 17:17:19 +0100 Subject: [PATCH 07/19] some fixes --- narwhals/_pandas_like/dataframe.py | 16 ++++++++-------- narwhals/_pandas_like/namespace.py | 5 +++-- narwhals/_pandas_like/series.py | 8 +++++--- narwhals/_pandas_like/utils.py | 9 ++++----- narwhals/translate.py | 17 ++++++++++++++++- tests/series/test_common.py | 9 +++++---- 6 files changed, 41 insertions(+), 23 deletions(-) diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index 2f3f59754..50ef7c729 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -314,11 +314,11 @@ def sort( # --- convert --- def collect(self) -> PandasLikeDataFrame: - if self._implementation == "dask": - return_df = self._dataframe.compute() - return_implementation = "pandas" + if self._implementation is Implementation.DASK: + return_df = self._native_dataframe.compute() + return_implementation = Implementation.PANDAS else: - return_df = self._dataframe + return_df = self._native_dataframe return_implementation = self._implementation return PandasLikeDataFrame( return_df, @@ -495,9 +495,9 @@ def to_numpy(self) -> Any: import numpy as np return np.hstack([self[col].to_numpy()[:, None] for col in self.columns]) - if self._implementation == Implementation.DASK: - return self._dataframe.compute().to_numpy() - return self._dataframe.to_numpy() + if self._implementation is Implementation.DASK: + return self._native_dataframe.compute().to_numpy() + return self._native_dataframe.to_numpy() def to_pandas(self) -> Any: if self._implementation is Implementation.PANDAS: @@ -505,7 +505,7 @@ def to_pandas(self) -> Any: if self._implementation is Implementation.MODIN: # pragma: no cover return self._native_dataframe._to_pandas() if self._implementation is Implementation.DASK: # pragma: no cover - return self._dataframe.compute() + return self._native_dataframe.compute() return self._native_dataframe.to_pandas() # pragma: no cover def write_parquet(self, file: Any) -> Any: diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index 42252f0fc..f065557df 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -14,6 +14,7 @@ from narwhals._pandas_like.series import PandasLikeSeries from narwhals._pandas_like.utils import create_native_series from narwhals._pandas_like.utils import horizontal_concat +from narwhals._pandas_like.utils import Implementation from narwhals._pandas_like.utils import vertical_concat from narwhals.utils import flatten @@ -79,8 +80,8 @@ def _create_series_from_scalar( self, value: Any, series: PandasLikeSeries ) -> PandasLikeSeries: index = ( - series._series.index[0:1] - if self._implementation != Implementation.DASK + series._native_series.index[0:1] + if self._implementation is not Implementation.DASK else None ) return PandasLikeSeries._from_iterable( diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index 42cb36538..92739cd7e 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -117,6 +117,7 @@ def __native_namespace__(self) -> Any: def __narwhals_series__(self) -> Self: return self + @not_implemented_in(Implementation.DASK) def __getitem__(self, idx: int | slice | Sequence[int]) -> Any: if isinstance(idx, int): return self._native_series.iloc[idx] @@ -156,7 +157,7 @@ def _from_iterable( ) def __len__(self) -> int: - return len(self._series) + return len(self._native_series) @property def name(self) -> str: @@ -509,7 +510,8 @@ def to_pandas(self) -> Any: return self._native_series.to_pandas() elif self._implementation is Implementation.MODIN: # pragma: no cover return self._native_series._to_pandas() - elif self._implementation is Iplmementation.DASK: # pragma: no cover + elif self._implementation is Implementation.DASK: # pragma: no cover + return self._native_series.compute() msg = f"Unknown implementation: {self._implementation}" # pragma: no cover raise AssertionError(msg) @@ -568,7 +570,7 @@ def quantile( quantile: float, interpolation: Literal["nearest", "higher", "lower", "midpoint", "linear"], ) -> Any: - if self._implementation == "dask": + if self._implementation is Implementation.DASK: if interpolation == "linear": return self._native_series.quantile(q=quantile) message = ( diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index 63340ff56..9b6c02f84 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -83,8 +83,8 @@ def validate_dataframe_comparand(index: Any, other: Any) -> Any: if isinstance(other, PandasLikeSeries): if other.len() == 1: # broadcast - return other._series.iloc[0] - if other._native_series.index is not index and other._implementation != Implementation.DASK: + return other._native_series.iloc[0] + if other._native_series.index is not index and other._implementation is not Implementation.DASK: return set_axis( other._native_series, index, @@ -92,7 +92,6 @@ def validate_dataframe_comparand(index: Any, other: Any) -> Any: backend_version=other._backend_version, ) return other._native_series - msg = "Please report a bug" # pragma: no cover return other._series raise AssertionError("Please report a bug") @@ -205,7 +204,7 @@ def horizontal_concat( mpd = get_modin() return mpd.concat(dfs, axis=1) - if implementation == "dask": # pragma: no cover + if implementation is Implementation.DASK: # pragma: no cover dd = get_dask() if hasattr(dfs[0], "_series"): return dd.concat([i._series for i in dfs], axis=1) @@ -597,7 +596,7 @@ def generate_unique_token(n_bytes: int, columns: list[str]) -> str: # pragma: n raise AssertionError(msg) -def not_implemented_in(*implementations: list[str]) -> Callable: +def not_implemented_in(*implementations: list[Implementation]) -> Callable: """ Produces method decorator to raise not implemented warnings for given implementations """ diff --git a/narwhals/translate.py b/narwhals/translate.py index 913f8b643..078c842ff 100644 --- a/narwhals/translate.py +++ b/narwhals/translate.py @@ -344,11 +344,12 @@ def from_native( # noqa: PLR0915 if series_only: # pragma: no cover (todo) msg = "Cannot only use `series_only` with dask.dataframe.DataFrame" raise TypeError(msg) + import dask return DataFrame( PandasLikeDataFrame( native_dataframe, implementation=Implementation.DASK, - backend_version=parse_version(dd.__version__), + backend_version=parse_version(dask.__version__), ), is_polars=False, backend_version=parse_version(pa.__version__), @@ -441,6 +442,20 @@ def from_native( # noqa: PLR0915 is_polars=False, backend_version=parse_version(pa.__version__), ) + elif (dd := get_dask()) is not None and isinstance(native_dataframe, dd.Series): + if not allow_series: # pragma: no cover (todo) + msg = "Please set `allow_series=True`" + raise TypeError(msg) + import dask + return Series( + PandasLikeSeries( + native_dataframe, + implementation=Implementation.DASK, + backend_version=parse_version(dask.__version__), + ), + is_polars=False, + backend_version=parse_version(dask.__version__) + ) elif hasattr(native_dataframe, "__narwhals_series__"): # pragma: no cover if not allow_series: # pragma: no cover (todo) msg = "Please set `allow_series=True`" diff --git a/tests/series/test_common.py b/tests/series/test_common.py index b482e20ad..6c05f98ba 100644 --- a/tests/series/test_common.py +++ b/tests/series/test_common.py @@ -14,6 +14,7 @@ import narwhals.stable.v1 as nw from narwhals.dependencies import get_dask from narwhals.utils import parse_version +from narwhals._pandas_like.utils import Implementation from tests.utils import maybe_get_dask_df df_pandas = pd.DataFrame({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}) @@ -47,9 +48,9 @@ def compute_if_dask(result: Any) -> Any: if ( - hasattr(result, "_series") - and hasattr(result._series, "_implementation") - and result._series._implementation == "dask" + hasattr(result, "_native_series") + and hasattr(result._native_series, "_implementation") + and result._series._implementation == Implementation.DASK ): return result.to_pandas() return result @@ -68,7 +69,7 @@ def test_len(df_raw: Any) -> None: assert result == 3 -@pytest.mark.parametrize("df_raw", [df_pandas, df_polars, df_dask]) +@pytest.mark.parametrize("df_raw", [df_pandas, df_polars]) def test_is_in(df_raw: Any) -> None: result = nw.from_native(df_raw["a"], series_only=True).is_in([1, 2]) result = compute_if_dask(result) From dd15289c9876f94d1584736ce2f373411201292f Mon Sep 17 00:00:00 2001 From: Ben Rutter Date: Mon, 15 Jul 2024 15:00:11 +0100 Subject: [PATCH 08/19] compute avoidance plus conflict-resolve-issue fix --- narwhals/_pandas_like/group_by.py | 8 +++++--- narwhals/_pandas_like/utils.py | 17 ++++------------- 2 files changed, 9 insertions(+), 16 deletions(-) diff --git a/narwhals/_pandas_like/group_by.py b/narwhals/_pandas_like/group_by.py index 8c428e5c2..ac4d7e9f8 100644 --- a/narwhals/_pandas_like/group_by.py +++ b/narwhals/_pandas_like/group_by.py @@ -59,7 +59,11 @@ def agg( raise ValueError(msg) output_names.extend(expr._output_names) - dataframe_is_empty=self._df._dataframe.empty if self._df._implementation != Implementation.DASK else len(self._df) == 0 + dataframe_is_empty= ( + self._df._native_dataframe.empty + if self._df._implementation != Implementation.DASK + else len(self._df._native_dataframe) == 0 + ) return agg_pandas( self._grouped, exprs, @@ -67,8 +71,6 @@ def agg( output_names, self._from_native_dataframe, dataframe_is_empty=dataframe_is_empty, - self._from_native_dataframe, - dataframe_is_empty=dataframe_is_empty, implementation=implementation, backend_version=self._df._backend_version, ) diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index 9b6c02f84..4fa2db8d3 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -277,22 +277,13 @@ def native_series_from_iterable( if implementation is Implementation.ARROW: # pragma: no cover dd = get_dask() pd = get_pandas() - if not hasattr(data[0], "compute"): - return ( - pd.Series( - data, - name=name, - index=index, - copy=False, - ) - .pipe(dd.from_pandas) - ) - # TODO: This is a current workaround, but needs more logic to avoid - # computing everything + if hasattr(data[0], "compute"): + return dd.concat([i.to_series() for i in data]) return ( pd.Series( - [i.compute() for i in data], + data, name=name, + index=index, copy=False, ) .pipe(dd.from_pandas) From 7dac812ffc7cd62e320c640c5b57402395c7aa11 Mon Sep 17 00:00:00 2001 From: Ben Rutter Date: Mon, 15 Jul 2024 16:45:28 +0100 Subject: [PATCH 09/19] test logic updates post merge fix --- tests/series/test_common.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/series/test_common.py b/tests/series/test_common.py index 4bb41a1d9..ab98df067 100644 --- a/tests/series/test_common.py +++ b/tests/series/test_common.py @@ -15,6 +15,7 @@ from narwhals.utils import parse_version from narwhals._pandas_like.utils import Implementation from tests.utils import maybe_get_dask_df +from tests.conftest import dask_series_constructor data = [1, 3, 2] @@ -22,6 +23,17 @@ data_sorted = [7.0, 8, 9] +def compute_if_dask(result: Any) -> Any: + if ( + hasattr(result, "_native_series") + and hasattr(result._native_series, "_implementation") + and result._series._implementation is Implementation.DASK + ): + + return result.to_pandas() + return result + + def test_len(constructor_series: Any) -> None: series = nw.from_native(constructor_series(data), series_only=True) @@ -240,8 +252,11 @@ def test_quantile( request.applymarker(pytest.mark.xfail) q = 0.3 + if (is_dask_test := constructor_series == dask_series_constructor): + interpolation = "linear" # other interpolation unsupported in dask series = nw.from_native(constructor_series(data_sorted), allow_series=True) + result = series.quantile(quantile=q, interpolation=interpolation) # type: ignore[union-attr] if is_dask_test: result = result.compute() From c727d1c5bcb57252eac81627647dd10cc6d07de2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 15 Jul 2024 15:49:21 +0000 Subject: [PATCH 10/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- narwhals/_pandas_like/group_by.py | 2 +- narwhals/_pandas_like/namespace.py | 2 +- narwhals/_pandas_like/series.py | 1 - narwhals/_pandas_like/utils.py | 33 ++++++++++++++++++------------ narwhals/translate.py | 4 +++- tests/frame/test_common.py | 1 - tests/series/test_common.py | 11 ++-------- 7 files changed, 27 insertions(+), 27 deletions(-) diff --git a/narwhals/_pandas_like/group_by.py b/narwhals/_pandas_like/group_by.py index 9244bac50..2c50376d2 100644 --- a/narwhals/_pandas_like/group_by.py +++ b/narwhals/_pandas_like/group_by.py @@ -59,7 +59,7 @@ def agg( raise ValueError(msg) output_names.extend(expr._output_names) - dataframe_is_empty= ( + dataframe_is_empty = ( self._df._native_dataframe.empty if self._df._implementation != Implementation.DASK else len(self._df._native_dataframe) == 0 diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index 3090ee7c5..c8c73408d 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -12,9 +12,9 @@ from narwhals._pandas_like.expr import PandasLikeExpr from narwhals._pandas_like.selectors import PandasSelectorNamespace from narwhals._pandas_like.series import PandasLikeSeries +from narwhals._pandas_like.utils import Implementation from narwhals._pandas_like.utils import create_native_series from narwhals._pandas_like.utils import horizontal_concat -from narwhals._pandas_like.utils import Implementation from narwhals._pandas_like.utils import vertical_concat from narwhals.utils import flatten diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index e2b643d2e..857e03915 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -575,7 +575,6 @@ def quantile( return self._native_series.quantile(q=quantile) message = ( "Dask performs approximate quantile calculations " - "and does not support specific interpolations methods. " "Interpolation keywords other than 'linear' are not supported" ) diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index 81122b22f..4d6c664a6 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -1,10 +1,9 @@ from __future__ import annotations import secrets -from copy import copy -from functools import wraps from enum import Enum from enum import auto +from functools import wraps from typing import TYPE_CHECKING from typing import Any from typing import Callable @@ -58,7 +57,10 @@ def validate_column_comparand(index: Any, other: Any) -> Any: if other.len() == 1: # broadcast return other.item() - if other._native_series.index is not index and other._implementation != Implementation.DASK: + if ( + other._native_series.index is not index + and other._implementation != Implementation.DASK + ): return set_axis( other._native_series, index, @@ -84,7 +86,10 @@ def validate_dataframe_comparand(index: Any, other: Any) -> Any: if other.len() == 1: # broadcast return other._native_series.iloc[0] - if other._native_series.index is not index and other._implementation is not Implementation.DASK: + if ( + other._native_series.index is not index + and other._implementation is not Implementation.DASK + ): return set_axis( other._native_series, index, @@ -104,6 +109,7 @@ def maybe_evaluate_expr(df: PandasDataFrame, expr: Any) -> Any: return expr._call(df) return expr + def parse_into_expr( implementation: str, into_expr: IntoPandasExpr | IntoArrowExpr ) -> PandasExpr: @@ -262,15 +268,12 @@ def native_series_from_iterable( pd = get_pandas() if hasattr(data[0], "compute"): return dd.concat([i.to_series() for i in data]) - return ( - pd.Series( - data, - name=name, - index=index, - copy=False, - ) - .pipe(dd.from_pandas) - ) + return pd.Series( + data, + name=name, + index=index, + copy=False, + ).pipe(dd.from_pandas) msg = f"Unknown implementation: {implementation}" # pragma: no cover raise TypeError(msg) # pragma: no cover @@ -574,12 +577,16 @@ def not_implemented_in(*implementations: list[Implementation]) -> Callable: """ Produces method decorator to raise not implemented warnings for given implementations """ + def check_implementation_wrapper(func: Callable) -> Callable: """Wraps function to return same function + implementation check""" + @wraps(func) def wrapped_func(self, *args, **kwargs): if (implementation := self._implementation) in implementations: raise NotImplementedError(f"Not implemented in {implementation}") return func(self, *args, **kwargs) + return wrapped_func + return check_implementation_wrapper diff --git a/narwhals/translate.py b/narwhals/translate.py index 34133600f..bc8c3564e 100644 --- a/narwhals/translate.py +++ b/narwhals/translate.py @@ -396,6 +396,7 @@ def from_native( # noqa: PLR0915 msg = "Cannot only use `series_only` with dask.dataframe.DataFrame" raise TypeError(msg) import dask + return DataFrame( PandasLikeDataFrame( native_dataframe, @@ -515,6 +516,7 @@ def from_native( # noqa: PLR0915 msg = "Please set `allow_series=True`" raise TypeError(msg) import dask + return Series( PandasLikeSeries( native_dataframe, @@ -522,7 +524,7 @@ def from_native( # noqa: PLR0915 backend_version=parse_version(dask.__version__), ), is_polars=False, - backend_version=parse_version(dask.__version__) + backend_version=parse_version(dask.__version__), ) elif hasattr(native_object, "__narwhals_series__"): if not allow_series: diff --git a/tests/frame/test_common.py b/tests/frame/test_common.py index 03d448d97..01eb34954 100644 --- a/tests/frame/test_common.py +++ b/tests/frame/test_common.py @@ -129,7 +129,6 @@ def test_expr_binary(request: Any, constructor: Any) -> None: compare_dicts(result_native, expected) - def test_expr_transform(request: Any, constructor: Any) -> None: if "pyarrow_table" in str(constructor): request.applymarker(pytest.mark.xfail) diff --git a/tests/series/test_common.py b/tests/series/test_common.py index ab98df067..a1a480ad0 100644 --- a/tests/series/test_common.py +++ b/tests/series/test_common.py @@ -11,13 +11,11 @@ from pandas.testing import assert_series_equal import narwhals.stable.v1 as nw +from narwhals._pandas_like.utils import Implementation from narwhals.dependencies import get_dask from narwhals.utils import parse_version -from narwhals._pandas_like.utils import Implementation -from tests.utils import maybe_get_dask_df from tests.conftest import dask_series_constructor - data = [1, 3, 2] data_dups = [4, 4, 6] data_sorted = [7.0, 8, 9] @@ -29,7 +27,6 @@ def compute_if_dask(result: Any) -> Any: and hasattr(result._native_series, "_implementation") and result._series._implementation is Implementation.DASK ): - return result.to_pandas() return result @@ -44,7 +41,6 @@ def test_len(constructor_series: Any) -> None: assert result == 3 - def test_is_in(request: Any, constructor_series: Any) -> None: if "pyarrow_series" in str(constructor_series): request.applymarker(pytest.mark.xfail) @@ -56,7 +52,6 @@ def test_is_in(request: Any, constructor_series: Any) -> None: assert result[2] - def test_is_in_other(constructor: Any) -> None: df_raw = constructor({"a": data}) with pytest.raises( @@ -68,7 +63,6 @@ def test_is_in_other(constructor: Any) -> None: nw.from_native(df_raw).with_columns(contains=nw.col("a").is_in("sets")) - def test_dtype(constructor_series: Any) -> None: series = nw.from_native(constructor_series(data), series_only=True) result = series.dtype @@ -176,7 +170,6 @@ def test_is_last_distinct(request: Any, constructor_series: Any) -> None: assert (result.to_numpy() == expected).all() - def test_value_counts(request: Any, constructor_series: Any) -> None: if "pyarrow_series" in str(constructor_series): request.applymarker(pytest.mark.xfail) @@ -252,7 +245,7 @@ def test_quantile( request.applymarker(pytest.mark.xfail) q = 0.3 - if (is_dask_test := constructor_series == dask_series_constructor): + if is_dask_test := constructor_series == dask_series_constructor: interpolation = "linear" # other interpolation unsupported in dask series = nw.from_native(constructor_series(data_sorted), allow_series=True) From 94b9dc7a0dca9b84dbd7710cdd05807b0cd488b8 Mon Sep 17 00:00:00 2001 From: benrutter Date: Mon, 15 Jul 2024 19:27:43 +0100 Subject: [PATCH 11/19] docstring update for stablev1 --- narwhals/stable/v1.py | 1 + 1 file changed, 1 insertion(+) diff --git a/narwhals/stable/v1.py b/narwhals/stable/v1.py index 45468d782..50b767d9c 100644 --- a/narwhals/stable/v1.py +++ b/narwhals/stable/v1.py @@ -685,6 +685,7 @@ def from_native( - pandas.DataFrame - polars.DataFrame - polars.LazyFrame + - dask.dataframe.DataFrame - anything with a `__narwhals_dataframe__` or `__narwhals_lazyframe__` method - pandas.Series - polars.Series From 8eff3a43844f0891bba74bf2626198cb9249f4df Mon Sep 17 00:00:00 2001 From: benrutter Date: Mon, 15 Jul 2024 20:24:52 +0100 Subject: [PATCH 12/19] type fixes --- narwhals/_pandas_like/dataframe.py | 4 +-- narwhals/_pandas_like/group_by.py | 2 +- narwhals/_pandas_like/series.py | 10 +++---- narwhals/_pandas_like/utils.py | 44 +++++++++++++++++------------- narwhals/translate.py | 10 ++++--- tests/series/test_common.py | 5 ++-- 6 files changed, 41 insertions(+), 34 deletions(-) diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index 89ddfc803..a9b63b5a6 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -67,9 +67,9 @@ def __native_namespace__(self) -> Any: return get_modin() if self._implementation is Implementation.CUDF: # pragma: no cover return get_cudf() - if self._implementation == "dask": # pragma: no cover + if self._implementation is Implementation.DASK: # pragma: no cover return get_dask() - msg = f"Expected pandas/modin/cudf, got: {type(self._implementation)}" # pragma: no cover + msg = f"Expected pandas/modin/cudf/dask, got: {type(self._implementation)}" # pragma: no cover raise AssertionError(msg) def __len__(self) -> int: diff --git a/narwhals/_pandas_like/group_by.py b/narwhals/_pandas_like/group_by.py index 2c50376d2..6cd6a8e55 100644 --- a/narwhals/_pandas_like/group_by.py +++ b/narwhals/_pandas_like/group_by.py @@ -29,7 +29,7 @@ def __init__(self, df: PandasLikeDataFrame, keys: list[str]) -> None: self._df = df self._keys = list(keys) keywords = {} - if df._implementation != "dask": + if df._implementation is not Implementation.DASK: keywords |= {"as_index": True} self._grouped = self._df._native_dataframe.groupby( list(self._keys), diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index 857e03915..151c11cba 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -109,7 +109,7 @@ def __native_namespace__(self) -> Any: return get_modin() if self._implementation is Implementation.CUDF: # pragma: no cover return get_cudf() - if self._implementation == "dask": # pragma: no cover + if self._implementation is Implementation.DASK: # pragma: no cover return get_dask() msg = f"Expected pandas/modin/cudf, got: {type(self._implementation)}" # pragma: no cover raise AssertionError(msg) @@ -179,7 +179,7 @@ def cast( dtype = reverse_translate_dtype(dtype, ser.dtype, self._implementation) return self._from_native_series(ser.astype(dtype)) - @not_implemented_in("dask") + @not_implemented_in(Implementation.DASK) def item(self: Self, index: int | None = None) -> Any: # cuDF doesn't have Series.item(). if index is None: @@ -516,7 +516,7 @@ def to_pandas(self) -> Any: raise AssertionError(msg) # --- descriptive --- - @not_implemented_in("dask") + @not_implemented_in(Implementation.DASK) def is_duplicated(self: Self) -> Self: return self._from_native_series(self._native_series.duplicated(keep=False)) @@ -529,11 +529,11 @@ def is_unique(self: Self) -> Self: def null_count(self: Self) -> int: return self._native_series.isna().sum() # type: ignore[no-any-return] - @not_implemented_in("dask") + @not_implemented_in(Implementation.DASK) def is_first_distinct(self: Self) -> Self: return self._from_native_series(~self._native_series.duplicated(keep="first")) - @not_implemented_in("dask") + @not_implemented_in(Implementation.DASK) def is_last_distinct(self: Self) -> Self: return self._from_native_series(~self._native_series.duplicated(keep="last")) diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index 4d6c664a6..4e429ea21 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -13,14 +13,19 @@ from narwhals.dependencies import get_cudf from narwhals.dependencies import get_dask from narwhals.dependencies import get_modin +from narwhals.dependencies import get_numpy from narwhals.dependencies import get_pandas +from narwhals.dependencies import get_pyarrow from narwhals.utils import isinstance_or_issubclass T = TypeVar("T") if TYPE_CHECKING: + from narwhals._arrow.typing import IntoArrowExpr + from narwhals._pandas_like.dataframe import PandasLikeDataFrame from narwhals._pandas_like.expr import PandasLikeExpr from narwhals._pandas_like.series import PandasLikeSeries + from narwhals._pandas_like.typing import IntoPandasLikeExpr from narwhals.dtypes import DType ExprT = TypeVar("ExprT", bound=PandasLikeExpr) @@ -101,18 +106,17 @@ def validate_dataframe_comparand(index: Any, other: Any) -> Any: raise AssertionError("Please report a bug") -def maybe_evaluate_expr(df: PandasDataFrame, expr: Any) -> Any: +def maybe_evaluate_expr(df: PandasLikeDataFrame, expr: Any) -> Any: """Evaluate `expr` if it's an expression, otherwise return it as is.""" - from narwhals._pandas_like.expr import PandasExpr - if isinstance(expr, PandasExpr): + if isinstance(expr, PandasLikeExpr): return expr._call(df) return expr def parse_into_expr( - implementation: str, into_expr: IntoPandasExpr | IntoArrowExpr -) -> PandasExpr: + implementation: str, into_expr: IntoPandasLikeExpr | IntoArrowExpr +) -> PandasLikeExpr: """Parse `into_expr` as an expression. For example, in Polars, we can do both `df.select('a')` and `df.select(pl.col('a'))`. @@ -127,17 +131,17 @@ def parse_into_expr( from narwhals._arrow.expr import ArrowExpr from narwhals._arrow.namespace import ArrowNamespace from narwhals._arrow.series import ArrowSeries - from narwhals._pandas_like.expr import PandasExpr - from narwhals._pandas_like.namespace import PandasNamespace - from narwhals._pandas_like.series import PandasSeries + from narwhals._pandas_like.expr import PandasLikeExpr + from narwhals._pandas_like.namespace import PandasLikeNamespace + from narwhals._pandas_like.series import PandasLikeSeries if implementation == "arrow": - plx: ArrowNamespace | PandasNamespace = ArrowNamespace() + plx: ArrowNamespace | PandasLikeNamespace = ArrowNamespace() else: - plx = PandasNamespace(implementation=implementation) - if isinstance(into_expr, (PandasExpr, ArrowExpr)): + PandasLikeNamespace(implementation=implementation) + if isinstance(into_expr, (PandasLikeExpr, ArrowExpr)): return into_expr # type: ignore[return-value] - if isinstance(into_expr, (PandasSeries, ArrowSeries)): + if isinstance(into_expr, (PandasLikeSeries, ArrowSeries)): return plx._create_expr_from_series(into_expr) # type: ignore[arg-type, return-value] if isinstance(into_expr, str): return plx.col(into_expr) # type: ignore[return-value] @@ -233,7 +237,7 @@ def vertical_concat( mpd = get_modin() return mpd.concat(dfs, axis=0) - if implementation == "dask": # pragma: no cover + if implementation is Implementation.DASK: # pragma: no cover dd = get_dask() return dd.concat(dfs, axis=0) @@ -263,10 +267,10 @@ def native_series_from_iterable( if implementation == "arrow": pa = get_pyarrow() return pa.chunked_array([data]) - if implementation is Implementation.ARROW: # pragma: no cover + if implementation is Implementation.DASK: # pragma: no cover dd = get_dask() pd = get_pandas() - if hasattr(data[0], "compute"): + if hasattr(next(data), "compute"): return dd.concat([i.to_series() for i in data]) return pd.Series( data, @@ -298,7 +302,7 @@ def set_axis( kwargs["copy"] = False else: # pragma: no cover pass - if implementation == "dask": + if implementation is Implementation.DASK: msg = "Setting axis on columns is not currently supported for dask" raise NotImplementedError(msg) return obj.set_axis(index, axis=0, **kwargs) # type: ignore[no-any-return, attr-defined] @@ -532,7 +536,7 @@ def to_datetime(implementation: Implementation) -> Any: return get_modin().to_datetime if implementation is Implementation.CUDF: return get_cudf().to_datetime - if implementation == "dask": + if implementation is Implementation.DASK: return get_dask().to_datetime raise AssertionError @@ -573,12 +577,14 @@ def generate_unique_token(n_bytes: int, columns: list[str]) -> str: # pragma: n raise AssertionError(msg) -def not_implemented_in(*implementations: list[Implementation]) -> Callable: +def not_implemented_in( + *implementations: Implementation, +) -> Callable[[Callable[[Any], Any]], Callable[[Any], Any]]: """ Produces method decorator to raise not implemented warnings for given implementations """ - def check_implementation_wrapper(func: Callable) -> Callable: + def check_implementation_wrapper(func: Callable[[Any], Any]) -> Callable[[Any], Any]: """Wraps function to return same function + implementation check""" @wraps(func) diff --git a/narwhals/translate.py b/narwhals/translate.py index bc8c3564e..f8a43e595 100644 --- a/narwhals/translate.py +++ b/narwhals/translate.py @@ -391,7 +391,7 @@ def from_native( # noqa: PLR0915 backend_version=parse_version(pa.__version__), level="full", ) - elif (dd := get_dask()) is not None and isinstance(native_dataframe, dd.DataFrame): + elif (dd := get_dask()) is not None and isinstance(native_object, dd.DataFrame): if series_only: # pragma: no cover (todo) msg = "Cannot only use `series_only` with dask.dataframe.DataFrame" raise TypeError(msg) @@ -399,12 +399,13 @@ def from_native( # noqa: PLR0915 return DataFrame( PandasLikeDataFrame( - native_dataframe, + native_object, implementation=Implementation.DASK, backend_version=parse_version(dask.__version__), ), is_polars=False, backend_version=parse_version(pa.__version__), + level="full", ) elif hasattr(native_object, "__dataframe__"): if eager_only or series_only: @@ -511,7 +512,7 @@ def from_native( # noqa: PLR0915 backend_version=parse_version(pa.__version__), level="full", ) - elif (dd := get_dask()) is not None and isinstance(native_dataframe, dd.Series): + elif (dd := get_dask()) is not None and isinstance(native_object, dd.Series): if not allow_series: # pragma: no cover (todo) msg = "Please set `allow_series=True`" raise TypeError(msg) @@ -519,12 +520,13 @@ def from_native( # noqa: PLR0915 return Series( PandasLikeSeries( - native_dataframe, + native_object, implementation=Implementation.DASK, backend_version=parse_version(dask.__version__), ), is_polars=False, backend_version=parse_version(dask.__version__), + level="full", ) elif hasattr(native_object, "__narwhals_series__"): if not allow_series: diff --git a/tests/series/test_common.py b/tests/series/test_common.py index a1a480ad0..bd326d56f 100644 --- a/tests/series/test_common.py +++ b/tests/series/test_common.py @@ -12,7 +12,6 @@ import narwhals.stable.v1 as nw from narwhals._pandas_like.utils import Implementation -from narwhals.dependencies import get_dask from narwhals.utils import parse_version from tests.conftest import dask_series_constructor @@ -209,8 +208,8 @@ def test_is_sorted( series = nw.from_native(constructor_series(input_data), series_only=True) result = series.is_sorted(descending=descending) - if (dd := get_dask()) is not None and isinstance(df_raw, dd.DataFrame): - result = result.compute() + if constructor_series == dask_series_constructor: + result = result.compute() # type: ignore assert result == expected From 30450ca953abd3de4e03185256f3fb8965e15ffd Mon Sep 17 00:00:00 2001 From: benrutter Date: Mon, 15 Jul 2024 21:47:10 +0100 Subject: [PATCH 13/19] more type fixes --- narwhals/_pandas_like/group_by.py | 2 +- narwhals/_pandas_like/series.py | 2 +- narwhals/_pandas_like/utils.py | 10 ++++++---- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/narwhals/_pandas_like/group_by.py b/narwhals/_pandas_like/group_by.py index 6cd6a8e55..10873ce77 100644 --- a/narwhals/_pandas_like/group_by.py +++ b/narwhals/_pandas_like/group_by.py @@ -28,7 +28,7 @@ class PandasLikeGroupBy: def __init__(self, df: PandasLikeDataFrame, keys: list[str]) -> None: self._df = df self._keys = list(keys) - keywords = {} + keywords: dict[str, bool] = {} if df._implementation is not Implementation.DASK: keywords |= {"as_index": True} self._grouped = self._df._native_dataframe.groupby( diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index 151c11cba..c6164dd50 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -111,7 +111,7 @@ def __native_namespace__(self) -> Any: return get_cudf() if self._implementation is Implementation.DASK: # pragma: no cover return get_dask() - msg = f"Expected pandas/modin/cudf, got: {type(self._implementation)}" # pragma: no cover + msg = f"Expected pandas/modin/cudf/dask, got: {type(self._implementation)}" # pragma: no cover raise AssertionError(msg) def __narwhals_series__(self) -> Self: diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index 4e429ea21..2012e5dce 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -9,6 +9,7 @@ from typing import Callable from typing import Iterable from typing import TypeVar +from typing import Self from narwhals.dependencies import get_cudf from narwhals.dependencies import get_dask @@ -270,7 +271,7 @@ def native_series_from_iterable( if implementation is Implementation.DASK: # pragma: no cover dd = get_dask() pd = get_pandas() - if hasattr(next(data), "compute"): + if hasattr(data[0], "compute"): # type: ignore return dd.concat([i.to_series() for i in data]) return pd.Series( data, @@ -579,16 +580,17 @@ def generate_unique_token(n_bytes: int, columns: list[str]) -> str: # pragma: n def not_implemented_in( *implementations: Implementation, -) -> Callable[[Callable[[Any], Any]], Callable[[Any], Any]]: +) -> Callable[[Callable], Callable]: # type: ignore """ Produces method decorator to raise not implemented warnings for given implementations """ - def check_implementation_wrapper(func: Callable[[Any], Any]) -> Callable[[Any], Any]: + def check_implementation_wrapper(func: Callable) -> Callable: # type: ignore """Wraps function to return same function + implementation check""" - @wraps(func) + @wraps(func) # type: ignore def wrapped_func(self, *args, **kwargs): + """Checks implementation then carries out wrapped call""" if (implementation := self._implementation) in implementations: raise NotImplementedError(f"Not implemented in {implementation}") return func(self, *args, **kwargs) From fcd3253dd4c19e2d484579b4d59b620998de1ac7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 15 Jul 2024 20:47:30 +0000 Subject: [PATCH 14/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- narwhals/_pandas_like/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index 2012e5dce..906d937f1 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -9,7 +9,6 @@ from typing import Callable from typing import Iterable from typing import TypeVar -from typing import Self from narwhals.dependencies import get_cudf from narwhals.dependencies import get_dask From 87bea6d369f3e8fd97411b287c872ba4d646c854 Mon Sep 17 00:00:00 2001 From: Ben Rutter Date: Tue, 16 Jul 2024 11:12:35 +0100 Subject: [PATCH 15/19] type fixes and linting --- narwhals/_pandas_like/namespace.py | 1 - narwhals/_pandas_like/utils.py | 67 ++++-------------------------- narwhals/dataframe.py | 3 ++ narwhals/dependencies.py | 2 +- narwhals/stable/v1.py | 11 +++++ narwhals/translate.py | 6 +++ narwhals/utils.py | 4 +- tests/series/test_common.py | 2 +- tests/utils.py | 2 +- 9 files changed, 31 insertions(+), 67 deletions(-) diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index c8c73408d..6a250fdc3 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -20,7 +20,6 @@ if TYPE_CHECKING: from narwhals._pandas_like.typing import IntoPandasLikeExpr - from narwhals._pandas_like.utils import Implementation class PandasLikeNamespace: diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index 906d937f1..6e3481800 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -13,19 +13,14 @@ from narwhals.dependencies import get_cudf from narwhals.dependencies import get_dask from narwhals.dependencies import get_modin -from narwhals.dependencies import get_numpy from narwhals.dependencies import get_pandas -from narwhals.dependencies import get_pyarrow from narwhals.utils import isinstance_or_issubclass T = TypeVar("T") if TYPE_CHECKING: - from narwhals._arrow.typing import IntoArrowExpr - from narwhals._pandas_like.dataframe import PandasLikeDataFrame from narwhals._pandas_like.expr import PandasLikeExpr from narwhals._pandas_like.series import PandasLikeSeries - from narwhals._pandas_like.typing import IntoPandasLikeExpr from narwhals.dtypes import DType ExprT = TypeVar("ExprT", bound=PandasLikeExpr) @@ -106,52 +101,6 @@ def validate_dataframe_comparand(index: Any, other: Any) -> Any: raise AssertionError("Please report a bug") -def maybe_evaluate_expr(df: PandasLikeDataFrame, expr: Any) -> Any: - """Evaluate `expr` if it's an expression, otherwise return it as is.""" - - if isinstance(expr, PandasLikeExpr): - return expr._call(df) - return expr - - -def parse_into_expr( - implementation: str, into_expr: IntoPandasLikeExpr | IntoArrowExpr -) -> PandasLikeExpr: - """Parse `into_expr` as an expression. - - For example, in Polars, we can do both `df.select('a')` and `df.select(pl.col('a'))`. - We do the same in Narwhals: - - - if `into_expr` is already an expression, just return it - - if it's a Series, then convert it to an expression - - if it's a numpy array, then convert it to a Series and then to an expression - - if it's a string, then convert it to an expression - - else, raise - """ - from narwhals._arrow.expr import ArrowExpr - from narwhals._arrow.namespace import ArrowNamespace - from narwhals._arrow.series import ArrowSeries - from narwhals._pandas_like.expr import PandasLikeExpr - from narwhals._pandas_like.namespace import PandasLikeNamespace - from narwhals._pandas_like.series import PandasLikeSeries - - if implementation == "arrow": - plx: ArrowNamespace | PandasLikeNamespace = ArrowNamespace() - else: - PandasLikeNamespace(implementation=implementation) - if isinstance(into_expr, (PandasLikeExpr, ArrowExpr)): - return into_expr # type: ignore[return-value] - if isinstance(into_expr, (PandasLikeSeries, ArrowSeries)): - return plx._create_expr_from_series(into_expr) # type: ignore[arg-type, return-value] - if isinstance(into_expr, str): - return plx.col(into_expr) # type: ignore[return-value] - if (np := get_numpy()) is not None and isinstance(into_expr, np.ndarray): - series = create_native_series(into_expr, implementation=implementation) - return plx._create_expr_from_series(series) # type: ignore[arg-type, return-value] - msg = f"Expected IntoExpr, got {type(into_expr)}" # pragma: no cover - raise AssertionError(msg) - - def create_native_series( iterable: Any, index: Any = None, @@ -264,13 +213,10 @@ def native_series_from_iterable( mpd = get_modin() return mpd.Series(data, name=name, index=index) - if implementation == "arrow": - pa = get_pyarrow() - return pa.chunked_array([data]) if implementation is Implementation.DASK: # pragma: no cover dd = get_dask() pd = get_pandas() - if hasattr(data[0], "compute"): # type: ignore + if hasattr(data[0], "compute"): # type: ignore[index] return dd.concat([i.to_series() for i in data]) return pd.Series( data, @@ -579,19 +525,20 @@ def generate_unique_token(n_bytes: int, columns: list[str]) -> str: # pragma: n def not_implemented_in( *implementations: Implementation, -) -> Callable[[Callable], Callable]: # type: ignore +) -> Callable[[Callable], Callable]: # type: ignore[type-arg] """ Produces method decorator to raise not implemented warnings for given implementations """ - def check_implementation_wrapper(func: Callable) -> Callable: # type: ignore + def check_implementation_wrapper(func: Callable) -> Callable: # type: ignore[type-arg] """Wraps function to return same function + implementation check""" - @wraps(func) # type: ignore - def wrapped_func(self, *args, **kwargs): + @wraps(func) + def wrapped_func(self, *args, **kwargs): # type: ignore[no-untyped-def] # noqa: ANN001, ANN002, ANN003, ANN202 """Checks implementation then carries out wrapped call""" if (implementation := self._implementation) in implementations: - raise NotImplementedError(f"Not implemented in {implementation}") + msg = f"Not implemented in {implementation}" + raise NotImplementedError(msg) return func(self, *args, **kwargs) return wrapped_func diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index 85130affc..43aab0a77 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -559,10 +559,13 @@ def __getitem__( @overload def to_dict(self, *, as_series: Literal[True] = ...) -> dict[str, Series]: ... + @overload def to_dict(self, *, as_series: Literal[False]) -> dict[str, list[Any]]: ... + @overload def to_dict(self, *, as_series: bool) -> dict[str, Series] | dict[str, list[Any]]: ... + def to_dict( self, *, as_series: bool = True ) -> dict[str, Series] | dict[str, list[Any]]: diff --git a/narwhals/dependencies.py b/narwhals/dependencies.py index 22e2c9fda..bf3a11b03 100644 --- a/narwhals/dependencies.py +++ b/narwhals/dependencies.py @@ -33,7 +33,7 @@ def get_modin() -> Any: # pragma: no cover def get_dask() -> Any: - """Get dask.dataframe module (if already improted - else return None).""" + """Get dask.dataframe module (if already imported - else return None).""" return sys.modules.get("dask.dataframe", None) diff --git a/narwhals/stable/v1.py b/narwhals/stable/v1.py index 50b767d9c..324109271 100644 --- a/narwhals/stable/v1.py +++ b/narwhals/stable/v1.py @@ -119,10 +119,13 @@ def lazy(self) -> LazyFrame[Any]: # thing that I need to understand category theory for @overload # type: ignore[override] def to_dict(self, *, as_series: Literal[True] = ...) -> dict[str, Series]: ... + @overload def to_dict(self, *, as_series: Literal[False]) -> dict[str, list[Any]]: ... + @overload def to_dict(self, *, as_series: bool) -> dict[str, Series] | dict[str, list[Any]]: ... + def to_dict( self, *, as_series: bool = True ) -> dict[str, Series] | dict[str, list[Any]]: @@ -450,12 +453,20 @@ class Schema(NwSchema): @overload def _stableify(obj: NwDataFrame[IntoFrameT]) -> DataFrame[IntoFrameT]: ... + + @overload def _stableify(obj: NwLazyFrame[IntoFrameT]) -> LazyFrame[IntoFrameT]: ... + + @overload def _stableify(obj: NwSeries) -> Series: ... + + @overload def _stableify(obj: NwExpr) -> Expr: ... + + @overload def _stableify(obj: Any) -> Any: ... diff --git a/narwhals/translate.py b/narwhals/translate.py index f8a43e595..a7b582659 100644 --- a/narwhals/translate.py +++ b/narwhals/translate.py @@ -29,12 +29,18 @@ def to_native( narwhals_object: DataFrame[IntoDataFrameT], *, strict: Literal[True] = ... ) -> IntoDataFrameT: ... + + @overload def to_native( narwhals_object: LazyFrame[IntoFrameT], *, strict: Literal[True] = ... ) -> IntoFrameT: ... + + @overload def to_native(narwhals_object: Series, *, strict: Literal[True] = ...) -> Any: ... + + @overload def to_native(narwhals_object: Any, *, strict: bool) -> Any: ... diff --git a/narwhals/utils.py b/narwhals/utils.py index ee19d4b6a..192c436b2 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -312,9 +312,7 @@ def is_ordered_categorical(series: Series) -> bool: isinstance(series._compliant_series, InterchangeSeries) and series.dtype == dtypes.Categorical ): - return series._compliant_series._native_series.describe_categorical[ # type: ignore[no-any-return] - "is_ordered" - ] + return series._compliant_series._native_series.describe_categorical["is_ordered"] # type: ignore[no-any-return] if series.dtype == dtypes.Enum: return True if series.dtype != dtypes.Categorical: diff --git a/tests/series/test_common.py b/tests/series/test_common.py index bd326d56f..b87950899 100644 --- a/tests/series/test_common.py +++ b/tests/series/test_common.py @@ -209,7 +209,7 @@ def test_is_sorted( series = nw.from_native(constructor_series(input_data), series_only=True) result = series.is_sorted(descending=descending) if constructor_series == dask_series_constructor: - result = result.compute() # type: ignore + result = result.compute() # type: ignore[attr-defined] assert result == expected diff --git a/tests/utils.py b/tests/utils.py index d1f12d411..42226970a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -58,7 +58,7 @@ def maybe_get_modin_df(df_pandas: pd.DataFrame) -> Any: def maybe_get_dask_df(df_pandas: pd.DataFrame) -> Any: - """Convert a pandas DataFrame to a Dask Dataframe if Dask is availabile.""" + """Convert a pandas DataFrame to a Dask Dataframe if Dask is available.""" try: import dask.dataframe as dd From 362f2832ea1ac61417535896fc8f4453d3345c1c Mon Sep 17 00:00:00 2001 From: Ben Rutter Date: Tue, 16 Jul 2024 13:06:39 +0100 Subject: [PATCH 16/19] test inclusion and some fixes --- narwhals/_pandas_like/dataframe.py | 3 +++ narwhals/_pandas_like/utils.py | 16 +++++++++++++--- narwhals/translate.py | 2 +- tests/conftest.py | 11 ++++++++--- tests/frame/test_common.py | 8 ++++++-- tests/tpch_q1_test.py | 16 +++++++++++++--- tests/utils.py | 11 ++++++++--- 7 files changed, 52 insertions(+), 15 deletions(-) diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index a9b63b5a6..97691dd8b 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -203,6 +203,9 @@ def select( new_series = evaluate_into_exprs(self, *exprs, **named_exprs) if not new_series: # return empty dataframe, like Polars does + if self._implementation is Implementation.DASK: + dd = get_dask() + return self._from_native_dataframe(dd.from_dict({}, npartitions=1)) return self._from_native_dataframe(self._native_dataframe.__class__()) new_series = validate_indices(new_series) df = horizontal_concat( diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index 6e3481800..27ae29c07 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -59,7 +59,7 @@ def validate_column_comparand(index: Any, other: Any) -> Any: return other.item() if ( other._native_series.index is not index - and other._implementation != Implementation.DASK + and other._implementation is not Implementation.DASK ): return set_axis( other._native_series, @@ -67,6 +67,15 @@ def validate_column_comparand(index: Any, other: Any) -> Any: implementation=other._implementation, backend_version=other._backend_version, ) + elif ( + other._native_series.index is not index + and other._implementation is Implementation.DASK + ): + msg = ( + "Index mismatch between columns and reindexing is not " + "currently supported within Dask implementation" + ) + raise ValueError(msg) return other._native_series return other @@ -148,8 +157,9 @@ def horizontal_concat( return mpd.concat(dfs, axis=1) if implementation is Implementation.DASK: # pragma: no cover dd = get_dask() - if hasattr(dfs[0], "_series"): - return dd.concat([i._series for i in dfs], axis=1) + pd = get_pandas() + if isinstance(dfs[0], pd.Series): + return dd.concat([i.pipe(dd.from_pandas) for i in dfs], axis=1) return dd.concat(dfs, axis=1) msg = f"Unknown implementation: {implementation}" # pragma: no cover raise TypeError(msg) # pragma: no cover diff --git a/narwhals/translate.py b/narwhals/translate.py index a7b582659..c934f8711 100644 --- a/narwhals/translate.py +++ b/narwhals/translate.py @@ -410,7 +410,7 @@ def from_native( # noqa: PLR0915 backend_version=parse_version(dask.__version__), ), is_polars=False, - backend_version=parse_version(pa.__version__), + backend_version=parse_version(dask.__version__), level="full", ) elif hasattr(native_object, "__dataframe__"): diff --git a/tests/conftest.py b/tests/conftest.py index eeeab5dcc..bed7ab26f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,4 @@ +import contextlib from typing import Any from typing import Callable @@ -11,6 +12,10 @@ from narwhals.typing import IntoDataFrame from narwhals.utils import parse_version +with contextlib.suppress(ImportError): + import dask.dataframe # noqa: F401 + import modin # noqa: F401 + def pytest_addoption(parser: Any) -> None: parser.addoption( @@ -49,9 +54,9 @@ def modin_constructor(obj: Any) -> IntoDataFrame: # pragma: no cover return mpd.DataFrame(obj).convert_dtypes(dtype_backend="pyarrow") # type: ignore[no-any-return] -def dask_contructor(obj: Any) -> IntoDataFrame: +def dask_constructor(obj: Any) -> IntoDataFrame: dd = get_dask() - return dd.DataFrame(obj) # type: ignore[no-any-return] + return pd.DataFrame(obj).pipe(dd.from_pandas, npartitions=1) # type: ignore[no-any-return] def polars_eager_constructor(obj: Any) -> IntoDataFrame: @@ -81,7 +86,7 @@ def pyarrow_table_constructor(obj: Any) -> IntoDataFrame: eager_constructors.append(modin_constructor) if get_dask() is not None: - eager_constructors.append(dask_contructor) + eager_constructors.append(dask_constructor) @pytest.fixture(params=eager_constructors) diff --git a/tests/frame/test_common.py b/tests/frame/test_common.py index 01eb34954..965282df3 100644 --- a/tests/frame/test_common.py +++ b/tests/frame/test_common.py @@ -10,9 +10,11 @@ import pytest import narwhals.stable.v1 as nw +from narwhals.dependencies import get_dask from narwhals.functions import _get_deps_info from narwhals.functions import _get_sys_info from narwhals.functions import show_versions +from tests.conftest import dask_constructor from tests.utils import compare_dicts data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} @@ -21,8 +23,10 @@ def test_empty_select(constructor: Any) -> None: - result = nw.from_native(constructor({"a": [1, 2, 3]}), eager_only=True).select() - assert result.shape == (0, 0) + result = nw.from_native(constructor({"a": [1, 2, 3]}), eager_only=True).select().shape + if constructor == dask_constructor and (dd := get_dask()) is not None: + result = dd.compute(result)[0] + assert result == (0, 0) def test_std(constructor: Any) -> None: diff --git a/tests/tpch_q1_test.py b/tests/tpch_q1_test.py index 372169715..b25fae278 100644 --- a/tests/tpch_q1_test.py +++ b/tests/tpch_q1_test.py @@ -11,13 +11,14 @@ import pytest import narwhals.stable.v1 as nw +from narwhals.dependencies import get_dask from narwhals.utils import parse_version from tests.utils import compare_dicts @pytest.mark.parametrize( "library", - ["pandas", "polars", "pyarrow"], + ["pandas", "polars", "pyarrow", "dask"], ) @pytest.mark.filterwarnings("ignore:.*Passing a BlockManager.*:DeprecationWarning") def test_q1(library: str, request: Any) -> None: @@ -28,6 +29,8 @@ def test_q1(library: str, request: Any) -> None: df_raw["l_shipdate"] = pd.to_datetime(df_raw["l_shipdate"]) elif library == "polars": df_raw = pl.scan_parquet("tests/data/lineitem.parquet") + elif library == "dask" and (dd := get_dask()) is not None: + df_raw = dd.read_parquet("tests/data/lineitem.parquet") else: request.applymarker(pytest.mark.xfail) df_raw = pq.read_table("tests/data/lineitem.parquet") @@ -86,7 +89,7 @@ def test_q1(library: str, request: Any) -> None: @pytest.mark.parametrize( "library", - ["pandas", "polars"], + ["pandas", "polars", "dask"], ) @pytest.mark.filterwarnings( "ignore:.*Passing a BlockManager.*:DeprecationWarning", @@ -98,8 +101,15 @@ def test_q1_w_generic_funcs(library: str, request: Any) -> None: elif library == "pandas": df_raw = pd.read_parquet("tests/data/lineitem.parquet") df_raw["l_shipdate"] = pd.to_datetime(df_raw["l_shipdate"]) - else: + elif library == "polars": df_raw = pl.read_parquet("tests/data/lineitem.parquet") + elif library == "dask" and (dd := get_dask()) is not None: + df_raw = dd.read_parquet("tests/data/lineitem.parquet") + df_raw["l_shipdate"] = dd.to_datetime(df_raw["l_shipdate"]) + else: + request.applymarker(pytest.mark.xfail) + df_raw = pq.read_table("tests/data/lineitem.parquet") + var_1 = datetime(1998, 9, 2) df = nw.from_native(df_raw, eager_only=True) query_result = ( diff --git a/tests/utils.py b/tests/utils.py index 42226970a..47e9c1592 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -9,6 +9,8 @@ import pandas as pd +from narwhals._pandas_like.utils import Implementation + def zip_strict(left: Sequence[Any], right: Sequence[Any]) -> Iterator[Any]: if len(left) != len(right): @@ -21,9 +23,12 @@ def compare_dicts(result: Any, expected: dict[str, Any]) -> None: if hasattr(result, "collect"): result = result.collect() if ( - hasattr(result, "_dataframe") - and hasattr(result._dataframe, "_implementation") - and result._dataframe._implementation == "dask" + hasattr(result, "_native_dataframe") + and hasattr(result._native_dataframe, "_implementation") + and result._dataframe._implementation is Implementation.DASK + ) or ( + hasattr(result, "__native_namespace__") + and "dask" in str(result.__native_namespace__()) ): result = result.to_pandas() if hasattr(result, "columns"): From 2dc41f5e32747aa6d7cc3db283c21d63faff1389 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 19 Jul 2024 20:11:58 +0100 Subject: [PATCH 17/19] fixup --- narwhals/_pandas_like/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index 27ae29c07..5151ff445 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -106,8 +106,8 @@ def validate_dataframe_comparand(index: Any, other: Any) -> Any: backend_version=other._backend_version, ) return other._native_series - return other._series - raise AssertionError("Please report a bug") + msg = "Please report a bug" + raise AssertionError(msg) def create_native_series( From c5f63a2840d8aa44c5537d19bc85df9f71426d25 Mon Sep 17 00:00:00 2001 From: benrutter Date: Sat, 20 Jul 2024 20:40:10 +0100 Subject: [PATCH 18/19] additional functionality for dask around exprs --- narwhals/_expression_parsing.py | 7 +++++-- narwhals/_pandas_like/series.py | 12 +++++++++++- narwhals/_pandas_like/utils.py | 26 +++++++++++++------------- 3 files changed, 29 insertions(+), 16 deletions(-) diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index 1cc9d3327..33c0c4fd6 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -13,6 +13,7 @@ from narwhals.dependencies import get_numpy from narwhals.utils import flatten +from narwhals._pandas_like.utils import Implementation if TYPE_CHECKING: from narwhals._arrow.dataframe import ArrowDataFrame @@ -192,10 +193,12 @@ def func(df: CompliantDataFrame) -> list[CompliantSeries]: if expr._output_names is not None and ( [s.name for s in out] != expr._output_names ): # pragma: no cover - msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues" - raise AssertionError(msg) + if not (hasattr(expr, "_implementation") and expr._implementation is Implementation.DASK): + msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues" + raise AssertionError(msg) return out + # Try tracking root and output names by combining them from all # expressions appearing in args and kwargs. If any anonymous # expression appears (e.g. nw.all()), then give up on tracking root names diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index c6164dd50..33fa15027 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -179,7 +179,6 @@ def cast( dtype = reverse_translate_dtype(dtype, ser.dtype, self._implementation) return self._from_native_series(ser.astype(dtype)) - @not_implemented_in(Implementation.DASK) def item(self: Self, index: int | None = None) -> Any: # cuDF doesn't have Series.item(). if index is None: @@ -189,7 +188,11 @@ def item(self: Self, index: int | None = None) -> Any: f" or an explicit index is provided (Series is of length {len(self)})" ) raise ValueError(msg) + if self._implementation is Implementation.DASK: + return self._native_series.max() # hack: taking aggregation of 1 item return self._native_series.iloc[0] + if self._implementation is Implementation.DASK: + raise NotImplementedError("Dask does not support index locating") return self._native_series.iloc[index] def to_frame(self) -> Any: @@ -202,6 +205,8 @@ def to_frame(self) -> Any: ) def to_list(self) -> Any: + if self._implementation is Implementation.DASK: + return self._native_series.compute().to_list() return self._native_series.to_list() def is_between( @@ -614,6 +619,11 @@ def __init__(self, series: PandasLikeSeries) -> None: def get_categories(self) -> PandasLikeSeries: s = self._pandas_series._native_series + if self._pandas_series._implementation is Implementation.DASK: + pd = get_pandas() + dd = get_dask() + native_series = pd.Series(s.cat.as_known().cat.categories, name=s.name).pipe(dd.from_pandas) + return self._pandas_series._from_native_series(native_series) return self._pandas_series._from_native_series( s.__class__(s.cat.categories, name=s.name) ) diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index 5151ff445..db58ccb6b 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -67,15 +67,6 @@ def validate_column_comparand(index: Any, other: Any) -> Any: implementation=other._implementation, backend_version=other._backend_version, ) - elif ( - other._native_series.index is not index - and other._implementation is Implementation.DASK - ): - msg = ( - "Index mismatch between columns and reindexing is not " - "currently supported within Dask implementation" - ) - raise ValueError(msg) return other._native_series return other @@ -128,6 +119,10 @@ def create_native_series( elif implementation is Implementation.CUDF: cudf = get_cudf() series = cudf.Series(iterable, index=index, name="") + elif implementation is Implementation.DASK: + pd = get_pandas() + dd = get_dask() + series = pd.Series(iterable, index=index, name="").pipe(dd.from_pandas) return PandasLikeSeries( series, implementation=implementation, backend_version=backend_version ) @@ -227,7 +222,7 @@ def native_series_from_iterable( dd = get_dask() pd = get_pandas() if hasattr(data[0], "compute"): # type: ignore[index] - return dd.concat([i.to_series() for i in data]) + return dd.concat([i.to_series() for i in data]).rename(name) return pd.Series( data, name=name, @@ -245,6 +240,8 @@ def set_axis( implementation: Implementation, backend_version: tuple[int, ...], ) -> T: + if implementation is Implementation.DASK: + return obj # HACK: dask doesn't really reset indexes so much, so assuming its fine if implementation is Implementation.PANDAS and backend_version < ( 1, ): # pragma: no cover @@ -258,9 +255,6 @@ def set_axis( kwargs["copy"] = False else: # pragma: no cover pass - if implementation is Implementation.DASK: - msg = "Setting axis on columns is not currently supported for dask" - raise NotImplementedError(msg) return obj.set_axis(index, axis=0, **kwargs) # type: ignore[no-any-return, attr-defined] @@ -324,6 +318,12 @@ def translate_dtype(column: Any) -> DType: if str(dtype) == "date32[day][pyarrow]": return dtypes.Date() if str(dtype) == "object": + if (dd := get_dask()) is not None and isinstance(column, dd.Series): + # below we'll try to infer strings or objects from values but + # with dask we can only do this if we compute so we'll avoid and + # treat as a string (this *may* be a bad call, as it does not allow + # for object types which are potentially valid) + return dtypes.String() if (idx := column.first_valid_index()) is not None and isinstance( column.loc[idx], str ): From c38311eb1e64e13cc05504dea6cec2f4811fbb97 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 20 Jul 2024 19:40:27 +0000 Subject: [PATCH 19/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- narwhals/_expression_parsing.py | 8 +++++--- narwhals/_pandas_like/series.py | 4 +++- narwhals/_pandas_like/utils.py | 4 +++- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index 33c0c4fd6..7940ddeb0 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -11,9 +11,9 @@ from typing import cast from typing import overload +from narwhals._pandas_like.utils import Implementation from narwhals.dependencies import get_numpy from narwhals.utils import flatten -from narwhals._pandas_like.utils import Implementation if TYPE_CHECKING: from narwhals._arrow.dataframe import ArrowDataFrame @@ -193,12 +193,14 @@ def func(df: CompliantDataFrame) -> list[CompliantSeries]: if expr._output_names is not None and ( [s.name for s in out] != expr._output_names ): # pragma: no cover - if not (hasattr(expr, "_implementation") and expr._implementation is Implementation.DASK): + if not ( + hasattr(expr, "_implementation") + and expr._implementation is Implementation.DASK + ): msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues" raise AssertionError(msg) return out - # Try tracking root and output names by combining them from all # expressions appearing in args and kwargs. If any anonymous # expression appears (e.g. nw.all()), then give up on tracking root names diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index 33fa15027..8c43f30f5 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -622,7 +622,9 @@ def get_categories(self) -> PandasLikeSeries: if self._pandas_series._implementation is Implementation.DASK: pd = get_pandas() dd = get_dask() - native_series = pd.Series(s.cat.as_known().cat.categories, name=s.name).pipe(dd.from_pandas) + native_series = pd.Series(s.cat.as_known().cat.categories, name=s.name).pipe( + dd.from_pandas + ) return self._pandas_series._from_native_series(native_series) return self._pandas_series._from_native_series( s.__class__(s.cat.categories, name=s.name) diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index db58ccb6b..316495aa4 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -241,7 +241,9 @@ def set_axis( backend_version: tuple[int, ...], ) -> T: if implementation is Implementation.DASK: - return obj # HACK: dask doesn't really reset indexes so much, so assuming its fine + return ( + obj # HACK: dask doesn't really reset indexes so much, so assuming its fine + ) if implementation is Implementation.PANDAS and backend_version < ( 1, ): # pragma: no cover