Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: improving array casting #1865

Merged
merged 11 commits into from
Feb 10, 2025
Merged
12 changes: 8 additions & 4 deletions narwhals/_arrow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ def native_to_narwhals_dtype(dtype: pa.DataType, version: Version) -> DType:
for i in range(dtype.num_fields)
]
)

if pa.types.is_list(dtype) or pa.types.is_large_list(dtype):
return dtypes.List(native_to_narwhals_dtype(dtype.value_type, version))
if pa.types.is_fixed_size_list(dtype):
Expand Down Expand Up @@ -141,8 +140,13 @@ def narwhals_to_native_dtype(dtype: DType | type[DType], version: Version) -> pa
]
)
if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover
msg = "Converting to Array dtype is not supported yet"
return NotImplementedError(msg)
inner = narwhals_to_native_dtype(
dtype.inner, # type: ignore[union-attr]
version=version,
)
list_size = dtype.size # type: ignore[union-attr]
return pa.list_(inner, list_size=list_size)

msg = f"Unknown dtype: {dtype}" # pragma: no cover
raise AssertionError(msg)

Expand Down Expand Up @@ -220,7 +224,7 @@ def broadcast_and_extract_dataframe_comparand(

if isinstance(other, ArrowSeries):
len_other = len(other)
if len_other == 1:
if len_other == 1 and length != 1:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise for list/array types we end up getting the first element

import numpy as np # ignore-banned-import

value = other._native_series[0]
Expand Down
16 changes: 11 additions & 5 deletions narwhals/_duckdb/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,13 @@ def native_to_narwhals_dtype(duckdb_dtype: str, version: Version) -> DType:
)
if match_ := re.match(r"(.*)\[\]$", duckdb_dtype):
return dtypes.List(native_to_narwhals_dtype(match_.group(1), version))
if match_ := re.match(r"(\w+)\[(\d+)\]", duckdb_dtype):
if match_ := re.match(r"(\w+)((?:\[\d+\])+)", duckdb_dtype):
Copy link
Member Author

@FBruzzesi FBruzzesi Jan 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Array type in duckdb can have multiple dimensions. The resulting type is: INNER[d1][d2][...]

With this new regex we can parse multiple instances of the dimensions

duckdb_inner_type = match_.group(1)
duckdb_shape = match_.group(2)
shape = tuple(int(value) for value in re.findall(r"\[(\d+)\]", duckdb_shape))
return dtypes.Array(
native_to_narwhals_dtype(match_.group(1), version),
int(match_.group(2)),
inner=native_to_narwhals_dtype(duckdb_inner_type, version),
shape=shape,
)
if duckdb_dtype.startswith("DECIMAL("):
return dtypes.Decimal()
Expand Down Expand Up @@ -171,8 +174,11 @@ def narwhals_to_native_dtype(dtype: DType | type[DType], version: Version) -> st
)
return f"STRUCT({inner})"
if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover
msg = "todo"
raise NotImplementedError(msg)
duckdb_shape_fmt = "".join(f"[{item}]" for item in dtype.shape) # type: ignore[union-attr]
while isinstance(dtype.inner, dtypes.Array): # type: ignore[union-attr]
dtype = dtype.inner # type: ignore[union-attr]
inner = narwhals_to_native_dtype(dtype.inner, version) # type: ignore[union-attr]
return f"{inner}{duckdb_shape_fmt}"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First creates the shape [d1][d2]... then find the inner type recursively (first being non array)

msg = f"Unknown dtype: {dtype}" # pragma: no cover
raise AssertionError(msg)

Expand Down
61 changes: 8 additions & 53 deletions narwhals/_pandas_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def broadcast_and_extract_dataframe_comparand(index: Any, other: Any) -> Any:
if isinstance(other, PandasLikeSeries):
len_other = other.len()

if len_other == 1:
if len_other == 1 and len(index) != 1:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly as for pyarrow, otherwise for list/array types we end up getting the first element

# broadcast
s = other._native_series
return s.__class__(s.iloc[0], index=index, dtype=s.dtype, name=s.name)
Expand Down Expand Up @@ -387,9 +387,7 @@ def rename(


@lru_cache(maxsize=16)
def non_object_native_to_narwhals_dtype(
dtype: str, version: Version, _implementation: Implementation
) -> DType:
def non_object_native_to_narwhals_dtype(dtype: str, version: Version) -> DType:
dtypes = import_dtypes_module(version)
if dtype in {"int64", "Int64", "Int64[pyarrow]", "int64[pyarrow]"}:
return dtypes.Int64()
Expand Down Expand Up @@ -465,7 +463,7 @@ def native_to_narwhals_dtype(
return arrow_native_to_narwhals_dtype(native_dtype.to_arrow(), version)
return arrow_native_to_narwhals_dtype(native_dtype.pyarrow_dtype, version)
if dtype != "object":
return non_object_native_to_narwhals_dtype(dtype, version, implementation)
return non_object_native_to_narwhals_dtype(dtype, version)
if implementation is Implementation.DASK:
# Dask columns are lazy, so we can't inspect values.
# The most useful assumption is probably String
Expand Down Expand Up @@ -649,68 +647,25 @@ def narwhals_to_native_dtype( # noqa: PLR0915
if isinstance_or_issubclass(dtype, dtypes.Enum):
msg = "Converting to Enum is not (yet) supported"
raise NotImplementedError(msg)
if isinstance_or_issubclass(dtype, dtypes.List):
from narwhals._arrow.utils import (
narwhals_to_native_dtype as arrow_narwhals_to_native_dtype,
)

if implementation is Implementation.PANDAS and backend_version >= (2, 2):
try:
import pandas as pd
import pyarrow as pa # ignore-banned-import
except ImportError as exc: # pragma: no cover
msg = f"Unable to convert to {dtype} to to the following exception: {exc.msg}"
raise ImportError(msg) from exc

return pd.ArrowDtype(
pa.list_(
value_type=arrow_narwhals_to_native_dtype(
dtype.inner, # type: ignore[union-attr]
version=version,
)
)
)
else: # pragma: no cover
msg = (
"Converting to List dtype is not supported for implementation "
f"{implementation} and version {version}."
)
return NotImplementedError(msg)
if isinstance_or_issubclass(dtype, dtypes.Struct):
if isinstance_or_issubclass(dtype, (dtypes.Struct, dtypes.Array, dtypes.List)):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This diff is quite nice πŸ˜‰

if implementation is Implementation.PANDAS and backend_version >= (2, 2):
try:
import pandas as pd
import pyarrow as pa # ignore-banned-import
import pyarrow as pa # ignore-banned-import # noqa: F401
except ImportError as exc: # pragma: no cover
msg = f"Unable to convert to {dtype} to to the following exception: {exc.msg}"
raise ImportError(msg) from exc
from narwhals._arrow.utils import (
narwhals_to_native_dtype as arrow_narwhals_to_native_dtype,
)

return pd.ArrowDtype(
pa.struct(
[
(
field.name,
arrow_narwhals_to_native_dtype(
field.dtype,
version=version,
),
)
for field in dtype.fields # type: ignore[union-attr]
]
)
)
return pd.ArrowDtype(arrow_narwhals_to_native_dtype(dtype, version=version))
else: # pragma: no cover
msg = (
"Converting to Struct dtype is not supported for implementation "
f"Converting to {dtype} dtype is not supported for implementation "
f"{implementation} and version {version}."
)
return NotImplementedError(msg)
if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover
msg = "Converting to Array dtype is not supported yet"
return NotImplementedError(msg)
raise NotImplementedError(msg)
msg = f"Unknown dtype: {dtype}" # pragma: no cover
raise AssertionError(msg)

Expand Down
8 changes: 5 additions & 3 deletions narwhals/_polars/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def func(*args: Any, **kwargs: Any) -> Any:

def cast(self: Self, dtype: DType) -> Self:
expr = self._native_expr
dtype_pl = narwhals_to_native_dtype(dtype, self._version)
dtype_pl = narwhals_to_native_dtype(dtype, self._version, self._backend_version)
return self._from_native_expr(expr.cast(dtype_pl))

def ewm_mean(
Expand Down Expand Up @@ -193,7 +193,9 @@ def map_batches(
return_dtype: DType | None,
) -> Self:
if return_dtype is not None:
return_dtype_pl = narwhals_to_native_dtype(return_dtype, self._version)
return_dtype_pl = narwhals_to_native_dtype(
return_dtype, self._version, self._backend_version
)
return self._from_native_expr(
self._native_expr.map_batches(function, return_dtype_pl)
)
Expand All @@ -205,7 +207,7 @@ def replace_strict(
) -> Self:
expr = self._native_expr
return_dtype_pl = (
narwhals_to_native_dtype(return_dtype, self._version)
narwhals_to_native_dtype(return_dtype, self._version, self._backend_version)
if return_dtype
else None
)
Expand Down
13 changes: 10 additions & 3 deletions narwhals/_polars/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,12 @@ def lit(self: Self, value: Any, dtype: DType | None) -> PolarsExpr:

if dtype is not None:
return PolarsExpr(
pl.lit(value, dtype=narwhals_to_native_dtype(dtype, self._version)),
pl.lit(
value,
dtype=narwhals_to_native_dtype(
dtype, self._version, self._backend_version
),
),
version=self._version,
backend_version=self._backend_version,
)
Expand Down Expand Up @@ -206,9 +211,11 @@ def by_dtype(self: Self, dtypes: Iterable[DType]) -> PolarsExpr:
from narwhals._polars.expr import PolarsExpr

native_dtypes = [
narwhals_to_native_dtype(dtype, self._version).__class__
narwhals_to_native_dtype(
dtype, self._version, self._backend_version
).__class__
if isinstance(dtype, type) and issubclass(dtype, DType)
else narwhals_to_native_dtype(dtype, self._version)
else narwhals_to_native_dtype(dtype, self._version, self._backend_version)
for dtype in dtypes
]
return PolarsExpr(
Expand Down
4 changes: 2 additions & 2 deletions narwhals/_polars/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,15 +130,15 @@ def __getitem__(self: Self, item: int | slice | Sequence[int]) -> Any | Self:

def cast(self: Self, dtype: DType) -> Self:
ser = self._native_series
dtype_pl = narwhals_to_native_dtype(dtype, self._version)
dtype_pl = narwhals_to_native_dtype(dtype, self._version, self._backend_version)
return self._from_native_series(ser.cast(dtype_pl))

def replace_strict(
self: Self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType | None
) -> Self:
ser = self._native_series
dtype = (
narwhals_to_native_dtype(return_dtype, self._version)
narwhals_to_native_dtype(return_dtype, self._version, self._backend_version)
if return_dtype
else None
)
Expand Down
31 changes: 16 additions & 15 deletions narwhals/_polars/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,22 +132,19 @@ def native_to_narwhals_dtype(
native_to_narwhals_dtype(dtype.inner, version, backend_version) # type: ignore[attr-defined]
)
if dtype == pl.Array:
if backend_version < (0, 20, 30): # pragma: no cover
return dtypes.Array(
native_to_narwhals_dtype(dtype.inner, version, backend_version), # type: ignore[attr-defined]
dtype.width, # type: ignore[attr-defined]
)
else:
return dtypes.Array(
native_to_narwhals_dtype(dtype.inner, version, backend_version), # type: ignore[attr-defined]
dtype.size, # type: ignore[attr-defined]
)
outer_shape = dtype.width if backend_version < (0, 20, 30) else dtype.size # type: ignore[attr-defined]
return dtypes.Array(
inner=native_to_narwhals_dtype(dtype.inner, version, backend_version), # type: ignore[attr-defined]
shape=outer_shape,
)
if dtype == pl.Decimal:
return dtypes.Decimal()
return dtypes.Unknown()


def narwhals_to_native_dtype(dtype: DType | type[DType], version: Version) -> pl.DataType:
def narwhals_to_native_dtype(
dtype: DType | type[DType], version: Version, backend_version: tuple[int, ...]
) -> pl.DataType:
import polars as pl

dtypes = import_dtypes_module(version)
Expand Down Expand Up @@ -193,20 +190,24 @@ def narwhals_to_native_dtype(dtype: DType | type[DType], version: Version) -> pl
du_time_unit: Literal["us", "ns", "ms"] = getattr(dtype, "time_unit", "us")
return pl.Duration(time_unit=du_time_unit)
if dtype == dtypes.List:
return pl.List(narwhals_to_native_dtype(dtype.inner, version)) # type: ignore[union-attr]
return pl.List(narwhals_to_native_dtype(dtype.inner, version, backend_version)) # type: ignore[union-attr]
if dtype == dtypes.Struct:
return pl.Struct(
fields=[
pl.Field(
name=field.name,
dtype=narwhals_to_native_dtype(field.dtype, version),
dtype=narwhals_to_native_dtype(field.dtype, version, backend_version),
)
for field in dtype.fields # type: ignore[union-attr]
]
)
if dtype == dtypes.Array: # pragma: no cover
msg = "Converting to Array dtype is not supported yet"
raise NotImplementedError(msg)
size = dtype.size # type: ignore[union-attr]
kwargs = {"width": size} if backend_version < (0, 20, 30) else {"shape": size}
return pl.Array(
inner=narwhals_to_native_dtype(dtype.inner, version, backend_version), # type: ignore[union-attr]
**kwargs,
)
return pl.Unknown() # pragma: no cover


Expand Down
28 changes: 15 additions & 13 deletions narwhals/_spark_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,26 +42,24 @@ def native_to_narwhals_dtype(
return dtypes.Int16()
if isinstance(dtype, pyspark_types.ByteType):
return dtypes.Int8()
string_types = [
pyspark_types.StringType,
pyspark_types.VarcharType,
pyspark_types.CharType,
]
if any(isinstance(dtype, t) for t in string_types):
if isinstance(
dtype,
(pyspark_types.StringType, pyspark_types.VarcharType, pyspark_types.CharType),
):
return dtypes.String()
if isinstance(dtype, pyspark_types.BooleanType):
return dtypes.Boolean()
if isinstance(dtype, pyspark_types.DateType):
return dtypes.Date()
datetime_types = [
pyspark_types.TimestampType,
pyspark_types.TimestampNTZType,
]
if any(isinstance(dtype, t) for t in datetime_types):
if isinstance(dtype, (pyspark_types.TimestampType, pyspark_types.TimestampNTZType)):
return dtypes.Datetime()
if isinstance(dtype, pyspark_types.DecimalType): # pragma: no cover
# TODO(unassigned): cover this in dtypes_test.py
return dtypes.Decimal()
if isinstance(dtype, pyspark_types.ArrayType): # pragma: no cover
return dtypes.List(
inner=native_to_narwhals_dtype(dtype.elementType, version=version)
)
FBruzzesi marked this conversation as resolved.
Show resolved Hide resolved
return dtypes.Unknown()


Expand Down Expand Up @@ -96,8 +94,12 @@ def narwhals_to_native_dtype(
msg = "Converting to Struct dtype is not supported yet"
raise NotImplementedError(msg)
if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover
msg = "Converting to Array dtype is not supported yet"
raise NotImplementedError(msg)
inner = narwhals_to_native_dtype(
dtype.inner, # type: ignore[union-attr]
version=version,
)
return pyspark_types.ArrayType(elementType=inner)

if isinstance_or_issubclass(
dtype, (dtypes.UInt64, dtypes.UInt32, dtypes.UInt16, dtypes.UInt8)
): # pragma: no cover
Expand Down
Loading
Loading