Skip to content

Commit

Permalink
fix: zip_with also broadcast mask (#12309)
Browse files Browse the repository at this point in the history
  • Loading branch information
reswqa authored Nov 8, 2023
1 parent 1d65aac commit 81b02f3
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 38 deletions.
122 changes: 84 additions & 38 deletions crates/polars-core/src/chunked_array/ops/zip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,42 +27,72 @@ fn prepare_mask(mask: &BooleanArray) -> BooleanArray {
}

macro_rules! impl_ternary_broadcast {
($self:ident, $self_len:expr, $other_len:expr, $other:expr, $mask:expr, $ty:ty) => {{
match ($self_len, $other_len) {
(1, 1) => {
let left = $self.get(0);
let right = $other.get(0);
let mut val: ChunkedArray<$ty> = $mask
.into_no_null_iter()
.map(|mask_val| ternary_apply(mask_val, left, right))
.collect_trusted();
val.rename($self.name());
Ok(val)
}
(_, 1) => {
let right = $other.get(0);
let mut val: ChunkedArray<$ty> = $mask
.into_no_null_iter()
.zip($self)
.map(|(mask_val, left)| ternary_apply(mask_val, left, right))
.collect_trusted();
val.rename($self.name());
Ok(val)
}
(1, _) => {
let left = $self.get(0);
let mut val: ChunkedArray<$ty> = $mask
.into_no_null_iter()
.zip($other)
.map(|(mask_val, right)| ternary_apply(mask_val, left, right))
.collect_trusted();
val.rename($self.name());
Ok(val)
}
(_, _) => Err(polars_err!(
ShapeMismatch: "shapes of `mask` and `other` are not suitable for `zip_with` operation"
)),
($self:ident, $self_len:expr, $other_len:expr, $mask_len: expr, $other:expr, $mask:expr, $ty:ty) => {{
match ($self_len, $other_len, $mask_len) {
(1, 1, _) => {
let left = $self.get(0);
let right = $other.get(0);
let mut val: ChunkedArray<$ty> = $mask
.into_no_null_iter()
.map(|mask_val| ternary_apply(mask_val, left, right))
.collect_trusted();
val.rename($self.name());
Ok(val)
}
(_, 1, 1) => {
let right = $other.get(0);
let mask = $mask.get(0).unwrap_or(false);
let mut val: ChunkedArray<$ty> = $self
.into_iter()
.map(|left| ternary_apply(mask, left, right))
.collect_trusted();
val.rename($self.name());
Ok(val)
}
(1, _, 1) => {
let left = $self.get(0);
let mask = $mask.get(0).unwrap_or(false);
let mut val: ChunkedArray<$ty> = $other
.into_iter()
.map(|right| ternary_apply(mask, left, right))
.collect_trusted();
val.rename($self.name());
Ok(val)
},
(1, r_len, mask_len) if r_len == mask_len =>{
let left = $self.get(0);
let mut val: ChunkedArray<$ty> = $mask
.into_no_null_iter()
.zip($other)
.map(|(mask, right)| ternary_apply(mask, left, right))
.collect_trusted();
val.rename($self.name());
Ok(val)
},
(l_len, 1, mask_len) if l_len == mask_len => {
let right = $other.get(0);
let mut val: ChunkedArray<$ty> = $mask
.into_no_null_iter()
.zip($self)
.map(|(mask, left)| ternary_apply(mask, left, right))
.collect_trusted();
val.rename($self.name());
Ok(val)
},
(l_len, r_len, 1) if l_len == r_len => {
let mask = $mask.get(0).unwrap_or(false);
let mut val: ChunkedArray<$ty> = $self
.into_iter()
.zip($other)
.map(|(left, right)| ternary_apply(mask, left, right))
.collect_trusted();
val.rename($self.name());
Ok(val)
},
(_, _, _) => Err(polars_err!(
ShapeMismatch: "shapes of `self`, `mask` and `other` are not suitable for `zip_with` operation"
)),
}
}};
}

Expand Down Expand Up @@ -97,7 +127,7 @@ where
) -> PolarsResult<ChunkedArray<T>> {
// broadcasting path
if self.len() != mask.len() || other.len() != mask.len() {
impl_ternary_broadcast!(self, self.len(), other.len(), other, mask, T)
impl_ternary_broadcast!(self, self.len(), other.len(), mask.len(), other, mask, T)
} else {
zip_with(self, other, mask)
}
Expand All @@ -112,7 +142,15 @@ impl ChunkZip<BooleanType> for BooleanChunked {
) -> PolarsResult<BooleanChunked> {
// broadcasting path
if self.len() != mask.len() || other.len() != mask.len() {
impl_ternary_broadcast!(self, self.len(), other.len(), other, mask, BooleanType)
impl_ternary_broadcast!(
self,
self.len(),
other.len(),
mask.len(),
other,
mask,
BooleanType
)
} else {
zip_with(self, other, mask)
}
Expand All @@ -136,7 +174,15 @@ impl ChunkZip<BinaryType> for BinaryChunked {
other: &BinaryChunked,
) -> PolarsResult<BinaryChunked> {
if self.len() != mask.len() || other.len() != mask.len() {
impl_ternary_broadcast!(self, self.len(), other.len(), other, mask, BinaryType)
impl_ternary_broadcast!(
self,
self.len(),
other.len(),
mask.len(),
other,
mask,
BinaryType
)
} else {
zip_with(self, other, mask)
}
Expand Down
16 changes: 16 additions & 0 deletions py-polars/tests/unit/functions/test_whenthen.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,3 +256,19 @@ def test_when_then_deprecated_string_input() -> None:

expected = pl.Series("when", ["b", "c"])
assert_series_equal(result.to_series(), expected)


def test_predicate_broadcast() -> None:
df = pl.DataFrame(
{
"key": ["a", "a", "b", "b", "c", "c"],
"val": [1, 2, 3, 4, 5, 6],
}
)
out = df.group_by("key", maintain_order=True).agg(
agg=pl.when(pl.col("val").min() >= 3).then(pl.col("val")),
)
assert out.to_dict(as_series=False) == {
"key": ["a", "b", "c"],
"agg": [[None, None], [3, 4], [5, 6]],
}

0 comments on commit 81b02f3

Please sign in to comment.