From f141dd27faa54e7974494aaa8e962d31ed237b26 Mon Sep 17 00:00:00 2001 From: raisa-toptal <> Date: Sat, 16 Mar 2024 16:29:59 +0000 Subject: [PATCH 1/3] add tests for common dataframe operations --- tests/test_common.py | 120 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 120 insertions(+) create mode 100644 tests/test_common.py diff --git a/tests/test_common.py b/tests/test_common.py new file mode 100644 index 000000000..fda506420 --- /dev/null +++ b/tests/test_common.py @@ -0,0 +1,120 @@ +from __future__ import annotations + +from typing import Any + +import pandas as pd +import polars as pl +import pytest + +import narwhals as nw +from tests.utils import compare_dicts + +df_pandas = pd.DataFrame({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}) +df_polars = pl.DataFrame({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}) + + +@pytest.mark.parametrize( + "df_raw", + [df_pandas, df_polars], +) +def test_sort(df_raw: Any) -> None: + df = nw.DataFrame(df_raw) + result = df.sort("a", "b") + result_native = nw.to_native(result) + expected = { + "a": [1, 2, 3], + "b": [4, 6, 4], + "z": [7.0, 9.0, 8.0], + } + compare_dicts(result_native, expected) + + +@pytest.mark.parametrize( + "df_raw", + [df_pandas, df_polars], +) +def test_filter(df_raw: Any) -> None: + df = nw.DataFrame(df_raw) + result = df.filter(nw.col("a") > 1) + result_native = nw.to_native(result) + expected = {"a": [3, 2], "b": [4, 6], "z": [8.0, 9.0]} + compare_dicts(result_native, expected) + + +@pytest.mark.parametrize( + "df_raw", + [df_pandas, df_polars], +) +def test_add(df_raw: Any) -> None: + df = nw.DataFrame(df_raw) + result = df.with_columns( + c=nw.col("a") + nw.col("b"), + d=nw.col("a") - nw.col("a").mean(), + ) + result_native = nw.to_native(result) + expected = { + "a": [1, 3, 2], + "b": [4, 4, 6], + "z": [7.0, 8.0, 9.0], + "c": [5, 7, 8], + "d": [-1.0, 1.0, 0.0], + } + compare_dicts(result_native, expected) + + +@pytest.mark.parametrize( + "df_raw", + [df_pandas, df_polars], +) +def test_double(df_raw: Any) -> None: + df = nw.DataFrame(df_raw) + result = df.with_columns(nw.all() * 2) + result_native = nw.to_native(result) + expected = {"a": [2, 6, 4], "b": [8, 8, 12], "z": [14.0, 16.0, 18.0]} + compare_dicts(result_native, expected) + + +@pytest.mark.parametrize("df_raw", [df_pandas, df_polars]) +def test_sumh(df_raw: Any) -> None: + df = nw.DataFrame(df_raw) + result = df.with_columns(horizonal_sum=nw.sum_horizontal(nw.col("a"), nw.col("b"))) + result_native = nw.to_native(result) + expected = { + "a": [1, 3, 2], + "b": [4, 4, 6], + "z": [7.0, 8.0, 9.0], + "horizonal_sum": [5, 7, 8], + } + compare_dicts(result_native, expected) + + +@pytest.mark.parametrize("df_raw", [df_pandas, df_polars]) +def test_sumh_literal(df_raw: Any) -> None: + df = nw.DataFrame(df_raw) + result = df.with_columns(horizonal_sum=nw.sum_horizontal("a", nw.col("b"))) + result_native = nw.to_native(result) + expected = { + "a": [1, 3, 2], + "b": [4, 4, 6], + "z": [7.0, 8.0, 9.0], + "horizonal_sum": [5, 7, 8], + } + compare_dicts(result_native, expected) + + +@pytest.mark.parametrize("df_raw", [df_pandas, df_polars]) +def test_sum_all(df_raw: Any) -> None: + df = nw.DataFrame(df_raw) + result = df.select(nw.all().sum()) + result_native = nw.to_native(result) + expected = {"a": [6], "b": [14], "z": [24.0]} + compare_dicts(result_native, expected) + + +@pytest.mark.parametrize("df_raw", [df_pandas, df_polars]) +def test_double_selected(df_raw: Any) -> None: + df = nw.DataFrame(df_raw) + result = df.select(nw.col("a", "b") * 2) + result_native = nw.to_native(result) + expected = {"a": [2, 6, 4], "b": [8, 8, 12]} + compare_dicts(result_native, expected) From 1268b98e6d3970c3b6a34667acacb62b660920f0 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 16 Mar 2024 16:35:43 +0000 Subject: [PATCH 2/3] fixup sum_horizontal --- narwhals/expression.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/narwhals/expression.py b/narwhals/expression.py index 83a059344..11a9cb0a5 100644 --- a/narwhals/expression.py +++ b/narwhals/expression.py @@ -6,6 +6,7 @@ from typing import Iterable from narwhals.dtypes import translate_dtype +from narwhals.utils import flatten if TYPE_CHECKING: from narwhals.typing import IntoExpr @@ -184,7 +185,9 @@ def max(*columns: str) -> Expr: def sum_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: - return Expr(lambda plx: plx.sum_horizontal(*exprs)) + return Expr( + lambda plx: plx.sum_horizontal([extract_native(plx, v) for v in flatten(exprs)]) + ) __all__ = [ From 23ec222a85dbb9d10b9c090ec215854c4d3d89c6 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 16 Mar 2024 16:43:12 +0000 Subject: [PATCH 3/3] fixup pandas as well --- narwhals/pandas_like/dataframe.py | 4 ++-- narwhals/pandas_like/namespace.py | 2 +- narwhals/pandas_like/utils.py | 11 +++++++---- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/narwhals/pandas_like/dataframe.py b/narwhals/pandas_like/dataframe.py index 036261ca0..689d00b41 100644 --- a/narwhals/pandas_like/dataframe.py +++ b/narwhals/pandas_like/dataframe.py @@ -99,9 +99,9 @@ def filter( self, *predicates: IntoPandasExpr | Iterable[IntoPandasExpr], ) -> Self: - from narwhals.pandas_like.namespace import Namespace + from narwhals.pandas_like.namespace import PandasNamespace - plx = Namespace(self._implementation) + plx = PandasNamespace(self._implementation) expr = plx.all_horizontal(*predicates) # Safety: all_horizontal's expression only returns a single column. mask = expr._call(self)[0] diff --git a/narwhals/pandas_like/namespace.py b/narwhals/pandas_like/namespace.py index 5fb009ae2..bc418f7bd 100644 --- a/narwhals/pandas_like/namespace.py +++ b/narwhals/pandas_like/namespace.py @@ -19,7 +19,7 @@ from narwhals.pandas_like.typing import IntoPandasExpr -class Namespace: +class PandasNamespace: Int64 = dtypes.Int64 Int32 = dtypes.Int32 Int16 = dtypes.Int16 diff --git a/narwhals/pandas_like/utils.py b/narwhals/pandas_like/utils.py index bf7ffd1bb..17f59971c 100644 --- a/narwhals/pandas_like/utils.py +++ b/narwhals/pandas_like/utils.py @@ -97,10 +97,13 @@ def parse_into_exprs( def parse_into_expr(implementation: str, into_expr: IntoPandasExpr) -> PandasExpr: from narwhals.expression import Expr - from narwhals.pandas_like.namespace import Namespace + from narwhals.pandas_like.expr import PandasExpr + from narwhals.pandas_like.namespace import PandasNamespace - plx = Namespace(implementation=implementation) + plx = PandasNamespace(implementation=implementation) + if isinstance(into_expr, PandasExpr): + return into_expr if isinstance(into_expr, Expr): return into_expr._call(plx) if isinstance(into_expr, str): @@ -141,10 +144,10 @@ def evaluate_into_exprs( def register_expression_call(expr: ExprT, attr: str, *args: Any, **kwargs: Any) -> ExprT: from narwhals.pandas_like.expr import PandasExpr - from narwhals.pandas_like.namespace import Namespace + from narwhals.pandas_like.namespace import PandasNamespace from narwhals.pandas_like.series import PandasSeries - plx = Namespace(implementation=expr._implementation) + plx = PandasNamespace(implementation=expr._implementation) def func(df: PandasDataFrame) -> list[PandasSeries]: out: list[PandasSeries] = []