Skip to content

Commit

Permalink
feat: arrow series (#550)
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi authored Jul 18, 2024
1 parent 187b7c9 commit ec0ad3d
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 12 deletions.
11 changes: 11 additions & 0 deletions narwhals/_arrow/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,17 @@ def tail(self, n: int) -> Self:
def is_in(self, other: ArrowExpr | Any) -> Self:
return reuse_series_implementation(self, "is_in", other)

def sample(
self: Self,
n: int | None = None,
fraction: float | None = None,
*,
with_replacement: bool = False,
) -> Self:
return reuse_series_implementation(
self, "sample", n=n, fraction=fraction, with_replacement=with_replacement
)

@property
def dt(self) -> ArrowExprDateTimeNamespace:
return ArrowExprDateTimeNamespace(self)
Expand Down
20 changes: 20 additions & 0 deletions narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from narwhals._arrow.utils import reverse_translate_dtype
from narwhals._arrow.utils import translate_dtype
from narwhals._arrow.utils import validate_column_comparand
from narwhals.dependencies import get_numpy
from narwhals.dependencies import get_pyarrow
from narwhals.dependencies import get_pyarrow_compute

Expand Down Expand Up @@ -294,6 +295,25 @@ def zip_with(self: Self, mask: Self, other: Self) -> Self:
)
)

def sample(
self: Self,
n: int | None = None,
fraction: float | None = None,
*,
with_replacement: bool = False,
) -> Self:
np = get_numpy()
pc = get_pyarrow_compute()
ser = self._native_series
num_rows = len(self)

if n is None and fraction is not None:
n = int(num_rows * fraction)

idx = np.arange(0, num_rows)
mask = np.random.choice(idx, size=n, replace=with_replacement)
return self._from_native_series(pc.take(ser, mask))

@property
def shape(self) -> tuple[int]:
return (len(self._native_series),)
Expand Down
32 changes: 21 additions & 11 deletions tests/expr_and_series/sample_test.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,27 @@
from typing import Any

import pytest

import narwhals.stable.v1 as nw


def test_expr_sample(request: Any, constructor: Any) -> None:
if "pyarrow_table" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_expr_sample(constructor: Any) -> None:
df = nw.from_native(constructor({"a": [1, 2, 3], "b": [4, 5, 6]})).lazy()
result_shape = nw.to_native(df.select(nw.col("a").sample(n=2)).collect()).shape
expected = (2, 1)
assert result_shape == expected
result_shape = nw.to_native(df.collect()["a"].sample(n=2)).shape
expected = (2,) # type: ignore[assignment]
assert result_shape == expected

result_expr = df.select(nw.col("a").sample(n=2)).collect().shape
expected_expr = (2, 1)
assert result_expr == expected_expr

result_series = df.collect()["a"].sample(n=2).shape
expected_series = (2,)
assert result_series == expected_series


def test_expr_sample_fraction(constructor: Any) -> None:
df = nw.from_native(constructor({"a": [1, 2, 3] * 10, "b": [4, 5, 6] * 10})).lazy()

result_expr = df.select(nw.col("a").sample(fraction=0.1)).collect().shape
expected_expr = (3, 1)
assert result_expr == expected_expr

result_series = df.collect()["a"].sample(fraction=0.1).shape
expected_series = (3,)
assert result_series == expected_series
1 change: 0 additions & 1 deletion utils/check_backend_completeness.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
"Series.n_unique",
"Series.quantile",
"Series.round",
"Series.sample",
"Series.shift",
"Series.sort",
"Series.to_frame",
Expand Down

0 comments on commit ec0ad3d

Please sign in to comment.