From 0ccb9a9d4027e8b620bf6f2e74b78e738ad714d1 Mon Sep 17 00:00:00 2001 From: Weijie Guo Date: Wed, 25 Oct 2023 20:53:02 +0800 Subject: [PATCH] feat: enable eq and neq for array dtype --- .../src/legacy/kernels/comparison.rs | 40 +++++++++++++++---- .../src/chunked_array/comparison/mod.rs | 28 +++++++++++-- crates/polars-core/src/series/comparison.rs | 2 + py-polars/tests/unit/datatypes/test_array.py | 16 ++++++++ 4 files changed, 75 insertions(+), 11 deletions(-) diff --git a/crates/polars-arrow/src/legacy/kernels/comparison.rs b/crates/polars-arrow/src/legacy/kernels/comparison.rs index a3ef62c850a3..341a0fe342d3 100644 --- a/crates/polars-arrow/src/legacy/kernels/comparison.rs +++ b/crates/polars-arrow/src/legacy/kernels/comparison.rs @@ -1,26 +1,52 @@ -use crate::array::{BooleanArray, FixedSizeListArray}; +use crate::array::{Array, BooleanArray, FixedSizeListArray}; use crate::bitmap::utils::count_zeros; use crate::legacy::utils::combine_validities_and; -fn fixed_size_list_cmp(a: &FixedSizeListArray, b: &FixedSizeListArray, func: F) -> BooleanArray +fn fixed_size_list_cmp( + a: &FixedSizeListArray, + b: &FixedSizeListArray, + cmp_func: F1, + func: F2, +) -> BooleanArray where - F: Fn(usize) -> bool, + F1: Fn(&dyn Array, &dyn Array) -> BooleanArray, + F2: Fn(usize) -> bool, { assert_eq!(a.size(), b.size()); - let mask = crate::compute::comparison::eq(a.values().as_ref(), b.values().as_ref()); + let mask = cmp_func(a.values().as_ref(), b.values().as_ref()); let mask = combine_validities_and(Some(mask.values()), mask.validity()).unwrap(); let (slice, offset, _len) = mask.as_slice(); assert_eq!(offset, 0); let width = a.size(); - let iter = (0..a.len()).map(|i| func(count_zeros(slice, i, width))); + let iter = (0..a.len()).map(|i| func(count_zeros(slice, i * width, width))); // range is trustedlen unsafe { BooleanArray::from_trusted_len_values_iter_unchecked(iter) } } pub fn fixed_size_list_eq(a: &FixedSizeListArray, b: &FixedSizeListArray) -> BooleanArray { - fixed_size_list_cmp(a, b, |count_zeros| count_zeros == 0) + fixed_size_list_cmp(a, b, crate::compute::comparison::eq, |count_zeros| { + count_zeros == 0 + }) } pub fn fixed_size_list_neq(a: &FixedSizeListArray, b: &FixedSizeListArray) -> BooleanArray { - fixed_size_list_cmp(a, b, |count_zeros| count_zeros != 0) + fixed_size_list_cmp(a, b, crate::compute::comparison::eq, |count_zeros| { + count_zeros != 0 + }) +} +pub fn fixed_size_list_eq_missing(a: &FixedSizeListArray, b: &FixedSizeListArray) -> BooleanArray { + fixed_size_list_cmp( + a, + b, + crate::compute::comparison::eq_and_validity, + |count_zeros| count_zeros == 0, + ) +} +pub fn fixed_size_list_neq_missing(a: &FixedSizeListArray, b: &FixedSizeListArray) -> BooleanArray { + fixed_size_list_cmp( + a, + b, + crate::compute::comparison::eq_and_validity, + |count_zeros| count_zeros != 0, + ) } diff --git a/crates/polars-core/src/chunked_array/comparison/mod.rs b/crates/polars-core/src/chunked_array/comparison/mod.rs index 7435b5a22864..5d62963f4ddf 100644 --- a/crates/polars-core/src/chunked_array/comparison/mod.rs +++ b/crates/polars-core/src/chunked_array/comparison/mod.rs @@ -856,6 +856,9 @@ impl ChunkCompare<&StructChunked> for StructChunked { impl ChunkCompare<&ArrayChunked> for ArrayChunked { type Item = BooleanChunked; fn equal(&self, rhs: &ArrayChunked) -> BooleanChunked { + if self.width() != rhs.width() { + return BooleanChunked::full("", false, self.len()); + } arity::binary_mut_with_options( self, rhs, @@ -865,11 +868,21 @@ impl ChunkCompare<&ArrayChunked> for ArrayChunked { } fn equal_missing(&self, rhs: &ArrayChunked) -> BooleanChunked { - // TODO!: maybe do something else here - self.equal(rhs) + if self.width() != rhs.width() { + return BooleanChunked::full("", false, self.len()); + } + arity::binary_mut_with_options( + self, + rhs, + arrow::legacy::kernels::comparison::fixed_size_list_eq_missing, + "", + ) } fn not_equal(&self, rhs: &ArrayChunked) -> BooleanChunked { + if self.width() != rhs.width() { + return BooleanChunked::full("", true, self.len()); + } arity::binary_mut_with_options( self, rhs, @@ -879,8 +892,15 @@ impl ChunkCompare<&ArrayChunked> for ArrayChunked { } fn not_equal_missing(&self, rhs: &ArrayChunked) -> Self::Item { - // TODO!: maybe do something else here - self.not_equal(rhs) + if self.width() != rhs.width() { + return BooleanChunked::full("", true, self.len()); + } + arity::binary_mut_with_options( + self, + rhs, + arrow::legacy::kernels::comparison::fixed_size_list_neq_missing, + "", + ) } // following are not implemented because gt, lt comparison of series don't make sense diff --git a/crates/polars-core/src/series/comparison.rs b/crates/polars-core/src/series/comparison.rs index f65a6e502e64..46d82980c431 100644 --- a/crates/polars-core/src/series/comparison.rs +++ b/crates/polars-core/src/series/comparison.rs @@ -28,6 +28,8 @@ macro_rules! impl_compare { DataType::Float32 => lhs.f32().unwrap().$method(rhs.f32().unwrap()), DataType::Float64 => lhs.f64().unwrap().$method(rhs.f64().unwrap()), DataType::List(_) => lhs.list().unwrap().$method(rhs.list().unwrap()), + #[cfg(feature = "dtype-array")] + DataType::Array(_, _) => lhs.array().unwrap().$method(rhs.array().unwrap()), #[cfg(feature = "dtype-struct")] DataType::Struct(_) => lhs .struct_() diff --git a/py-polars/tests/unit/datatypes/test_array.py b/py-polars/tests/unit/datatypes/test_array.py index fc4064addbb6..d6138eac972e 100644 --- a/py-polars/tests/unit/datatypes/test_array.py +++ b/py-polars/tests/unit/datatypes/test_array.py @@ -106,6 +106,22 @@ def test_array_concat() -> None: } +def test_array_equal_and_not_equal() -> None: + left = pl.Series([[1, 2], [3, 5]], dtype=pl.Array(width=2, inner=pl.Int64)) + right = pl.Series([[1, 2], [3, 1]], dtype=pl.Array(width=2, inner=pl.Int64)) + assert_series_equal(left == right, pl.Series([True, False])) + assert_series_equal(left.eq_missing(right), pl.Series([True, False])) + assert_series_equal(left != right, pl.Series([False, True])) + assert_series_equal(left.ne_missing(right), pl.Series([False, True])) + + left = pl.Series([[1, None], [3, None]], dtype=pl.Array(width=2, inner=pl.Int64)) + right = pl.Series([[1, None], [3, 4]], dtype=pl.Array(width=2, inner=pl.Int64)) + assert_series_equal(left == right, pl.Series([False, False])) + assert_series_equal(left.eq_missing(right), pl.Series([True, False])) + assert_series_equal(left != right, pl.Series([True, True])) + assert_series_equal(left.ne_missing(right), pl.Series([False, True])) + + def test_array_init_deprecation() -> None: with pytest.deprecated_call(): pl.Array(2)