Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(rust, python): repeat_by should also support broadcasting of LHS #10735

Merged
merged 2 commits into from
Aug 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 84 additions & 52 deletions crates/polars-core/src/chunked_array/ops/repeat_by.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ type LargeListArray = ListArray<i64>;

fn check_lengths(length_srs: usize, length_by: usize) -> PolarsResult<()> {
polars_ensure!(
(length_srs == length_by) | (length_by == 1),
ComputeError: "Length of repeat_by argument needs to be 1 or equal to the length of the Series. Series length {}, by length {}",
(length_srs == length_by) | (length_by == 1) | (length_srs == 1),
ComputeError: "repeat_by argument and the Series should have equal length, or at least one of them should have length 1. Series length {}, by length {}",
length_srs, length_by
);
Ok(())
Expand All @@ -22,95 +22,127 @@ where
fn repeat_by(&self, by: &IdxCa) -> PolarsResult<ListChunked> {
check_lengths(self.len(), by.len())?;

if (self.len() != by.len()) & (by.len() == 1) {
return self.repeat_by(&IdxCa::new(
match (self.len(), by.len()) {
(left_len, right_len) if left_len == right_len => {
Ok(arity::binary(self, by, |arr, by| {
let iter = arr.into_iter().zip(by).map(|(opt_v, opt_by)| {
opt_by.map(|by| std::iter::repeat(opt_v.copied()).take(*by as usize))
});

// SAFETY: length of iter is trusted.
unsafe {
LargeListArray::from_iter_primitive_trusted_len(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we can share most of the code with these types, but before that, maybe we should unify this call first. 🤔

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, might be a good follow up!

iter,
T::get_dtype().to_arrow(),
)
}
}))
},
(_, 1) => self.repeat_by(&IdxCa::new(
self.name(),
std::iter::repeat(by.get(0).unwrap())
.take(self.len())
.collect::<Vec<IdxSize>>(),
));
)),
(1, _) => {
let new_array = self.new_from_index(0, by.len());
new_array.repeat_by(by)
},
// we have already checked the length
_ => unreachable!(),
}

Ok(arity::binary(self, by, |arr, by| {
let iter = arr.into_iter().zip(by).map(|(opt_v, opt_by)| {
opt_by.map(|by| std::iter::repeat(opt_v.copied()).take(*by as usize))
});

// SAFETY: length of iter is trusted.
unsafe {
LargeListArray::from_iter_primitive_trusted_len(iter, T::get_dtype().to_arrow())
}
}))
}
}

impl RepeatBy for BooleanChunked {
fn repeat_by(&self, by: &IdxCa) -> PolarsResult<ListChunked> {
check_lengths(self.len(), by.len())?;

if (self.len() != by.len()) & (by.len() == 1) {
return self.repeat_by(&IdxCa::new(
match (self.len(), by.len()) {
(left_len, right_len) if left_len == right_len => {
Ok(arity::binary(self, by, |arr, by| {
let iter = arr.into_iter().zip(by).map(|(opt_v, opt_by)| {
opt_by.map(|by| std::iter::repeat(opt_v).take(*by as usize))
});

// SAFETY: length of iter is trusted.
unsafe { LargeListArray::from_iter_bool_trusted_len(iter) }
}))
},
(_, 1) => self.repeat_by(&IdxCa::new(
self.name(),
std::iter::repeat(by.get(0).unwrap())
.take(self.len())
.collect::<Vec<IdxSize>>(),
));
)),
(1, _) => {
let new_array = self.new_from_index(0, by.len());
new_array.repeat_by(by)
},
// we have already checked the length
_ => unreachable!(),
}

Ok(arity::binary(self, by, |arr, by| {
let iter = arr.into_iter().zip(by).map(|(opt_v, opt_by)| {
opt_by.map(|by| std::iter::repeat(opt_v).take(*by as usize))
});

// SAFETY: length of iter is trusted.
unsafe { LargeListArray::from_iter_bool_trusted_len(iter) }
}))
}
}
impl RepeatBy for Utf8Chunked {
fn repeat_by(&self, by: &IdxCa) -> PolarsResult<ListChunked> {
// TODO! dispatch via binary.
check_lengths(self.len(), by.len())?;

if (self.len() != by.len()) & (by.len() == 1) {
return self.repeat_by(&IdxCa::new(
match (self.len(), by.len()) {
(left_len, right_len) if left_len == right_len => {
Ok(arity::binary(self, by, |arr, by| {
let iter = arr.into_iter().zip(by).map(|(opt_v, opt_by)| {
opt_by.map(|by| std::iter::repeat(opt_v).take(*by as usize))
});

// SAFETY: length of iter is trusted.
unsafe { LargeListArray::from_iter_utf8_trusted_len(iter, self.len()) }
}))
},
(_, 1) => self.repeat_by(&IdxCa::new(
self.name(),
std::iter::repeat(by.get(0).unwrap())
.take(self.len())
.collect::<Vec<IdxSize>>(),
));
)),
(1, _) => {
let new_array = self.new_from_index(0, by.len());
new_array.repeat_by(by)
},
// we have already checked the length
_ => unreachable!(),
}

Ok(arity::binary(self, by, |arr, by| {
let iter = arr.into_iter().zip(by).map(|(opt_v, opt_by)| {
opt_by.map(|by| std::iter::repeat(opt_v).take(*by as usize))
});

// SAFETY: length of iter is trusted.
unsafe { LargeListArray::from_iter_utf8_trusted_len(iter, self.len()) }
}))
}
}

impl RepeatBy for BinaryChunked {
fn repeat_by(&self, by: &IdxCa) -> PolarsResult<ListChunked> {
check_lengths(self.len(), by.len())?;

if (self.len() != by.len()) & (by.len() == 1) {
return self.repeat_by(&IdxCa::new(
match (self.len(), by.len()) {
(left_len, right_len) if left_len == right_len => {
Ok(arity::binary(self, by, |arr, by| {
let iter = arr.into_iter().zip(by).map(|(opt_v, opt_by)| {
opt_by.map(|by| std::iter::repeat(opt_v).take(*by as usize))
});

// SAFETY: length of iter is trusted.
unsafe { LargeListArray::from_iter_binary_trusted_len(iter, self.len()) }
}))
},
(_, 1) => self.repeat_by(&IdxCa::new(
self.name(),
std::iter::repeat(by.get(0).unwrap())
.take(self.len())
.collect::<Vec<IdxSize>>(),
));
)),
(1, _) => {
let new_array = self.new_from_index(0, by.len());
new_array.repeat_by(by)
},
// we have already checked the length
_ => unreachable!(),
}

Ok(arity::binary(self, by, |arr, by| {
let iter = arr.into_iter().zip(by).map(|(opt_v, opt_by)| {
opt_by.map(|by| std::iter::repeat(opt_v).take(*by as usize))
});

// SAFETY: length of iter is trusted.
unsafe { LargeListArray::from_iter_binary_trusted_len(iter, self.len()) }
}))
}
}
57 changes: 48 additions & 9 deletions py-polars/tests/unit/dataframe/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -1692,20 +1692,47 @@ def test_repeat_by_unequal_lengths_panic() -> None:
)
with pytest.raises(
pl.ComputeError,
match="""Length of repeat_by argument needs to be 1 or equal to the length of the Series.""",
match="repeat_by argument and the Series should have equal length, "
"or at least one of them should have length 1",
):
df.select(pl.col("a").repeat_by(pl.Series([2, 2])))


@pytest.mark.parametrize(
("value", "values_expect"),
[
(1.2, [[1.2], [1.2, 1.2], [1.2, 1.2, 1.2]]),
(True, [[True], [True, True], [True, True, True]]),
("x", [["x"], ["x", "x"], ["x", "x", "x"]]),
(b"a", [[b"a"], [b"a", b"a"], [b"a", b"a", b"a"]]),
],
)
def test_repeat_by_broadcast_left(
value: float | bool | str, values_expect: list[list[float | bool | str]]
) -> None:
df = pl.DataFrame(
{
"n": [1, 2, 3],
}
)
expected = pl.DataFrame({"values": values_expect})
result = df.select(pl.lit(value).repeat_by(pl.col("n")).alias("values"))
assert_frame_equal(result, expected)


@pytest.mark.parametrize(
("a", "a_expected"),
[
([1.2, 2.2, 3.3], [[1.2, 1.2, 1.2], [2.2, 2.2, 2.2], [3.3, 3.3, 3.3]]),
([True, False], [[True, True, True], [False, False, False]]),
(["x", "y", "z"], [["x", "x", "x"], ["y", "y", "y"], ["z", "z", "z"]]),
(
[b"a", b"b", b"c"],
[[b"a", b"a", b"a"], [b"b", b"b", b"b"], [b"c", b"c", b"c"]],
),
],
)
def test_repeat_by_parameterized(
def test_repeat_by_broadcast_right(
a: list[float | bool | str], a_expected: list[list[float | bool | str]]
) -> None:
df = pl.DataFrame(
Expand All @@ -1720,13 +1747,25 @@ def test_repeat_by_parameterized(
assert_frame_equal(result, expected)


def test_repeat_by() -> None:
df = pl.DataFrame({"name": ["foo", "bar"], "n": [2, 3]})
out = df.select(pl.col("n").repeat_by("n"))
s = out["n"]

assert s[0].to_list() == [2, 2]
assert s[1].to_list() == [3, 3, 3]
@pytest.mark.parametrize(
("a", "a_expected"),
[
(["foo", "bar"], [["foo", "foo"], ["bar", "bar", "bar"]]),
([1, 2], [[1, 1], [2, 2, 2]]),
([True, False], [[True, True], [False, False, False]]),
(
[b"a", b"b"],
[[b"a", b"a"], [b"b", b"b", b"b"]],
),
],
)
def test_repeat_by(
a: list[float | bool | str], a_expected: list[list[float | bool | str]]
) -> None:
df = pl.DataFrame({"a": a, "n": [2, 3]})
expected = pl.DataFrame({"a": a_expected})
result = df.select(pl.col("a").repeat_by("n"))
assert_frame_equal(result, expected)


def test_join_dates() -> None:
Expand Down