diff --git a/crates/polars-core/src/chunked_array/comparison/mod.rs b/crates/polars-core/src/chunked_array/comparison/mod.rs index 70af1c119377..3515cd6e4678 100644 --- a/crates/polars-core/src/chunked_array/comparison/mod.rs +++ b/crates/polars-core/src/chunked_array/comparison/mod.rs @@ -680,59 +680,84 @@ impl ChunkCompare<&str> for Utf8Chunked { } } -impl ChunkCompare<&ListChunked> for ListChunked { - type Item = BooleanChunked; - fn equal(&self, rhs: &ListChunked) -> BooleanChunked { - // SAFETY: unstable series never lives longer than the iterator. - unsafe { - self.amortized_iter() +#[doc(hidden)] +fn _list_comparison_helper(lhs: &ListChunked, rhs: &ListChunked, op: F) -> BooleanChunked +where + F: Fn(Option<&Series>, Option<&Series>) -> Option, +{ + match (lhs.len(), rhs.len()) { + (_, 1) => { + let right = rhs.get(0).map(|s| s.with_name("")); + // SAFETY: values within iterator do not outlive the iterator itself + unsafe { + lhs.amortized_iter() + .map(|left| op(left.as_ref().map(|us| us.as_ref()), right.as_ref())) + .collect_trusted() + } + }, + (1, _) => { + let left = lhs.get(0).map(|s| s.with_name("")); + // SAFETY: values within iterator do not outlive the iterator itself + unsafe { + rhs.amortized_iter() + .map(|right| op(left.as_ref(), right.as_ref().map(|us| us.as_ref()))) + .collect_trusted() + } + }, + // SAFETY: values within iterator do not outlive the iterator itself + _ => unsafe { + lhs.amortized_iter() .zip(rhs.amortized_iter()) - .map(|(left, right)| match (left, right) { - (Some(l), Some(r)) => Some(l.as_ref().series_equal_missing(r.as_ref())), - _ => None, + .map(|(left, right)| { + op( + left.as_ref().map(|us| us.as_ref()), + right.as_ref().map(|us| us.as_ref()), + ) }) .collect_trusted() - } + }, + } +} + +impl ChunkCompare<&ListChunked> for ListChunked { + type Item = BooleanChunked; + fn equal(&self, rhs: &ListChunked) -> BooleanChunked { + let _series_equal = |lhs: Option<&Series>, rhs: Option<&Series>| match (lhs, rhs) { + (Some(l), Some(r)) => Some(l.series_equal(r)), + _ => None, + }; + + _list_comparison_helper(self, rhs, _series_equal) } fn equal_missing(&self, rhs: &ListChunked) -> BooleanChunked { - // SAFETY: unstable series never lives longer than the iterator. - unsafe { - self.amortized_iter() - .zip(rhs.amortized_iter()) - .map(|(left, right)| match (left, right) { - (Some(l), Some(r)) => l.as_ref().series_equal_missing(r.as_ref()), - (None, None) => true, - _ => false, - }) - .collect_trusted() - } + let _series_equal_missing = |lhs: Option<&Series>, rhs: Option<&Series>| match (lhs, rhs) { + (Some(l), Some(r)) => Some(l.series_equal_missing(r)), + (None, None) => Some(true), + _ => Some(false), + }; + + _list_comparison_helper(self, rhs, _series_equal_missing) } fn not_equal(&self, rhs: &ListChunked) -> BooleanChunked { - // SAFETY: unstable series never lives longer than the iterator. - unsafe { - self.amortized_iter() - .zip(rhs.amortized_iter()) - .map(|(left, right)| match (left, right) { - (Some(l), Some(r)) => Some(!l.as_ref().series_equal_missing(r.as_ref())), - _ => None, - }) - .collect_trusted() - } + let _series_not_equal = |lhs: Option<&Series>, rhs: Option<&Series>| match (lhs, rhs) { + (Some(l), Some(r)) => Some(!l.series_equal(r)), + _ => None, + }; + + _list_comparison_helper(self, rhs, _series_not_equal) } fn not_equal_missing(&self, rhs: &ListChunked) -> BooleanChunked { - unsafe { - self.amortized_iter() - .zip(rhs.amortized_iter()) - .map(|(left, right)| match (left, right) { - (Some(l), Some(r)) => !l.as_ref().series_equal_missing(r.as_ref()), - (None, None) => false, - _ => true, - }) - .collect_trusted() - } + let _series_not_equal_missing = + |lhs: Option<&Series>, rhs: Option<&Series>| match (lhs, rhs) { + (Some(l), Some(r)) => Some(!l.series_equal_missing(r)), + (None, None) => Some(false), + _ => Some(true), + }; + + _list_comparison_helper(self, rhs, _series_not_equal_missing) } // The following are not implemented because gt, lt comparison of series don't make sense. @@ -1227,6 +1252,17 @@ mod test { assert_eq!(Vec::from(&c), &[Some(true), Some(false), None]) } + #[test] + fn list_broadcasting_lists() { + let s_el = Series::new("", &[1, 2, 3]); + let s_lhs = Series::new("", &[s_el.clone(), s_el.clone()]); + let s_rhs = Series::new("", &[s_el.clone()]); + + let result = s_lhs.list().unwrap().equal(s_rhs.list().unwrap()); + assert_eq!(result.len(), 2); + assert!(result.all()); + } + #[test] fn test_broadcasting_bools() { let a = BooleanChunked::from_slice("", &[true, false, true]); diff --git a/py-polars/tests/unit/series/test_series.py b/py-polars/tests/unit/series/test_series.py index 434968829031..990e789573cc 100644 --- a/py-polars/tests/unit/series/test_series.py +++ b/py-polars/tests/unit/series/test_series.py @@ -288,6 +288,9 @@ def test_equality() -> None: a = pl.Series("name", ["ham", "foo", "bar"]) assert_series_equal((a == "ham"), pl.Series("name", [True, False, False])) + a = pl.Series("name", [[1], [1, 2], [2, 3]]) + assert_series_equal((a == [1]), pl.Series("name", [True, False, False])) + def test_agg() -> None: series = pl.Series("a", [1, 2])