Skip to content

Commit

Permalink
sort out dataframe.to_numpy too
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Nov 2, 2024
1 parent 434bd5e commit 75badf8
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 3 deletions.
18 changes: 16 additions & 2 deletions narwhals/_pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,18 +701,32 @@ def to_numpy(self, dtype: Any = None, copy: bool | None = None) -> Any:
if dtype is not None:
return self._native_frame.to_numpy(dtype=dtype, copy=copy)

convert = set()
for key, val in self.schema.items():
if val == self._dtypes.Datetime and val.time_zone is not None: # type: ignore[attr-defined]
convert.add(key)
df = self.with_columns(
*[
self.__narwhals_namespace__()
.col(x)
.dt.convert_time_zone("UTC")
.dt.replace_time_zone(None)
for x in convert
]
)._native_frame

# pandas return `object` dtype for nullable dtypes if dtype=None,
# so we cast each Series to numpy and let numpy find a common dtype.
# If there aren't any dtypes where `to_numpy()` is "broken" (i.e. it
# returns Object) then we just call `to_numpy()` on the DataFrame.
for col_dtype in self._native_frame.dtypes:
for col_dtype in df.dtypes:
if str(col_dtype) in PANDAS_TO_NUMPY_DTYPE_MISSING:
import numpy as np # ignore-banned-import

return np.hstack(
[self[col].to_numpy(copy=copy)[:, None] for col in self.columns]
)
return self._native_frame.to_numpy(copy=copy)
return df.to_numpy(copy=copy)

def to_pandas(self) -> Any:
if self._implementation is Implementation.PANDAS:
Expand Down
19 changes: 19 additions & 0 deletions tests/frame/to_numpy_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from datetime import datetime
from typing import TYPE_CHECKING

import numpy as np
Expand All @@ -18,3 +19,21 @@ def test_to_numpy(constructor_eager: ConstructorEager) -> None:
expected = np.array([[1, 3, 2], [4, 4, 6], [7.1, 8, 9]]).T
np.testing.assert_array_equal(result, expected)
assert result.dtype == "float64"


def test_to_numpy_tz_aware(constructor_eager: ConstructorEager) -> None:
df = nw.from_native(
constructor_eager({"a": [datetime(2020, 1, 1), datetime(2020, 1, 2)]}),
eager_only=True,
)
df = df.select(nw.col("a").dt.replace_time_zone("Asia/Kathmandu"))
result = df.to_numpy()
# for some reason, NumPy uses 'M' for datetimes
assert result.dtype.kind == "M"
assert (
result
== np.array(
[["2019-12-31T18:15:00.000000"], ["2020-01-01T18:15:00.000000"]],
dtype=result.dtype,
)
).all()
7 changes: 6 additions & 1 deletion tests/series_only/to_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from numpy.testing import assert_array_equal

import narwhals.stable.v1 as nw
from tests.utils import PYARROW_VERSION

if TYPE_CHECKING:
from tests.utils import ConstructorEager
Expand All @@ -33,7 +34,11 @@ def test_to_numpy(
assert_array_equal(s.to_numpy(), np.array(data, dtype=float))


def test_to_numpy_tz_aware(constructor_eager: ConstructorEager) -> None:
def test_to_numpy_tz_aware(
constructor_eager: ConstructorEager, request: pytest.FixtureRequest
) -> None:
if "pyarrow_table" in str(constructor_eager) and PYARROW_VERSION < (12,):
request.applymarker(pytest.mark.xfail)
df = nw.from_native(
constructor_eager({"a": [datetime(2020, 1, 1), datetime(2020, 1, 2)]}),
eager_only=True,
Expand Down

0 comments on commit 75badf8

Please sign in to comment.