Skip to content

Commit

Permalink
feat: add Expr.cast and Series.cast for PyArrow backend (#389)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli authored Jul 3, 2024
1 parent 6a1a6f3 commit bf9afea
Show file tree
Hide file tree
Showing 8 changed files with 196 additions and 182 deletions.
8 changes: 7 additions & 1 deletion narwhals/_arrow/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import TYPE_CHECKING
from typing import Callable

from narwhals._arrow.series import ArrowSeries
from narwhals._pandas_like.utils import reuse_series_implementation
from narwhals._pandas_like.utils import reuse_series_namespace_implementation

Expand All @@ -12,6 +11,8 @@

from narwhals._arrow.dataframe import ArrowDataFrame
from narwhals._arrow.namespace import ArrowNamespace
from narwhals._arrow.series import ArrowSeries
from narwhals.dtypes import DType


class ArrowExpr:
Expand Down Expand Up @@ -43,6 +44,8 @@ def __repr__(self) -> str: # pragma: no cover

@classmethod
def from_column_names(cls: type[Self], *column_names: str) -> Self:
from narwhals._arrow.series import ArrowSeries

def func(df: ArrowDataFrame) -> list[ArrowSeries]:
return [
ArrowSeries(
Expand All @@ -65,6 +68,9 @@ def __narwhals_namespace__(self) -> ArrowNamespace:

return ArrowNamespace()

def cast(self, dtype: DType) -> Self:
return reuse_series_implementation(self, "cast", dtype) # type: ignore[type-var]

def cum_sum(self) -> Self:
return reuse_series_implementation(self, "cum_sum") # type: ignore[type-var]

Expand Down
4 changes: 3 additions & 1 deletion narwhals/_arrow/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@

from narwhals import dtypes
from narwhals._arrow.expr import ArrowExpr
from narwhals._arrow.series import ArrowSeries
from narwhals.utils import flatten

if TYPE_CHECKING:
from typing import Callable

from narwhals._arrow.dataframe import ArrowDataFrame
from narwhals._arrow.expr import ArrowExpr
from narwhals._arrow.series import ArrowSeries


class ArrowNamespace:
Expand Down Expand Up @@ -68,6 +68,8 @@ def _create_expr_from_series(self, series: ArrowSeries) -> ArrowExpr:
)

def _create_series_from_scalar(self, value: Any, series: ArrowSeries) -> ArrowSeries:
from narwhals._arrow.series import ArrowSeries

return ArrowSeries.from_iterable(
[value],
name=series.name,
Expand Down
11 changes: 11 additions & 0 deletions narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from typing import Any
from typing import Iterable

from narwhals._arrow.namespace import ArrowNamespace
from narwhals._arrow.utils import reverse_translate_dtype
from narwhals._arrow.utils import translate_dtype
from narwhals._pandas_like.utils import native_series_from_iterable
from narwhals.dependencies import get_pyarrow
Expand Down Expand Up @@ -44,6 +46,9 @@ def from_iterable(cls: type[Self], data: Iterable[Any], name: str) -> Self:
def __len__(self) -> int:
return len(self._series)

def __narwhals_namespace__(self) -> ArrowNamespace:
return ArrowNamespace()

@property
def name(self) -> str:
return self._name
Expand Down Expand Up @@ -85,6 +90,12 @@ def all(self) -> bool:
def is_empty(self) -> bool:
return len(self) == 0

def cast(self, dtype: DType) -> Self:
pc = get_pyarrow_compute()
ser = self._series
dtype = reverse_translate_dtype(dtype)
return self._from_series(pc.cast(ser, dtype))

@property
def shape(self) -> tuple[int]:
return (len(self._series),)
Expand Down
48 changes: 48 additions & 0 deletions narwhals/_arrow/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

from typing import Any

from narwhals import dtypes
from narwhals.dependencies import get_pyarrow
from narwhals.utils import isinstance_or_issubclass


def translate_dtype(dtype: Any) -> dtypes.DType:
Expand Down Expand Up @@ -45,3 +48,48 @@ def translate_dtype(dtype: Any) -> dtypes.DType:
if pa.types.is_dictionary(dtype):
return dtypes.Categorical()
raise AssertionError


def reverse_translate_dtype(dtype: dtypes.DType | type[dtypes.DType]) -> Any:
from narwhals import dtypes

pa = get_pyarrow()

if isinstance_or_issubclass(dtype, dtypes.Float64):
return pa.float64()
if isinstance_or_issubclass(dtype, dtypes.Float32):
return pa.float32()
if isinstance_or_issubclass(dtype, dtypes.Int64):
return pa.int64()
if isinstance_or_issubclass(dtype, dtypes.Int32):
return pa.int32()
if isinstance_or_issubclass(dtype, dtypes.Int16):
return pa.int16()
if isinstance_or_issubclass(dtype, dtypes.Int8):
return pa.int8()
if isinstance_or_issubclass(dtype, dtypes.UInt64):
return pa.uint64()
if isinstance_or_issubclass(dtype, dtypes.UInt32):
return pa.uint32()
if isinstance_or_issubclass(dtype, dtypes.UInt16):
return pa.uint16()
if isinstance_or_issubclass(dtype, dtypes.UInt8):
return pa.uint8()
if isinstance_or_issubclass(dtype, dtypes.String):
return pa.string()
if isinstance_or_issubclass(dtype, dtypes.Boolean):
return pa.bool_()
if isinstance_or_issubclass(dtype, dtypes.Categorical):
# todo: what should the key be? let's keep it consistent
# with Polars for now
return pa.dictionary(pa.uint32(), pa.string())
if isinstance_or_issubclass(dtype, dtypes.Datetime):
# Use Polars' default
return pa.timestamp("us")
if isinstance_or_issubclass(dtype, dtypes.Duration):
# Use Polars' default
return pa.duration("us")
if isinstance_or_issubclass(dtype, dtypes.Date):
return pa.date32()
msg = f"Unknown dtype: {dtype}" # pragma: no cover
raise AssertionError(msg)
108 changes: 108 additions & 0 deletions tests/expr/cast_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from typing import Any

import pyarrow as pa
import pytest

import narwhals as nw
from narwhals.utils import parse_version


def test_cast(constructor_with_pyarrow: Any, request: Any) -> None:
if "table" in str(constructor_with_pyarrow) and parse_version(
pa.__version__
) <= parse_version("12.0.0"): # pragma: no cover
request.applymarker(pytest.mark.xfail)
data = {
"a": [1],
"b": [1],
"c": [1],
"d": [1],
"e": [1],
"f": [1],
"g": [1],
"h": [1],
"i": [1],
"j": [1],
"k": ["1"],
"l": [1],
"m": [True],
"n": [True],
"o": ["a"],
"p": [1],
}
schema = {
"a": nw.Int64,
"b": nw.Int32,
"c": nw.Int16,
"d": nw.Int8,
"e": nw.UInt64,
"f": nw.UInt32,
"g": nw.UInt16,
"h": nw.UInt8,
"i": nw.Float64,
"j": nw.Float32,
"k": nw.String,
"l": nw.Datetime,
"m": nw.Boolean,
"n": nw.Boolean,
"o": nw.Categorical,
"p": nw.Int64,
}
df = nw.from_native(constructor_with_pyarrow(data), eager_only=True).select(
nw.col(key).cast(value) for key, value in schema.items()
)
result = df.select(
nw.col("a").cast(nw.Int32),
nw.col("b").cast(nw.Int16),
nw.col("c").cast(nw.Int8),
nw.col("d").cast(nw.Int64),
nw.col("e").cast(nw.UInt32),
nw.col("f").cast(nw.UInt16),
nw.col("g").cast(nw.UInt8),
nw.col("h").cast(nw.UInt64),
nw.col("i").cast(nw.Float32),
nw.col("j").cast(nw.Float64),
nw.col("k").cast(nw.String),
nw.col("l").cast(nw.Datetime),
nw.col("m").cast(nw.Int8),
nw.col("n").cast(nw.Int8),
nw.col("o").cast(nw.String),
nw.col("p").cast(nw.Duration),
)
expected = {
"a": nw.Int32,
"b": nw.Int16,
"c": nw.Int8,
"d": nw.Int64,
"e": nw.UInt32,
"f": nw.UInt16,
"g": nw.UInt8,
"h": nw.UInt64,
"i": nw.Float32,
"j": nw.Float64,
"k": nw.String,
"l": nw.Datetime,
"m": nw.Int8,
"n": nw.Int8,
"o": nw.String,
"p": nw.Duration,
}
assert result.schema == expected
result = df.select(
df["a"].cast(nw.Int32),
df["b"].cast(nw.Int16),
df["c"].cast(nw.Int8),
df["d"].cast(nw.Int64),
df["e"].cast(nw.UInt32),
df["f"].cast(nw.UInt16),
df["g"].cast(nw.UInt8),
df["h"].cast(nw.UInt64),
df["i"].cast(nw.Float32),
df["j"].cast(nw.Float64),
df["k"].cast(nw.String),
df["l"].cast(nw.Datetime),
df["m"].cast(nw.Int8),
df["n"].cast(nw.Int8),
df["o"].cast(nw.String),
df["p"].cast(nw.Duration),
)
19 changes: 19 additions & 0 deletions tests/series/cast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pandas as pd
import polars as pl
import pyarrow as pa
import pytest
from polars.testing import assert_frame_equal

Expand Down Expand Up @@ -43,6 +44,24 @@ def test_cast_date_datetime_polars() -> None:
assert df.schema == {"a": nw.Date}


def test_cast_date_datetime_pyarrow() -> None:
# polars: date to datetime
dfpa = pa.table({"a": [date(2020, 1, 1), date(2020, 1, 2)]})
df = nw.from_native(dfpa)
df = df.select(nw.col("a").cast(nw.Datetime))
result = nw.to_native(df)
expected = pa.table({"a": [datetime(2020, 1, 1), datetime(2020, 1, 2)]})
assert result == expected

# pyarrow: datetime to date
dfpa = pa.table({"a": [datetime(2020, 1, 1), datetime(2020, 1, 2)]})
df = nw.from_native(dfpa)
df = df.select(nw.col("a").cast(nw.Date))
result = nw.to_native(df)
expected = pa.table({"a": [date(2020, 1, 1), date(2020, 1, 2)]})
assert result == expected


@pytest.mark.skipif(
parse_version(pd.__version__) < parse_version("2.0.0"),
reason="pyarrow dtype not available",
Expand Down
Loading

0 comments on commit bf9afea

Please sign in to comment.