Skip to content

Commit

Permalink
feat: add is_ordered_categorical (#377)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli authored Jul 2, 2024
1 parent d77db26 commit c1f1a01
Show file tree
Hide file tree
Showing 11 changed files with 155 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/api-reference/dtypes.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
- Float32
- Boolean
- Categorical
- Enum
- String
- Datetime
- Duration
Expand Down
1 change: 1 addition & 0 deletions docs/api-reference/narwhals.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Here are the top-level functions available in Narwhals.
- concat
- from_native
- get_native_namespace
- is_ordered_categorical
- len
- maybe_align_index
- maybe_set_index
Expand Down
4 changes: 4 additions & 0 deletions narwhals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from narwhals.dtypes import Date
from narwhals.dtypes import Datetime
from narwhals.dtypes import Duration
from narwhals.dtypes import Enum
from narwhals.dtypes import Float32
from narwhals.dtypes import Float64
from narwhals.dtypes import Int8
Expand Down Expand Up @@ -36,6 +37,7 @@
from narwhals.translate import get_native_namespace
from narwhals.translate import narwhalify
from narwhals.translate import to_native
from narwhals.utils import is_ordered_categorical
from narwhals.utils import maybe_align_index
from narwhals.utils import maybe_convert_dtypes
from narwhals.utils import maybe_set_index
Expand All @@ -47,6 +49,7 @@
"concat",
"to_native",
"from_native",
"is_ordered_categorical",
"maybe_align_index",
"maybe_convert_dtypes",
"maybe_set_index",
Expand Down Expand Up @@ -78,6 +81,7 @@
"Object",
"Unknown",
"Categorical",
"Enum",
"String",
"Datetime",
"Duration",
Expand Down
1 change: 1 addition & 0 deletions narwhals/_arrow/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class ArrowNamespace:
Object = dtypes.Object
Unknown = dtypes.Unknown
Categorical = dtypes.Categorical
Enum = dtypes.Enum
String = dtypes.String
Datetime = dtypes.Datetime
Duration = dtypes.Duration
Expand Down
6 changes: 4 additions & 2 deletions narwhals/_arrow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,12 @@ def translate_dtype(dtype: Any) -> dtypes.DType:
return dtypes.Float64()
if pa.types.is_float32(dtype):
return dtypes.Float32()
if (
# bug in coverage? it shows `31->exit` (where `31` is currently the line number of
# the next line), even though both when the if condition is true and false are covered
if ( # pragma: no cover
pa.types.is_string(dtype)
or pa.types.is_large_string(dtype)
or pa.types.is_string_view(dtype)
or getattr(pa.types, "is_string_view", lambda _: False)(dtype)
):
return dtypes.String()
if pa.types.is_date32(dtype):
Expand Down
1 change: 1 addition & 0 deletions narwhals/_pandas_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class PandasNamespace:
Object = dtypes.Object
Unknown = dtypes.Unknown
Categorical = dtypes.Categorical
Enum = dtypes.Enum
String = dtypes.String
Datetime = dtypes.Datetime
Duration = dtypes.Duration
Expand Down
11 changes: 11 additions & 0 deletions narwhals/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ def is_numeric(cls: type[Self]) -> bool:
def __eq__(self, other: DType | type[DType]) -> bool: # type: ignore[override]
return isinstance_or_issubclass(other, type(self))

def __hash__(self) -> int:
return hash(self.__class__)


class NumericType(DType): ...

Expand Down Expand Up @@ -79,6 +82,9 @@ class Duration(TemporalType): ...
class Categorical(DType): ...


class Enum(DType): ...


class Date(TemporalType): ...


Expand Down Expand Up @@ -117,6 +123,9 @@ def translate_dtype(plx: Any, dtype: DType) -> Any:
return plx.Boolean
if dtype == Categorical:
return plx.Categorical
if dtype == Enum:
msg = "Converting to Enum is not (yet) supported"
raise NotImplementedError(msg)
if dtype == Datetime:
return plx.Datetime
if dtype == Duration:
Expand Down Expand Up @@ -160,6 +169,8 @@ def to_narwhals_dtype(dtype: Any, *, is_polars: bool) -> DType:
return Object()
if dtype == pl.Categorical:
return Categorical()
if dtype == pl.Enum:
return Enum()
if dtype == pl.Datetime:
return Datetime()
if dtype == pl.Duration:
Expand Down
65 changes: 65 additions & 0 deletions narwhals/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,13 @@
from typing import TypeVar
from typing import cast

from narwhals import dtypes
from narwhals.dependencies import get_cudf
from narwhals.dependencies import get_modin
from narwhals.dependencies import get_pandas
from narwhals.dependencies import get_polars
from narwhals.dependencies import get_pyarrow
from narwhals.translate import to_native

if TYPE_CHECKING:
from narwhals.dataframe import BaseFrame
Expand Down Expand Up @@ -257,3 +262,63 @@ def maybe_convert_dtypes(df: T, *args: bool, **kwargs: bool | str) -> T:
)
)
return df


def is_ordered_categorical(series: Series) -> bool:
"""
Return whether indices of categories are semantically meaningful.
This is a convenience function to accessing what would otherwise be
the `is_ordered` property from the DataFrame Interchange Protocol,
see https://data-apis.org/dataframe-protocol/latest/API.html.
- For Polars:
- Enums are always ordered.
- Categoricals are ordered if `dtype.ordering == "physical"`.
- For pandas-like APIs:
- Categoricals are ordered if `dtype.cat.ordered == True`.
- For PyArrow table:
- Categoricals are ordered if `dtype.type.ordered == True`.
Examples:
>>> import narwhals as nw
>>> import pandas as pd
>>> import polars as pl
>>> data = ["x", "y"]
>>> s_pd = pd.Series(data, dtype=pd.CategoricalDtype(ordered=True))
>>> s_pl = pl.Series(data, dtype=pl.Categorical(ordering="physical"))
Let's define a library-agnostic function:
>>> @nw.narwhalify
... def func(s):
... return nw.is_ordered_categorical(s)
Then, we can pass any supported library to `func`:
>>> func(s_pd)
True
>>> func(s_pl)
True
"""
if series.dtype == dtypes.Enum:
return True
if series.dtype != dtypes.Categorical:
return False
native_series = to_native(series)
if (pl := get_polars()) is not None and isinstance(native_series, pl.Series):
return native_series.dtype.ordering == "physical" # type: ignore[no-any-return]
if (pd := get_pandas()) is not None and isinstance(native_series, pd.Series):
return native_series.cat.ordered # type: ignore[no-any-return]
if (mpd := get_modin()) is not None and isinstance(
native_series, mpd.Series
): # pragma: no cover
return native_series.cat.ordered # type: ignore[no-any-return]
if (cudf := get_cudf()) is not None and isinstance(
native_series, cudf.Series
): # pragma: no cover
return native_series.cat.ordered # type: ignore[no-any-return]
if (pa := get_pyarrow()) is not None and isinstance(native_series, pa.ChunkedArray):
return native_series.type.ordered # type: ignore[no-any-return]
# If it doesn't match any of the above, let's just play it safe and return False.
return False # pragma: no cover
10 changes: 10 additions & 0 deletions tests/frame/schema_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def test_dtypes() -> None:
"o": [datetime(2020, 1, 1)],
"p": ["a"],
"q": [timedelta(1)],
"r": ["a"],
},
schema={
"a": pl.Int64,
Expand All @@ -85,6 +86,7 @@ def test_dtypes() -> None:
"o": pl.Datetime,
"p": pl.Categorical,
"q": pl.Duration,
"r": pl.Enum(["a", "b"]),
},
)
df = nw.DataFrame(df_pl)
Expand All @@ -107,9 +109,13 @@ def test_dtypes() -> None:
"o": nw.Datetime,
"p": nw.Categorical,
"q": nw.Duration,
"r": nw.Enum,
}
assert result == expected
assert {name: df[name].dtype for name in df.columns} == expected

# pandas/pyarrow only have categorical, not enum
expected["r"] = nw.Categorical
df_pd = df_pl.to_pandas(use_pyarrow_extension_array=True)
df = nw.DataFrame(df_pd)
result_pd = df.schema
Expand All @@ -130,3 +136,7 @@ def test_unknown_dtype() -> None:
def test_unknown_dtype_polars() -> None:
df = pl.DataFrame({"a": [[1, 2, 3]]})
assert nw.from_native(df).schema == {"a": nw.Unknown}


def test_hash() -> None:
assert nw.Int64() in {nw.Int64, nw.Int32}
16 changes: 16 additions & 0 deletions tests/series/cast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,19 @@ def test_cast_date_datetime_invalid() -> None:
def test_unknown_to_int() -> None:
df = pd.DataFrame({"a": pd.period_range("2000", periods=3, freq="M")})
assert nw.from_native(df).select(nw.col("a").cast(nw.Int64)).schema == {"a": nw.Int64}


def test_cast_to_enum() -> None:
# we don't yet support metadata in dtypes, so for now disallow this
# seems like a very niche use case anyway, and allowing it later wouldn't be
# backwards-incompatible
df = pl.DataFrame({"a": ["a", "b"]}, schema={"a": pl.Categorical})
with pytest.raises(
NotImplementedError, match=r"Converting to Enum is not \(yet\) supported"
):
nw.from_native(df).select(nw.col("a").cast(nw.Enum))
df = pd.DataFrame({"a": ["a", "b"]}, dtype="category")
with pytest.raises(
NotImplementedError, match=r"Converting to Enum is not \(yet\) supported"
):
nw.from_native(df).select(nw.col("a").cast(nw.Enum))
41 changes: 41 additions & 0 deletions tests/series/is_ordered_categorical_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from typing import Any

import pandas as pd
import polars as pl
import pyarrow as pa
import pytest

import narwhals as nw


def test_is_ordered_categorical() -> None:
s = pl.Series(["a", "b"], dtype=pl.Categorical)
assert nw.is_ordered_categorical(nw.from_native(s, series_only=True))
s = pl.Series(["a", "b"], dtype=pl.Categorical(ordering="lexical"))
assert not nw.is_ordered_categorical(nw.from_native(s, series_only=True))
s = pl.Series(["a", "b"], dtype=pl.Enum(["a", "b"]))
assert nw.is_ordered_categorical(nw.from_native(s, series_only=True))
s = pd.Series(["a", "b"], dtype=pd.CategoricalDtype(ordered=True))
assert nw.is_ordered_categorical(nw.from_native(s, series_only=True))
s = pd.Series(["a", "b"], dtype=pd.CategoricalDtype(ordered=False))
assert not nw.is_ordered_categorical(nw.from_native(s, series_only=True))
s = pa.chunked_array(
[pa.array(["a", "b"], type=pa.dictionary(pa.int32(), pa.string()))]
)
assert not nw.is_ordered_categorical(nw.from_native(s, series_only=True))


def test_is_definitely_not_ordered_categorical(
constructor_series_with_pyarrow: Any,
) -> None:
assert not nw.is_ordered_categorical(
nw.from_native(constructor_series_with_pyarrow([1, 2, 3]), series_only=True)
)


@pytest.mark.xfail(reason="https://github.com/apache/arrow/issues/41017")
def test_is_ordered_categorical_pyarrow() -> None:
s = pa.chunked_array(
[pa.array(["a", "b"], type=pa.dictionary(pa.int32(), pa.string(), ordered=True))]
)
assert nw.is_ordered_categorical(nw.from_native(s, series_only=True))

0 comments on commit c1f1a01

Please sign in to comment.