Skip to content

Commit

Permalink
chore: add constructor fixture for tests (#373)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli authored Jul 1, 2024
1 parent 1e16c17 commit b5c3c51
Show file tree
Hide file tree
Showing 41 changed files with 192 additions and 218 deletions.
37 changes: 26 additions & 11 deletions narwhals/_pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Literal
from typing import Sequence

from narwhals._pandas_like.utils import int_dtype_mapper
from narwhals._pandas_like.utils import native_series_from_iterable
from narwhals._pandas_like.utils import reverse_translate_dtype
from narwhals._pandas_like.utils import to_datetime
Expand Down Expand Up @@ -609,14 +610,23 @@ def second(self) -> PandasSeries:
)

def millisecond(self) -> PandasSeries:
if "pyarrow" in str(self._series._series.dtype):
msg = ".dt.millisecond not implemented for pyarrow-backed pandas"
raise NotImplementedError(msg)
return self._series._from_series(
self._series._series.dt.microsecond // 1000,
)

def microsecond(self) -> PandasSeries:
if "pyarrow" in str(self._series._series.dtype):
msg = ".dt.microsecond not implemented for pyarrow-backed pandas"
raise NotImplementedError(msg)
return self._series._from_series(self._series._series.dt.microsecond)

def nanosecond(self) -> PandasSeries:
if "pyarrow" in str(self._series._series.dtype):
msg = ".dt.nanosecond not implemented for pyarrow-backed pandas"
raise NotImplementedError(msg)
return self._series._from_series(
(
(self._series._series.dt.microsecond * 1_000)
Expand All @@ -639,54 +649,59 @@ def ordinal_day(self) -> PandasSeries:
def total_minutes(self) -> PandasSeries:
s = self._series._series.dt.total_seconds()
s_sign = (
2 * (s > 0).astype(int) - 1
2 * (s > 0).astype(int_dtype_mapper(s.dtype)) - 1
) # this calculates the sign of each series element
s_abs = s.abs() // 60
if ~s.isna().any():
s_abs = s_abs.astype(int)
s_abs = s_abs.astype(int_dtype_mapper(s.dtype))
return self._series._from_series(s_abs * s_sign)

def total_seconds(self) -> PandasSeries:
s = self._series._series.dt.total_seconds()
s_sign = (
2 * (s > 0).astype(int) - 1
2 * (s > 0).astype(int_dtype_mapper(s.dtype)) - 1
) # this calculates the sign of each series element
s_abs = s.abs() // 1
if ~s.isna().any():
s_abs = s_abs.astype(int)
s_abs = s_abs.astype(int_dtype_mapper(s.dtype))
return self._series._from_series(s_abs * s_sign)

def total_milliseconds(self) -> PandasSeries:
s = self._series._series.dt.total_seconds() * 1e3
s_sign = (
2 * (s > 0).astype(int) - 1
2 * (s > 0).astype(int_dtype_mapper(s.dtype)) - 1
) # this calculates the sign of each series element
s_abs = s.abs() // 1
if ~s.isna().any():
s_abs = s_abs.astype(int)
s_abs = s_abs.astype(int_dtype_mapper(s.dtype))
return self._series._from_series(s_abs * s_sign)

def total_microseconds(self) -> PandasSeries:
s = self._series._series.dt.total_seconds() * 1e6
s_sign = (
2 * (s > 0).astype(int) - 1
2 * (s > 0).astype(int_dtype_mapper(s.dtype)) - 1
) # this calculates the sign of each series element
s_abs = s.abs() // 1
if ~s.isna().any():
s_abs = s_abs.astype(int)
s_abs = s_abs.astype(int_dtype_mapper(s.dtype))
return self._series._from_series(s_abs * s_sign)

def total_nanoseconds(self) -> PandasSeries:
s = self._series._series.dt.total_seconds() * 1e9
s_sign = (
2 * (s > 0).astype(int) - 1
2 * (s > 0).astype(int_dtype_mapper(s.dtype)) - 1
) # this calculates the sign of each series element
s_abs = s.abs() // 1
if ~s.isna().any():
s_abs = s_abs.astype(int)
s_abs = s_abs.astype(int_dtype_mapper(s.dtype))
return self._series._from_series(s_abs * s_sign)

def to_string(self, format: str) -> PandasSeries: # noqa: A002
# Polars' parser treats `'%.f'` as pandas does `'.%f'`
format = format.replace("%.f", ".%f")
# PyArrow interprets `'%S'` as "seconds, plus fractional seconds"
# and doesn't support `%f`
if "pyarrow" not in str(self._series._series.dtype):
format = format.replace("%S%.f", "%S.%f")
else:
format = format.replace("%S.%f", "%S").replace("%S%.f", "%S")
return self._series._from_series(self._series._series.dt.strftime(format))
8 changes: 8 additions & 0 deletions narwhals/_pandas_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,3 +600,11 @@ def to_datetime(implementation: str) -> Any:
if implementation == "cudf":
return get_cudf().to_datetime
raise AssertionError


def int_dtype_mapper(dtype: Any) -> str:
if "pyarrow" in str(dtype):
return "Int64[pyarrow]"
if str(dtype).lower() != str(dtype): # pragma: no cover
return "Int64"
return "int64"
10 changes: 4 additions & 6 deletions narwhals/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ def drop_nulls(self) -> Self:
>>> import polars as pl
>>> import pandas as pd
>>> import narwhals as nw
>>> data = {"a": [1.0, 2.0, None], "ba": [1, None, 2.0]}
>>> data = {"a": [1.0, 2.0, None], "ba": [1.0, None, 2.0]}
>>> df_pd = pd.DataFrame(data)
>>> df_pl = pl.DataFrame(data)
Expand Down Expand Up @@ -1828,9 +1828,7 @@ def pipe(self, function: Callable[[Any], Self], *args: Any, **kwargs: Any) -> Se
>>> def func(df_any):
... df = nw.from_native(df_any)
... df = df.pipe(
... lambda _df: _df.select([x for x in _df.columns if len(x) == 1])
... )
... df = df.pipe(lambda _df: _df.select("a"))
... return nw.to_native(df)
We can then pass either pandas or Polars:
Expand Down Expand Up @@ -1866,7 +1864,7 @@ def drop_nulls(self) -> Self:
>>> import polars as pl
>>> import pandas as pd
>>> import narwhals as nw
>>> data = {"a": [1.0, 2.0, None], "ba": [1, None, 2.0]}
>>> data = {"a": [1.0, 2.0, None], "ba": [1.0, None, 2.0]}
>>> df_pd = pd.DataFrame(data)
>>> df_pl = pl.LazyFrame(data)
Expand Down Expand Up @@ -1972,7 +1970,7 @@ def columns(self) -> list[str]:
... }
... ).select("foo", "bar")
>>> lf = nw.from_native(lf_pl)
>>> lf.columns
>>> lf.columns # doctest: +SKIP
['foo', 'bar']
"""
return super().columns
Expand Down
2 changes: 1 addition & 1 deletion narwhals/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -2787,7 +2787,7 @@ def to_string(self, format: str) -> Expr: # noqa: A002
Therefore, we make the following adjustments:
- for pandas-like libraries, we replace `".%f"` with `"%.f"`.
- for pandas-like libraries, we replace `"%S.%f"` with `"%S%.f"`.
- for PyArrow, we replace `"%S.%f"` with `"%S"`.
Workarounds like these don't make us happy, and we try to avoid them as
Expand Down
2 changes: 1 addition & 1 deletion narwhals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2775,7 +2775,7 @@ def to_string(self, format: str) -> Series: # noqa: A002
Therefore, we make the following adjustments:
- for pandas-like libraries, we replace `".%f"` with `"%.f"`.
- for pandas-like libraries, we replace `"%S.%f"` with `"%S%.f"`.
- for PyArrow, we replace `"%S.%f"` with `"%S"`.
Workarounds like these don't make us happy, and we try to avoid them as
Expand Down
50 changes: 50 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
import os
from typing import Any
from typing import Callable

import pandas as pd
import polars as pl
import pyarrow as pa
import pytest

from narwhals.dependencies import get_modin
from narwhals.typing import IntoDataFrame
from narwhals.utils import parse_version


def pytest_addoption(parser: Any) -> None:
parser.addoption(
Expand All @@ -21,3 +30,44 @@ def pytest_collection_modifyitems(config: Any, items: Any) -> Any: # pragma: no
for item in items:
if "slow" in item.keywords:
item.add_marker(skip_slow)


def pandas_constructor(obj: Any) -> IntoDataFrame:
return pd.DataFrame(obj) # type: ignore[no-any-return]


def pandas_nullable_constructor(obj: Any) -> IntoDataFrame:
return pd.DataFrame(obj).convert_dtypes() # type: ignore[no-any-return]


def pandas_pyarrow_constructor(obj: Any) -> IntoDataFrame:
return pd.DataFrame(obj).convert_dtypes(dtype_backend="pyarrow") # type: ignore[no-any-return]


def modin_constructor(obj: Any) -> IntoDataFrame: # pragma: no cover
return pd.DataFrame(obj).convert_dtypes(dtype_backend="pyarrow") # type: ignore[no-any-return]


def polars_constructor(obj: Any) -> IntoDataFrame:
return pl.DataFrame(obj)


if parse_version(pd.__version__) >= parse_version("1.5.0"):
params = [pandas_constructor, pandas_nullable_constructor, pandas_pyarrow_constructor]
else: # pragma: no cover
params = [pandas_constructor]
params.append(polars_constructor)
if os.environ.get("CI") and get_modin() is not None: # pragma: no cover
params.append(modin_constructor)


@pytest.fixture(params=params)
def constructor(request: Any) -> Callable[[Any], IntoDataFrame]:
return request.param # type: ignore[no-any-return]


# TODO: once pyarrow has complete coverage, we can remove this one,
# and just put `pa.table` into `constructor`
@pytest.fixture(params=[*params, pa.table])
def constructor_with_pyarrow(request: Any) -> Callable[[Any], IntoDataFrame]:
return request.param # type: ignore[no-any-return]
22 changes: 22 additions & 0 deletions tests/expr/any_all_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import Any

import narwhals as nw
from tests.utils import compare_dicts


def test_any_all(constructor: Any) -> None:
df = nw.from_native(
constructor(
{
"a": [True, False, True],
"b": [True, True, True],
"c": [False, False, False],
}
)
)
result = nw.to_native(df.select(nw.all().all()))
expected = {"a": [False], "b": [True], "c": [False]}
compare_dicts(result, expected)
result = nw.to_native(df.select(nw.all().any()))
expected = {"a": [True], "b": [True], "c": [False]}
compare_dicts(result, expected)
5 changes: 0 additions & 5 deletions tests/expr/cat/get_categories_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,12 @@

from typing import Any

import pandas as pd
import polars as pl
import pytest

import narwhals as nw
from tests.utils import compare_dicts

data = {"a": ["one", "two", "two"]}


@pytest.mark.parametrize("constructor", [pd.DataFrame, pl.DataFrame])
def test_get_categories(constructor: Any) -> None:
df = nw.from_native(constructor(data), eager_only=True)
df = df.select(nw.col("a").cast(nw.Categorical))
Expand Down
10 changes: 2 additions & 8 deletions tests/expr/cum_sum_test.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
from typing import Any

import pandas as pd
import polars as pl
import pyarrow as pa
import pytest

import narwhals as nw
from tests.utils import compare_dicts

Expand All @@ -15,9 +10,8 @@
}


@pytest.mark.parametrize("constructor", [pd.DataFrame, pl.DataFrame, pa.table])
def test_cum_sum_simple(constructor: Any) -> None:
df = nw.from_native(constructor(data), eager_only=True)
def test_cum_sum_simple(constructor_with_pyarrow: Any) -> None:
df = nw.from_native(constructor_with_pyarrow(data), eager_only=True)
result = df.select(nw.all().cum_sum())
expected = {
"a": [0, 1, 3, 6, 10],
Expand Down
5 changes: 0 additions & 5 deletions tests/expr/diff_test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
from typing import Any

import pandas as pd
import polars as pl
import pytest

import narwhals as nw
from tests.utils import compare_dicts

Expand All @@ -14,7 +10,6 @@
}


@pytest.mark.parametrize("constructor", [pd.DataFrame, pl.DataFrame])
def test_over_single(constructor: Any) -> None:
df = nw.from_native(constructor(data), eager_only=True)
result = df.with_columns(c_diff=nw.col("c").diff()).filter(nw.col("i") > 0)
Expand Down
5 changes: 0 additions & 5 deletions tests/expr/fill_null_test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
from typing import Any

import pandas as pd
import polars as pl
import pytest

import narwhals as nw
from tests.utils import compare_dicts

Expand All @@ -14,7 +10,6 @@
}


@pytest.mark.parametrize("constructor", [pd.DataFrame, pl.DataFrame])
def test_over_single(constructor: Any) -> None:
df = nw.from_native(constructor(data), eager_only=True)
result = df.with_columns(nw.all().fill_null(99))
Expand Down
5 changes: 0 additions & 5 deletions tests/expr/filter_test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
from typing import Any

import pandas as pd
import polars as pl
import pytest

import narwhals as nw
from tests.utils import compare_dicts

Expand All @@ -15,7 +11,6 @@
}


@pytest.mark.parametrize("constructor", [pd.DataFrame, pl.DataFrame])
def test_filter(constructor: Any) -> None:
df = nw.from_native(constructor(data), eager_only=True)
result = df.select(nw.col("a").filter(nw.col("i") < 2, nw.col("c") == 5))
Expand Down
3 changes: 0 additions & 3 deletions tests/expr/is_between_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

from typing import Any

import pandas as pd
import polars as pl
import pytest

import narwhals as nw
Expand All @@ -14,7 +12,6 @@
}


@pytest.mark.parametrize("constructor", [pd.DataFrame, pl.DataFrame])
@pytest.mark.parametrize(
("closed", "expected"),
[
Expand Down
5 changes: 0 additions & 5 deletions tests/expr/is_duplicated_test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
from typing import Any

import pandas as pd
import polars as pl
import pytest

import narwhals as nw
from tests.utils import compare_dicts

Expand All @@ -13,7 +9,6 @@
}


@pytest.mark.parametrize("constructor", [pd.DataFrame, pl.DataFrame])
def test_is_duplicated(constructor: Any) -> None:
df = nw.from_native(constructor(data), eager_only=True)
result = df.select(nw.all().is_duplicated())
Expand Down
Loading

0 comments on commit b5c3c51

Please sign in to comment.