Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: dask expr cast #821

Merged
merged 13 commits into from
Aug 25, 2024
17 changes: 17 additions & 0 deletions narwhals/_dask/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from narwhals._dask.utils import add_row_index
from narwhals._dask.utils import maybe_evaluate
from narwhals._dask.utils import reverse_translate_dtype
from narwhals.dependencies import get_dask
from narwhals.utils import generate_unique_token

Expand All @@ -17,6 +18,7 @@

from narwhals._dask.dataframe import DaskLazyFrame
from narwhals._dask.namespace import DaskNamespace
from narwhals.dtypes import DType


class DaskExpr:
Expand Down Expand Up @@ -654,6 +656,21 @@ def dt(self: Self) -> DaskExprDateTimeNamespace:
def name(self: Self) -> DaskExprNameNamespace:
return DaskExprNameNamespace(self)

def cast(
self: Self,
dtype: DType | type[DType],
) -> Self:
def func(_input: Any, dtype: DType | type[DType]) -> Any:
dtype = reverse_translate_dtype(dtype)
return _input.astype(dtype)

return self._from_call(
func,
"cast",
dtype,
returns_scalar=False,
)


class DaskExprStringNamespace:
def __init__(self, expr: DaskExpr) -> None:
Expand Down
49 changes: 49 additions & 0 deletions narwhals/_dask/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,14 @@
from typing import Any

from narwhals.dependencies import get_dask_expr
from narwhals.dependencies import get_pandas
from narwhals.dependencies import get_pyarrow
from narwhals.utils import isinstance_or_issubclass
from narwhals.utils import parse_version

if TYPE_CHECKING:
from narwhals._dask.dataframe import DaskLazyFrame
from narwhals.dtypes import DType


def maybe_evaluate(df: DaskLazyFrame, obj: Any) -> Any:
Expand Down Expand Up @@ -73,3 +78,47 @@ def parse_exprs_and_named_exprs(
def add_row_index(frame: Any, name: str) -> Any:
frame = frame.assign(**{name: 1})
return frame.assign(**{name: frame[name].cumsum(method="blelloch") - 1})


def reverse_translate_dtype(dtype: DType | type[DType]) -> Any:
from narwhals import dtypes

if isinstance_or_issubclass(dtype, dtypes.Float64):
return "float64"
if isinstance_or_issubclass(dtype, dtypes.Float32):
return "float32"
if isinstance_or_issubclass(dtype, dtypes.Int64):
return "int64"
if isinstance_or_issubclass(dtype, dtypes.Int32):
return "int32"
if isinstance_or_issubclass(dtype, dtypes.Int16):
return "int16"
if isinstance_or_issubclass(dtype, dtypes.Int8):
return "int8"
if isinstance_or_issubclass(dtype, dtypes.UInt64):
return "uint64"
if isinstance_or_issubclass(dtype, dtypes.UInt32):
return "uint32"
if isinstance_or_issubclass(dtype, dtypes.UInt16):
return "uint16"
if isinstance_or_issubclass(dtype, dtypes.UInt8):
return "uint8"
if isinstance_or_issubclass(dtype, dtypes.String):
if (pd := get_pandas()) is not None and parse_version(
pd.__version__
) >= parse_version("2.0.0"):
if get_pyarrow() is not None:
return "string[pyarrow]"
return "string[python]" # pragma: no cover
return "object" # pragma: no cover
if isinstance_or_issubclass(dtype, dtypes.Boolean):
return "bool"
if isinstance_or_issubclass(dtype, dtypes.Categorical):
return "category"
if isinstance_or_issubclass(dtype, dtypes.Datetime):
return "datetime64[us]"
if isinstance_or_issubclass(dtype, dtypes.Duration):
return "timedelta64[ns]"

msg = f"Unknown dtype: {dtype}" # pragma: no cover
raise AssertionError(msg)
6 changes: 1 addition & 5 deletions tests/expr_and_series/binary_test.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
from typing import Any

import pytest

import narwhals.stable.v1 as nw
from tests.utils import compare_dicts


def test_expr_binary(constructor: Any, request: Any) -> None:
def test_expr_binary(constructor: Any) -> None:
data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)
df_raw = constructor(data)
result = nw.from_native(df_raw).with_columns(
a=(1 + 3 * nw.col("a")) * (1 / nw.col("a")),
Expand Down
35 changes: 28 additions & 7 deletions tests/expr_and_series/cast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@

@pytest.mark.filterwarnings("ignore:casting period[M] values to int64:FutureWarning")
def test_cast(constructor: Any, request: Any) -> None:
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)
if "pyarrow_table_constructor" in str(constructor) and parse_version(
pa.__version__
) <= (15,): # pragma: no cover
Expand Down Expand Up @@ -98,17 +96,21 @@ def test_cast(constructor: Any, request: Any) -> None:
assert dict(result.collect_schema()) == expected


def test_cast_series(constructor_eager: Any, request: Any) -> None:
if "pyarrow_table_constructor" in str(constructor_eager) and parse_version(
def test_cast_series(constructor: Any, request: Any) -> None:
if "pyarrow_table_constructor" in str(constructor) and parse_version(
pa.__version__
) <= (15,): # pragma: no cover
request.applymarker(pytest.mark.xfail)
if "modin" in str(constructor_eager):
if "modin" in str(constructor):
# TODO(unassigned): in modin, we end up with `'<U0'` dtype
request.applymarker(pytest.mark.xfail)
df = nw.from_native(constructor_eager(data), eager_only=True).select(
nw.col(key).cast(value) for key, value in schema.items()
df = (
nw.from_native(constructor(data))
.select(nw.col(key).cast(value) for key, value in schema.items())
.lazy()
.collect()
)

expected = {
"a": nw.Int32,
"b": nw.Int16,
Expand Down Expand Up @@ -158,3 +160,22 @@ def test_cast_string() -> None:
s = s.cast(nw.String)
result = nw.to_native(s)
assert str(result.dtype) in ("string", "object", "dtype('O')")


def test_cast_raises_for_unknown_dtype(constructor: Any, request: Any) -> None:
if "pyarrow_table_constructor" in str(constructor) and parse_version(
pa.__version__
) <= (15,): # pragma: no cover
request.applymarker(pytest.mark.xfail)
if "polars" in str(constructor):
request.applymarker(pytest.mark.xfail)

df = nw.from_native(constructor(data)).select(
nw.col(key).cast(value) for key, value in schema.items()
)

class Banana:
pass

with pytest.raises(AssertionError, match=r"Unknown dtype"):
df.select(nw.col("a").cast(Banana))
14 changes: 0 additions & 14 deletions tests/selectors_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import pytest

import narwhals.stable.v1 as nw
from narwhals.dependencies import get_dask_dataframe
from narwhals.selectors import all
from narwhals.selectors import boolean
from narwhals.selectors import by_dtype
Expand Down Expand Up @@ -57,8 +56,6 @@ def test_string(constructor: Any, request: Any) -> None:


def test_categorical(request: Any, constructor: Any) -> None:
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)
if "pyarrow_table_constructor" in str(constructor) and parse_version(
pa.__version__
) <= (15,): # pragma: no cover
Expand All @@ -70,17 +67,6 @@ def test_categorical(request: Any, constructor: Any) -> None:
compare_dicts(result, expected)


@pytest.mark.skipif((get_dask_dataframe() is None), reason="too old for dask")
def test_dask_categorical() -> None:
import dask.dataframe as dd

expected = {"b": ["a", "b", "c"]}
df_raw = dd.from_dict(expected, npartitions=1).astype({"b": "category"})
df = nw.from_native(df_raw)
result = df.select(categorical())
compare_dicts(result, expected)


@pytest.mark.parametrize(
("selector", "expected"),
[
Expand Down
3 changes: 0 additions & 3 deletions tests/test_group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,6 @@ def test_group_by_multiple_keys(constructor: Any) -> None:


def test_key_with_nulls(constructor: Any, request: Any) -> None:
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)

if "modin" in str(constructor):
# TODO(unassigned): Modin flaky here?
request.applymarker(pytest.mark.skip)
Expand Down
Loading