Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Dask Support Implementation #484

Closed
wants to merge 22 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions narwhals/_expression_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import cast
from typing import overload

from narwhals._pandas_like.utils import Implementation
from narwhals.dependencies import get_numpy
from narwhals.utils import flatten

Expand Down Expand Up @@ -192,8 +193,12 @@ def func(df: CompliantDataFrame) -> list[CompliantSeries]:
if expr._output_names is not None and (
[s.name for s in out] != expr._output_names
): # pragma: no cover
msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues"
raise AssertionError(msg)
if not (
hasattr(expr, "_implementation")
and expr._implementation is Implementation.DASK
):
msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues"
raise AssertionError(msg)
return out

# Try tracking root and output names by combining them from all
Expand Down
22 changes: 19 additions & 3 deletions narwhals/_pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from narwhals._pandas_like.utils import validate_dataframe_comparand
from narwhals._pandas_like.utils import validate_indices
from narwhals.dependencies import get_cudf
from narwhals.dependencies import get_dask
from narwhals.dependencies import get_modin
from narwhals.dependencies import get_numpy
from narwhals.dependencies import get_pandas
Expand Down Expand Up @@ -66,7 +67,9 @@ def __native_namespace__(self) -> Any:
return get_modin()
if self._implementation is Implementation.CUDF: # pragma: no cover
return get_cudf()
msg = f"Expected pandas/modin/cudf, got: {type(self._implementation)}" # pragma: no cover
if self._implementation is Implementation.DASK: # pragma: no cover
return get_dask()
msg = f"Expected pandas/modin/cudf/dask, got: {type(self._implementation)}" # pragma: no cover
raise AssertionError(msg)

def __len__(self) -> int:
Expand Down Expand Up @@ -200,6 +203,9 @@ def select(
new_series = evaluate_into_exprs(self, *exprs, **named_exprs)
if not new_series:
# return empty dataframe, like Polars does
if self._implementation is Implementation.DASK:
dd = get_dask()
return self._from_native_dataframe(dd.from_dict({}, npartitions=1))
return self._from_native_dataframe(self._native_dataframe.__class__())
new_series = validate_indices(new_series)
df = horizontal_concat(
Expand Down Expand Up @@ -312,9 +318,15 @@ def sort(

# --- convert ---
def collect(self) -> PandasLikeDataFrame:
if self._implementation is Implementation.DASK:
return_df = self._native_dataframe.compute()
return_implementation = Implementation.PANDAS
else:
return_df = self._native_dataframe
return_implementation = self._implementation
return PandasLikeDataFrame(
self._native_dataframe,
implementation=self._implementation,
return_df,
implementation=return_implementation,
backend_version=self._backend_version,
)

Expand Down Expand Up @@ -487,13 +499,17 @@ def to_numpy(self) -> Any:
import numpy as np

return np.hstack([self[col].to_numpy()[:, None] for col in self.columns])
if self._implementation is Implementation.DASK:
return self._native_dataframe.compute().to_numpy()
return self._native_dataframe.to_numpy()

def to_pandas(self) -> Any:
if self._implementation is Implementation.PANDAS:
return self._native_dataframe
if self._implementation is Implementation.MODIN: # pragma: no cover
return self._native_dataframe._to_pandas()
if self._implementation is Implementation.DASK: # pragma: no cover
return self._native_dataframe.compute()
return self._native_dataframe.to_pandas() # pragma: no cover

def write_parquet(self, file: Any) -> Any:
Expand Down
11 changes: 10 additions & 1 deletion narwhals/_pandas_like/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,15 @@ class PandasLikeGroupBy:
def __init__(self, df: PandasLikeDataFrame, keys: list[str]) -> None:
self._df = df
self._keys = list(keys)
keywords: dict[str, bool] = {}
if df._implementation is not Implementation.DASK:
keywords |= {"as_index": True}
self._grouped = self._df._native_dataframe.groupby(
list(self._keys),
sort=False,
as_index=True,
dropna=False,
**keywords,
)

def agg(
Expand All @@ -57,13 +61,18 @@ def agg(
raise ValueError(msg)
output_names.extend(expr._output_names)

dataframe_is_empty = (
self._df._native_dataframe.empty
if self._df._implementation != Implementation.DASK
else len(self._df._native_dataframe) == 0
)
Comment on lines +64 to +68
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we just use self._df.is_empty()?

return agg_pandas(
self._grouped,
exprs,
self._keys,
output_names,
self._from_native_dataframe,
dataframe_is_empty=self._df._native_dataframe.empty,
dataframe_is_empty=dataframe_is_empty,
implementation=implementation,
backend_version=self._df._backend_version,
)
Expand Down
9 changes: 7 additions & 2 deletions narwhals/_pandas_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
from narwhals._pandas_like.expr import PandasLikeExpr
from narwhals._pandas_like.selectors import PandasSelectorNamespace
from narwhals._pandas_like.series import PandasLikeSeries
from narwhals._pandas_like.utils import Implementation
from narwhals._pandas_like.utils import create_native_series
from narwhals._pandas_like.utils import horizontal_concat
from narwhals._pandas_like.utils import vertical_concat
from narwhals.utils import flatten

if TYPE_CHECKING:
from narwhals._pandas_like.typing import IntoPandasLikeExpr
from narwhals._pandas_like.utils import Implementation


class PandasLikeNamespace:
Expand Down Expand Up @@ -78,10 +78,15 @@ def _create_expr_from_callable(
def _create_series_from_scalar(
self, value: Any, series: PandasLikeSeries
) -> PandasLikeSeries:
index = (
series._native_series.index[0:1]
if self._implementation is not Implementation.DASK
else None
)
return PandasLikeSeries._from_iterable(
[value],
name=series._native_series.name,
index=series._native_series.index[0:1],
index=index,
implementation=self._implementation,
backend_version=self._backend_version,
)
Expand Down
36 changes: 34 additions & 2 deletions narwhals/_pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
from narwhals._pandas_like.utils import Implementation
from narwhals._pandas_like.utils import int_dtype_mapper
from narwhals._pandas_like.utils import native_series_from_iterable
from narwhals._pandas_like.utils import not_implemented_in
from narwhals._pandas_like.utils import reverse_translate_dtype
from narwhals._pandas_like.utils import to_datetime
from narwhals._pandas_like.utils import translate_dtype
from narwhals._pandas_like.utils import validate_column_comparand
from narwhals.dependencies import get_cudf
from narwhals.dependencies import get_dask
from narwhals.dependencies import get_modin
from narwhals.dependencies import get_pandas
from narwhals.dependencies import get_pyarrow_compute
Expand Down Expand Up @@ -107,12 +109,15 @@ def __native_namespace__(self) -> Any:
return get_modin()
if self._implementation is Implementation.CUDF: # pragma: no cover
return get_cudf()
msg = f"Expected pandas/modin/cudf, got: {type(self._implementation)}" # pragma: no cover
if self._implementation is Implementation.DASK: # pragma: no cover
return get_dask()
msg = f"Expected pandas/modin/cudf/dask, got: {type(self._implementation)}" # pragma: no cover
raise AssertionError(msg)

def __narwhals_series__(self) -> Self:
return self

@not_implemented_in(Implementation.DASK)
def __getitem__(self, idx: int | slice | Sequence[int]) -> Any:
if isinstance(idx, int):
return self._native_series.iloc[idx]
Expand Down Expand Up @@ -152,7 +157,7 @@ def _from_iterable(
)

def __len__(self) -> int:
return self.shape[0]
return len(self._native_series)

@property
def name(self) -> str:
Expand Down Expand Up @@ -183,7 +188,11 @@ def item(self: Self, index: int | None = None) -> Any:
f" or an explicit index is provided (Series is of length {len(self)})"
)
raise ValueError(msg)
if self._implementation is Implementation.DASK:
return self._native_series.max() # hack: taking aggregation of 1 item
return self._native_series.iloc[0]
if self._implementation is Implementation.DASK:
raise NotImplementedError("Dask does not support index locating")
return self._native_series.iloc[index]

def to_frame(self) -> Any:
Expand All @@ -196,6 +205,8 @@ def to_frame(self) -> Any:
)

def to_list(self) -> Any:
if self._implementation is Implementation.DASK:
return self._native_series.compute().to_list()
return self._native_series.to_list()

def is_between(
Expand Down Expand Up @@ -504,10 +515,13 @@ def to_pandas(self) -> Any:
return self._native_series.to_pandas()
elif self._implementation is Implementation.MODIN: # pragma: no cover
return self._native_series._to_pandas()
elif self._implementation is Implementation.DASK: # pragma: no cover
return self._native_series.compute()
msg = f"Unknown implementation: {self._implementation}" # pragma: no cover
raise AssertionError(msg)

# --- descriptive ---
@not_implemented_in(Implementation.DASK)
def is_duplicated(self: Self) -> Self:
return self._from_native_series(self._native_series.duplicated(keep=False))

Expand All @@ -520,9 +534,11 @@ def is_unique(self: Self) -> Self:
def null_count(self: Self) -> int:
return self._native_series.isna().sum() # type: ignore[no-any-return]

@not_implemented_in(Implementation.DASK)
def is_first_distinct(self: Self) -> Self:
return self._from_native_series(~self._native_series.duplicated(keep="first"))

@not_implemented_in(Implementation.DASK)
def is_last_distinct(self: Self) -> Self:
return self._from_native_series(~self._native_series.duplicated(keep="last"))

Expand Down Expand Up @@ -559,6 +575,15 @@ def quantile(
quantile: float,
interpolation: Literal["nearest", "higher", "lower", "midpoint", "linear"],
) -> Any:
if self._implementation is Implementation.DASK:
if interpolation == "linear":
return self._native_series.quantile(q=quantile)
message = (
"Dask performs approximate quantile calculations "
"and does not support specific interpolations methods. "
"Interpolation keywords other than 'linear' are not supported"
)
raise NotImplementedError(message)
return self._native_series.quantile(q=quantile, interpolation=interpolation)

def zip_with(self: Self, mask: Any, other: Any) -> PandasLikeSeries:
Expand Down Expand Up @@ -594,6 +619,13 @@ def __init__(self, series: PandasLikeSeries) -> None:

def get_categories(self) -> PandasLikeSeries:
s = self._pandas_series._native_series
if self._pandas_series._implementation is Implementation.DASK:
pd = get_pandas()
dd = get_dask()
native_series = pd.Series(s.cat.as_known().cat.categories, name=s.name).pipe(
dd.from_pandas
)
return self._pandas_series._from_native_series(native_series)
return self._pandas_series._from_native_series(
s.__class__(s.cat.categories, name=s.name)
)
Expand Down
Loading
Loading