diff --git a/docs/api-reference/dtypes.md b/docs/api-reference/dtypes.md index 848e9e58c..a607e9a54 100644 --- a/docs/api-reference/dtypes.md +++ b/docs/api-reference/dtypes.md @@ -16,6 +16,7 @@ - Float32 - Boolean - Categorical + - Enum - String - Datetime - Duration diff --git a/docs/api-reference/narwhals.md b/docs/api-reference/narwhals.md index 23267ec4c..1c76b1629 100644 --- a/docs/api-reference/narwhals.md +++ b/docs/api-reference/narwhals.md @@ -11,6 +11,7 @@ Here are the top-level functions available in Narwhals. - concat - from_native - get_native_namespace + - is_ordered_categorical - len - maybe_align_index - maybe_set_index diff --git a/narwhals/__init__.py b/narwhals/__init__.py index 603c20ff7..406d6ee79 100644 --- a/narwhals/__init__.py +++ b/narwhals/__init__.py @@ -6,6 +6,7 @@ from narwhals.dtypes import Date from narwhals.dtypes import Datetime from narwhals.dtypes import Duration +from narwhals.dtypes import Enum from narwhals.dtypes import Float32 from narwhals.dtypes import Float64 from narwhals.dtypes import Int8 @@ -36,6 +37,7 @@ from narwhals.translate import get_native_namespace from narwhals.translate import narwhalify from narwhals.translate import to_native +from narwhals.utils import is_ordered_categorical from narwhals.utils import maybe_align_index from narwhals.utils import maybe_convert_dtypes from narwhals.utils import maybe_set_index @@ -47,6 +49,7 @@ "concat", "to_native", "from_native", + "is_ordered_categorical", "maybe_align_index", "maybe_convert_dtypes", "maybe_set_index", @@ -78,6 +81,7 @@ "Object", "Unknown", "Categorical", + "Enum", "String", "Datetime", "Duration", diff --git a/narwhals/_arrow/namespace.py b/narwhals/_arrow/namespace.py index 4f3256d65..bc7182b77 100644 --- a/narwhals/_arrow/namespace.py +++ b/narwhals/_arrow/namespace.py @@ -30,6 +30,7 @@ class ArrowNamespace: Object = dtypes.Object Unknown = dtypes.Unknown Categorical = dtypes.Categorical + Enum = dtypes.Enum String = dtypes.String Datetime = dtypes.Datetime Duration = dtypes.Duration diff --git a/narwhals/_arrow/utils.py b/narwhals/_arrow/utils.py index a7bc303da..9b129a329 100644 --- a/narwhals/_arrow/utils.py +++ b/narwhals/_arrow/utils.py @@ -28,10 +28,12 @@ def translate_dtype(dtype: Any) -> dtypes.DType: return dtypes.Float64() if pa.types.is_float32(dtype): return dtypes.Float32() - if ( + # bug in coverage? it shows `31->exit` (where `31` is currently the line number of + # the next line), even though both when the if condition is true and false are covered + if ( # pragma: no cover pa.types.is_string(dtype) or pa.types.is_large_string(dtype) - or pa.types.is_string_view(dtype) + or getattr(pa.types, "is_string_view", lambda _: False)(dtype) ): return dtypes.String() if pa.types.is_date32(dtype): diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index 2d6d154f9..e90a197a2 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -35,6 +35,7 @@ class PandasNamespace: Object = dtypes.Object Unknown = dtypes.Unknown Categorical = dtypes.Categorical + Enum = dtypes.Enum String = dtypes.String Datetime = dtypes.Datetime Duration = dtypes.Duration diff --git a/narwhals/dtypes.py b/narwhals/dtypes.py index 89a29c5ea..66fffdf3d 100644 --- a/narwhals/dtypes.py +++ b/narwhals/dtypes.py @@ -21,6 +21,9 @@ def is_numeric(cls: type[Self]) -> bool: def __eq__(self, other: DType | type[DType]) -> bool: # type: ignore[override] return isinstance_or_issubclass(other, type(self)) + def __hash__(self) -> int: + return hash(self.__class__) + class NumericType(DType): ... @@ -79,6 +82,9 @@ class Duration(TemporalType): ... class Categorical(DType): ... +class Enum(DType): ... + + class Date(TemporalType): ... @@ -117,6 +123,9 @@ def translate_dtype(plx: Any, dtype: DType) -> Any: return plx.Boolean if dtype == Categorical: return plx.Categorical + if dtype == Enum: + msg = "Converting to Enum is not (yet) supported" + raise NotImplementedError(msg) if dtype == Datetime: return plx.Datetime if dtype == Duration: @@ -160,6 +169,8 @@ def to_narwhals_dtype(dtype: Any, *, is_polars: bool) -> DType: return Object() if dtype == pl.Categorical: return Categorical() + if dtype == pl.Enum: + return Enum() if dtype == pl.Datetime: return Datetime() if dtype == pl.Duration: diff --git a/narwhals/utils.py b/narwhals/utils.py index 9d38871fe..b597aed50 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -8,8 +8,13 @@ from typing import TypeVar from typing import cast +from narwhals import dtypes +from narwhals.dependencies import get_cudf +from narwhals.dependencies import get_modin from narwhals.dependencies import get_pandas from narwhals.dependencies import get_polars +from narwhals.dependencies import get_pyarrow +from narwhals.translate import to_native if TYPE_CHECKING: from narwhals.dataframe import BaseFrame @@ -257,3 +262,63 @@ def maybe_convert_dtypes(df: T, *args: bool, **kwargs: bool | str) -> T: ) ) return df + + +def is_ordered_categorical(series: Series) -> bool: + """ + Return whether indices of categories are semantically meaningful. + + This is a convenience function to accessing what would otherwise be + the `is_ordered` property from the DataFrame Interchange Protocol, + see https://data-apis.org/dataframe-protocol/latest/API.html. + + - For Polars: + - Enums are always ordered. + - Categoricals are ordered if `dtype.ordering == "physical"`. + - For pandas-like APIs: + - Categoricals are ordered if `dtype.cat.ordered == True`. + - For PyArrow table: + - Categoricals are ordered if `dtype.type.ordered == True`. + + Examples: + >>> import narwhals as nw + >>> import pandas as pd + >>> import polars as pl + >>> data = ["x", "y"] + >>> s_pd = pd.Series(data, dtype=pd.CategoricalDtype(ordered=True)) + >>> s_pl = pl.Series(data, dtype=pl.Categorical(ordering="physical")) + + Let's define a library-agnostic function: + + >>> @nw.narwhalify + ... def func(s): + ... return nw.is_ordered_categorical(s) + + Then, we can pass any supported library to `func`: + + >>> func(s_pd) + True + >>> func(s_pl) + True + """ + if series.dtype == dtypes.Enum: + return True + if series.dtype != dtypes.Categorical: + return False + native_series = to_native(series) + if (pl := get_polars()) is not None and isinstance(native_series, pl.Series): + return native_series.dtype.ordering == "physical" # type: ignore[no-any-return] + if (pd := get_pandas()) is not None and isinstance(native_series, pd.Series): + return native_series.cat.ordered # type: ignore[no-any-return] + if (mpd := get_modin()) is not None and isinstance( + native_series, mpd.Series + ): # pragma: no cover + return native_series.cat.ordered # type: ignore[no-any-return] + if (cudf := get_cudf()) is not None and isinstance( + native_series, cudf.Series + ): # pragma: no cover + return native_series.cat.ordered # type: ignore[no-any-return] + if (pa := get_pyarrow()) is not None and isinstance(native_series, pa.ChunkedArray): + return native_series.type.ordered # type: ignore[no-any-return] + # If it doesn't match any of the above, let's just play it safe and return False. + return False # pragma: no cover diff --git a/tests/frame/schema_test.py b/tests/frame/schema_test.py index 3caaf58c6..67e634aee 100644 --- a/tests/frame/schema_test.py +++ b/tests/frame/schema_test.py @@ -66,6 +66,7 @@ def test_dtypes() -> None: "o": [datetime(2020, 1, 1)], "p": ["a"], "q": [timedelta(1)], + "r": ["a"], }, schema={ "a": pl.Int64, @@ -85,6 +86,7 @@ def test_dtypes() -> None: "o": pl.Datetime, "p": pl.Categorical, "q": pl.Duration, + "r": pl.Enum(["a", "b"]), }, ) df = nw.DataFrame(df_pl) @@ -107,9 +109,13 @@ def test_dtypes() -> None: "o": nw.Datetime, "p": nw.Categorical, "q": nw.Duration, + "r": nw.Enum, } assert result == expected assert {name: df[name].dtype for name in df.columns} == expected + + # pandas/pyarrow only have categorical, not enum + expected["r"] = nw.Categorical df_pd = df_pl.to_pandas(use_pyarrow_extension_array=True) df = nw.DataFrame(df_pd) result_pd = df.schema @@ -130,3 +136,7 @@ def test_unknown_dtype() -> None: def test_unknown_dtype_polars() -> None: df = pl.DataFrame({"a": [[1, 2, 3]]}) assert nw.from_native(df).schema == {"a": nw.Unknown} + + +def test_hash() -> None: + assert nw.Int64() in {nw.Int64, nw.Int32} diff --git a/tests/series/cast_test.py b/tests/series/cast_test.py index 0192cefe9..0026eceee 100644 --- a/tests/series/cast_test.py +++ b/tests/series/cast_test.py @@ -89,3 +89,19 @@ def test_cast_date_datetime_invalid() -> None: def test_unknown_to_int() -> None: df = pd.DataFrame({"a": pd.period_range("2000", periods=3, freq="M")}) assert nw.from_native(df).select(nw.col("a").cast(nw.Int64)).schema == {"a": nw.Int64} + + +def test_cast_to_enum() -> None: + # we don't yet support metadata in dtypes, so for now disallow this + # seems like a very niche use case anyway, and allowing it later wouldn't be + # backwards-incompatible + df = pl.DataFrame({"a": ["a", "b"]}, schema={"a": pl.Categorical}) + with pytest.raises( + NotImplementedError, match=r"Converting to Enum is not \(yet\) supported" + ): + nw.from_native(df).select(nw.col("a").cast(nw.Enum)) + df = pd.DataFrame({"a": ["a", "b"]}, dtype="category") + with pytest.raises( + NotImplementedError, match=r"Converting to Enum is not \(yet\) supported" + ): + nw.from_native(df).select(nw.col("a").cast(nw.Enum)) diff --git a/tests/series/is_ordered_categorical_test.py b/tests/series/is_ordered_categorical_test.py new file mode 100644 index 000000000..e5f9d7c04 --- /dev/null +++ b/tests/series/is_ordered_categorical_test.py @@ -0,0 +1,41 @@ +from typing import Any + +import pandas as pd +import polars as pl +import pyarrow as pa +import pytest + +import narwhals as nw + + +def test_is_ordered_categorical() -> None: + s = pl.Series(["a", "b"], dtype=pl.Categorical) + assert nw.is_ordered_categorical(nw.from_native(s, series_only=True)) + s = pl.Series(["a", "b"], dtype=pl.Categorical(ordering="lexical")) + assert not nw.is_ordered_categorical(nw.from_native(s, series_only=True)) + s = pl.Series(["a", "b"], dtype=pl.Enum(["a", "b"])) + assert nw.is_ordered_categorical(nw.from_native(s, series_only=True)) + s = pd.Series(["a", "b"], dtype=pd.CategoricalDtype(ordered=True)) + assert nw.is_ordered_categorical(nw.from_native(s, series_only=True)) + s = pd.Series(["a", "b"], dtype=pd.CategoricalDtype(ordered=False)) + assert not nw.is_ordered_categorical(nw.from_native(s, series_only=True)) + s = pa.chunked_array( + [pa.array(["a", "b"], type=pa.dictionary(pa.int32(), pa.string()))] + ) + assert not nw.is_ordered_categorical(nw.from_native(s, series_only=True)) + + +def test_is_definitely_not_ordered_categorical( + constructor_series_with_pyarrow: Any, +) -> None: + assert not nw.is_ordered_categorical( + nw.from_native(constructor_series_with_pyarrow([1, 2, 3]), series_only=True) + ) + + +@pytest.mark.xfail(reason="https://github.com/apache/arrow/issues/41017") +def test_is_ordered_categorical_pyarrow() -> None: + s = pa.chunked_array( + [pa.array(["a", "b"], type=pa.dictionary(pa.int32(), pa.string(), ordered=True))] + ) + assert nw.is_ordered_categorical(nw.from_native(s, series_only=True))