From 2b200f08181fe0a3ed108c950ee43a13e9f02492 Mon Sep 17 00:00:00 2001 From: Weijie Guo Date: Sun, 27 Aug 2023 01:22:56 +0800 Subject: [PATCH] feat(rust, python): repeat_by should also support broadcasting of LHS (#10735) --- .../src/chunked_array/ops/repeat_by.rs | 136 +++++++++++------- py-polars/tests/unit/dataframe/test_df.py | 57 ++++++-- 2 files changed, 132 insertions(+), 61 deletions(-) diff --git a/crates/polars-core/src/chunked_array/ops/repeat_by.rs b/crates/polars-core/src/chunked_array/ops/repeat_by.rs index 419065689dff..69e83a042879 100644 --- a/crates/polars-core/src/chunked_array/ops/repeat_by.rs +++ b/crates/polars-core/src/chunked_array/ops/repeat_by.rs @@ -8,8 +8,8 @@ type LargeListArray = ListArray; fn check_lengths(length_srs: usize, length_by: usize) -> PolarsResult<()> { polars_ensure!( - (length_srs == length_by) | (length_by == 1), - ComputeError: "Length of repeat_by argument needs to be 1 or equal to the length of the Series. Series length {}, by length {}", + (length_srs == length_by) | (length_by == 1) | (length_srs == 1), + ComputeError: "repeat_by argument and the Series should have equal length, or at least one of them should have length 1. Series length {}, by length {}", length_srs, length_by ); Ok(()) @@ -22,48 +22,66 @@ where fn repeat_by(&self, by: &IdxCa) -> PolarsResult { check_lengths(self.len(), by.len())?; - if (self.len() != by.len()) & (by.len() == 1) { - return self.repeat_by(&IdxCa::new( + match (self.len(), by.len()) { + (left_len, right_len) if left_len == right_len => { + Ok(arity::binary(self, by, |arr, by| { + let iter = arr.into_iter().zip(by).map(|(opt_v, opt_by)| { + opt_by.map(|by| std::iter::repeat(opt_v.copied()).take(*by as usize)) + }); + + // SAFETY: length of iter is trusted. + unsafe { + LargeListArray::from_iter_primitive_trusted_len( + iter, + T::get_dtype().to_arrow(), + ) + } + })) + }, + (_, 1) => self.repeat_by(&IdxCa::new( self.name(), std::iter::repeat(by.get(0).unwrap()) .take(self.len()) .collect::>(), - )); + )), + (1, _) => { + let new_array = self.new_from_index(0, by.len()); + new_array.repeat_by(by) + }, + // we have already checked the length + _ => unreachable!(), } - - Ok(arity::binary(self, by, |arr, by| { - let iter = arr.into_iter().zip(by).map(|(opt_v, opt_by)| { - opt_by.map(|by| std::iter::repeat(opt_v.copied()).take(*by as usize)) - }); - - // SAFETY: length of iter is trusted. - unsafe { - LargeListArray::from_iter_primitive_trusted_len(iter, T::get_dtype().to_arrow()) - } - })) } } + impl RepeatBy for BooleanChunked { fn repeat_by(&self, by: &IdxCa) -> PolarsResult { check_lengths(self.len(), by.len())?; - if (self.len() != by.len()) & (by.len() == 1) { - return self.repeat_by(&IdxCa::new( + match (self.len(), by.len()) { + (left_len, right_len) if left_len == right_len => { + Ok(arity::binary(self, by, |arr, by| { + let iter = arr.into_iter().zip(by).map(|(opt_v, opt_by)| { + opt_by.map(|by| std::iter::repeat(opt_v).take(*by as usize)) + }); + + // SAFETY: length of iter is trusted. + unsafe { LargeListArray::from_iter_bool_trusted_len(iter) } + })) + }, + (_, 1) => self.repeat_by(&IdxCa::new( self.name(), std::iter::repeat(by.get(0).unwrap()) .take(self.len()) .collect::>(), - )); + )), + (1, _) => { + let new_array = self.new_from_index(0, by.len()); + new_array.repeat_by(by) + }, + // we have already checked the length + _ => unreachable!(), } - - Ok(arity::binary(self, by, |arr, by| { - let iter = arr.into_iter().zip(by).map(|(opt_v, opt_by)| { - opt_by.map(|by| std::iter::repeat(opt_v).take(*by as usize)) - }); - - // SAFETY: length of iter is trusted. - unsafe { LargeListArray::from_iter_bool_trusted_len(iter) } - })) } } impl RepeatBy for Utf8Chunked { @@ -71,23 +89,30 @@ impl RepeatBy for Utf8Chunked { // TODO! dispatch via binary. check_lengths(self.len(), by.len())?; - if (self.len() != by.len()) & (by.len() == 1) { - return self.repeat_by(&IdxCa::new( + match (self.len(), by.len()) { + (left_len, right_len) if left_len == right_len => { + Ok(arity::binary(self, by, |arr, by| { + let iter = arr.into_iter().zip(by).map(|(opt_v, opt_by)| { + opt_by.map(|by| std::iter::repeat(opt_v).take(*by as usize)) + }); + + // SAFETY: length of iter is trusted. + unsafe { LargeListArray::from_iter_utf8_trusted_len(iter, self.len()) } + })) + }, + (_, 1) => self.repeat_by(&IdxCa::new( self.name(), std::iter::repeat(by.get(0).unwrap()) .take(self.len()) .collect::>(), - )); + )), + (1, _) => { + let new_array = self.new_from_index(0, by.len()); + new_array.repeat_by(by) + }, + // we have already checked the length + _ => unreachable!(), } - - Ok(arity::binary(self, by, |arr, by| { - let iter = arr.into_iter().zip(by).map(|(opt_v, opt_by)| { - opt_by.map(|by| std::iter::repeat(opt_v).take(*by as usize)) - }); - - // SAFETY: length of iter is trusted. - unsafe { LargeListArray::from_iter_utf8_trusted_len(iter, self.len()) } - })) } } @@ -95,22 +120,29 @@ impl RepeatBy for BinaryChunked { fn repeat_by(&self, by: &IdxCa) -> PolarsResult { check_lengths(self.len(), by.len())?; - if (self.len() != by.len()) & (by.len() == 1) { - return self.repeat_by(&IdxCa::new( + match (self.len(), by.len()) { + (left_len, right_len) if left_len == right_len => { + Ok(arity::binary(self, by, |arr, by| { + let iter = arr.into_iter().zip(by).map(|(opt_v, opt_by)| { + opt_by.map(|by| std::iter::repeat(opt_v).take(*by as usize)) + }); + + // SAFETY: length of iter is trusted. + unsafe { LargeListArray::from_iter_binary_trusted_len(iter, self.len()) } + })) + }, + (_, 1) => self.repeat_by(&IdxCa::new( self.name(), std::iter::repeat(by.get(0).unwrap()) .take(self.len()) .collect::>(), - )); + )), + (1, _) => { + let new_array = self.new_from_index(0, by.len()); + new_array.repeat_by(by) + }, + // we have already checked the length + _ => unreachable!(), } - - Ok(arity::binary(self, by, |arr, by| { - let iter = arr.into_iter().zip(by).map(|(opt_v, opt_by)| { - opt_by.map(|by| std::iter::repeat(opt_v).take(*by as usize)) - }); - - // SAFETY: length of iter is trusted. - unsafe { LargeListArray::from_iter_binary_trusted_len(iter, self.len()) } - })) } } diff --git a/py-polars/tests/unit/dataframe/test_df.py b/py-polars/tests/unit/dataframe/test_df.py index 0841c35dc6ba..269a34fe2578 100644 --- a/py-polars/tests/unit/dataframe/test_df.py +++ b/py-polars/tests/unit/dataframe/test_df.py @@ -1692,20 +1692,47 @@ def test_repeat_by_unequal_lengths_panic() -> None: ) with pytest.raises( pl.ComputeError, - match="""Length of repeat_by argument needs to be 1 or equal to the length of the Series.""", + match="repeat_by argument and the Series should have equal length, " + "or at least one of them should have length 1", ): df.select(pl.col("a").repeat_by(pl.Series([2, 2]))) +@pytest.mark.parametrize( + ("value", "values_expect"), + [ + (1.2, [[1.2], [1.2, 1.2], [1.2, 1.2, 1.2]]), + (True, [[True], [True, True], [True, True, True]]), + ("x", [["x"], ["x", "x"], ["x", "x", "x"]]), + (b"a", [[b"a"], [b"a", b"a"], [b"a", b"a", b"a"]]), + ], +) +def test_repeat_by_broadcast_left( + value: float | bool | str, values_expect: list[list[float | bool | str]] +) -> None: + df = pl.DataFrame( + { + "n": [1, 2, 3], + } + ) + expected = pl.DataFrame({"values": values_expect}) + result = df.select(pl.lit(value).repeat_by(pl.col("n")).alias("values")) + assert_frame_equal(result, expected) + + @pytest.mark.parametrize( ("a", "a_expected"), [ ([1.2, 2.2, 3.3], [[1.2, 1.2, 1.2], [2.2, 2.2, 2.2], [3.3, 3.3, 3.3]]), ([True, False], [[True, True, True], [False, False, False]]), (["x", "y", "z"], [["x", "x", "x"], ["y", "y", "y"], ["z", "z", "z"]]), + ( + [b"a", b"b", b"c"], + [[b"a", b"a", b"a"], [b"b", b"b", b"b"], [b"c", b"c", b"c"]], + ), ], ) -def test_repeat_by_parameterized( +def test_repeat_by_broadcast_right( a: list[float | bool | str], a_expected: list[list[float | bool | str]] ) -> None: df = pl.DataFrame( @@ -1720,13 +1747,25 @@ def test_repeat_by_parameterized( assert_frame_equal(result, expected) -def test_repeat_by() -> None: - df = pl.DataFrame({"name": ["foo", "bar"], "n": [2, 3]}) - out = df.select(pl.col("n").repeat_by("n")) - s = out["n"] - - assert s[0].to_list() == [2, 2] - assert s[1].to_list() == [3, 3, 3] +@pytest.mark.parametrize( + ("a", "a_expected"), + [ + (["foo", "bar"], [["foo", "foo"], ["bar", "bar", "bar"]]), + ([1, 2], [[1, 1], [2, 2, 2]]), + ([True, False], [[True, True], [False, False, False]]), + ( + [b"a", b"b"], + [[b"a", b"a"], [b"b", b"b", b"b"]], + ), + ], +) +def test_repeat_by( + a: list[float | bool | str], a_expected: list[list[float | bool | str]] +) -> None: + df = pl.DataFrame({"a": a, "n": [2, 3]}) + expected = pl.DataFrame({"a": a_expected}) + result = df.select(pl.col("a").repeat_by("n")) + assert_frame_equal(result, expected) def test_join_dates() -> None: