From 1b8fd7d480181e146b4fcf6140168b7dbeb32b63 Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Sat, 25 Jan 2025 15:48:08 +0100 Subject: [PATCH 1/5] feat: improving array casting --- narwhals/_arrow/utils.py | 11 ++++-- narwhals/_duckdb/utils.py | 16 ++++++--- narwhals/_pandas_like/utils.py | 61 +++++----------------------------- narwhals/_polars/utils.py | 21 +++++------- narwhals/_spark_like/utils.py | 28 ++++++++-------- narwhals/dtypes.py | 51 ++++++++++++++++++++++------ tests/dtypes_test.py | 24 ++++++++++--- 7 files changed, 111 insertions(+), 101 deletions(-) diff --git a/narwhals/_arrow/utils.py b/narwhals/_arrow/utils.py index 2040c4e2e5..a3d2b09c70 100644 --- a/narwhals/_arrow/utils.py +++ b/narwhals/_arrow/utils.py @@ -141,8 +141,13 @@ def narwhals_to_native_dtype(dtype: DType | type[DType], version: Version) -> pa ] ) if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover - msg = "Converting to Array dtype is not supported yet" - return NotImplementedError(msg) + inner = narwhals_to_native_dtype( + dtype.inner, # type: ignore[union-attr] + version=version, + ) + list_size = dtype.size # type: ignore[union-attr] + return pa.list_(inner, list_size=list_size) + msg = f"Unknown dtype: {dtype}" # pragma: no cover raise AssertionError(msg) @@ -220,7 +225,7 @@ def broadcast_and_extract_dataframe_comparand( if isinstance(other, ArrowSeries): len_other = len(other) - if len_other == 1: + if len_other == 1 and length != 1: import numpy as np # ignore-banned-import value = other._native_series[0] diff --git a/narwhals/_duckdb/utils.py b/narwhals/_duckdb/utils.py index d45123267f..6a10eaed0a 100644 --- a/narwhals/_duckdb/utils.py +++ b/narwhals/_duckdb/utils.py @@ -133,10 +133,13 @@ def native_to_narwhals_dtype(duckdb_dtype: str, version: Version) -> DType: ) if match_ := re.match(r"(.*)\[\]$", duckdb_dtype): return dtypes.List(native_to_narwhals_dtype(match_.group(1), version)) - if match_ := re.match(r"(\w+)\[(\d+)\]", duckdb_dtype): + if match_ := re.match(r"(\w+)((?:\[\d+\])+)", duckdb_dtype): + duckdb_inner_type = match_.group(1) + duckdb_shape = match_.group(2) + shape = tuple(int(value) for value in re.findall(r"\[(\d+)\]", duckdb_shape)) return dtypes.Array( - native_to_narwhals_dtype(match_.group(1), version), - int(match_.group(2)), + inner=native_to_narwhals_dtype(duckdb_inner_type, version), + shape=shape, ) if duckdb_dtype.startswith("DECIMAL("): return dtypes.Decimal() @@ -193,8 +196,11 @@ def narwhals_to_native_dtype(dtype: DType | type[DType], version: Version) -> st ) return f"STRUCT({inner})" if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover - msg = "todo" - raise NotImplementedError(msg) + duckdb_shape_fmt = "".join(f"[{item}]" for item in dtype.shape) # type: ignore[union-attr] + while isinstance(dtype.inner, dtypes.Array): # type: ignore[union-attr] + dtype = dtype.inner # type: ignore[union-attr] + inner = narwhals_to_native_dtype(dtype.inner, version) # type: ignore[union-attr] + return f"{inner}{duckdb_shape_fmt}" msg = f"Unknown dtype: {dtype}" # pragma: no cover raise AssertionError(msg) diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index 883af4bbdc..1b8ae58dab 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -160,7 +160,7 @@ def broadcast_and_extract_dataframe_comparand(index: Any, other: Any) -> Any: if isinstance(other, PandasLikeSeries): len_other = other.len() - if len_other == 1: + if len_other == 1 and len(index) != 1: # broadcast s = other._native_series return s.__class__(s.iloc[0], index=index, dtype=s.dtype, name=s.name) @@ -387,9 +387,7 @@ def rename( @lru_cache(maxsize=16) -def non_object_native_to_narwhals_dtype( - dtype: str, version: Version, _implementation: Implementation -) -> DType: +def non_object_native_to_narwhals_dtype(dtype: str, version: Version) -> DType: dtypes = import_dtypes_module(version) if dtype in {"int64", "Int64", "Int64[pyarrow]", "int64[pyarrow]"}: return dtypes.Int64() @@ -465,7 +463,7 @@ def native_to_narwhals_dtype( return arrow_native_to_narwhals_dtype(native_dtype.to_arrow(), version) return arrow_native_to_narwhals_dtype(native_dtype.pyarrow_dtype, version) if dtype != "object": - return non_object_native_to_narwhals_dtype(dtype, version, implementation) + return non_object_native_to_narwhals_dtype(dtype, version) if implementation is Implementation.DASK: # Dask columns are lazy, so we can't inspect values. # The most useful assumption is probably String @@ -649,38 +647,11 @@ def narwhals_to_native_dtype( # noqa: PLR0915 if isinstance_or_issubclass(dtype, dtypes.Enum): msg = "Converting to Enum is not (yet) supported" raise NotImplementedError(msg) - if isinstance_or_issubclass(dtype, dtypes.List): - from narwhals._arrow.utils import ( - narwhals_to_native_dtype as arrow_narwhals_to_native_dtype, - ) - - if implementation is Implementation.PANDAS and backend_version >= (2, 2): - try: - import pandas as pd - import pyarrow as pa # ignore-banned-import - except ImportError as exc: # pragma: no cover - msg = f"Unable to convert to {dtype} to to the following exception: {exc.msg}" - raise ImportError(msg) from exc - - return pd.ArrowDtype( - pa.list_( - value_type=arrow_narwhals_to_native_dtype( - dtype.inner, # type: ignore[union-attr] - version=version, - ) - ) - ) - else: # pragma: no cover - msg = ( - "Converting to List dtype is not supported for implementation " - f"{implementation} and version {version}." - ) - return NotImplementedError(msg) - if isinstance_or_issubclass(dtype, dtypes.Struct): + if isinstance_or_issubclass(dtype, (dtypes.Struct, dtypes.Array, dtypes.List)): if implementation is Implementation.PANDAS and backend_version >= (2, 2): try: import pandas as pd - import pyarrow as pa # ignore-banned-import + import pyarrow as pa # ignore-banned-import # noqa: F401 except ImportError as exc: # pragma: no cover msg = f"Unable to convert to {dtype} to to the following exception: {exc.msg}" raise ImportError(msg) from exc @@ -688,29 +659,13 @@ def narwhals_to_native_dtype( # noqa: PLR0915 narwhals_to_native_dtype as arrow_narwhals_to_native_dtype, ) - return pd.ArrowDtype( - pa.struct( - [ - ( - field.name, - arrow_narwhals_to_native_dtype( - field.dtype, - version=version, - ), - ) - for field in dtype.fields # type: ignore[union-attr] - ] - ) - ) + return pd.ArrowDtype(arrow_narwhals_to_native_dtype(dtype, version=version)) else: # pragma: no cover msg = ( - "Converting to Struct dtype is not supported for implementation " + f"Converting to {dtype} dtype is not supported for implementation " f"{implementation} and version {version}." ) - return NotImplementedError(msg) - if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover - msg = "Converting to Array dtype is not supported yet" - return NotImplementedError(msg) + raise NotImplementedError(msg) msg = f"Unknown dtype: {dtype}" # pragma: no cover raise AssertionError(msg) diff --git a/narwhals/_polars/utils.py b/narwhals/_polars/utils.py index 94510d2efb..3ae61f8383 100644 --- a/narwhals/_polars/utils.py +++ b/narwhals/_polars/utils.py @@ -132,16 +132,11 @@ def native_to_narwhals_dtype( native_to_narwhals_dtype(dtype.inner, version, backend_version) # type: ignore[attr-defined] ) if dtype == pl.Array: - if backend_version < (0, 20, 30): # pragma: no cover - return dtypes.Array( - native_to_narwhals_dtype(dtype.inner, version, backend_version), # type: ignore[attr-defined] - dtype.width, # type: ignore[attr-defined] - ) - else: - return dtypes.Array( - native_to_narwhals_dtype(dtype.inner, version, backend_version), # type: ignore[attr-defined] - dtype.size, # type: ignore[attr-defined] - ) + outer_shape = dtype.width if backend_version < (0, 20, 30) else dtype.size # type: ignore[attr-defined] + return dtypes.Array( + inner=native_to_narwhals_dtype(dtype.inner, version, backend_version), # type: ignore[attr-defined] + shape=outer_shape, + ) if dtype == pl.Decimal: return dtypes.Decimal() return dtypes.Unknown() @@ -205,8 +200,10 @@ def narwhals_to_native_dtype(dtype: DType | type[DType], version: Version) -> pl ] ) if dtype == dtypes.Array: # pragma: no cover - msg = "Converting to Array dtype is not supported yet" - raise NotImplementedError(msg) + return pl.Array( + inner=narwhals_to_native_dtype(dtype.inner, version), # type: ignore[union-attr] + shape=dtype.size, # type: ignore[union-attr] + ) return pl.Unknown() # pragma: no cover diff --git a/narwhals/_spark_like/utils.py b/narwhals/_spark_like/utils.py index 37a2426d44..42e362aaf4 100644 --- a/narwhals/_spark_like/utils.py +++ b/narwhals/_spark_like/utils.py @@ -41,26 +41,24 @@ def native_to_narwhals_dtype( return dtypes.Int16() if isinstance(dtype, pyspark_types.ByteType): return dtypes.Int8() - string_types = [ - pyspark_types.StringType, - pyspark_types.VarcharType, - pyspark_types.CharType, - ] - if any(isinstance(dtype, t) for t in string_types): + if isinstance( + dtype, + (pyspark_types.StringType, pyspark_types.VarcharType, pyspark_types.CharType), + ): return dtypes.String() if isinstance(dtype, pyspark_types.BooleanType): return dtypes.Boolean() if isinstance(dtype, pyspark_types.DateType): return dtypes.Date() - datetime_types = [ - pyspark_types.TimestampType, - pyspark_types.TimestampNTZType, - ] - if any(isinstance(dtype, t) for t in datetime_types): + if isinstance(dtype, (pyspark_types.TimestampType, pyspark_types.TimestampNTZType)): return dtypes.Datetime() if isinstance(dtype, pyspark_types.DecimalType): # pragma: no cover # TODO(unassigned): cover this in dtypes_test.py return dtypes.Decimal() + if isinstance(dtype, pyspark_types.ArrayType): # pragma: no cover + return dtypes.List( + inner=native_to_narwhals_dtype(dtype.elementType, version=version) + ) return dtypes.Unknown() @@ -97,8 +95,12 @@ def narwhals_to_native_dtype( msg = "Converting to Struct dtype is not supported yet" raise NotImplementedError(msg) if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover - msg = "Converting to Array dtype is not supported yet" - raise NotImplementedError(msg) + inner = narwhals_to_native_dtype( + dtype.inner, # type: ignore[union-attr] + version=version, + ) + return pyspark_types.ArrayType(elementType=inner) + if isinstance_or_issubclass( dtype, (dtypes.UInt64, dtypes.UInt32, dtypes.UInt16, dtypes.UInt8) ): # pragma: no cover diff --git a/narwhals/dtypes.py b/narwhals/dtypes.py index 1a615b4800..35fb86307c 100644 --- a/narwhals/dtypes.py +++ b/narwhals/dtypes.py @@ -731,14 +731,37 @@ class Array(DType): Array(Int32, 2) """ + inner: DType | type[DType] + size: int + shape: tuple[int, ...] + def __init__( - self: Self, inner: DType | type[DType], width: int | None = None + self: Self, + inner: DType | type[DType], + shape: int | tuple[int, ...] | None = None, ) -> None: - self.inner = inner - if width is None: - error = "`width` must be specified when initializing an `Array`" - raise TypeError(error) - self.width = width + inner_shape: tuple[int, ...] = inner.shape if isinstance(inner, Array) else () + + if shape is None: # pragma: no cover + msg = "Array constructor is missing the required argument `shape`" + raise TypeError(msg) + + if isinstance(shape, int): + self.inner = inner + self.size = shape + self.shape = (shape, *inner_shape) + + elif isinstance(shape, tuple): + if len(shape) > 1: + inner = Array(inner, shape[1:]) + + self.inner = inner + self.size = shape[0] + self.shape = shape + inner_shape + + else: + msg = f"invalid input for shape: {shape!r}" + raise TypeError(msg) def __eq__(self: Self, other: DType | type[DType]) -> bool: # type: ignore[override] # This equality check allows comparison of type classes and type instances. @@ -751,16 +774,24 @@ def __eq__(self: Self, other: DType | type[DType]) -> bool: # type: ignore[over if type(other) is type and issubclass(other, self.__class__): return True elif isinstance(other, self.__class__): - return self.inner == other.inner + if self.shape != other.shape: + return False + else: + return self.inner == other.inner else: return False def __hash__(self: Self) -> int: - return hash((self.__class__, self.inner, self.width)) + return hash((self.__class__, self.inner, self.shape)) + + def __repr__(self) -> str: + # Get leaf type + dtype = self.inner + while isinstance(dtype, Array): + dtype = dtype.inner - def __repr__(self: Self) -> str: class_name = self.__class__.__name__ - return f"{class_name}({self.inner!r}, {self.width})" + return f"{class_name}({dtype!r}, shape={self.shape})" class Date(TemporalType): diff --git a/tests/dtypes_test.py b/tests/dtypes_test.py index 97ca384c88..ad4193e75d 100644 --- a/tests/dtypes_test.py +++ b/tests/dtypes_test.py @@ -3,6 +3,7 @@ from datetime import datetime from datetime import timedelta from datetime import timezone +from typing import TYPE_CHECKING from typing import Literal import numpy as np @@ -15,6 +16,9 @@ from tests.utils import PANDAS_VERSION from tests.utils import POLARS_VERSION +if TYPE_CHECKING: + from tests.utils import Constructor + @pytest.mark.parametrize("time_unit", ["us", "ns", "ms"]) @pytest.mark.parametrize("time_zone", ["Europe/Rome", timezone.utc, None]) @@ -75,7 +79,7 @@ def test_array_valid() -> None: assert dtype == nw.Array assert dtype != nw.Array(nw.Float32, 2) assert dtype != nw.Duration - assert repr(dtype) == "Array(, 2)" + assert repr(dtype) == "Array(, shape=(2,))" dtype = nw.Array(nw.Array(nw.Int64, 2), 2) assert dtype == nw.Array(nw.Array(nw.Int64, 2), 2) assert dtype == nw.Array @@ -83,7 +87,7 @@ def test_array_valid() -> None: assert dtype in {nw.Array(nw.Array(nw.Int64, 2), 2)} with pytest.raises( - TypeError, match="`width` must be specified when initializing an `Array`" + TypeError, match="Array constructor is missing the required argument `shape`" ): dtype = nw.Array(nw.Int64) @@ -133,13 +137,23 @@ def test_polars_2d_array() -> None: df = pl.DataFrame( {"a": [[[1, 2], [3, 4], [5, 6]]]}, schema={"a": pl.Array(pl.Int64, (3, 2))} ) - assert nw.from_native(df).collect_schema()["a"] == nw.Array(nw.Array(nw.Int64, 2), 3) + assert nw.from_native(df).collect_schema()["a"] == nw.Array(nw.Int64(), (3, 2)) assert nw.from_native(df.to_arrow()).collect_schema()["a"] == nw.Array( - nw.Array(nw.Int64, 2), 3 + nw.Array(nw.Int64(), 2), 3 ) assert nw.from_native( df.to_pandas(use_pyarrow_extension_array=True) - ).collect_schema()["a"] == nw.Array(nw.Array(nw.Int64, 2), 3) + ).collect_schema()["a"] == nw.Array(nw.Array(nw.Int64(), 2), 3) + + +def test_2d_array(constructor: Constructor, request: pytest.FixtureRequest) -> None: + if any(x in str(constructor) for x in ("dask", "modin", "cudf", "pyspark")): + request.applymarker(pytest.mark.xfail) + data = {"a": [[[1, 2], [3, 4], [5, 6]]]} + df = nw.from_native(constructor(data)).with_columns( + a=nw.col("a").cast(nw.Array(nw.Int64(), (3, 2))) + ) + assert df.collect_schema()["a"] == nw.Array(nw.Int64(), (3, 2)) def test_second_time_unit() -> None: From ab25f4d8ab575bea305d3aa9b4f899e3310db1b6 Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Sat, 25 Jan 2025 16:45:28 +0100 Subject: [PATCH 2/5] WIP --- narwhals/_polars/expr.py | 8 +++++--- narwhals/_polars/namespace.py | 12 ++++++++++-- narwhals/_polars/series.py | 4 ++-- narwhals/_polars/utils.py | 14 +++++++++----- narwhals/dtypes.py | 2 +- narwhals/functions.py | 19 +++++++++++++++---- tests/dtypes_test.py | 25 +++++++------------------ 7 files changed, 49 insertions(+), 35 deletions(-) diff --git a/narwhals/_polars/expr.py b/narwhals/_polars/expr.py index 98a692a674..233b281cf7 100644 --- a/narwhals/_polars/expr.py +++ b/narwhals/_polars/expr.py @@ -47,7 +47,7 @@ def func(*args: Any, **kwargs: Any) -> Any: def cast(self: Self, dtype: DType) -> Self: expr = self._native_expr - dtype_pl = narwhals_to_native_dtype(dtype, self._version) + dtype_pl = narwhals_to_native_dtype(dtype, self._version, self._backend_version) return self._from_native_expr(expr.cast(dtype_pl)) def ewm_mean( @@ -193,7 +193,9 @@ def map_batches( return_dtype: DType | None, ) -> Self: if return_dtype is not None: - return_dtype_pl = narwhals_to_native_dtype(return_dtype, self._version) + return_dtype_pl = narwhals_to_native_dtype( + return_dtype, self._version, self._backend_version + ) return self._from_native_expr( self._native_expr.map_batches(function, return_dtype_pl) ) @@ -205,7 +207,7 @@ def replace_strict( ) -> Self: expr = self._native_expr return_dtype_pl = ( - narwhals_to_native_dtype(return_dtype, self._version) + narwhals_to_native_dtype(return_dtype, self._version, self._backend_version) if return_dtype else None ) diff --git a/narwhals/_polars/namespace.py b/narwhals/_polars/namespace.py index 32e53b372f..5f3bf67c33 100644 --- a/narwhals/_polars/namespace.py +++ b/narwhals/_polars/namespace.py @@ -120,7 +120,12 @@ def lit(self: Self, value: Any, dtype: DType | None) -> PolarsExpr: if dtype is not None: return PolarsExpr( - pl.lit(value, dtype=narwhals_to_native_dtype(dtype, self._version)), + pl.lit( + value, + dtype=narwhals_to_native_dtype( + dtype, self._version, self._backend_version + ), + ), version=self._version, backend_version=self._backend_version, ) @@ -221,7 +226,10 @@ def by_dtype(self: Self, dtypes: Iterable[DType]) -> PolarsExpr: return PolarsExpr( pl.selectors.by_dtype( - [narwhals_to_native_dtype(dtype, self._version) for dtype in dtypes] + [ + narwhals_to_native_dtype(dtype, self._version, self._backend_version) + for dtype in dtypes + ] ), version=self._version, backend_version=self._backend_version, diff --git a/narwhals/_polars/series.py b/narwhals/_polars/series.py index 6a4d50d11d..38e84f3f98 100644 --- a/narwhals/_polars/series.py +++ b/narwhals/_polars/series.py @@ -130,7 +130,7 @@ def __getitem__(self: Self, item: int | slice | Sequence[int]) -> Any | Self: def cast(self: Self, dtype: DType) -> Self: ser = self._native_series - dtype_pl = narwhals_to_native_dtype(dtype, self._version) + dtype_pl = narwhals_to_native_dtype(dtype, self._version, self._backend_version) return self._from_native_series(ser.cast(dtype_pl)) def replace_strict( @@ -138,7 +138,7 @@ def replace_strict( ) -> Self: ser = self._native_series dtype = ( - narwhals_to_native_dtype(return_dtype, self._version) + narwhals_to_native_dtype(return_dtype, self._version, self._backend_version) if return_dtype else None ) diff --git a/narwhals/_polars/utils.py b/narwhals/_polars/utils.py index 3ae61f8383..a69818adc9 100644 --- a/narwhals/_polars/utils.py +++ b/narwhals/_polars/utils.py @@ -142,7 +142,9 @@ def native_to_narwhals_dtype( return dtypes.Unknown() -def narwhals_to_native_dtype(dtype: DType | type[DType], version: Version) -> pl.DataType: +def narwhals_to_native_dtype( + dtype: DType | type[DType], version: Version, backend_version: tuple[int, ...] +) -> pl.DataType: import polars as pl dtypes = import_dtypes_module(version) @@ -188,21 +190,23 @@ def narwhals_to_native_dtype(dtype: DType | type[DType], version: Version) -> pl du_time_unit: Literal["us", "ns", "ms"] = getattr(dtype, "time_unit", "us") return pl.Duration(time_unit=du_time_unit) if dtype == dtypes.List: - return pl.List(narwhals_to_native_dtype(dtype.inner, version)) # type: ignore[union-attr] + return pl.List(narwhals_to_native_dtype(dtype.inner, version, backend_version)) # type: ignore[union-attr] if dtype == dtypes.Struct: return pl.Struct( fields=[ pl.Field( name=field.name, - dtype=narwhals_to_native_dtype(field.dtype, version), + dtype=narwhals_to_native_dtype(field.dtype, version, backend_version), ) for field in dtype.fields # type: ignore[union-attr] ] ) if dtype == dtypes.Array: # pragma: no cover + size = dtype.size # type: ignore[union-attr] + kwargs = {"width": size} if backend_version < (0, 20, 30) else {"shape": size} return pl.Array( - inner=narwhals_to_native_dtype(dtype.inner, version), # type: ignore[union-attr] - shape=dtype.size, # type: ignore[union-attr] + inner=narwhals_to_native_dtype(dtype.inner, version, backend_version), # type: ignore[union-attr] + **kwargs, ) return pl.Unknown() # pragma: no cover diff --git a/narwhals/dtypes.py b/narwhals/dtypes.py index 35fb86307c..44303fc7cf 100644 --- a/narwhals/dtypes.py +++ b/narwhals/dtypes.py @@ -751,7 +751,7 @@ def __init__( self.size = shape self.shape = (shape, *inner_shape) - elif isinstance(shape, tuple): + elif isinstance(shape, tuple) and isinstance(shape[0], int): if len(shape) > 1: inner = Array(inner, shape[1:]) diff --git a/narwhals/functions.py b/narwhals/functions.py index 7fd648dc1d..c703c678d7 100644 --- a/narwhals/functions.py +++ b/narwhals/functions.py @@ -326,7 +326,10 @@ def _new_series_impl( narwhals_to_native_dtype as polars_narwhals_to_native_dtype, ) - dtype_pl = polars_narwhals_to_native_dtype(dtype, version=version) + backend_version = parse_version(native_namespace.__version__) + dtype_pl = polars_narwhals_to_native_dtype( + dtype, version=version, backend_version=backend_version + ) else: dtype_pl = None @@ -441,7 +444,7 @@ def from_dict( ) -def _from_dict_impl( +def _from_dict_impl( # noqa: PLR0915 data: dict[str, Any], schema: dict[str, DType] | Schema | None = None, *, @@ -471,8 +474,11 @@ def _from_dict_impl( narwhals_to_native_dtype as polars_narwhals_to_native_dtype, ) + backend_version = parse_version(native_namespace.__version__) schema_pl = { - name: polars_narwhals_to_native_dtype(dtype, version=version) + name: polars_narwhals_to_native_dtype( + dtype, version=version, backend_version=backend_version + ) for name, dtype in schema.items() } else: @@ -713,8 +719,13 @@ def _from_numpy_impl( narwhals_to_native_dtype as polars_narwhals_to_native_dtype, ) + backend_version = parse_version(native_namespace.__version__) schema = { - name: polars_narwhals_to_native_dtype(dtype, version=version) # type: ignore[misc] + name: polars_narwhals_to_native_dtype( # type: ignore[misc] + dtype, + version=version, + backend_version=backend_version, + ) for name, dtype in schema.items() } elif schema is None: diff --git a/tests/dtypes_test.py b/tests/dtypes_test.py index ad4193e75d..1892abf514 100644 --- a/tests/dtypes_test.py +++ b/tests/dtypes_test.py @@ -77,6 +77,7 @@ def test_array_valid() -> None: dtype = nw.Array(nw.Int64, 2) assert dtype == nw.Array(nw.Int64, 2) assert dtype == nw.Array + assert dtype != nw.Array(nw.Int64, 3) assert dtype != nw.Array(nw.Float32, 2) assert dtype != nw.Duration assert repr(dtype) == "Array(, shape=(2,))" @@ -89,7 +90,10 @@ def test_array_valid() -> None: with pytest.raises( TypeError, match="Array constructor is missing the required argument `shape`" ): - dtype = nw.Array(nw.Int64) + nw.Array(nw.Int64) + + with pytest.raises(TypeError, match="invalid input for shape"): + nw.Array(nw.Int64(), shape="invalid_type") # type: ignore[arg-type] def test_struct_valid() -> None: @@ -129,23 +133,7 @@ def test_struct_hashes() -> None: assert len({hash(tp) for tp in (dtypes)}) == 3 -@pytest.mark.skipif( - POLARS_VERSION < (1,) or PANDAS_VERSION < (2, 2), - reason="`shape` is only available after 1.0", -) -def test_polars_2d_array() -> None: - df = pl.DataFrame( - {"a": [[[1, 2], [3, 4], [5, 6]]]}, schema={"a": pl.Array(pl.Int64, (3, 2))} - ) - assert nw.from_native(df).collect_schema()["a"] == nw.Array(nw.Int64(), (3, 2)) - assert nw.from_native(df.to_arrow()).collect_schema()["a"] == nw.Array( - nw.Array(nw.Int64(), 2), 3 - ) - assert nw.from_native( - df.to_pandas(use_pyarrow_extension_array=True) - ).collect_schema()["a"] == nw.Array(nw.Array(nw.Int64(), 2), 3) - - +@pytest.mark.skipif(PANDAS_VERSION < (2, 2), reason="old pandas") def test_2d_array(constructor: Constructor, request: pytest.FixtureRequest) -> None: if any(x in str(constructor) for x in ("dask", "modin", "cudf", "pyspark")): request.applymarker(pytest.mark.xfail) @@ -154,6 +142,7 @@ def test_2d_array(constructor: Constructor, request: pytest.FixtureRequest) -> N a=nw.col("a").cast(nw.Array(nw.Int64(), (3, 2))) ) assert df.collect_schema()["a"] == nw.Array(nw.Int64(), (3, 2)) + assert df.collect_schema()["a"] == nw.Array(nw.Array(nw.Int64(), 2), 3) def test_second_time_unit() -> None: From 65287c9fa90df186dff1fc8cf49ed6039467f6d6 Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Sat, 25 Jan 2025 17:17:43 +0100 Subject: [PATCH 3/5] update docstring --- narwhals/dtypes.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/narwhals/dtypes.py b/narwhals/dtypes.py index 44303fc7cf..8fe7f10d6a 100644 --- a/narwhals/dtypes.py +++ b/narwhals/dtypes.py @@ -711,7 +711,7 @@ class Array(DType): Arguments: inner: The datatype of the values within each array. - width: the length of each array. + shape: the length of each array. Examples: >>> import pandas as pd @@ -724,11 +724,11 @@ class Array(DType): >>> ser_pa = pa.chunked_array([data], type=pa.list_(pa.int32(), 2)) >>> nw.from_native(ser_pd, series_only=True).dtype - Array(Int32, 2) + Array(Int32, shape=(2,)) >>> nw.from_native(ser_pl, series_only=True).dtype - Array(Int32, 2) + Array(Int32, shape=(2,)) >>> nw.from_native(ser_pa, series_only=True).dtype - Array(Int32, 2) + Array(Int32, shape=(2,)) """ inner: DType | type[DType] From 31dd90f7640924aff7f8502c329e71b9a6fa7352 Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Mon, 27 Jan 2025 16:10:00 +0100 Subject: [PATCH 4/5] rm imports --- narwhals/_duckdb/dataframe.py | 4 ---- narwhals/_polars/dataframe.py | 5 ----- narwhals/_polars/expr.py | 6 ------ narwhals/_polars/namespace.py | 10 ---------- narwhals/_polars/series.py | 14 -------------- narwhals/_polars/utils.py | 2 -- narwhals/_spark_like/utils.py | 2 -- 7 files changed, 43 deletions(-) diff --git a/narwhals/_duckdb/dataframe.py b/narwhals/_duckdb/dataframe.py index c34028e841..bc7bf081f3 100644 --- a/narwhals/_duckdb/dataframe.py +++ b/narwhals/_duckdb/dataframe.py @@ -315,8 +315,6 @@ def collect_schema(self: Self) -> dict[str, DType]: def unique(self: Self, subset: Sequence[str] | None, keep: str) -> Self: if subset is not None: - import duckdb - rel = self._native_frame # Sanitise input if any(x not in rel.columns for x in subset): @@ -365,8 +363,6 @@ def sort( return self._from_native_frame(result) def drop_nulls(self: Self, subset: list[str] | None) -> Self: - import duckdb - rel = self._native_frame subset_ = subset if subset is not None else rel.columns keep_condition = " and ".join(f'"{col}" is not null' for col in subset_) diff --git a/narwhals/_polars/dataframe.py b/narwhals/_polars/dataframe.py index 727912f11a..0473a171a9 100644 --- a/narwhals/_polars/dataframe.py +++ b/narwhals/_polars/dataframe.py @@ -101,8 +101,6 @@ def _from_native_object( def __getattr__(self: Self, attr: str) -> Any: def func(*args: Any, **kwargs: Any) -> Any: - import polars as pl - args, kwargs = extract_args_kwargs(args, kwargs) # type: ignore[assignment] try: return self._from_native_object( @@ -176,7 +174,6 @@ def __getitem__(self: Self, item: Any) -> Any: ) msg = f"Expected slice of integers or strings, got: {type(item[1])}" # pragma: no cover raise TypeError(msg) # pragma: no cover - import polars as pl if ( isinstance(item, tuple) @@ -412,8 +409,6 @@ def collect_schema(self: Self) -> dict[str, DType]: } def collect(self: Self) -> PolarsDataFrame: - import polars as pl - try: result = self._native_frame.collect() except pl.exceptions.ColumnNotFoundError as e: diff --git a/narwhals/_polars/expr.py b/narwhals/_polars/expr.py index 233b281cf7..2b6969a3db 100644 --- a/narwhals/_polars/expr.py +++ b/narwhals/_polars/expr.py @@ -79,8 +79,6 @@ def ewm_mean( **extra_kwargs, ) if self._backend_version < (1,): # pragma: no cover - import polars as pl - return self._from_native_expr( pl.when(~expr.is_null()).then(native_expr).otherwise(None) ) @@ -352,14 +350,10 @@ def len(self: Self) -> PolarsExpr: native_result = native_expr.list.len() if self._expr._backend_version < (1, 16): # pragma: no cover - import polars as pl - native_result: pl.Expr = ( # type: ignore[no-redef] pl.when(~native_expr.is_null()).then(native_result).cast(pl.UInt32()) ) elif self._expr._backend_version < (1, 17): # pragma: no cover - import polars as pl - native_result = native_result.cast(pl.UInt32()) return self._expr._from_native_expr(native_result) diff --git a/narwhals/_polars/namespace.py b/narwhals/_polars/namespace.py index 33fb84dfc6..60789862c9 100644 --- a/narwhals/_polars/namespace.py +++ b/narwhals/_polars/namespace.py @@ -225,8 +225,6 @@ def by_dtype(self: Self, dtypes: Iterable[DType]) -> PolarsExpr: ) def numeric(self: Self) -> PolarsExpr: - import polars as pl - from narwhals._polars.expr import PolarsExpr return PolarsExpr( @@ -236,8 +234,6 @@ def numeric(self: Self) -> PolarsExpr: ) def boolean(self: Self) -> PolarsExpr: - import polars as pl - from narwhals._polars.expr import PolarsExpr return PolarsExpr( @@ -247,8 +243,6 @@ def boolean(self: Self) -> PolarsExpr: ) def string(self: Self) -> PolarsExpr: - import polars as pl - from narwhals._polars.expr import PolarsExpr return PolarsExpr( @@ -258,8 +252,6 @@ def string(self: Self) -> PolarsExpr: ) def categorical(self: Self) -> PolarsExpr: - import polars as pl - from narwhals._polars.expr import PolarsExpr return PolarsExpr( @@ -269,8 +261,6 @@ def categorical(self: Self) -> PolarsExpr: ) def all(self: Self) -> PolarsExpr: - import polars as pl - from narwhals._polars.expr import PolarsExpr return PolarsExpr( diff --git a/narwhals/_polars/series.py b/narwhals/_polars/series.py index 38e84f3f98..3fcc231bcf 100644 --- a/narwhals/_polars/series.py +++ b/narwhals/_polars/series.py @@ -77,8 +77,6 @@ def _from_native_object(self: Self, series: T) -> T: ... def _from_native_object( self: Self, series: pl.Series | pl.DataFrame | T ) -> Self | PolarsDataFrame | T: - import polars as pl - if isinstance(series, pl.Series): return self._from_native_series(series) if isinstance(series, pl.DataFrame): @@ -244,8 +242,6 @@ def median(self: Self) -> Any: return self._native_series.median() def to_dummies(self: Self, *, separator: str, drop_first: bool) -> PolarsDataFrame: - import polars as pl - from narwhals._polars.dataframe import PolarsDataFrame if self._backend_version < (0, 20, 15): @@ -294,8 +290,6 @@ def ewm_mean( **extra_kwargs, ) if self._backend_version < (1,): # pragma: no cover - import polars as pl - return self._from_native_series( pl.select( pl.when(~native_series.is_null()).then(native_result).otherwise(None) @@ -405,8 +399,6 @@ def sort(self: Self, *, descending: bool, nulls_last: bool) -> Self: result = self._native_series.sort(descending=descending) if nulls_last: - import polars as pl - is_null = result.is_null() result = pl.concat([result.filter(~is_null), result.filter(is_null)]) else: @@ -433,8 +425,6 @@ def value_counts( from narwhals._polars.dataframe import PolarsDataFrame if self._backend_version < (1, 0, 0): - import polars as pl - value_name_ = name or ("proportion" if normalize else "count") result = self._native_series.value_counts(sort=sort, parallel=parallel) @@ -547,15 +537,11 @@ def len(self: Self) -> PolarsSeries: native_result = native_series.list.len() if self._series._backend_version < (1, 16): # pragma: no cover - import polars as pl - native_result = pl.select( pl.when(~native_series.is_null()).then(native_result).otherwise(None) )[native_series.name].cast(pl.UInt32()) elif self._series._backend_version < (1, 17): # pragma: no cover - import polars as pl - native_result = native_series.cast(pl.UInt32()) return self._series._from_native_series(native_result) diff --git a/narwhals/_polars/utils.py b/narwhals/_polars/utils.py index a69818adc9..f0ee621bf2 100644 --- a/narwhals/_polars/utils.py +++ b/narwhals/_polars/utils.py @@ -145,8 +145,6 @@ def native_to_narwhals_dtype( def narwhals_to_native_dtype( dtype: DType | type[DType], version: Version, backend_version: tuple[int, ...] ) -> pl.DataType: - import polars as pl - dtypes = import_dtypes_module(version) if dtype == dtypes.Float64: diff --git a/narwhals/_spark_like/utils.py b/narwhals/_spark_like/utils.py index aa3ccce0e4..b0a613e8a4 100644 --- a/narwhals/_spark_like/utils.py +++ b/narwhals/_spark_like/utils.py @@ -15,8 +15,6 @@ from narwhals.utils import isinstance_or_issubclass if TYPE_CHECKING: - from pyspark.sql import Column - from narwhals._spark_like.dataframe import SparkLikeLazyFrame from narwhals._spark_like.expr import SparkLikeExpr from narwhals.dtypes import DType From 690fc9938e2a64cad1514c7bd530722f4d046519 Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Sun, 9 Feb 2025 22:24:41 +0100 Subject: [PATCH 5/5] avoid while loops, shape docstring --- narwhals/_duckdb/utils.py | 12 +++++++----- narwhals/dtypes.py | 10 +++++----- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/narwhals/_duckdb/utils.py b/narwhals/_duckdb/utils.py index fee5f7af9f..eb69d86b9e 100644 --- a/narwhals/_duckdb/utils.py +++ b/narwhals/_duckdb/utils.py @@ -192,11 +192,13 @@ def narwhals_to_native_dtype(dtype: DType | type[DType], version: Version) -> st ) return f"STRUCT({inner})" if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover - duckdb_shape_fmt = "".join(f"[{item}]" for item in dtype.shape) # type: ignore[union-attr] - while isinstance(dtype.inner, dtypes.Array): # type: ignore[union-attr] - dtype = dtype.inner # type: ignore[union-attr] - inner = narwhals_to_native_dtype(dtype.inner, version) # type: ignore[union-attr] - return f"{inner}{duckdb_shape_fmt}" + shape: tuple[int] = dtype.shape # type: ignore[union-attr] + duckdb_shape_fmt = "".join(f"[{item}]" for item in shape) + inner_dtype = dtype + for _ in shape: + inner_dtype = inner_dtype.inner # type: ignore[union-attr] + duckdb_inner = narwhals_to_native_dtype(inner_dtype, version) + return f"{duckdb_inner}{duckdb_shape_fmt}" msg = f"Unknown dtype: {dtype}" # pragma: no cover raise AssertionError(msg) diff --git a/narwhals/dtypes.py b/narwhals/dtypes.py index c38aacc35a..7d9cf31361 100644 --- a/narwhals/dtypes.py +++ b/narwhals/dtypes.py @@ -765,7 +765,7 @@ class Array(NestedType): Arguments: inner: The datatype of the values within each array. - shape: the length of each array. + shape: The shape of the arrays. Examples: >>> import pandas as pd @@ -840,12 +840,12 @@ def __hash__(self: Self) -> int: def __repr__(self) -> str: # Get leaf type - dtype = self.inner - while isinstance(dtype, Array): - dtype = dtype.inner + dtype_ = self + for _ in self.shape: + dtype_ = dtype_.inner # type: ignore[assignment] class_name = self.__class__.__name__ - return f"{class_name}({dtype!r}, shape={self.shape})" + return f"{class_name}({dtype_!r}, shape={self.shape})" class Date(TemporalType):