Skip to content

Commit

Permalink
Merge pull request #9 from raisadz/main
Browse files Browse the repository at this point in the history
Add tests for common dataframe operations
  • Loading branch information
MarcoGorelli authored Mar 16, 2024
2 parents 9ce6f21 + 23ec222 commit b44d749
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 8 deletions.
5 changes: 4 additions & 1 deletion narwhals/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Iterable

from narwhals.dtypes import translate_dtype
from narwhals.utils import flatten

if TYPE_CHECKING:
from narwhals.typing import IntoExpr
Expand Down Expand Up @@ -184,7 +185,9 @@ def max(*columns: str) -> Expr:


def sum_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr:
return Expr(lambda plx: plx.sum_horizontal(*exprs))
return Expr(
lambda plx: plx.sum_horizontal([extract_native(plx, v) for v in flatten(exprs)])
)


__all__ = [
Expand Down
4 changes: 2 additions & 2 deletions narwhals/pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,9 @@ def filter(
self,
*predicates: IntoPandasExpr | Iterable[IntoPandasExpr],
) -> Self:
from narwhals.pandas_like.namespace import Namespace
from narwhals.pandas_like.namespace import PandasNamespace

plx = Namespace(self._implementation)
plx = PandasNamespace(self._implementation)
expr = plx.all_horizontal(*predicates)
# Safety: all_horizontal's expression only returns a single column.
mask = expr._call(self)[0]
Expand Down
2 changes: 1 addition & 1 deletion narwhals/pandas_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from narwhals.pandas_like.typing import IntoPandasExpr


class Namespace:
class PandasNamespace:
Int64 = dtypes.Int64
Int32 = dtypes.Int32
Int16 = dtypes.Int16
Expand Down
11 changes: 7 additions & 4 deletions narwhals/pandas_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,13 @@ def parse_into_exprs(

def parse_into_expr(implementation: str, into_expr: IntoPandasExpr) -> PandasExpr:
from narwhals.expression import Expr
from narwhals.pandas_like.namespace import Namespace
from narwhals.pandas_like.expr import PandasExpr
from narwhals.pandas_like.namespace import PandasNamespace

plx = Namespace(implementation=implementation)
plx = PandasNamespace(implementation=implementation)

if isinstance(into_expr, PandasExpr):
return into_expr
if isinstance(into_expr, Expr):
return into_expr._call(plx)
if isinstance(into_expr, str):
Expand Down Expand Up @@ -141,10 +144,10 @@ def evaluate_into_exprs(

def register_expression_call(expr: ExprT, attr: str, *args: Any, **kwargs: Any) -> ExprT:
from narwhals.pandas_like.expr import PandasExpr
from narwhals.pandas_like.namespace import Namespace
from narwhals.pandas_like.namespace import PandasNamespace
from narwhals.pandas_like.series import PandasSeries

plx = Namespace(implementation=expr._implementation)
plx = PandasNamespace(implementation=expr._implementation)

def func(df: PandasDataFrame) -> list[PandasSeries]:
out: list[PandasSeries] = []
Expand Down
120 changes: 120 additions & 0 deletions tests/test_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from __future__ import annotations

from typing import Any

import pandas as pd
import polars as pl
import pytest

import narwhals as nw
from tests.utils import compare_dicts

df_pandas = pd.DataFrame({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]})
df_polars = pl.DataFrame({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]})


@pytest.mark.parametrize(
"df_raw",
[df_pandas, df_polars],
)
def test_sort(df_raw: Any) -> None:
df = nw.DataFrame(df_raw)
result = df.sort("a", "b")
result_native = nw.to_native(result)
expected = {
"a": [1, 2, 3],
"b": [4, 6, 4],
"z": [7.0, 9.0, 8.0],
}
compare_dicts(result_native, expected)


@pytest.mark.parametrize(
"df_raw",
[df_pandas, df_polars],
)
def test_filter(df_raw: Any) -> None:
df = nw.DataFrame(df_raw)
result = df.filter(nw.col("a") > 1)
result_native = nw.to_native(result)
expected = {"a": [3, 2], "b": [4, 6], "z": [8.0, 9.0]}
compare_dicts(result_native, expected)


@pytest.mark.parametrize(
"df_raw",
[df_pandas, df_polars],
)
def test_add(df_raw: Any) -> None:
df = nw.DataFrame(df_raw)
result = df.with_columns(
c=nw.col("a") + nw.col("b"),
d=nw.col("a") - nw.col("a").mean(),
)
result_native = nw.to_native(result)
expected = {
"a": [1, 3, 2],
"b": [4, 4, 6],
"z": [7.0, 8.0, 9.0],
"c": [5, 7, 8],
"d": [-1.0, 1.0, 0.0],
}
compare_dicts(result_native, expected)


@pytest.mark.parametrize(
"df_raw",
[df_pandas, df_polars],
)
def test_double(df_raw: Any) -> None:
df = nw.DataFrame(df_raw)
result = df.with_columns(nw.all() * 2)
result_native = nw.to_native(result)
expected = {"a": [2, 6, 4], "b": [8, 8, 12], "z": [14.0, 16.0, 18.0]}
compare_dicts(result_native, expected)


@pytest.mark.parametrize("df_raw", [df_pandas, df_polars])
def test_sumh(df_raw: Any) -> None:
df = nw.DataFrame(df_raw)
result = df.with_columns(horizonal_sum=nw.sum_horizontal(nw.col("a"), nw.col("b")))
result_native = nw.to_native(result)
expected = {
"a": [1, 3, 2],
"b": [4, 4, 6],
"z": [7.0, 8.0, 9.0],
"horizonal_sum": [5, 7, 8],
}
compare_dicts(result_native, expected)


@pytest.mark.parametrize("df_raw", [df_pandas, df_polars])
def test_sumh_literal(df_raw: Any) -> None:
df = nw.DataFrame(df_raw)
result = df.with_columns(horizonal_sum=nw.sum_horizontal("a", nw.col("b")))
result_native = nw.to_native(result)
expected = {
"a": [1, 3, 2],
"b": [4, 4, 6],
"z": [7.0, 8.0, 9.0],
"horizonal_sum": [5, 7, 8],
}
compare_dicts(result_native, expected)


@pytest.mark.parametrize("df_raw", [df_pandas, df_polars])
def test_sum_all(df_raw: Any) -> None:
df = nw.DataFrame(df_raw)
result = df.select(nw.all().sum())
result_native = nw.to_native(result)
expected = {"a": [6], "b": [14], "z": [24.0]}
compare_dicts(result_native, expected)


@pytest.mark.parametrize("df_raw", [df_pandas, df_polars])
def test_double_selected(df_raw: Any) -> None:
df = nw.DataFrame(df_raw)
result = df.select(nw.col("a", "b") * 2)
result_native = nw.to_native(result)
expected = {"a": [2, 6, 4], "b": [8, 8, 12]}
compare_dicts(result_native, expected)

0 comments on commit b44d749

Please sign in to comment.