Skip to content

Commit

Permalink
selector refactor and rest
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi committed Jul 13, 2024
1 parent 2096437 commit 280d4c6
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 38 deletions.
50 changes: 24 additions & 26 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,40 +48,43 @@ def modin_constructor(obj: Any) -> IntoDataFrame: # pragma: no cover
return mpd.DataFrame(obj).convert_dtypes(dtype_backend="pyarrow") # type: ignore[no-any-return]


def polars_constructor(obj: Any) -> IntoDataFrame:
def polars_eager_constructor(obj: Any) -> IntoDataFrame:
return pl.DataFrame(obj)


def polars_lazy_constructor(obj: Any) -> pl.LazyFrame:
return pl.LazyFrame(obj)


def pyarrow_table_constructor(obj: Any) -> IntoDataFrame:
return pa.table(obj) # type: ignore[no-any-return]


if parse_version(pd.__version__) >= parse_version("2.0.0"):
params = [pandas_constructor, pandas_nullable_constructor, pandas_pyarrow_constructor]
eager_constructors = [
pandas_constructor,
pandas_nullable_constructor,
pandas_pyarrow_constructor,
]
else: # pragma: no cover
params = [pandas_constructor]
params.append(polars_constructor)
eager_constructors = [pandas_constructor]

eager_constructors.extend([polars_eager_constructor, pyarrow_table_constructor])

if get_modin() is not None: # pragma: no cover
params.append(modin_constructor)
eager_constructors.append(modin_constructor)


@pytest.fixture(params=params)
@pytest.fixture(params=eager_constructors)
def constructor(request: Any) -> Callable[[Any], IntoDataFrame]:
return request.param # type: ignore[no-any-return]


@pytest.fixture(params=[*params, polars_lazy_constructor])
@pytest.fixture(params=[*eager_constructors, polars_lazy_constructor])
def constructor_with_lazy(request: Any) -> Callable[[Any], Any]:
return request.param # type: ignore[no-any-return]


# TODO(Unassigned): once pyarrow has complete coverage, we can remove this one,
# and just put `pa.table` into `constructor`
@pytest.fixture(params=[*params, pa.table])
def constructor_with_pyarrow(request: Any) -> Callable[[Any], IntoDataFrame]:
return request.param # type: ignore[no-any-return]


def pandas_series_constructor(obj: Any) -> Any:
return pd.Series(obj)

Expand All @@ -103,6 +106,10 @@ def polars_series_constructor(obj: Any) -> Any:
return pl.Series(obj)


def pyarrow_series_constructor(obj: Any) -> Any:
return pa.chunked_array([obj])


if parse_version(pd.__version__) >= parse_version("2.0.0"):
params_series = [
pandas_series_constructor,
Expand All @@ -111,22 +118,13 @@ def polars_series_constructor(obj: Any) -> Any:
]
else: # pragma: no cover
params_series = [pandas_series_constructor]
params_series.append(polars_series_constructor)
if get_modin() is not None: # pragma: no cover
params_series.append(modin_series_constructor)


@pytest.fixture(params=params_series)
def constructor_series(request: Any) -> Callable[[Any], Any]:
return request.param # type: ignore[no-any-return]
params_series.extend([polars_series_constructor, pyarrow_series_constructor])


def pyarrow_chunked_array_constructor(obj: Any) -> Any:
return pa.chunked_array([obj])


# TODO(Unassigned): once pyarrow has complete coverage, we can remove this one,
# and just put `pa.table` into `constructor`
@pytest.fixture(params=[*params_series, pyarrow_chunked_array_constructor])
def constructor_series_with_pyarrow(request: Any) -> Callable[[Any], Any]:
@pytest.fixture(params=params_series)
def constructor_series(request: Any) -> Callable[[Any], Any]:
return request.param # type: ignore[no-any-return]
3 changes: 0 additions & 3 deletions tests/expr/sum_horizontal_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@

@pytest.mark.parametrize("col_expr", [nw.col("a"), "a"])
def test_sumh(constructor: Any, col_expr: Any) -> None:
if "pyarrow_table" in str(constructor):
pytest.xfail()

data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}
df = nw.from_native(constructor(data), eager_only=True)
result = df.with_columns(horizontal_sum=nw.sum_horizontal(col_expr, nw.col("b")))
Expand Down
10 changes: 6 additions & 4 deletions tests/frame/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,10 @@ def test_std(constructor: Any) -> None:

# TODO(Unassigned): https://github.com/narwhals-dev/narwhals/issues/313
@pytest.mark.filterwarnings("ignore:Determining|Resolving.*")
def test_schema(constructor: Any) -> None:
df = nw.from_native(constructor({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.1, 8, 9]}))
def test_schema(constructor_with_lazy: Any) -> None:
df = nw.from_native(
constructor_with_lazy({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.1, 8, 9]})
)
result = df.schema
expected = {"a": nw.Int64, "b": nw.Int64, "z": nw.Float64}
assert result == expected
Expand All @@ -64,8 +66,8 @@ def test_schema(constructor: Any) -> None:

# TODO(Unassigned): https://github.com/narwhals-dev/narwhals/issues/313
@pytest.mark.filterwarnings("ignore:Determining|Resolving.*")
def test_columns(constructor: Any) -> None:
df = nw.from_native(constructor(data))
def test_columns(constructor_with_lazy: Any) -> None:
df = nw.from_native(constructor_with_lazy(data))
result = df.columns
expected = ["a", "b", "z"]
assert result == expected
Expand Down
4 changes: 4 additions & 0 deletions tests/stable_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ def test_renamed_taxicab_norm(constructor: Any) -> None:
# Here, we check that anyone who wrote code using the old
# API will still be able to use it, without the main namespace
# getting cluttered by the new name.

if "pyarrow_table" in str(constructor):
pytest.xfail()

df = nw.from_native(constructor({"a": [1, 2, 3, -4, 5]}))
result = df.with_columns(b=nw.col("a")._taxicab_norm())
expected = {"a": [1, 2, 3, -4, 5], "b": [15] * 5}
Expand Down
6 changes: 6 additions & 0 deletions tests/test_group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def test_invalid_group_by() -> None:


def test_group_by_iter(constructor: Any) -> None:
if "pyarrow_table" in str(constructor):
pytest.xfail()

df = nw.from_native(constructor(data), eager_only=True)
expected_keys = [(1,), (3,)]
keys = []
Expand All @@ -65,6 +68,9 @@ def test_group_by_iter(constructor: Any) -> None:


def test_group_by_len(constructor: Any) -> None:
if "pyarrow_table" in str(constructor):
pytest.xfail()

result = (
nw.from_native(constructor(data)).group_by("a").agg(nw.col("b").len()).sort("a")
)
Expand Down
17 changes: 16 additions & 1 deletion tests/test_selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,28 +23,40 @@
}


def test_selecctors(constructor: Any) -> None:
def test_selectors(constructor: Any) -> None:
if "pyarrow_table" in str(constructor):
pytest.xfail()

df = nw.from_native(constructor(data))
result = nw.to_native(df.select(by_dtype([nw.Int64, nw.Float64]) + 1))
expected = {"a": [2, 2, 3], "c": [5.1, 6.0, 7.0]}
compare_dicts(result, expected)


def test_numeric(constructor: Any) -> None:
if "pyarrow_table" in str(constructor):
pytest.xfail()

df = nw.from_native(constructor(data))
result = nw.to_native(df.select(numeric() + 1))
expected = {"a": [2, 2, 3], "c": [5.1, 6.0, 7.0]}
compare_dicts(result, expected)


def test_boolean(constructor: Any) -> None:
if "pyarrow_table" in str(constructor):
pytest.xfail()

df = nw.from_native(constructor(data))
result = nw.to_native(df.select(boolean()))
expected = {"d": [True, False, True]}
compare_dicts(result, expected)


def test_string(constructor: Any) -> None:
if "pyarrow_table" in str(constructor):
pytest.xfail()

df = nw.from_native(constructor(data))
result = nw.to_native(df.select(string()))
expected = {"b": ["a", "b", "c"]}
Expand Down Expand Up @@ -79,6 +91,9 @@ def test_categorical() -> None:
def test_set_ops(
constructor: Any, selector: nw.selectors.Selector, expected: list[str]
) -> None:
if "pyarrow_table" in str(constructor):
pytest.xfail()

df = nw.from_native(constructor(data))
result = df.select(selector).columns
assert sorted(result) == expected
Expand Down
6 changes: 2 additions & 4 deletions tests/translate/to_native_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,8 @@
),
],
)
def test_to_native(
constructor_with_pyarrow: Any, method: str, strict: Any, context: Any
) -> None:
df = nw.from_native(constructor_with_pyarrow({"a": [1, 2, 3]}))
def test_to_native(constructor: Any, method: str, strict: Any, context: Any) -> None:
df = nw.from_native(constructor({"a": [1, 2, 3]}))

with context:
nw.to_native(getattr(df, method)(), strict=strict)
Expand Down

0 comments on commit 280d4c6

Please sign in to comment.