From 2bc3a6f2f7a5655b01d73b3a3088825475cca027 Mon Sep 17 00:00:00 2001 From: Mfon Ekpo <58835748+mfonekpo@users.noreply.github.com> Date: Mon, 15 Jul 2024 12:19:25 +0100 Subject: [PATCH] feat: pyarrow `Series.sum` (#495) * Feat: Series.sum test case * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * test_series_sum refactoring * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixing according to standard * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactoring code to satisfy code criteria * Fixing Series.sum functionality for Pyarrow DF * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixing compare mistake on Series.sum for Pyarrow * cleaning up code * removing 'pyarrow_table' from test_sum_all and test_renamed_taxicab * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * making sure CI passes all test cases * all test cases passed on CI --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- narwhals/_arrow/expr.py | 3 +++ narwhals/_arrow/series.py | 4 ++++ tests/expr/sum_all_test.py | 7 +------ tests/frame/series_sum_test.py | 23 +++++++++++++++++++++++ tests/stable_api_test.py | 5 +---- utils/check_backend_completeness.py | 1 - 6 files changed, 32 insertions(+), 11 deletions(-) create mode 100644 tests/frame/series_sum_test.py diff --git a/narwhals/_arrow/expr.py b/narwhals/_arrow/expr.py index eea0c6cc6..9635b8353 100644 --- a/narwhals/_arrow/expr.py +++ b/narwhals/_arrow/expr.py @@ -165,6 +165,9 @@ def max(self) -> Self: def all(self) -> Self: return reuse_series_implementation(self, "all", returns_scalar=True) + def sum(self) -> Self: + return reuse_series_implementation(self, "sum", returns_scalar=True) + def alias(self, name: str) -> Self: # Define this one manually, so that we can # override `output_names` and not increase depth diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index af9fc7c00..5d1ae9a96 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -155,6 +155,10 @@ def max(self) -> int: pc = get_pyarrow_compute() return pc.max(self._native_series) # type: ignore[no-any-return] + def sum(self) -> int: + pc = get_pyarrow_compute() + return pc.sum(self._native_series) # type: ignore[no-any-return] + def std(self, ddof: int = 1) -> int: pc = get_pyarrow_compute() return pc.stddev(self._native_series, ddof=ddof) # type: ignore[no-any-return] diff --git a/tests/expr/sum_all_test.py b/tests/expr/sum_all_test.py index cf2103f88..2fbfb6251 100644 --- a/tests/expr/sum_all_test.py +++ b/tests/expr/sum_all_test.py @@ -1,15 +1,10 @@ from typing import Any -import pytest - import narwhals.stable.v1 as nw from tests.utils import compare_dicts -def test_sum_all(request: Any, constructor: Any) -> None: - if "pyarrow_table" in str(constructor): - request.applymarker(pytest.mark.xfail) - +def test_sum_all(constructor: Any) -> None: data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} df = nw.from_native(constructor(data), eager_only=True) result = df.select(nw.all().sum()) diff --git a/tests/frame/series_sum_test.py b/tests/frame/series_sum_test.py new file mode 100644 index 000000000..3a1c1d95b --- /dev/null +++ b/tests/frame/series_sum_test.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +from typing import Any + +import narwhals.stable.v1 as nw +from tests.utils import compare_dicts + + +def test_series_sum(constructor: Any) -> None: + data = { + "a": [0, 1, 2, 3, 4], + "b": [1, 2, 3, 5, 3], + "c": [5, 4, None, 2, 1], + } + df = nw.from_native( + constructor(data), strict=False, eager_only=True, allow_series=True + ) + + result = df.select(nw.col("a", "b", "c").sum()) + + expected_sum = {"a": [10], "b": [14], "c": [12]} + + compare_dicts(result, expected_sum) diff --git a/tests/stable_api_test.py b/tests/stable_api_test.py index 205eebeed..211ba8652 100644 --- a/tests/stable_api_test.py +++ b/tests/stable_api_test.py @@ -8,7 +8,7 @@ from tests.utils import compare_dicts -def test_renamed_taxicab_norm(request: Any, constructor: Any) -> None: +def test_renamed_taxicab_norm(constructor: Any) -> None: # Suppose we need to rename `_l1_norm` to `_taxicab_norm`. # We need `narwhals.stable.v1` to stay stable. So, we # make the change in `narwhals`, and then add the new method @@ -17,9 +17,6 @@ def test_renamed_taxicab_norm(request: Any, constructor: Any) -> None: # API will still be able to use it, without the main namespace # getting cluttered by the new name. - if "pyarrow_table" in str(constructor): - request.applymarker(pytest.mark.xfail) - df = nw.from_native(constructor({"a": [1, 2, 3, -4, 5]})) result = df.with_columns(b=nw.col("a")._taxicab_norm()) expected = {"a": [1, 2, 3, -4, 5], "b": [15] * 5} diff --git a/utils/check_backend_completeness.py b/utils/check_backend_completeness.py index a9b6ae73c..ab2eec287 100644 --- a/utils/check_backend_completeness.py +++ b/utils/check_backend_completeness.py @@ -35,7 +35,6 @@ "Series.sample", "Series.shift", "Series.sort", - "Series.sum", "Series.to_frame", "Series.to_pandas", "Series.unique",