diff --git a/crates/polars-core/src/chunked_array/ops/zip.rs b/crates/polars-core/src/chunked_array/ops/zip.rs index 86d51fa72989..cd592219f684 100644 --- a/crates/polars-core/src/chunked_array/ops/zip.rs +++ b/crates/polars-core/src/chunked_array/ops/zip.rs @@ -27,42 +27,72 @@ fn prepare_mask(mask: &BooleanArray) -> BooleanArray { } macro_rules! impl_ternary_broadcast { - ($self:ident, $self_len:expr, $other_len:expr, $other:expr, $mask:expr, $ty:ty) => {{ - match ($self_len, $other_len) { - (1, 1) => { - let left = $self.get(0); - let right = $other.get(0); - let mut val: ChunkedArray<$ty> = $mask - .into_no_null_iter() - .map(|mask_val| ternary_apply(mask_val, left, right)) - .collect_trusted(); - val.rename($self.name()); - Ok(val) - } - (_, 1) => { - let right = $other.get(0); - let mut val: ChunkedArray<$ty> = $mask - .into_no_null_iter() - .zip($self) - .map(|(mask_val, left)| ternary_apply(mask_val, left, right)) - .collect_trusted(); - val.rename($self.name()); - Ok(val) - } - (1, _) => { - let left = $self.get(0); - let mut val: ChunkedArray<$ty> = $mask - .into_no_null_iter() - .zip($other) - .map(|(mask_val, right)| ternary_apply(mask_val, left, right)) - .collect_trusted(); - val.rename($self.name()); - Ok(val) - } - (_, _) => Err(polars_err!( - ShapeMismatch: "shapes of `mask` and `other` are not suitable for `zip_with` operation" - )), + ($self:ident, $self_len:expr, $other_len:expr, $mask_len: expr, $other:expr, $mask:expr, $ty:ty) => {{ + match ($self_len, $other_len, $mask_len) { + (1, 1, _) => { + let left = $self.get(0); + let right = $other.get(0); + let mut val: ChunkedArray<$ty> = $mask + .into_no_null_iter() + .map(|mask_val| ternary_apply(mask_val, left, right)) + .collect_trusted(); + val.rename($self.name()); + Ok(val) + } + (_, 1, 1) => { + let right = $other.get(0); + let mask = $mask.get(0).unwrap_or(false); + let mut val: ChunkedArray<$ty> = $self + .into_iter() + .map(|left| ternary_apply(mask, left, right)) + .collect_trusted(); + val.rename($self.name()); + Ok(val) } + (1, _, 1) => { + let left = $self.get(0); + let mask = $mask.get(0).unwrap_or(false); + let mut val: ChunkedArray<$ty> = $other + .into_iter() + .map(|right| ternary_apply(mask, left, right)) + .collect_trusted(); + val.rename($self.name()); + Ok(val) + }, + (1, r_len, mask_len) if r_len == mask_len =>{ + let left = $self.get(0); + let mut val: ChunkedArray<$ty> = $mask + .into_no_null_iter() + .zip($other) + .map(|(mask, right)| ternary_apply(mask, left, right)) + .collect_trusted(); + val.rename($self.name()); + Ok(val) + }, + (l_len, 1, mask_len) if l_len == mask_len => { + let right = $other.get(0); + let mut val: ChunkedArray<$ty> = $mask + .into_no_null_iter() + .zip($self) + .map(|(mask, left)| ternary_apply(mask, left, right)) + .collect_trusted(); + val.rename($self.name()); + Ok(val) + }, + (l_len, r_len, 1) if l_len == r_len => { + let mask = $mask.get(0).unwrap_or(false); + let mut val: ChunkedArray<$ty> = $self + .into_iter() + .zip($other) + .map(|(left, right)| ternary_apply(mask, left, right)) + .collect_trusted(); + val.rename($self.name()); + Ok(val) + }, + (_, _, _) => Err(polars_err!( + ShapeMismatch: "shapes of `self`, `mask` and `other` are not suitable for `zip_with` operation" + )), + } }}; } @@ -97,7 +127,7 @@ where ) -> PolarsResult> { // broadcasting path if self.len() != mask.len() || other.len() != mask.len() { - impl_ternary_broadcast!(self, self.len(), other.len(), other, mask, T) + impl_ternary_broadcast!(self, self.len(), other.len(), mask.len(), other, mask, T) } else { zip_with(self, other, mask) } @@ -112,7 +142,15 @@ impl ChunkZip for BooleanChunked { ) -> PolarsResult { // broadcasting path if self.len() != mask.len() || other.len() != mask.len() { - impl_ternary_broadcast!(self, self.len(), other.len(), other, mask, BooleanType) + impl_ternary_broadcast!( + self, + self.len(), + other.len(), + mask.len(), + other, + mask, + BooleanType + ) } else { zip_with(self, other, mask) } @@ -136,7 +174,15 @@ impl ChunkZip for BinaryChunked { other: &BinaryChunked, ) -> PolarsResult { if self.len() != mask.len() || other.len() != mask.len() { - impl_ternary_broadcast!(self, self.len(), other.len(), other, mask, BinaryType) + impl_ternary_broadcast!( + self, + self.len(), + other.len(), + mask.len(), + other, + mask, + BinaryType + ) } else { zip_with(self, other, mask) } diff --git a/py-polars/tests/unit/functions/test_whenthen.py b/py-polars/tests/unit/functions/test_whenthen.py index f7612d156fb9..b6fc02409c6a 100644 --- a/py-polars/tests/unit/functions/test_whenthen.py +++ b/py-polars/tests/unit/functions/test_whenthen.py @@ -256,3 +256,19 @@ def test_when_then_deprecated_string_input() -> None: expected = pl.Series("when", ["b", "c"]) assert_series_equal(result.to_series(), expected) + + +def test_predicate_broadcast() -> None: + df = pl.DataFrame( + { + "key": ["a", "a", "b", "b", "c", "c"], + "val": [1, 2, 3, 4, 5, 6], + } + ) + out = df.group_by("key", maintain_order=True).agg( + agg=pl.when(pl.col("val").min() >= 3).then(pl.col("val")), + ) + assert out.to_dict(as_series=False) == { + "key": ["a", "b", "c"], + "agg": [[None, None], [3, 4], [5, 6]], + }