diff --git a/docs/api-reference/dtypes.md b/docs/api-reference/dtypes.md index b1e313ba3..d84b0684f 100644 --- a/docs/api-reference/dtypes.md +++ b/docs/api-reference/dtypes.md @@ -18,6 +18,7 @@ - Categorical - String - Datetime + - Object show_root_heading: false show_source: false show_bases: false diff --git a/narwhals/__init__.py b/narwhals/__init__.py index a19bf6683..3acbccac6 100644 --- a/narwhals/__init__.py +++ b/narwhals/__init__.py @@ -11,6 +11,7 @@ from narwhals.dtypes import Int16 from narwhals.dtypes import Int32 from narwhals.dtypes import Int64 +from narwhals.dtypes import Object from narwhals.dtypes import String from narwhals.dtypes import UInt8 from narwhals.dtypes import UInt16 @@ -72,6 +73,7 @@ "Float64", "Float32", "Boolean", + "Object", "Categorical", "String", "Datetime", diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index 8e49e04f2..7a9deab22 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -135,7 +135,8 @@ def iter_rows( @property def schema(self) -> dict[str, DType]: return { - col: translate_dtype(dtype) for col, dtype in self._dataframe.dtypes.items() + col: translate_dtype(self._dataframe.loc[:, col]) + for col in self._dataframe.columns } # --- reshape --- diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index 95768d704..a1150ba95 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -32,6 +32,7 @@ class PandasNamespace: Float64 = dtypes.Float64 Float32 = dtypes.Float32 Boolean = dtypes.Boolean + Object = dtypes.Object Categorical = dtypes.Categorical String = dtypes.String Datetime = dtypes.Datetime diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index d1b0c0c95..09d9ea569 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -151,7 +151,7 @@ def shape(self) -> tuple[int]: @property def dtype(self) -> DType: - return translate_dtype(self._series.dtype) + return translate_dtype(self._series) def cast( self, diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index 0678b799e..da1e44f43 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -373,9 +373,10 @@ def set_axis(obj: T, index: Any, implementation: str) -> T: return obj.set_axis(index, axis=0) # type: ignore[no-any-return, attr-defined] -def translate_dtype(dtype: Any) -> DType: +def translate_dtype(column: Any) -> DType: from narwhals import dtypes + dtype = column.dtype if dtype in ("int64", "Int64", "Int64[pyarrow]"): return dtypes.Int64() if dtype in ("int32", "Int32", "Int32[pyarrow]"): @@ -412,7 +413,15 @@ def translate_dtype(dtype: Any) -> DType: if str(dtype) == "date32[day][pyarrow]": return dtypes.Date() if dtype == "object": - return dtypes.String() + if (idx := column.first_valid_index()) is not None and isinstance( + column.loc[idx], str + ): + # Infer based on first non-missing value. + # For pandas pre 3.0, this isn't perfect. + # After pandas 3.0, pandas has a dedicated string dtype + # which is inferred by default. + return dtypes.String() + return dtypes.Object() msg = f"Unknown dtype: {dtype}" # pragma: no cover raise AssertionError(msg) diff --git a/narwhals/dtypes.py b/narwhals/dtypes.py index 93f3d99c9..7c3f48694 100644 --- a/narwhals/dtypes.py +++ b/narwhals/dtypes.py @@ -64,6 +64,9 @@ class String(DType): ... class Boolean(DType): ... +class Object(DType): ... + + class Datetime(TemporalType): ... @@ -145,6 +148,8 @@ def to_narwhals_dtype(dtype: Any, *, is_polars: bool) -> DType: return String() if dtype == pl.Boolean: return Boolean() + if dtype == pl.Object: + return Object() if dtype == pl.Categorical: return Categorical() if dtype == pl.Datetime: diff --git a/tests/frame/schema_test.py b/tests/frame/schema_test.py index 36b65ff5e..8f3d3487a 100644 --- a/tests/frame/schema_test.py +++ b/tests/frame/schema_test.py @@ -1,5 +1,10 @@ from datetime import datetime from datetime import timezone +from typing import Any + +import pandas as pd +import polars as pl +import pytest import narwhals as nw @@ -12,3 +17,25 @@ def test_schema_comparison() -> None: assert {"a": nw.String()} != {"a": nw.Int32()} assert {"a": nw.Int32()} == {"a": nw.Int32()} + + +def test_object() -> None: + df = pd.DataFrame({"a": [1, 2, 3]}).astype(object) + result = nw.from_native(df).schema + assert result["a"] == nw.Object + + +def test_string_disguised_as_object() -> None: + df = pd.DataFrame({"a": ["foo", "bar"]}).astype(object) + result = nw.from_native(df).schema + assert result["a"] == nw.String + + +@pytest.mark.parametrize("constructor", [pd.DataFrame, pl.DataFrame]) +def test_actual_object(constructor: Any) -> None: + class Foo: ... + + data = {"a": [Foo()]} + df = nw.from_native(constructor(data)) + result = df.schema + assert result == {"a": nw.Object}