Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: consistent to_numpy behaviour for tz-aware #1305

Merged
merged 12 commits into from
Nov 2, 2024
21 changes: 18 additions & 3 deletions narwhals/_pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,21 +698,36 @@ def to_numpy(self, dtype: Any = None, copy: bool | None = None) -> Any:
# pandas default differs from Polars, but cuDF default is True
copy = self._implementation is Implementation.CUDF

to_convert = [
key
for key, val in self.schema.items()
if val == self._dtypes.Datetime and val.time_zone is not None # type: ignore[attr-defined]
]
if to_convert:
df = self.with_columns(
self.__narwhals_namespace__()
.col(*to_convert)
.dt.convert_time_zone("UTC")
.dt.replace_time_zone(None)
)._native_frame
else:
df = self._native_frame

if dtype is not None:
return self._native_frame.to_numpy(dtype=dtype, copy=copy)
return df.to_numpy(dtype=dtype, copy=copy)

# pandas return `object` dtype for nullable dtypes if dtype=None,
# so we cast each Series to numpy and let numpy find a common dtype.
# If there aren't any dtypes where `to_numpy()` is "broken" (i.e. it
# returns Object) then we just call `to_numpy()` on the DataFrame.
for col_dtype in self._native_frame.dtypes:
for col_dtype in df.dtypes:
if str(col_dtype) in PANDAS_TO_NUMPY_DTYPE_MISSING:
import numpy as np # ignore-banned-import

return np.hstack(
[self[col].to_numpy(copy=copy)[:, None] for col in self.columns]
)
return self._native_frame.to_numpy(copy=copy)
return df.to_numpy(copy=copy)

def to_pandas(self) -> Any:
if self._implementation is Implementation.PANDAS:
Expand Down
28 changes: 12 additions & 16 deletions narwhals/_pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,34 +511,30 @@ def to_numpy(self, dtype: Any = None, copy: bool | None = None) -> Any:
# the default is meant to be None, but pandas doesn't allow it?
# https://numpy.org/doc/stable/reference/generated/numpy.ndarray.__array__.html
copy = copy or self._implementation is Implementation.CUDF
if self.dtype == self._dtypes.Datetime and self.dtype.time_zone is not None: # type: ignore[attr-defined]
s = self.dt.convert_time_zone("UTC").dt.replace_time_zone(None)._native_series
else:
s = self._native_series

has_missing = self._native_series.isna().any()
if (
has_missing
and str(self._native_series.dtype) in PANDAS_TO_NUMPY_DTYPE_MISSING
):
has_missing = s.isna().any()
if has_missing and str(s.dtype) in PANDAS_TO_NUMPY_DTYPE_MISSING:
if self._implementation is Implementation.PANDAS and self._backend_version < (
1,
): # pragma: no cover
kwargs = {}
else:
kwargs = {"na_value": float("nan")}
return self._native_series.to_numpy(
dtype=dtype
or PANDAS_TO_NUMPY_DTYPE_MISSING[str(self._native_series.dtype)],
return s.to_numpy(
dtype=dtype or PANDAS_TO_NUMPY_DTYPE_MISSING[str(s.dtype)],
copy=copy,
**kwargs,
)
if (
not has_missing
and str(self._native_series.dtype) in PANDAS_TO_NUMPY_DTYPE_NO_MISSING
):
return self._native_series.to_numpy(
dtype=dtype
or PANDAS_TO_NUMPY_DTYPE_NO_MISSING[str(self._native_series.dtype)],
if not has_missing and str(s.dtype) in PANDAS_TO_NUMPY_DTYPE_NO_MISSING:
return s.to_numpy(
dtype=dtype or PANDAS_TO_NUMPY_DTYPE_NO_MISSING[str(s.dtype)],
copy=copy,
)
return self._native_series.to_numpy(dtype=dtype, copy=copy)
return s.to_numpy(dtype=dtype, copy=copy)

def to_pandas(self) -> Any:
if self._implementation is Implementation.PANDAS:
Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import contextlib
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable

import pandas as pd
import polars as pl
Expand All @@ -19,6 +18,7 @@
from narwhals.typing import IntoDataFrame
from narwhals.typing import IntoFrame
from tests.utils import Constructor
from tests.utils import ConstructorEager

with contextlib.suppress(ImportError):
import modin.pandas # noqa: F401
Expand Down Expand Up @@ -117,7 +117,7 @@ def pyarrow_table_constructor(obj: Any) -> IntoDataFrame:
@pytest.fixture(params=eager_constructors)
def constructor_eager(
request: pytest.FixtureRequest,
) -> Callable[[Any], IntoDataFrame]:
) -> ConstructorEager:
return request.param # type: ignore[no-any-return]


Expand Down
34 changes: 34 additions & 0 deletions tests/frame/to_numpy_test.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from __future__ import annotations

from datetime import datetime
from typing import TYPE_CHECKING

import numpy as np
import pytest

import narwhals.stable.v1 as nw
from tests.utils import PANDAS_VERSION
from tests.utils import PYARROW_VERSION
from tests.utils import is_windows

if TYPE_CHECKING:
from tests.utils import ConstructorEager
Expand All @@ -18,3 +23,32 @@ def test_to_numpy(constructor_eager: ConstructorEager) -> None:
expected = np.array([[1, 3, 2], [4, 4, 6], [7.1, 8, 9]]).T
np.testing.assert_array_equal(result, expected)
assert result.dtype == "float64"


def test_to_numpy_tz_aware(
constructor_eager: ConstructorEager, request: pytest.FixtureRequest
) -> None:
if (
("pyarrow_table" in str(constructor_eager) and PYARROW_VERSION < (12,))
or ("pandas_pyarrow" in str(constructor_eager) and PANDAS_VERSION < (2, 2))
or (
any(x in str(constructor_eager) for x in ("pyarrow", "modin"))
and is_windows()
)
):
request.applymarker(pytest.mark.xfail)
df = nw.from_native(
constructor_eager({"a": [datetime(2020, 1, 1), datetime(2020, 1, 2)]}),
eager_only=True,
)
df = df.select(nw.col("a").dt.replace_time_zone("Asia/Kathmandu"))
result = df.to_numpy()
# for some reason, NumPy uses 'M' for datetimes
assert result.dtype.kind == "M"
assert (
result
== np.array(
[["2019-12-31T18:15:00.000000"], ["2020-01-01T18:15:00.000000"]],
dtype=result.dtype,
)
).all()
34 changes: 34 additions & 0 deletions tests/series_only/to_numpy_test.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from __future__ import annotations

from datetime import datetime
from typing import TYPE_CHECKING

import numpy as np
import pytest
from numpy.testing import assert_array_equal

import narwhals.stable.v1 as nw
from tests.utils import PANDAS_VERSION
from tests.utils import PYARROW_VERSION
from tests.utils import is_windows

if TYPE_CHECKING:
from tests.utils import ConstructorEager
Expand All @@ -30,3 +34,33 @@ def test_to_numpy(
assert s.shape == (3,)

assert_array_equal(s.to_numpy(), np.array(data, dtype=float))


def test_to_numpy_tz_aware(
constructor_eager: ConstructorEager, request: pytest.FixtureRequest
) -> None:
if (
("pyarrow_table" in str(constructor_eager) and PYARROW_VERSION < (12,))
or ("pandas_pyarrow" in str(constructor_eager) and PANDAS_VERSION < (2, 2))
or (
any(x in str(constructor_eager) for x in ("pyarrow", "modin"))
and is_windows()
)
):
request.applymarker(pytest.mark.xfail)
request.applymarker(pytest.mark.xfail)
df = nw.from_native(
constructor_eager({"a": [datetime(2020, 1, 1), datetime(2020, 1, 2)]}),
eager_only=True,
)
df = df.select(nw.col("a").dt.replace_time_zone("Asia/Kathmandu"))
result = df["a"].to_numpy()
# for some reason, NumPy uses 'M' for datetimes
assert result.dtype.kind == "M"
assert (
result
== np.array(
["2019-12-31T18:15:00.000000", "2020-01-01T18:15:00.000000"],
dtype=result.dtype,
)
).all()
Loading