From ec0ad3d95bcf836f8e1bac6ac6ef8e929d2f6087 Mon Sep 17 00:00:00 2001 From: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com> Date: Thu, 18 Jul 2024 22:56:42 +0200 Subject: [PATCH] feat: arrow series (#550) --- narwhals/_arrow/expr.py | 11 ++++++++++ narwhals/_arrow/series.py | 20 +++++++++++++++++ tests/expr_and_series/sample_test.py | 32 ++++++++++++++++++---------- utils/check_backend_completeness.py | 1 - 4 files changed, 52 insertions(+), 12 deletions(-) diff --git a/narwhals/_arrow/expr.py b/narwhals/_arrow/expr.py index 576645b69..f8bfbffce 100644 --- a/narwhals/_arrow/expr.py +++ b/narwhals/_arrow/expr.py @@ -201,6 +201,17 @@ def tail(self, n: int) -> Self: def is_in(self, other: ArrowExpr | Any) -> Self: return reuse_series_implementation(self, "is_in", other) + def sample( + self: Self, + n: int | None = None, + fraction: float | None = None, + *, + with_replacement: bool = False, + ) -> Self: + return reuse_series_implementation( + self, "sample", n=n, fraction=fraction, with_replacement=with_replacement + ) + @property def dt(self) -> ArrowExprDateTimeNamespace: return ArrowExprDateTimeNamespace(self) diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index e02b89d3a..c287dd81f 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -8,6 +8,7 @@ from narwhals._arrow.utils import reverse_translate_dtype from narwhals._arrow.utils import translate_dtype from narwhals._arrow.utils import validate_column_comparand +from narwhals.dependencies import get_numpy from narwhals.dependencies import get_pyarrow from narwhals.dependencies import get_pyarrow_compute @@ -294,6 +295,25 @@ def zip_with(self: Self, mask: Self, other: Self) -> Self: ) ) + def sample( + self: Self, + n: int | None = None, + fraction: float | None = None, + *, + with_replacement: bool = False, + ) -> Self: + np = get_numpy() + pc = get_pyarrow_compute() + ser = self._native_series + num_rows = len(self) + + if n is None and fraction is not None: + n = int(num_rows * fraction) + + idx = np.arange(0, num_rows) + mask = np.random.choice(idx, size=n, replace=with_replacement) + return self._from_native_series(pc.take(ser, mask)) + @property def shape(self) -> tuple[int]: return (len(self._native_series),) diff --git a/tests/expr_and_series/sample_test.py b/tests/expr_and_series/sample_test.py index 4fb54837f..a19c686e6 100644 --- a/tests/expr_and_series/sample_test.py +++ b/tests/expr_and_series/sample_test.py @@ -1,17 +1,27 @@ from typing import Any -import pytest - import narwhals.stable.v1 as nw -def test_expr_sample(request: Any, constructor: Any) -> None: - if "pyarrow_table" in str(constructor): - request.applymarker(pytest.mark.xfail) +def test_expr_sample(constructor: Any) -> None: df = nw.from_native(constructor({"a": [1, 2, 3], "b": [4, 5, 6]})).lazy() - result_shape = nw.to_native(df.select(nw.col("a").sample(n=2)).collect()).shape - expected = (2, 1) - assert result_shape == expected - result_shape = nw.to_native(df.collect()["a"].sample(n=2)).shape - expected = (2,) # type: ignore[assignment] - assert result_shape == expected + + result_expr = df.select(nw.col("a").sample(n=2)).collect().shape + expected_expr = (2, 1) + assert result_expr == expected_expr + + result_series = df.collect()["a"].sample(n=2).shape + expected_series = (2,) + assert result_series == expected_series + + +def test_expr_sample_fraction(constructor: Any) -> None: + df = nw.from_native(constructor({"a": [1, 2, 3] * 10, "b": [4, 5, 6] * 10})).lazy() + + result_expr = df.select(nw.col("a").sample(fraction=0.1)).collect().shape + expected_expr = (3, 1) + assert result_expr == expected_expr + + result_series = df.collect()["a"].sample(fraction=0.1).shape + expected_series = (3,) + assert result_series == expected_series diff --git a/utils/check_backend_completeness.py b/utils/check_backend_completeness.py index 827eae5c6..90506dc4b 100644 --- a/utils/check_backend_completeness.py +++ b/utils/check_backend_completeness.py @@ -29,7 +29,6 @@ "Series.n_unique", "Series.quantile", "Series.round", - "Series.sample", "Series.shift", "Series.sort", "Series.to_frame",