Skip to content

Commit

Permalink
fix(rust): Add broadcasting for list comparisons (#10857)
Browse files Browse the repository at this point in the history
  • Loading branch information
c-peters authored Sep 3, 2023
1 parent f6fe1b9 commit b2e6435
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 41 deletions.
118 changes: 77 additions & 41 deletions crates/polars-core/src/chunked_array/comparison/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -680,59 +680,84 @@ impl ChunkCompare<&str> for Utf8Chunked {
}
}

impl ChunkCompare<&ListChunked> for ListChunked {
type Item = BooleanChunked;
fn equal(&self, rhs: &ListChunked) -> BooleanChunked {
// SAFETY: unstable series never lives longer than the iterator.
unsafe {
self.amortized_iter()
#[doc(hidden)]
fn _list_comparison_helper<F>(lhs: &ListChunked, rhs: &ListChunked, op: F) -> BooleanChunked
where
F: Fn(Option<&Series>, Option<&Series>) -> Option<bool>,
{
match (lhs.len(), rhs.len()) {
(_, 1) => {
let right = rhs.get(0).map(|s| s.with_name(""));
// SAFETY: values within iterator do not outlive the iterator itself
unsafe {
lhs.amortized_iter()
.map(|left| op(left.as_ref().map(|us| us.as_ref()), right.as_ref()))
.collect_trusted()
}
},
(1, _) => {
let left = lhs.get(0).map(|s| s.with_name(""));
// SAFETY: values within iterator do not outlive the iterator itself
unsafe {
rhs.amortized_iter()
.map(|right| op(left.as_ref(), right.as_ref().map(|us| us.as_ref())))
.collect_trusted()
}
},
// SAFETY: values within iterator do not outlive the iterator itself
_ => unsafe {
lhs.amortized_iter()
.zip(rhs.amortized_iter())
.map(|(left, right)| match (left, right) {
(Some(l), Some(r)) => Some(l.as_ref().series_equal_missing(r.as_ref())),
_ => None,
.map(|(left, right)| {
op(
left.as_ref().map(|us| us.as_ref()),
right.as_ref().map(|us| us.as_ref()),
)
})
.collect_trusted()
}
},
}
}

impl ChunkCompare<&ListChunked> for ListChunked {
type Item = BooleanChunked;
fn equal(&self, rhs: &ListChunked) -> BooleanChunked {
let _series_equal = |lhs: Option<&Series>, rhs: Option<&Series>| match (lhs, rhs) {
(Some(l), Some(r)) => Some(l.series_equal(r)),
_ => None,
};

_list_comparison_helper(self, rhs, _series_equal)
}

fn equal_missing(&self, rhs: &ListChunked) -> BooleanChunked {
// SAFETY: unstable series never lives longer than the iterator.
unsafe {
self.amortized_iter()
.zip(rhs.amortized_iter())
.map(|(left, right)| match (left, right) {
(Some(l), Some(r)) => l.as_ref().series_equal_missing(r.as_ref()),
(None, None) => true,
_ => false,
})
.collect_trusted()
}
let _series_equal_missing = |lhs: Option<&Series>, rhs: Option<&Series>| match (lhs, rhs) {
(Some(l), Some(r)) => Some(l.series_equal_missing(r)),
(None, None) => Some(true),
_ => Some(false),
};

_list_comparison_helper(self, rhs, _series_equal_missing)
}

fn not_equal(&self, rhs: &ListChunked) -> BooleanChunked {
// SAFETY: unstable series never lives longer than the iterator.
unsafe {
self.amortized_iter()
.zip(rhs.amortized_iter())
.map(|(left, right)| match (left, right) {
(Some(l), Some(r)) => Some(!l.as_ref().series_equal_missing(r.as_ref())),
_ => None,
})
.collect_trusted()
}
let _series_not_equal = |lhs: Option<&Series>, rhs: Option<&Series>| match (lhs, rhs) {
(Some(l), Some(r)) => Some(!l.series_equal(r)),
_ => None,
};

_list_comparison_helper(self, rhs, _series_not_equal)
}

fn not_equal_missing(&self, rhs: &ListChunked) -> BooleanChunked {
unsafe {
self.amortized_iter()
.zip(rhs.amortized_iter())
.map(|(left, right)| match (left, right) {
(Some(l), Some(r)) => !l.as_ref().series_equal_missing(r.as_ref()),
(None, None) => false,
_ => true,
})
.collect_trusted()
}
let _series_not_equal_missing =
|lhs: Option<&Series>, rhs: Option<&Series>| match (lhs, rhs) {
(Some(l), Some(r)) => Some(!l.series_equal_missing(r)),
(None, None) => Some(false),
_ => Some(true),
};

_list_comparison_helper(self, rhs, _series_not_equal_missing)
}

// The following are not implemented because gt, lt comparison of series don't make sense.
Expand Down Expand Up @@ -1227,6 +1252,17 @@ mod test {
assert_eq!(Vec::from(&c), &[Some(true), Some(false), None])
}

#[test]
fn list_broadcasting_lists() {
let s_el = Series::new("", &[1, 2, 3]);
let s_lhs = Series::new("", &[s_el.clone(), s_el.clone()]);
let s_rhs = Series::new("", &[s_el.clone()]);

let result = s_lhs.list().unwrap().equal(s_rhs.list().unwrap());
assert_eq!(result.len(), 2);
assert!(result.all());
}

#[test]
fn test_broadcasting_bools() {
let a = BooleanChunked::from_slice("", &[true, false, true]);
Expand Down
3 changes: 3 additions & 0 deletions py-polars/tests/unit/series/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,9 @@ def test_equality() -> None:
a = pl.Series("name", ["ham", "foo", "bar"])
assert_series_equal((a == "ham"), pl.Series("name", [True, False, False]))

a = pl.Series("name", [[1], [1, 2], [2, 3]])
assert_series_equal((a == [1]), pl.Series("name", [True, False, False]))


def test_agg() -> None:
series = pl.Series("a", [1, 2])
Expand Down

0 comments on commit b2e6435

Please sign in to comment.