From 5354f4b6f1a955dcb9b5ea3d35680e7778f820d4 Mon Sep 17 00:00:00 2001 From: ritchie Date: Mon, 25 Sep 2023 10:04:45 +0200 Subject: [PATCH] make statistics more general --- crates/polars-io/src/parquet/predicates.rs | 78 ++++++++++--------- .../src/physical_plan/expressions/apply.rs | 4 +- 2 files changed, 43 insertions(+), 39 deletions(-) diff --git a/crates/polars-io/src/parquet/predicates.rs b/crates/polars-io/src/parquet/predicates.rs index 1dfc6b231b5a..be64653ed438 100644 --- a/crates/polars-io/src/parquet/predicates.rs +++ b/crates/polars-io/src/parquet/predicates.rs @@ -1,4 +1,3 @@ -use arrow::compute::concatenate::concatenate; use arrow::io::parquet::read::statistics::{deserialize, Statistics}; use arrow::io::parquet::read::RowGroupMetaData; use polars_core::prelude::*; @@ -12,22 +11,35 @@ use crate::ArrowResult; /// - min value /// - null_count #[cfg_attr(debug_assertions, derive(Debug))] -pub struct ColumnStats(Statistics, Field); +pub struct ColumnStats { + field: Field, + // The array may hold the null count for every row group, + // or for a single row group. + null_count: Option, + min_value: Option, + max_value: Option, +} impl ColumnStats { - pub fn dtype(&self) -> DataType { - self.1.data_type().clone() + fn from_arrow_stats(stats: Statistics, field: &ArrowField) -> Self { + Self { + field: field.into(), + null_count: Some(Series::try_from(("", stats.null_count)).unwrap()), + min_value: Some(Series::try_from(("", stats.min_value)).unwrap()), + max_value: Some(Series::try_from(("", stats.max_value)).unwrap()), + } + } + + pub fn dtype(&self) -> &DataType { + self.field.data_type() } pub fn null_count(&self) -> Option { - match self.1.data_type() { + match self.field.data_type() { #[cfg(feature = "dtype-struct")] DataType::Struct(_) => None, _ => { - // the array holds the null count for every row group - // so we sum them to get them of the whole file. - let s = Series::try_from(("", self.0.null_count.clone())).unwrap(); - + let s = self.null_count.as_ref()?; // if all null, there are no statistics. if s.null_count() != s.len() { s.sum() @@ -39,64 +51,56 @@ impl ColumnStats { } pub fn to_min_max(&self) -> Option { - let max_val = &*self.0.max_value; - let min_val = &*self.0.min_value; + let max_val = self.max_value.as_ref()?; + let min_val = self.min_value.as_ref()?; - let dtype = DataType::from(min_val.data_type()); + let dtype = min_val.dtype(); if Self::use_min_max(dtype) { - let arr = concatenate(&[min_val, max_val]).unwrap(); - let s = Series::try_from(("", arr)).unwrap(); - if s.null_count() > 0 { + let mut min_max_values = min_val.clone(); + min_max_values.append(max_val).unwrap(); + if min_max_values.null_count() > 0 { None } else { - Some(s) + Some(min_max_values) } } else { None } } - pub fn to_min(&self) -> Option { - let min_val = self.0.min_value.clone(); - let dtype = DataType::from(min_val.data_type()); + pub fn to_min(&self) -> Option<&Series> { + let min_val = self.min_value.as_ref()?; + let dtype = min_val.dtype(); if !Self::use_min_max(dtype) || min_val.len() != 1 { return None; } - let s = Series::try_from(("", min_val)).unwrap(); - if s.null_count() > 0 { + if min_val.null_count() > 0 { None } else { - Some(s) + Some(min_val) } } - pub fn to_max(&self) -> Option { - let max_val = self.0.max_value.clone(); - let dtype = DataType::from(max_val.data_type()); + pub fn to_max(&self) -> Option<&Series> { + let max_val = self.max_value.as_ref()?; + let dtype = max_val.dtype(); if !Self::use_min_max(dtype) || max_val.len() != 1 { return None; } - let s = Series::try_from(("", max_val)).unwrap(); - if s.null_count() > 0 { + if max_val.null_count() > 0 { None } else { - Some(s) + Some(max_val) } } - #[cfg(feature = "dtype-binary")] - fn use_min_max(dtype: DataType) -> bool { - dtype.is_numeric() || matches!(dtype, DataType::Utf8) || matches!(dtype, DataType::Binary) - } - - #[cfg(not(feature = "dtype-binary"))] - fn use_min_max(dtype: DataType) -> bool { - dtype.is_numeric() || matches!(dtype, DataType::Utf8) + fn use_min_max(dtype: &DataType) -> bool { + dtype.is_numeric() || matches!(dtype, DataType::Utf8 | DataType::Binary) } } @@ -133,7 +137,7 @@ pub(crate) fn collect_statistics( Some(rg) => deserialize(fld, &md[rg..rg + 1])?, }; schema.with_column((&fld.name).into(), (&fld.data_type).into()); - stats.push(ColumnStats(st, fld.into())); + stats.push(ColumnStats::from_arrow_stats(st, fld)); } Ok(if stats.is_empty() { diff --git a/crates/polars-lazy/src/physical_plan/expressions/apply.rs b/crates/polars-lazy/src/physical_plan/expressions/apply.rs index aa6597ddcc7b..eec11625bd2d 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/apply.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/apply.rs @@ -467,8 +467,8 @@ impl ApplyExpr { let min = st.to_min()?; let max = st.to_max()?; - let all_smaller = || Some(ChunkCompare::lt(input, &min).ok()?.all()); - let all_bigger = || Some(ChunkCompare::gt(input, &max).ok()?.all()); + let all_smaller = || Some(ChunkCompare::lt(input, min).ok()?.all()); + let all_bigger = || Some(ChunkCompare::gt(input, max).ok()?.all()); Some(!all_smaller()? && !all_bigger()?) };